Skip to content
Snippets Groups Projects
3-prediction.py 3.86 KiB
Newer Older
from Packages import *
st.set_page_config(page_title="NIRS Utils", page_icon=":goat:", layout="wide")
from Modules import *
from Class_Mod.DATA_HANDLING import *
# HTML pour le bandeau "CEFE - CNRS"
# bandeau_html = """
# <div style="width: 100%; background-color: #4682B4; padding: 10px; margin-bottom: 10px;">
#   <h1 style="text-align: center; color: white;">CEFE - CNRS</h1>
# </div>
# """
# # Injecter le code HTML du bandeau
# st.markdown(bandeau_html, unsafe_allow_html=True)
add_header()

st.session_state["interface"] = st.session_state.get('interface')
DIANE's avatar
DIANE committed


st.header("Predictions making", divider='blue')
DIANE's avatar
DIANE committed
model_column, space, file_column= st.columns([2, 1, 1])
DIANE's avatar
DIANE committed

DIANE's avatar
DIANE committed
files_format = ['.csv', '.dx']
file = file_column.file_uploader("Select NIRS Data to predict", type = files_format, help=" :mushroom: select a csv matrix with samples as rows and lambdas as columns")
DIANE's avatar
DIANE committed
export_folder = './data/predictions/'
export_name = 'Predictions_of_'
reg_algo = ["Interval-PLS"]

DIANE's avatar
DIANE committed
pred_data = pd.DataFrame
loaded_model = None

if file:
    test = file.name[file.name.find('.'):]
    export_name += file.name[:file.name.find('.')]

    if test == files_format[0]:
        #
        qsep = file_column.selectbox("Select csv separator - _detected_: " + str(find_delimiter('data/'+file.name)), options=[";", ","], index=[";", ","].index(str(find_delimiter('data/'+file.name))), key=2)
        qhdr = file_column.selectbox("indexes column in csv? - _detected_: " + str(find_col_index('data/'+file.name)), options=["no", "yes"], index=["no", "yes"].index(str(find_col_index('data/'+file.name))), key=3)
DIANE's avatar
DIANE committed
        if qhdr == 'yes':
            col = 0
        else:
            col = False
DIANE's avatar
DIANE committed
        pred_data = pd.read_csv(file, sep=qsep, index_col=col)
DIANE's avatar
DIANE committed
    elif test == files_format[1]:
        with NamedTemporaryFile(delete=False, suffix=".dx") as tmp:
            tmp.write(file.read())
            tmp_path = tmp.name
            chem_data, spectra, meta_data = read_dx(file =  tmp_path)
            file_column.success("The data have been loaded successfully", icon="")
            yname = file_column.selectbox('Select target', options=chem_data.columns)
            measured = chem_data.loc[:,yname] == 0
            y = chem_data.loc[:,yname].loc[measured]
            pred_data = spectra.loc[measured]
        os.unlink(tmp_path)

################### Predictions making
if not pred_data.empty:# Load the model with joblib
    model_column.write("Load your saved predictive model")
    dir = os.listdir('data/models/')[1:]
    dir.insert(0,'')
    model_name = model_column.selectbox('Choose file:', options = dir, key = 21)
    
    

    if model_name and model_name !='':
        export_name += '_with_' + model_name[:model_name.find('.')]
        with open('data/models/'+ model_name,'rb') as f:
            loaded_model = joblib.load(f)
            
        if loaded_model:
            s = model_column.checkbox('the model is of ipls type?')
            model_column.success("The model has been loaded successfully", icon="")

            if s:
                index = model_column.file_uploader("select wavelengths index file", type="csv")
                if index:
                    idx = pd.read_csv(index, sep=';', index_col=0).iloc[:,0].to_numpy()




if loaded_model:
    if st.button('Predict'):
DIANE's avatar
DIANE committed
        if s:
DIANE's avatar
DIANE committed
            result = loaded_model.predict(pred_data.iloc[:,idx])
DIANE's avatar
DIANE committed
        else:
DIANE's avatar
DIANE committed
            # use prediction function from application_functions.py to predict chemical values
            result = loaded_model.predict(pred_data)
        
        st.write('Predicted values')
        st.dataframe(result.T)
DIANE's avatar
DIANE committed

DIANE's avatar
DIANE committed
        pd.DataFrame(result).to_csv(export_folder + export_name + '.csv', sep = ';')
        # export to local drive - Download
        download_results(export_folder + export_name + '.csv', export_name + '.csv')
        # create a report with information on the prediction
        ## see https://stackoverflow.com/a/59578663
DIANE's avatar
DIANE committed