From 204a2f346ffe6107f3791f8d5a303c8caca80f5c Mon Sep 17 00:00:00 2001
From: DIANE <abderrahim.diane@cefe.cnrs.fr>
Date: Mon, 29 Jul 2024 10:38:11 +0200
Subject: [PATCH] switch case

---
 src/pages/1-samples_selection.py | 282 ++++++++-------
 src/pages/2-model_creation.py    | 581 ++++++++++++++-----------------
 2 files changed, 394 insertions(+), 469 deletions(-)

diff --git a/src/pages/1-samples_selection.py b/src/pages/1-samples_selection.py
index af79bea..4c700a4 100644
--- a/src/pages/1-samples_selection.py
+++ b/src/pages/1-samples_selection.py
@@ -24,26 +24,27 @@ dim_red_methods=['', 'PCA','UMAP', 'NMF']  # List of dimensionality reduction al
 cluster_methods = ['', 'Kmeans','HDBSCAN', 'AP', 'KS', 'RDM'] # List of clustering algos
 selec_strategy = ['center','random']
 
-if st.session_state["interface"] == 'simple':
-    st.write(':red[Automated Simple Interface]')
-    # hide_pages("Predictions")
-    if 37 not in st.session_state:
-        default_reduction_option = 1
-    else:
-        default_reduction_option = dim_red_methods.index(st.session_state.get(37))
-    if 38 not in st.session_state:
-        default_clustering_option = 1
-    else:
-        default_clustering_option = cluster_methods.index(st.session_state.get(38))
-    if 102 not in st.session_state:
-        default_sample_selection_option = 1
-    else:
-        default_sample_selection_option = selec_strategy.index(st.session_state.get(102))
-
-if st.session_state["interface"] == 'advanced':
-    default_reduction_option = 0
-    default_clustering_option = 0
-    default_sample_selection_option = 0
+match st.session_state["interface"]:
+    case 'simple':
+        st.write(':red[Automated Simple Interface]')
+        # hide_pages("Predictions")
+        if 37 not in st.session_state:
+            default_reduction_option = 1
+        else:
+            default_reduction_option = dim_red_methods.index(st.session_state.get(37))
+        if 38 not in st.session_state:
+            default_clustering_option = 1
+        else:
+            default_clustering_option = cluster_methods.index(st.session_state.get(38))
+        if 102 not in st.session_state:
+            default_sample_selection_option = 1
+        else:
+            default_sample_selection_option = selec_strategy.index(st.session_state.get(102))
+        
+    case'advanced':
+        default_reduction_option = 0
+        default_clustering_option = 0
+        default_sample_selection_option = 0
 
 ################################### I - Data Loading and Visualization ########################################
 st.title("Calibration Subset Selection")
@@ -74,33 +75,34 @@ if not data_file:
 else:
     # Retrieve the extension of the file
     test = data_file.name[data_file.name.find('.'):]
+    match test:
     ## Load .csv file
-    if test== '.csv':
-        with col1:
-            # Select list for CSV delimiter
-            psep = st.radio("Select csv separator - _detected_: " + str(find_delimiter('data/'+data_file.name)), options=[";", ","], index=[";", ","].index(str(find_delimiter('data/'+data_file.name))), key=9)
-                # Select list for CSV header True / False
-            phdr = st.radio("indexes column in csv? - _detected_: " + str(find_col_index('data/'+data_file.name)), options=["no", "yes"], index=["no", "yes"].index(str(find_col_index('data/'+data_file.name))), key=31)
-            if phdr == 'yes':
-                col = 0
-            else:
-                col = False
-            imp = pd.read_csv(data_file, sep=psep, index_col=col)
-            # spectra = col_cat(imp)[0]
-            # meta_data = col_cat(imp)[1]
-            spectra, md_df_st_ = col_cat(imp)
-            meta_data = md_df_st_
-            st.success("The data have been loaded successfully", icon="✅")
-    ## Load .dx file
-    elif test == '.dx':
-        # Create a temporary file to save the uploaded file
-        with NamedTemporaryFile(delete=False, suffix=".dx") as tmp:
-            tmp.write(data_file.read())
-            tmp_path = tmp.name
+        case '.csv':
             with col1:
-                _, spectra, meta_data, md_df_st_ = read_dx(file = tmp_path)
+                # Select list for CSV delimiter
+                psep = st.radio("Select csv separator - _detected_: " + str(find_delimiter('data/'+data_file.name)), options=[";", ","], index=[";", ","].index(str(find_delimiter('data/'+data_file.name))), key=9)
+                    # Select list for CSV header True / False
+                phdr = st.radio("indexes column in csv? - _detected_: " + str(find_col_index('data/'+data_file.name)), options=["no", "yes"], index=["no", "yes"].index(str(find_col_index('data/'+data_file.name))), key=31)
+                if phdr == 'yes':
+                    col = 0
+                else:
+                    col = False
+                imp = pd.read_csv(data_file, sep=psep, index_col=col)
+                # spectra = col_cat(imp)[0]
+                # meta_data = col_cat(imp)[1]
+                spectra, md_df_st_ = col_cat(imp)
+                meta_data = md_df_st_
                 st.success("The data have been loaded successfully", icon="✅")
-        os.unlink(tmp_path)
+        ## Load .dx file
+        case '.dx':
+            # Create a temporary file to save the uploaded file
+            with NamedTemporaryFile(delete=False, suffix=".dx") as tmp:
+                tmp.write(data_file.read())
+                tmp_path = tmp.name
+                with col1:
+                    _, spectra, meta_data, md_df_st_ = read_dx(file = tmp_path)
+                    st.success("The data have been loaded successfully", icon="✅")
+            os.unlink(tmp_path)
 
 
     
@@ -163,28 +165,29 @@ if not spectra.empty:
     dim_red_method = bb1.selectbox("Dimensionality reduction techniques: ", options = dim_red_methods, index = default_reduction_option, key = 37)
     clus_method = bb2.selectbox("Clustering/sampling techniques: ", options = cluster_methods, index = default_clustering_option, key = 38)
     xc = standardize(spectra, center=True, scale=False)
-
-
-    if dim_red_method == dim_red_methods[0]:
-        bb1.warning('⚠️ Please choose an algothithm !')
-    elif dim_red_method == dim_red_methods[1]:
-        dr_model = LinearPCA(xc, Ncomp=8)
-
-    elif dim_red_method == dim_red_methods[2]:
-        if not meta_data.empty:
-            filter = md_df_st_.columns
-            filter = filter.insert(0, 'Nothing')
-            col = bb1.selectbox('Supervised UMAP by:', options= filter, key=108)
-            if col == 'Nothing':
-                supervised = None
+    
+    match dim_red_method:
+        case "":
+                bb1.warning('⚠️ Please choose an algorithm !')
+        
+        case "PCA":
+            dr_model = LinearPCA(xc, Ncomp=8)
+
+        case "UMAP":
+            if not meta_data.empty:
+                filter = md_df_st_.columns
+                filter = filter.insert(0, 'Nothing')
+                col = bb1.selectbox('Supervised UMAP by:', options= filter, key=108)
+                if col == 'Nothing':
+                    supervised = None
+                else:
+                    supervised = md_df_st_[col]
             else:
