Skip to content
Snippets Groups Projects
Commit 7ccfbffa authored by DIANE's avatar DIANE
Browse files

gestion d erreur

parent 86d7c2b2
No related branches found
No related tags found
No related merge requests found
...@@ -115,12 +115,13 @@ if not pred_data.empty and params:# Load the model with joblib ...@@ -115,12 +115,13 @@ if not pred_data.empty and params:# Load the model with joblib
#dir = os.listdir('data/models/')[1:] #dir = os.listdir('data/models/')[1:]
dir = os.listdir('data/models/') dir = os.listdir('data/models/')
dir.insert(0,'') dir.insert(0,'')
model_name = M6.selectbox("Select your model from the dropdown list:", options = dir, key = 21) model_name = M6.selectbox("Select your model from the dropdown list:", options = dir, key = 21, format_func=lambda x: x if x else "<Select>")
if model_name and model_name !='': if model_name:
export_name += '_with_' + model_name[:model_name.find('.')] export_name += '_with_' + model_name[:model_name.find('.')]
with open('data/models/'+ model_name,'rb') as f: with open('data/models/'+ model_name,'rb') as f:
loaded_model = joblib.load(f) loaded_model = joblib.load(f)
ncols = loaded_model.n_features_in_
if loaded_model: if loaded_model:
M6.success("The model has been loaded successfully", icon="") M6.success("The model has been loaded successfully", icon="")
...@@ -132,50 +133,55 @@ if not pred_data.empty and params:# Load the model with joblib ...@@ -132,50 +133,55 @@ if not pred_data.empty and params:# Load the model with joblib
idx = [] idx = []
for i in range(intervalls.shape[0]): for i in range(intervalls.shape[0]):
idx.extend(np.arange(intervalls[i,0], intervalls[i,1]+1)) idx.extend(np.arange(intervalls[i,0], intervalls[i,1]+1))
if max(idx) <= preprocessed.shape[1]:
preprocessed = preprocessed.iloc[:,idx] ### get predictors
else:
M6.error("Error: The number of columns in your data does not match the number of columns used to train the model. Please ensure they are the same.")
if loaded_model: if loaded_model:
if M6.button('Predict', type='primary'): if M6.button('Predict', type='primary'):
if s: if ncols == preprocessed.shape[1]:
result = loaded_model.predict(preprocessed.iloc[:,idx]) result = pd.DataFrame(loaded_model.predict(preprocessed), index = preprocessed.index)
#############################
if preprocessed.shape[1]>1:
M5.write('Predicted values distribution')
# Creating histogram
fig, axs = plt.subplots(1, 1, figsize =(15, 3),
tight_layout = True)
# Add x, y gridlines
axs.grid( color ='grey', linestyle ='-.', linewidth = 0.5, alpha = 0.6)
# Remove axes splines
for s in ['top', 'bottom', 'left', 'right']:
axs.spines[s].set_visible(False)
# Remove x, y ticks
axs.xaxis.set_ticks_position('none')
axs.yaxis.set_ticks_position('none')
# Add padding between axes and labels
axs.xaxis.set_tick_params(pad = 5)
axs.yaxis.set_tick_params(pad = 10)
# Creating histogram
N, bins, patches = axs.hist(result, bins = 12)
# Setting color
fracs = ((N**(1 / 5)) / N.max())
norm = colors.Normalize(fracs.min(), fracs.max())
for thisfrac, thispatch in zip(fracs, patches):
color = plt.cm.viridis(norm(thisfrac))
thispatch.set_facecolor(color)
M5.pyplot(fig)
st.write('Predicted values table')
st.dataframe(result.T)
##################################
# result.to_csv(export_folder + export_name + '.csv', sep = ';')
# export to local drive - Download
download_results(export_folder + export_name + '.csv', export_name + '.csv')
# create a report with information on the prediction
## see https://stackoverflow.com/a/59578663
else: else:
# use prediction function from application_functions.py to predict chemical values M6.error(f'Error: The model was trained with {ncols} wavelengths, but you provided {preprocessed.shape[1]} wavelengths for prediction. Please ensure they match.')
result = loaded_model.predict(x2)
result = pd.DataFrame(result, index = pred_data.index)
#############################
M5.write('Predicted values distribution')
# Creating histogram
fig, axs = plt.subplots(1, 1, figsize =(15, 3),
tight_layout = True)
# Add x, y gridlines
axs.grid( color ='grey', linestyle ='-.', linewidth = 0.5, alpha = 0.6)
# Remove axes splines
for s in ['top', 'bottom', 'left', 'right']:
axs.spines[s].set_visible(False)
# Remove x, y ticks
axs.xaxis.set_ticks_position('none')
axs.yaxis.set_ticks_position('none')
# Add padding between axes and labels
axs.xaxis.set_tick_params(pad = 5)
axs.yaxis.set_tick_params(pad = 10)
# Creating histogram
N, bins, patches = axs.hist(result, bins = 12)
# Setting color
fracs = ((N**(1 / 5)) / N.max())
norm = colors.Normalize(fracs.min(), fracs.max())
for thisfrac, thispatch in zip(fracs, patches):
color = plt.cm.viridis(norm(thisfrac))
thispatch.set_facecolor(color)
M5.pyplot(fig)
st.write('Predicted values table')
st.dataframe(result.T)
##################################
result.to_csv(export_folder + export_name + '.csv', sep = ';')
# export to local drive - Download
download_results(export_folder + export_name + '.csv', export_name + '.csv')
# create a report with information on the prediction
## see https://stackoverflow.com/a/59578663
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment