Skip to content
Snippets Groups Projects
3-prediction.py 8.24 KiB
Newer Older
  • Learn to ignore specific revisions
  • from Packages import *
    st.set_page_config(page_title="NIRS Utils", page_icon=":goat:", layout="wide")
    from Modules import *
    from Class_Mod.DATA_HANDLING import *
    
    # HTML pour le bandeau "CEFE - CNRS"
    
    # bandeau_html = """
    # <div style="width: 100%; background-color: #4682B4; padding: 10px; margin-bottom: 10px;">
    #   <h1 style="text-align: center; color: white;">CEFE - CNRS</h1>
    # </div>
    # """
    # # Injecter le code HTML du bandeau
    # st.markdown(bandeau_html, unsafe_allow_html=True)
    add_header()
    
    DIANE's avatar
    DIANE committed
    local_css(css_file / "style_model.css")
    
    DIANE's avatar
    DIANE committed
    
    
    st.title("Prediction making using a previously developed model")
    M10, M20= st.columns([2, 1])
    
    DIANE's avatar
    DIANE committed
    M10.image("./images/prediction making.png", use_column_width=True)
    
    DIANE's avatar
    DIANE committed
    
    
    
    # st.header("Prediction making", divider='blue')
    # M5, M6 = st.columns([2, 0.01])
    
    DIANE's avatar
    DIANE committed
    
    
    DIANE's avatar
    DIANE committed
    
    
    DIANE's avatar
    DIANE committed
    files_format = ['.csv', '.dx']
    
    file = M20.file_uploader("Select NIRS Data to predict", type = files_format, help=" :mushroom: select a csv matrix with samples as rows and lambdas as columns")
    
    DIANE's avatar
    DIANE committed
    export_folder = './data/predictions/'
    export_name = 'Predictions_of_'
    reg_algo = ["Interval-PLS"]
    
    DIANE's avatar
    DIANE committed
    pred_data = pd.DataFrame()
    
    DIANE's avatar
    DIANE committed
    loaded_model = None
    
    
    if not file:
        M20.warning('Insert your spectral data file here!')
    else:
    
    DIANE's avatar
    DIANE committed
        test = file.name[file.name.find('.'):]
        export_name += file.name[:file.name.find('.')]
    
        if test == files_format[0]:
            #
    
    DIANE's avatar
    DIANE committed
            qsep = M20.selectbox("Select csv separator - _detected_: " + str(find_delimiter('data/'+file.name)), options=[";", ","], index=[";", ","].index(str(find_delimiter('data/'+file.name))), key=2)
            qhdr = M20.selectbox("indexes column in csv? - _detected_: " + str(find_col_index('data/'+file.name)), options=["no", "yes"], index=["no", "yes"].index(str(find_col_index('data/'+file.name))), key=3)
    
    DIANE's avatar
    DIANE committed
            if qhdr == 'yes':
                col = 0
            else:
                col = False
    
    DIANE's avatar
    DIANE committed
            pred_data = pd.read_csv(file, sep=qsep, index_col=col)
    
    DIANE's avatar
    DIANE committed
        elif test == files_format[1]:
            with NamedTemporaryFile(delete=False, suffix=".dx") as tmp:
                tmp.write(file.read())
                tmp_path = tmp.name
    
    DIANE's avatar
    DIANE committed
                chem_data, spectra, meta_data, _ = read_dx(file =  tmp_path)
    
    DIANE's avatar
    DIANE committed
                M20.success("The data have been loaded successfully", icon="")
    
    DIANE's avatar
    DIANE committed
                if chem_data.to_numpy().shape[1]>0:
    
    DIANE's avatar
    DIANE committed
                    yname = M20.selectbox('Select target', options=chem_data.columns)
    
    DIANE's avatar
    DIANE committed
                    measured = chem_data.loc[:,yname] == 0
                    y = chem_data.loc[:,yname].loc[measured]
                    pred_data = spectra.loc[measured]
                
                else:
                    pred_data = spectra
    
    DIANE's avatar
    DIANE committed
            os.unlink(tmp_path)
    
    
    DIANE's avatar
    DIANE committed
    
    # Load parameters
    
    DIANE's avatar
    DIANE committed
    st.header("I - Spectral data preprocessing & visualization", divider='blue')
    
    DIANE's avatar
    DIANE committed
    if not pred_data.empty:# Load the model with joblib
    
    DIANE's avatar
    DIANE committed
        M1.write('Raw spectra')
    
    DIANE's avatar
    DIANE committed
        fig = plot_spectra(pred_data, xunits = 'lab', yunits = "meta_data.loc[:,'yunits'][0]")
    
    DIANE's avatar
    DIANE committed
        M1.pyplot(fig)
    
    
    DIANE's avatar
    DIANE committed
    ### preprocessing
    preprocessed = pd.DataFrame
    if not pred_data.empty:
    
    DIANE's avatar
    DIANE committed
        params = M2.file_uploader("Load preprocessings params", type = '.json', help=" .json file")
    
    DIANE's avatar
    DIANE committed
        if params:
            prep = json.load(params)
    
    DIANE's avatar
    DIANE committed
            # M4.write(ProcessLookupError)
    
    DIANE's avatar
    DIANE committed
    
    
    DIANE's avatar
    DIANE committed
            if prep['normalization'] == 'Snv':
    
    DIANE's avatar
    DIANE committed
                x1 = Snv(pred_data)
    
    DIANE's avatar
    DIANE committed
                norm = 'Standard Normal Variate'
    
    DIANE's avatar
    DIANE committed
            else:
    
    DIANE's avatar
    DIANE committed
                norm = 'No Normalization was applied'
    
    DIANE's avatar
    DIANE committed
                x1 = pred_data
            x2 = savgol_filter(x1,
    
    DIANE's avatar
    DIANE committed
                                window_length = prep["window_length"],
                                polyorder = prep["polyorder"],
                                deriv=prep["deriv"],
    
    DIANE's avatar
    DIANE committed
                                    delta=1.0, axis=-1, mode="interp", cval=0.0)
            preprocessed = pd.DataFrame(x2, index = pred_data.index, columns = pred_data.columns)
    
    DIANE's avatar
    DIANE committed
    
    ################################################################################################
    
    DIANE's avatar
    DIANE committed
    ## plot preprocessed spectra
    if not preprocessed.empty:
    
    DIANE's avatar
    DIANE committed
        M3.write('Preprocessed spectra')
    
    DIANE's avatar
    DIANE committed
        fig2 = plot_spectra(preprocessed, xunits = 'lab', yunits = "meta_data.loc[:,'yunits'][0]")
    
    DIANE's avatar
    DIANE committed
        M3.pyplot(fig2)
        SG = f'- Savitzky-Golay derivative parameters \:(Window_length:{prep['window_length']};  polynomial order: {prep['polyorder']};  Derivative order : {prep['deriv']})'
        Norm = f'- Spectral Normalization \: {norm}'
    
    DIANE's avatar
    DIANE committed
        M4.info('The spectra were preprocessed using:\n'+SG+"\n"+Norm)
    
    DIANE's avatar
    DIANE committed
    
    ################### Predictions making  ##########################
    
    DIANE's avatar
    DIANE committed
    st.header("II - Prediction making", divider='blue')
    if not pred_data.empty and params:# Load the model with joblib
        M5, M6 = st.columns([2, 1])
    
    DIANE's avatar
    DIANE committed
        #dir = os.listdir('data/models/')[1:]
        dir = os.listdir('data/models/')
    
    DIANE's avatar
    DIANE committed
        dir.insert(0,'')
    
    DIANE's avatar
    DIANE committed
        model_name = M6.selectbox("Select your model from the dropdown list:", options = dir, key = 21, format_func=lambda x: x if x else "<Select>")
    
    DIANE's avatar
    DIANE committed
    
    
    DIANE's avatar
    DIANE committed
        if model_name:
    
    DIANE's avatar
    DIANE committed
            export_name += '_with_' + model_name[:model_name.find('.')]
            with open('data/models/'+ model_name,'rb') as f:
                loaded_model = joblib.load(f)
    
    DIANE's avatar
    DIANE committed
                ncols = loaded_model.n_features_in_
    
    DIANE's avatar
    DIANE committed
                
            if loaded_model:
    
    DIANE's avatar
    DIANE committed
                M6.success("The model has been loaded successfully", icon="")
                s = M6.checkbox('the model is of ipls type?')
    
    DIANE's avatar
    DIANE committed
                if s:
    
    DIANE's avatar
    DIANE committed
                    index = M6.file_uploader("select wavelengths index file", type="csv")
    
    DIANE's avatar
    DIANE committed
                    if index:
    
    DIANE's avatar
    DIANE committed
                        intervalls = pd.read_csv(index, sep=';', index_col=0).to_numpy()
                        idx = []
                        for i in range(intervalls.shape[0]):
                            idx.extend(np.arange(intervalls[i,0], intervalls[i,1]+1))
    
    DIANE's avatar
    DIANE committed
                        if max(idx) <= preprocessed.shape[1]:
                            preprocessed = preprocessed.iloc[:,idx] ### get predictors
                        else:
                            M6.error("Error: The number of columns in your data does not match the number of columns used to train the model. Please ensure they are the same.")
    
    
    DIANE's avatar
    DIANE committed
    
    if loaded_model:
    
    DIANE's avatar
    DIANE committed
        if M6.button('Predict', type='primary'):
    
    DIANE's avatar
    DIANE committed
                if ncols == preprocessed.shape[1]:
                    result = pd.DataFrame(loaded_model.predict(preprocessed), index = preprocessed.index)
    
                    #############################
                    if preprocessed.shape[1]>1:
                        M5.write('Predicted values distribution')
                        # Creating histogram
                        fig, axs = plt.subplots(1, 1, figsize =(15, 3), 
                                                tight_layout = True)
                        
                        # Add x, y gridlines 
                        axs.grid( color ='grey', linestyle ='-.', linewidth = 0.5, alpha = 0.6) 
                        # Remove axes splines 
                        for s in ['top', 'bottom', 'left', 'right']: 
                            axs.spines[s].set_visible(False) 
                        # Remove x, y ticks
                        axs.xaxis.set_ticks_position('none') 
                        axs.yaxis.set_ticks_position('none') 
                        # Add padding between axes and labels 
                        axs.xaxis.set_tick_params(pad = 5) 
                        axs.yaxis.set_tick_params(pad = 10) 
                        # Creating histogram
                        N, bins, patches = axs.hist(result, bins = 12)
                        # Setting color
                        fracs = ((N**(1 / 5)) / N.max())
                        norm = colors.Normalize(fracs.min(), fracs.max())
                        
                        for thisfrac, thispatch in zip(fracs, patches):
                            color = plt.cm.viridis(norm(thisfrac))
                            thispatch.set_facecolor(color)
    
                        M5.pyplot(fig)
                    st.write('Predicted values table')
                    st.dataframe(result.T)
                    ##################################
    
                    # result.to_csv(export_folder + export_name + '.csv', sep = ';')
                    # export to local drive - Download
                    download_results(export_folder + export_name + '.csv', export_name + '.csv')
                    # create a report with information on the prediction
                    ## see https://stackoverflow.com/a/59578663
    
    DIANE's avatar
    DIANE committed
                else:
    
    DIANE's avatar
    DIANE committed
                    M6.error(f'Error: The model was trained with {ncols} wavelengths, but you provided {preprocessed.shape[1]} wavelengths for prediction. Please ensure they match.')
    
    DIANE's avatar
    DIANE committed