From fd8ae76ba115ed6a27382aa83b84e72f695b9688 Mon Sep 17 00:00:00 2001
From: DIANE <abderrahim.diane@cefe.cnrs.fr>
Date: Thu, 29 Aug 2024 16:16:39 +0200
Subject: [PATCH] Download report, app performance enhancements

---
 src/pages/2-model_creation.py |  10 +-
 src/pages/3-prediction.py     | 418 +++++++++++++++++++++-------------
 2 files changed, 270 insertions(+), 158 deletions(-)

diff --git a/src/pages/2-model_creation.py b/src/pages/2-model_creation.py
index a4a415f..4aefb6e 100644
--- a/src/pages/2-model_creation.py
+++ b/src/pages/2-model_creation.py
@@ -516,7 +516,7 @@ if Reg:
     numbers_dict = {1: "One", 2: "Two",3: "Three",4: "Four",5: "Five",
                     6: "Six",7: "Seven",8: "Eight",9: "Nine",10: "Ten"}
     st.header(f" {numbers_dict[nb_folds]}-Fold Cross-Validation results")
-
+    
     @st.cache_data(show_spinner =False)
     def cv_display(change):
         fig1 = px.scatter(Reg.cv_data_[0], x = 'Measured', y = 'Predicted' , trendline = 'ols', color = 'Folds', symbol = 'Folds',
@@ -530,7 +530,7 @@ if Reg:
         fig0.update_traces(marker_size = 8, showlegend = False)
         return fig0, fig1
     fig0, fig1 = cv_display(change= Reg.cv_data_)
-
+    
     cv1, cv2 = st.columns([2, 2])
     with cv2:
         cv_results = pd.DataFrame(Reg.CV_results_).round(4)# CV table
@@ -545,12 +545,11 @@ if Reg:
         st.plotly_chart(fig0, use_container_width=True)
     
 
-
     ###################################################    BEGIN : Model Diagnosis    ####################################################
 st.header("III - Model Diagnosis", divider='blue')
 if Reg:
     # signal preprocessing results preparation for latex report
-    prep_para = Reg.best_hyperparams_
+    prep_para = Reg.best_hyperparams_.copy()
     if model_type != reg_algo[2]:
         prep_para.pop('n_components')
         for i in ['deriv','polyorder']:
@@ -560,7 +559,7 @@ if Reg:
                 prep_para[i] = '1st'
             elif Reg.best_hyperparams_[i] > 1:
                 prep_para[i] = f"{Reg.best_hyperparams_[i]}nd"
-
+    
     # reg plot and residuals plot
     if model_type != reg_algo[2]:
         measured_vs_predicted = reg_plot([y_train, y_test],[yc, yt], train_idx = train_index, test_idx = test_index)
@@ -581,7 +580,6 @@ if Reg:
         # residual_plot.savefig('./Report/figures/residuals_plot.png')
 
 ###################################################      END : Model Diagnosis   #######################################################
-
     
 ###################################################    BEGIN : Download results    #######################################################
 ##########################################################################################################################################
diff --git a/src/pages/3-prediction.py b/src/pages/3-prediction.py
index f73aa24..bed4a0e 100644
--- a/src/pages/3-prediction.py
+++ b/src/pages/3-prediction.py
@@ -14,174 +14,288 @@ add_header()
 add_sidebar(pages_folder)
 
 local_css(css_file / "style_model.css")
+import shutil
+hash_ = ''
+def p_hash(add):
+    global hash_
+    hash_ = hash_data(hash_+str(add))
+    return hash_
 
-st.title("Prediction making using a previously developed model")
-M10, M20= st.columns([2, 1])
-M10.image("./images/prediction making.png", use_column_width=True)
+dirpath = Path('Report/out/model')
+if dirpath.exists() and dirpath.is_dir():
+    shutil.rmtree(dirpath)
 
-# M1, M2= st.columns([2, 1])
+# ####################################  Methods ##############################################
+# empty temp figures
+def delete_files(keep):
+    supp = []
+    # Walk through the directory
+    for root, dirs, files in os.walk('Report/', topdown=False):
+        for file in files:
+            if file != 'logo_cefe.png' and not any(file.endswith(ext) for ext in keep):
+                os.remove(os.path.join(root, file))
+###################################################################
 
 
 
-# st.header("Prediction making", divider='blue')
-# M5, M6 = st.columns([2, 0.01])
+st.title("Prediction making using a previously developed model")
+M10, M20= st.columns([2, 1])
+M10.image("./images/prediction making.png", use_column_width=True)
+def preparespecdf(df):
+    other = df.select_dtypes(exclude = 'float')
+    rownames = other.iloc[:,0]
+    spec = df.select_dtypes(include='float')
+    spec.index = rownames
+    return spec, other, rownames
 
+def check_exist(var):
+    out = var in globals()
+    return out
 
 files_format = ['.csv', '.dx']
-file = M20.file_uploader("Select NIRS Data to predict", type = files_format, help=" :mushroom: select a csv matrix with samples as rows and lambdas as columns")
 export_folder = './data/predictions/'
 export_name = 'Predictions_of_'
 reg_algo = ["Interval-PLS"]
-pred_data = pd.DataFrame()
-loaded_model = None
-
-if not file:
-    M20.warning('Insert your spectral data file here!')
-else:
-    test = file.name[file.name.find('.'):]
-    export_name += file.name[:file.name.find('.')]
-
-    if test == files_format[0]:
-        #
-        qsep = M20.selectbox("Select csv separator - _detected_: " + str(find_delimiter('data/'+file.name)), options=[";", ","], index=[";", ","].index(str(find_delimiter('data/'+file.name))), key=2)
-        qhdr = M20.selectbox("indexes column in csv? - _detected_: " + str(find_col_index('data/'+file.name)), options=["no", "yes"], index=["no", "yes"].index(str(find_col_index('data/'+file.name))), key=3)
-        if qhdr == 'yes':
-            col = 0
-        else:
-            col = False
-        pred_data = pd.read_csv(file, sep=qsep, index_col=col)
-
-    elif test == files_format[1]:
-        with NamedTemporaryFile(delete=False, suffix=".dx") as tmp:
-            tmp.write(file.read())
-            tmp_path = tmp.name
-            chem_data, spectra, meta_data, _ = read_dx(file =  tmp_path)
-            M20.success("The data have been loaded successfully", icon="✅")
-            if chem_data.to_numpy().shape[1]>0:
-                yname = M20.selectbox('Select target', options=chem_data.columns)
-                measured = chem_data.loc[:,yname] == 0
-                y = chem_data.loc[:,yname].loc[measured]
-                pred_data = spectra.loc[measured]
-            
-            else:
-                pred_data = spectra
-        os.unlink(tmp_path)
+
+
+with M20:
+    file = st.file_uploader("Load NIRS Data for prediction making:", type = files_format, help=" :mushroom: select a csv matrix with samples as rows and lambdas as columns")
+    
+    if not file:
+        st.info('Info: Insert your spectral data file above!')
+    else:
+        p_hash(file.name)
+        test = file.name[file.name.find('.'):]
+        export_name += file.name[:file.name.find('.')]
+
+        if test == files_format[0]:
+            qsep = st.radio("Select csv separator - _detected_: " + str(find_delimiter('data/'+file.name)), options=[";", ","],index=[";", ","].index(str(find_delimiter('data/'+file.name))), key=2, horizontal= True)
+            qhdr = st.radio("indexes column in csv? - _detected_: " + str(find_col_index('data/'+file.name)), options=["no", "yes"], index=["no", "yes"].index(str(find_col_index('data/'+file.name))), key=3, horizontal= True)
+            col = 0 if qhdr == 'yes' else None
+            p_hash([qsep,qhdr])
+
+            df = pd.read_csv(file, sep=qsep, header= col)
+            pred_data, cat, rownames =  preparespecdf(df)
+
+        elif test == files_format[1]:
+            with NamedTemporaryFile(delete=False, suffix=".dx") as tmp:
+                tmp.write(file.read())
+                tmp_path = tmp.name
+                with open(tmp.name, 'r') as dd:
+                    dxdata = file.read()
+                    p_hash(str(dxdata)+str(file.name))
+
+                ## load and parse the temp dx file
+                @st.cache_data
+                def dx_loader(change):
+                    chem_data, spectra, meta_data, _ = read_dx(file =  tmp_path)
+                    return chem_data, spectra, meta_data, _
+                chem_data, spectra, meta_data, _ = dx_loader(change = hash_)
+                st.success("The data have been loaded successfully", icon="✅")
+                if chem_data.to_numpy().shape[1]>0:
+                    yname = st.selectbox('Select target', options=chem_data.columns)
+                    measured = chem_data.loc[:,yname] == 0
+                    y = chem_data.loc[:,yname].loc[measured]
+                    pred_data = spectra.loc[measured]
+                
+                else:
+                    pred_data = spectra
+            os.unlink(tmp_path)
 
 
 # Load parameters
 st.header("I - Spectral data preprocessing & visualization", divider='blue')
-if not pred_data.empty:# Load the model with joblib
-    M1, M2= st.columns([2, 1])
-    M1.write('Raw spectra')
-    fig = plot_spectra(pred_data, xunits = 'lab', yunits = "meta_data.loc[:,'yunits'][0]")
-    M1.pyplot(fig)
-
-### preprocessing
-preprocessed = pd.DataFrame
-if not pred_data.empty:
-    params = M2.file_uploader("Load preprocessings params", type = '.json', help=" .json file")
-    if params:
-        prep = json.load(params)
-        # M4.write(ProcessLookupError)
-
-        if prep['normalization'] == 'Snv':
-            x1 = Snv(pred_data)
-            norm = 'Standard Normal Variate'
-        else:
-            norm = 'No Normalization was applied'
-            x1 = pred_data
-        x2 = savgol_filter(x1,
-                            window_length = prep["window_length"],
-                            polyorder = prep["polyorder"],
-                            deriv=prep["deriv"],
-                                delta=1.0, axis=-1, mode="interp", cval=0.0)
-        preprocessed = pd.DataFrame(x2, index = pred_data.index, columns = pred_data.columns)
-
-################################################################################################
-## plot preprocessed spectra
-if not preprocessed.empty:
-    M3, M4= st.columns([2, 1])
-    M3.write('Preprocessed spectra')
-    fig2 = plot_spectra(preprocessed, xunits = 'lab', yunits = "meta_data.loc[:,'yunits'][0]")
-    M3.pyplot(fig2)
-    SG = f'- Savitzky-Golay derivative parameters \:(Window_length:{prep['window_length']};  polynomial order: {prep['polyorder']};  Derivative order : {prep['deriv']})'
-    Norm = f'- Spectral Normalization \: {norm}'
-    M4.info('The spectra were preprocessed using:\n'+SG+"\n"+Norm)
-
-################### Predictions making  ##########################
-st.header("II - Prediction making", divider='blue')
-if not pred_data.empty and params:# Load the model with joblib
-    M5, M6 = st.columns([2, 1])
-    #dir = os.listdir('data/models/')[1:]
-    dir = os.listdir('data/models/')
-    dir.insert(0,'')
-    model_name = M6.selectbox("Select your model from the dropdown list:", options = dir, key = 21, format_func=lambda x: x if x else "<Select>")
-
-    if model_name:
-        export_name += '_with_' + model_name[:model_name.find('.')]
-        with open('data/models/'+ model_name,'rb') as f:
-            loaded_model = joblib.load(f)
-            ncols = loaded_model.n_features_in_
-            
-        if loaded_model:
-            M6.success("The model has been loaded successfully", icon="✅")
-            s = M6.checkbox('the model is of ipls type?')
-            if s:
-                index = M6.file_uploader("select wavelengths index file", type="csv")
-                if index:
-                    intervalls = pd.read_csv(index, sep=';', index_col=0).to_numpy()
-                    idx = []
-                    for i in range(intervalls.shape[0]):
-                        idx.extend(np.arange(intervalls[i,0], intervalls[i,1]+1))
-                    if max(idx) <= preprocessed.shape[1]:
-                        preprocessed = preprocessed.iloc[:,idx] ### get predictors
+try:
+    if check_exist("pred_data"):# Load the model with joblib
+        @st.cache_data
+        def specplot_raw(change):
+            fig2 = plot_spectra(pred_data, xunits = 'lab', yunits = "meta_data.loc[:,'yunits'][0]")
+            return fig2
+        rawspectraplot = specplot_raw(change = hash_)
+        M1, M2= st.columns([2, 1])
+        with M1:
+            st.write('Raw spectra')
+            st.pyplot(rawspectraplot)
+
+        with M2:
+            params = st.file_uploader("Load preprocessings params", type = '.json', help=" .json file")
+            if params:
+                prep = json.load(params)
+                p_hash(prep)
+                
+                @st.cache_data
+                def preprocess_spectra(change):
+                    # M4.write(ProcessLookupError)
+                    
+                    if prep['normalization'] == 'Snv':
+                        x1 = Snv(pred_data)
+                        norm = 'Standard Normal Variate'
                     else:
-                        M6.error("Error: The number of columns in your data does not match the number of columns used to train the model. Please ensure they are the same.")
+                        norm = 'No Normalization was applied'
+                        x1 = pred_data
+                    x2 = savgol_filter(x1,
+                                        window_length = int(prep["window_length"]),
+                                        polyorder = int(prep["polyorder"]),
+                                        deriv = int(prep["deriv"]),
+                                            delta=1.0, axis=-1, mode="interp", cval=0.0)
+                    preprocessed = pd.DataFrame(x2, index = pred_data.index, columns = pred_data.columns)
+                    return norm, prep, preprocessed
+                norm, prep, preprocessed = preprocess_spectra(change= hash_)
+
+    ################################################################################################
+    ## plot preprocessed spectra
+    if check_exist("preprocessed"):
+        p_hash(preprocessed)
+        M3, M4= st.columns([2, 1])
+        with M3:
+            st.write('Preprocessed spectra')
+            def specplot_prep(change):
+                fig2 = plot_spectra(preprocessed, xunits = 'lab', yunits = "meta_data.loc[:,'yunits'][0]")
+                return fig2
+            prepspectraplot = specplot_prep(change = hash_)
+            st.pyplot(prepspectraplot)
+
+        with M4:
+            @st.cache_data
+            def prep_info(change):
+                SG = f'- Savitzky-Golay derivative parameters \:(Window_length:{prep['window_length']};  polynomial order: {prep['polyorder']};  Derivative order : {prep['deriv']})'
+                Norm = f'- Spectral Normalization \: {norm}'
+                return SG, Norm
+            SG, Norm = prep_info(change = hash_)
+            st.info('The spectra were preprocessed using:\n'+SG+"\n"+Norm)
+
+    ################### Predictions making  ##########################
+    st.header("II - Prediction making", divider='blue')
+    if check_exist("pred_data") and params:# Load the model with joblib
+        M5, M6 = st.columns([2, 1])
+        model_file = M6.file_uploader("Load your model", type = '.pkl', help=" .pkl file")
+        if model_file:
+            with M6:
+                try:
+                    model = joblib.load(model_file)
+                    st.success("The model has been loaded successfully", icon="✅")
+                    nvar = model.n_features_in_
+
+                except:
+                    st.error("Error: Something went wrong, the model was not loaded !", icon="❌")
+        
+        with M6:
+            s = st.checkbox('Check this box if your model is of ipls type!', disabled = False if 'model' in globals() else True)
+            index = st.file_uploader("select wavelengths index file", type="csv", disabled = [False if s else True][0])
+            if check_exist('preprocessed'):
+                if s:
+                    if index:
+                        intervalls = pd.read_csv(index, sep=';', index_col=0).to_numpy()
+                        idx = []
+                        for i in range(intervalls.shape[0]):
+                            idx.extend(np.arange(intervalls[i,2], intervalls[i,3]+1))
+                        if max(idx) <= preprocessed.shape[1]:
+                            preprocesseddf = preprocessed.iloc[:,idx] ### get predictors    
+                        else:
+                            st.error("Error: The number of columns in your data does not match the number of columns used to train the model. Please ensure they are the same.")
+                else:
+                    preprocesseddf = preprocessed
+                
+
+                
+            if check_exist("model") == False:
+                disable = True
+            elif check_exist("model") == True:
+                if s and not index :
+                    disable = True
+                elif s and index:
+                    disable  = False
+                elif not s and not index:
+                    disable  = False
+                elif not s and index:
+                    disable  = True
+
+                
+            pred_button = M6.button('Predict', type='primary', disabled= disable)
+
+            if check_exist("preprocesseddf"):
+                if pred_button and nvar == preprocesseddf.shape[1]:
+                    try:
+                        result = pd.DataFrame(model.predict(preprocesseddf), index = rownames, columns = ['Results'])
+                    except:
+                        st.error(f'''Error: Length mismatch: the number of samples indices is {len(rownames)}, while the model produced 
+                                {len(model.predict(preprocesseddf))} values. correct the "indexes column in csv?" parameter''')
+                    with M5:
+                        if preprocesseddf.shape[1]>1 and check_exist('result'):
+                            st.write('Predicted values distribution')
+                            # Creating histogram
+                            hist, axs = plt.subplots(1, 1, figsize =(15, 3), 
+                                                    tight_layout = True)
+                            
+                            # Add x, y gridlines 
+                            axs.grid( color ='grey', linestyle ='-.', linewidth = 0.5, alpha = 0.6) 
+                            # Remove axes splines 
+                            for s in ['top', 'bottom', 'left', 'right']: 
+                                axs.spines[s].set_visible(False) 
+                            # Remove x, y ticks
+                            axs.xaxis.set_ticks_position('none') 
+                            axs.yaxis.set_ticks_position('none') 
+                            # Add padding between axes and labels 
+                            axs.xaxis.set_tick_params(pad = 5) 
+                            axs.yaxis.set_tick_params(pad = 10) 
+                            # Creating histogram
+                            N, bins, patches = axs.hist(result, bins = 12)
+                            # Setting color
+                            fracs = ((N**(1 / 5)) / N.max())
+                            norm = colors.Normalize(fracs.min(), fracs.max())
+                            
+                            for thisfrac, thispatch in zip(fracs, patches):
+                                color = plt.cm.viridis(norm(thisfrac))
+                                thispatch.set_facecolor(color)
+
+                            st.pyplot(hist)
+                            st.write('Predicted values table')
+                            st.dataframe(result.T)
+                            #################################3
+                elif pred_button and nvar != preprocesseddf.shape[1]:
+                    M6.error(f'Error: The model was trained on {nvar} wavelengths, but you provided {preprocessed.shape[1]} wavelengths for prediction. Please ensure they match!')
+
+
+    if check_exist('result'):
+        @st.cache_data(show_spinner =False)
+        def preparing_results_for_downloading(change):
+            match test:
+                # load csv file
+                case '.csv':
+                    df.to_csv('Report/out/dataset/'+ file.name, sep = ';', encoding = 'utf-8', mode = 'a')
+                case '.dx':
+                    with open('Report/out/dataset/'+file.name, 'w') as dd:
+                        dd.write(dxdata)
+
+            prepspectraplot.savefig('./Report/out/figures/raw_spectra.png')
+            rawspectraplot.savefig('./Report/out/figures/preprocessed_spectra.png')
+            hist.savefig('./Report/out/figures/histogram.png')
+            result.round(4).to_csv('./Report/out/The analysis result.csv', sep = ";", index_col=0)
+
+            return change
+        preparing_results_for_downloading(change = hash_)
+
+        import tempfile
+        @st.cache_data(show_spinner =False)
+        def tempdir(change):
+            with  tempfile.TemporaryDirectory( prefix="results", dir="./Report") as temp_dir:# create a temp directory
+                tempdirname = os.path.split(temp_dir)[1]
+                if len(os.listdir('./Report/out/figures/'))==3:
+                    shutil.make_archive(base_name="./Report/Results", format="zip", base_dir="out", root_dir = "./Report")# create a zip file
+                    shutil.move("./Report/Results.zip", f"./Report/{tempdirname}/Results.zip")# put the inside the temp dir
+                    with open(f"./Report/{tempdirname}/Results.zip", "rb") as f:
+                        zip_data = f.read()
+            return tempdirname, zip_data
+
+        date_time = datetime.datetime.now().strftime('%y%m%d%H%M')
+        try :
+            tempdirname, zip_data = tempdir(change = hash_)
+            st.download_button(label = 'Download', data = zip_data, file_name = f'Nirs_Workflow_{date_time}_Pred_.zip', mime ="application/zip",
+                        args = None, kwargs = None,type = "primary",use_container_width = True)
+        except:
+            st.write('rtt')
+except:
+    M20.error('''Error: Data loading failed. Please check your file. Consider fine-tuning the dialect settings or ensure the file isn't corrupted.''')
 
 
-if loaded_model:
-    if M6.button('Predict', type='primary'):
-            if ncols == preprocessed.shape[1]:
-                result = pd.DataFrame(loaded_model.predict(preprocessed), index = preprocessed.index)
 
-                #############################
-                if preprocessed.shape[1]>1:
-                    M5.write('Predicted values distribution')
-                    # Creating histogram
-                    fig, axs = plt.subplots(1, 1, figsize =(15, 3), 
-                                            tight_layout = True)
-                    
-                    # Add x, y gridlines 
-                    axs.grid( color ='grey', linestyle ='-.', linewidth = 0.5, alpha = 0.6) 
-                    # Remove axes splines 
-                    for s in ['top', 'bottom', 'left', 'right']: 
-                        axs.spines[s].set_visible(False) 
-                    # Remove x, y ticks
-                    axs.xaxis.set_ticks_position('none') 
-                    axs.yaxis.set_ticks_position('none') 
-                    # Add padding between axes and labels 
-                    axs.xaxis.set_tick_params(pad = 5) 
-                    axs.yaxis.set_tick_params(pad = 10) 
-                    # Creating histogram
-                    N, bins, patches = axs.hist(result, bins = 12)
-                    # Setting color
-                    fracs = ((N**(1 / 5)) / N.max())
-                    norm = colors.Normalize(fracs.min(), fracs.max())
-                    
-                    for thisfrac, thispatch in zip(fracs, patches):
-                        color = plt.cm.viridis(norm(thisfrac))
-                        thispatch.set_facecolor(color)
-
-                    M5.pyplot(fig)
-                st.write('Predicted values table')
-                st.dataframe(result.T)
-                ##################################
-
-                # result.to_csv(export_folder + export_name + '.csv', sep = ';')
-                # export to local drive - Download
-                download_results(export_folder + export_name + '.csv', export_name + '.csv')
-                # create a report with information on the prediction
-                ## see https://stackoverflow.com/a/59578663
-            else:
-                M6.error(f'Error: The model was trained with {ncols} wavelengths, but you provided {preprocessed.shape[1]} wavelengths for prediction. Please ensure they match.')
-            
-- 
GitLab