Skip to content
Snippets Groups Projects
3-prediction.py 13.9 KiB
Newer Older
from Packages import *
st.set_page_config(page_title="NIRS Utils", page_icon=":goat:", layout="wide")
from Modules import *
from Class_Mod.DATA_HANDLING import *
# HTML pour le bandeau "CEFE - CNRS"
# bandeau_html = """
# <div style="width: 100%; background-color: #4682B4; padding: 10px; margin-bottom: 10px;">
#   <h1 style="text-align: center; color: white;">CEFE - CNRS</h1>
# </div>
# """
# # Injecter le code HTML du bandeau
# st.markdown(bandeau_html, unsafe_allow_html=True)
add_header()
DIANE's avatar
DIANE committed
local_css(css_file / "style_model.css")
import shutil
hash_ = ''
def p_hash(add):
    global hash_
    hash_ = hash_data(hash_+str(add))
    return hash_
DIANE's avatar
DIANE committed

dirpath = Path('Report/out/model')
if dirpath.exists() and dirpath.is_dir():
    shutil.rmtree(dirpath)
DIANE's avatar
DIANE committed

# ####################################  Methods ##############################################
# empty temp figures
def delete_files(keep):
    supp = []
    # Walk through the directory
    for root, dirs, files in os.walk('Report/', topdown=False):
        for file in files:
            if file != 'logo_cefe.png' and not any(file.endswith(ext) for ext in keep):
                os.remove(os.path.join(root, file))
###################################################################
st.title("Prediction making using a previously developed model")
M10, M20= st.columns([2, 1])
M10.image("./images/prediction making.png", use_column_width=True)
def preparespecdf(df):
    other = df.select_dtypes(exclude = 'float')
    rownames = other.iloc[:,0]
    spec = df.select_dtypes(include='float')
    spec.index = rownames
    return spec, other, rownames
DIANE's avatar
DIANE committed

def check_exist(var):
    out = var in globals()
    return out
DIANE's avatar
DIANE committed

DIANE's avatar
DIANE committed
files_format = ['.csv', '.dx']
DIANE's avatar
DIANE committed
export_folder = './data/predictions/'
export_name = 'Predictions_of_'
reg_algo = ["Interval-PLS"]


with M20:
    file = st.file_uploader("Load NIRS Data for prediction making:", type = files_format, help=" :mushroom: select a csv matrix with samples as rows and lambdas as columns")
    
    if not file:
        st.info('Info: Insert your spectral data file above!')
    else:
        p_hash(file.name)
        test = file.name[file.name.find('.'):]
        export_name += file.name[:file.name.find('.')]

        if test == files_format[0]:
            qsep = st.radio("Select csv separator - _detected_: " + str(find_delimiter('data/'+file.name)), options=[";", ","],index=[";", ","].index(str(find_delimiter('data/'+file.name))), key=2, horizontal= True)
            qhdr = st.radio("indexes column in csv? - _detected_: " + str(find_col_index('data/'+file.name)), options=["no", "yes"], index=["no", "yes"].index(str(find_col_index('data/'+file.name))), key=3, horizontal= True)
            col = 0 if qhdr == 'yes' else None
            p_hash([qsep,qhdr])

            df = pd.read_csv(file, sep=qsep, header= col)
            pred_data, cat, rownames =  preparespecdf(df)

        elif test == files_format[1]:
            with NamedTemporaryFile(delete=False, suffix=".dx") as tmp:
                tmp.write(file.read())
                tmp_path = tmp.name
                with open(tmp.name, 'r') as dd:
                    dxdata = file.read()
                    p_hash(str(dxdata)+str(file.name))

                ## load and parse the temp dx file
                @st.cache_data
                def dx_loader(change):
                    chem_data, spectra, meta_data, _ = read_dx(file =  tmp_path)
                    return chem_data, spectra, meta_data, _
                chem_data, spectra, meta_data, _ = dx_loader(change = hash_)
                st.success("The data have been loaded successfully", icon="")
                if chem_data.to_numpy().shape[1]>0:
                    yname = st.selectbox('Select target', options=chem_data.columns)
                    measured = chem_data.loc[:,yname] == 0
                    y = chem_data.loc[:,yname].loc[measured]
                    pred_data = spectra.loc[measured]
                
                else:
                    pred_data = spectra
            os.unlink(tmp_path)
DIANE's avatar
DIANE committed

