Skip to content
Snippets Groups Projects
3-prediction.py 8.24 KiB
Newer Older
from Packages import *
st.set_page_config(page_title="NIRS Utils", page_icon=":goat:", layout="wide")
from Modules import *
from Class_Mod.DATA_HANDLING import *
# HTML pour le bandeau "CEFE - CNRS"
# bandeau_html = """
# <div style="width: 100%; background-color: #4682B4; padding: 10px; margin-bottom: 10px;">
#   <h1 style="text-align: center; color: white;">CEFE - CNRS</h1>
# </div>
# """
# # Injecter le code HTML du bandeau
# st.markdown(bandeau_html, unsafe_allow_html=True)
add_header()
DIANE's avatar
DIANE committed
local_css(css_file / "style_model.css")
DIANE's avatar
DIANE committed

st.title("Prediction making using a previously developed model")
M10, M20= st.columns([2, 1])
DIANE's avatar
DIANE committed
M10.image("./images/prediction making.png", use_column_width=True)
DIANE's avatar
DIANE committed


# st.header("Prediction making", divider='blue')
# M5, M6 = st.columns([2, 0.01])
DIANE's avatar
DIANE committed

DIANE's avatar
DIANE committed

DIANE's avatar
DIANE committed
files_format = ['.csv', '.dx']
file = M20.file_uploader("Select NIRS Data to predict", type = files_format, help=" :mushroom: select a csv matrix with samples as rows and lambdas as columns")
DIANE's avatar
DIANE committed
export_folder = './data/predictions/'
export_name = 'Predictions_of_'
reg_algo = ["Interval-PLS"]
DIANE's avatar
DIANE committed
pred_data = pd.DataFrame()
DIANE's avatar
DIANE committed
loaded_model = None

if not file:
    M20.warning('Insert your spectral data file here!')
else:
DIANE's avatar
DIANE committed
    test = file.name[file.name.find('.'):]
    export_name += file.name[:file.name.find('.')]

    if test == files_format[0]:
        #
DIANE's avatar
DIANE committed
        qsep = M20.selectbox("Select csv separator - _detected_: " + str(find_delimiter('data/'+file.name)), options=[";", ","], index=[";", ","].index(str(find_delimiter('data/'+file.name))), key=2)
        qhdr = M20.selectbox("indexes column in csv? - _detected_: " + str(find_col_index('data/'+file.name)), options=["no", "yes"], index=["no", "yes"].index(str(find_col_index('data/'+file.name))), key=3)
DIANE's avatar
DIANE committed
        if qhdr == 'yes':
            col = 0
        else:
            col = False
DIANE's avatar
DIANE committed
        pred_data = pd.read_csv(file, sep=qsep, index_col=col)
DIANE's avatar
DIANE committed
    elif test == files_format[1]:
        with NamedTemporaryFile(delete=False, suffix=".dx") as tmp:
            tmp.write(file.read())
            tmp_path = tmp.name
DIANE's avatar
DIANE committed
            chem_data, spectra, meta_data, _ = read_dx(file =  tmp_path)
DIANE's avatar
DIANE committed
            M20.success("The data have been loaded successfully", icon="")
DIANE's avatar
DIANE committed
            if chem_data.to_numpy().shape[1]>0:
DIANE's avatar
DIANE committed
                yname = M20.selectbox('Select target', options=chem_data.columns)
DIANE's avatar
DIANE committed
                measured = chem_data.loc[:,yname] == 0
                y = chem_data.loc[:,yname].loc[measured]
                pred_data = spectra.loc[measured]
            
            else:
                pred_data = spectra
DIANE's avatar
DIANE committed
        os.unlink(tmp_path)

DIANE's avatar
DIANE committed

# Load parameters
DIANE's avatar
DIANE committed
st.header("I - Spectral data preprocessing & visualization", divider='blue')
DIANE's avatar
DIANE committed
if not pred_data.empty:# Load the model with joblib
DIANE's avatar
DIANE committed
    M1.write('Raw spectra')
DIANE's avatar
DIANE committed
    fig = plot_spectra(pred_data, xunits = 'lab', yunits = "meta_data.loc[:,'yunits'][0]")
DIANE's avatar
DIANE committed
    M1.pyplot(fig)

DIANE's avatar
DIANE committed
### preprocessing
preprocessed = pd.DataFrame
if not pred_data.empty:
DIANE's avatar
DIANE committed
    params = M2.file_uploader("Load preprocessings params", type = '.json', help=" .json file")
DIANE's avatar
DIANE committed
    if params:
        prep = json.load(params)
DIANE's avatar
DIANE committed
        # M4.write(ProcessLookupError)
DIANE's avatar
DIANE committed

DIANE's avatar
DIANE committed
        if prep['normalization'] == 'Snv':
DIANE's avatar
DIANE committed
            x1 = Snv(pred_data)
DIANE's avatar
DIANE committed
            norm = 'Standard Normal Variate'
DIANE's avatar
DIANE committed
        else:
DIANE's avatar
DIANE committed
            norm = 'No Normalization was applied'
DIANE's avatar
DIANE committed
            x1 = pred_data
        x2 = savgol_filter(x1,
DIANE's avatar
DIANE committed
                            window_length = prep["window_length"],
                            polyorder = prep["polyorder"],
                            deriv=prep["deriv"],
DIANE's avatar
DIANE committed
                                delta=1.0, axis=-1, mode="interp", cval=0.0)
        preprocessed = pd.DataFrame(x2, index = pred_data.index, columns = pred_data.columns)
