Skip to content
Snippets Groups Projects
3-prediction.py 19 KiB
Newer Older
DIANE's avatar
DIANE committed
from common import *
st.set_page_config(page_title="NIRS Utils", page_icon=":goat:", layout="wide")
DIANE's avatar
DIANE committed





# layout
UiComponents(pagespath = pages_folder, csspath= css_file,imgpath=image_path ,
             header=True, sidebar= True, bgimg=False, colborders=True)
# HTML pour le bandeau "CEFE - CNRS"
# bandeau_html = """
# <div style="width: 100%; background-color: #4682B4; padding: 10px; margin-bottom: 10px;">
#   <h1 style="text-align: center; color: white;">CEFE - CNRS</h1>
# </div>
# """
# # Injecter le code HTML du bandeau
# st.markdown(bandeau_html, unsafe_allow_html=True)
DIANE's avatar
DIANE committed
# PageStyle(pages_folder)
DIANE's avatar
DIANE committed
# local_css(css_file / "style_model.css")
hash_ = ''
def p_hash(add):
    global hash_
    hash_ = hash_data(hash_+str(add))
    return hash_
DIANE's avatar
DIANE committed

dirpath = Path('report/out/model')
if dirpath.exists() and dirpath.is_dir():
DIANE's avatar
DIANE committed
    rmtree(dirpath)
DIANE's avatar
DIANE committed

DIANE's avatar
DIANE committed
if 'Predict' not in st.session_state:
    st.session_state['Predict'] = False
# ####################################  Methods ##############################################
# empty temp figures
def delete_files(keep):
    supp = []
    # Walk through the directory
    for root, dirs, files in os.walk('report/', topdown=False):
        for file in files:
            if file != 'logo_cefe.png' and not any(file.endswith(ext) for ext in keep):
                os.remove(os.path.join(root, file))
###################################################################
DIANE's avatar
DIANE committed
                
DIANE's avatar
DIANE committed
st.header("Prediction making using a previously developed model")
DIANE's avatar
DIANE committed
c1, c2 = st.columns([2, 1])
c1.image("./images/prediction making.png", use_column_width=True)
pred_data = DataFrame


def preparespecdf(df):
    other = df.select_dtypes(exclude = 'float')
    spec = df.select_dtypes(include='float')
DIANE's avatar
DIANE committed
    if other.shape[1] > 0:
        rownames = other.iloc[:,0]
        spec.index = rownames
    else:
        rownames = [str(i) for i in range(df.shape[0])]
    if spec.shape[1]<60:
        spec = DataFrame
    return spec, other, rownames
DIANE's avatar
DIANE committed

def check_exist(var):
    out = var in globals()
    return out
DIANE's avatar
DIANE committed

DIANE's avatar
DIANE committed
with c2:
    zip = st.file_uploader("Load your zip file:", type = ['.zip'], help=" :mushroom: select a csv matrix with samples as rows and lambdas as columns")
    if not zip:
        st.info('Info: Insert your zip file above!')
    
    disable1 = False if zip else True
    new_data = st.file_uploader("Load NIRS Data for prediction making:", type = ['csv', 'dx'], help=" :mushroom: select a csv matrix with samples as rows and lambdas as columns", disabled=disable1)
    if not disable1 :
        info1 = st.info('Info: Insert your NIRS data file above!')
DIANE's avatar
DIANE committed
    if zip:
        @st.cache_data
        def tempdir(prefix, dir):
DIANE Abderrahim's avatar
DIANE Abderrahim committed
            from tempfile import TemporaryDirectory
DIANE's avatar
DIANE committed
            with TemporaryDirectory( prefix= prefix, dir= dir ) as temp_dir:# create a temp directory
                tempdirname = os.path.split(temp_dir)[1]
            return tempdirname
        
        temp_dir = tempdir(prefix = "pred_temp", dir = "./temp")
        # Open and extract the zip file
        from zipfile import ZipFile
        with ZipFile(zip, 'r') as zip_ref:
            zip_ref.extractall(temp_dir)
