Skip to content
Snippets Groups Projects
Commit c1cf0b75 authored by Nicolas Barthes's avatar Nicolas Barthes
Browse files

reformated code in model creation

parent 3d2259c0
No related branches found
No related tags found
No related merge requests found
...@@ -361,19 +361,20 @@ if Reg: ...@@ -361,19 +361,20 @@ if Reg:
with open("data/params/Preprocessing.json", "w") as outfile: with open("data/params/Preprocessing.json", "w") as outfile:
json.dump(Reg.best_hyperparams_, outfile) json.dump(Reg.best_hyperparams_, outfile)
yc = Reg.pred_data_[0] yc = Reg.pred_data_[0]
yt = Reg.pred_data_[1] yt = Reg.pred_data_[1]
# ########## # ##########
M1.write("-- Model performance --")
if regression_algo != reg_algo[2]: if regression_algo != reg_algo[2]:
M1.dataframe(metrics(c = [y_train, yc], t = [y_test, yt], method='regression').scores_) M1.dataframe(metrics(c = [y_train, yc], t = [y_test, yt], method='regression').scores_)
else: else:
M1.dataframe(metrics(t = [y_test, yt], method='regression').scores_) M1.dataframe(metrics(t = [y_test, yt], method='regression').scores_)
model_per=pd.DataFrame(metrics(c = [y_train, yc], t = [y_test, yt], method='regression').scores_)
M1.write("-- Model performance --")
M1.dataframe(model_per) model_per=pd.DataFrame(metrics(c = [y_train, yc], t = [y_test, yt], method='regression').scores_)
# M1.dataframe(model_per) # duplicate with line 371
@st.cache_data @st.cache_data
def prep_important(change,regression_algo): def prep_important(change,regression_algo):
...@@ -383,7 +384,7 @@ if Reg: ...@@ -383,7 +384,7 @@ if Reg:
ax2.plot(colnames, np.mean(Reg.pretreated_spectra_ , axis = 0), color = 'black', label = 'Average spectrum (Pretreated)') ax2.plot(colnames, np.mean(Reg.pretreated_spectra_ , axis = 0), color = 'black', label = 'Average spectrum (Pretreated)')
ax2.set_xlabel('Wavelenghts') ax2.set_xlabel('Wavelenghts')
plt.tight_layout() plt.tight_layout()
for i in range(2): for i in range(2):
eval(f'ax{i+1}').grid(color='grey', linestyle=':', linewidth=0.2) eval(f'ax{i+1}').grid(color='grey', linestyle=':', linewidth=0.2)
eval(f'ax{i+1}').margins(x = 0) eval(f'ax{i+1}').margins(x = 0)
...@@ -396,9 +397,9 @@ if Reg: ...@@ -396,9 +397,9 @@ if Reg:
min, max = intervalls_with_cols['from'][j], intervalls_with_cols['to'][j] min, max = intervalls_with_cols['from'][j], intervalls_with_cols['to'][j]
else: else:
min, max = intervalls['from'][j], intervalls['to'][j] min, max = intervalls['from'][j], intervalls['to'][j]
eval(f'ax{i+1}').axvspan(min, max, color='#00ff00', alpha=0.5, lw=0) eval(f'ax{i+1}').axvspan(min, max, color='#00ff00', alpha=0.5, lw=0)
if regression_algo == 'PLS': if regression_algo == 'PLS':
ax1.scatter(colnames[np.array(Reg.sel_ratio_.index)], np.mean(X_train, axis = 0)[np.array(Reg.sel_ratio_.index)], ax1.scatter(colnames[np.array(Reg.sel_ratio_.index)], np.mean(X_train, axis = 0)[np.array(Reg.sel_ratio_.index)],
color = '#7ab0c7', label = 'Important variables') color = '#7ab0c7', label = 'Important variables')
...@@ -416,7 +417,7 @@ if Reg: ...@@ -416,7 +417,7 @@ if Reg:
M2.write('-- Visualization of the spectral regions used for model creation --') M2.write('-- Visualization of the spectral regions used for model creation --')
fig.savefig("./Report/figures/Variable_importance.png") fig.savefig("./Report/figures/Variable_importance.png")
M2.pyplot(fig) M2.pyplot(fig)
...@@ -425,7 +426,7 @@ if Reg: ...@@ -425,7 +426,7 @@ if Reg:
if Reg: if Reg:
# fig, (ax1, ax2) = plt.subplots(2,1, figsize = (12, 6)) # fig, (ax1, ax2) = plt.subplots(2,1, figsize = (12, 6))
# fig = make_subplots(rows=3, cols=1, shared_xaxes=True, vertical_spacing=0.02) # fig = make_subplots(rows=3, cols=1, shared_xaxes=True, vertical_spacing=0.02)
st.header("Cross-Validation results") st.header("Cross-Validation results")
cv1, cv2 = st.columns([2,2]) cv1, cv2 = st.columns([2,2])
############ ############
...@@ -434,7 +435,7 @@ if Reg: ...@@ -434,7 +435,7 @@ if Reg:
cv_results=pd.DataFrame(Reg.CV_results_) cv_results=pd.DataFrame(Reg.CV_results_)
cv2.write('-- Out-of-Fold Predictions Visualization (All in one) --') cv2.write('-- Out-of-Fold Predictions Visualization (All in one) --')
fig1 = px.scatter(Reg.cv_data_[0], x ='Measured', y = 'Predicted' , trendline='ols', color='Folds', symbol="Folds", fig1 = px.scatter(Reg.cv_data_[0], x ='Measured', y = 'Predicted' , trendline='ols', color='Folds', symbol="Folds",
color_discrete_sequence=px.colors.qualitative.G10) color_discrete_sequence=px.colors.qualitative.G10)
fig1.add_shape(type='line', x0 = .95 * min(Reg.cv_data_[0].loc[:,'Measured']), x1 = 1.05 * max(Reg.cv_data_[0].loc[:,'Measured']), fig1.add_shape(type='line', x0 = .95 * min(Reg.cv_data_[0].loc[:,'Measured']), x1 = 1.05 * max(Reg.cv_data_[0].loc[:,'Measured']),
y0 = .95 * min(Reg.cv_data_[0].loc[:,'Measured']), y1 = 1.05 * max(Reg.cv_data_[0].loc[:,'Measured']), line = dict(color='black', dash = "dash")) y0 = .95 * min(Reg.cv_data_[0].loc[:,'Measured']), y1 = 1.05 * max(Reg.cv_data_[0].loc[:,'Measured']), line = dict(color='black', dash = "dash"))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment