From c1cf0b75eb603dfae4560df1ff4ae333678947bb Mon Sep 17 00:00:00 2001
From: Nicolas Barthes <nicolas.barthes@cnrs.fr>
Date: Wed, 7 Aug 2024 11:57:26 +0200
Subject: [PATCH] reformated code in model creation

---
 src/pages/2-model_creation.py | 21 +++++++++++----------
 1 file changed, 11 insertions(+), 10 deletions(-)

diff --git a/src/pages/2-model_creation.py b/src/pages/2-model_creation.py
index 2db808e..c2265ad 100644
--- a/src/pages/2-model_creation.py
+++ b/src/pages/2-model_creation.py
@@ -361,19 +361,20 @@ if Reg:
     with open("data/params/Preprocessing.json", "w") as outfile:
         json.dump(Reg.best_hyperparams_, outfile)
 
-  
+
 
     yc = Reg.pred_data_[0]
     yt = Reg.pred_data_[1]
     # ##########
+    M1.write("-- Model performance --")
     if regression_algo != reg_algo[2]:
         M1.dataframe(metrics(c = [y_train, yc], t = [y_test, yt], method='regression').scores_)
     else:
         M1.dataframe(metrics(t = [y_test, yt], method='regression').scores_)
-    model_per=pd.DataFrame(metrics(c = [y_train, yc], t = [y_test, yt], method='regression').scores_)
 
-    M1.write("-- Model performance --")
-    M1.dataframe(model_per)
+
+    model_per=pd.DataFrame(metrics(c = [y_train, yc], t = [y_test, yt], method='regression').scores_)
+    # M1.dataframe(model_per) # duplicate with line 371
 
     @st.cache_data
     def prep_important(change,regression_algo):
@@ -383,7 +384,7 @@ if Reg:
         ax2.plot(colnames, np.mean(Reg.pretreated_spectra_ , axis = 0), color = 'black', label = 'Average spectrum (Pretreated)')
         ax2.set_xlabel('Wavelenghts')
         plt.tight_layout()
-            
+
         for i in range(2):
             eval(f'ax{i+1}').grid(color='grey', linestyle=':', linewidth=0.2)
             eval(f'ax{i+1}').margins(x = 0)
@@ -396,9 +397,9 @@ if Reg:
                         min, max = intervalls_with_cols['from'][j], intervalls_with_cols['to'][j]
                     else:
                         min, max = intervalls['from'][j], intervalls['to'][j]
-                    
+
                     eval(f'ax{i+1}').axvspan(min, max, color='#00ff00', alpha=0.5, lw=0)
-        
+
         if regression_algo == 'PLS':
             ax1.scatter(colnames[np.array(Reg.sel_ratio_.index)], np.mean(X_train, axis = 0)[np.array(Reg.sel_ratio_.index)],
                             color = '#7ab0c7', label = 'Important variables')
@@ -416,7 +417,7 @@ if Reg:
 
     M2.write('-- Visualization of the spectral regions used for model creation --')
     fig.savefig("./Report/figures/Variable_importance.png")
-    M2.pyplot(fig) 
+    M2.pyplot(fig)
 
 
 
@@ -425,7 +426,7 @@ if Reg:
     if Reg:
         # fig, (ax1, ax2) = plt.subplots(2,1, figsize = (12, 6))
         # fig = make_subplots(rows=3, cols=1, shared_xaxes=True, vertical_spacing=0.02)
-        
+
         st.header("Cross-Validation results")
         cv1, cv2 = st.columns([2,2])
         ############
@@ -434,7 +435,7 @@ if Reg:
         cv_results=pd.DataFrame(Reg.CV_results_)
         cv2.write('-- Out-of-Fold Predictions Visualization (All in one) --')
 
-        fig1 = px.scatter(Reg.cv_data_[0], x ='Measured', y = 'Predicted' , trendline='ols', color='Folds', symbol="Folds", 
+        fig1 = px.scatter(Reg.cv_data_[0], x ='Measured', y = 'Predicted' , trendline='ols', color='Folds', symbol="Folds",
                  color_discrete_sequence=px.colors.qualitative.G10)
         fig1.add_shape(type='line', x0 = .95 * min(Reg.cv_data_[0].loc[:,'Measured']), x1 = 1.05 * max(Reg.cv_data_[0].loc[:,'Measured']),
                         y0 = .95 * min(Reg.cv_data_[0].loc[:,'Measured']), y1 = 1.05 * max(Reg.cv_data_[0].loc[:,'Measured']), line = dict(color='black', dash = "dash"))
-- 
GitLab