From c1cf0b75eb603dfae4560df1ff4ae333678947bb Mon Sep 17 00:00:00 2001 From: Nicolas Barthes <nicolas.barthes@cnrs.fr> Date: Wed, 7 Aug 2024 11:57:26 +0200 Subject: [PATCH] reformated code in model creation --- src/pages/2-model_creation.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/src/pages/2-model_creation.py b/src/pages/2-model_creation.py index 2db808e..c2265ad 100644 --- a/src/pages/2-model_creation.py +++ b/src/pages/2-model_creation.py @@ -361,19 +361,20 @@ if Reg: with open("data/params/Preprocessing.json", "w") as outfile: json.dump(Reg.best_hyperparams_, outfile) - + yc = Reg.pred_data_[0] yt = Reg.pred_data_[1] # ########## + M1.write("-- Model performance --") if regression_algo != reg_algo[2]: M1.dataframe(metrics(c = [y_train, yc], t = [y_test, yt], method='regression').scores_) else: M1.dataframe(metrics(t = [y_test, yt], method='regression').scores_) - model_per=pd.DataFrame(metrics(c = [y_train, yc], t = [y_test, yt], method='regression').scores_) - M1.write("-- Model performance --") - M1.dataframe(model_per) + + model_per=pd.DataFrame(metrics(c = [y_train, yc], t = [y_test, yt], method='regression').scores_) + # M1.dataframe(model_per) # duplicate with line 371 @st.cache_data def prep_important(change,regression_algo): @@ -383,7 +384,7 @@ if Reg: ax2.plot(colnames, np.mean(Reg.pretreated_spectra_ , axis = 0), color = 'black', label = 'Average spectrum (Pretreated)') ax2.set_xlabel('Wavelenghts') plt.tight_layout() - + for i in range(2): eval(f'ax{i+1}').grid(color='grey', linestyle=':', linewidth=0.2) eval(f'ax{i+1}').margins(x = 0) @@ -396,9 +397,9 @@ if Reg: min, max = intervalls_with_cols['from'][j], intervalls_with_cols['to'][j] else: min, max = intervalls['from'][j], intervalls['to'][j] - + eval(f'ax{i+1}').axvspan(min, max, color='#00ff00', alpha=0.5, lw=0) - + if regression_algo == 'PLS': ax1.scatter(colnames[np.array(Reg.sel_ratio_.index)], np.mean(X_train, axis = 0)[np.array(Reg.sel_ratio_.index)], color = '#7ab0c7', label = 'Important variables') @@ -416,7 +417,7 @@ if Reg: M2.write('-- Visualization of the spectral regions used for model creation --') fig.savefig("./Report/figures/Variable_importance.png") - M2.pyplot(fig) + M2.pyplot(fig) @@ -425,7 +426,7 @@ if Reg: if Reg: # fig, (ax1, ax2) = plt.subplots(2,1, figsize = (12, 6)) # fig = make_subplots(rows=3, cols=1, shared_xaxes=True, vertical_spacing=0.02) - + st.header("Cross-Validation results") cv1, cv2 = st.columns([2,2]) ############ @@ -434,7 +435,7 @@ if Reg: cv_results=pd.DataFrame(Reg.CV_results_) cv2.write('-- Out-of-Fold Predictions Visualization (All in one) --') - fig1 = px.scatter(Reg.cv_data_[0], x ='Measured', y = 'Predicted' , trendline='ols', color='Folds', symbol="Folds", + fig1 = px.scatter(Reg.cv_data_[0], x ='Measured', y = 'Predicted' , trendline='ols', color='Folds', symbol="Folds", color_discrete_sequence=px.colors.qualitative.G10) fig1.add_shape(type='line', x0 = .95 * min(Reg.cv_data_[0].loc[:,'Measured']), x1 = 1.05 * max(Reg.cv_data_[0].loc[:,'Measured']), y0 = .95 * min(Reg.cv_data_[0].loc[:,'Measured']), y1 = 1.05 * max(Reg.cv_data_[0].loc[:,'Measured']), line = dict(color='black', dash = "dash")) -- GitLab