DIANE's avatar
DIANE committed

# Load parameters
DIANE's avatar
DIANE committed
st.header("I - Spectral data preprocessing & visualization", divider='blue')
try:
    if check_exist("pred_data"):# Load the model with joblib
        @st.cache_data
        def specplot_raw(change):
            fig2 = plot_spectra(pred_data, xunits = 'lab', yunits = "meta_data.loc[:,'yunits'][0]")
            return fig2
        rawspectraplot = specplot_raw(change = hash_)
        M1, M2= st.columns([2, 1])
        with M1:
            st.write('Raw spectra')
            st.pyplot(rawspectraplot)

        with M2:
            params = st.file_uploader("Load preprocessings params", type = '.json', help=" .json file")
            if params:
                prep = json.load(params)
                p_hash(prep)
                
                @st.cache_data
                def preprocess_spectra(change):
                    # M4.write(ProcessLookupError)
                    
                    if prep['normalization'] == 'Snv':
                        x1 = Snv(pred_data)
                        norm = 'Standard Normal Variate'
DIANE's avatar
DIANE committed
                    else:
                        norm = 'No Normalization was applied'
                        x1 = pred_data
                    x2 = savgol_filter(x1,
                                        window_length = int(prep["window_length"]),
                                        polyorder = int(prep["polyorder"]),
                                        deriv = int(prep["deriv"]),
                                            delta=1.0, axis=-1, mode="interp", cval=0.0)
                    preprocessed = pd.DataFrame(x2, index = pred_data.index, columns = pred_data.columns)
                    return norm, prep, preprocessed
                norm, prep, preprocessed = preprocess_spectra(change= hash_)

    ################################################################################################
    ## plot preprocessed spectra
    if check_exist("preprocessed"):
        p_hash(preprocessed)
        M3, M4= st.columns([2, 1])
        with M3:
            st.write('Preprocessed spectra')
            def specplot_prep(change):
                fig2 = plot_spectra(preprocessed, xunits = 'lab', yunits = "meta_data.loc[:,'yunits'][0]")
                return fig2
            prepspectraplot = specplot_prep(change = hash_)
            st.pyplot(prepspectraplot)

        with M4:
            @st.cache_data
            def prep_info(change):