-                supervised = md_df_st_[col]
-        else:
-            supervised = None
-        dr_model = Umap(numerical_data = MinMaxScale(spectra), cat_data = supervised)
+                supervised = None
+            dr_model = Umap(numerical_data = MinMaxScale(spectra), cat_data = supervised)
 
-    elif dim_red_method == dim_red_methods[3]:
-        dr_model = Nmf(spectra, Ncomp= 3)
+        case 'NMF':
+            dr_model = Nmf(spectra, Ncomp= 3)
 
     if dr_model:
         axis1 = bb3.selectbox("x-axis", options = dr_model.scores_.columns, index=0)
@@ -196,69 +199,56 @@ if not spectra.empty:
 
 
 ###### II - clustering #######
-
 if not t.empty:
+    clustered = np.arange(n_samples)
+    non_clustered = None
+
     if dim_red_method == 'UMAP':
         scores = st.container()
     else:
         scores, loadings= st.columns([3,3])
 
     tcr = standardize(t)
-        # Clustering
-    # 1- K-MEANS Clustering
-    if clus_method == cluster_methods[0]:
-        bb2.warning('⚠️ Please choose an algothithm !')
-
-    if clus_method == cluster_methods[1]:
-        cl_model = Sk_Kmeans(tcr, max_clusters = 25)
-        ncluster = scores.number_input(min_value=2, max_value=25, value=cl_model.suggested_n_clusters_, label = 'Select the desired number of clusters')
-        # fig2 = px.bar(cl_model.inertia_.T, y = 'inertia')
-        # scores.write(f"Suggested n_clusters : {cl_model.suggested_n_clusters_}")
-        # scores.plotly_chart(fig2,use_container_width=True)
-        # img = pio.to_image(fig2, format="png")
-        # with open("./Report/figures/Elbow.png", "wb") as f:
-        #         f.write(img)    
-        data, labels, clu_centers = cl_model.fit_optimal(nclusters = ncluster)
-
-    # 2- HDBSCAN clustering
-    elif clus_method == cluster_methods[2]:
-        optimized_hdbscan = Hdbscan(np.array(tcr))
-        # all_labels, hdbscan_score, clu_centers = optimized_hdbscan.HDBSCAN_scores_
-        all_labels, clu_centers = optimized_hdbscan.HDBSCAN_scores_
-        labels = [f'cluster#{i+1}' if i !=-1 else 'Non clustered' for i in all_labels]
-        ncluster = len(clu_centers)
-
-    # 3- Affinity propagation
-    elif clus_method == cluster_methods[3]:
-        cl_model = AP(X = tcr)
-        data, labels, clu_centers = cl_model.fit_optimal_
-        ncluster = len(clu_centers)
-
-    elif clus_method == cluster_methods[4]:
-        rset = scores.number_input(min_value=0, max_value=100, value=20, label = 'The ratio of data to be sampled (%)')
-        cl_model = KS(x = tcr, rset = rset)
-        calset = cl_model.calset
-        labels = ["ind"]*n_samples
-        ncluster = "1"
-        selection_number = 'None'
-
-    elif clus_method == cluster_methods[5]:
-        rset = scores.number_input(min_value=0, max_value=100, value=20, label = 'The ratio of data to be sampled (%)')
-        cl_model = RDM(x = tcr, rset = rset)
-        calset = cl_model.calset
-        labels = ["ind"]*n_samples
-        ncluster = "1"
-        selection_number = 'None'
     
-    if clus_method == cluster_methods[2]:
-        #clustered = np.where(np.array(labels) != 'Non clustered')[0]
-        clustered = np.arange(n_samples)
-        non_clustered = np.where(np.array(labels) == 'Non clustered')[0]
+    # Clustering
+    match clus_method:
+        case '':
+            bb2.warning('⚠️ Please choose an algothithm !')
+        case 'Kmeans':
+            cl_model = Sk_Kmeans(tcr, max_clusters = 25)
+            ncluster = scores.number_input(min_value=2, max_value=25, value=cl_model.suggested_n_clusters_, label = 'Select the desired number of clusters')  
+            data, labels, clu_centers = cl_model.fit_optimal(nclusters = ncluster)
+
+        # 2- HDBSCAN clustering
+        case 'HDBSCAN':
+            optimized_hdbscan = Hdbscan(np.array(tcr))
+            all_labels, clu_centers = optimized_hdbscan.HDBSCAN_scores_
+            labels = [f'cluster#{i+1}' if i !=-1 else 'Non clustered' for i in all_labels]
+            ncluster = len(clu_centers)
+            non_clustered = np.where(np.array(labels) == 'Non clustered')[0]
+
+        # 3- Affinity propagation
+        case 'AP':
+            cl_model = AP(X = tcr)
+            data, labels, clu_centers = cl_model.fit_optimal_
+            ncluster = len(clu_centers)
+
+        case 'KS':
+            rset = scores.number_input(min_value=0, max_value=100, value=20, label = 'The ratio of data to be sampled (%)')
+            cl_model = KS(x = tcr, rset = rset)
+            calset = cl_model.calset
+            labels = ["ind"]*n_samples
+            ncluster = "1"
+            selection_number = 'None'
+
+        case 'RDM':
+            rset = scores.number_input(min_value=0, max_value=100, value=20, label = 'The ratio of data to be sampled (%)')
+            cl_model = RDM(x = tcr, rset = rset)
+            calset = cl_model.calset
+            labels = ["ind"]*n_samples
+            ncluster = "1"
+            selection_number = 'None'            
 
-    else:
-        clustered = np.arange(n_samples)
-        non_clustered = None
-    
     new_tcr = tcr.iloc[clustered,:]    
     
 
@@ -273,35 +263,37 @@ elif labels:
     num_clusters = len(np.unique(labels))
     custom_color_palette = px.colors.qualitative.Plotly[:num_clusters]
     if clus_method:
-        if clus_method == cluster_methods[4] or clus_method == cluster_methods[5]:
+        if clus_method in ['KS', 'RDM']:
             selected_samples_idx = calset[1]
             selection = 'None'
         else:
             selection = scores.radio('Select samples selection strategy:',
                                         options = selec_strategy, index = default_sample_selection_option, key=102)
+        
+        match selection:
         # Strategy 0
