from Packages import * st.set_page_config(page_title="NIRS Utils", page_icon=":goat:", layout="wide") from Modules import * from Class_Mod.DATA_HANDLING import * # HTML pour le bandeau "CEFE - CNRS" # bandeau_html = """ # <div style="width: 100%; background-color: #4682B4; padding: 10px; margin-bottom: 10px;"> # <h1 style="text-align: center; color: white;">CEFE - CNRS</h1> # </div> # """ # # Injecter le code HTML du bandeau # st.markdown(bandeau_html, unsafe_allow_html=True) add_header() st.session_state["interface"] = st.session_state.get('interface') local_css(css_file / "style_model.css") st.header("Data loading", divider='blue') model_column1, space1, file_column1= st.columns([2, 1, 1]) st.header("Prediction making", divider='blue') model_column2, space2, file_column2= st.columns([2, 1, 1]) _, space3, _ = st.columns([1, 3, 1]) files_format = ['.csv', '.dx'] file = file_column1.file_uploader("Select NIRS Data to predict", type = files_format, help=" :mushroom: select a csv matrix with samples as rows and lambdas as columns") export_folder = './data/predictions/' export_name = 'Predictions_of_' reg_algo = ["Interval-PLS"] pred_data = pd.DataFrame loaded_model = None if file: test = file.name[file.name.find('.'):] export_name += file.name[:file.name.find('.')] if test == files_format[0]: # qsep = file_column1.selectbox("Select csv separator - _detected_: " + str(find_delimiter('data/'+file.name)), options=[";", ","], index=[";", ","].index(str(find_delimiter('data/'+file.name))), key=2) qhdr = file_column1.selectbox("indexes column in csv? - _detected_: " + str(find_col_index('data/'+file.name)), options=["no", "yes"], index=["no", "yes"].index(str(find_col_index('data/'+file.name))), key=3) if qhdr == 'yes': col = 0 else: col = False pred_data = pd.read_csv(file, sep=qsep, index_col=col) elif test == files_format[1]: with NamedTemporaryFile(delete=False, suffix=".dx") as tmp: tmp.write(file.read()) tmp_path = tmp.name chem_data, spectra, meta_data, _ = read_dx(file = tmp_path) file_column1.success("The data have been loaded successfully", icon="✅") if chem_data.to_numpy().shape[1]>0: yname = file_column1.selectbox('Select target', options=chem_data.columns) measured = chem_data.loc[:,yname] == 0 y = chem_data.loc[:,yname].loc[measured] pred_data = spectra.loc[measured] else: pred_data = spectra os.unlink(tmp_path) # Load parameters if not pred_data.empty:# Load the model with joblib model_column1.write('Raw spectra') fig = plot_spectra(pred_data, xunits = 'lab', yunits = "meta_data.loc[:,'yunits'][0]") model_column1.pyplot(fig) ### preprocessing preprocessed = pd.DataFrame if not pred_data.empty: params = file_column1.file_uploader("Load preprocessings params", type = '.json', help=" .json file") if params: prep = json.load(params) if prep['Scatter'] == 'SNV': x1 = Snv(pred_data) else: x1 = pred_data x2 = savgol_filter(x1, window_length = prep["Saitzky-Golay derivative parameters"]["window_length"], polyorder = prep["Saitzky-Golay derivative parameters"]["polyorder"], deriv=prep["Saitzky-Golay derivative parameters"]["deriv"], delta=1.0, axis=-1, mode="interp", cval=0.0) preprocessed = pd.DataFrame(x2, index = pred_data.index, columns = pred_data.columns) ## plot preprocessed spectra if not preprocessed.empty: model_column1.write('Preprocessed spectra') fig2 = plot_spectra(preprocessed, xunits = 'lab', yunits = "meta_data.loc[:,'yunits'][0]") model_column1.pyplot(fig2) ################### Predictions making ########################## if not pred_data.empty:# Load the model with joblib #dir = os.listdir('data/models/')[1:] dir = os.listdir('data/models/') dir.insert(0,'') model_name = model_column2.selectbox("Select your model from the dropdown list:", options = dir, key = 21) if model_name and model_name !='': export_name += '_with_' + model_name[:model_name.find('.')] with open('data/models/'+ model_name,'rb') as f: loaded_model = joblib.load(f) if loaded_model: model_column2.success("The model has been loaded successfully", icon="✅") s = model_column2.checkbox('the model is of ipls type?') if s: index = model_column2.file_uploader("select wavelengths index file", type="csv") if index: idx = pd.read_csv(index, sep=';', index_col=0).iloc[:,0].to_numpy() if loaded_model: if model_column2.button('Predict'): if s: result = loaded_model.predict(preprocessed.iloc[:,idx]) else: # use prediction function from application_functions.py to predict chemical values result = loaded_model.predict(x2) result = pd.DataFrame(result, index = pred_data.index) st.write('Predicted values') st.dataframe(result.T) ############################# # Creating histogram fig, axs = plt.subplots(1, 1, figsize =(12, 6), tight_layout = True) # Add x, y gridlines axs.grid( color ='grey', linestyle ='-.', linewidth = 0.5, alpha = 0.6) plt.title('Predicted values distribution') # Remove axes splines for s in ['top', 'bottom', 'left', 'right']: axs.spines[s].set_visible(False) # Remove x, y ticks axs.xaxis.set_ticks_position('none') axs.yaxis.set_ticks_position('none') # Add padding between axes and labels axs.xaxis.set_tick_params(pad = 5) axs.yaxis.set_tick_params(pad = 10) # Creating histogram N, bins, patches = axs.hist(result, bins = 12) # Setting color fracs = ((N**(1 / 5)) / N.max()) norm = colors.Normalize(fracs.min(), fracs.max()) for thisfrac, thispatch in zip(fracs, patches): color = plt.cm.viridis(norm(thisfrac)) thispatch.set_facecolor(color) space3.pyplot(fig) ################################## result.to_csv(export_folder + export_name + '.csv', sep = ';') # export to local drive - Download download_results(export_folder + export_name + '.csv', export_name + '.csv') # create a report with information on the prediction ## see https://stackoverflow.com/a/59578663