DIANE's avatar
DIANE committed
            
        def find_pkl_files(root_dir):
            # List to store the paths of .pkl files
            pkl_files = []

            # Walk through the directory
            for dirpath, dirnames, filenames in os.walk(root_dir):
                for filename in filenames:
                    # Check if the file has a .pkl extension
                    if filename.endswith('.pkl'):
                        # Construct the full file path
                        file_path = os.path.join(dirpath, filename)
                        pkl_files.append(file_path)
            return pkl_files
        pkl = find_pkl_files(root_dir=temp_dir)

        system_file = [path for path in pkl if 'file_system' in path]
DIANE's avatar
DIANE committed
            with open(system_file[0], 'rb') as fi:
DIANE Abderrahim's avatar
DIANE Abderrahim committed
                from joblib import load
DIANE's avatar
DIANE committed
                system_data = load(fi)

if new_data:
        info1.empty()

with c2:
    if new_data:
DIANE Abderrahim's avatar
DIANE Abderrahim committed
        hash_ = ObjectHash(current = hash_,add = new_data.name)
DIANE's avatar
DIANE committed
        test = new_data.name.split('.')[-1]
        export_name = 'Pred_of'
DIANE's avatar
DIANE committed
        export_name += new_data.name[:new_data.name.find('.')]

        match test:
            case 'csv':
DIANE Abderrahim's avatar
DIANE Abderrahim committed
                c1_1, c2_2 = st.columns([.5, .5])
                with c1_1:
                    qdec = st.radio('decimal(x):', options= [".", ","], horizontal = True)
                    qsep = st.radio("separator(x):", options = [";", ","], horizontal = True)
                with c2_2:
                    qhdr = st.radio("header(x): ", options = ["yes", "no"], horizontal = True)
                    qnames = st.radio("samples name(x):", options = ["yes", "no"], horizontal = True)

                qhdr = 0 if qhdr =="yes" else None
                qnames = 0 if qnames =="yes" else None
                hash_ = ObjectHash(current = hash_,add = [qsep, qhdr, qnames, qdec])



                def read_csv(file = None, change = None, dec = None, sep= None, names = None, hdr = None):
                    delete_files(keep = ['.py', '.pyc','.bib'])
                    from utils.data_parsing import CsvParser
                    if file is not None:
                        par = CsvParser(file= file)
                        par.parse(decimal = dec, separator = sep, index_col = names, header = hdr)
                        return par.float, par.meta_data, par.meta_data_st_, par.df


                try:
                    pred_data, _, _, df = read_csv(file= new_data, change = hash_, dec = qdec, sep = qsep,
                     names =qnames, hdr = qhdr)
                    rownames = pred_data.index
                    st.success('file has been loaded successfully')
                except:
                    df = read_csv(new_data, sep=qsep, header= col, decimal=".")
                    pred_data, cat, rownames =  preparespecdf(df)
DIANE's avatar
DIANE committed

            case "dx":
                with NamedTemporaryFile(delete=False, suffix=".dx") as tmp:
                    tmp.write(new_data.read())
                    tmp_path = tmp.name
                    with open(tmp.name, 'r') as dd:
                        dxdata = new_data.read()
                        p_hash(str(dxdata)+str(new_data.name))

                    ## load and parse the temp dx file
                    @st.cache_data
                    def dx_loader(change):
                        chem_data, spectra, meta_data, _ = read_dx(file =  tmp_path)
                        return chem_data, spectra, meta_data, _
                    chem_data, spectra, meta_data, _ = dx_loader(change = hash_)
                    st.success("The data have been loaded successfully", icon="")
                    if chem_data.to_numpy().shape[1]>0:
                        # yname = st.selectbox('Select target', options=chem_data.columns, index=chem_data.columns.to_list().index(system_data['data']['target'].name))
                        yname = system_data['data']['target'].name
                        st.info("Loaded model to predict " + yname)