-        if selection == selec_strategy[0]:
-            # list samples at clusters centers - Use sklearn.metrics.pairwise_distances_argmin if you want more than 1 sample per cluster
-            closest, _ = pairwise_distances_argmin_min(clu_centers, new_tcr)
-            selected_samples_idx = np.array(new_tcr.index)[list(closest)]
-            selected_samples_idx = selected_samples_idx.tolist()
-            
-        #### Strategy 1
-        elif selection == selec_strategy[1]:
-            selection_number = scores.number_input('How many samples per cluster?',
-                                                    min_value = 1, step=1, value = 3)
-            s = np.array(labels)[np.where(np.array(labels) !='Non clustered')[0]]
-            for i in np.unique(s):
-                C = np.where(np.array(labels) == i)[0]
-                if C.shape[0] >= selection_number:
-                    # scores.write(list(tcr.index)[labels== i])
-                    km2 = KMeans(n_clusters = selection_number)
-                    km2.fit(tcr.iloc[C,:])
-                    clos, _ = pairwise_distances_argmin_min(km2.cluster_centers_, tcr.iloc[C,:])
-                    selected_samples_idx.extend(tcr.iloc[C,:].iloc[list(clos)].index)
-                else:
-                    selected_samples_idx.extend(new_tcr.iloc[C,:].index.to_list())
-                # list indexes of selected samples for colored plot    
+            case 'center':
+                # list samples at clusters centers - Use sklearn.metrics.pairwise_distances_argmin if you want more than 1 sample per cluster
+                closest, _ = pairwise_distances_argmin_min(clu_centers, new_tcr)
+                selected_samples_idx = np.array(new_tcr.index)[list(closest)]
+                selected_samples_idx = selected_samples_idx.tolist()
+                
+            #### Strategy 1
+            case 'random':
+                selection_number = scores.number_input('How many samples per cluster?',
+                                                        min_value = 1, step=1, value = 3)
+                s = np.array(labels)[np.where(np.array(labels) !='Non clustered')[0]]
+                for i in np.unique(s):
+                    C = np.where(np.array(labels) == i)[0]
+                    if C.shape[0] >= selection_number:
+                        # scores.write(list(tcr.index)[labels== i])
+                        km2 = KMeans(n_clusters = selection_number)
+                        km2.fit(tcr.iloc[C,:])
+                        clos, _ = pairwise_distances_argmin_min(km2.cluster_centers_, tcr.iloc[C,:])
+                        selected_samples_idx.extend(tcr.iloc[C,:].iloc[list(clos)].index)
+                    else:
+                        selected_samples_idx.extend(new_tcr.iloc[C,:].index.to_list())
+                    # list indexes of selected samples for colored plot    
 
 ################################      Plots visualization          ############################################
 
@@ -385,7 +377,7 @@ if not t.empty:
 
 
 if not spectra.empty:
-    if dim_red_method == dim_red_methods[1] or dim_red_method == dim_red_methods[3]:
+    if dim_red_method in ['PCA','NMF']:
         with loadings:
             st.write('Loadings plot')
             p = dr_model.loadings_
@@ -421,7 +413,7 @@ if not spectra.empty:
             with open("./Report/figures/loadings_plot.png", "wb") as f:
                 f.write(img)
 #############################################################################################################
-    if dim_red_method == dim_red_methods[1]:
+    if dim_red_method == 'PCA':
         influence, hotelling = st.columns([3, 3])
         with influence:
             st.write('Influence plot')
@@ -549,10 +541,12 @@ if labels:
         sam1.index = np.arange(len(selected_samples_idx))+1
         info.info(f'Information !\n - The total number of samples: {n_samples}.\n- The number of samples selected for reference analysis: {sam1.shape[0]}.\n - The proportion of samples selected for reference analysis: {round(sam1.shape[0]/n_samples*100)}%.')
         sam = sam1
+        # if clus_method == cluster_methods[2]:
+        #     unclus = sel.checkbox("Include non clustered samples (for HDBSCAN clustering)", value=True)
+
         if clus_method == cluster_methods[2]:
             unclus = sel.checkbox("Include non clustered samples (for HDBSCAN clustering)", value=True)
 
-        if clus_method == cluster_methods[2]:
             if selected_samples_idx:
                 if unclus:
                     if meta_data.empty:
diff --git a/src/pages/2-model_creation.py b/src/pages/2-model_creation.py
index 06d0645..daa10e7 100644
--- a/src/pages/2-model_creation.py
+++ b/src/pages/2-model_creation.py
@@ -1,5 +1,3 @@
-# import streamlit
-import pandas as pd
 from Packages import *
 st.set_page_config(page_title="NIRS Utils", page_icon=":goat:", layout="wide")
 from Modules import *
@@ -19,98 +17,104 @@ if os.path.exists(repertoire_a_vider):
 
 local_css(css_file / "style_model.css")
 
-    ####################################### page Design #######################################
-st.title("Calibration Model Development")
+####################################### page preamble #######################################
+st.title("Calibration Model Development") # page title
 st.markdown("Create a predictive model, then use it for predicting your target variable (chemical data) from NIRS spectra")
 M0, M00 = st.columns([1, .4])
-M0.image("C:/Users/diane/Desktop/nirs_workflow/src/images/graphical_abstract.jpg", use_column_width=True)
-# st.header("II - Model creation", divider='blue')
-# st.header("Cross-Validation results")
-# cv1, cv2 = st.columns([2,2])
-cv3 = st.container()
-
-
-    ##############################################################################################
-
-
-files_format = ['.csv', '.dx']
-file = M00.radio('Select files format:', options = files_format)
-spectra = pd.DataFrame()
-y = pd.DataFrame()
-regression_algo = None
-Reg = None
-# load .csv file
-if file == files_format[0]:
-    xcal_csv = M00.file_uploader("Select NIRS Data", type="csv", help=" :mushroom: select a csv matrix with samples as rows and lambdas as columns")
-    if xcal_csv:
-        sepx = M00.radio("Select separator (X file) - _detected_: " + str(find_delimiter('data/'+xcal_csv.name)),
-                                options=[";", ","], index=[";", ","].index(str(find_delimiter('data/'+xcal_csv.name))), key=0)
-        hdrx = M00.radio("samples name (X file)? - _detected_: " + str(find_col_index('data/'+xcal_csv.name)),
-                                options=["no", "yes"], index=["no", "yes"].index(str(find_col_index('data/'+xcal_csv.name))), key=1)
-        if hdrx == "yes": col = 0
-        else: col = False
-    else:
-        M00.warning('Insert your spectral data file here!')
+M0.image("C:/Users/diane/Desktop/nirs_workflow/src/images/graphical_abstract.jpg", use_column_width=True) # graphical abstract
+
+
+
+####################################### I- Data preparation
+files_format = ['.csv', '.dx'] # Supported files format
+file = M00.radio('Select files format:', options = files_format) # Select a file format
+spectra = pd.DataFrame() # preallocate the spectral data block
+y = pd.DataFrame() # preallocate the target(s) data block
+match file:
+    ## load .csv file
+    case '.csv':
+        # Load X-block data
+        xcal_csv = M00.file_uploader("Select NIRS Data", type="csv", help=" :mushroom: select a csv matrix with samples as rows and lambdas as columns")
+        if xcal_csv:
+            sepx = M00.radio("Select separator (X file) - _detected_: " + str(find_delimiter('data/'+xcal_csv.name)),
+                                    options=[";", ","], index=[";", ","].index(str(find_delimiter('data/'+xcal_csv.name))), key=0)
+            hdrx = M00.radio("samples name (X file)? - _detected_: " + str(find_col_index('data/'+xcal_csv.name)),
+                                    options=["no", "yes"], index=["no", "yes"].index(str(find_col_index('data/'+xcal_csv.name))), key=1)
+            match hdrx:
+                case "yes":
+                    col = 0
+                case "no":
+                    col = False
+        else:
+            M00.warning('Insert your spectral data file here!')
         