DIANE's avatar
DIANE committed

################################################################################################
DIANE's avatar
DIANE committed
## plot preprocessed spectra
if not preprocessed.empty:
DIANE's avatar
DIANE committed
    M3.write('Preprocessed spectra')
DIANE's avatar
DIANE committed
    fig2 = plot_spectra(preprocessed, xunits = 'lab', yunits = "meta_data.loc[:,'yunits'][0]")
DIANE's avatar
DIANE committed
    M3.pyplot(fig2)
    SG = f'- Savitzky-Golay derivative parameters \:(Window_length:{prep['window_length']};  polynomial order: {prep['polyorder']};  Derivative order : {prep['deriv']})'
    Norm = f'- Spectral Normalization \: {norm}'
DIANE's avatar
DIANE committed
    M4.info('The spectra were preprocessed using:\n'+SG+"\n"+Norm)
DIANE's avatar
DIANE committed

################### Predictions making  ##########################
DIANE's avatar
DIANE committed
st.header("II - Prediction making", divider='blue')
if not pred_data.empty and params:# Load the model with joblib
    M5, M6 = st.columns([2, 1])
DIANE's avatar
DIANE committed
    #dir = os.listdir('data/models/')[1:]
    dir = os.listdir('data/models/')
DIANE's avatar
DIANE committed
    dir.insert(0,'')
DIANE's avatar
DIANE committed
    model_name = M6.selectbox("Select your model from the dropdown list:", options = dir, key = 21, format_func=lambda x: x if x else "<Select>")
DIANE's avatar
DIANE committed

DIANE's avatar
DIANE committed
    if model_name:
DIANE's avatar
DIANE committed
        export_name += '_with_' + model_name[:model_name.find('.')]
        with open('data/models/'+ model_name,'rb') as f:
            loaded_model = joblib.load(f)
DIANE's avatar
DIANE committed
            ncols = loaded_model.n_features_in_
DIANE's avatar
DIANE committed
            
        if loaded_model:
DIANE's avatar
DIANE committed
            M6.success("The model has been loaded successfully", icon="")
            s = M6.checkbox('the model is of ipls type?')
DIANE's avatar
DIANE committed
            if s:
DIANE's avatar
DIANE committed
                index = M6.file_uploader("select wavelengths index file", type="csv")
DIANE's avatar
DIANE committed
                if index:
DIANE's avatar
DIANE committed
                    intervalls = pd.read_csv(index, sep=';', index_col=0).to_numpy()
                    idx = []
                    for i in range(intervalls.shape[0]):
                        idx.extend(np.arange(intervalls[i,0], intervalls[i,1]+1))
DIANE's avatar
DIANE committed
                    if max(idx) <= preprocessed.shape[1]:
                        preprocessed = preprocessed.iloc[:,idx] ### get predictors
                    else:
                        M6.error("Error: The number of columns in your data does not match the number of columns used to train the model. Please ensure they are the same.")

DIANE's avatar
DIANE committed

if loaded_model:
DIANE's avatar
DIANE committed
    if M6.button('Predict', type='primary'):
DIANE's avatar
DIANE committed
            if ncols == preprocessed.shape[1]:
                result = pd.DataFrame(loaded_model.predict(preprocessed), index = preprocessed.index)

                #############################
                if preprocessed.shape[1]>1:
                    M5.write('Predicted values distribution')
                    # Creating histogram
                    fig, axs = plt.subplots(1, 1, figsize =(15, 3), 
                                            tight_layout = True)
                    
                    # Add x, y gridlines 
                    axs.grid( color ='grey', linestyle ='-.', linewidth = 0.5, alpha = 0.6) 
                    # Remove axes splines 
                    for s in ['top', 'bottom', 'left', 'right']: 
                        axs.spines[s].set_visible(False) 
                    # Remove x, y ticks
                    axs.xaxis.set_ticks_position('none') 
                    axs.yaxis.set_ticks_position('none') 
                    # Add padding between axes and labels 
                    axs.xaxis.set_tick_params(pad = 5) 
                    axs.yaxis.set_tick_params(pad = 10) 
                    # Creating histogram
                    N, bins, patches = axs.hist(result, bins = 12)
                    # Setting color
                    fracs = ((N**(1 / 5)) / N.max())
                    norm = colors.Normalize(fracs.min(), fracs.max())
                    
                    for thisfrac, thispatch in zip(fracs, patches):
                        color = plt.cm.viridis(norm(thisfrac))
                        thispatch.set_facecolor(color)

                    M5.pyplot(fig)
                st.write('Predicted values table')
                st.dataframe(result.T)
                ##################################

                # result.to_csv(export_folder + export_name + '.csv', sep = ';')
                # export to local drive - Download
                download_results(export_folder + export_name + '.csv', export_name + '.csv')
                # create a report with information on the prediction
                ## see https://stackoverflow.com/a/59578663
DIANE's avatar
DIANE committed
            else:
DIANE's avatar
DIANE committed
                M6.error(f'Error: The model was trained with {ncols} wavelengths, but you provided {preprocessed.shape[1]} wavelengths for prediction. Please ensure they match.')
DIANE's avatar
DIANE committed