DIANE's avatar
DIANE committed
                SG = r'''- Savitzky-Golay derivative parameters \:(Window_length:{prep['window_length']};  polynomial order: {prep['polyorder']};  Derivative order : {prep['deriv']})'''
                Norm = r'''- Spectral Normalization \: {norm}'''
                return SG, Norm
            SG, Norm = prep_info(change = hash_)
            st.info('The spectra were preprocessed using:\n'+SG+"\n"+Norm)

    ################### Predictions making  ##########################
    st.header("II - Prediction making", divider='blue')
    if check_exist("pred_data") and params:# Load the model with joblib
        M5, M6 = st.columns([2, 1])
        model_file = M6.file_uploader("Load your model", type = '.pkl', help=" .pkl file")
        if model_file:
            with M6:
                try:
                    model = joblib.load(model_file)
                    st.success("The model has been loaded successfully", icon="")
                    nvar = model.n_features_in_

                except:
                    st.error("Error: Something went wrong, the model was not loaded !", icon="")
        
        with M6:
            s = st.checkbox('Check this box if your model is of ipls type!', disabled = False if 'model' in globals() else True)
            index = st.file_uploader("select wavelengths index file", type="csv", disabled = [False if s else True][0])
            if check_exist('preprocessed'):
                if s:
                    if index:
                        intervalls = pd.read_csv(index, sep=';', index_col=0).to_numpy()
                        idx = []
                        for i in range(intervalls.shape[0]):
                            idx.extend(np.arange(intervalls[i,2], intervalls[i,3]+1))
                        if max(idx) <= preprocessed.shape[1]:
                            preprocesseddf = preprocessed.iloc[:,idx] ### get predictors    
                        else:
                            st.error("Error: The number of columns in your data does not match the number of columns used to train the model. Please ensure they are the same.")
                else:
                    preprocesseddf = preprocessed
                

                
            if check_exist("model") == False:
                disable = True
            elif check_exist("model") == True:
                if s and not index :
                    disable = True
                elif s and index:
                    disable  = False
                elif not s and not index:
                    disable  = False
                elif not s and index:
                    disable  = True

                
            pred_button = M6.button('Predict', type='primary', disabled= disable)

            if check_exist("preprocesseddf"):
                if pred_button and nvar == preprocesseddf.shape[1]:
                    try:
                        result = pd.DataFrame(model.predict(preprocesseddf), index = rownames, columns = ['Results'])
                    except:
                        st.error(f'''Error: Length mismatch: the number of samples indices is {len(rownames)}, while the model produced 
                                {len(model.predict(preprocesseddf))} values. correct the "indexes column in csv?" parameter''')
                    with M5:
                        if preprocesseddf.shape[1]>1 and check_exist('result'):
                            st.write('Predicted values distribution')
                            # Creating histogram
                            hist, axs = plt.subplots(1, 1, figsize =(15, 3), 
                                                    tight_layout = True)
                            
                            # Add x, y gridlines 
                            axs.grid( color ='grey', linestyle ='-.', linewidth = 0.5, alpha = 0.6) 
                            # Remove axes splines 
                            for s in ['top', 'bottom', 'left', 'right']: 
                                axs.spines[s].set_visible(False) 
                            # Remove x, y ticks
                            axs.xaxis.set_ticks_position('none') 
                            axs.yaxis.set_ticks_position('none') 
                            # Add padding between axes and labels 
                            axs.xaxis.set_tick_params(pad = 5) 
                            axs.yaxis.set_tick_params(pad = 10) 
                            # Creating histogram
                            N, bins, patches = axs.hist(result, bins = 12)
                            # Setting color
                            fracs = ((N**(1 / 5)) / N.max())
                            norm = colors.Normalize(fracs.min(), fracs.max())
                            
                            for thisfrac, thispatch in zip(fracs, patches):
                                color = plt.cm.viridis(norm(thisfrac))
                                thispatch.set_facecolor(color)

                            st.pyplot(hist)
                            st.write('Predicted values table')
                            st.dataframe(result.T)
                            #################################3
                elif pred_button and nvar != preprocesseddf.shape[1]:
                    M6.error(f'Error: The model was trained on {nvar} wavelengths, but you provided {preprocessed.shape[1]} wavelengths for prediction. Please ensure they match!')


    if check_exist('result'):
        @st.cache_data(show_spinner =False)
        def preparing_results_for_downloading(change):
            match test:
                # load csv file
                case '.csv':
                    df.to_csv('Report/out/dataset/'+ file.name, sep = ';', encoding = 'utf-8', mode = 'a')
                case '.dx':
                    with open('Report/out/dataset/'+file.name, 'w') as dd:
                        dd.write(dxdata)

            prepspectraplot.savefig('./Report/out/figures/raw_spectra.png')
            rawspectraplot.savefig('./Report/out/figures/preprocessed_spectra.png')
            hist.savefig('./Report/out/figures/histogram.png')
            result.round(4).to_csv('./Report/out/The analysis result.csv', sep = ";", index_col=0)

            return change
        preparing_results_for_downloading(change = hash_)

        import tempfile
        @st.cache_data(show_spinner =False)
        def tempdir(change):
            with  tempfile.TemporaryDirectory( prefix="results", dir="./Report") as temp_dir:# create a temp directory
                tempdirname = os.path.split(temp_dir)[1]
                if len(os.listdir('./Report/out/figures/'))==3:
                    shutil.make_archive(base_name="./Report/Results", format="zip", base_dir="out", root_dir = "./Report")# create a zip file
                    shutil.move("./Report/Results.zip", f"./Report/{tempdirname}/Results.zip")# put the inside the temp dir
                    with open(f"./Report/{tempdirname}/Results.zip", "rb") as f:
                        zip_data = f.read()
            return tempdirname, zip_data

        date_time = datetime.datetime.now().strftime('%y%m%d%H%M')
        try :
            tempdirname, zip_data = tempdir(change = hash_)
            st.download_button(label = 'Download', data = zip_data, file_name = f'Nirs_Workflow_{date_time}_Pred_.zip', mime ="application/zip",
                        args = None, kwargs = None,type = "primary",use_container_width = True)
        except:
            st.write('rtt')
except:
    M20.error('''Error: Data loading failed. Please check your file. Consider fine-tuning the dialect settings or ensure the file isn't corrupted.''')
DIANE's avatar
DIANE committed

DIANE's avatar
DIANE committed

DIANE's avatar
DIANE committed