Skip to content
Snippets Groups Projects
3-prediction.py 3.08 KiB
Newer Older
  • Learn to ignore specific revisions
  • from Packages import *
    st.set_page_config(page_title="NIRS Utils", page_icon=":goat:", layout="wide")
    from Modules import *
    from Class_Mod.DATA_HANDLING import *
    
    # HTML pour le bandeau "CEFE - CNRS"
    bandeau_html = """
    <div style="width: 100%; background-color: #4682B4; padding: 10px; margin-bottom: 10px;">
      <h1 style="text-align: center; color: white;">CEFE - CNRS</h1>
    </div>
    """
    
    # Injecter le code HTML du bandeau
    st.markdown(bandeau_html, unsafe_allow_html=True)
    
    st.session_state["interface"] = st.session_state.get('interface')
    
    DIANE's avatar
    DIANE committed
    
    
    st.header("Predictions making", divider='blue')
    model_column, space, file_column= st.columns((2, 1, 1))
    
    
    #M9, M10, M11 = st.columns([2,2,2])
    
    DIANE's avatar
    DIANE committed
    NIRS_csv = file_column.file_uploader("Select NIRS Data to predict", type="csv", help=" :mushroom: select a csv matrix with samples as rows and lambdas as columns")
    
    export_folder = './data/predictions/'
    export_name = 'Predictions_of_'
    
    reg_algo = ["Interval-PLS"]
    
    if NIRS_csv:
    
            export_name += str(NIRS_csv.name[:-4])
            qsep = file_column.selectbox("Select csv separator - _detected_: " + str(find_delimiter('data/'+NIRS_csv.name)), options=[";", ","], index=[";", ","].index(str(find_delimiter('data/'+NIRS_csv.name))), key=2)
            qhdr = file_column.selectbox("indexes column in csv? - _detected_: " + str(find_col_index('data/'+NIRS_csv.name)), options=["no", "yes"], index=["no", "yes"].index(str(find_col_index('data/'+NIRS_csv.name))), key=3)
    
    DIANE's avatar
    DIANE committed
            if qhdr == 'yes':
                col = 0
            else:
                col = False
            pred_data = pd.read_csv(NIRS_csv, sep=qsep, index_col=col)
    
    
            # Load the model with joblib
            model_column.write("Load your saved predictive model")
    
    DIANE's avatar
    DIANE committed
            
    
            model_name_import = model_column.selectbox('Choose file:', options=os.listdir('data/models/'), key = 21)
            if model_name_import != ' ':
                export_name += '_with_' + str(model_name_import[:-4])
                with open('data/models/'+ model_name_import,'rb') as f:
                    model_loaded = joblib.load(f)
                if model_loaded:
    
    DIANE's avatar
    DIANE committed
                    s = model_column.checkbox('the model is of ipls type?')
    
                    model_column.success("The model has been loaded successfully", icon="")
    
    DIANE's avatar
    DIANE committed
                    if s:
                          index = model_column.file_uploader("select wavelengths index file", type="csv")
                          if index:
                            idx = pd.read_csv(index, sep=';', index_col=0).iloc[:,0].to_numpy()
    
    
    DIANE's avatar
    DIANE committed
    #result = ''
    
    DIANE's avatar
    DIANE committed
    if st.button("Predict"):
            if s:
                 result = model_loaded.predict(pred_data.iloc[:,idx])
            else:
    
            # use prediction function from application_functions.py to predict chemical values
    
    DIANE's avatar
    DIANE committed
                result = model_loaded.predict(pred_data)
    
            st.write('Predicted values are: ')
            st.dataframe(result.T)
    
    DIANE's avatar
    DIANE committed
            pd.DataFrame(result).to_csv(export_folder + export_name + '.csv', sep = ';')
    
            # export to local drive - Download
            download_results(export_folder + export_name + '.csv', export_name + '.csv')
            # create a report with information on the prediction
            ## see https://stackoverflow.com/a/59578663