from Packages import * st.set_page_config(page_title="NIRS Utils", page_icon=":goat:", layout="wide") from Modules import * from Class_Mod.DATA_HANDLING import * # HTML pour le bandeau "CEFE - CNRS" # bandeau_html = """ # <div style="width: 100%; background-color: #4682B4; padding: 10px; margin-bottom: 10px;"> # <h1 style="text-align: center; color: white;">CEFE - CNRS</h1> # </div> # """ # # Injecter le code HTML du bandeau # st.markdown(bandeau_html, unsafe_allow_html=True) add_header() add_sidebar(pages_folder) local_css(css_file / "style_model.css") import shutil hash_ = '' def p_hash(add): global hash_ hash_ = hash_data(hash_+str(add)) return hash_ dirpath = Path('Report/out/model') if dirpath.exists() and dirpath.is_dir(): shutil.rmtree(dirpath) # #################################### Methods ############################################## # empty temp figures def delete_files(keep): supp = [] # Walk through the directory for root, dirs, files in os.walk('Report/', topdown=False): for file in files: if file != 'logo_cefe.png' and not any(file.endswith(ext) for ext in keep): os.remove(os.path.join(root, file)) ################################################################### st.title("Prediction making using a previously developed model") M10, M20= st.columns([2, 1]) M10.image("./images/prediction making.png", use_column_width=True) def preparespecdf(df): other = df.select_dtypes(exclude = 'float') rownames = other.iloc[:,0] spec = df.select_dtypes(include='float') spec.index = rownames return spec, other, rownames def check_exist(var): out = var in globals() return out files_format = ['.csv', '.dx'] export_folder = './data/predictions/' export_name = 'Predictions_of_' reg_algo = ["Interval-PLS"] with M20: file = st.file_uploader("Load NIRS Data for prediction making:", type = files_format, help=" :mushroom: select a csv matrix with samples as rows and lambdas as columns") if not file: st.info('Info: Insert your spectral data file above!') else: p_hash(file.name) test = file.name[file.name.find('.'):] export_name += file.name[:file.name.find('.')] if test == files_format[0]: qsep = st.radio("Select csv separator - _detected_: " + str(find_delimiter('data/'+file.name)), options=[";", ","],index=[";", ","].index(str(find_delimiter('data/'+file.name))), key=2, horizontal= True) qhdr = st.radio("indexes column in csv? - _detected_: " + str(find_col_index('data/'+file.name)), options=["no", "yes"], index=["no", "yes"].index(str(find_col_index('data/'+file.name))), key=3, horizontal= True) col = 0 if qhdr == 'yes' else None p_hash([qsep,qhdr]) df = pd.read_csv(file, sep=qsep, header= col) pred_data, cat, rownames = preparespecdf(df) elif test == files_format[1]: with NamedTemporaryFile(delete=False, suffix=".dx") as tmp: tmp.write(file.read()) tmp_path = tmp.name with open(tmp.name, 'r') as dd: dxdata = file.read() p_hash(str(dxdata)+str(file.name)) ## load and parse the temp dx file @st.cache_data def dx_loader(change): chem_data, spectra, meta_data, _ = read_dx(file = tmp_path) return chem_data, spectra, meta_data, _ chem_data, spectra, meta_data, _ = dx_loader(change = hash_) st.success("The data have been loaded successfully", icon="✅") if chem_data.to_numpy().shape[1]>0: yname = st.selectbox('Select target', options=chem_data.columns) measured = chem_data.loc[:,yname] == 0 y = chem_data.loc[:,yname].loc[measured] pred_data = spectra.loc[measured] else: pred_data = spectra os.unlink(tmp_path) # Load parameters st.header("I - Spectral data preprocessing & visualization", divider='blue') try: if check_exist("pred_data"):# Load the model with joblib @st.cache_data def specplot_raw(change): fig2 = plot_spectra(pred_data, xunits = 'lab', yunits = "meta_data.loc[:,'yunits'][0]") return fig2 rawspectraplot = specplot_raw(change = hash_) M1, M2= st.columns([2, 1]) with M1: st.write('Raw spectra') st.pyplot(rawspectraplot) with M2: params = st.file_uploader("Load preprocessings params", type = '.json', help=" .json file") if params: prep = json.load(params) p_hash(prep) @st.cache_data def preprocess_spectra(change): # M4.write(ProcessLookupError) if prep['normalization'] == 'Snv': x1 = Snv(pred_data) norm = 'Standard Normal Variate' else: norm = 'No Normalization was applied' x1 = pred_data x2 = savgol_filter(x1, window_length = int(prep["window_length"]), polyorder = int(prep["polyorder"]), deriv = int(prep["deriv"]), delta=1.0, axis=-1, mode="interp", cval=0.0) preprocessed = pd.DataFrame(x2, index = pred_data.index, columns = pred_data.columns) return norm, prep, preprocessed norm, prep, preprocessed = preprocess_spectra(change= hash_) ################################################################################################ ## plot preprocessed spectra if check_exist("preprocessed"): p_hash(preprocessed) M3, M4= st.columns([2, 1]) with M3: st.write('Preprocessed spectra') def specplot_prep(change): fig2 = plot_spectra(preprocessed, xunits = 'lab', yunits = "meta_data.loc[:,'yunits'][0]") return fig2 prepspectraplot = specplot_prep(change = hash_) st.pyplot(prepspectraplot) with M4: @st.cache_data def prep_info(change): SG = r'''- Savitzky-Golay derivative parameters \:(Window_length:{prep['window_length']}; polynomial order: {prep['polyorder']}; Derivative order : {prep['deriv']})''' Norm = r'''- Spectral Normalization \: {norm}''' return SG, Norm SG, Norm = prep_info(change = hash_) st.info('The spectra were preprocessed using:\n'+SG+"\n"+Norm) ################### Predictions making ########################## st.header("II - Prediction making", divider='blue') if check_exist("pred_data") and params:# Load the model with joblib M5, M6 = st.columns([2, 1]) model_file = M6.file_uploader("Load your model", type = '.pkl', help=" .pkl file") if model_file: with M6: try: model = joblib.load(model_file) st.success("The model has been loaded successfully", icon="✅") nvar = model.n_features_in_ except: st.error("Error: Something went wrong, the model was not loaded !", icon="❌") with M6: s = st.checkbox('Check this box if your model is of ipls type!', disabled = False if 'model' in globals() else True) index = st.file_uploader("select wavelengths index file", type="csv", disabled = [False if s else True][0]) if check_exist('preprocessed'): if s: if index: intervalls = pd.read_csv(index, sep=';', index_col=0).to_numpy() idx = [] for i in range(intervalls.shape[0]): idx.extend(np.arange(intervalls[i,2], intervalls[i,3]+1)) if max(idx) <= preprocessed.shape[1]: preprocesseddf = preprocessed.iloc[:,idx] ### get predictors else: st.error("Error: The number of columns in your data does not match the number of columns used to train the model. Please ensure they are the same.") else: preprocesseddf = preprocessed if check_exist("model") == False: disable = True elif check_exist("model") == True: if s and not index : disable = True elif s and index: disable = False elif not s and not index: disable = False elif not s and index: disable = True pred_button = M6.button('Predict', type='primary', disabled= disable) if check_exist("preprocesseddf"): if pred_button and nvar == preprocesseddf.shape[1]: try: result = pd.DataFrame(model.predict(preprocesseddf), index = rownames, columns = ['Results']) except: st.error(f'''Error: Length mismatch: the number of samples indices is {len(rownames)}, while the model produced {len(model.predict(preprocesseddf))} values. correct the "indexes column in csv?" parameter''') with M5: if preprocesseddf.shape[1]>1 and check_exist('result'): st.write('Predicted values distribution') # Creating histogram hist, axs = plt.subplots(1, 1, figsize =(15, 3), tight_layout = True) # Add x, y gridlines axs.grid( color ='grey', linestyle ='-.', linewidth = 0.5, alpha = 0.6) # Remove axes splines for s in ['top', 'bottom', 'left', 'right']: axs.spines[s].set_visible(False) # Remove x, y ticks axs.xaxis.set_ticks_position('none') axs.yaxis.set_ticks_position('none') # Add padding between axes and labels axs.xaxis.set_tick_params(pad = 5) axs.yaxis.set_tick_params(pad = 10) # Creating histogram N, bins, patches = axs.hist(result, bins = 12) # Setting color fracs = ((N**(1 / 5)) / N.max()) norm = colors.Normalize(fracs.min(), fracs.max()) for thisfrac, thispatch in zip(fracs, patches): color = plt.cm.viridis(norm(thisfrac)) thispatch.set_facecolor(color) st.pyplot(hist) st.write('Predicted values table') st.dataframe(result.T) #################################3 elif pred_button and nvar != preprocesseddf.shape[1]: M6.error(f'Error: The model was trained on {nvar} wavelengths, but you provided {preprocessed.shape[1]} wavelengths for prediction. Please ensure they match!') if check_exist('result'): @st.cache_data(show_spinner =False) def preparing_results_for_downloading(change): match test: # load csv file case '.csv': df.to_csv('Report/out/dataset/'+ file.name, sep = ';', encoding = 'utf-8', mode = 'a') case '.dx': with open('Report/out/dataset/'+file.name, 'w') as dd: dd.write(dxdata) prepspectraplot.savefig('./Report/out/figures/raw_spectra.png') rawspectraplot.savefig('./Report/out/figures/preprocessed_spectra.png') hist.savefig('./Report/out/figures/histogram.png') result.round(4).to_csv('./Report/out/The analysis result.csv', sep = ";", index_col=0) return change preparing_results_for_downloading(change = hash_) import tempfile @st.cache_data(show_spinner =False) def tempdir(change): with tempfile.TemporaryDirectory( prefix="results", dir="./Report") as temp_dir:# create a temp directory tempdirname = os.path.split(temp_dir)[1] if len(os.listdir('./Report/out/figures/'))==3: shutil.make_archive(base_name="./Report/Results", format="zip", base_dir="out", root_dir = "./Report")# create a zip file shutil.move("./Report/Results.zip", f"./Report/{tempdirname}/Results.zip")# put the inside the temp dir with open(f"./Report/{tempdirname}/Results.zip", "rb") as f: zip_data = f.read() return tempdirname, zip_data date_time = datetime.datetime.now().strftime('%y%m%d%H%M') try : tempdirname, zip_data = tempdir(change = hash_) st.download_button(label = 'Download', data = zip_data, file_name = f'Nirs_Workflow_{date_time}_Pred_.zip', mime ="application/zip", args = None, kwargs = None,type = "primary",use_container_width = True) except: st.write('rtt') except: M20.error('''Error: Data loading failed. Please check your file. Consider fine-tuning the dialect settings or ensure the file isn't corrupted.''')