-    ycal_csv = M00.file_uploader("Select corresponding Chemical Data", type="csv", help=" :mushroom: select a csv matrix with samples as rows and chemical values as a column")
-    if ycal_csv:
-        sepy = M00.radio("Select separator (Y file) - _detected_: " + str(find_delimiter('data/'+ycal_csv.name)),
-                         options=[";", ","], index=[";", ","].index(str(find_delimiter('data/'+ycal_csv.name))), key=2)
-        hdry = M00.radio("samples name (Y file)? - _detected_: " + str(find_col_index('data/'+ycal_csv.name)),
-                         options=["no", "yes"], index=["no", "yes"].index(str(find_col_index('data/'+ycal_csv.name))), key=3)
-        if hdry == "yes": col = 0
-        else: col = False
-    else:
-        M00.warning('Insert your target data file here!')
-    
-    if xcal_csv and ycal_csv:
-        file_name = str(xcal_csv.name) +' and '+ str(ycal_csv.name)
-        xfile = pd.read_csv(xcal_csv, decimal='.', sep=sepx, index_col=col, header=0)
-        yfile =  pd.read_csv(ycal_csv, decimal='.', sep=sepy, index_col=col)
-        if yfile.shape[1]>0 and xfile.shape[1]>0 :
-            spectra, meta_data = col_cat(xfile)
-            chem_data, idx = col_cat(yfile)
-            if chem_data.shape[1]>1:
-                yname = M00.selectbox('Select target', options=chem_data.columns)
-                y = chem_data.loc[:,yname]
-            else:
-                y = chem_data.iloc[:,0]
+        # Load Y-block data
+        ycal_csv = M00.file_uploader("Select corresponding Chemical Data", type="csv", help=" :mushroom: select a csv matrix with samples as rows and chemical values as a column")
+        if ycal_csv:
+            sepy = M00.radio("Select separator (Y file) - _detected_: " + str(find_delimiter('data/'+ycal_csv.name)),
+                            options=[";", ","], index=[";", ","].index(str(find_delimiter('data/'+ycal_csv.name))), key=2)
+            hdry = M00.radio("samples name (Y file)? - _detected_: " + str(find_col_index('data/'+ycal_csv.name)),
+                            options=["no", "yes"], index=["no", "yes"].index(str(find_col_index('data/'+ycal_csv.name))), key=3)
             
+            match hdry:
+                case "yes":
+                    col = 0
+                case "no":
+                    col = False
 
-            spectra = pd.DataFrame(spectra).astype(float)
-            # if not meta_data.empty :
-            #     st.write(meta_data)
+        else:
+            M00.warning('Insert your target data file here!')
+        
+        if xcal_csv and ycal_csv:
+            file_name = str(xcal_csv.name) +' and '+ str(ycal_csv.name)
+            xfile = pd.read_csv(xcal_csv, decimal='.', sep=sepx, index_col=col, header=0)
+            yfile =  pd.read_csv(ycal_csv, decimal='.', sep=sepy, index_col=col)
+            if yfile.shape[1]>0 and xfile.shape[1]>0 :
+                spectra, meta_data = col_cat(xfile)
+                chem_data, idx = col_cat(yfile)
+                if chem_data.shape[1]>1:
+                    yname = M00.selectbox('Select target', options=chem_data.columns)
+                    y = chem_data.loc[:,yname]
+                else:
+                    y = chem_data.iloc[:,0]
+                
 
-            if spectra.shape[0] != y.shape[0]:
-                M00.warning('X and Y have different sample size')
-                y = pd.DataFrame
-                spectra = pd.DataFrame
+                spectra = pd.DataFrame(spectra).astype(float)
+                # if not meta_data.empty :
+                #     st.write(meta_data)
+
+                if spectra.shape[0] != y.shape[0]:
+                    M00.warning('X and Y have different sample size')
+                    y = pd.DataFrame
+                    spectra = pd.DataFrame
 
-        else:
-            M00.error('Error: The data has not been loaded successfully, please consider tuning the decimal and separator !')
-
-## Load .dx file
-elif file == files_format[1]:
-    data_file = M00.file_uploader("Select Data", type=".dx", help=" :mushroom: select a dx file")
-    if not data_file:
-        M00.warning('Load your file here!')
-    else :
-        file_name = str(data_file.name)
-        with NamedTemporaryFile(delete=False, suffix=".dx") as tmp:
-            tmp.write(data_file.read())
-            tmp_path = tmp.name
-            chem_data, spectra, meta_data, meta_data_st = read_dx(file =  tmp_path)
-            M00.success("The data have been loaded successfully", icon="✅")
-            if chem_data.shape[1]>0:
-                yname = M00.selectbox('Select target', options=chem_data.columns)
-                measured = chem_data.loc[:,yname] > 0
-                y = chem_data.loc[:,yname].loc[measured]
-                spectra = spectra.loc[measured]
             else:
-                M00.warning('Warning: your file includes no target variables to model !', icon="⚠️")
-        os.unlink(tmp_path)
+                M00.error('Error: The data has not been loaded successfully, please consider tuning the decimal and separator !')
+    
+    ## Load .dx file
+    case '.dx':
+        data_file = M00.file_uploader("Select Data", type=".dx", help=" :mushroom: select a dx file")
+        if not data_file:
+            M00.warning('Load your file here!')
+        else :
+            file_name = str(data_file.name)
+            with NamedTemporaryFile(delete=False, suffix=".dx") as tmp:
+                tmp.write(data_file.read())
+                tmp_path = tmp.name
+                chem_data, spectra, meta_data, meta_data_st = read_dx(file =  tmp_path)
+                M00.success("The data have been loaded successfully", icon="✅")
+                if chem_data.shape[1]>0:
+                    yname = M00.selectbox('Select target', options=chem_data.columns)
+                    measured = chem_data.loc[:,yname] > 0
+                    y = chem_data.loc[:,yname].loc[measured]
+                    spectra = spectra.loc[measured]
+                else:
+                    M00.warning('Warning: your file includes no target variables to model !', icon="⚠️")
+            os.unlink(tmp_path)
 
-### split the data
+
+
+# visualize and split the data
 st.header("I - Data visualization", divider='blue')
 if not spectra.empty and not y.empty:
     M0, M000 = st.columns([1, .4])
@@ -120,7 +124,6 @@ if not spectra.empty and not y.empty:
         colnames = np.arange(spectra.shape[1])
 
 
-    #rd_seed = M1.slider("Customize Train-test split", min_value=1, max_value=100, value=42, format="%i")
     # Split data into training and test sets using the kennard_stone method and correlation metric, 25% of data is used for testing
     train_index, test_index = train_test_split_idx(spectra, y = y, method="kennard_stone", metric="correlation", test_size=0.25, random_state=42)
 
@@ -152,202 +155,150 @@ if not spectra.empty and not y.empty:
     M000.write('Loaded data summary')
     M000.write(pd.DataFrame([desc_stats(y_train),desc_stats(y_test),desc_stats(y)], index =['train', 'test', 'total'] ).round(2))
     stats=pd.DataFrame([desc_stats(y_train),desc_stats(y_test),desc_stats(y)], index =['train', 'test', 'total'] ).round(2)
