from Packages import *
st.set_page_config(page_title="NIRS Utils", page_icon=":goat:", layout="wide")
from Modules import *
from utils.DATA_HANDLING import *
# HTML pour le bandeau "CEFE - CNRS"
# bandeau_html = """
# <div style="width: 100%; background-color: #4682B4; padding: 10px; margin-bottom: 10px;">
#   <h1 style="text-align: center; color: white;">CEFE - CNRS</h1>
# </div>
# """
# # Injecter le code HTML du bandeau
# st.markdown(bandeau_html, unsafe_allow_html=True)
local_css(css_file / "style_model.css")
hash_ = ''
def p_hash(add):
    global hash_
    hash_ = hash_data(hash_+str(add))
    return hash_
dirpath = Path('Report/out/model')
if dirpath.exists() and dirpath.is_dir():
if 'Predict' not in st.session_state:
    st.session_state['Predict'] = False
# ####################################  Methods ##############################################
# empty temp figures
def delete_files(keep):
    supp = []
    # Walk through the directory
    for root, dirs, files in os.walk('Report/', topdown=False):
        for file in files:
            if file != 'logo_cefe.png' and not any(file.endswith(ext) for ext in keep):
                os.remove(os.path.join(root, file))
st.header("Prediction making using a previously developed model")
c1, c2 = st.columns([2, 1])
c1.image("./images/prediction making.png", use_column_width=True)
pred_data = DataFrame

def preparespecdf(df):
    other = df.select_dtypes(exclude = 'float')
    spec = df.select_dtypes(include='float')
    if other.shape[1] > 0:
        rownames = other.iloc[:,0]
        spec.index = rownames
        rownames = [str(i) for i in range(df.shape[0])]
    if spec.shape[1]<60:
        spec = DataFrame
    return spec, other, rownames
def check_exist(var):
    out = var in globals()
    return out
with c2:
    zip = st.file_uploader("Load your zip file:", type = ['.zip'], help=" :mushroom: select a csv matrix with samples as rows and lambdas as columns")
    if not zip:'Info: Insert your zip file above!')
    disable1 = False if zip else True
    new_data = st.file_uploader("Load NIRS Data for prediction making:", type = ['csv', 'dx'], help=" :mushroom: select a csv matrix with samples as rows and lambdas as columns", disabled=disable1)
    if not disable1 :
        info1 ='Info: Insert your NIRS data file above!')
    if zip:
        def tempdir(prefix, dir):
            with TemporaryDirectory( prefix= prefix, dir= dir ) as temp_dir:# create a temp directory
                tempdirname = os.path.split(temp_dir)[1]
            return tempdirname
        temp_dir = tempdir(prefix = "pred_temp", dir = "./temp")
        # Open and extract the zip file
        from zipfile import ZipFile
        with ZipFile(zip, 'r') as zip_ref:
        def find_pkl_files(root_dir):
            # List to store the paths of .pkl files
            pkl_files = []

            # Walk through the directory
            for dirpath, dirnames, filenames in os.walk(root_dir):
                for filename in filenames:
                    # Check if the file has a .pkl extension
                    if filename.endswith('.pkl'):
                        # Construct the full file path
                        file_path = os.path.join(dirpath, filename)
            return pkl_files
        pkl = find_pkl_files(root_dir=temp_dir)

        system_file = [path for path in pkl if 'file_system' in path]
        if len(system_file) ==1 :
            with open(system_file[0], 'rb') as fi:
                system_data = load(fi)

if new_data:

with c2:
    if new_data:
        test ='.')[-1]
        export_name = 'Pred of'
        export_name +=['.')]

        match test:
            case 'csv':
                qsep ="Select csv separator : " , options = [';', ','], key = 2, horizontal = True)
                qhdr ="indexes column in csv? : " , options = ['yes', 'no'], key = 3, horizontal = True)
                col = 0 if qhdr == 'yes' else None

                df = read_csv(new_data, sep=qsep, header= col, decimal=".")
                pred_data, cat, rownames =  preparespecdf(df)

            case "dx":
                with NamedTemporaryFile(delete=False, suffix=".dx") as tmp:
                    tmp_path =
                    with open(, 'r') as dd:
                        dxdata =

                    ## load and parse the temp dx file
                    def dx_loader(change):
                        chem_data, spectra, meta_data, _ = read_dx(file =  tmp_path)
                        return chem_data, spectra, meta_data, _
                    chem_data, spectra, meta_data, _ = dx_loader(change = hash_)
                    st.success("The data have been loaded successfully", icon="")
                    if chem_data.to_numpy().shape[1]>0:
                        yname = st.selectbox('Select target', options=chem_data.columns)
                        measured = chem_data.loc[:,yname] == 0
                        y = chem_data.loc[:,yname].loc[measured]
                        pred_data = spectra.loc[measured]
                        pred_data = spectra
# Load parameters
st.subheader("I - Spectral data preprocessing & visualization", divider='blue')
# try:
if not pred_data.empty:# Load the model with joblib
        # M4.write(ProcessLookupError)
        if system_data['spec-preprocessing']['normalization'] == 'Snv':
            norm = 'Standard Normal Variate'
            norm = 'No Normalization was applied'
        x2 = savgol_filter(x1,
                            window_length = int(system_data['spec-preprocessing']['SavGol(polyorder,window_length,deriv)'][1]),
                            polyorder = int(system_data['spec-preprocessing']['SavGol(polyorder,window_length,deriv)'][0]),
                            deriv = int(system_data['spec-preprocessing']['SavGol(polyorder,window_length,deriv)'][2]),
                                delta=1.0, axis=-1, mode="interp", cval=0.0)
        preprocessed = DataFrame(x2, index = data.index, columns = data.columns)
        return norm, preprocessed
    norm, preprocessed = preprocess_spectra(pred_data, change= hash_)
                        # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
    # @st.cache_data
    # def specplot_raw(change):
    #     fig2 = plot_spectra(pred_data, xunits = 'lab', yunits = "meta_data.loc[:,'yunits'][0]")
    #     return fig2
    # rawspectraplot = specplot_raw(change = hash_)
    rawspectraplot = plot_spectra(pred_data, xunits = 'Wavelength/Wavenumber', yunits = "Signal intensity")

    c3, c4 = st.columns([2, 1])
    with c3:
        st.write('Raw spectra')

        ## plot preprocessed spectra
        if check_exist("preprocessed"):
            # def specplot_prep(change):
            #     fig2 = plot_spectra(preprocessed, xunits = 'lab', yunits = "meta_data.loc[:,'yunits'][0]")
            #     return fig2
            # prepspectraplot = specplot_prep(change = hash_)
            prepspectraplot = plot_spectra(preprocessed, xunits = 'Wavelength/Wavenumber', yunits = "Signal intensity")
            st.write('Preprocessed spectra')
    with c4:
        def prep_info(change):
            SG = f'- Savitzky-Golay derivative parameters \n:(Window_length:{system_data['spec-preprocessing']['SavGol(polyorder,window_length,deriv)'][1]};  polynomial order: {system_data['spec-preprocessing']['SavGol(polyorder,window_length,deriv)'][0]};  Derivative order : {system_data['spec-preprocessing']['SavGol(polyorder,window_length,deriv)'][2]})'
            Norm = f'- Spectral Normalization \n: {system_data['spec-preprocessing']['normalization']}'
            return SG, Norm
        SG, Norm = prep_info(change = hash_)'The spectra were preprocessed using:\n'+SG+"\n"+Norm)

    ################### Predictions making  ##########################
    st.subheader("II - Prediction making", divider='blue')
    disable2 = False if check_exist("pred_data") else True
    pred_button = st.button('Predict', type='primary', disabled= disable2, use_container_width=False)
    if pred_button:st.session_state['Predict'] = True

    if st.session_state['Predict']:
        if  check_exist("pred_data"):# Load the model with joblib
            c5, c6 = st.columns([2, 1])
            with c6:
                model = system_data['model_']
                if system_data['model_type'] in ['PLS','TPE-iPLS']:
                    nvar = system_data['model_'].n_features_in_
                elif system_data['model_type']  =='LW-PLS':
                    nvar = system_data['data']['raw-spectra'].shape[1]

        if check_exist('preprocessed'):
            if isinstance(system_data['selected-wls']['idx'], DataFrame):
                idx = np.concatenate([np.arange(system_data['selected-wls']['idx'].values.reshape((-1,))[2*i],system_data['selected-wls']['idx'].values.reshape((-1,))[2*i+1]+1) for i in range(system_data['selected-wls']['idx'].shape[0])])
                idx = np.arange(nvar)

            if np.max(idx) <= preprocessed.shape[1]:
                preprocesseddf = preprocessed.iloc[:,idx] ### get predictors
                st.error("Error: The number of columns in your data does not match the number of columns used to train the model. Please ensure they are the same.")

        if check_exist("preprocesseddf"):
            if st.session_state['Predict'] and nvar == preprocesseddf.shape[1]:
            # if nvar == preprocesseddf.shape[1]:
                match system_data['model_type']:
                    case 'PLS'|'TPE-iPLS':
                            result = DataFrame(system_data['model_'].predict(preprocesseddf), index = rownames, columns = ['Results'])
                            st.error(f'''Error: Length mismatch: the number of samples indices is {len(rownames)}, while the model produced 
                                            {len(model.predict(preprocesseddf))} values. correct the "indexes column in csv?" parameter''')
                    case 'LW-PLS':
                        temp_path = Path('temp/')
                        # export data to csv for Julia train/pred
                        # spectra = system_data['data']['raw-spectra'] # without pretreatments
                        spectra = preprocess_spectra(system_data['data']['raw-spectra'], change= hash_)
                        # with pretreatments
                        x_pred = preprocessed
                        y = system_data['data']['target']
                        data_to_work_with = ['spectra', 'y', 'x_pred']
                        spectra_np, y_np, x_pred_np = spectra.to_numpy(), y.to_numpy(), x_pred.to_numpy()
                        # export spectra, y, x_pred to temp folder as csv files
                        for i in data_to_work_with:
                            np.savetxt(temp_path / str(i + ".csv"), j, delimiter=",")
                        import subprocess
                        subprocess_path = Path("utils/")
                        #[f"{sys.executable}", subprocess_path / ""])
                        # # retrieve json results from Julia JChemo
                        # try:
                        #     with open(temp_path / "lwplsr_outputs.json", "r") as outfile:
                        #         Reg_json = json.load(outfile)
                        #         # delete csv files
                        #         for i in data_to_work_with: os.unlink(temp_path / str(i + ".csv"))
                        #     # delete json file after import
                        #     os.unlink(temp_path / "lwplsr_outputs.json")
                        #     os.unlink(temp_path / "lwplsr_preTreatments.json")
                        #     # format result data into Reg object
                        #     pred = ['pred_data_train', 'pred_data_test']### keys of the dict
                        #     for i in range(nb_folds):
                        #         pred.append("CV" + str(i+1)) ### add cv folds keys to pred
                        # except FileNotFoundError as e:
                        #     Reg = None
                        #     for i in data_to_work_with: os.unlink(temp_path / str(i + ".csv"))
                        # st.write(Reg_json)
            ################################### results display ###################################
        if check_exist("preprocesseddf"):
            if preprocesseddf.shape[1]>1 and check_exist('result'):
                hist = pred_hist(pred=result)
                with c5:
                    st.write('Predicted values distribution')
                    st.write('Predicted values table')
                with c6:
          'descriptive statistics for the model output')
            elif pred_button and nvar != preprocesseddf.shape[1]:
                with c6:
                    st.error(f'Error: The model was trained on {nvar} wavelengths, but you provided {preprocessed.shape[1]} wavelengths for prediction. Please ensure they match!')

            ################################# Download results #################################
        if check_exist('result'):
            @st.cache_data(show_spinner =False)
            def preparing_results_for_downloading(change):
                match test:
                    # load csv file
                    case 'csv':
                        df.to_csv('Report/out/dataset/'+, sep = ';', encoding = 'utf-8', mode = 'a')
                    case 'dx':
                        with open('Report/out/dataset/', 'w') as dd:

                result.round(4).to_csv('./Report/out/The analysis result.csv', sep = ";")

                return change
            preparing_results_for_downloading(change = hash_)

            @st.cache_data(show_spinner =False)
            def tempdir(change):
                with  TemporaryDirectory( prefix="results", dir="./Report") as temp_dir:# create a temp directory
                    tempdirname = os.path.split(temp_dir)[1]
                    if len(os.listdir('./Report/out/figures/'))==3:
                        make_archive(base_name="./Report/Results", format="zip", base_dir="out", root_dir = "./Report")# create a zip file
                        move("./Report/", f"./Report/{tempdirname}/")# put the inside the temp dir
                        with open(f"./Report/{tempdirname}/", "rb") as f:
                            zip_data =
                return tempdirname, zip_data

            date_time ='%y%m%d%H%M')
            try :
                tempdirname, zip_data = tempdir(change = hash_)
                st.download_button(label = 'Download', data = zip_data, file_name = f'Nirs_Workflow_{date_time}', mime ="application/zip",
                            args = None, kwargs = None,type = "primary",use_container_width = True)
        # except:
        #     c2.error('''Error: Data loading failed. Please check your file. Consider fine-tuning the dialect settings or ensure the file isn't corrupted.''')

    with c2:
        if new_data:
            st.error("Error!:The The data you provided for making predictions doesn't appear to be multivariable.!")