Skip to content
Snippets Groups Projects
3-prediction.py 13.9 KiB
Newer Older
  • Learn to ignore specific revisions
  • from Packages import *
    st.set_page_config(page_title="NIRS Utils", page_icon=":goat:", layout="wide")
    from Modules import *
    from Class_Mod.DATA_HANDLING import *
    
    # HTML pour le bandeau "CEFE - CNRS"
    
    # bandeau_html = """
    # <div style="width: 100%; background-color: #4682B4; padding: 10px; margin-bottom: 10px;">
    #   <h1 style="text-align: center; color: white;">CEFE - CNRS</h1>
    # </div>
    # """
    # # Injecter le code HTML du bandeau
    # st.markdown(bandeau_html, unsafe_allow_html=True)
    add_header()
    
    DIANE's avatar
    DIANE committed
    local_css(css_file / "style_model.css")
    
    import shutil
    hash_ = ''
    def p_hash(add):
        global hash_
        hash_ = hash_data(hash_+str(add))
        return hash_
    
    DIANE's avatar
    DIANE committed
    
    
    dirpath = Path('Report/out/model')
    if dirpath.exists() and dirpath.is_dir():
        shutil.rmtree(dirpath)
    
    DIANE's avatar
    DIANE committed
    
    
    # ####################################  Methods ##############################################
    # empty temp figures
    def delete_files(keep):
        supp = []
        # Walk through the directory
        for root, dirs, files in os.walk('Report/', topdown=False):
            for file in files:
                if file != 'logo_cefe.png' and not any(file.endswith(ext) for ext in keep):
                    os.remove(os.path.join(root, file))
    ###################################################################
    
    st.title("Prediction making using a previously developed model")
    M10, M20= st.columns([2, 1])
    M10.image("./images/prediction making.png", use_column_width=True)
    def preparespecdf(df):
        other = df.select_dtypes(exclude = 'float')
        rownames = other.iloc[:,0]
        spec = df.select_dtypes(include='float')
        spec.index = rownames
        return spec, other, rownames
    
    DIANE's avatar
    DIANE committed
    
    
    def check_exist(var):
        out = var in globals()
        return out
    
    DIANE's avatar
    DIANE committed
    
    
    DIANE's avatar
    DIANE committed
    files_format = ['.csv', '.dx']
    
    DIANE's avatar
    DIANE committed
    export_folder = './data/predictions/'
    export_name = 'Predictions_of_'
    reg_algo = ["Interval-PLS"]
    
    
    
    with M20:
        file = st.file_uploader("Load NIRS Data for prediction making:", type = files_format, help=" :mushroom: select a csv matrix with samples as rows and lambdas as columns")
        
        if not file:
            st.info('Info: Insert your spectral data file above!')
        else:
            p_hash(file.name)
            test = file.name[file.name.find('.'):]
            export_name += file.name[:file.name.find('.')]
    
            if test == files_format[0]:
                qsep = st.radio("Select csv separator - _detected_: " + str(find_delimiter('data/'+file.name)), options=[";", ","],index=[";", ","].index(str(find_delimiter('data/'+file.name))), key=2, horizontal= True)
                qhdr = st.radio("indexes column in csv? - _detected_: " + str(find_col_index('data/'+file.name)), options=["no", "yes"], index=["no", "yes"].index(str(find_col_index('data/'+file.name))), key=3, horizontal= True)
                col = 0 if qhdr == 'yes' else None
                p_hash([qsep,qhdr])
    
                df = pd.read_csv(file, sep=qsep, header= col)
                pred_data, cat, rownames =  preparespecdf(df)
    
            elif test == files_format[1]:
                with NamedTemporaryFile(delete=False, suffix=".dx") as tmp:
                    tmp.write(file.read())
                    tmp_path = tmp.name
                    with open(tmp.name, 'r') as dd:
                        dxdata = file.read()
                        p_hash(str(dxdata)+str(file.name))
    
                    ## load and parse the temp dx file
                    @st.cache_data
                    def dx_loader(change):
                        chem_data, spectra, meta_data, _ = read_dx(file =  tmp_path)
                        return chem_data, spectra, meta_data, _
                    chem_data, spectra, meta_data, _ = dx_loader(change = hash_)
                    st.success("The data have been loaded successfully", icon="")
                    if chem_data.to_numpy().shape[1]>0:
                        yname = st.selectbox('Select target', options=chem_data.columns)
                        measured = chem_data.loc[:,yname] == 0
                        y = chem_data.loc[:,yname].loc[measured]
                        pred_data = spectra.loc[measured]
                    
                    else:
                        pred_data = spectra
                os.unlink(tmp_path)
    
    DIANE's avatar
    DIANE committed
    
    
    DIANE's avatar
    DIANE committed
    
    # Load parameters
    
    DIANE's avatar
    DIANE committed
    st.header("I - Spectral data preprocessing & visualization", divider='blue')
    
    try:
        if check_exist("pred_data"):# Load the model with joblib
            @st.cache_data
            def specplot_raw(change):
                fig2 = plot_spectra(pred_data, xunits = 'lab', yunits = "meta_data.loc[:,'yunits'][0]")
                return fig2
            rawspectraplot = specplot_raw(change = hash_)
            M1, M2= st.columns([2, 1])
            with M1:
                st.write('Raw spectra')
                st.pyplot(rawspectraplot)
    
            with M2:
                params = st.file_uploader("Load preprocessings params", type = '.json', help=" .json file")
                if params:
                    prep = json.load(params)
                    p_hash(prep)
                    
                    @st.cache_data
                    def preprocess_spectra(change):
                        # M4.write(ProcessLookupError)
                        
                        if prep['normalization'] == 'Snv':
                            x1 = Snv(pred_data)
                            norm = 'Standard Normal Variate'
    
    DIANE's avatar
    DIANE committed
                        else:
    
                            norm = 'No Normalization was applied'
                            x1 = pred_data
                        x2 = savgol_filter(x1,
                                            window_length = int(prep["window_length"]),
                                            polyorder = int(prep["polyorder"]),
                                            deriv = int(prep["deriv"]),
                                                delta=1.0, axis=-1, mode="interp", cval=0.0)
                        preprocessed = pd.DataFrame(x2, index = pred_data.index, columns = pred_data.columns)
                        return norm, prep, preprocessed
                    norm, prep, preprocessed = preprocess_spectra(change= hash_)
    
        ################################################################################################
        ## plot preprocessed spectra
        if check_exist("preprocessed"):
            p_hash(preprocessed)
            M3, M4= st.columns([2, 1])
            with M3:
                st.write('Preprocessed spectra')
                def specplot_prep(change):
                    fig2 = plot_spectra(preprocessed, xunits = 'lab', yunits = "meta_data.loc[:,'yunits'][0]")
                    return fig2
                prepspectraplot = specplot_prep(change = hash_)
                st.pyplot(prepspectraplot)
    
            with M4:
                @st.cache_data
                def prep_info(change):
                    SG = f'- Savitzky-Golay derivative parameters \:(Window_length:{prep['window_length']};  polynomial order: {prep['polyorder']};  Derivative order : {prep['deriv']})'
                    Norm = f'- Spectral Normalization \: {norm}'
                    return SG, Norm
                SG, Norm = prep_info(change = hash_)
                st.info('The spectra were preprocessed using:\n'+SG+"\n"+Norm)
    
        ################### Predictions making  ##########################
        st.header("II - Prediction making", divider='blue')
        if check_exist("pred_data") and params:# Load the model with joblib
            M5, M6 = st.columns([2, 1])
            model_file = M6.file_uploader("Load your model", type = '.pkl', help=" .pkl file")
            if model_file:
                with M6:
                    try:
                        model = joblib.load(model_file)
                        st.success("The model has been loaded successfully", icon="")
                        nvar = model.n_features_in_
    
                    except:
                        st.error("Error: Something went wrong, the model was not loaded !", icon="")
            
            with M6:
                s = st.checkbox('Check this box if your model is of ipls type!', disabled = False if 'model' in globals() else True)
                index = st.file_uploader("select wavelengths index file", type="csv", disabled = [False if s else True][0])
                if check_exist('preprocessed'):
                    if s:
                        if index:
                            intervalls = pd.read_csv(index, sep=';', index_col=0).to_numpy()
                            idx = []
                            for i in range(intervalls.shape[0]):
                                idx.extend(np.arange(intervalls[i,2], intervalls[i,3]+1))
                            if max(idx) <= preprocessed.shape[1]:
                                preprocesseddf = preprocessed.iloc[:,idx] ### get predictors    
                            else:
                                st.error("Error: The number of columns in your data does not match the number of columns used to train the model. Please ensure they are the same.")
                    else:
                        preprocesseddf = preprocessed
                    
    
                    
                if check_exist("model") == False:
                    disable = True
                elif check_exist("model") == True:
                    if s and not index :
                        disable = True
                    elif s and index:
                        disable  = False
                    elif not s and not index:
                        disable  = False
                    elif not s and index:
                        disable  = True
    
                    
                pred_button = M6.button('Predict', type='primary', disabled= disable)
    
                if check_exist("preprocesseddf"):
                    if pred_button and nvar == preprocesseddf.shape[1]:
                        try:
                            result = pd.DataFrame(model.predict(preprocesseddf), index = rownames, columns = ['Results'])
                        except:
                            st.error(f'''Error: Length mismatch: the number of samples indices is {len(rownames)}, while the model produced 
                                    {len(model.predict(preprocesseddf))} values. correct the "indexes column in csv?" parameter''')
                        with M5:
                            if preprocesseddf.shape[1]>1 and check_exist('result'):
                                st.write('Predicted values distribution')
                                # Creating histogram
                                hist, axs = plt.subplots(1, 1, figsize =(15, 3), 
                                                        tight_layout = True)
                                
                                # Add x, y gridlines 
                                axs.grid( color ='grey', linestyle ='-.', linewidth = 0.5, alpha = 0.6) 
                                # Remove axes splines 
                                for s in ['top', 'bottom', 'left', 'right']: 
                                    axs.spines[s].set_visible(False) 
                                # Remove x, y ticks
                                axs.xaxis.set_ticks_position('none') 
                                axs.yaxis.set_ticks_position('none') 
                                # Add padding between axes and labels 
                                axs.xaxis.set_tick_params(pad = 5) 
                                axs.yaxis.set_tick_params(pad = 10) 
                                # Creating histogram
                                N, bins, patches = axs.hist(result, bins = 12)
                                # Setting color
                                fracs = ((N**(1 / 5)) / N.max())
                                norm = colors.Normalize(fracs.min(), fracs.max())
                                
                                for thisfrac, thispatch in zip(fracs, patches):
                                    color = plt.cm.viridis(norm(thisfrac))
                                    thispatch.set_facecolor(color)
    
                                st.pyplot(hist)
                                st.write('Predicted values table')
                                st.dataframe(result.T)
                                #################################3
                    elif pred_button and nvar != preprocesseddf.shape[1]:
                        M6.error(f'Error: The model was trained on {nvar} wavelengths, but you provided {preprocessed.shape[1]} wavelengths for prediction. Please ensure they match!')
    
    
        if check_exist('result'):
            @st.cache_data(show_spinner =False)
            def preparing_results_for_downloading(change):
                match test:
                    # load csv file
                    case '.csv':
                        df.to_csv('Report/out/dataset/'+ file.name, sep = ';', encoding = 'utf-8', mode = 'a')
                    case '.dx':
                        with open('Report/out/dataset/'+file.name, 'w') as dd:
                            dd.write(dxdata)
    
                prepspectraplot.savefig('./Report/out/figures/raw_spectra.png')
                rawspectraplot.savefig('./Report/out/figures/preprocessed_spectra.png')
                hist.savefig('./Report/out/figures/histogram.png')
                result.round(4).to_csv('./Report/out/The analysis result.csv', sep = ";", index_col=0)
    
                return change
            preparing_results_for_downloading(change = hash_)
    
            import tempfile
            @st.cache_data(show_spinner =False)
            def tempdir(change):
                with  tempfile.TemporaryDirectory( prefix="results", dir="./Report") as temp_dir:# create a temp directory
                    tempdirname = os.path.split(temp_dir)[1]
                    if len(os.listdir('./Report/out/figures/'))==3:
                        shutil.make_archive(base_name="./Report/Results", format="zip", base_dir="out", root_dir = "./Report")# create a zip file
                        shutil.move("./Report/Results.zip", f"./Report/{tempdirname}/Results.zip")# put the inside the temp dir
                        with open(f"./Report/{tempdirname}/Results.zip", "rb") as f:
                            zip_data = f.read()
                return tempdirname, zip_data
    
            date_time = datetime.datetime.now().strftime('%y%m%d%H%M')
            try :
                tempdirname, zip_data = tempdir(change = hash_)
                st.download_button(label = 'Download', data = zip_data, file_name = f'Nirs_Workflow_{date_time}_Pred_.zip', mime ="application/zip",
                            args = None, kwargs = None,type = "primary",use_container_width = True)
            except:
                st.write('rtt')
    except:
        M20.error('''Error: Data loading failed. Please check your file. Consider fine-tuning the dialect settings or ensure the file isn't corrupted.''')
    
    DIANE's avatar
    DIANE committed
    
    
    DIANE's avatar
    DIANE committed
    
    
    DIANE's avatar
    DIANE committed