-    ####################################### Insight into the loaded data
 
 
     ####################################### Model creation ###################################################
+regression_algo = None # initialize the selected regression algorithm
+Reg = None  # initialize the regression model object
+
 st.header("II - Model creation", divider='blue')
-if not spectra.empty and not y.empty:
+if not (spectra.empty and y.empty):
     M10, M20, M30, M40, M50 = st.columns([1,1,1,1,1])
-    modes = ['regression', 'classification']
-    mode =M10.radio("Supervised modelling mode", options=modes)
-    if mode == 'regression':
-        reg_algo = ["","PLS", "LW-PLS", "TPE-iPLS"]
-        regression_algo = M20.selectbox("Choose the regression algorithm", options= reg_algo, key = 12, placeholder ="Choose an option")
-
-    elif mode == 'classification':
-        reg_algo = ["","PLS", "LW-PLS", "TPE-iPLS"]
-        regression_algo = M20.selectbox("Choose the classification algorithm", options= reg_algo, key = 12, placeholder ="Choose an option")
 
+    # select type of supervised modelling problem
+    modes = ['regression', 'classification']
+    mode =M10.radio("Analysis Methods", options=modes)
+    match mode:
+        case "regression":
+            reg_algo = ["","PLS", "LW-PLS", "TPE-iPLS"]
+            regression_algo = M20.selectbox("Choose the regression algorithm", options= reg_algo, key = 12, format_func=lambda x: x if x else "<Select>")
+        case 'classification':
+            reg_algo = ["","PLS", "LW-PLS", "TPE-iPLS"]
+            regression_algo = M20.selectbox("Choose the classification algorithm", options= reg_algo, key = 12, format_func=lambda x: x if x else "<Select>")
     
-
-    # split train data into nb_folds for cross_validation
+    # Training set preparation for cross-validation(CV)
     nb_folds = 3
-    folds = KF_CV.CV(X_train, y_train, nb_folds)
-
-    if not regression_algo:
-        M20.warning('Choose a modelling algorithm from the dropdown list !')
-    else:
-        M1, M2 = st.columns([2 ,4])
-    if regression_algo == reg_algo[1]:
-        # Train model with model function from application_functions.py
-        Reg = Plsr(train = [X_train, y_train], test = [X_test, y_test], n_iter=1)
-        reg_model = Reg.model_
-        #M2.dataframe(Pin.pred_data_)
-
-    elif regression_algo == reg_algo[2]:
-        M20.write(f'K-Fold for Cross-Validation (K = {str(nb_folds)})')
-        info = M20.info('Starting LWPLSR model creation... Please wait a few minutes.')
-        # export data to csv for Julia train/test
-        data_to_work_with = ['x_train_np', 'y_train_np', 'x_test_np', 'y_test_np']
-        x_train_np, y_train_np, x_test_np, y_test_np = X_train.to_numpy(), y_train.to_numpy(), X_test.to_numpy(), y_test.to_numpy()
-        # Cross-Validation calculation
-
-        d = {}
-        for i in range(nb_folds):
-            d["xtr_fold{0}".format(i+1)], d["ytr_fold{0}".format(i+1)], d["xte_fold{0}".format(i+1)], d["yte_fold{0}".format(i+1)] = np.delete(x_train_np, folds[list(folds)[i]], axis=0), np.delete(y_train_np, folds[list(folds)[i]], axis=0), x_train_np[folds[list(folds)[i]]], y_train_np[folds[list(folds)[i]]]
-            data_to_work_with.append("xtr_fold{0}".format(i+1))
-            data_to_work_with.append("ytr_fold{0}".format(i+1))
-            data_to_work_with.append("xte_fold{0}".format(i+1))
-            data_to_work_with.append("yte_fold{0}".format(i+1))
-        # check best pre-treatment with a global PLSR model
-        preReg = Plsr(train = [X_train, y_train], test = [X_test, y_test], n_iter=20)
-        # M2.write(preReg.best_hyperparams_)
-        temp_path = Path('temp/')
-        with open(temp_path / "lwplsr_preTreatments.json", "w+") as outfile:
-            json.dump(preReg.best_hyperparams_, outfile)
-        # export Xtrain, Xtest, Ytrain, Ytest and all CV folds to temp folder as csv files
-        for i in data_to_work_with:
-            if 'fold' in i:
-                j = d[i]
-            else:
-                j = globals()[i]
-            np.savetxt(temp_path / str(i + ".csv"), j, delimiter=",")
-        # run Julia Jchemo as subprocess
-        import subprocess
-        subprocess_path = Path("Class_Mod/")
-        subprocess.run([f"{sys.executable}", subprocess_path / "LWPLSR_Call.py"])
-        # retrieve json results from Julia JChemo
-        try:
-            with open(temp_path / "lwplsr_outputs.json", "r") as outfile:
-                Reg_json = json.load(outfile)
-                # delete csv files
-                for i in data_to_work_with: os.unlink(temp_path / str(i + ".csv"))
-            # # delete json file after import
-            os.unlink(temp_path / "lwplsr_outputs.json")
-            os.unlink(temp_path / "lwplsr_preTreatments.json")
-            # format result data into Reg object
-            pred = ['pred_data_train', 'pred_data_test']### keys of the dict
-            for i in range(nb_folds):
-                pred.append("CV" + str(i+1)) ### add cv folds keys to pred
-
-            Reg = type('obj', (object,), {'model_' : Reg_json['model'], 'best_hyperparams_' : Reg_json['best_lwplsr_params'],
-                                          'pred_data_' : [pd.json_normalize(Reg_json[i]) for i in pred]})
+    folds = KF_CV.CV(X_train, y_train, nb_folds)# split train data into nb_folds for cross_validation
+
+    M1, M2 = st.columns([2 ,4])
+    # Model creation
+    match regression_algo:
+        case "":
+            M20.warning('Choose a modelling algorithm from the dropdown list !')
+        case "PLS":
+            Reg = Plsr(train = [X_train, y_train], test = [X_test, y_test], n_iter=1)
             reg_model = Reg.model_
