Skip to content
Snippets Groups Projects
3-prediction.py 7.2 KiB
Newer Older
from Packages import *
st.set_page_config(page_title="NIRS Utils", page_icon=":goat:", layout="wide")
from Modules import *
from Class_Mod.DATA_HANDLING import *
# HTML pour le bandeau "CEFE - CNRS"
# bandeau_html = """
# <div style="width: 100%; background-color: #4682B4; padding: 10px; margin-bottom: 10px;">
#   <h1 style="text-align: center; color: white;">CEFE - CNRS</h1>
# </div>
# """
# # Injecter le code HTML du bandeau
# st.markdown(bandeau_html, unsafe_allow_html=True)
add_header()

st.session_state["interface"] = st.session_state.get('interface')
DIANE's avatar
DIANE committed
local_css(css_file / "style_model.css")
DIANE's avatar
DIANE committed


DIANE's avatar
DIANE committed
st.header("Data loading", divider='blue')
DIANE's avatar
DIANE committed
M1, M2= st.columns([2, 1])

st.header('Data preprocessing', divider='blue')
M3, M4= st.columns([2, 1])

DIANE's avatar
DIANE committed
st.header("Prediction making", divider='blue')
DIANE's avatar
DIANE committed
M5, M6 = st.columns([2, 0.01])
DIANE's avatar
DIANE committed

DIANE's avatar
DIANE committed

DIANE's avatar
DIANE committed
files_format = ['.csv', '.dx']
DIANE's avatar
DIANE committed
file = M2.file_uploader("Select NIRS Data to predict", type = files_format, help=" :mushroom: select a csv matrix with samples as rows and lambdas as columns")
DIANE's avatar
DIANE committed
export_folder = './data/predictions/'
export_name = 'Predictions_of_'
reg_algo = ["Interval-PLS"]
DIANE's avatar
DIANE committed
pred_data = pd.DataFrame()
DIANE's avatar
DIANE committed
loaded_model = None

DIANE's avatar
DIANE committed

DIANE's avatar
DIANE committed
if file:
    test = file.name[file.name.find('.'):]
    export_name += file.name[:file.name.find('.')]

    if test == files_format[0]:
        #
DIANE's avatar
DIANE committed
        qsep = M2.selectbox("Select csv separator - _detected_: " + str(find_delimiter('data/'+file.name)), options=[";", ","], index=[";", ","].index(str(find_delimiter('data/'+file.name))), key=2)
        qhdr = M2.selectbox("indexes column in csv? - _detected_: " + str(find_col_index('data/'+file.name)), options=["no", "yes"], index=["no", "yes"].index(str(find_col_index('data/'+file.name))), key=3)
DIANE's avatar
DIANE committed
        if qhdr == 'yes':
            col = 0
        else:
            col = False
DIANE's avatar
DIANE committed
        pred_data = pd.read_csv(file, sep=qsep, index_col=col)
DIANE's avatar
DIANE committed
    elif test == files_format[1]:
        with NamedTemporaryFile(delete=False, suffix=".dx") as tmp:
            tmp.write(file.read())
            tmp_path = tmp.name
DIANE's avatar
DIANE committed
            chem_data, spectra, meta_data, _ = read_dx(file =  tmp_path)
DIANE's avatar
DIANE committed
            M2.success("The data have been loaded successfully", icon="")
DIANE's avatar
DIANE committed
            if chem_data.to_numpy().shape[1]>0:
DIANE's avatar
DIANE committed
                yname = M2.selectbox('Select target', options=chem_data.columns)
DIANE's avatar
DIANE committed
                measured = chem_data.loc[:,yname] == 0
                y = chem_data.loc[:,yname].loc[measured]
                pred_data = spectra.loc[measured]
            
            else:
                pred_data = spectra
DIANE's avatar
DIANE committed
        os.unlink(tmp_path)

DIANE's avatar
DIANE committed

# Load parameters
DIANE's avatar
DIANE committed
if not pred_data.empty:# Load the model with joblib
DIANE's avatar
DIANE committed
    M1.write('Raw spectra')
DIANE's avatar
DIANE committed
    fig = plot_spectra(pred_data, xunits = 'lab', yunits = "meta_data.loc[:,'yunits'][0]")
DIANE's avatar
DIANE committed
    M1.pyplot(fig)

DIANE's avatar
DIANE committed
### preprocessing
preprocessed = pd.DataFrame
if not pred_data.empty:
DIANE's avatar
DIANE committed
    params = M4.file_uploader("Load preprocessings params", type = '.json', help=" .json file")
DIANE's avatar
DIANE committed
    if params:
        prep = json.load(params)
DIANE's avatar
DIANE committed
        # M4.write(ProcessLookupError)
DIANE's avatar
DIANE committed

DIANE's avatar
DIANE committed
        if prep['normalization'] == 'Snv':
DIANE's avatar
DIANE committed
            x1 = Snv(pred_data)
DIANE's avatar
DIANE committed
            norm = 'Standard Normal Variate'
DIANE's avatar
DIANE committed
        else:
DIANE's avatar
DIANE committed
            norm = 'No Normalization was applied'
DIANE's avatar
DIANE committed
            x1 = pred_data
        x2 = savgol_filter(x1,
DIANE's avatar
DIANE committed
                            window_length = prep["window_length"],
                            polyorder = prep["polyorder"],
                            deriv=prep["deriv"],
DIANE's avatar
DIANE committed
                                delta=1.0, axis=-1, mode="interp", cval=0.0)
        preprocessed = pd.DataFrame(x2, index = pred_data.index, columns = pred_data.columns)
DIANE's avatar
DIANE committed

################################################################################################
DIANE's avatar
DIANE committed
## plot preprocessed spectra
if not preprocessed.empty:
DIANE's avatar
DIANE committed
    M3.write('Preprocessed spectra')
DIANE's avatar
DIANE committed
    fig2 = plot_spectra(preprocessed, xunits = 'lab', yunits = "meta_data.loc[:,'yunits'][0]")
DIANE's avatar
DIANE committed
    M3.pyplot(fig2)
    SG = f'- Savitzky-Golay derivative parameters \:(Window_length:{prep['window_length']};  polynomial order: {prep['polyorder']};  Derivative order : {prep['deriv']})'
    Norm = f'- Spectral Normalization \: {norm}'
    M4.write('The spectra were preprocessed using:\n'+SG+"\n"+Norm)
DIANE's avatar
DIANE committed

################### Predictions making  ##########################
if not pred_data.empty:# Load the model with joblib
    #dir = os.listdir('data/models/')[1:]
    dir = os.listdir('data/models/')
DIANE's avatar
DIANE committed
    dir.insert(0,'')
DIANE's avatar
DIANE committed
    model_name = M5.selectbox("Select your model from the dropdown list:", options = dir, key = 21)
DIANE's avatar
DIANE committed

    if model_name and model_name !='':
        export_name += '_with_' + model_name[:model_name.find('.')]
        with open('data/models/'+ model_name,'rb') as f:
            loaded_model = joblib.load(f)
            
        if loaded_model:
DIANE's avatar
DIANE committed
            M5.success("The model has been loaded successfully", icon="")
            s = M5.checkbox('the model is of ipls type?')
DIANE's avatar
DIANE committed
            if s:
DIANE's avatar
DIANE committed
                index = M5.file_uploader("select wavelengths index file", type="csv")
DIANE's avatar
DIANE committed
                if index:
DIANE's avatar
DIANE committed
                    intervalls = pd.read_csv(index, sep=';', index_col=0).to_numpy()
                    idx = []
                    for i in range(intervalls.shape[0]):
                        idx.extend(np.arange(intervalls[i,0], intervalls[i,1]+1))
DIANE's avatar
DIANE committed

if loaded_model:
DIANE's avatar
DIANE committed
    if M5.button('Predict'):
DIANE's avatar
DIANE committed
            if s:
                result = loaded_model.predict(preprocessed.iloc[:,idx])
            else:
                # use prediction function from application_functions.py to predict chemical values
                result = loaded_model.predict(x2)
            result = pd.DataFrame(result, index = pred_data.index)

            #############################
DIANE's avatar
DIANE committed
            M5.write('Predicted values distribution')
DIANE's avatar
DIANE committed
            # Creating histogram
DIANE's avatar
DIANE committed
            fig, axs = plt.subplots(1, 1, figsize =(15, 3), 
DIANE's avatar
DIANE committed
                                    tight_layout = True)
            
            # Add x, y gridlines 
DIANE's avatar
DIANE committed
            axs.grid( color ='grey', linestyle ='-.', linewidth = 0.5, alpha = 0.6) 
DIANE's avatar
DIANE committed
            # Remove axes splines 
            for s in ['top', 'bottom', 'left', 'right']: 
                axs.spines[s].set_visible(False) 
            # Remove x, y ticks
            axs.xaxis.set_ticks_position('none') 
            axs.yaxis.set_ticks_position('none') 
            # Add padding between axes and labels 
            axs.xaxis.set_tick_params(pad = 5) 
            axs.yaxis.set_tick_params(pad = 10) 
            # Creating histogram
            N, bins, patches = axs.hist(result, bins = 12)
            # Setting color
            fracs = ((N**(1 / 5)) / N.max())
            norm = colors.Normalize(fracs.min(), fracs.max())
            
            for thisfrac, thispatch in zip(fracs, patches):
                color = plt.cm.viridis(norm(thisfrac))
                thispatch.set_facecolor(color)

DIANE's avatar
DIANE committed
            M5.pyplot(fig)
            M6.write('Predicted values table')
            M6.dataframe(result.T)
DIANE's avatar
DIANE committed
            ##################################

            result.to_csv(export_folder + export_name + '.csv', sep = ';')
            # export to local drive - Download
            download_results(export_folder + export_name + '.csv', export_name + '.csv')
            # create a report with information on the prediction
DIANE's avatar
DIANE committed
            ## see https://stackoverflow.com/a/59578663