diff --git a/src/pages/3-prediction.py b/src/pages/3-prediction.py index 8600dee93c0ea6b3d83d84485eed64ec05f04495..f73aa24998196e1a44f341b117b75309a4527fdf 100644 --- a/src/pages/3-prediction.py +++ b/src/pages/3-prediction.py @@ -115,12 +115,13 @@ if not pred_data.empty and params:# Load the model with joblib #dir = os.listdir('data/models/')[1:] dir = os.listdir('data/models/') dir.insert(0,'') - model_name = M6.selectbox("Select your model from the dropdown list:", options = dir, key = 21) + model_name = M6.selectbox("Select your model from the dropdown list:", options = dir, key = 21, format_func=lambda x: x if x else "<Select>") - if model_name and model_name !='': + if model_name: export_name += '_with_' + model_name[:model_name.find('.')] with open('data/models/'+ model_name,'rb') as f: loaded_model = joblib.load(f) + ncols = loaded_model.n_features_in_ if loaded_model: M6.success("The model has been loaded successfully", icon="✅") @@ -132,50 +133,55 @@ if not pred_data.empty and params:# Load the model with joblib idx = [] for i in range(intervalls.shape[0]): idx.extend(np.arange(intervalls[i,0], intervalls[i,1]+1)) + if max(idx) <= preprocessed.shape[1]: + preprocessed = preprocessed.iloc[:,idx] ### get predictors + else: + M6.error("Error: The number of columns in your data does not match the number of columns used to train the model. Please ensure they are the same.") + if loaded_model: if M6.button('Predict', type='primary'): - if s: - result = loaded_model.predict(preprocessed.iloc[:,idx]) + if ncols == preprocessed.shape[1]: + result = pd.DataFrame(loaded_model.predict(preprocessed), index = preprocessed.index) + + ############################# + if preprocessed.shape[1]>1: + M5.write('Predicted values distribution') + # Creating histogram + fig, axs = plt.subplots(1, 1, figsize =(15, 3), + tight_layout = True) + + # Add x, y gridlines + axs.grid( color ='grey', linestyle ='-.', linewidth = 0.5, alpha = 0.6) + # Remove axes splines + for s in ['top', 'bottom', 'left', 'right']: + axs.spines[s].set_visible(False) + # Remove x, y ticks + axs.xaxis.set_ticks_position('none') + axs.yaxis.set_ticks_position('none') + # Add padding between axes and labels + axs.xaxis.set_tick_params(pad = 5) + axs.yaxis.set_tick_params(pad = 10) + # Creating histogram + N, bins, patches = axs.hist(result, bins = 12) + # Setting color + fracs = ((N**(1 / 5)) / N.max()) + norm = colors.Normalize(fracs.min(), fracs.max()) + + for thisfrac, thispatch in zip(fracs, patches): + color = plt.cm.viridis(norm(thisfrac)) + thispatch.set_facecolor(color) + + M5.pyplot(fig) + st.write('Predicted values table') + st.dataframe(result.T) + ################################## + + # result.to_csv(export_folder + export_name + '.csv', sep = ';') + # export to local drive - Download + download_results(export_folder + export_name + '.csv', export_name + '.csv') + # create a report with information on the prediction + ## see https://stackoverflow.com/a/59578663 else: - # use prediction function from application_functions.py to predict chemical values - result = loaded_model.predict(x2) - result = pd.DataFrame(result, index = pred_data.index) - - ############################# - M5.write('Predicted values distribution') - # Creating histogram - fig, axs = plt.subplots(1, 1, figsize =(15, 3), - tight_layout = True) - - # Add x, y gridlines - axs.grid( color ='grey', linestyle ='-.', linewidth = 0.5, alpha = 0.6) - # Remove axes splines - for s in ['top', 'bottom', 'left', 'right']: - axs.spines[s].set_visible(False) - # Remove x, y ticks - axs.xaxis.set_ticks_position('none') - axs.yaxis.set_ticks_position('none') - # Add padding between axes and labels - axs.xaxis.set_tick_params(pad = 5) - axs.yaxis.set_tick_params(pad = 10) - # Creating histogram - N, bins, patches = axs.hist(result, bins = 12) - # Setting color - fracs = ((N**(1 / 5)) / N.max()) - norm = colors.Normalize(fracs.min(), fracs.max()) + M6.error(f'Error: The model was trained with {ncols} wavelengths, but you provided {preprocessed.shape[1]} wavelengths for prediction. Please ensure they match.') - for thisfrac, thispatch in zip(fracs, patches): - color = plt.cm.viridis(norm(thisfrac)) - thispatch.set_facecolor(color) - - M5.pyplot(fig) - st.write('Predicted values table') - st.dataframe(result.T) - ################################## - - result.to_csv(export_folder + export_name + '.csv', sep = ';') - # export to local drive - Download - download_results(export_folder + export_name + '.csv', export_name + '.csv') - # create a report with information on the prediction - ## see https://stackoverflow.com/a/59578663 \ No newline at end of file