-            Reg.CV_results_ = pd.DataFrame()
-            Reg.cv_data_ = {'YpredCV' : {}, 'idxCV' : {}}
-            # # set indexes to Reg.pred_data (train, test, folds idx)
-            for i in range(len(pred)):
-                Reg.pred_data_[i] = Reg.pred_data_[i].T.reset_index().drop(columns = ['index'])
-                if i == 0: # data_train
-                    # Reg.pred_data_[i] = np.array(Reg.pred_data_[i])
-                    Reg.pred_data_[i].index = list(y_train.index)
-                    Reg.pred_data_[i] = Reg.pred_data_[i].iloc[:,0]
-                elif i == 1: # data_test
-                    # Reg.pred_data_[i] = np.array(Reg.pred_data_[i])
-                    Reg.pred_data_[i].index = list(y_test.index)
-                    Reg.pred_data_[i] = Reg.pred_data_[i].iloc[:,0]
+        case 'LW-PLS':
+            M20.write(f'K-Fold for Cross-Validation (K = {str(nb_folds)})')
+            info = M20.info('Starting LWPLSR model creation... Please wait a few minutes.')
+            # export data to csv for Julia train/test
+            data_to_work_with = ['x_train_np', 'y_train_np', 'x_test_np', 'y_test_np']
+            x_train_np, y_train_np, x_test_np, y_test_np = X_train.to_numpy(), y_train.to_numpy(), X_test.to_numpy(), y_test.to_numpy()
+            # Cross-Validation calculation
+
+            d = {}
+            for i in range(nb_folds):
+                d["xtr_fold{0}".format(i+1)], d["ytr_fold{0}".format(i+1)], d["xte_fold{0}".format(i+1)], d["yte_fold{0}".format(i+1)] = np.delete(x_train_np, folds[list(folds)[i]], axis=0), np.delete(y_train_np, folds[list(folds)[i]], axis=0), x_train_np[folds[list(folds)[i]]], y_train_np[folds[list(folds)[i]]]
+                data_to_work_with.append("xtr_fold{0}".format(i+1))
+                data_to_work_with.append("ytr_fold{0}".format(i+1))
+                data_to_work_with.append("xte_fold{0}".format(i+1))
+                data_to_work_with.append("yte_fold{0}".format(i+1))
+            # check best pre-treatment with a global PLSR model
+            preReg = Plsr(train = [X_train, y_train], test = [X_test, y_test], n_iter=20)
+            temp_path = Path('temp/')
+            with open(temp_path / "lwplsr_preTreatments.json", "w+") as outfile:
+                json.dump(preReg.best_hyperparams_, outfile)
+            # export Xtrain, Xtest, Ytrain, Ytest and all CV folds to temp folder as csv files
+            for i in data_to_work_with:
+                if 'fold' in i:
+                    j = d[i]
                 else:
-                    # CVi
-                    Reg.pred_data_[i].index = folds[list(folds)[i-2]]
-                    # Reg.CV_results_ = pd.concat([Reg.CV_results_, Reg.pred_data_[i]])
-                    Reg.cv_data_['YpredCV']['Fold' + str(i-1)] = np.array(Reg.pred_data_[i]).reshape(-1)
-                    Reg.cv_data_['idxCV']['Fold' + str(i-1)] = np.array(folds[list(folds)[i-2]]).reshape(-1)
-
-            Reg.CV_results_= KF_CV.metrics_cv(y = y_train, ypcv = Reg.cv_data_['YpredCV'], folds = folds)[1]
-            #### cross validation results print
-            Reg.best_hyperparams_print = Reg.best_hyperparams_
-            ## plots
-            Reg.cv_data_ = KF_CV().meas_pred_eq(y = np.array(y_train), ypcv= Reg.cv_data_['YpredCV'], folds=folds)
-            Reg.pretreated_spectra_ = preReg.pretreated_spectra_
-            
-            Reg.best_hyperparams_print = {**preReg.best_hyperparams_, **Reg.best_hyperparams_}
-            Reg.best_hyperparams_ = {**preReg.best_hyperparams_, **Reg.best_hyperparams_}
-            info.empty()
-            M20.success('Model created!')
-        except FileNotFoundError as e:
-            # Display error message on the interface if modeling is wrong
-            info.empty()
-            M20.warning('- ERROR during model creation -')
-            Reg = None
-            for i in data_to_work_with: os.unlink(temp_path / str(i + ".csv"))
-
-
-#######################
-
-
-            
-    elif regression_algo == reg_algo[3]:
-        s = M20.number_input(label='Enter the maximum number of intervals', min_value=1, max_value=6, value=3)
-        it = M20.number_input(label='Enter the number of iterations', min_value=1, max_value=3, value=2)
-        progress_text = "The model is being created. Please wait."
-            
-        Reg = TpeIpls(train = [X_train, y_train], test=[X_test, y_test], n_intervall = s, n_iter=it)
-        pro = M1.progress(0, text="The model is being created. Please wait!")
-        pro.empty()
-        M20.progress(100, text = "The model has successfully been  created!")            
-        time.sleep(1)
-        reg_model = Reg.model_
-
-        
-        intervalls = Reg.selected_features_.T
-        intervalls_with_cols = Reg.selected_features_.T
-        for i in range(intervalls.shape[0]):
-            for j in range(intervalls.shape[1]):
-                intervalls_with_cols.iloc[i,j] = spectra.columns[intervalls.iloc[i,j]]
-        M2.write('-- Important Spectral regions used for model creation --')
-        M2.table(intervalls_with_cols)
-        
-    # elif regression_algo == reg_algo[4]:
-    #     Reg = PlsR(x_train = X_train, x_test = X_test, y_train = y_train, y_test = y_test)
-    #     reg_model = Reg.model_
-
-
+                    j = globals()[i]
+                np.savetxt(temp_path / str(i + ".csv"), j, delimiter=",")
+            # run Julia Jchemo as subprocess
+            import subprocess
+            subprocess_path = Path("Class_Mod/")
+            subprocess.run([f"{sys.executable}", subprocess_path / "LWPLSR_Call.py"])
+            # retrieve json results from Julia JChemo
+            try:
+                with open(temp_path / "lwplsr_outputs.json", "r") as outfile:
+                    Reg_json = json.load(outfile)
+                    # delete csv files
+                    for i in data_to_work_with: os.unlink(temp_path / str(i + ".csv"))
+                # # delete json file after import
+                os.unlink(temp_path / "lwplsr_outputs.json")
+                os.unlink(temp_path / "lwplsr_preTreatments.json")
+                # format result data into Reg object
+                pred = ['pred_data_train', 'pred_data_test']### keys of the dict
+                for i in range(nb_folds):
+                    pred.append("CV" + str(i+1)) ### add cv folds keys to pred
+
+                Reg = type('obj', (object,), {'model_' : Reg_json['model'], 'best_hyperparams_' : Reg_json['best_lwplsr_params'],
+                                            'pred_data_' : [pd.json_normalize(Reg_json[i]) for i in pred]})
+                reg_model = Reg.model_
+                Reg.CV_results_ = pd.DataFrame()
+                Reg.cv_data_ = {'YpredCV' : {}, 'idxCV' : {}}
+                # # set indexes to Reg.pred_data (train, test, folds idx)
+                for i in range(len(pred)):
+                    Reg.pred_data_[i] = Reg.pred_data_[i].T.reset_index().drop(columns = ['index'])
+                    if i == 0: # data_train
+                        # Reg.pred_data_[i] = np.array(Reg.pred_data_[i])
+                        Reg.pred_data_[i].index = list(y_train.index)
+                        Reg.pred_data_[i] = Reg.pred_data_[i].iloc[:,0]
+                    elif i == 1: # data_test
+                        # Reg.pred_data_[i] = np.array(Reg.pred_data_[i])
+                        Reg.pred_data_[i].index = list(y_test.index)
+                        Reg.pred_data_[i] = Reg.pred_data_[i].iloc[:,0]
+                    else:
+                        # CVi
+                        Reg.pred_data_[i].index = folds[list(folds)[i-2]]
+                        # Reg.CV_results_ = pd.concat([Reg.CV_results_, Reg.pred_data_[i]])
+                        Reg.cv_data_['YpredCV']['Fold' + str(i-1)] = np.array(Reg.pred_data_[i]).reshape(-1)
+                        Reg.cv_data_['idxCV']['Fold' + str(i-1)] = np.array(folds[list(folds)[i-2]]).reshape(-1)
+
+                Reg.CV_results_= KF_CV.metrics_cv(y = y_train, ypcv = Reg.cv_data_['YpredCV'], folds = folds)[1]
+                #### cross validation results print
+                Reg.best_hyperparams_print = Reg.best_hyperparams_
+                ## plots
+                Reg.cv_data_ = KF_CV().meas_pred_eq(y = np.array(y_train), ypcv= Reg.cv_data_['YpredCV'], folds=folds)
+                Reg.pretreated_spectra_ = preReg.pretreated_spectra_
+                
+                Reg.best_hyperparams_print = {**preReg.best_hyperparams_, **Reg.best_hyperparams_}
+                Reg.best_hyperparams_ = {**preReg.best_hyperparams_, **Reg.best_hyperparams_}
+                info.empty()
+                M20.success('Model created!')
+            except FileNotFoundError as e:
+                # Display error message on the interface if modeling is wrong
+                info.empty()
+                M20.warning('- ERROR during model creation -')
+                Reg = None
+                for i in data_to_work_with: os.unlink(temp_path / str(i + ".csv"))
 