DIANE's avatar
DIANE committed
                        measured = chem_data.loc[:,yname] == 0
                        y = chem_data.loc[:,yname].loc[measured]
                        pred_data = spectra.loc[measured]
                    
                    else:
                        pred_data = spectra
                os.unlink(tmp_path)
DIANE's avatar
DIANE committed

DIANE's avatar
DIANE committed
# Load parameters
DIANE's avatar
DIANE committed
st.subheader("I - Spectral data preprocessing & visualization", divider='blue')
DIANE's avatar
DIANE committed
# try:
if not pred_data.empty:# Load the model with joblib
    @st.cache_data
DIANE Abderrahim's avatar
DIANE Abderrahim committed
        from scipy.signal import savgol_filter

DIANE's avatar
DIANE committed
        # M4.write(ProcessLookupError)
        
        if system_data['spec-preprocessing']['normalization'] == 'Snv':
DIANE's avatar
DIANE committed
            norm = 'Standard Normal Variate'
        else:
            norm = 'No Normalization was applied'
DIANE's avatar
DIANE committed
        x2 = savgol_filter(x1,
                            window_length = int(system_data['spec-preprocessing']['SavGol(polyorder,window_length,deriv)'][1]),
                            polyorder = int(system_data['spec-preprocessing']['SavGol(polyorder,window_length,deriv)'][0]),
                            deriv = int(system_data['spec-preprocessing']['SavGol(polyorder,window_length,deriv)'][2]),
                                delta=1.0, axis=-1, mode="interp", cval=0.0)
        preprocessed = DataFrame(x2, index = data.index, columns = data.columns)
DIANE's avatar
DIANE committed
        return norm, preprocessed
    norm, preprocessed = preprocess_spectra(pred_data, change= hash_)
DIANE's avatar
DIANE committed

                        # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
    # @st.cache_data
    # def specplot_raw(change):
    #     fig2 = plot_spectra(pred_data, xunits = 'lab', yunits = "meta_data.loc[:,'yunits'][0]")
    #     return fig2
    # rawspectraplot = specplot_raw(change = hash_)
    rawspectraplot = plot_spectra(pred_data, xunits = 'Wavelength/Wavenumber', yunits = "Signal intensity")

    c3, c4 = st.columns([2, 1])
    with c3:
        st.write('Raw spectra')
        st.pyplot(rawspectraplot)

        ## plot preprocessed spectra
        if check_exist("preprocessed"):
            # def specplot_prep(change):
            #     fig2 = plot_spectra(preprocessed, xunits = 'lab', yunits = "meta_data.loc[:,'yunits'][0]")
            #     return fig2
            # prepspectraplot = specplot_prep(change = hash_)
            prepspectraplot = plot_spectra(preprocessed, xunits = 'Wavelength/Wavenumber', yunits = "Signal intensity")
            st.write('Preprocessed spectra')
            st.pyplot(prepspectraplot)
DIANE's avatar
DIANE committed
    with c4:
        @st.cache_data
        def prep_info(change):
            SG = f'- Savitzky-Golay derivative parameters \n:(Window_length:{system_data['spec-preprocessing']['SavGol(polyorder,window_length,deriv)'][1]};  polynomial order: {system_data['spec-preprocessing']['SavGol(polyorder,window_length,deriv)'][0]};  Derivative order : {system_data['spec-preprocessing']['SavGol(polyorder,window_length,deriv)'][2]})'
            Norm = f'- Spectral Normalization \n: {system_data['spec-preprocessing']['normalization']}'
            return SG, Norm
        SG, Norm = prep_info(change = hash_)
        st.info('The spectra were preprocessed using:\n'+SG+"\n"+Norm)

    ################### Predictions making  ##########################
DIANE's avatar
DIANE committed
    st.subheader("II - Prediction making", divider='blue')
