From bdeab7d8a934de2029f85e471e65055617d0560b Mon Sep 17 00:00:00 2001 From: DIANE <abderrahim.diane@cefe.cnrs.fr> Date: Fri, 18 Oct 2024 15:32:43 +0200 Subject: [PATCH] range selection --- src/pages/2-model_creation.py | 47 +++++++++++++++++++++++++---------- src/pages/3-prediction.py | 17 +++++++------ src/utils/miscellaneous.py | 2 +- src/utils/visualize.py | 2 +- 4 files changed, 46 insertions(+), 22 deletions(-) diff --git a/src/pages/2-model_creation.py b/src/pages/2-model_creation.py index 2fe2f53..a513d8f 100644 --- a/src/pages/2-model_creation.py +++ b/src/pages/2-model_creation.py @@ -80,6 +80,7 @@ match file: namesx = 0 if pnamesx =="yes" else None try: spectra, _, _, xfile = read_csv(file= xcal_csv, change = hash_, dec = decx, sep = sepx, names =namesx, hdr = hdrx) + N,P = spectra.shape st.success('xfile has been loaded successfully') except: st.error('Error: The xfile has not been loaded successfully, please consider tuning the dialect settings!') @@ -215,23 +216,17 @@ match file: ################################################### BEGIN : visualize and split the data #################################################### st.subheader("I - Data visualization", divider = 'blue') if not spectra.empty and not y.empty: - # p_hash(y) - # p_hash(np.mean(spectra)) - if np.array(spectra.columns).dtype.kind in ['i', 'f']: - colnames = spectra.columns - else: - colnames = np.arange(spectra.shape[1]) + # if np.array(spectra.columns).dtype.kind in ['i', 'f']: + # colnames = spectra.columns + # else: + # colnames = np.arange(spectra.shape[1]) + - from utils.miscellaneous import data_split - X_train, X_test, y_train, y_test, train_index, test_index = data_split(x=spectra, y=y) - #### insight on loaded data spectra_plot = plot_spectra(spectra, xunits = 'Wavelength/Wavenumber', yunits = "Signal intensity") - target_plot = hist(y = y, y_train = y_train, y_test = y_test, target_name=yname) from utils.miscellaneous import desc_stats - stats = DataFrame([desc_stats(y_train), desc_stats(y_test), desc_stats(y)], index =['train', 'test', 'total'] ).round(2) # fig1, ax1 = plt.subplots( figsize = (12, 3)) # spectra.T.plot(legend = False, ax = ax1, linestyle = '-', linewidth = 0.6) @@ -241,11 +236,37 @@ if not spectra.empty and not y.empty: c2, c3 = st.columns([1, .4]) with c2: st.pyplot(spectra_plot) ######## Loaded graph - st.pyplot(target_plot) + if st.session_state.interface =='advanced': + with st.container(): + values = st.slider('Select a range of values', min_value = 0, max_value = 100, value = (0, P)) + hash_ = ObjectHash(current=hash_, add= values) + spectra = spectra.iloc[:,values[0]:values[1]] + nwl = spectra.shape + + + if np.array(spectra.columns).dtype.kind in ['i', 'f']: + colnames = spectra.columns + else: + colnames = np.arange(spectra.shape[1]) + + + + hash_ = ObjectHash(current= hash_, add=values) + st.pyplot(plot_spectra(spectra, xunits = 'Wavelength/Wavenumber', yunits = "Signal intensity")) + + from utils.miscellaneous import data_split + X_train, X_test, y_train, y_test, train_index, test_index = data_split(x=spectra, y=y) + with c3: st.write('Loaded data summary') + stats = DataFrame([desc_stats(y_train), desc_stats(y_test), desc_stats(y)], index =[f'{yname} (Cal)', f'{yname} (Val)', f'{yname} (Total)'] ).round(2) st.write(stats) + ## histogramms + target_plot = hist(y = y, y_train = y_train, y_test = y_test, target_name=yname) + st.pyplot(target_plot) + st.info('Info: 70/30 split ratio was used to split the dataset into calibration and prediction subsets') + ################################################### END : visualize and split the data ####################################################### @@ -323,7 +344,7 @@ if not spectra.empty and not y.empty: # target_plot.savefig("./report/figures/histogram.png") # st.session_state['hash_Reg'] = str(np.random.randint(2000000000)) folds = KF_CV.CV(X_train, y_train, nb_folds)# split train data into nb_folds for cross_validation - + match model_type: case 'PLS': from utils.regress import Plsr diff --git a/src/pages/3-prediction.py b/src/pages/3-prediction.py index 1069f9f..acc1c04 100644 --- a/src/pages/3-prediction.py +++ b/src/pages/3-prediction.py @@ -20,10 +20,10 @@ UiComponents(pagespath = pages_folder, csspath= css_file,imgpath=image_path , # local_css(css_file / "style_model.css") hash_ = '' -def p_hash(add): - global hash_ - hash_ = hash_data(hash_+str(add)) - return hash_ +# def p_hash(add): +# global hash_ +# hash_ = hash_data(hash_+str(add)) +# return hash_ dirpath = Path('report/out/model') if dirpath.exists() and dirpath.is_dir(): @@ -163,13 +163,16 @@ with c2: tmp_path = tmp.name with open(tmp.name, 'r') as dd: dxdata = new_data.read() - p_hash(str(dxdata)+str(new_data.name)) + hash_ = ObjectHash(current= hash_, add = str(dxdata)+str(new_data.name)) ## load and parse the temp dx file @st.cache_data def dx_loader(change): - chem_data, spectra, meta_data, _ = read_dx(file = tmp_path) - return chem_data, spectra, meta_data, _ + from utils.data_parsing import JcampParser + M = JcampParser(path = tmp_path) + M.parse() + return M.chem_data, M.specs_df_, M.meta_data, M.meta_data_st_ + chem_data, spectra, meta_data, _ = dx_loader(change = hash_) st.success("The data have been loaded successfully", icon="✅") if chem_data.to_numpy().shape[1]>0: diff --git a/src/utils/miscellaneous.py b/src/utils/miscellaneous.py index 7de390f..f890e2e 100644 --- a/src/utils/miscellaneous.py +++ b/src/utils/miscellaneous.py @@ -24,7 +24,7 @@ def download_results(data, export_name): def data_split(x, y): from kennard_stone import train_test_split # Split data into training and test sets using the kennard_stone method and correlation metric, 25% of data is used for testing - X_train, X_test, y_train, y_test = train_test_split(x, y, test_size = 0.25) + X_train, X_test, y_train, y_test = train_test_split(x, y, test_size = 0.30) train_index, test_index = np.array(X_train.index), np.array(X_test.index) return X_train, X_test, y_train, y_test, train_index, test_index diff --git a/src/utils/visualize.py b/src/utils/visualize.py index 8c3add6..94183b9 100644 --- a/src/utils/visualize.py +++ b/src/utils/visualize.py @@ -50,7 +50,7 @@ def plot_spectra(specdf, xunits, yunits): # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Cal/val hist ~~~~~~~~~~~~~~~~~~~~~~~~~~ @st.cache_data def hist(y, y_train, y_test, target_name = 'y'): - fig, ax = plt.subplots(figsize = (12,3)) + fig, ax = plt.subplots(figsize = (5,2)) sns.histplot(y, color = "#004e9e", kde = True, label = str(target_name), ax = ax, fill = True) sns.histplot(y_train, color = "#2C6B6F", kde = True, label = str(target_name)+" (Cal)", ax = ax, fill = True) sns.histplot(y_test, color = "#d0f7be", kde = True, label = str(target_name)+" (Val)", ax = ax, fill = True) -- GitLab