+        case 'TPE-iPLS':
+            s = M20.number_input(label='Enter the maximum number of intervals', min_value=1, max_value=6, value=3)
+            it = M20.number_input(label='Enter the number of iterations', min_value=1, max_value=3, value=2)
+            progress_text = "The model is being created. Please wait."
+                
+            Reg = TpeIpls(train = [X_train, y_train], test=[X_test, y_test], n_intervall = s, n_iter=it)
+            pro = M1.info("The model is being created. Please wait!")
+            pro.empty()
+            M20.info("The model has successfully been  created!")            
+            time.sleep(1)
+            reg_model = Reg.model_
+            intervalls = Reg.selected_features_.T
+            intervalls_with_cols = Reg.selected_features_.T
+            for i in range(intervalls.shape[0]):
+                for j in range(intervalls.shape[1]):
+                    intervalls_with_cols.iloc[i,j] = spectra.columns[intervalls.iloc[i,j]]
+            M2.write('-- Important Spectral regions used for model creation --')
+            M2.table(intervalls_with_cols)
 
-#         ###############################################################################################################DDDVVVVVVVVVV
         
-#        ################# Model analysis ############
-if not spectra.empty and not y.empty:
-    if regression_algo in reg_algo[1:] and Reg is not None:
-        #M2.write('-- Pretreated data (train) visualization and important spectral regions in the model --   ')
-
+       ################# Model analysis ############
+if not (spectra.empty and y.empty):
+    if regression_algo in reg_algo[1:] and Reg:
         fig, (ax1, ax2) = plt.subplots(2,1, figsize = (12, 6))
         fig = make_subplots(rows=3, cols=1, shared_xaxes=True, vertical_spacing=0.02)
-        # fig.append_trace(go.Scatter(x=[3, 4, 5],
-        #                             y=[1000, 1100, 1200],), row=1, col=1)
-
-        # fig.append_trace(go.Scatter(x=[2, 3, 4],
-        #                             y=[100, 110, 120],), row=2, col=1)
-
-        # fig.append_trace(go.Scatter(x=[0, 1, 2],
-        #                             y=[10, 11, 12]), row=3, col=1)
-
-        # fig.update_layout(height=600, width=600, title_text="Stacked Subplots")   
-        # a = Reg.pretreated_spectra_
-        # r = pd.concat([y_train, a], axis = 1)
-        # rr = r.melt("x")
-        # rr.columns = ['y values', 'x_axis', 'y_axis']
-        # fig = px.scatter(rr, x = 'x_axis', y = 'y_axis', color_continuous_scale=px.colors.sequential.Viridis, color = 'y values')
-        # M3.plotly_chart(fig)
-        
-        
-        # from matplotlib.colors import Normalize
-        # color_variable = y_train
-        # norm = Normalize(vmin=color_variable.min(), vmax= color_variable.max())
-        # cmap = plt.get_cmap('viridis')
-        # colors = cmap(norm(color_variable.values))
-        # fig, ax = plt.subplots(figsize = (10,3))
-
-        # for i in range(Reg.pretreated_spectra_.shape[0]):
-        #     ax.plot(Reg.pretreated_spectra_.columns, Reg.pretreated_spectra_.iloc[i,:], color = colors[i])
-        # sm = ScalarMappable(norm = norm, cmap = cmap)
-        # cbar = plt.colorbar(sm, ax = ax)
-        # # cbar.set_label('Target range') 
-        # plt.tight_layout()      
-        # htmlfig = mpld3.fig_to_html(fig)
-        # with M2:
-        #     st.components.v1.html(htmlfig, height=600)
         
         st.header("Cross-Validation results")
         cv1, cv2 = st.columns([2,2])
@@ -376,9 +327,7 @@ if not spectra.empty and not y.empty:
         yc = Reg.pred_data_[0]
         yt = Reg.pred_data_[1]
             
-        #if
         M1.write('-- Spectral preprocessing info --')
-        
         M1.write(Reg.best_hyperparams_print)
         with open("data/params/Preprocessing.json", "w") as outfile:
             json.dump(Reg.best_hyperparams_, outfile)
@@ -389,14 +338,9 @@ if not spectra.empty and not y.empty:
         if regression_algo != reg_algo[2]:
             M1.dataframe(metrics(c = [y_train, yc], t = [y_test, yt], method='regression').scores_)
         else:
-            M1.dataframe(metrics(c = [y_train, yc], t = [y_test, yt], method='regression').scores_)
+            M1.dataframe(metrics(t = [y_test, yt], method='regression').scores_)
         model_per=pd.DataFrame(metrics(c = [y_train, yc], t = [y_test, yt], method='regression').scores_)
-        #from st_circular_progress import CircularProgress
-        #my_circular_progress = CircularProgress(label = 'Performance',value = 50, key = 'my performance',
-        #                                         size = "medium", track_color = "black", color = "blue")
         
-        #my_circular_progress.st_circular_progress()
-        #my_circular_progress.update_value(progress=20)
         if regression_algo != reg_algo[2]:
             a = reg_plot([y_train, y_test],[yc, yt], train_idx = train_index, test_idx = test_index)
         else:
@@ -404,7 +348,7 @@ if not spectra.empty and not y.empty:
 
 st.header("III - Model Diagnosis", divider='blue')
 if not spectra.empty and not y.empty:
-    if regression_algo in reg_algo[1:] and Reg is not None:
+    if regression_algo in reg_algo[1:] and Reg:
         
         M7, M8 = st.columns([2,2])
         M7.write('Predicted vs Measured values')
@@ -434,35 +378,30 @@ if not spectra.empty and not y.empty:
         
         if regression_algo != reg_algo[2]:
             rega = Reg.selected_features_  ##### ADD FEATURES IMPORTANCE PLOT
-            
-            #model_export = M1.selectbox("Choose way to export", options=["pickle", "joblib"], key=20)
-            
+
         M9 = st.container()
         M9.write("-- Save the model --")
         model_name = M9.text_input('Give it a name')
         date_time = datetime.datetime.strftime(datetime.date.today(), '_%Y_%m_%d_')
         if M9.button('Export Model'):
             path = 'data/models/model_'
-            if file == files_format[0]:
-                #export_package = __import__(model_export)
-                with open(path + model_name + date_time + '_created_on_' + xcal_csv.name[:xcal_csv.name.find(".")] +""+
-                           '_and_' + ycal_csv.name[:ycal_csv.name.find(".")] + '_data_' + '.pkl','wb') as f:
-                    joblib.dump(reg_model, f)
-                    if regression_algo == reg_algo[3]:
-                        Reg.selected_features_.T.to_csv(path + model_name + date_time + '_on_' + xcal_csv.name[:xcal_csv.name.find(".")]
-                                                      + '_and_' + ycal_csv.name[:ycal_csv.name.find(".")] + '_data_'+'Wavelengths_index.csv', sep = ';')
-
-            elif file == files_format[1]:
-                #export_package = __import__(model_export)
-                with open(path + model_name + '_on_'+ data_file.name[:data_file.name.find(".")] + '_data_' + '.pkl','wb') as f:
-                    joblib.dump(reg_model, f)
-                    if regression_algo == reg_algo[3]:
-                        Reg.selected_features_.T.to_csv(path +data_file.name[:data_file.name.find(".")]+ model_name + date_time+ '_on_' + '_data_'+'Wavelengths_index.csv', sep = ';')
-                        st.write('Model Exported ')
-
-                # create a report with information on the model
-                ## see https://stackoverflow.com/a/59578663
-
+            match file:
+                case '.csv':
+                    #export_package = __import__(model_export)
+                    with open(path + model_name + date_time + '_created_on_' + xcal_csv.name[:xcal_csv.name.find(".")] +""+
+                            '_and_' + ycal_csv.name[:ycal_csv.name.find(".")] + '_data_' + '.pkl','wb') as f:
+                        joblib.dump(reg_model, f)
+                        if regression_algo == reg_algo[3]:
+                            Reg.selected_features_.T.to_csv(path + model_name + date_time + '_on_' + xcal_csv.name[:xcal_csv.name.find(".")]
+                                                        + '_and_' + ycal_csv.name[:ycal_csv.name.find(".")] + '_data_'+'Wavelengths_index.csv', sep = ';')
+
+                case '.dx':
+                    #export_package = __import__(model_export)
+                    with open(path + model_name + '_on_'+ data_file.name[:data_file.name.find(".")] + '_data_' + '.pkl','wb') as f:
+                        joblib.dump(reg_model, f)
+                        if regression_algo == reg_algo[3]:
+                            Reg.selected_features_.T.to_csv(path +data_file.name[:data_file.name.find(".")]+ model_name + date_time+ '_on_' + '_data_'+'Wavelengths_index.csv', sep = ';')
+                            st.write('Model Exported ')
 
         if st.session_state['interface'] == 'simple':
             pages_folder = Path("pages/")
@@ -479,7 +418,7 @@ if not spectra.empty and not y.empty:
 
 
 if not spectra.empty and not y.empty and regression_algo:
-    if regression_algo in reg_algo[1:] and Reg is not None:
+    if regression_algo in reg_algo[1:] and Reg:
         fig, (ax1, ax2) = plt.subplots(2,1, figsize = (12, 4), sharex=True)
         ax1.plot(colnames, np.mean(X_train, axis = 0), color = 'black', label = 'Average spectrum (Raw)')
         # if regression_algo != reg_algo[2]:
@@ -503,8 +442,6 @@ if not spectra.empty and not y.empty and regression_algo:
 
 
         if regression_algo == reg_algo[1]:
-                # st.write(colnames[np.array(Reg.sel_ratio_.index)])
-                # st.write(colnames[np.array(Reg.sel_ratio_.index)])
                 ax1.scatter(colnames[np.array(Reg.sel_ratio_.index)], np.mean(X_train, axis = 0)[np.array(Reg.sel_ratio_.index)],
                              color = 'red', label = 'Important variables')
                 ax2.scatter(colnames[Reg.sel_ratio_.index], np.mean(Reg.pretreated_spectra_, axis = 0)[np.array(Reg.sel_ratio_.index)],
@@ -515,31 +452,25 @@ if not spectra.empty and not y.empty and regression_algo:
         M2.write('-- Visualization of the spectral regions used for model creation --')
         fig.savefig("./Report/figures/Variable_importance.png")
         M2.pyplot(fig)
-        # if regression_algo == reg_algo[3]:
-        #     M2.write('-- Important Spectral regions used for model creation --')
-        #     M2.table(intervalls_with_cols)
 
-## Load .dx file
-if Reg is not None:
+
+######################## Download report ###############################
+if Reg:
     with st.container():
         if st.button("Download the report"):
-            if regression_algo == reg_algo[1]:
-                    latex_report = report.report('Predictive model development', file_name, stats, list(Reg.best_hyperparams_.values()), regression_algo, model_per, cv_results)
-    
-            elif regression_algo == reg_algo[2]:
-                    latex_report = report.report('Predictive model development', file_name, stats,
-                                                  list({key: Reg.best_hyperparams_[key] for key in ['deriv', 'normalization', 'polyorder', 'window_length'] if key in Reg.best_hyperparams_}.values()), regression_algo, model_per, cv_results)
-                    
-            elif regression_algo == reg_algo[3]:
-                    latex_report = report.report('Predictive model development', file_name, stats,
-                                                  list({key: Reg.best_hyperparams_[key] for key in ['deriv', 'normalization', 'polyorder', 'window_length'] if key in Reg.best_hyperparams_}.values()), regression_algo, model_per, cv_results)
-                    
+            match regression_algo:
+                case 'PLS':
+                        latex_report = report.report('Predictive model development', file_name, stats, list(Reg.best_hyperparams_.values()), regression_algo, model_per, cv_results)
+
+                case 'LW-PLS':
+                        latex_report = report.report('Predictive model development', file_name, stats,
+                                                    list({key: Reg.best_hyperparams_[key] for key in ['deriv', 'normalization', 'polyorder', 'window_length'] if key in Reg.best_hyperparams_}.values()), regression_algo, model_per, cv_results)
+                        
+                case 'TPE-iPLS':
+                        latex_report = report.report('Predictive model development', file_name, stats,
+                                                    list({key: Reg.best_hyperparams_[key] for key in ['deriv', 'normalization', 'polyorder', 'window_length'] if key in Reg.best_hyperparams_}.values()), regression_algo, model_per, cv_results)
+                        
+                case _:
+                    st.warning('Data processing has not been performed or finished yet!', icon = "⚠️")
 
-            
-            if regression_algo is None:
-                st.warning('Data processing has not been performed or finished yet!', icon = "⚠️")
-            else:
-                pass
             report.compile_latex()
-        else:
-            pass
\ No newline at end of file
-- 
GitLab