DIANE's avatar
DIANE committed
    
    disable2 = False if check_exist("pred_data") else True
    pred_button = st.button('Predict', type='primary', disabled= disable2, use_container_width=False)
    if pred_button:st.session_state['Predict'] = True

    if st.session_state['Predict']:
        if  check_exist("pred_data"):# Load the model with joblib
            c5, c6 = st.columns([2, 1])
            with c6:
                model = system_data['model_']
                if system_data['model_type'] in ['PLS','TPE-iPLS']:
                    nvar = system_data['model_'].n_features_in_
                elif system_data['model_type']  =='LW-PLS':
                    nvar = system_data['data']['raw-spectra'].shape[1]


        if check_exist('preprocessed'):
            if isinstance(system_data['selected-wls']['idx'], DataFrame):
                idx = np.concatenate([np.arange(system_data['selected-wls']['idx'].values.reshape((-1,))[2*i],system_data['selected-wls']['idx'].values.reshape((-1,))[2*i+1]+1) for i in range(system_data['selected-wls']['idx'].shape[0])])
            else:
                idx = np.arange(nvar)

            if np.max(idx) <= preprocessed.shape[1]:
                preprocesseddf = preprocessed.iloc[:,idx] ### get predictors
            else:
                st.error("Error: The number of columns in your data does not match the number of columns used to train the model. Please ensure they are the same.")
            

        if check_exist("preprocesseddf"):
            if st.session_state['Predict'] and nvar == preprocesseddf.shape[1]:
            # if nvar == preprocesseddf.shape[1]:
                match system_data['model_type']:
                    case 'PLS'|'TPE-iPLS':
                        try:
                            result = DataFrame(system_data['model_'].predict(preprocesseddf), index = rownames, columns = ['Results'])
                        except:
                            st.error(f'''Error: Length mismatch: the number of samples indices is {len(rownames)}, while the model produced 
                                            {len(model.predict(preprocesseddf))} values. correct the "indexes column in csv?" parameter''')
                    case 'LW-PLS':
                        try:
                            temp_path = Path('temp/')
                            # export data to csv for Julia train/pred
                            # with pretreatments
                            spectra = preprocess_spectra(system_data['data']['raw-spectra'], change= hash_)
                            x_pred = preprocessed
                            rownames = x_pred.index.to_list()
                            y = system_data['data']['target']
                            data_to_work_with = ['spectra_np', 'y_np', 'x_pred_np']
                            spectra_np, y_np, x_pred_np = spectra[1].to_numpy(), y.to_numpy(), x_pred.to_numpy()
                            # export spectra, y, x_pred to temp folder as csv files
                            for i in data_to_work_with:
                                j = globals()[i]
                                np.savetxt(temp_path / str(i + ".csv"), j, delimiter=",")
                            # export best LWPLSR params
                            with open(temp_path / "lwplsr_best_params.json", "w+") as outfile:
                                json.dump(system_data['lwpls_params'], outfile)
                            # create empty file to specify LWPLSR_Call.py that we want predictions
                            open(temp_path / 'predict', 'w').close()
                            # # run Julia Jchemo as subprocess
                            import subprocess
                            subprocess_path = Path("utils/")
                            subprocess.run([f"{sys.executable}", subprocess_path / "LWPLSR_Call.py"])
                            # retrieve json results from Julia JChemo
                            try:
                                with open(temp_path / "lwplsr_outputs.json", "r") as outfile:
                                    Reg_json = json.load(outfile)
                                    # delete csv files
                                    for i in data_to_work_with: os.unlink(temp_path / str(i + ".csv"))
                                    os.unlink(temp_path / 'predict')
                                # delete json file after import
                                os.unlink(temp_path / "lwplsr_outputs.json")
                                os.unlink(temp_path / "lwplsr_best_params.json")
                                # format result data into Reg object
                                result = DataFrame(Reg_json['y_pred'])  ### keys of the json dict
                                result.index = rownames
                                result.columns = ['Results']
                            except FileNotFoundError as e:
                                Reg = None
                                for i in data_to_work_with: os.unlink(temp_path / str(i + ".csv"))
                                os.unlink(temp_path / 'predict')
                        except:
                            st.error('Error during LWPLSR predictions')
DIANE's avatar
DIANE committed

DIANE's avatar
DIANE committed
            ################################### results display ###################################
        if check_exist("preprocesseddf"):
            if preprocesseddf.shape[1]>1 and check_exist('result'):
                hist = pred_hist(pred=result)
                with c5:
                    st.write('Predicted values distribution')
                    st.pyplot(hist)
                    st.write('Predicted values table')
                    resultT = result.reset_index()
                    st.dataframe(resultT.T)
DIANE's avatar
DIANE committed
                with c6:
DIANE Abderrahim's avatar
DIANE Abderrahim committed
                    from utils.miscellaneous import desc_stats
DIANE's avatar
DIANE committed
                    st.info('descriptive statistics for the model output')
                    st.write(DataFrame(desc_stats(result)))
                    
            elif pred_button and nvar != preprocesseddf.shape[1]:
                with c6:
                    st.error(f'Error: The model was trained on {nvar} wavelengths, but you provided {preprocessed.shape[1]} wavelengths for prediction. Please ensure they match!')

            ################################# Download results #################################
        if check_exist('result'):
            @st.cache_data(show_spinner =False)
            def preparing_results_for_downloading(change):
                match test:
                    # load csv file
                    case 'csv':
                        df.to_csv('report/out/dataset/'+ new_data.name, sep = ';', encoding = 'utf-8', mode = 'a')
DIANE's avatar
DIANE committed
                    case 'dx':
                        with open('report/out/dataset/'+new_data.name, 'w') as dd:
DIANE's avatar
DIANE committed
                            dd.write(dxdata)

                prepspectraplot.savefig('./report/out/figures/raw_spectra.png')
                rawspectraplot.savefig('./report/out/figures/preprocessed_spectra.png')
                hist.savefig('./report/out/figures/histogram.png')
                result.round(4).to_csv('./report/out/The_analysis_result.csv', sep = ";")
DIANE's avatar
DIANE committed
                return change
DIANE's avatar
DIANE committed
            preparing_results_for_downloading(change = hash_)

            @st.cache_data(show_spinner =False)
            def tempdir(change):
DIANE Abderrahim's avatar
DIANE Abderrahim committed
                from tempfile import TemporaryDirectory
                with  TemporaryDirectory( prefix="results", dir="./report") as temp_dir:# create a temp directory
DIANE's avatar
DIANE committed
                    tempdirname = os.path.split(temp_dir)[1]
                    if len(os.listdir('./report/out/figures/'))==3:
                        make_archive(base_name="./report/Results", format="zip", base_dir="out", root_dir = "./report")# create a zip file
                        move("./report/Results.zip", f"./report/{tempdirname}/Results.zip")# put the inside the temp dir
                        with open(f"./report/{tempdirname}/Results.zip", "rb") as f:
DIANE's avatar
DIANE committed
                            zip_data = f.read()
                return tempdirname, zip_data

            date_time = datetime.now().strftime('%y%m%d%H%M')
            try :
                tempdirname, zip_data = tempdir(change = hash_)
                st.download_button(label = 'Download', data = zip_data, file_name = f'Nirs_Workflow_{date_time}_Pred_.zip', mime ="application/zip",
                            args = None, kwargs = None,type = "primary",use_container_width = True)
            except:
DIANE Abderrahim's avatar
DIANE Abderrahim committed
                st.write('-')
DIANE's avatar
DIANE committed
        # except:
        #     c2.error('''Error: Data loading failed. Please check your file. Consider fine-tuning the dialect settings or ensure the file isn't corrupted.''')


else:
    with c2:
        if new_data:
DIANE Abderrahim's avatar
DIANE Abderrahim committed
            st.error("Error!:The The data you provided for making predictions doesn't appear to be multivariable.!")