Skip to content
Snippets Groups Projects
Commit 1ad0c689 authored by DIANE's avatar DIANE
Browse files

sample selection page update

parent 4c93e8db
No related branches found
No related tags found
No related merge requests found
...@@ -25,6 +25,7 @@ from tempfile import NamedTemporaryFile ...@@ -25,6 +25,7 @@ from tempfile import NamedTemporaryFile
import numpy as np import numpy as np
from datetime import datetime from datetime import datetime
import json import json
from shutil import rmtree, move, make_archive
from utils.data_parsing import JcampParser, CsvParser from utils.data_parsing import JcampParser, CsvParser
from style.layout import UiComponents from style.layout import UiComponents
......
...@@ -38,7 +38,7 @@ if Path('report/out/model').exists() and Path('report/out/model').is_dir(): ...@@ -38,7 +38,7 @@ if Path('report/out/model').exists() and Path('report/out/model').is_dir():
match st.session_state["interface"]: match st.session_state["interface"]:
case 'simple': case 'simple':
dim_red_methods = ['PCA'] dim_red_methods = ['PCA']
cluster_methods = ['Kmeans'] # List of clustering algos cluster_methods = ['KS'] # List of clustering algos
selec_strategy = ['center'] selec_strategy = ['center']
case 'advanced': case 'advanced':
...@@ -109,15 +109,16 @@ else: ...@@ -109,15 +109,16 @@ else:
data_str = str(stringio.read()) data_str = str(stringio.read())
@st.cache_data @st.cache_data
def csv_loader(file = file, change = None): def read_csv(file = file, change = None):
from utils.data_parsing import CsvParser from utils.data_parsing import CsvParser
par = CsvParser(file= file) par = CsvParser(file= file)
float_data, non_float = par.parse(decimal = dec, separator = sep, index_col = names, header = hdr) par.parse(decimal = dec, separator = sep, index_col = names, header = hdr)
return float_data, non_float return par.float, par.meta_data, par.meta_data_st_, par.df
spectra, meta_data, md_df_st_, imp = read_csv(file= file, change = hash_)
try : try :
spectra, meta_data = csv_loader(file= file, change = hash_) spectra, meta_data, md_df_st_, imp = read_csv(file= file)
st.success("The data have been loaded successfully", icon="") st.success("The data have been loaded successfully", icon="")
except: except:
...@@ -147,7 +148,12 @@ else: ...@@ -147,7 +148,12 @@ else:
st.success("The data have been loaded successfully", icon="") st.success("The data have been loaded successfully", icon="")
################################################### END : I- Data loading and preparation #################################################### ################################################### END : I- Data loading and preparation ####################################################
if not spectra.empty:
with c2:
st.write('Data summary:')
st.write(f'- the number of spectra:{spectra.shape[0]}')
st.write(f'- the number of wavelengths:{spectra.shape[1]}')
st.write(f'- the number of categorical variables:{meta_data.shape[1]}')
################################################### BEGIN : visualize and split the data #################################################### ################################################### BEGIN : visualize and split the data ####################################################
st.subheader("I - Spectral Data Visualization", divider='blue') st.subheader("I - Spectral Data Visualization", divider='blue')
...@@ -265,13 +271,16 @@ if not spectra.empty: ...@@ -265,13 +271,16 @@ if not spectra.empty:
elif sel_ratio < 1.00: elif sel_ratio < 1.00:
ratio = int(sel_ratio*spectra.shape[0]) ratio = int(sel_ratio*spectra.shape[0])
ObjectHash(sel_ratio) ObjectHash(sel_ratio)
if st.session_state["interface"] =='simple':
clus_method = 'KS'
if dr_model and not clus_method: else:
clus_method = st.radio('Select samples selection strategy:', options = ['RDM', 'KS']) if dr_model and not clus_method:
clus_method = st.radio('Select samples selection strategy:', options = ['RDM', 'KS'])
elif dr_model and clus_method: elif dr_model and clus_method:
disabled1 = False if clus_method in cluster_methods else True disabled1 = False if clus_method in cluster_methods else True
selection = st.radio('Select samples selection strategy:', options = selec_strategy, disabled = disabled1) selection = st.radio('Select samples selection strategy:', options = selec_strategy, disabled = disabled1)
...@@ -338,8 +347,10 @@ elif labels: ...@@ -338,8 +347,10 @@ elif labels:
for i in np.unique(s): for i in np.unique(s):
C = np.where(np.array(labels) == i)[0] C = np.where(np.array(labels) == i)[0]
if C.shape[0] >= selection_number: if C.shape[0] >= selection_number:
from sklearn.cluster import KMeans
km2 = KMeans(n_clusters = selection_number) km2 = KMeans(n_clusters = selection_number)
km2.fit(tcr.iloc[C,:]) km2.fit(tcr.iloc[C,:])
from sklearn.metrics import pairwise_distances_argmin_min
clos, _ = pairwise_distances_argmin_min(km2.cluster_centers_, tcr.iloc[C,:]) clos, _ = pairwise_distances_argmin_min(km2.cluster_centers_, tcr.iloc[C,:])
selected_samples_idx.extend(tcr.iloc[C,:].iloc[list(clos)].index) selected_samples_idx.extend(tcr.iloc[C,:].iloc[list(clos)].index)
else: else:
...@@ -356,6 +367,18 @@ if not t.empty: ...@@ -356,6 +367,18 @@ if not t.empty:
filter = [''] + md_df_st_.columns.tolist() filter = [''] + md_df_st_.columns.tolist()
elif meta_data.empty and not clus_method in cluster_methods: elif meta_data.empty and not clus_method in cluster_methods:
filter = [] filter = []
if st.session_state["interface"] =='simple':
desactivatelist = True
if meta_data.empty:
desactivatelist = True
filter = ['']
elif not meta_data.empty:
filter = [''] + md_df_st_.columns.tolist()
desactivatelist = False
else:
desactivatelist = False
with c12: with c12:
st.write('Scores plot') st.write('Scores plot')
...@@ -363,7 +386,7 @@ if not t.empty: ...@@ -363,7 +386,7 @@ if not t.empty:
if len(axis)== 1: if len(axis)== 1:
tcr_plot['1d'] = np.random.uniform(-.5, .5, tcr_plot.shape[0]) tcr_plot['1d'] = np.random.uniform(-.5, .5, tcr_plot.shape[0])
colfilter = st.selectbox('Color by:', options= filter,format_func = lambda x: x if x else "<Select>") colfilter = st.selectbox('Color by:', options= filter,format_func = lambda x: x if x else "<Select>", disabled = desactivatelist)
ObjectHash(colfilter) ObjectHash(colfilter)
if colfilter in cluster_methods: if colfilter in cluster_methods:
tcr_plot[colfilter] = labels tcr_plot[colfilter] = labels
...@@ -500,14 +523,14 @@ if not spectra.empty: ...@@ -500,14 +523,14 @@ if not spectra.empty:
out3 = leverage > tresh3 out3 = leverage > tresh3
out4 = residuals > tresh4 out4 = residuals > tresh4
for i in range(n_samples): # for i in range(n_samples):
if out3[i]: # if out3[i]:
if not meta_data.empty: # if not meta_data.empty:
ann = meta_data.loc[:,'name'][i] # ann = meta_data.loc[:,'name'][i]
else: # else:
ann = t.index[i] # ann = t.index[i]
influence_plot.add_annotation(dict(x = leverage[i], y = residuals[i], showarrow=True, text = str(ann),font= dict(color= "black", size= 15), # influence_plot.add_annotation(dict(x = leverage[i], y = residuals[i], showarrow=True, text = str(ann),font= dict(color= "black", size= 15),
xanchor = 'auto', yanchor = 'auto')) # xanchor = 'auto', yanchor = 'auto'))
influence_plot.update_traces(marker=dict(size= 6), showlegend= True) influence_plot.update_traces(marker=dict(size= 6), showlegend= True)
influence_plot.update_layout(font=dict(size=23), width=800, height=500) influence_plot.update_layout(font=dict(size=23), width=800, height=500)
...@@ -623,7 +646,7 @@ if not sam.empty: ...@@ -623,7 +646,7 @@ if not sam.empty:
################################################### ###################################################
# ## generate report # ## generate report
@st.cache_data @st.cache_data
def export_report(variable): def export_report(change):
latex_report = report.report('Representative subset selection', file.name, dim_red_method, latex_report = report.report('Representative subset selection', file.name, dim_red_method,
clus_method, Nb_ech, ncluster, selection, selection_number, nb_clu,tcr, sam) clus_method, Nb_ech, ncluster, selection, selection_number, nb_clu,tcr, sam)
...@@ -638,7 +661,7 @@ if not sam.empty: ...@@ -638,7 +661,7 @@ if not sam.empty:
with open('report/out/dataset/'+file.name, 'w') as dd: with open('report/out/dataset/'+file.name, 'w') as dd:
dd.write(dxdata) dd.write(dxdata)
fig_spectra.savefig(report_path_rel/"out/figures/spectra_plot.png", dpi=400) ## Export report fig_spectra.savefig(report_path_rel/"out/figures/spectra_plot.png", dpi = 400) ## Export report
if len(axis) == 3: if len(axis) == 3:
for i in range(len(comb)): for i in range(len(comb)):
...@@ -650,6 +673,7 @@ if not sam.empty: ...@@ -650,6 +673,7 @@ if not sam.empty:
# Export du graphique # Export du graphique
if dim_red_method in ['PCA','NMF']: if dim_red_method in ['PCA','NMF']:
import plotly.io as pio
img = pio.to_image(loadingsplot, format="png") img = pio.to_image(loadingsplot, format="png")
with open(report_path_rel/"out/figures/loadings_plot.png", "wb") as f: with open(report_path_rel/"out/figures/loadings_plot.png", "wb") as f:
f.write(img) f.write(img)
...@@ -658,25 +682,27 @@ if not sam.empty: ...@@ -658,25 +682,27 @@ if not sam.empty:
influence_plot.write_image(report_path_rel/'out/figures/influence_plot.png', engine = 'kaleido') influence_plot.write_image(report_path_rel/'out/figures/influence_plot.png', engine = 'kaleido')
sam.to_csv(report_path_rel/'out/Selected_subset_for_calib_development.csv', sep = ';') sam.to_csv(report_path_rel/'out/Selected_subset_for_calib_development.csv', sep = ';')
export_report(variable) export_report(change = hash_)
if Path(report_path_rel/"report.tex").exists(): if Path(report_path_rel/"report.tex").exists():
report.generate_report(variable = 25) report.generate_report(change = hash_)
if Path(report_path_rel/"report.pdf").exists(): if Path(report_path_rel/"report.pdf").exists():
move(report_path_rel/"report.pdf", "./report/out/report.pdf") move(report_path_rel/"report.pdf", "./report/out/report.pdf")
return change return change
preparing_results_for_downloading(variable = 25) preparing_results_for_downloading(change = hash_)
report.generate_report(variable = 25) report.generate_report(change = hash_)
@st.cache_data @st.cache_data
def tempdir(change): def tempdir(change):
from tempfile import TemporaryDirectory
with TemporaryDirectory( prefix="results", dir="./report") as temp_dir:# create a temp directory with TemporaryDirectory( prefix="results", dir="./report") as temp_dir:# create a temp directory
tempdirname = os.path.split(temp_dir)[1] tempdirname = os.path.split(temp_dir)[1]
if len(os.listdir(report_path_rel/'out/figures/'))>=2: if len(os.listdir(report_path_rel/'out/figures/'))>=2:
make_archive(base_name= report_path_rel/"Results", format="zip", base_dir="out", root_dir = "./report")# create a zip file make_archive(base_name= report_path_rel/"Results", format="zip", base_dir="out", root_dir = "./report")# create a zip file
move(report_path_rel/"Results.zip", f"./report/{tempdirname}/Results.zip")# put the inside the temp dir move(report_path_rel/"Results.zip", f"./report/{tempdirname}/Results.zip")# put the inside the temp dir
with open(report_path_rel/f"{tempdirname}/Results.zip", "rb") as f: with open(report_path_rel/f"{tempdirname}/Results.zip", "rb") as f:
......
...@@ -86,6 +86,7 @@ class JcampParser: ...@@ -86,6 +86,7 @@ class JcampParser:
@property @property
def specs_df_(self): def specs_df_(self):
return self.spectra return self.spectra
@property @property
def meta_data_st_(self): def meta_data_st_(self):
me = self.metadata_.drop("concentrations", axis = 1) me = self.metadata_.drop("concentrations", axis = 1)
...@@ -114,14 +115,22 @@ class CsvParser: ...@@ -114,14 +115,22 @@ class CsvParser:
def parse(self, decimal, separator, index_col, header): def parse(self, decimal, separator, index_col, header):
from pandas import read_csv from pandas import read_csv
df = read_csv(self.file, decimal = decimal, sep = separator, index_col = index_col, header = header) self.df = read_csv(self.file, decimal = decimal, sep = separator, index_col = index_col, header = header)
if len(set(df.index))<df.shape[0]: if len(set(self.df.index))<self.df.shape[0]:
df = read_csv(self.file, decimal = decimal, sep = separator, index_col = None, header = header) self.df = read_csv(self.file, decimal = decimal, sep = separator, index_col = None, header = header)
float, non_float = df.select_dtypes(include='float'), df.select_dtypes(exclude='float') self.float, self.non_float = self.df.select_dtypes(include='float'), self.df.select_dtypes(exclude='float')
return float, non_float
@property
def meta_data_st_(self):
me = self.non_float.applymap(lambda x: x.upper() if isinstance(x, str) else x)
meta_data_st = me.loc[:,me.nunique(axis=0) > 1]
return meta_data_st
@property
def meta_data(self):
return self.non_float
# def parse(self): # def parse(self):
# import pandas as pd # import pandas as pd
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment