Skip to content
Snippets Groups Projects
3-prediction.py 6.91 KiB
Newer Older
  • Learn to ignore specific revisions
  • from Packages import *
    st.set_page_config(page_title="NIRS Utils", page_icon=":goat:", layout="wide")
    from Modules import *
    from Class_Mod.DATA_HANDLING import *
    
    # HTML pour le bandeau "CEFE - CNRS"
    
    # bandeau_html = """
    # <div style="width: 100%; background-color: #4682B4; padding: 10px; margin-bottom: 10px;">
    #   <h1 style="text-align: center; color: white;">CEFE - CNRS</h1>
    # </div>
    # """
    # # Injecter le code HTML du bandeau
    # st.markdown(bandeau_html, unsafe_allow_html=True)
    add_header()
    
    
    st.session_state["interface"] = st.session_state.get('interface')
    
    DIANE's avatar
    DIANE committed
    local_css(css_file / "style_model.css")
    
    DIANE's avatar
    DIANE committed
    
    
    
    DIANE's avatar
    DIANE committed
    st.header("Data loading", divider='blue')
    model_column1, space1, file_column1= st.columns([2, 1, 1])
    st.header("Prediction making", divider='blue')
    model_column2, space2, file_column2= st.columns([2, 1, 1])
    _, space3, _ = st.columns([1, 3, 1])
    
    
    DIANE's avatar
    DIANE committed
    
    
    DIANE's avatar
    DIANE committed
    files_format = ['.csv', '.dx']
    
    DIANE's avatar
    DIANE committed
    file = file_column1.file_uploader("Select NIRS Data to predict", type = files_format, help=" :mushroom: select a csv matrix with samples as rows and lambdas as columns")
    
    DIANE's avatar
    DIANE committed
    export_folder = './data/predictions/'
    export_name = 'Predictions_of_'
    reg_algo = ["Interval-PLS"]
    
    DIANE's avatar
    DIANE committed
    pred_data = pd.DataFrame
    loaded_model = None
    
    
    DIANE's avatar
    DIANE committed
    
    
    DIANE's avatar
    DIANE committed
    if file:
        test = file.name[file.name.find('.'):]
        export_name += file.name[:file.name.find('.')]
    
        if test == files_format[0]:
            #
    
    DIANE's avatar
    DIANE committed
            qsep = file_column1.selectbox("Select csv separator - _detected_: " + str(find_delimiter('data/'+file.name)), options=[";", ","], index=[";", ","].index(str(find_delimiter('data/'+file.name))), key=2)
            qhdr = file_column1.selectbox("indexes column in csv? - _detected_: " + str(find_col_index('data/'+file.name)), options=["no", "yes"], index=["no", "yes"].index(str(find_col_index('data/'+file.name))), key=3)
    
    DIANE's avatar
    DIANE committed
            if qhdr == 'yes':
                col = 0
            else:
                col = False
    
    DIANE's avatar
    DIANE committed
            pred_data = pd.read_csv(file, sep=qsep, index_col=col)
    
    DIANE's avatar
    DIANE committed
        elif test == files_format[1]:
            with NamedTemporaryFile(delete=False, suffix=".dx") as tmp:
                tmp.write(file.read())
                tmp_path = tmp.name
    
    DIANE's avatar
    DIANE committed
                chem_data, spectra, meta_data, _ = read_dx(file =  tmp_path)
                file_column1.success("The data have been loaded successfully", icon="")
                if chem_data.to_numpy().shape[1]>0:
                    yname = file_column1.selectbox('Select target', options=chem_data.columns)
                    measured = chem_data.loc[:,yname] == 0
                    y = chem_data.loc[:,yname].loc[measured]
                    pred_data = spectra.loc[measured]
                
                else:
                    pred_data = spectra
                    
    
    DIANE's avatar
    DIANE committed
            os.unlink(tmp_path)
    
    
    DIANE's avatar
    DIANE committed
    
    # Load parameters
    
    DIANE's avatar
    DIANE committed
    if not pred_data.empty:# Load the model with joblib
    
    DIANE's avatar
    DIANE committed
        model_column1.write('Raw spectra')
        fig = plot_spectra(pred_data, xunits = 'lab', yunits = "meta_data.loc[:,'yunits'][0]")
        model_column1.pyplot(fig)
    ### preprocessing
    preprocessed = pd.DataFrame
    if not pred_data.empty:
        params = file_column1.file_uploader("Load preprocessings params", type = '.json', help=" .json file")
        if params:
            prep = json.load(params)
    
            if prep['Scatter'] == 'SNV':
                x1 = Snv(pred_data)
            else:
                x1 = pred_data
            x2 = savgol_filter(x1,
                                window_length = prep["Saitzky-Golay derivative parameters"]["window_length"],
                                polyorder = prep["Saitzky-Golay derivative parameters"]["polyorder"],
                                deriv=prep["Saitzky-Golay derivative parameters"]["deriv"],
                                    delta=1.0, axis=-1, mode="interp", cval=0.0)
            preprocessed = pd.DataFrame(x2, index = pred_data.index, columns = pred_data.columns)
            
    ## plot preprocessed spectra
    if not preprocessed.empty:
        model_column1.write('Preprocessed spectra')
        fig2 = plot_spectra(preprocessed, xunits = 'lab', yunits = "meta_data.loc[:,'yunits'][0]")
        model_column1.pyplot(fig2)
    
    
    ################### Predictions making  ##########################
    if not pred_data.empty:# Load the model with joblib
        #dir = os.listdir('data/models/')[1:]
        dir = os.listdir('data/models/')
    
    DIANE's avatar
    DIANE committed
        dir.insert(0,'')
    
    DIANE's avatar
    DIANE committed
        model_name = model_column2.selectbox("Select your model from the dropdown list:", options = dir, key = 21)
    
    DIANE's avatar
    DIANE committed
    
        if model_name and model_name !='':
            export_name += '_with_' + model_name[:model_name.find('.')]
            with open('data/models/'+ model_name,'rb') as f:
                loaded_model = joblib.load(f)
                
            if loaded_model:
    
    DIANE's avatar
    DIANE committed
                model_column2.success("The model has been loaded successfully", icon="")
                s = model_column2.checkbox('the model is of ipls type?')
    
    DIANE's avatar
    DIANE committed
                if s:
    
    DIANE's avatar
    DIANE committed
                    index = model_column2.file_uploader("select wavelengths index file", type="csv")
    
    DIANE's avatar
    DIANE committed
                    if index:
                        idx = pd.read_csv(index, sep=';', index_col=0).iloc[:,0].to_numpy()
    
    
    if loaded_model:
    
    DIANE's avatar
    DIANE committed
        if model_column2.button('Predict'):
                if s:
                    result = loaded_model.predict(preprocessed.iloc[:,idx])
                else:
                    # use prediction function from application_functions.py to predict chemical values
                    result = loaded_model.predict(x2)
                result = pd.DataFrame(result, index = pred_data.index)
    
                st.write('Predicted values')
                st.dataframe(result.T)
                #############################
                # Creating histogram
                fig, axs = plt.subplots(1, 1,
                                        figsize =(12, 6), 
                                        tight_layout = True)
                
                # Add x, y gridlines 
                axs.grid( color ='grey', 
                        linestyle ='-.', linewidth = 0.5, 
                        alpha = 0.6) 
                plt.title('Predicted values distribution')
                # Remove axes splines 
                for s in ['top', 'bottom', 'left', 'right']: 
                    axs.spines[s].set_visible(False) 
                # Remove x, y ticks
                axs.xaxis.set_ticks_position('none') 
                axs.yaxis.set_ticks_position('none') 
                # Add padding between axes and labels 
                axs.xaxis.set_tick_params(pad = 5) 
                axs.yaxis.set_tick_params(pad = 10) 
                # Creating histogram
                N, bins, patches = axs.hist(result, bins = 12)
                # Setting color
                fracs = ((N**(1 / 5)) / N.max())
                norm = colors.Normalize(fracs.min(), fracs.max())
                
                for thisfrac, thispatch in zip(fracs, patches):
                    color = plt.cm.viridis(norm(thisfrac))
                    thispatch.set_facecolor(color)
    
                space3.pyplot(fig)
                ##################################
    
                result.to_csv(export_folder + export_name + '.csv', sep = ';')
                # export to local drive - Download
                download_results(export_folder + export_name + '.csv', export_name + '.csv')
                # create a report with information on the prediction
    
                ## see https://stackoverflow.com/a/59578663