diff --git a/.gitignore b/.gitignore
index 2015cd5d54387d056b86f31ffd8e464d784072d4..26aaaf09a17d237bb39b785b8eb06847dd7bf7b7 100644
--- a/.gitignore
+++ b/.gitignore
@@ -9,7 +9,9 @@ _lightning_logs/*
 src/models/*.ipynb
 *.swp
 *~
-thirdparties
+thirdparties/*
+iota2_thirdparties
 .coverage
 src/MMDC_SingleDate.egg-info/
 .projectile
+iota2_thirdparties/*
diff --git a/Makefile b/Makefile
index 5692cfc81f56d500719a680d88cdb810aa0a04de..d112a0ae67321f6a6ab3887a00a2155e5d993644 100644
--- a/Makefile
+++ b/Makefile
@@ -74,8 +74,11 @@ test_no_PIL:
 test_mask_loss:
 	$(CONDA) && pytest -vv test/test_masked_losses.py
 
+test_inference:
+	$(CONDA) && pytest -vv test/test_mmdc_inference.py
 
 PYLINT_IGNORED = "pix2pix_module.py,pix2pix_networks.py,mmdc_residual_module.py"
+
 #.PHONY:
 pylint:
 	$(CONDA) && pylint --ignore=$(PYLINT_IGNORED) src/
diff --git a/README.rst b/README.rst
new file mode 100644
index 0000000000000000000000000000000000000000..5bc068184ef56f4b256ff867dfc8fb66e4a0230c
--- /dev/null
+++ b/README.rst
@@ -0,0 +1 @@
+# MMDC-SingleDate
diff --git a/configs/iota2/config_ressources.cfg b/configs/iota2/config_ressources.cfg
new file mode 100644
index 0000000000000000000000000000000000000000..9ef37b3e8271362bc58573623eeee90ac5b68494
--- /dev/null
+++ b/configs/iota2/config_ressources.cfg
@@ -0,0 +1,434 @@
+################################################################################
+#               Configuration file use to set HPC ressources request
+
+# All steps are mandatory
+#
+#                  TEMPLATE
+# step_name:{
+#           name : "IOTA2_dir" #no space characters
+#           nb_cpu : 1
+#           ram : "5gb"
+#           walltime : "00:10:00"#HH:MM:SS
+#           process_min : 1 #not mandatory (default = 1)
+#           process_max : 9 #not mandatory (default = number of tasks)
+#           }
+#
+################################################################################
+
+iota2_chain:{
+             name : "IOTA2"
+             nb_cpu : 1
+             ram : "5gb"
+             walltime : "10:00:00"
+            }
+
+iota2_dir:{
+           name : "IOTA2_dir"
+           nb_cpu : 1
+           ram : "5gb"
+           walltime : "00:10:00"
+           process_min : 1
+          }
+
+preprocess_data : {
+                   name:"preprocess_data"
+                   nb_cpu:40
+                   ram:"180gb"
+                   walltime:"10:00:00"
+                   process_min : 1
+                  }
+
+get_common_mask : {
+                   name:"CommonMasks"
+                   nb_cpu:3
+                   ram:"15gb"
+                   walltime:"00:10:00"
+                   process_min : 1
+                  }
+
+coregistration : {
+                   name:"Coregistration"
+                   nb_cpu:4
+                   ram:"20gb"
+                   walltime:"05:00:00"
+                   process_min : 1
+                  }
+
+get_pixValidity : {
+                   name:"NbView"
+                   nb_cpu:3
+                   ram:"15gb"
+                   walltime:"00:10:00"
+                   process_min : 1
+                   }
+
+envelope : {
+            name:"Envelope"
+            nb_cpu:3
+            ram:"15gb"
+            walltime:"00:10:00"
+            process_min : 1
+            }
+
+regionShape : {
+               name:"regionShape"
+               nb_cpu:3
+               ram:"15gb"
+               walltime:"00:10:00"
+               process_min : 1
+               }
+
+splitRegions : {
+                name:"splitRegions"
+                nb_cpu:3
+                ram:"15gb"
+                walltime:"00:10:00"
+                process_min : 1
+                }
+
+extract_data_region_tiles : {
+                             name:"extract_data_region_tiles"
+                             nb_cpu:3
+                             ram:"15gb"
+                             walltime:"00:10:00"
+                             process_min : 1
+                             }
+
+split_learning_val : {
+                      name:"split_learning_val"
+                      nb_cpu:3
+                      ram:"15gb"
+                      walltime:"00:10:00"
+                      process_min : 1
+                      }
+
+split_learning_val_sub : {
+                          name:"split_learning_val_sub"
+                          nb_cpu:3
+                          ram:"15gb"
+                          walltime:"00:10:00"
+                          process_min : 1
+                          }
+split_samples : {
+                 name:"split_samples"
+                 nb_cpu:3
+                 ram:"15gb"
+                 walltime:"00:10:00"
+                 process_min : 1
+                 }
+
+samplesFormatting : {
+                     name:"samples_formatting"
+                     nb_cpu:1
+                     ram:"5gb"
+                     walltime:"00:10:00"
+                     process_min : 1
+                     }
+
+samplesMerge: {
+                name:"samples_merges"
+                nb_cpu:2
+                ram:"10gb"
+                walltime:"00:10:00"
+                process_min : 1
+               }
+
+samplesStatistics : {
+                   name:"samples_stats"
+                   nb_cpu:3
+                   ram:"15gb"
+                   walltime:"00:10:00"
+                   process_min : 1
+                  }
+
+samplesSelection : {
+                   name:"samples_selection"
+                   nb_cpu:3
+                   ram:"15gb"
+                   walltime:"00:10:00"
+                   process_min : 1
+                  }
+samplesselection_tiles : {
+                   name:"samplesSelection_tiles"
+                   nb_cpu:3
+                   ram:"15gb"
+                   walltime:"00:10:00"
+                   process_min : 1
+                  }
+samplesManagement : {
+                   name:"samples_management"
+                   nb_cpu:3
+                   ram:"15gb"
+                   walltime:"00:10:00"
+                   process_min : 1
+                  }
+
+vectorSampler : {
+                 name:"vectorSampler"
+                 nb_cpu:40
+                 ram:"180gb"
+                 walltime:"10:00:00"
+                 process_min : 1
+                 }
+
+dimensionalityReduction : {
+               name:"dimensionalityReduction"
+               nb_cpu:3
+               nb_MPI_process:2
+               ram:"15gb"
+               nb_chunk:1
+               walltime:"00:20:00"}
+
+samplesAugmentation : {
+                       name:"samplesAugmentation"
+                       nb_cpu:3
+                       ram:"15gb"
+                       walltime:"00:10:00"
+                       }
+
+mergeSample : {
+               name:"mergeSample"
+               nb_cpu:3
+               ram:"15gb"
+               walltime:"00:10:00"
+               process_min : 1
+               }
+
+stats_by_models : {
+                   name:"stats_by_models"
+                   nb_cpu:3
+                   ram:"15gb"
+                   walltime:"00:10:00"
+                   process_min : 1
+                   }
+
+training : {
+            name:"training"
+            nb_cpu:40
+            ram:"180gb"
+            walltime:"10:00:00"
+            process_min : 1
+            }
+
+#training regression
+TrainRegressionPytorch : {
+                 name:"training"
+                 nb_cpu:40
+                 ram:"180gb"
+                 walltime:"30:00:00"
+                 }
+
+PredictRegressionScikit: {
+                 name:"prediction"
+                 nb_cpu:40
+                 ram:"180gb"
+                 walltime:"30:00:00"
+                 }
+
+#training regression
+PredictRegressionPytorch : {
+                 name:"choice"
+                 nb_cpu:40
+                 ram:"180gb"
+                 walltime:"30:00:00"
+                 }
+
+
+cmdClassifications : {
+                      name:"cmdClassifications"
+                      nb_cpu:10
+                      ram:"60gb"
+                      walltime:"10:00:00"
+                      process_min : 1
+                      }
+
+classifications : {
+                   name:"classifications"
+                   nb_cpu:10
+                   ram:"60gb"
+                   walltime:"10:00:00"
+                   process_min : 1
+                   }
+
+classifShaping : {
+                  name:"classifShaping"
+                  nb_cpu:10
+                  ram:"60gb"
+                  walltime:"10:00:00"
+                  process_min : 1
+                  }
+
+gen_confusionMatrix : {
+                       name:"genCmdconfusionMatrix"
+                       nb_cpu:3
+                       ram:"15gb"
+                       walltime:"00:10:00"
+                       process_min : 1
+                       }
+
+confusionMatrix : {
+                   name:"confusionMatrix"
+                   nb_cpu:3
+                   ram:"15gb"
+                   walltime:"00:10:00"
+                   process_min : 1
+                   }
+
+#regression
+GenerateRegressionMetrics : {
+                 name:"classification"
+                 nb_cpu:10
+                 ram:"50gb"
+                 walltime:"40:00:00"
+                 }
+
+
+fusion : {
+          name:"fusion"
+          nb_cpu:3
+          ram:"15gb"
+          walltime:"00:10:00"
+          process_min : 1
+          }
+
+noData : {
+          name:"noData"
+          nb_cpu:3
+          ram:"15gb"
+          walltime:"00:10:00"
+          process_min : 1
+          }
+
+statsReport : {
+               name:"statsReport"
+               nb_cpu:3
+               ram:"15gb"
+               walltime:"00:10:00"
+               process_min : 1
+               }
+
+confusionMatrixFusion : {
+                         name:"confusionMatrixFusion"
+                         nb_cpu:3
+                         ram:"15gb"
+                         walltime:"00:10:00"
+                         process_min : 1
+                         }
+
+merge_final_classifications : {
+                           name:"merge_final_classifications"
+                           nb_cpu:1
+                           ram:"5gb"
+                           walltime:"00:10:00"
+                          }
+
+reportGen : {
+             name:"reportGeneration"
+             nb_cpu:3
+             ram:"15gb"
+             walltime:"00:10:00"
+             process_min : 1
+             }
+
+mergeOutStats : {
+                 name:"mergeOutStats"
+                 nb_cpu:3
+                 ram:"15gb"
+                 walltime:"00:10:00"
+                 process_min : 1
+                 }
+
+SAROptConfusionMatrix : {
+                 name:"SAROptConfusionMatrix"
+                 nb_cpu:3
+                 ram:"15gb"
+                 walltime:"00:10:00"
+                 process_min : 1
+                 }
+
+SAROptConfusionMatrixFusion : {
+                 name:"SAROptConfusionMatrixFusion"
+                 nb_cpu:3
+                 ram:"15gb"
+                 walltime:"00:10:00"
+                 process_min : 1
+                 }
+
+SAROptFusion : {
+                 name:"SAROptFusion"
+                 nb_cpu:3
+                 ram:"15gb"
+                 walltime:"00:10:00"
+                 process_min : 1
+                 }
+clump : {
+         name:"clump"
+         nb_cpu:1
+         ram:"5gb"
+         walltime:"00:10:00"
+         process_min:1
+         }
+
+grid : {
+         name:"grid"
+         nb_cpu:1
+         ram:"5gb"
+         walltime:"00:10:00"
+         process_min:1
+         }
+
+vectorisation : {
+                 name:"vectorisation"
+                 nb_cpu:1
+                 ram:"5gb"
+                 walltime:"00:10:00"
+                 process_min:1
+                 }
+
+crownsearch : {
+                name:"crownsearch"
+                nb_cpu:1
+                ram:"5gb"
+                walltime:"00:10:00"
+                process_min:1
+                }
+
+crownbuild : {
+                name:"crownbuild"
+                nb_cpu:1
+                ram:"5gb"
+                walltime:"00:10:00"
+                process_min:1}
+
+statistics : {
+                name:"statistics"
+                nb_cpu:1
+                ram:"5gb"
+                walltime:"00:10:00"
+                process_min:1
+                }
+
+join : {
+         name:"join"
+         nb_cpu:1
+         ram:"5gb"
+         walltime:"00:10:00"
+         process_min:1
+         }
+
+tiler : {
+         name:"tiler"
+         nb_cpu:8
+         ram:"120gb"
+         walltime:"01:00:00"
+         process_min:1
+
+}
+features_tiler : {
+         name:"features_tiler"
+         nb_cpu:8
+         ram:"120gb"
+         walltime:"01:00:00"
+         process_min:1
+
+}
diff --git a/configs/iota2/config_sar.cfg b/configs/iota2/config_sar.cfg
new file mode 100644
index 0000000000000000000000000000000000000000..07512be098cb3c53f0bef5d62cabb56d8691d435
--- /dev/null
+++ b/configs/iota2/config_sar.cfg
@@ -0,0 +1,19 @@
+[Paths]
+output = /work/CESBIO/projects/MAESTRIA/iota2_mmdc/s1_preprocess
+s1images = /work/CESBIO/projects/MAESTRIA/iota2_mmdc/s1_data_datalake/
+srtm = /datalake/static_aux/MNT/SRTM_30_hgt
+geoidfile = /home/uz/vinascj/src/MMDC/mmdc-data/Geoid/egm96.grd
+
+[Processing]
+tiles = 31TCJ
+tilesshapefile = /work/CESBIO/projects/MAESTRIA/iota2_mmdc/vector_db/mgrs_tiles.shp
+referencesfolder = /work/CESBIO/projects/MAESTRIA/iota2_mmdc/tiled_reference/
+srtmshapefile = /work/CESBIO/projects/MAESTRIA/iota2_mmdc/vector_db/srtm.shp
+rasterpattern = T31TCJ.tif
+gapFilling_interpolation = spline
+temporalresolution = 10
+borderthreshold = 1e-3
+ramperprocess = 5000
+
+[Filtering]
+window_radius = 2
diff --git a/configs/iota2/externalfeatures_with_userfeatures.cfg b/configs/iota2/externalfeatures_with_userfeatures.cfg
new file mode 100644
index 0000000000000000000000000000000000000000..b6eb0faba181ae73ec023d94019fcc3a0063c74d
--- /dev/null
+++ b/configs/iota2/externalfeatures_with_userfeatures.cfg
@@ -0,0 +1,24 @@
+# iota2 configuration file to launch iota2 as a tiler of features.
+# paths with 'XXX' must be replaced by local user paths.
+
+
+chain :
+{
+  output_path : '/home/uz/vinascj/scratch/MMDC/iota2/grid_test_full'
+
+  first_step : 'tiler'
+  last_step : 'tiler'
+
+  proj : 'EPSG:2154'
+  # rasters_grid_path : '/XXX/grid_raster'
+  grid : '/home/uz/vinascj/src/MMDC/mmdc-singledate/thirdparties/sensorsio/src/sensorsio/data/sentinel2/mgrs_tiles.shp' #'/XXXX/Features.shp' # MGRS file providing tiles grid
+  list_tile : 'T31TCJ' # T31TCK T31TDJ'
+  features_path : '/XXX/non_tiled_features'
+
+  spatial_resolution : 10 # mandatory but not used in this case. The output spatial resolution will be the one found in 'rasters_grid_path' directory.
+}
+
+builders:
+{
+builders_class_name : ["i2_features_to_grid"]
+}
diff --git a/configs/iota2/externalfeatures_with_userfeatures_50x50.cfg b/configs/iota2/externalfeatures_with_userfeatures_50x50.cfg
new file mode 100644
index 0000000000000000000000000000000000000000..71bc61d1d25887b551723e21b110a3afe6bb05db
--- /dev/null
+++ b/configs/iota2/externalfeatures_with_userfeatures_50x50.cfg
@@ -0,0 +1,77 @@
+chain :
+{
+  output_path : '/home/uz/vinascj/scratch/MMDC/iota2/juan_i2_data/results/external_user_features'
+  remove_output_path : False
+  check_inputs : False
+
+  nomenclature_path : '/home/uz/vinascj/scratch/MMDC/iota2/juan_i2_data/inputs/other/nomenclature_grosse.txt'
+  color_table : '/home/uz/vinascj/scratch/MMDC/iota2/juan_i2_data/inputs/other/colorFile_grosse.txt'
+
+  list_tile : 'T31TCJ'
+  ground_truth : '/home/uz/vinascj/scratch/MMDC/iota2/juan_i2_data/inputs/vector_data/S2_50x50.shp'
+  data_field : 'groscode'
+  s2_path : '/home/uz/vinascj/scratch/MMDC/iota2/juan_i2_data/inputs/raster_data/S2_2dates_50x50_symlink'
+  # user_feat_path: '/home/uz/vinascj/scratch/MMDC/iota2/juan_i2_data/inputs/raster_data/mnt_50x50'
+  user_feat_path: '/home/uz/vinascj/scratch/MMDC/iota2/grid_test_full'
+  # s2_output_path : '/home/uz/vinascj/scratch/MMDC/iota2/juan_i2_data/inputs/raster_data/S2_2dates_50x50_output'
+  s1_path : '/home/uz/vinascj/scratch/MMDC/iota2/juan_i2_data/inputs/config/SAR_test_T31TCJ.cfg'
+
+  srtm_path:'/datalake/static_aux/MNT/SRTM_30_hgt'
+  worldclim_path:'/datalake/static_aux/worldclim-2.0'
+  grid : '/home/uz/vinascj/src/MMDC/mmdc-singledate/thirdparties/sensorsio/src/sensorsio/data/sentinel2/mgrs_tiles.shp' #'/XXXX/Features.shp' # MGRS file providing tiles grid
+  tile_field : 'Name'
+
+  spatial_resolution : 10
+  first_step : 'init'
+  last_step : 'validation'
+  proj : 'EPSG:2154'
+}
+userFeat:
+{
+  arbo:"/*"
+  patterns:"aspect,EVEN_AZIMUTH,ODD_ZENITH,s1_incidence_ascending,worldclim,AZIMUTH,EVEN_ZENITH,s1_azimuth_ascending,s1_incidence_descending,ZENITH,elevation,ODD_AZIMUTH,s1_azimuth_descending,slope"
+}
+python_data_managing :
+{
+    padding_size_x : 1
+    padding_size_y : 1
+    chunk_size_mode:"split_number"
+    number_of_chunks:2
+    data_mode_access: "both"
+}
+
+external_features :
+{
+  module:"/home/uz/vinascj/scratch/MMDC/iota2/juan_i2_data/inputs/other/soi.py"
+  functions : [['test_mnt', {"arg1":1}]]
+  # functions : [['test_s2_concat', {"arg1":1}]]
+  concat_mode : True
+  external_features_flag:True
+}
+arg_classification:
+{
+ enable_probability_map:True
+ generate_final_probability_map:True
+}
+arg_train :
+{
+
+  runs : 1
+  random_seed : 1
+  classifier:"sharkrf"
+
+  sample_selection :
+  {
+    sampler : 'random'
+    strategy : 'percent'
+    'strategy.percent.p' : 1.0
+    ram : 4000
+  }
+}
+
+task_retry_limits :
+{
+  allowed_retry : 0
+  maximum_ram : 180.0
+  maximum_cpu : 40
+}
diff --git a/configs/iota2/i2_classification.cfg b/configs/iota2/i2_classification.cfg
new file mode 100644
index 0000000000000000000000000000000000000000..71c1b7284e9759e95b520311ec2e0c3beb72c2d9
--- /dev/null
+++ b/configs/iota2/i2_classification.cfg
@@ -0,0 +1,61 @@
+chain :
+{
+  output_path : '/work/scratch/sowm/IOTA2_TEST_S2/IOTA2_Outputs/Results_classif'
+  remove_output_path : True
+  nomenclature_path : '/work/scratch/sowm/IOTA2_TEST_S2/nomenclature23.txt'
+  list_tile : 'T31TCJ T31TDJ'
+  s2_path : '/work/scratch/sowm/IOTA2_TEST_S2/sensor_data'
+  s1_path : '/work/scratch/sowm/test-data-cube/iota2/s1_full.cfg'
+  ground_truth : '/work/scratch/sowm/IOTA2_TEST_S2/vector_data/reference_data.shp'
+  data_field : 'code'
+  spatial_resolution : 10
+  color_table : '/work/scratch/sowm/IOTA2_TEST_S2/colorFile.txt'
+  proj : 'EPSG:2154'
+  first_step: "init"
+  last_step: "validation"
+}
+sensors_data_interpolation:
+{
+  write_outputs: False
+}
+
+external_features:
+{
+
+  module: "/home/qt/sowm/scratch/test-data-cube/iota2/external_iota2_code.py"
+  functions: [["apply_convae",
+              { "checkpoint_path": "/home/qt/sowm/scratch/results/Convae_split_roi_psz_128_ep_150_bs_16_lr_1e-05_s1_[64, 128, 256, 512, 1024]_[8]_s2_[64, 128, 256, 512, 1024]_[8]_21916652.admin01",
+                "checkpoint_epoch": 100,
+                "patch_size": 74
+              }]]
+  concat_mode: True
+}
+
+python_data_managing:
+{
+ chunk_size_mode: "split_number"
+ number_of_chunks: 20
+ padding_size_y: 27
+
+}
+
+arg_train :
+{
+  classifier : 'sharkrf'
+  otb_classifier_options : {"classifier.sharkrf.ntrees" : 100 }
+  sample_selection: {"sampler": "random",
+  "strategy": "percent",
+  "strategy.percent.p": 0.1}
+  random_seed: 42
+
+}
+arg_classification :
+{
+  classif_mode : 'separate'
+}
+task_retry_limits :
+{
+  allowed_retry : 0
+  maximum_ram : 180.0
+  maximum_cpu : 40
+}
diff --git a/configs/iota2/i2_grid.cfg b/configs/iota2/i2_grid.cfg
new file mode 100644
index 0000000000000000000000000000000000000000..e13f87883d2ddfe06eb41e448d2a53c83b2911d9
--- /dev/null
+++ b/configs/iota2/i2_grid.cfg
@@ -0,0 +1,19 @@
+chain :
+{
+  output_path : '/work/scratch/vinascj/MMDC/iota2/grid_test'
+  spatial_resolution : 100 # output target resolution
+  first_step : 'tiler' # do not change
+  last_step : 'tiler' # do not change
+
+  proj : 'EPSG:2154' # output target crs
+  grid : '/home/uz/vinascj/src/MMDC/mmdc-singledate/thirdparties/sensorsio/src/sensorsio/data/sentinel2/mgrs_tiles.shp' #'/XXXX/Features.shp' # MGRS file providing tiles grid
+  tile_field : 'Name'
+  list_tile : '31TCK 31TCJ' # list of tiles in grid to build
+  srtm_path:'/datalake/static_aux/MNT/SRTM_30_hgt'
+  worldclim_path:'/datalake/static_aux/worldclim-2.0'
+}
+
+builders:
+{
+builders_class_name : ["i2_features_to_grid"]
+}
diff --git a/configs/iota2/i2_tiler.cfg b/configs/iota2/i2_tiler.cfg
new file mode 100644
index 0000000000000000000000000000000000000000..193d757d7bf28c475f36cf673d4cb32c8ff0c2a1
--- /dev/null
+++ b/configs/iota2/i2_tiler.cfg
@@ -0,0 +1,22 @@
+chain :
+{
+  output_path : '/work/CESBIO/projects/MAESTRIA/iota2_mmdc/tiled_auxilar_data'
+  spatial_resolution : 10
+  first_step : 'tiler'
+  last_step : 'tiler'
+
+  proj : 'EPSG:2154'
+  rasters_grid_path : '/work/CESBIO/projects/MAESTRIA/iota2_mmdc/tiled_reference/T31TCJ'
+  tile_field : 'Name'
+  list_tile : 'T31TCJ'
+  srtm_path:'/datalake/static_aux/MNT/SRTM_30_hgt'
+  worldclim_path:'/datalake/static_aux/worldclim-2.0'
+  s2_path : '/work/CESBIO/projects/MAESTRIA/iota2_mmdc/s2_data_datalake'
+  # add S1 data
+  s1_dir:"/work/CESBIO/projects/MAESTRIA/iota2_mmdc/s1_data_datalake"
+}
+
+builders:
+{
+builders_class_name : ["i2_features_to_grid"]
+}
diff --git a/configs/iota2/iota2_grid_full.cfg b/configs/iota2/iota2_grid_full.cfg
new file mode 100644
index 0000000000000000000000000000000000000000..d5379d4997ca4edf95e8ea310d76361999f93483
--- /dev/null
+++ b/configs/iota2/iota2_grid_full.cfg
@@ -0,0 +1,21 @@
+chain :
+{
+  output_path : '/home/uz/vinascj/scratch/MMDC/iota2/grid_test_full'
+  spatial_resolution : 10
+  first_step : 'tiler'
+  last_step : 'tiler'
+
+  proj : 'EPSG:2154'
+  rasters_grid_path : '/work/OT/theia/oso/arthur/TMP/raster_grid'
+  tile_field : 'Name'
+  list_tile : 'T31TCJ'
+  srtm_path:'/datalake/static_aux/MNT/SRTM_30_hgt'
+  worldclim_path:'/datalake/static_aux/worldclim-2.0'
+  s2_path:'/work/OT/theia/oso/arthur/TMP/test_s2_angles'
+  s1_dir:'/work/OT/theia/oso/arthur/s1_emma/all_tiles'
+}
+
+builders:
+{
+builders_class_name : ["i2_features_to_grid"]
+}
diff --git a/configs/iota2/iota2_mmdc_full.cfg b/configs/iota2/iota2_mmdc_full.cfg
new file mode 100644
index 0000000000000000000000000000000000000000..fb651a84712caf32e05b65e6a5b196aca2e2357a
--- /dev/null
+++ b/configs/iota2/iota2_mmdc_full.cfg
@@ -0,0 +1,76 @@
+# Entry Point Iota2 Config
+chain :
+{
+  # Input and Output data
+  output_path : '/work/CESBIO/projects/MAESTRIA/iota2_mmdc/iota2_results'
+  remove_output_path : False
+  check_inputs : False
+  list_tile : 'T31TCJ'
+  data_field : 'code'
+  s2_path : '/work/scratch/vinascj/MMDC/iota2/i2_training_data/s2_data_full'
+  ground_truth : '/work/CESBIO/projects/MAESTRIA/iota2_mmdc/vector_db/reference_data.shp'
+
+  # Add S1 data
+  s1_path : "/home/uz/vinascj/src/MMDC/mmdc-singledate/configs/iota2/config_sar.cfg"
+
+  # Classification related parameters
+  spatial_resolution : 10
+  color_table : '/work/CESBIO/projects/MAESTRIA/iota2_mmdc/colorFile.txt'
+  nomenclature_path : '/work/CESBIO/projects/MAESTRIA/iota2_mmdc/nomenclature.txt'
+  first_step : 'init'
+  last_step : 'validation'
+  proj : 'EPSG:2154'
+
+  # Path to the User data
+  user_feat_path: '/work/CESBIO/projects/MAESTRIA/iota2_mmdc/tiled_auxilar_data' #/work/CESBIO/projects/MAESTRIA/iota2_mmdc/tiled_auxilar_data/T31TCJ
+}
+
+# Parametrize the classifier
+arg_train :
+{
+  random_seed : 0 # Set the random seed to split the ground truth
+  runs : 1
+  classifier : 'sharkrf'
+  otb_classifier_options : {'classifier.sharkrf.nodesize': 25}
+  sample_selection :
+  {
+    sampler : 'random'
+    strategy : 'percent'
+    'strategy.percent.p' : 0.1
+  }
+}
+
+# Use User Provided data
+userFeat:
+{
+  arbo:"/*"
+  patterns:"elevation,aspect,slope,worldclim,EVEN_AZIMUTH,ODD_ZENITH,s1_incidence_ascending,worldclim,AZIMUTH,EVEN_ZENITH,s1_azimuth_ascending,s1_incidence_descending,ZENITH,elevation,ODD_AZIMUTH,s1_azimuth_descending,"
+}
+
+# Parameters for the user provided data
+python_data_managing :
+{
+    padding_size_x : 1
+    padding_size_y : 1
+    chunk_size_mode:"split_number"
+    number_of_chunks:4
+    data_mode_access: "both"
+}
+
+# Process the User and Iota2 data
+external_features :
+{
+  # module:"/home/uz/vinascj/scratch/MMDC/iota2/juan_i2_data/inputs/other/soi.py"
+  module:"/home/uz/vinascj/src/MMDC/mmdc-singledate/src/mmdc_singledate/inference/mmdc_iota2_inference.py"
+  functions : [['test_mnt', {"arg1":1}]]
+  concat_mode : True
+  external_features_flag:True
+}
+
+# Parametrize nodes
+task_retry_limits :
+{
+  allowed_retry : 0
+  maximum_ram : 180.0
+  maximum_cpu : 40
+}
diff --git a/create-conda-env.sh b/create-conda-env.sh
index 36d9e095ac741a8f583be21825b83183102951ce..61f3fc75559987423d0242d841d339532bea7f91 100755
--- a/create-conda-env.sh
+++ b/create-conda-env.sh
@@ -51,6 +51,8 @@ mamba install --yes "pytorch>=2.0.0.dev202302=py3.10_cuda11.7_cudnn8.5.0_0" "tor
 conda deactivate
 conda activate $target
 
+# proxy
+source ~/set_proxy_iota2.sh
 # Requirements
 pip install -r requirements-mmdc-sgld-test.txt
 
diff --git a/create-iota2-env.sh b/create-iota2-env.sh
new file mode 100644
index 0000000000000000000000000000000000000000..5f3b99ee1a40f098a0648e34b6892558a7dbde8c
--- /dev/null
+++ b/create-iota2-env.sh
@@ -0,0 +1,92 @@
+#!/usr/bin/env bash
+
+export python_version="3.9"
+export name="mmdc-iota2"
+if ! [ -z "$1" ]
+then
+    export name=$1
+fi
+
+source ~/set_proxy_iota2.sh
+if [ -z "$https_proxy" ]
+then
+    echo "Please set https_proxy environment variable before running this script"
+    exit 1
+fi
+
+export target=/work/scratch/$USER/virtualenv/$name
+
+if ! [ -z "$2" ]
+then
+    export target="$2/$name"
+fi
+
+echo "Installing $name in $target ..."
+
+if [ -d "$target" ]; then
+   echo "Cleaning previous conda env"
+   rm -rf $target
+fi
+
+# Create blank virtualenv
+module purge
+module load conda
+module load gcc
+conda activate
+conda create --yes --prefix $target python==${python_version} pip
+
+# Enter virtualenv
+conda deactivate
+conda activate $target
+
+# install mamba
+conda install mamba
+
+which python
+python --version
+
+# install libsvm
+mamba install -c conda-forge libsvm=325
+
+# Install iota2
+#mamba install iota2_develop=257da617 -c iota2
+mamba install iota2 -c iota2
+
+# clone the lastest version of iota2
+rm -rf iota2_thirdparties/iota2
+git clone -b issue#600/tile_exogenous_features https://framagit.org/iota2-project/iota2.git iota2_thirdparties/iota2
+cd iota2_thirdparties/iota2
+git checkout issue#600/tile_exogenous_features
+git pull origin issue#600/tile_exogenous_features
+cd ../../
+pwd
+
+# create a backup
+cd  /home/uz/$USER/scratch/virtualenv/mmdc-iota2/lib/python3.9/site-packages/iota2-0.0.0-py3.9.egg/
+mv iota2 iota2_backup
+
+# create a symbolic link
+ln -s ~/src/MMDC/mmdc-singledate/iota2_thirdparties/iota2/iota2/ iota2
+
+cd ~/src/MMDC/mmdc-singledate
+
+mamba install -c conda-forge pydantic
+
+# install missing dependancies
+pip install -r requirements-mmdc-iota2.txt
+
+# Install sensorsio
+rm -rf iota2_thirdparties/sensorsio
+git clone -b handle_local_srtm https://framagit.org/iota2-project/sensorsio.git iota2_thirdparties/sensorsio
+pip install iota2_thirdparties/sensorsio
+
+# Install torchutils
+rm -rf iota2_thirdparties/torchutils
+git clone https://src.koda.cnrs.fr/mmdc/torchutils.git iota2_thirdparties/torchutils
+pip install iota2_thirdparties/torchutils
+
+# Install the current project in edit mode
+pip install -e .[testing]
+
+# End
+conda deactivate
diff --git a/iota2_init.sh b/iota2_init.sh
new file mode 100644
index 0000000000000000000000000000000000000000..4d1f386fd738ade3e88092fc7c9009e21d9af48a
--- /dev/null
+++ b/iota2_init.sh
@@ -0,0 +1,3 @@
+module purge
+module load conda
+conda activate /work/scratch/${USER}/virtualenv/mmdc-iota2
diff --git a/jobs/inference_single_date.pbs b/jobs/inference_single_date.pbs
new file mode 100644
index 0000000000000000000000000000000000000000..1374be032667395ef6155e5acae186735a34b252
--- /dev/null
+++ b/jobs/inference_single_date.pbs
@@ -0,0 +1,35 @@
+#!/bin/bash
+#PBS -N infer_sgdt
+#PBS -q qgpgpu
+#PBS -l select=1:ncpus=8:mem=92G:ngpus=1
+#PBS -l walltime=1:00:00
+
+# be sure no modules loaded
+module purge
+
+export SRCDIR=${HOME}/src/MMDC/mmdc-singledate
+export MMDC_INIT=${SRCDIR}/mmdc_init.sh
+export WORKING_DIR=/work/scratch/${USER}/MMDC/jobs
+export DATA_DIR=/work/CESBIO/projects/MAESTRIA/test_onetile/total/
+export SCRATCH_DIR=/work/scratch/${USER}/
+
+cd ${SRCDIR}
+
+mkdir ${SCRATCH_DIR}/MMDC/inference/
+mkdir ${SCRATCH_DIR}/MMDC/inference/singledate
+
+source ${MMDC_INIT}
+
+python ${SRCDIR}/src/bin/inference_mmdc_singledate.py \
+    --s2_filename ${DATA_DIR}/T31TCJ/SENTINEL2A_20180302-105023-464_L2A_T31TCJ_C_V2-2_roi_0.tif \
+    --s1_asc_filename ${DATA_DIR}/T31TCJ/S1A_IW_GRDH_1SDV_20180303T174722_20180303T174747_020854_023C45_4525_36.09887223808782_15.429479982324139_roi_0.tif \
+    --s1_desc_filename ${DATA_DIR}/T31TCJ/S1A_IW_GRDH_1SDV_20180302T060027_20180302T060052_020832_023B8F_C485_40.83835645911643_165.05888005216622_roi_0.tif \
+    --srtm_filename ${DATA_DIR}/T31TCJ/srtm_T31TCJ_roi_0.tif \
+    --worldclim_filename ${DATA_DIR}/T31TCJ/wc_clim_1_T31TCJ_roi_0.tif \
+    --export_path ${SCRATCH_DIR}/MMDC/inference/singledate \
+    --model_checkpoint_path /home/uz/vinascj/scratch/MMDC/results/latent/logs/experiments/runs/mmdc_full/2023-03-16_14-00-10/checkpoints/last.ckpt \
+    --model_config_path /home/uz/vinascj/src/MMDC/mmdc-singledate/configs/model/mmdc_full.yaml \
+    --latent_space_size 16 \
+    --patch_size 128 \
+    --nb_lines 256 \
+    --sensors S2L2A S1FULL \
diff --git a/jobs/inference_time_serie.pbs b/jobs/inference_time_serie.pbs
new file mode 100644
index 0000000000000000000000000000000000000000..0aaf57e27a722be7a9463c1067a8425e7c5f53ee
--- /dev/null
+++ b/jobs/inference_time_serie.pbs
@@ -0,0 +1,32 @@
+#!/bin/bash
+#PBS -N infer_sgdt
+#PBS -q qgpgpu
+#PBS -l select=1:ncpus=8:mem=92G:ngpus=1
+#PBS -l walltime=2:00:00
+
+# be sure no modules loaded
+module purge
+
+export SRCDIR=${HOME}/src/MMDC/mmdc-singledate
+export MMDC_INIT=${SRCDIR}/mmdc_init.sh
+export WORKING_DIR=/work/scratch/${USER}/MMDC/jobs
+export DATA_DIR=/work/CESBIO/projects/MAESTRIA/training_dataset2/
+export SCRATCH_DIR=/work/scratch/${USER}/
+
+
+mkdir ${SCRATCH_DIR}/MMDC/inference/
+mkdir ${SCRATCH_DIR}/MMDC/inference/time_serie
+
+
+cd ${SRCDIR}
+source ${MMDC_INIT}
+
+python ${SRCDIR}/src/bin/inference_mmdc_timeserie.py \
+    --input_path ${DATA_DIR} \
+    --export_path ${SCRATCH_DIR}/MMDC/inference/time_serie \
+    --tile_list T33TUL \
+    --model_checkpoint_path /home/uz/vinascj/scratch/MMDC/results/latent/logs/experiments/runs/mmdc_full/2023-03-16_14-00-10/checkpoints/last.ckpt \
+    --model_config_path /home/uz/vinascj/src/MMDC/mmdc-singledate/configs/model/mmdc_full.yaml \
+    --days_gap 0 \
+    --patch_size 128 \
+    --nb_lines 256 \
diff --git a/jobs/iota2_aux_angles_test.pbs b/jobs/iota2_aux_angles_test.pbs
new file mode 100644
index 0000000000000000000000000000000000000000..6eaacfe2b80aa45250bccc366d768aae11968d32
--- /dev/null
+++ b/jobs/iota2_aux_angles_test.pbs
@@ -0,0 +1,20 @@
+#!/bin/bash
+#PBS -N iota2-test
+#PBS -l select=1:ncpus=8:mem=12G
+#PBS -l walltime=1:00:00
+
+# be sure no modules loaded
+conda deactivate
+module purge
+
+# load modules
+module load conda
+conda activate /work/scratch/${USER}/virtualenv/mmdc-iota2
+
+
+Iota2.py -scheduler_type debug \
+    -config ${HOME}/src/MMDC/mmdc-singledate/configs/iota2/iota2_grid_angles.cfg
+    #-nb_parallel_tasks 1 \
+    #-only_summary
+
+# /home/uz/vinascj/src/MMDC/mmdc-singledate/configs/iota2/externalfeatures_with_userfeatures_50x50.cfg \
diff --git a/jobs/iota2_aux_full.pbs b/jobs/iota2_aux_full.pbs
new file mode 100644
index 0000000000000000000000000000000000000000..2e438ed230ab02a57c2423491c4f0b270b4215b7
--- /dev/null
+++ b/jobs/iota2_aux_full.pbs
@@ -0,0 +1,20 @@
+#!/bin/bash
+#PBS -N iota2-test
+#PBS -l select=1:ncpus=8:mem=12G
+#PBS -l walltime=1:00:00
+
+# be sure no modules loaded
+conda deactivate
+module purge
+
+# load modules
+module load conda
+conda activate /work/scratch/${USER}/virtualenv/mmdc-iota2
+
+
+Iota2.py -scheduler_type debug \
+    -config ${HOME}/src/MMDC/mmdc-singledate/configs/iota2/iota2_grid_full.cfg
+    #-nb_parallel_tasks 1 \
+    #-only_summary
+
+# /home/uz/vinascj/src/MMDC/mmdc-singledate/configs/iota2/externalfeatures_with_userfeatures_50x50.cfg \
diff --git a/jobs/iota2_aux_test.pbs b/jobs/iota2_aux_test.pbs
new file mode 100644
index 0000000000000000000000000000000000000000..3b3eed36098a0dadbac5933a6100d5321cdd6b7c
--- /dev/null
+++ b/jobs/iota2_aux_test.pbs
@@ -0,0 +1,20 @@
+#!/bin/bash
+#PBS -N iota2-test
+#PBS -l select=1:ncpus=8:mem=12G
+#PBS -l walltime=1:00:00
+
+# be sure no modules loaded
+conda deactivate
+module purge
+
+# load modules
+module load conda
+conda activate /work/scratch/${USER}/virtualenv/mmdc-iota2
+
+
+Iota2.py -scheduler_type debug \
+    -config ${HOME}/src/MMDC/mmdc-singledate/configs/iota2/i2_grid.cfg
+    #-nb_parallel_tasks 1 \
+    #-only_summary
+
+# /home/uz/vinascj/src/MMDC/mmdc-singledate/configs/iota2/externalfeatures_with_userfeatures_50x50.cfg \
diff --git a/jobs/iota2_clasif_mmdc_full.pbs b/jobs/iota2_clasif_mmdc_full.pbs
new file mode 100644
index 0000000000000000000000000000000000000000..e3ec1e9731521f71acb8957c61d5e4f6c3840d5b
--- /dev/null
+++ b/jobs/iota2_clasif_mmdc_full.pbs
@@ -0,0 +1,120 @@
+#!/bin/bash
+#PBS -N iota2-class
+#PBS -l select=1:ncpus=8:mem=12G
+#PBS -l walltime=10:00:00
+
+# be sure no modules loaded
+conda deactivate
+module purge
+
+
+ulimit -u 5000
+
+# export ROOT_PATH="$(dirname $( cd "$(dirname "${BASH_SOURCE[0]}")"; pwd -P ))"
+# export OTB_APPLICATION_PATH=$OTB_APPLICATION_PATH:/work/scratch/$USER/virtualenv/mmdc-iota2/lib/otb/applications//
+# export PATH=$PATH:$ROOT_PATH/bin:$OTB_APPLICATION_PATH
+# export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$ROOT_PATH/lib:$OTB_APPLICATION_PATH #
+#
+# load modules
+module load conda/4.12.0
+conda activate /work/scratch/${USER}/virtualenv/mmdc-iota2
+
+# Apply Iota2
+Iota2.py \
+    -config ${HOME}/src/MMDC/mmdc-singledate/configs/iota2/iota2_mmdc_full.cfg \
+    -scheduler_type PBS \
+    -config_ressources /home/uz/vinascj/src/MMDC/mmdc-singledate/configs/iota2/config_ressources.cfg \
+    # -scheduler_type debug \
+    # -only_summary
+    # -nb_parallel_tasks 1 \
+
+
+# New
+# Functions availables
+#['__class__',
+#'__delattr__',
+#'__dict__',
+#'__dir__',
+#'__doc__',
+#'__eq__',
+#'__format__',
+#'__ge__',
+#'__getattribute__',
+#'__gt__',
+#'__hash__',
+#'__init__',
+#'__init_subclass__',
+#'__le__',
+#'__lt__',
+#'__module__',
+#'__ne__',
+#'__new__',
+#'__reduce__',
+#'__reduce_ex__',
+#'__repr__',
+#'__setattr__',
+#'__sizeof__',
+#'__str__',
+#'__subclasshook__',
+#'__weakref__',
+#'all_dates',
+#'allow_nans',
+#'binary_masks',
+#'concat_mode',
+#'dim_ts',
+#'enabled_gap',
+#'exogeneous_data_array',
+#'exogeneous_data_name',
+#'external_functions',
+#'fill_missing_dates',
+#'func_param',
+#'get_Sentinel1_ASC_vh_binary_masks',
+#'get_Sentinel1_ASC_vv_binary_masks',
+#'get_Sentinel1_DES_vh_binary_masks',
+#'get_Sentinel1_DES_vv_binary_masks',
+#'get_Sentinel2_binary_masks',
+#'get_filled_masks',
+#'get_filled_stack',
+#'get_interpolated_Sentinel1_ASC_vh',
+#'get_interpolated_Sentinel1_ASC_vv',
+#'get_interpolated_Sentinel1_DES_vh',
+#'get_interpolated_Sentinel1_DES_vv',
+#'get_interpolated_Sentinel2_B11',
+#'get_interpolated_Sentinel2_B12',
+#'get_interpolated_Sentinel2_B2',
+#'get_interpolated_Sentinel2_B3',
+#'get_interpolated_Sentinel2_B4',
+#'get_interpolated_Sentinel2_B5',
+#'get_interpolated_Sentinel2_B6',
+#'get_interpolated_Sentinel2_B7',
+#'get_interpolated_Sentinel2_B8',
+#'get_interpolated_Sentinel2_B8A',
+#'get_interpolated_Sentinel2_Brightness',
+#'get_interpolated_Sentinel2_NDVI',
+#'get_interpolated_Sentinel2_NDWI',
+#'get_interpolated_dates',
+#'get_interpolated_userFeatures_aspect',
+#'get_raw_Sentinel1_ASC_vh',
+#'get_raw_Sentinel1_ASC_vv',
+#'get_raw_Sentinel1_DES_vh', 'get_raw_Sentinel1_DES_vv',
+#'get_raw_Sentinel2_B11',
+#'get_raw_Sentinel2_B12',
+#'get_raw_Sentinel2_B2',
+#'get_raw_Sentinel2_B3',
+#'get_raw_Sentinel2_B4',
+#'get_raw_Sentinel2_B5',
+#'get_raw_Sentinel2_B6',
+#'get_raw_Sentinel2_B7',
+#'get_raw_Sentinel2_B8',
+#'get_raw_Sentinel2_B8A',
+#'get_raw_dates',
+#'get_raw_userFeatures_aspect',
+#'interpolated_data',
+#'interpolated_dates',
+#'missing_masks_values',
+#'missing_refl_values',
+#'out_data',
+#'process',
+#'raw_data',
+#'raw_dates',
+#'test_user_feature_with_fake_data']
diff --git a/jobs/iota2_external_feature_test.pbs b/jobs/iota2_external_feature_test.pbs
new file mode 100644
index 0000000000000000000000000000000000000000..f8b6953c9eccb6981125153d07ae459d9d51a2e0
--- /dev/null
+++ b/jobs/iota2_external_feature_test.pbs
@@ -0,0 +1,45 @@
+#!/bin/bash
+#PBS -N iota2-test
+#PBS -l select=1:ncpus=8:mem=12G
+#PBS -l walltime=1:00:00
+
+# be sure no modules loaded
+conda deactivate
+module purge
+
+# load modules
+module load conda
+conda activate /work/scratch/${USER}/virtualenv/mmdc-iota2
+# conda activate /work/scratch/${USER}/virtualenv/mmdc-iota2-develop
+
+
+Iota2.py -scheduler_type debug \
+    -config ${HOME}/src/MMDC/mmdc-singledate/configs/iota2/externalfeatures_with_userfeatures_50x50.cfg \
+    -nb_parallel_tasks 1 \
+    # -only_summary
+
+
+
+# Functions availables
+# ['__class__', '__delattr__', '__dict__', '__dir__', '__doc__',
+# '__eq__', '__format__', '__ge__', '__getattribute__', '__gt__',
+# '__hash__', '__init__', '__init_subclass__', '__le__', '__lt__',
+# '__module__', '__ne__', '__new__', '__reduce__', '__reduce_ex__',
+# '__repr__', '__setattr__', '__sizeof__', '__str__', '__subclasshook__',
+# '__weakref__', 'all_dates', 'allow_nans', 'binary_masks', 'concat_mode',
+# 'dim_ts', 'enabled_gap', 'exogeneous_data_array', 'exogeneous_data_name',
+# 'external_functions', 'fill_missing_dates', 'func_param',
+# 'get_Sentinel2_binary_masks', 'get_filled_masks', 'get_filled_stack',
+# 'get_interpolated_Sentinel2_B11', 'get_interpolated_Sentinel2_B12',
+# 'get_interpolated_Sentinel2_B2', 'get_interpolated_Sentinel2_B3',
+# 'get_interpolated_Sentinel2_B4', 'get_interpolated_Sentinel2_B5',
+# 'get_interpolated_Sentinel2_B6', 'get_interpolated_Sentinel2_B7',
+# 'get_interpolated_Sentinel2_B8', 'get_interpolated_Sentinel2_B8A',
+# 'get_interpolated_Sentinel2_Brightness', 'get_interpolated_Sentinel2_NDVI',
+# 'get_interpolated_Sentinel2_NDWI', 'get_interpolated_dates', 'get_raw_Sentinel2_B11',
+# 'get_raw_Sentinel2_B12', 'get_raw_Sentinel2_B2', 'get_raw_Sentinel2_B3',
+# 'get_raw_Sentinel2_B4', 'get_raw_Sentinel2_B5', 'get_raw_Sentinel2_B6',
+# 'get_raw_Sentinel2_B7', 'get_raw_Sentinel2_B8', 'get_raw_Sentinel2_B8A',
+# 'get_raw_dates', 'interpolated_data', 'interpolated_dates', 'missing_masks_values',
+# 'missing_refl_values', 'out_data', 'process', 'raw_data', 'raw_dates',
+# 'test_user_feature_with_fake_data']
diff --git a/jobs/iota2_tile_aux_features.pbs b/jobs/iota2_tile_aux_features.pbs
new file mode 100644
index 0000000000000000000000000000000000000000..2507bf3622735f451fd4f5ddca3bcf957378f1b6
--- /dev/null
+++ b/jobs/iota2_tile_aux_features.pbs
@@ -0,0 +1,18 @@
+#!/bin/bash
+#PBS -N iota2-tiler
+#PBS -l select=1:ncpus=8:mem=120G
+#PBS -l walltime=4:00:00
+
+# be sure no modules loaded
+conda deactivate
+module purge
+
+# load modules
+module load conda
+conda activate /work/scratch/${USER}/virtualenv/mmdc-iota2
+
+
+Iota2.py -scheduler_type debug \
+    -config ${HOME}/src/MMDC/mmdc-singledate/configs/iota2/i2_tiler.cfg
+    -nb_parallel_tasks 1 \
+    -only_summary
diff --git a/requirements-mmdc-iota2.txt b/requirements-mmdc-iota2.txt
new file mode 100644
index 0000000000000000000000000000000000000000..619064374381c3f90897b1041ece6b7f67ac77eb
--- /dev/null
+++ b/requirements-mmdc-iota2.txt
@@ -0,0 +1,3 @@
+config==0.5.1
+itk
+pydantic
diff --git a/requirements-mmdc-sgld.txt b/requirements-mmdc-sgld.txt
index 02defbac5b66055577c13f6a2bdbb69a8d055c06..c4b218d12589530ee4bc2a8bb5323b799f95346e 100644
--- a/requirements-mmdc-sgld.txt
+++ b/requirements-mmdc-sgld.txt
@@ -1,6 +1,7 @@
 affine
 auto-walrus
 black           # code formatting
+config
 findpeaks
 flake8          # code analysis
 geopandas
@@ -18,6 +19,8 @@ numpy
 pandas
 pre-commit      # hooks for applying linters on commit
 pudb            # debugger
+pydantic        # validate configurations in iota2
+pydantic-yaml   # read yml configs for construct the networks
 pytest          # tests
 python-dotenv   # loading env variables from .env file
 python-lsp-server[all]
diff --git a/setup.cfg b/setup.cfg
index 9c5469ae7b2d4c444e78501d06a55285fa81dc5e..3cce8fa609ddec17b800acd1c9e5742fad738506 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -68,6 +68,14 @@ testing =
     pytest-cov
 
 [options.entry_points]
+console_scripts =
+     mmdc_generate_dataset = bin.generate_mmdc_dataset:main
+     mmdc_visualize = bin.visualize_mmdc_ds:main
+     mmdc_split_dataset = bin.split_mmdc_dataset:main
+     mmdc_inference_singledate = bin.inference_mmdc_singledate:main
+     mmdc_inference_time_serie = bin.inference_mmdc_timeserie:main
+
+# mmdc_inference = bin.infer_mmdc_ds:main
 # Add here console scripts like:
 # console_scripts =
 #     script_name = mmdc-singledate.module:function
diff --git a/src/bin/generate_mmdc_dataset.py b/src/bin/generate_mmdc_dataset.py
index d801c56694537e2f86e830bfb7a10ba024b524c8..ada7c11b64e258f1b8f3a9cacd0c0b683b0eb1d6 100644
--- a/src/bin/generate_mmdc_dataset.py
+++ b/src/bin/generate_mmdc_dataset.py
@@ -81,7 +81,7 @@ def get_parser() -> argparse.ArgumentParser:
     return arg_parser
 
 
-if __name__ == "__main__":
+def main():
     # Parser arguments
     parser = get_parser()
     args = parser.parse_args()
@@ -124,3 +124,7 @@ if __name__ == "__main__":
             args.threshold,
         ),
     )
+
+
+if __name__ == "__main__":
+    main()
diff --git a/src/bin/inference_mmdc_singledate.py b/src/bin/inference_mmdc_singledate.py
new file mode 100644
index 0000000000000000000000000000000000000000..5df740b3fb860dcfbd91be4e7d79581306f888ff
--- /dev/null
+++ b/src/bin/inference_mmdc_singledate.py
@@ -0,0 +1,180 @@
+#!/usr/bin/env python3
+# copyright: (c) 2023 cesbio / centre national d'Etudes Spatiales
+
+"""
+Inference code for MMDC
+"""
+
+
+# imports
+import argparse
+import logging
+import os
+
+import torch
+
+from mmdc_singledate.inference.components.inference_components import (
+    get_mmdc_full_model,
+    predict_mmdc_model,
+)
+from mmdc_singledate.inference.components.inference_utils import GeoTiffDataset
+from mmdc_singledate.inference.mmdc_tile_inference import (
+    MMDCProcess,
+    predict_single_date_tile,
+)
+from mmdc_singledate.inference.utils import get_scales
+
+
+def get_parser() -> argparse.ArgumentParser:
+    """
+    Generate argument parser for CLI
+    """
+    arg_parser = argparse.ArgumentParser(
+        os.path.basename(__file__),
+        description="Compute the inference a selected number of latent spaces "
+        " and exported as GTiff for a given tile",
+    )
+
+    arg_parser.add_argument(
+        "--loglevel",
+        default="INFO",
+        choices=("DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"),
+        help="Logger level (default: INFO. Should be one of "
+        "(DEBUG, INFO, WARNING, ERROR, CRITICAL)",
+    )
+
+    arg_parser.add_argument(
+        "--s2_filename", type=str, help="s2 input file", required=False
+    )
+
+    arg_parser.add_argument(
+        "--s1_asc_filename", type=str, help="s1 asc input file", required=False
+    )
+
+    arg_parser.add_argument(
+        "--s1_desc_filename", type=str, help="s1 desc input file", required=False
+    )
+
+    arg_parser.add_argument(
+        "--srtm_filename", type=str, help="srtm input file", required=True
+    )
+
+    arg_parser.add_argument(
+        "--worldclim_filename", type=str, help="worldclim input file", required=True
+    )
+
+    arg_parser.add_argument(
+        "--export_path", type=str, help="export path ", required=True
+    )
+
+    arg_parser.add_argument(
+        "--model_checkpoint_path",
+        type=str,
+        help="model checkpoint path ",
+        required=True,
+    )
+
+    arg_parser.add_argument(
+        "--model_config_path",
+        type=str,
+        help="model configuration path ",
+        required=True,
+    )
+
+    arg_parser.add_argument(
+        "--latent_space_size",
+        type=int,
+        help="Number of channels in the output file",
+        required=False,
+    )
+
+    arg_parser.add_argument(
+        "--nb_lines",
+        type=int,
+        help="Number of Line to read in each window reading",
+        required=False,
+    )
+
+    arg_parser.add_argument(
+        "--patch_size",
+        type=int,
+        help="Number of Line to read in each window reading",
+        required=True,
+    )
+
+    arg_parser.add_argument(
+        "--sensors",
+        dest="sensors",
+        nargs="+",
+        help=(
+            "List of sensors S2L2A | S1FULL | S1ASC | S1DESC",
+            " (Is sensible to the order S2L2A always firts)",
+        ),
+        required=True,
+    )
+
+    return arg_parser
+
+
+def main():
+    """
+    Entry Point
+    """
+    # Parser arguments
+    parser = get_parser()
+    args = parser.parse_args()
+
+    # Configure logging
+    numeric_level = getattr(logging, args.loglevel.upper(), None)
+    if not isinstance(numeric_level, int):
+        raise ValueError("Invalid log level: %s" % args.loglevel)
+
+    logging.basicConfig(
+        level=numeric_level,
+        datefmt="%y-%m-%d %H:%M:%S",
+        format="%(asctime)s :: %(levelname)s :: %(message)s",
+    )
+
+    # device
+    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+    # model
+    mmdc_full_model = get_mmdc_full_model(
+        checkpoint=args.model_checkpoint_path,
+        mmdc_full_config_path=args.model_config_path,
+        inference_tile=True,
+    )
+    scales = get_scales()
+    mmdc_full_model.set_scales(scales)
+    mmdc_full_model.to(device)
+    # process class
+    mmdc_process = MMDCProcess(
+        count=args.latent_space_size,
+        nb_lines=args.nb_lines,
+        patch_size=args.patch_size,
+        process=predict_mmdc_model,
+        model=mmdc_full_model,
+    )
+
+    # create export filename
+    export_path = f"{args.export_path}/latent_singledate_{'_'.join(args.sensors)}.tif"
+
+    # GeoTiff
+    geotiff_dataclass = GeoTiffDataset(
+        s2_filename=args.s2_filename,
+        s1_asc_filename=args.s1_asc_filename,
+        s1_desc_filename=args.s1_desc_filename,
+        srtm_filename=args.srtm_filename,
+        wc_filename=args.worldclim_filename,
+    )
+
+    predict_single_date_tile(
+        input_data=geotiff_dataclass,
+        export_path=export_path,
+        sensors=args.sensors,
+        process=mmdc_process,
+    )
+
+
+if __name__ == "__main__":
+    main()
diff --git a/src/bin/inference_mmdc_timeserie.py b/src/bin/inference_mmdc_timeserie.py
new file mode 100644
index 0000000000000000000000000000000000000000..4db518f2b3d1d7a4702e7522469f591160e3cb06
--- /dev/null
+++ b/src/bin/inference_mmdc_timeserie.py
@@ -0,0 +1,139 @@
+#!/usr/bin/env python3
+# copyright: (c) 2023 cesbio / centre national d'Etudes Spatiales
+
+"""
+Inference code for MMDC
+"""
+
+
+# imports
+import argparse
+import logging
+import os
+from pathlib import Path
+
+from mmdc_singledate.inference.mmdc_tile_inference import (
+    inference_dataframe,
+    mmdc_tile_inference,
+)
+
+
+def get_parser() -> argparse.ArgumentParser:
+    """
+    Generate argument parser for CLI
+    """
+    arg_parser = argparse.ArgumentParser(
+        os.path.basename(__file__),
+        description="Compute the inference a selected number of latent spaces "
+        " and exported as GTiff for a given tile",
+    )
+
+    arg_parser.add_argument(
+        "--loglevel",
+        default="INFO",
+        choices=("DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"),
+        help="Logger level (default: INFO. Should be one of "
+        "(DEBUG, INFO, WARNING, ERROR, CRITICAL)",
+    )
+
+    arg_parser.add_argument(
+        "--input_path", type=str, help="data input folder", required=True
+    )
+
+    arg_parser.add_argument(
+        "--export_path", type=str, help="export path ", required=True
+    )
+
+    arg_parser.add_argument(
+        "--days_gap",
+        type=int,
+        help="difference between S1 and S2 acquisitions in days",
+        required=True,
+    )
+
+    arg_parser.add_argument(
+        "--model_checkpoint_path",
+        type=str,
+        help="model checkpoint path ",
+        required=True,
+    )
+
+    arg_parser.add_argument(
+        "--model_config_path",
+        type=str,
+        help="model configuration path ",
+        required=True,
+    )
+
+    arg_parser.add_argument(
+        "--tile_list",
+        nargs="+",
+        type=str,
+        help="List of tiles to produce",
+        required=False,
+    )
+
+    arg_parser.add_argument(
+        "--patch_size",
+        type=int,
+        help="Sizes of the patches to make the inference",
+        required=False,
+    )
+
+    arg_parser.add_argument(
+        "--nb_lines",
+        type=int,
+        help="Number of Line to read in each window reading",
+        required=False,
+    )
+
+    return arg_parser
+
+
+def main():
+    """
+    Entry Point
+    """
+    # Parser arguments
+    parser = get_parser()
+    args = parser.parse_args()
+
+    # Configure logging
+    numeric_level = getattr(logging, args.loglevel.upper(), None)
+    if not isinstance(numeric_level, int):
+        raise ValueError("Invalid log level: %s" % args.loglevel)
+
+    logging.basicConfig(
+        level=numeric_level,
+        datefmt="%y-%m-%d %H:%M:%S",
+        format="%(asctime)s :: %(levelname)s :: %(message)s",
+    )
+
+    if not os.path.isdir(args.export_path):
+        logging.info(
+            f"The directory is not present. Creating a new one.. " f"{args.export_path}"
+        )
+        Path(args.export_path).mkdir()
+
+    # calculate time serie
+    tile_df = inference_dataframe(
+        samples_dir=args.input_path,
+        input_tile_list=args.tile_list,
+        days_gap=args.days_gap,
+    )
+    logging.info(f"nb of dates to infer : {tile_df.shape[0]}")
+    # entry point
+    mmdc_tile_inference(
+        inference_dataframe=tile_df,
+        export_path=Path(args.export_path),
+        model_checkpoint_path=Path(args.model_checkpoint_path),
+        model_config_path=Path(args.model_config_path),
+        patch_size=args.patch_size,
+        nb_lines=args.nb_lines,
+    )
+
+    logging.info("All the data has exported succesfully")
+
+
+if __name__ == "__main__":
+    main()
diff --git a/src/bin/split_mmdc_dataset.py b/src/bin/split_mmdc_dataset.py
index de429bf5c5a539a2c76e8d0af1bc1a42401b7c10..b205a24971293c051f942134bfa3c3ede81d3f4c 100644
--- a/src/bin/split_mmdc_dataset.py
+++ b/src/bin/split_mmdc_dataset.py
@@ -141,16 +141,10 @@ def split_tile_rois(
             roi_intersect.writelines(sorted([f"{roi}\n" for roi in roi_list]))
 
 
-def main(
-    input_dir: Path, input_file: Path, test_percentage: int, random_state: int
-) -> None:
+def main():
     """
     Split the tiles/rois between train/val/test
     """
-    split_tile_rois(input_dir, input_file, test_percentage, random_state)
-
-
-if __name__ == "__main__":
     # Parser arguments
     parser = get_parser()
     args = parser.parse_args()
@@ -171,10 +165,14 @@ if __name__ == "__main__":
     logging.info("test percentage selected : %s", args.test_percentage)
     logging.info("random state value :  %s", args.random_state)
 
-    # Go to main
-    main(
+    # Go to entry point
+    split_tile_rois(
         Path(args.tensors_dir),
         Path(args.roi_intersections_file),
         args.test_percentage,
         args.random_state,
     )
+
+
+if __name__ == "__main__":
+    main()
diff --git a/src/bin/visualize_mmdc_ds.py b/src/bin/visualize_mmdc_ds.py
index 01da86c662bac0e91ad3cf8c654e0ec3d39d4e38..a69bc924abb3550c0e6e5e263a5b14e5004f72fa 100644
--- a/src/bin/visualize_mmdc_ds.py
+++ b/src/bin/visualize_mmdc_ds.py
@@ -71,34 +71,10 @@ def get_parser() -> argparse.ArgumentParser:
     return arg_parser
 
 
-def main(
-    input_directory: Path,
-    input_tiles: Path,
-    patch_index: list[int] | None,
-    nb_patches: int | None,
-    export_path: Path,
-) -> None:
+def main():
     """
     Entry point for the visualization
     """
-    try:
-        export_visualization_graphs(
-            input_directory=input_directory,
-            input_tiles=input_tiles,
-            patch_index=patch_index,
-            nb_patches=nb_patches,
-            export_path=export_path,
-        )
-
-    except FileNotFoundError:
-        if not input_directory.exists():
-            print(f"Folder {input_directory} does not exist!!")
-
-        if not export_path.exists():
-            print(f"Folder {export_path} does not exist!!")
-
-
-if __name__ == "__main__":
     # Parser arguments
     parser = get_parser()
     args = parser.parse_args()
@@ -126,13 +102,26 @@ if __name__ == "__main__":
     logging.info(" index patches : %s", args.patch_index)
     logging.info(" output directory : %s", args.export_path)
 
-    # Go to main
-    main(
-        input_directory=Path(args.input_path),
-        input_tiles=Path(args.input_tiles),
-        patch_index=args.patch_index,
-        nb_patches=args.nb_patches,
-        export_path=Path(args.export_path),
-    )
+    # Go to entry point
+
+    try:
+        export_visualization_graphs(
+            input_directory=Path(args.input_path),
+            input_tiles=Path(args.input_tiles),
+            patch_index=args.patch_index,
+            nb_patches=args.nb_patches,
+            export_path=Path(args.export_path),
+        )
+
+    except FileNotFoundError:
+        if not args.input_directory.exists():
+            print(f"Folder {args.input_directory} does not exist!!")
+
+        if not args.export_path.exists():
+            print(f"Folder {args.export_path} does not exist!!")
 
     logging.info("Visualization export finished")
+
+
+if __name__ == "__main__":
+    main()
diff --git a/src/mmdc_singledate/inference/__init__.py b/src/mmdc_singledate/inference/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e5a0d9b4834ec8f46d6e0d1256c6dcaad2e460fe
--- /dev/null
+++ b/src/mmdc_singledate/inference/__init__.py
@@ -0,0 +1 @@
+#!/usr/bin/env python3
diff --git a/src/mmdc_singledate/inference/components/__init__.py b/src/mmdc_singledate/inference/components/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e5a0d9b4834ec8f46d6e0d1256c6dcaad2e460fe
--- /dev/null
+++ b/src/mmdc_singledate/inference/components/__init__.py
@@ -0,0 +1 @@
+#!/usr/bin/env python3
diff --git a/src/mmdc_singledate/inference/components/inference_components.py b/src/mmdc_singledate/inference/components/inference_components.py
new file mode 100644
index 0000000000000000000000000000000000000000..42cdc604a0cbd1fc7d9bd47b8325bb49ef290189
--- /dev/null
+++ b/src/mmdc_singledate/inference/components/inference_components.py
@@ -0,0 +1,304 @@
+#!/usr/bin/env python3
+# copyright: (c) 2023 cesbio / centre national d'Etudes Spatiales
+
+"""
+Functions components for get the
+latent spaces for a given tile
+"""
+# imports
+import logging
+from pathlib import Path
+from typing import Literal
+
+import torch
+from torch import nn
+
+from mmdc_singledate.inference.utils import get_mmdc_full_config
+from mmdc_singledate.models.torch.full import MMDCFullModule
+from mmdc_singledate.models.types import S1S2VAELatentSpace
+
+# define sensors variable for typing
+SENSORS = Literal["S2L2A", "S1FULL", "S1ASC", "S1DESC"]
+
+# Configure the logger
+NUMERIC_LEVEL = getattr(logging, "INFO", None)
+logging.basicConfig(
+    level=NUMERIC_LEVEL, format="%(asctime)-15s %(levelname)s: %(message)s"
+)
+logger = logging.getLogger(__name__)
+
+
+@torch.no_grad()
+def get_mmdc_full_model(
+    checkpoint: str,
+    mmdc_full_config_path: str,
+    inference_tile: bool,
+) -> nn.Module:
+    """
+    Instantiate the network
+
+    """
+    if not Path(checkpoint).exists:
+        logger.info("No checkpoint path was give. ")
+        return 1
+    else:
+        # read the config file
+        mmdc_full_config = get_mmdc_full_config(
+            mmdc_full_config_path, inference_tile=inference_tile
+        )
+        # Init the model passing the config
+        mmdc_full_model = MMDCFullModule(
+            mmdc_full_config,
+        )
+
+        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+        logger.debug(f"device detected : {device}")
+        logger.debug(f"nb_gpu's detected : {torch.cuda.device_count()}")
+        # load state_dict
+        lightning_checkpoint = torch.load(checkpoint, map_location=device)
+
+        # delete "model" from the loaded checkpoint
+        lightning_checkpoint_cleaning = {
+            key.split("model.")[-1]: item
+            for key, item in lightning_checkpoint["state_dict"].items()
+            if key.startswith("model.")
+        }
+
+        # load the state dict
+        mmdc_full_model.load_state_dict(lightning_checkpoint_cleaning)
+
+        # disble randomness, dropout, etc...
+        mmdc_full_model.eval()
+
+    return mmdc_full_model
+
+
+@torch.no_grad()
+def predict_mmdc_model(
+    model: nn.Module,
+    sensors: SENSORS,
+    s2_ref,
+    s2_mask,
+    s2_angles,
+    s1_back,
+    s1_vm,
+    s1_asc_angles,
+    s1_desc_angles,
+    worldclim,
+    srtm,
+) -> torch.Tensor:
+    """
+    This function apply the predict method to a
+    batch of mmdc data and get the latent space
+    by sensor
+
+        :args: model : mmdc_full model intance
+        :args: sensors : Sensor acquisition
+        :args: s2_ref :
+        :args: s2_angles :
+        :args: s1_back :
+        :args: s1_asc_angles :
+        :args: s1_desc_angles :
+        :args: worldclim :
+        :args: srtm :
+
+        :return: VAELatentSpace
+    """
+    # TODO check if the sensors are in the current supported list
+    # available_sensors = ["S2L2A", "S1FULL", "S1ASC", "S1DESC"]
+
+    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+    logger.debug(device)
+
+    prediction = model.predict(
+        s2_ref.to(device),
+        s2_mask.to(device),
+        s2_angles.to(device),
+        s1_back.to(device),
+        s1_vm.to(device),
+        s1_asc_angles.to(device),
+        s1_desc_angles.to(device),
+        worldclim.to(device),
+        srtm.to(device),
+    )
+
+    logger.debug("prediction.shape :", prediction[0].latent.sen2.mean.shape)
+    logger.debug("prediction.shape :", prediction[0].latent.sen2.logvar.shape)
+    logger.debug("prediction.shape :", prediction[0].latent.sen1.mean.shape)
+    logger.debug("prediction.shape :", prediction[0].latent.sen1.logvar.shape)
+
+    # init the output latent spaces as
+    # empty dataclass
+    latent_space = S1S2VAELatentSpace(sen1=None, sen2=None)
+
+    # fullfit with a matchcase
+    # match
+    # match sensors:
+    # Cases
+    if sensors == ["S2L2A", "S1FULL"]:
+        logger.debug("S2 captured & S1 full captured")
+
+        latent_space.sen2 = prediction[0].latent.sen2
+        latent_space.sen1 = prediction[0].latent.sen1
+
+        latent_space_stack = torch.cat(
+            (
+                latent_space.sen2.mean,
+                latent_space.sen2.logvar,
+                latent_space.sen1.mean,
+                latent_space.sen1.logvar,
+            ),
+            1,
+        )
+
+        logger.debug(
+            "latent_space.sen2.mean :",
+            latent_space.sen2.mean.shape,
+        )
+        logger.debug(
+            "latent_space.sen2.logvar :",
+            latent_space.sen2.logvar.shape,
+        )
+        logger.debug("latent_space.sen1.mean :", latent_space.sen1.mean.shape)
+        logger.debug("latent_space.sen1.logvar :", latent_space.sen1.logvar.shape)
+
+    elif sensors == ["S2L2A", "S1ASC"]:
+        # logger.debug("S2 captured & S1 asc captured")
+
+        latent_space.sen2 = prediction[0].latent.sen2
+        latent_space.sen1 = prediction[0].latent.sen1
+
+        latent_space_stack = torch.cat(
+            (
+                latent_space.sen2.mean,
+                latent_space.sen2.logvar,
+                latent_space.sen1.mean,
+                latent_space.sen1.logvar,
+            ),
+            1,
+        )
+        logger.debug(
+            "latent_space.sen2.mean :",
+            latent_space.sen2.mean.shape,
+        )
+        logger.debug(
+            "latent_space.sen2.logvar :",
+            latent_space.sen2.logvar.shape,
+        )
+        logger.debug("latent_space.sen1.mean :", latent_space.sen1.mean.shape)
+        logger.debug("latent_space.sen1.logvar :", latent_space.sen1.logvar.shape)
+
+    elif sensors == ["S2L2A", "S1DESC"]:
+        logger.debug("S2 captured & S1 desc captured")
+
+        latent_space.sen2 = prediction[0].latent.sen2
+        latent_space.sen1 = prediction[0].latent.sen1
+
+        latent_space_stack = torch.cat(
+            (
+                latent_space.sen2.mean,
+                latent_space.sen2.logvar,
+                latent_space.sen1.mean,
+                latent_space.sen1.logvar,
+            ),
+            1,
+        )
+        logger.debug(
+            "latent_space.sen2.mean :",
+            latent_space.sen2.mean.shape,
+        )
+        logger.debug(
+            "latent_space.sen2.logvar :",
+            latent_space.sen2.logvar.shape,
+        )
+        logger.debug("latent_space.sen1.mean :", latent_space.sen1.mean.shape)
+        logger.debug("latent_space.sen1.logvar :", latent_space.sen1.logvar.shape)
+
+    elif sensors == ["S2L2A"]:
+        # logger.debug("Only S2 captured")
+
+        latent_space.sen2 = prediction[0].latent.sen2
+
+        latent_space_stack = torch.cat(
+            (
+                latent_space.sen2.mean,
+                latent_space.sen2.logvar,
+            ),
+            1,
+        )
+        logger.debug(
+            "latent_space.sen2.mean :",
+            latent_space.sen2.mean.shape,
+        )
+        logger.debug(
+            "latent_space.sen2.logvar :",
+            latent_space.sen2.logvar.shape,
+        )
+
+    elif sensors == ["S1FULL"]:
+        # logger.debug("Only S1 full captured")
+
+        latent_space.sen1 = prediction[0].latent.sen1
+
+        latent_space_stack = torch.cat(
+            (latent_space.sen1.mean, latent_space.sen1.logvar),
+            1,
+        )
+        logger.debug("latent_space.sen1.mean :", latent_space.sen1.mean.shape)
+        logger.debug("latent_space.sen1.logvar :", latent_space.sen1.logvar.shape)
+
+    elif sensors == ["S1ASC"]:
+        # logger.debug("Only S1ASC captured")
+
+        latent_space.sen1 = prediction[0].latent.sen1
+
+        latent_space_stack = torch.cat(
+            (latent_space.sen1.mean, latent_space.sen1.logvar),
+            1,
+        )
+        logger.debug("latent_space.sen1.mean :", latent_space.sen1.mean.shape)
+        logger.debug("latent_space.sen1.logvar :", latent_space.sen1.logvar.shape)
+
+    elif sensors == ["S1DESC"]:
+        # logger.debug("Only S1DESC captured")
+
+        latent_space.sen1 = prediction[0].latent.sen1
+
+        latent_space_stack = torch.cat(
+            (latent_space.sen1.mean, latent_space.sen1.logvar),
+            1,
+        )
+        logger.debug("latent_space.sen1.mean :", latent_space.sen1.mean.shape)
+        logger.debug("latent_space.sen1.logvar :", latent_space.sen1.logvar.shape)
+    else:
+        raise Exception(f"Not a valid sensors ({sensors}) are given")
+
+    return latent_space_stack  # latent_space
+
+
+# Switch to match case in python 3.10 o superior
+# fullfit with a matchcase
+# for captor in sensors:
+#     # match
+#     match captor:
+#         # Cases
+#         case ["S2L2A", "S1FULL"]:
+#             print("S1 full captured & S2 captured")
+
+#         case ["S2L2A",  "S1ASC"]:
+#             print("S1 asc captured & S2 captured")
+
+#         case ["S2L2A", "S1DESC"]:
+#             print("S1 desc captured & S2 captured")
+
+#         case ["S2L2A"]:
+#             print("Only S2 captured")
+
+#         case ["S1FULL"]:
+#             print("Only S1 full captured")
+
+#         case ["S1ASC"]:
+#             print("Only S1ASC captured")
+
+#         case ["S1DESC"]:
+#             print("Only S1DESC captured")
diff --git a/src/mmdc_singledate/inference/components/inference_utils.py b/src/mmdc_singledate/inference/components/inference_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..ef85dbdb889e4515bbd634e0f855ce91f5e22121
--- /dev/null
+++ b/src/mmdc_singledate/inference/components/inference_utils.py
@@ -0,0 +1,345 @@
+#!/usr/bin/env python3
+# copyright: (c) 2023 cesbio / centre national d'Etudes Spatiales
+
+"""
+Infereces utils functions
+
+"""
+from collections.abc import Callable, Generator
+from dataclasses import dataclass, field
+from pathlib import Path
+from typing import Any
+
+import rasterio as rio
+import torch
+from rasterio.windows import Window as rio_window
+
+from mmdc_singledate.datamodules.components.datamodule_components import (
+    get_worldclim_filenames,
+)
+from mmdc_singledate.datamodules.components.datamodule_utils import (
+    apply_log_to_s1,
+    get_s1_acquisition_angles,
+    join_even_odd_s2_angles,
+    srtm_height_aspect,
+)
+
+# dataclasses
+Chunk = tuple[int, int, int, int]
+
+# hold the data
+
+
+@dataclass
+class GeoTiffDataset:
+
+    """
+    hold the tif's filenames
+    stored in disk and his metadata
+    """
+
+    s2_filename: str
+    s1_asc_filename: str
+    s1_desc_filename: str
+    srtm_filename: str
+    wc_filename: str
+    metadata: dict = field(init=False)
+    s2_availabitity: bool | None = field(default=None, init=True)
+    s1_asc_availability: bool | None = field(default=None, init=True)
+    s1_desc_availability: bool | None = field(default=None, init=True)
+
+    def __post_init__(self) -> None:
+        # check if the data exists
+        if not Path(self.srtm_filename).is_file():
+            raise Exception(f"{self.srtm_filename} do not exists!")
+        if not Path(self.wc_filename).is_file():
+            raise Exception(f"{self.wc_filename} do not exists!")
+        if Path(self.s1_asc_filename).is_file():
+            self.s1_asc_availability = True
+        else:
+            self.s1_asc_availability = False
+        if Path(self.s1_desc_filename).is_file():
+            self.s1_desc_availability = True
+        else:
+            self.s1_desc_availability = False
+        if Path(self.s2_filename).is_file():
+            self.s2_availabitity = True
+        else:
+            self.s2_availabitity = False
+        # check if the filenames
+        # provide are in the same
+        # footprint and spatial resolution
+        # rio.open(
+        #    self.s1_asc_filename
+        # ) as s1_asc, rio.open(self.s1_desc_filename) as s1_desc,
+        # rio.open(self.s2_filename) as s2,
+        with rio.open(self.srtm_filename) as srtm, rio.open(self.wc_filename) as wc:
+            # recover the metadata
+            self.metadata = srtm.meta
+
+            # create a python set to check
+            crss = {
+                srtm.meta["crs"],
+                wc.meta["crs"],
+                # s2.meta["crs"],
+                # s1_asc.meta["crs"],
+                # s1_desc.meta["crs"],
+            }
+
+            heights = {
+                srtm.meta["height"],
+                wc.meta["height"],
+                # s2.meta["height"],
+                # s1_asc.meta["height"],
+                # s1_desc.meta["height"],
+            }
+
+            widths = {
+                srtm.meta["width"],
+                wc.meta["width"],
+                # s2.meta["width"],
+                # s1_asc.meta["width"],
+                # s1_desc.meta["width"],
+            }
+
+            # check crs
+            if len(crss) > 1:
+                raise Exception("Data should be in the same CRS")
+
+            # check height
+            if len(heights) > 1:
+                raise Exception("Data should be have the same height")
+
+            # check width
+            if len(widths) > 1:
+                raise Exception("Data should be have the same width")
+
+
+@dataclass
+class S2Components:
+    """
+    dataclass for hold the s2 related data
+    """
+
+    s2_reflectances: torch.Tensor
+    s2_angles: torch.Tensor
+    s2_mask: torch.Tensor
+
+
+@dataclass
+class S1Components:
+    """
+    dataclass for hold the s1 related data
+    """
+
+    s1_backscatter: torch.Tensor
+    s1_valmask: torch.Tensor
+    s1_lia_angles: torch.Tensor
+
+
+@dataclass
+class SRTMComponents:
+    """
+    dataclass for hold srtm related data
+    """
+
+    srtm: torch.Tensor
+
+
+@dataclass
+class WorldClimComponents:
+    """
+    dataclass for hold worldclim related data
+    """
+
+    worldclim: torch.Tensor
+
+
+# chunks logic
+def generate_chunks(
+    width: int, height: int, nlines: int, test_area: tuple[int, int] | None = None
+) -> list[Chunk]:
+    """Given the width and height of an image and a number of lines per chunk,
+    generate the list of chunks for the image. The chunks span all the width. A
+    chunk is encoded as 4 values: (x0, y0, width, height)
+    :param width: image width
+    :param height: image height
+    :param nlines: number of lines per chunk
+    :param test_area: generate only the chunks between (ya, yb)
+    :returns: a list of chunks
+    """
+    chunks = []
+    for iter_height in range(height // nlines - 1):
+        y0 = iter_height * nlines
+        y1 = y0 + nlines
+        if y1 > height:
+            y1 = height
+        append_chunk = (test_area is None) or (y0 > test_area[0] and y1 < test_area[1])
+        if append_chunk:
+            chunks.append((0, y0, width, y1 - y0))
+    return chunks
+
+
+# read data
+def read_img_tile(
+    filename: str,
+    rois: list[rio_window],
+    availabitity: bool,
+    sensor_func: Callable[[torch.Tensor, Any], Any],
+) -> Generator[Any, None, None]:
+    """
+    read a patch of a abstract sensor data
+    contruct the auxiliary data and yield the patch
+    of data
+    """
+    # check if the filename exists
+    # if exist proceed to yield the data
+    if availabitity:
+        with rio.open(filename) as raster:
+            for roi in rois:
+                # read the patch as tensor
+                tensor = torch.tensor(raster.read(window=roi), requires_grad=False)
+                # compute some transformation to the tensor and return dataclass
+                sensor_data = sensor_func(tensor, availabitity, filename)
+                yield sensor_data
+    # if not exists create a zeros tensor and yield
+    else:
+        # print("Creating fake data")
+        for roi in rois:
+            null_data = torch.ones(roi.width, roi.height)
+            sensor_data = sensor_func(null_data, availabitity, filename)
+            yield sensor_data
+
+
+def read_s2_img_tile(
+    s2_tensor: torch.Tensor,
+    availabitity: bool,
+    *args: Any,
+    **kwargs: Any,
+) -> S2Components:  # [torch.Tensor, torch.Tensor, torch.Tensor]:
+    """
+    read a patch of sentinel 2 data
+    contruct the masks and yield the patch
+    of data
+    """
+    if availabitity:
+        # extract masks
+        cloud_mask = s2_tensor[11, ...].to(torch.uint8)
+        cloud_mask[cloud_mask > 0] = 1
+        sat_mask = s2_tensor[10, ...].to(torch.uint8)
+        edge_mask = s2_tensor[12, ...].to(torch.uint8)
+        # create the validity mask
+        mask = torch.logical_or(
+            cloud_mask, torch.logical_or(edge_mask, sat_mask)
+        ).unsqueeze(0)
+        angles_s2 = join_even_odd_s2_angles(s2_tensor[14:, ...])
+        image_s2 = s2_tensor[:10, ...]
+    else:
+        image_s2 = torch.ones(10, s2_tensor.shape[0], s2_tensor.shape[1])
+        angles_s2 = torch.ones(6, s2_tensor.shape[0], s2_tensor.shape[1])
+        mask = torch.zeros(1, s2_tensor.shape[0], s2_tensor.shape[1])
+        # print("passing zero data")
+
+    return S2Components(s2_reflectances=image_s2, s2_angles=angles_s2, s2_mask=mask)
+
+
+# TODO get out the s1_lia computation from this function
+def read_s1_img_tile(
+    s1_tensor: torch.Tensor,
+    availability: bool,
+    s1_filename: str,
+    *args: Any,
+    **kwargs: Any,
+) -> S1Components:  # [torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+    """
+    read a patch of s1 construct the datasets associated
+    and yield the result
+
+    inputs : s1_tensor read from  read_img_tile
+    """
+    # if s1_tensor.shape[0] == 2:
+    if availability:
+        # get the vv,vh, vv/vh bands
+        img_s1 = torch.cat(
+            (s1_tensor, (s1_tensor[1, ...] / s1_tensor[0, ...]).unsqueeze(0))
+        )
+        # get the local incidencce angle (lia)
+        s1_lia = get_s1_acquisition_angles(s1_filename)
+        # compute validity mask
+        s1_valmask = torch.ones(1, s1_tensor.shape[1], s1_tensor.shape[2])
+        # compute edge mask
+        s1_backscatter = apply_log_to_s1(img_s1)
+    else:
+        s1_backscatter = torch.ones(3, s1_tensor.shape[1], s1_tensor.shape[0])
+        s1_valmask = torch.zeros(1, s1_tensor.shape[1], s1_tensor.shape[0])
+        s1_lia = torch.ones(3)
+
+    return S1Components(
+        s1_backscatter=s1_backscatter,
+        s1_valmask=s1_valmask,
+        s1_lia_angles=s1_lia,
+    )
+
+
+def read_srtm_img_tile(
+    srtm_tensor: torch.Tensor,
+    *args: Any,
+    **kwargs: Any,
+) -> SRTMComponents:  # [torch.Tensor]:
+    """
+    read srtm patch
+    """
+    # read the patch as tensor
+    img_srtm = srtm_height_aspect(srtm_tensor)
+    return SRTMComponents(img_srtm)
+
+
+def read_worldclim_img_tile(
+    wc_tensor: torch.Tensor,
+    *args: Any,
+    **kwargs: Any,
+) -> WorldClimComponents:  # [torch.tensor]:
+    """
+    read worldclim subset
+    """
+    return WorldClimComponents(wc_tensor)
+
+
+def expand_worldclim_filenames(
+    wc_filename: str,
+) -> list[str]:
+    """
+    given the firts worldclim filename expand the filenames
+    to the others files
+    """
+    name_wc = [
+        get_worldclim_filenames(wc_filename, str(idx), True) for idx in range(1, 13, 1)
+    ]
+
+    name_wc_bio = get_worldclim_filenames(wc_filename, None, False)
+    name_wc.append(name_wc_bio)
+
+    return name_wc
+
+
+def concat_worldclim_components(
+    wc_filename: str,
+    rois: list[rio_window],
+    availabitity,
+    *args: Any,
+    **kwargs: Any,
+) -> Generator[Any, None, None]:
+    """
+    Compose function for apply the read_img_tile general function
+    to the different worldclim files
+    """
+    # get all filenames
+    wc_filenames = expand_worldclim_filenames(wc_filename)
+    # read the tensors
+    wc_tensor = [
+        next(
+            read_img_tile(wc_filename, rois, availabitity, read_worldclim_img_tile)
+        ).worldclim
+        for idx, wc_filename in enumerate(wc_filenames)
+    ]
+    yield WorldClimComponents(torch.cat(wc_tensor))
diff --git a/src/mmdc_singledate/inference/mmdc_iota2_inference.py b/src/mmdc_singledate/inference/mmdc_iota2_inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..742f2892e94d794519dfb309ff1c20f6357d8db7
--- /dev/null
+++ b/src/mmdc_singledate/inference/mmdc_iota2_inference.py
@@ -0,0 +1,512 @@
+#!/usr/bin/env python3
+# copyright: (c) 2023 cesbio / centre national d'Etudes Spatiales
+
+"""
+Infereces API with Iota-2
+Inspired by:
+https://src.koda.cnrs.fr/mmdc/mmdc-singledate/-/blob/01ff5139a9eb22785930964d181e2a0b7b7af0d1/iota2/external_iota2_code.py
+"""
+# imports
+import logging
+import os
+
+import numpy as np
+import torch
+from dateutil import parser as date_parser
+from iota2.configuration_files import read_config_file as rcf
+from torchutils import patches
+
+# from mmdc_singledate.models.mmdc_full_module import MMDCFullModule
+# from mmdc_singledate.models.types import S1S2VAELatentSpace
+# from mmdc_singledate.datamodules.components.datamodule_utils import (
+#     apply_log_to_s1,
+#     srtm_height_aspect,
+# )
+
+# from .components.inference_components import get_mmdc_full_model, predict_mmdc_model
+# from .components.inference_utils import (
+#     GeoTiffDataset,
+#     S1Components,
+#     S2Components,
+#     SRTMComponents,
+#     WorldClimComponents,
+#     get_mmdc_full_config,
+#     read_s1_img_tile,
+#     read_s2_img_tile,
+#     read_srtm_img_tile,
+#     read_worldclim_img_tile,
+# )
+# from .utils import get_scales
+
+# from tqdm import tqdm
+# from pathlib import Path
+
+
+# Configure the logger
+NUMERIC_LEVEL = getattr(logging, "INFO", None)
+logging.basicConfig(
+    level=NUMERIC_LEVEL, format="%(asctime)-15s %(levelname)s: %(message)s"
+)
+logger = logging.getLogger(__name__)
+
+
+def read_s2_img_iota2(
+        self,
+):
+# ) -> S2Components:  # [torch.Tensor, torch.Tensor, torch.Tensor]:
+    """
+    read a patch of sentinel 2 data
+    contruct the masks and yield the patch
+    of data
+    """
+    # DONE Get the data in the same order as
+    # sentinel2.Sentinel2.GROUP_10M +
+    # sensorsio.sentinel2.GROUP_20M
+    # This data correspond to (Chunk_Size, Img_Size, nb_dates)
+    list_bands_s2 = [
+        self.get_interpolated_Sentinel2_B2(),
+        self.get_interpolated_Sentinel2_B3(),
+        self.get_interpolated_Sentinel2_B4(),
+        self.get_interpolated_Sentinel2_B8(),
+        self.get_interpolated_Sentinel2_B5(),
+        self.get_interpolated_Sentinel2_B6(),
+        self.get_interpolated_Sentinel2_B7(),
+        self.get_interpolated_Sentinel2_B8A(),
+        self.get_interpolated_Sentinel2_B11(),
+        self.get_interpolated_Sentinel2_B12(),
+    ]
+
+    # DONE Masks contruction for S2
+    list_s2_mask = [
+        self.get_Sentinel2_binary_masks(),
+        self.get_Sentinel2_binary_masks(),
+        self.get_Sentinel2_binary_masks(),
+        self.get_Sentinel2_binary_masks(),
+    ]
+
+    # DONE Add S2 Angles
+    list_bands_s2_angles = [
+        self.get_interpolated_userFeatures_ZENITH(),
+        self.get_interpolated_userFeatures_AZIMUTH(),
+        self.get_interpolated_userFeatures_EVEN_ZENITH(),
+        self.get_interpolated_userFeatures_ODD_ZENITH(),
+        self.get_interpolated_userFeatures_EVEN_AZIMUTH(),
+        self.get_interpolated_userFeatures_ODD_AZIMUTH(),
+    ]
+
+    # extend the list
+    list_bands_s2.extend(list_s2_mask)
+    list_bands_s2.extend(list_bands_s2_angles)
+
+    bands_s2 = torch.Tensor(
+        list_bands_s2, dtype=torch.torch.float32, device=device
+    ).permute(-1, 0, 1, 2)
+
+    return read_s2_img_tile(bands_s2)
+
+
+def read_s1_img_iota2(
+    self,
+):
+# ) -> S1Components:  # [torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+    """
+    read a patch of s1 construct the datasets associated
+    and yield the result
+    """
+    # TODO Manage S1 ASC and S1 DESC ?
+    list_bands_s1 = [
+        self.get_interpolated_Sentinel1_ASC_vh(),
+        self.get_interpolated_Sentinel1_ASC_vv(),
+        self.get_interpolated_Sentinel1_ASC_vh()
+        / (self.get_interpolated_Sentinel1_ASC_vv() + 1e-4),
+        self.get_interpolated_Sentinel1_DES_vh(),
+        self.get_interpolated_Sentinel1_DES_vv(),
+        self.get_interpolated_Sentinel1_DES_vh()
+        / (self.get_interpolated_Sentinel1_DES_vv() + 1e-4),
+    ]
+
+    bands_s1 = torch.Tensor(list_bands_s1).permute(-1, 0, 1, 2)
+    bands_s1 = apply_log_to_s1(bands_s1).permute(-1, 0, 1, 2)
+
+    # get the vv,vh, vv/vh bands
+    # img_s1 = torch.cat(
+    # (s1_tensor, (s1_tensor[1, ...] / s1_tensor[0, ...]).unsqueeze(0))
+    # )
+
+    list_bands_s1_lia = [
+        self.get_interpolated_,
+        self.get_interpolated_,
+        self.get_interpolated_,
+    ]
+
+    # s1_lia = get_s1_acquisition_angles(s1_filename)
+    # compute validity mask
+    s1_valmask = torch.ones(img_s1.shape)
+    # compute edge mask
+    s1_backscatter = apply_log_to_s1(img_s1)
+
+    return read_s1_img_tile()
+
+
+def read_srtm_img_iota2(
+    self,
+):
+# ) -> SRTMComponents:
+    """
+    read srtm patch from Iota2 API
+    """
+    # read the patch as numpy array
+
+    list_bands_srtm = [
+        self.get_interpolated_userFeatures_elevation(),
+        self.get_interpolated_userFeatures_slope(),
+        self.get_interpolated_userFeatures_aspect(),
+    ]
+    # convert to tensor and permute the bands
+    srtm_tensor = torch.Tensor(list_bands_srtm).permute(-1, 0, 1, 2)
+
+    return read_srtm_img_tile(
+        srtm_height_aspect(
+            srtm_tensor
+        )
+    )
+
+
+
+def read_worldclim_img_iota2(
+    self,
+):
+# ) -> WorldClimComponents:  # [torch.tensor]:
+    """
+    read worldclim patch from Iota2 API
+    """
+    wc_tensor = torch.Tensor(
+        self.get_interpolated_userFeatures_worldclim()
+    ).permute(-1, 0, 1, 2)
+    return read_worldclim_img_tile(wc_tensor)
+
+
+def predictc_iota2(
+    self,
+):
+    """
+    Apply MMDC with Iota-2.py
+    """
+
+    # retrieve the metadata
+    # meta = input_data.metadata.copy()
+    # get nb outputs
+    # meta.update({"count": process.count})
+    # calculate rois
+    # chunks = generate_chunks(meta["width"], meta["height"], process.nb_lines)
+    # get the windows from the chunks
+    # rois = [rio.windows.Window(*chunk) for chunk in chunks]
+    # logger.debug(f"chunk size : ({rois[0].width}, {rois[0].height}) ")
+
+    # init the dataset
+    logger.debug("Reading S2 data")
+    s2_data = read_s2_img_iota2()
+    logger.debug("Reading S1 ASC data")
+    s1_asc_data = read_s1_img_iota2()
+    logger.debug("Reading S1 DESC data")
+    s1_desc_data = read_s1_img_iota2()
+    logger.debug("Reading SRTM data")
+    srtm_data = read_srtm_img_iota2()
+    logger.debug("Reading WorldClim data")
+    worldclim_data = read_worldclim_img_iota2()
+
+    # Concat S1 Data
+    s1_backscatter = torch.cat((s1_asc.s1_backscatter, s1_desc.s1_backscatter), 0)
+    # Multiply the masks
+    s1_valmask = torch.mul(s1_asc.s1_valmask, s1_desc.s1_valmask)
+    # keep the sizes for recover the original size
+    s2_mask_patch_size = patches.patchify(s2.s2_mask, process.patch_size).shape
+    s2_s2_mask_shape = s2.s2_mask.shape  # [1, 1024, 10980]
+
+    # reshape the data
+    s2_refl_patch = patches.flatten2d(
+        patches.patchify(s2_data.s2_reflectances, process.patch_size)
+    )
+    s2_ang_patch = patches.flatten2d(
+        patches.patchify(s2_data.s2_angles, process.patch_size)
+    )
+    s2_mask_patch = patches.flatten2d(
+        patches.patchify(s2_data.s2_mask, process.patch_size)
+    )
+    s1_patch = patches.flatten2d(patches.patchify(s1_backscatter, process.patch_size))
+    s1_valmask_patch = patches.flatten2d(
+        patches.patchify(s1_valmask, process.patch_size)
+    )
+    srtm_patch = patches.flatten2d(patches.patchify(srtm_data.srtm, process.patch_size))
+    wc_patch = patches.flatten2d(
+        patches.patchify(worldclim_data.worldclim, process.patch_size)
+    )
+    # Expand the angles to fit the sizes
+    s1asc_lia_patch = s1_asc.s1_lia_angles.unsqueeze(0).repeat(
+        s2_mask_patch_size[0] * s2_mask_patch_size[1], 1
+    )
+    s1desc_lia_patch = s1_desc.s1_lia_angles.unsqueeze(0).repeat(
+        s2_mask_patch_size[0] * s2_mask_patch_size[1], 1
+    )
+
+    logger.debug("s2_patches", s2_refl_patch.shape)
+    logger.debug("s2_angles_patches", s2_ang_patch.shape)
+    logger.debug("s2_mask_patch", s2_mask_patch.shape)
+    logger.debug("s1_patch", s1_patch.shape)
+    logger.debug("s1_valmask_patch", s1_valmask_patch.shape)
+    logger.debug(
+        "s1_lia_asc patches",
+        s1_asc.s1_lia_angles.shape,
+        s1asc_lia_patch.shape,
+    )
+    logger.debug(
+        "s1_lia_desc patches",
+        s1_desc.s1_lia_angles.shape,
+        s1desc_lia_patch.shape,
+    )
+    logger.debug("srtm patches", srtm_patch.shape)
+    logger.debug("wc patches", wc_patch.shape)
+
+    # apply predict function
+    # should return a s1s2vaelatentspace
+    # object
+    pred_vaelatentspace = process.process(
+        process.model,
+        sensors,
+        s2_refl_patch,
+        s2_mask_patch,
+        s2_ang_patch,
+        s1_patch,
+        s1_valmask_patch,
+        s1asc_lia_patch,
+        s1desc_lia_patch,
+        wc_patch,
+        srtm_patch,
+    )
+    logger.debug(type(pred_vaelatentspace))
+    logger.debug("latent space sizes : ", pred_vaelatentspace.shape)
+
+    # unpatchify
+    pred_vaelatentspace_unpatchify = patches.unpatchify(
+        patches.unflatten2d(
+            pred_vaelatentspace,
+            s2_mask_patch_size[0],
+            s2_mask_patch_size[1],
+        )
+    )[:, : s2_s2_mask_shape[1], : s2_s2_mask_shape[2]]
+
+    logger.debug("pred_tensor :", pred_vaelatentspace_unpatchify.shape)
+    logger.debug("process.count : ", process.count)
+    # check the pred and the ask are the same
+    assert process.count == pred_vaelatentspace_unpatchify.shape[0]
+
+    predition_array = np.array(pred_vaelatentspace_unpatchify[0, ...])
+
+    logger.debug(("Export tile", f"filename :{export_path}"))
+
+    return predition_array, labels
+
+
+def acquisition_dataframe(
+    self,
+) -> pd.DataFrame:
+    """
+    Construct a DataFrame with the acquisition dates
+    for the predictions
+    """
+    # use Iota2 API
+    # acquisition_dates = self.get_raw_dates()
+
+    # Get the acquisition dates
+    acquisition_dates = {
+        "Sentinel1_DES_vv": ["20151231"],
+        "Sentinel1_DES_vh": ["20151231"],
+        "Sentinel1_ASC_vv": ["20170518", "20170519"],
+        "Sentinel1_ASC_vh": ["20170518", "20170519"],
+        "Sentinel2": ["20151130", "20151203"],
+    }
+
+    # get the acquisition dates for each sensor-mode
+    sentinel2_acquisition_dates = pd.DataFrame(
+        {"Sentinel2": acquisition_dates["Sentinel2"]}, dtype=str
+    )
+    sentinel1_asc_acquisition_dates = pd.DataFrame(
+        {
+            "Sentinel1_ASC_vv": acquisition_dates["Sentinel1_ASC_vv"],
+            "Sentinel1_ASC_vh": acquisition_dates["Sentinel1_ASC_vh"],
+        },
+        dtype=str,
+    )
+    sentinel1_des_acquisition_dates = pd.DataFrame(
+        {
+            "Sentinel1_DES_vv": acquisition_dates["Sentinel1_DES_vv"],
+            "Sentinel1_DES_vh": acquisition_dates["Sentinel1_DES_vh"],
+        },
+        dtype=str,
+    )
+
+    # convert acquistion dates from str to datetime object
+    sentinel2_acquisition_dates["Sentinel2"] = sentinel2_acquisition_dates["Sentinel2"].apply(
+        lambda x: date_parser.parse(x)
+    )
+    sentinel1_asc_acquisition_dates["Sentinel1_ASC_vv"] = sentinel1_asc_acquisition_dates[
+        "Sentinel1_ASC_vv"
+        ].apply(lambda x: date_parser.parse(x))
+    sentinel1_asc_acquisition_dates["Sentinel1_ASC_vh"] = sentinel1_asc_acquisition_dates[
+        "Sentinel1_ASC_vh"
+        ].apply(lambda x: date_parser.parse(x))
+    sentinel1_des_acquisition_dates["Sentinel1_DES_vv"] = sentinel1_des_acquisition_dates[
+        "Sentinel1_DES_vv"
+        ].apply(lambda x: date_parser.parse(x))
+    sentinel1_des_acquisition_dates["Sentinel1_DES_vh"] = sentinel1_des_acquisition_dates[
+        "Sentinel1_DES_vh"
+        ].apply(lambda x: date_parser.parse(x))
+
+    # merge the dataframes
+    sensors_join = pd.concat([
+        sentinel2_acquisition_dates,
+        sentinel1_asc_acquisition_dates,
+        sentinel1_des_acquisition_dates,
+        ]
+    )
+
+    # get acquisition availability
+    sensors_join["s2_availability"] = pd.isna(sensors_join["Sentinel2"])
+    sensors_join["s1_availability_asc"] = pd.isna(sensors_join["Sentinel1_ASC_vv"])
+    sensors_join["s1_availability_desc"] = pd.isna(sensors_join["Sentinel1_DES_vv"])
+    # return
+
+
+def apply_mmdc_full_mode(
+    # checkpoint_path: str, # read from the cfg file
+    # checkpoint_epoch: int = 100, # read from the cfg file
+    # patch_size: int = 256, # read from the cfg file
+):
+    """
+    Entry Point
+    """
+    # Set device
+    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+    # read the configuration file
+    config = rcf("configs/iota2/externalfeatures_with_userfeatures.cfg")
+    # return a list with a tuple of functions and arguments
+    mmdc_parametrs = config.getParam("external_features", "functions")
+    padding_size_x = config.getParam("python_data_managing", "padding_size_x")
+    padding_size_y = config.getParam("python_data_managing", "padding_size_y")
+
+    checkpoint_path = parametrs[0][1]  # TODO fix this parameters
+
+    # Get the acquisition dates
+    # dict_dates  = {
+    # 'Sentinel1_DES_vv': ['20151231'],
+    # 'Sentinel1_DES_vh': ['20151231'],
+    # 'Sentinel1_ASC_vv': ['20170518', '20170519'],
+    # 'Sentinel1_ASC_vh': ['20170518', '20170519'],
+    # 'Sentinel2': ['20151130', '20151203']}
+
+    # model
+    mmdc_full_model = get_mmdc_full_model(
+        checkpoint=model_checkpoint_path,
+        mmdc_full_config_path=model_config_path,
+        inference_tile=False,
+    )
+    scales = get_scales()
+    mmdc_full_model.set_scales(scales)
+    mmdc_full_model.to(device)
+
+    # process class
+    mmdc_process = MMDCProcess(
+        count=latent_space_size,
+        nb_lines=nb_lines,
+        patch_size=patch_size,
+        process=predict_mmdc_model,
+        model=mmdc_full_model,
+    )
+    # create export filename
+    export_path = f"{export_path}/latent_singledate_{'_'.join(sensors)}.tif"
+
+    # # How manage the S1 acquisition dates for construct the mask validity
+    # # in time
+
+    # # TODO Read SRTM data
+    # # TODO Read Worldclim Data
+
+    # # with torch.no_grad():
+    # # Permute dimensions to fit patchify
+    # # Shape before permutation is C,H,W,D. D being the dates
+    # # Shape after permutation is D,C,H,W
+    # bands_s2_mask = torch.Tensor(list_s2_mask).permute(-1, 0, 1, 2)
+    # bands_s1 = torch.Tensor(list_bands_s1).permute(-1, 0, 1, 2)
+    # bands_s1 = apply_log_to_s1(bands_s1).permute(-1, 0, 1, 2)
+
+    # # TODO Masks contruction for S1
+    # # build_s1_image_and_masks function for datamodules components datamodule components ?
+
+    # # Replace nan by 0
+    # bands_s1 = bands_s1.nan_to_num()
+    # # These dimensions are useful for unpatchify
+    # # Keep the dimensions of height and width found in chunk
+    # band_h, band_w = bands_s2.shape[-2:]
+    # # Keep the number of patches of patch_size
+    # # in rows and cols found in chunk
+    # h, w = patches.patchify(
+    #     bands_s2[0, ...], patch_size=patch_size, margin=patch_margin
+    # ).shape[:2]
+
+    # # TODO Apply patchify
+
+    # # Get the model
+    # mmdc_full_model = get_mmdc_full_model(
+    #     os.path.join(checkpoint_path, checkpoint_filename)
+    # )
+
+    # # apply model
+    # # latent_variable = mmdc_full_model.predict(
+    # #     s2_x,
+    # #     s2_m,
+    # #     s2_angles_x,
+    # #     s1_x,
+    # #     s1_vm,
+    # #     s1_asc_angles_x,
+    # #     s1_desc_angles_x,
+    # #     worldclim_x,
+    # #     srtm_x,
+    # # )
+
+    # # TODO unpatchify
+
+    # # TODO crop padding
+
+    # # TODO Depending of the configuration return a unique latent variable
+    # # or a stack of laten variables
+    # # coef = pass
+    # # labels = pass
+    # # return coef, labels
+
+
+def test_mnt(self, arg1):
+    """
+    compute the Soil Composition Index
+    """
+    print(dir(self))
+    print("raw dates :", self.get_raw_dates())
+    print("interpolated dates :",self.get_interpolated_dates())
+    print("interpolated dates :",self.interpolated_dates)
+    print("exogeneous_data_name:",self.exogeneous_data_name)
+    print("fill_missing_dates:",self.fill_missing_dates)
+    print(":",None)
+
+    mnt =  self.get_Sentinel2_binary_masks() #self.get_raw_userFeatures_mnt_band_0()#self.get_raw_userFeatures_mnt2()
+    # mnt =  self.get_raw_userFeatures_mnt2_band_0()
+    # mnt =  self.get_raw_userFeatures_slope()
+    print(f"type mnt : {type(mnt)}")
+    print(f"size mnt : {mnt.shape}")
+    labels = [f"mnt_{cpt}" for cpt in range(mnt.shape[-1])]
+    print(f"labels mnt : {labels}")
+
+    # s2_b5 =  self.get_interpolated_Sentinel2_B5()
+    # print(f"type s2_b5 : {type(s2_b5)}") # 4, 3, 2  # 26,50,2
+    # print(f"size s2_b5 : {s2_b5.shape}") # 4, 3, 1  # 26,50,2
+
+    return mnt, labels
diff --git a/src/mmdc_singledate/inference/mmdc_tile_inference.py b/src/mmdc_singledate/inference/mmdc_tile_inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..307e3007405badd2b7fddfc05a5e3fbbb159e300
--- /dev/null
+++ b/src/mmdc_singledate/inference/mmdc_tile_inference.py
@@ -0,0 +1,442 @@
+#!/usr/bin/env python3
+# copyright: (c) 2023 cesbio / centre national d'Etudes Spatiales
+
+"""
+Infereces API with Rasterio
+"""
+
+# imports
+import logging
+from collections.abc import Callable
+from dataclasses import dataclass
+from pathlib import Path
+
+import numpy as np
+import pandas as pd
+import rasterio as rio
+import torch
+from torchutils import patches
+from tqdm import tqdm
+
+from mmdc_singledate.datamodules.components.datamodule_components import prepare_data_df
+
+from .components.inference_components import get_mmdc_full_model, predict_mmdc_model
+from .components.inference_utils import (
+    GeoTiffDataset,
+    concat_worldclim_components,
+    generate_chunks,
+    read_img_tile,
+    read_s1_img_tile,
+    read_s2_img_tile,
+    read_srtm_img_tile,
+)
+from .utils import get_scales
+
+# Configure the logger
+NUMERIC_LEVEL = getattr(logging, "INFO", None)
+logging.basicConfig(
+    level=NUMERIC_LEVEL, format="%(asctime)-15s %(levelname)s: %(message)s"
+)
+logger = logging.getLogger(__name__)
+
+
+def inference_dataframe(
+    samples_dir: str,
+    input_tile_list: list[str] | None,
+    days_gap: int,
+):
+    """
+    Read the input directory and create
+    a dataframe with the occurrences
+    for be infered
+    """
+
+    # Create the Tile time serie
+
+    tile_time_serie_df = prepare_data_df(
+        samples_dir=samples_dir, input_tile_list=input_tile_list, days_gap=days_gap
+    )
+
+    tile_time_serie_df["patch_s2_availability"] = tile_time_serie_df["patch_s2"].apply(
+        lambda x: Path(x).exists()
+    )
+
+    tile_time_serie_df["patchasc_s1_availability"] = tile_time_serie_df[
+        "patchasc_s1"
+    ].apply(lambda x: Path(x).exists())
+
+    tile_time_serie_df["patchdesc_s1_availability"] = tile_time_serie_df[
+        "patchasc_s1"
+    ].apply(lambda x: Path(x).exists())
+
+    return tile_time_serie_df
+
+
+# functions and classes
+@dataclass
+class MMDCProcess:
+    """
+    Class to hold
+    """
+
+    count: int
+    nb_lines: int
+    patch_size: int
+    model: torch.nn
+    process: Callable[
+        [
+            torch.tensor,
+            torch.tensor,
+            torch.tensor,
+            torch.tensor,
+            torch.tensor,
+            torch.tensor,
+            torch.tensor,
+        ],
+        torch.tensor,
+    ]
+
+
+def predict_single_date_tile(
+    input_data: GeoTiffDataset,
+    export_path: Path,
+    sensors: list[str],
+    process: MMDCProcess,
+):
+    """
+    Predict a tile of data
+    """
+    # check that at least one dataset exists
+    if (
+        Path(input_data.s2_filename).exists
+        or Path(input_data.s1_asc_filename)
+        or Path(input_data.s1_desc_filename)
+    ):
+        # retrieve the metadata
+        meta = input_data.metadata.copy()
+        # get nb outputs
+        meta.update({"count": process.count})
+        # calculate rois
+        chunks = generate_chunks(meta["width"], meta["height"], process.nb_lines)
+        # get the windows from the chunks
+        rois = [rio.windows.Window(*chunk) for chunk in chunks]
+        logger.debug(f"chunk size : ({rois[0].width}, {rois[0].height}) ")
+
+        # init the dataset
+        logger.debug("Reading S2 data")
+        s2_data = read_img_tile(
+            filename=input_data.s2_filename,
+            rois=rois,
+            availabitity=input_data.s2_availabitity,
+            sensor_func=read_s2_img_tile,
+        )
+        logger.debug("Reading S1 ASC data")
+        s1_asc_data = read_img_tile(
+            filename=input_data.s1_asc_filename,
+            rois=rois,
+            availabitity=input_data.s1_asc_availability,
+            sensor_func=read_s1_img_tile,
+        )
+        logger.debug("Reading S1 DESC data")
+        s1_desc_data = read_img_tile(
+            filename=input_data.s1_desc_filename,
+            rois=rois,
+            availabitity=input_data.s1_desc_availability,
+            sensor_func=read_s1_img_tile,
+        )
+        logger.debug("Reading SRTM data")
+        srtm_data = read_img_tile(
+            filename=input_data.srtm_filename,
+            rois=rois,
+            availabitity=True,
+            sensor_func=read_srtm_img_tile,
+        )
+        logger.debug("Reading WorldClim data")
+        worldclim_data = concat_worldclim_components(
+            wc_filename=input_data.wc_filename, rois=rois, availabitity=True
+        )
+        logger.debug("input_data.s2_availabitity=", input_data.s2_availabitity)
+        logger.debug("s2_data=", s2_data)
+        logger.debug("s1_asc_data=", s1_asc_data)
+        logger.debug("s1_desc_data=", s1_desc_data)
+        logger.debug("srtm_data=", srtm_data)
+        logger.debug("worldclim_data=", worldclim_data)
+        logger.debug("Export Init")
+
+        with rio.open(export_path, "w", **meta) as prediction:
+            # iterate over the windows
+            for roi, s2, s1_asc, s1_desc, srtm, wc in tqdm(
+                zip(
+                    rois,
+                    s2_data,
+                    s1_asc_data,
+                    s1_desc_data,
+                    srtm_data,
+                    worldclim_data,
+                ),
+                total=len(rois),
+                position=1,
+            ):
+                # export one tile
+                logger.debug(" original size : ", s2.s2_reflectances.shape)
+                logger.debug(" original size : ", s2.s2_angles.shape)
+                logger.debug(" original size : ", s2.s2_mask.shape)
+                logger.debug(" original size : ", s1_asc.s1_backscatter.shape)
+                logger.debug(" original size : ", s1_asc.s1_valmask.shape)
+                logger.debug(" original size : ", s1_asc.s1_lia_angles.shape)
+                logger.debug(" original size : ", s1_desc.s1_backscatter.shape)
+                logger.debug(" original size : ", s1_desc.s1_valmask.shape)
+                logger.debug(" original size : ", s1_desc.s1_lia_angles.shape)
+                logger.debug(" original size : ", srtm.srtm.shape)
+                logger.debug(" original size : ", wc.worldclim.shape)
+
+                # Concat S1 Data
+                s1_backscatter = torch.cat(
+                    (s1_asc.s1_backscatter, s1_desc.s1_backscatter), 0
+                )
+                # The validity mask of S1 is unique
+                # Multiply the masks
+                s1_valmask = torch.mul(s1_asc.s1_valmask, s1_desc.s1_valmask)
+
+                # keep the sizes for recover the original size
+                s2_mask_patch_size = patches.patchify(
+                    s2.s2_mask, process.patch_size
+                ).shape
+                s2_s2_mask_shape = s2.s2_mask.shape  # [1, 1024, 10980]
+
+                logger.debug("s2_mask_patch_size :", s2_mask_patch_size)
+                logger.debug("s2_s2_mask_shape:", s2_s2_mask_shape)
+                # reshape the data
+                s2_refl_patch = patches.flatten2d(
+                    patches.patchify(s2.s2_reflectances, process.patch_size)
+                )
+                s2_ang_patch = patches.flatten2d(
+                    patches.patchify(s2.s2_angles, process.patch_size)
+                )
+                s2_mask_patch = patches.flatten2d(
+                    patches.patchify(s2.s2_mask, process.patch_size)
+                )
+                s1_patch = patches.flatten2d(
+                    patches.patchify(s1_backscatter, process.patch_size)
+                )
+                s1_valmask_patch = patches.flatten2d(
+                    patches.patchify(s1_valmask, process.patch_size)
+                )
+                srtm_patch = patches.flatten2d(
+                    patches.patchify(srtm.srtm, process.patch_size)
+                )
+                wc_patch = patches.flatten2d(
+                    patches.patchify(wc.worldclim, process.patch_size)
+                )
+                # Expand the angles to fit the sizes
+                s1asc_lia_patch = s1_asc.s1_lia_angles.unsqueeze(0).repeat(
+                    s2_mask_patch_size[0] * s2_mask_patch_size[1], 1
+                )
+                # torch.flatten(start_dim=0, end_dim=1)
+                s1desc_lia_patch = s1_desc.s1_lia_angles.unsqueeze(0).repeat(
+                    s2_mask_patch_size[0] * s2_mask_patch_size[1], 1
+                )
+
+                logger.debug("s2_patches", s2_refl_patch.shape)
+                logger.debug("s2_angles_patches", s2_ang_patch.shape)
+                logger.debug("s2_mask_patch", s2_mask_patch.shape)
+                logger.debug("s1_patch", s1_patch.shape)
+                logger.debug("s1_valmask_patch", s1_valmask_patch.shape)
+                logger.debug(
+                    "s1_lia_asc patches",
+                    s1_asc.s1_lia_angles.shape,
+                    s1asc_lia_patch.shape,
+                )
+                logger.debug(
+                    "s1_lia_desc patches",
+                    s1_desc.s1_lia_angles.shape,
+                    s1desc_lia_patch.shape,
+                )
+                logger.debug("srtm patches", srtm_patch.shape)
+                logger.debug("wc patches", wc_patch.shape)
+
+                # apply predict function
+                # should return a s1s2vaelatentspace
+                # object
+                pred_vaelatentspace = process.process(
+                    process.model,
+                    sensors,
+                    s2_refl_patch,
+                    s2_mask_patch,
+                    s2_ang_patch,
+                    s1_patch,
+                    s1_valmask_patch,
+                    s1asc_lia_patch,
+                    s1desc_lia_patch,
+                    wc_patch,
+                    srtm_patch,
+                )
+                logger.debug(type(pred_vaelatentspace))
+                logger.debug("latent space sizes : ", pred_vaelatentspace.shape)
+
+                # unpatchify
+                pred_vaelatentspace_unpatchify = patches.unpatchify(
+                    patches.unflatten2d(
+                        pred_vaelatentspace,
+                        s2_mask_patch_size[0],
+                        s2_mask_patch_size[1],
+                    )
+                )[:, : s2_s2_mask_shape[1], : s2_s2_mask_shape[2]]
+
+                logger.debug("pred_tensor :", pred_vaelatentspace_unpatchify.shape)
+                logger.debug("process.count : ", process.count)
+                # check the pred and the ask are the same
+                assert process.count == pred_vaelatentspace_unpatchify.shape[0]
+
+                prediction.write(
+                    np.array(pred_vaelatentspace_unpatchify[0, ...]),
+                    window=roi,
+                    indexes=process.count,
+                )
+                logger.debug(("Export tile", f"filename :{export_path}"))
+
+
+def estimate_nb_latent_spaces(
+    patch_s2_availability: bool,
+    patchasc_s1_availability: bool,
+    patchdesc_s1_availability: bool,
+) -> int:
+    """
+    Given the availability of the input data
+    estimate the number of latent variables
+    """
+
+    if (patch_s2_availability and patchasc_s1_availability) or (
+        patch_s2_availability and patchdesc_s1_availability
+    ):
+        nb_latent_spaces = 8 * 2
+
+    else:
+        nb_latent_spaces = 4 * 2
+
+    return nb_latent_spaces
+
+
+def determinate_sensor_list(
+    patch_s2_availability: bool,
+    patchasc_s1_availability: bool,
+    patchdesc_s1_availability: bool,
+) -> list:
+    """
+    Convert availabitiy dataframe to list of sensors
+    """
+    sensors_availability = [
+        patch_s2_availability,
+        patchasc_s1_availability,
+        patchdesc_s1_availability,
+    ]
+    if sensors_availability == [True, True, True]:
+        sensors = ["S2L2A", "S1FULL"]
+    elif sensors_availability == [True, True, False]:
+        sensors = ["S2L2A", "S1ASC"]
+    elif sensors_availability == [True, False, True]:
+        sensors = ["S2L2A", "S1DESC"]
+    elif sensors_availability == [False, True, True]:
+        sensors = ["S1FULL"]
+    elif sensors_availability == [True, False, False]:
+        sensors = ["S2L2A"]
+    elif sensors_availability == [False, True, False]:
+        sensors = ["S1ASC"]
+    elif sensors_availability == [False, False, True]:
+        sensors = ["S1DESC"]
+    else:
+        raise Exception(
+            (f"The sensor {sensors_availability}", " is not in the available list")
+        )
+    return sensors
+
+
+def mmdc_tile_inference(
+    inference_dataframe: pd.DataFrame,
+    export_path: Path,
+    model_checkpoint_path: Path,
+    model_config_path: Path,
+    patch_size: int = 256,
+    nb_lines: int = 1024,
+) -> None:
+    """
+    Entry point
+
+        :args: input_path : full path to the export
+        :args: export_path : full path to the export
+        :args: sensor: list of sensors data input
+        :args: tile_list: list of tiles,
+        :args: days_gap : days between s1 and s2 acquisitions,
+        :args: latent_space_size : latent space output par sensor,
+        :args: nb_lines : number of lines to read every at time 1024,
+
+        :return : None
+
+    """
+    # device
+
+    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+    # TODO Uncomment for test purpuse
+    # inference_dataframe = inference_dataframe.head()
+    # logger.debug(inference_dataframe.shape)
+
+    # instance the model and get the pred func
+    model = get_mmdc_full_model(
+        checkpoint=model_checkpoint_path,
+        mmdc_full_config_path=model_config_path,
+        inference_tile=True,
+    )
+    scales = get_scales()
+    model.set_scales(scales)
+    # move to device
+    model.to(device)
+
+    # iterate over the dates in the time serie
+    for tuile, df_row in tqdm(
+        inference_dataframe.iterrows(), total=inference_dataframe.shape[0], position=0
+    ):
+        # estimate nb latent spaces
+        latent_space_size = estimate_nb_latent_spaces(
+            df_row["patch_s2_availability"],
+            df_row["patchasc_s1_availability"],
+            df_row["patchdesc_s1_availability"],
+        )
+        sensor = determinate_sensor_list(
+            df_row["patch_s2_availability"],
+            df_row["patchasc_s1_availability"],
+            df_row["patchdesc_s1_availability"],
+        )
+        # export path
+        date_ = df_row["date"].strftime("%Y-%m-%d")
+        export_file = export_path / f"latent_tile_infer_{date_}_{'_'.join(sensor)}.tif"
+        logger.debug(tuile, df_row)
+        logger.debug(latent_space_size)
+        # define process
+        mmdc_process = MMDCProcess(
+            count=latent_space_size,
+            nb_lines=nb_lines,
+            patch_size=patch_size,
+            model=model,
+            process=predict_mmdc_model,
+        )
+        # get the input data
+        mmdc_input_data = GeoTiffDataset(
+            s2_filename=df_row["patch_s2"],
+            s1_asc_filename=df_row["patchasc_s1"],
+            s1_desc_filename=df_row["patchdesc_s1"],
+            srtm_filename=df_row["srtm_filename"],
+            wc_filename=df_row["worldclim_filename"],
+            s2_availabitity=df_row["patch_s2_availability"],
+            s1_asc_availability=df_row["patchasc_s1_availability"],
+            s1_desc_availability=df_row["patchdesc_s1_availability"],
+        )
+        # predict tile
+        predict_single_date_tile(
+            input_data=mmdc_input_data,
+            export_path=export_file,
+            sensors=sensor,
+            process=mmdc_process,
+        )
+
+        logger.debug("Export Finish !!!")
diff --git a/src/mmdc_singledate/inference/mmdc_tile_inference_iota2.py  b/src/mmdc_singledate/inference/mmdc_tile_inference_iota2.py 
new file mode 100644
index 0000000000000000000000000000000000000000..1110eccfc79e703032a87b3c673e992252594453
--- /dev/null
+++ b/src/mmdc_singledate/inference/mmdc_tile_inference_iota2.py 	
@@ -0,0 +1,120 @@
+#!/usr/bin/env python3
+# copyright: (c) 2023 cesbio / centre national d'Etudes Spatiales
+
+"""
+Infereces API with Iota-2
+Inspired by:
+https://src.koda.cnrs.fr/mmdc/mmdc-singledate/-/blob/01ff5139a9eb22785930964d181e2a0b7b7af0d1/iota2/external_iota2_code.py
+"""
+
+from typing import Any
+
+import torch
+from torchutils import patches
+
+from mmdc_singledate.datamodules.components.datamodule_utils import (
+    apply_log_to_s1,
+    srtm_height_aspect,
+    join_even_odd_s2_angles
+)
+
+
+def apply_mmdc_full_mode(
+    self,
+    checkpoint_path: str,
+    checkpoint_epoch: int = 100,
+    patch_size: int = 256,
+):
+    """
+    Apply MMDC with Iota-2.py
+    """
+    # How manage the S1 acquisition dates for construct the mask validity
+    # in time
+
+    # DONE Get the data in the same order as
+    # sentinel2.Sentinel2.GROUP_10M +
+    # sensorsio.sentinel2.GROUP_20M
+    # This data correspond to (Chunk_Size, Img_Size, nb_dates)
+    list_bands_s2 = [
+        self.get_interpolated_Sentinel2_B2(),
+        self.get_interpolated_Sentinel2_B3(),
+        self.get_interpolated_Sentinel2_B4(),
+        self.get_interpolated_Sentinel2_B8(),
+        self.get_interpolated_Sentinel2_B5(),
+        self.get_interpolated_Sentinel2_B6(),
+        self.get_interpolated_Sentinel2_B7(),
+        self.get_interpolated_Sentinel2_B8A(),
+        self.get_interpolated_Sentinel2_B11(),
+        self.get_interpolated_Sentinel2_B12(),
+    ]
+
+    # TODO Masks contruction for S2
+    list_s2_mask = self.get_Sentinel2_binary_masks()
+
+    # TODO Manage S1 ASC and S1 DESC ?
+    list_bands_s1 = [
+        self.get_interpolated_Sentinel1_ASC_vh(),
+        self.get_interpolated_Sentinel1_ASC_vv(),
+        self.get_interpolated_Sentinel1_ASC_vh()
+        / (self.get_interpolated_Sentinel1_ASC_vv() + 1e-4),
+        self.get_interpolated_Sentinel1_DES_vh(),
+        self.get_interpolated_Sentinel1_DES_vv(),
+        self.get_interpolated_Sentinel1_DES_vh()
+        / (self.get_interpolated_Sentinel1_DES_vv() + 1e-4),
+    ]
+
+    with torch.no_grad():
+        # Permute dimensions to fit patchify
+        # Shape before permutation is C,H,W,D. D being the dates
+        # Shape after permutation is D,C,H,W
+        bands_s2 = torch.Tensor(list_bands_s2).permute(-1, 0, 1, 2)
+        bands_s2_mask = torch.Tensor(list_s2_mask).permute(-1, 0, 1, 2)
+        bands_s1 = torch.Tensor(list_bands_s1).permute(-1, 0, 1, 2)
+        bands_s1 = apply_log_to_s1(bands_s1).permute(-1, 0, 1, 2)
+
+
+        # TODO Masks contruction for S1
+        # build_s1_image_and_masks function for datamodules components datamodule components ?
+
+        # Replace nan by 0
+        bands_s1 = bands_s1.nan_to_num()
+        # These dimensions are useful for unpatchify
+        # Keep the dimensions of height and width found in chunk
+        band_h, band_w = bands_s2.shape[-2:]
+        # Keep the number of patches of patch_size
+        # in rows and cols found in chunk
+        h, w = patches.patchify(bands_s2[0, ...],
+                                patch_size=patch_size,
+                                margin=patch_margin).shape[:2]
+
+
+    # TODO Apply patchify
+
+
+    # Get the model
+    mmdc_full_model = get_mmdc_full_model(
+        os.path.join(checkpoint_path, checkpoint_filename)
+    )
+
+    # apply model
+    latent_variable = mmdc_full_model.predict(
+        s2_x,
+        s2_m,
+        s2_angles_x,
+        s1_x,
+        s1_vm,
+        s1_asc_angles_x,
+        s1_desc_angles_x,
+        worldclim_x,
+        srtm_x,
+    )
+
+    # TODO unpatchify
+
+    # TODO crop padding
+
+    # TODO Depending of the configuration return a unique latent variable
+    # or a stack of laten variables
+    coef = pass
+    labels = pass
+    return coef, labels
diff --git a/src/mmdc_singledate/inference/utils.py b/src/mmdc_singledate/inference/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..4c24cb634838ce955c8dec096e4557113993e4bb
--- /dev/null
+++ b/src/mmdc_singledate/inference/utils.py
@@ -0,0 +1,240 @@
+#!/usr/bin/env python3
+
+from dataclasses import fields
+
+import pydantic
+import torch
+import yaml
+from config import Config
+
+from mmdc_singledate.datamodules.types import (
+    MMDCDataChannels,
+    MMDCDataStats,
+    MMDCShiftScales,
+    ShiftScale,
+)
+from mmdc_singledate.models.types import (  # S1S2VAELatentSpace,
+    ConvnetParams,
+    MLPConfig,
+    MMDCDataUse,
+    ModularEmbeddingConfig,
+    MultiVAEConfig,
+    MultiVAELossWeights,
+    TranslationLossWeights,
+    UnetParams,
+)
+
+
+# TODO Complexify the scales used for the inference
+def get_scales():
+    """
+    Read Scales for Inference
+    """
+    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+    dataset_dir = "/work/CESBIO/projects/MAESTRIA/training_dataset2/tftwmwc3"
+    tile = "T31TEM"
+    patch_size = "256x256"
+    roi = 1
+    stats = MMDCDataStats(
+        torch.load(f"{dataset_dir}/{tile}/{tile}_stats_s2_{patch_size}_{roi}.pth"),
+        torch.load(f"{dataset_dir}/{tile}/{tile}_stats_s1_{patch_size}_{roi}.pth"),
+        torch.load(
+            f"{dataset_dir}/{tile}/{tile}_stats_worldclim_{patch_size}_{roi}.pth"
+        ),
+        torch.load(f"{dataset_dir}/{tile}/{tile}_stats_srtm_{patch_size}_{roi}.pth"),
+    )
+    scale_regul = torch.nn.Threshold(1e-10, 1.0)
+    shift_scale_s2 = ShiftScale(
+        stats.sen2.median.to(device),
+        scale_regul((stats.sen2.qmax - stats.sen2.qmin) / 2.0).to(device),
+    )
+    shift_scale_s1 = ShiftScale(
+        stats.sen1.median.to(device),
+        scale_regul((stats.sen1.qmax - stats.sen1.qmin) / 2.0).to(device),
+    )
+    shift_scale_wc = ShiftScale(
+        stats.worldclim.median.to(device),
+        scale_regul((stats.worldclim.qmax - stats.worldclim.qmin) / 2.0).to(device),
+    )
+    shift_scale_srtm = ShiftScale(
+        stats.srtm.median.to(device),
+        scale_regul((stats.srtm.qmax - stats.srtm.qmin) / 2.0).to(device),
+    )
+
+    return MMDCShiftScales(
+        shift_scale_s2, shift_scale_s1, shift_scale_wc, shift_scale_srtm
+    )
+
+
+# DONE add pydantic validation
+def parse_from_yaml(path_to_yaml: str):
+    """
+    Read a yaml file and return a dict
+    https://betterprogramming.pub/validating-yaml-configs-made-easy-with-pydantic-594522612db5
+    """
+    with open(path_to_yaml) as f:
+        config = yaml.safe_load(f)
+
+    return config
+
+
+def parse_multivaeconfig_to_pydantic(cfg: dict):
+    """
+    Convert a MultiVAEConfig to a pydantic version
+    for validate the values
+        :args: cfg : configuration as dict
+    """
+    # parse the field of the class to a dict
+    multivae_fields = {
+        _field.name: (_field.type, ...) for _field in fields(MultiVAEConfig)
+    }
+    # create pydantic model from a dataclass
+    pydantic_multivaefields = pydantic.create_model("MultiVAEConfig", **multivae_fields)
+    # apply the convertion
+    multivaeconfig_pydantic = pydantic_multivaefields(**cfg["model"]["config"])
+
+    return multivaeconfig_pydantic
+
+
+# DONE add pydantic validation
+def parse_from_cfg(path_to_cfg: str):
+    """
+    Read a cfg as config file validate the values
+    and return a dict
+    """
+    with open(path_to_cfg, encoding="UTF-8") as i2_config:
+        cfg = Config(i2_config)
+        cfg_dict = cfg.as_dict()
+
+    return cfg_dict
+
+
+def get_mmdc_full_config(path_to_cfg: str, inference_tile: bool) -> MultiVAEConfig:
+    """ "
+    Get the parameters for instantiate the mmdc_full module
+    based on the target inference mode
+    """
+
+    # read the yaml hydra config
+    if inference_tile:
+        cfg = parse_from_yaml(path_to_cfg)  # "configs/model/mmdc_full.yaml")
+        cfg = parse_multivaeconfig_to_pydantic(cfg)
+
+    else:
+        cfg = parse_from_cfg(path_to_cfg)
+        cfg = parse_multivaeconfig_to_pydantic(cfg)
+
+    mmdc_full_config = MultiVAEConfig(
+        data_sizes=MMDCDataChannels(
+            sen1=cfg.data_sizes.sen1,
+            s1_angles=cfg.data_sizes.s1_angles,
+            sen2=cfg.data_sizes.sen2,
+            s2_angles=cfg.data_sizes.s2_angles,
+            srtm=cfg.data_sizes.srtm,
+            w_c=cfg.data_sizes.w_c,
+        ),
+        embeddings=ModularEmbeddingConfig(
+            s1_angles=MLPConfig(
+                hidden_layers=cfg.embeddings.s1_angles.hidden_layers,
+                out_channels=cfg.embeddings.s1_angles.out_channels,
+            ),
+            s1_srtm=UnetParams(
+                out_channels=cfg.embeddings.s1_srtm.out_channels,
+                encoder_sizes=cfg.embeddings.s1_srtm.encoder_sizes,
+                kernel_size=cfg.embeddings.s1_srtm.kernel_size,
+                tail_layers=cfg.embeddings.s1_srtm.tail_layers,
+            ),
+            s2_angles=ConvnetParams(
+                out_channels=cfg.embeddings.s2_angles.out_channels,
+                sizes=cfg.embeddings.s2_angles.sizes,
+                kernel_sizes=cfg.embeddings.s2_angles.kernel_sizes,
+            ),
+            s2_srtm=UnetParams(
+                out_channels=cfg.embeddings.s2_srtm.out_channels,
+                encoder_sizes=cfg.embeddings.s2_srtm.encoder_sizes,
+                kernel_size=cfg.embeddings.s2_srtm.kernel_size,
+                tail_layers=cfg.embeddings.s2_srtm.tail_layers,
+            ),
+            w_c=ConvnetParams(
+                out_channels=cfg.embeddings.w_c.out_channels,
+                sizes=cfg.embeddings.w_c.sizes,
+                kernel_sizes=cfg.embeddings.w_c.kernel_sizes,
+            ),
+        ),
+        s1_encoder=UnetParams(
+            out_channels=cfg.s1_encoder.out_channels,
+            encoder_sizes=cfg.s1_encoder.encoder_sizes,
+            kernel_size=cfg.s1_encoder.kernel_size,
+            tail_layers=cfg.s1_encoder.tail_layers,
+        ),
+        s2_encoder=UnetParams(
+            out_channels=cfg.s2_encoder.out_channels,
+            encoder_sizes=cfg.s2_encoder.encoder_sizes,
+            kernel_size=cfg.s2_encoder.kernel_size,
+            tail_layers=cfg.s2_encoder.tail_layers,
+        ),
+        s1_decoder=ConvnetParams(
+            out_channels=cfg.s1_decoder.out_channels,
+            sizes=cfg.s1_decoder.sizes,
+            kernel_sizes=cfg.s1_decoder.kernel_sizes,
+        ),
+        s2_decoder=ConvnetParams(
+            out_channels=cfg.s2_decoder.out_channels,
+            sizes=cfg.s2_decoder.sizes,
+            kernel_sizes=cfg.s2_decoder.kernel_sizes,
+        ),
+        s1_enc_use=MMDCDataUse(
+            s1_angles=cfg.s1_enc_use.s1_angles,
+            s2_angles=cfg.s1_enc_use.s2_angles,
+            srtm=cfg.s1_enc_use.srtm,
+            w_c=cfg.s1_enc_use.w_c,
+        ),
+        s2_enc_use=MMDCDataUse(
+            s1_angles=cfg.s2_enc_use.s1_angles,
+            s2_angles=cfg.s2_enc_use.s2_angles,
+            srtm=cfg.s2_enc_use.srtm,
+            w_c=cfg.s2_enc_use.w_c,
+        ),
+        s1_dec_use=MMDCDataUse(
+            s1_angles=cfg.s1_dec_use.s1_angles,
+            s2_angles=cfg.s1_dec_use.s2_angles,
+            srtm=cfg.s1_dec_use.srtm,
+            w_c=cfg.s1_dec_use.w_c,
+        ),
+        s2_dec_use=MMDCDataUse(
+            s1_angles=cfg.s2_dec_use.s1_angles,
+            s2_angles=cfg.s2_dec_use.s2_angles,
+            srtm=cfg.s2_dec_use.srtm,
+            w_c=cfg.s2_dec_use.w_c,
+        ),
+        loss_weights=MultiVAELossWeights(
+            sen1=TranslationLossWeights(
+                nll=cfg.loss_weights.sen1.nll,
+                lpips=cfg.loss_weights.sen1.lpips,
+                sam=cfg.loss_weights.sen1.sam,
+                gradients=cfg.loss_weights.sen1.gradients,
+            ),
+            sen2=TranslationLossWeights(
+                nll=cfg.loss_weights.sen2.nll,
+                lpips=cfg.loss_weights.sen2.lpips,
+                sam=cfg.loss_weights.sen2.sam,
+                gradients=cfg.loss_weights.sen2.gradients,
+            ),
+            s1_s2=TranslationLossWeights(
+                nll=cfg.loss_weights.s1_s2.nll,
+                lpips=cfg.loss_weights.s1_s2.lpips,
+                sam=cfg.loss_weights.s1_s2.sam,
+                gradients=cfg.loss_weights.s1_s2.gradients,
+            ),
+            s2_s1=TranslationLossWeights(
+                nll=cfg.loss_weights.s2_s1.nll,
+                lpips=cfg.loss_weights.s2_s1.lpips,
+                sam=cfg.loss_weights.s2_s1.sam,
+                gradients=cfg.loss_weights.s2_s1.gradients,
+            ),
+            forward=cfg.loss_weights.forward,
+            cross=cfg.loss_weights.forward,
+            latent=cfg.loss_weights.latent,
+        ),
+    )
+    return mmdc_full_config
diff --git a/test/full_module_config.json b/test/full_module_config.json
new file mode 100644
index 0000000000000000000000000000000000000000..db4ed17229fcb4b9e5c9a88aa23e791825ea80dd
--- /dev/null
+++ b/test/full_module_config.json
@@ -0,0 +1,197 @@
+{
+  "_target_": "mmdc_singledate.models.mmdc_full_module.MMDCFullLitModule",
+  "lr": 0.0001,
+  "model": {
+    "_target_": "mmdc_singledate.models.mmdc_full_module.MMDCFullModule",
+    "config": {
+      "_target_": "mmdc_singledate.models.types.MultiVAEConfig",
+      "data_sizes": {
+        "_target_": "mmdc_singledate.models.types.MMDCDataChannels",
+        "sen1": 6,
+        "s1_angles": 6,
+        "sen2": 10,
+        "s2_angles": 6,
+        "srtm": 4,
+        "w_c": 103
+      },
+      "embeddings": {
+        "_target_": "mmdc_singledate.models.types.ModularEmbeddingConfig",
+        "s1_angles": {
+          "_target_": "mmdc_singledate.models.types.MLPConfig",
+          "hidden_layers": [
+            9,
+            12,
+            9
+          ],
+          "out_channels": 3
+        },
+        "s1_srtm": {
+          "_target_": "mmdc_singledate.models.types.UnetParams",
+          "out_channels": 4,
+          "encoder_sizes": [
+            32,
+            64,
+            128
+          ],
+          "kernel_size": 3,
+          "tail_layers": 3
+        },
+        "s2_angles": {
+          "_target_": "mmdc_singledate.models.types.ConvnetParams",
+          "out_channels": 3,
+          "sizes": [
+            16,
+            8
+          ],
+          "kernel_sizes": [
+            1,
+            1,
+            1
+          ]
+        },
+        "s2_srtm": {
+          "_target_": "mmdc_singledate.models.types.UnetParams",
+          "out_channels": 4,
+          "encoder_sizes": [
+            32,
+            64,
+            128
+          ],
+          "kernel_size": 3,
+          "tail_layers": 3
+        },
+        "w_c": {
+          "_target_": "mmdc_singledate.models.types.ConvnetParams",
+          "out_channels": 3,
+          "sizes": [
+            64,
+            32,
+            16
+          ],
+          "kernel_sizes": [
+            1,
+            1,
+            1,
+            1
+          ]
+        }
+      },
+      "s1_encoder": {
+        "_target_": "mmdc_singledate.models.types.UnetParams",
+        "out_channels": 4,
+        "encoder_sizes": [
+          64,
+          128,
+          256,
+          512
+        ],
+        "kernel_size": 3,
+        "tail_layers": 3
+      },
+      "s2_encoder": {
+        "_target_": "mmdc_singledate.models.types.UnetParams",
+        "out_channels": 4,
+        "encoder_sizes": [
+          64,
+          128,
+          256,
+          512
+        ],
+        "kernel_size": 3,
+        "tail_layers": 3
+      },
+      "s1_decoder": {
+        "_target_": "mmdc_singledate.models.types.ConvnetParams",
+        "out_channels": 0,
+        "sizes": [
+          64,
+          32,
+          16
+        ],
+        "kernel_sizes": [
+          3,
+          3,
+          3,
+          3
+        ]
+      },
+      "s2_decoder": {
+        "_target_": "mmdc_singledate.models.types.ConvnetParams",
+        "out_channels": 0,
+        "sizes": [
+          64,
+          32,
+          16
+        ],
+        "kernel_sizes": [
+          3,
+          3,
+          3,
+          3
+        ]
+      },
+      "s1_enc_use": {
+        "_target_": "mmdc_singledate.models.types.MMDCDataUse",
+        "s1_angles": true,
+        "s2_angles": true,
+        "srtm": true,
+        "w_c": true
+      },
+      "s2_enc_use": {
+        "_target_": "mmdc_singledate.models.types.MMDCDataUse",
+        "s1_angles": true,
+        "s2_angles": true,
+        "srtm": true,
+        "w_c": true
+      },
+      "s1_dec_use": {
+        "_target_": "mmdc_singledate.models.types.MMDCDataUse",
+        "s1_angles": true,
+        "s2_angles": true,
+        "srtm": true,
+        "w_c": true
+      },
+      "s2_dec_use": {
+        "_target_": "mmdc_singledate.models.types.MMDCDataUse",
+        "s1_angles": true,
+        "s2_angles": true,
+        "srtm": true,
+        "w_c": true
+      },
+      "loss_weights": {
+        "_target_": "mmdc_singledate.models.types.MultiVAELossWeights",
+        "sen1": {
+          "_target_": "mmdc_singledate.models.types.TranslationLossWeights",
+          "nll": 1,
+          "lpips": 1,
+          "sam": 1,
+          "gradients": 1
+        },
+        "sen2": {
+          "_target_": "mmdc_singledate.models.types.TranslationLossWeights",
+          "nll": 1,
+          "lpips": 1,
+          "sam": 1,
+          "gradients": 1
+        },
+        "s1_s2": {
+          "_target_": "mmdc_singledate.models.types.TranslationLossWeights",
+          "nll": 1,
+          "lpips": 1,
+          "sam": 1,
+          "gradients": 1
+        },
+        "s2_s1": {
+          "_target_": "mmdc_singledate.models.types.TranslationLossWeights",
+          "nll": 1,
+          "lpips": 1,
+          "sam": 1,
+          "gradients": 1
+        },
+        "forward": 1,
+        "cross": 1,
+        "latent": 1
+      }
+    }
+  }
+}
diff --git a/test/test_mmdc_inference.py b/test/test_mmdc_inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..6d8a5f4fcb26dd62acc4162b8e3a9f0dc045c845
--- /dev/null
+++ b/test/test_mmdc_inference.py
@@ -0,0 +1,453 @@
+#!/usr/bin/env python3
+# Copyright: (c) 2023 CESBIO / Centre National d'Etudes Spatiales
+
+import os
+from datetime import datetime
+from pathlib import Path
+
+import pandas as pd
+import pytest
+import rasterio as rio
+import torch
+
+from mmdc_singledate.inference.components.inference_components import (
+    get_mmdc_full_model,
+    predict_mmdc_model,
+)
+from mmdc_singledate.inference.components.inference_utils import (
+    GeoTiffDataset,
+    generate_chunks,
+)
+from mmdc_singledate.inference.mmdc_tile_inference import (  # inference_dataframe,
+    MMDCProcess,
+    mmdc_tile_inference,
+    predict_single_date_tile,
+)
+
+from .utils import get_scales, setup_data
+
+# from mmdc_singledate.models.types import VAELatentSpace
+
+
+# usefull variables
+dataset_dir = "/work/CESBIO/projects/MAESTRIA/test_onetile/total/T31TCJ/"
+s2_filename_variable = "SENTINEL2A_20180302-105023-464_L2A_T31TCJ_C_V2-2_roi_0.tif"
+s1_asc_filename_variable = "S1A_IW_GRDH_1SDV_20180303T174722_20180303T174747_020854_023C45_4525_36.09887223808782_15.429479982324139_roi_0.tif"
+s1_desc_filename_variable = "S1A_IW_GRDH_1SDV_20180302T060027_20180302T060052_020832_023B8F_C485_40.83835645911643_165.05888005216622_roi_0.tif"
+srtm_filename_variable = "srtm_T31TCJ_roi_0.tif"
+wc_filename_variable = "wc_clim_1_T31TCJ_roi_0.tif"
+
+
+sensors_test = [
+    (["S2L2A", "S1FULL"]),
+    (["S2L2A", "S1ASC"]),
+    (["S2L2A", "S1DESC"]),
+    (["S2L2A"]),
+    (["S1FULL"]),
+    (["S1ASC"]),
+    (["S1DESC"]),
+]
+
+
+datasets = [
+    (
+        s2_filename_variable,
+        s1_asc_filename_variable,
+        s1_desc_filename_variable,
+        srtm_filename_variable,
+        wc_filename_variable,
+    ),
+    (
+        "",
+        s1_asc_filename_variable,
+        s1_desc_filename_variable,
+        srtm_filename_variable,
+        wc_filename_variable,
+    ),
+    (
+        s2_filename_variable,
+        "",
+        s1_desc_filename_variable,
+        srtm_filename_variable,
+        wc_filename_variable,
+    ),
+    (
+        s2_filename_variable,
+        s1_asc_filename_variable,
+        "",
+        srtm_filename_variable,
+        wc_filename_variable,
+    ),
+]
+
+
+# @pytest.mark.skip(reason="Check Time Serie generation")
+def test_GeoTiffDataset():
+    """ """
+    input_data = GeoTiffDataset(
+        s2_filename=os.path.join(dataset_dir, s2_filename_variable),
+        s1_asc_filename=os.path.join(
+            dataset_dir,
+            s1_asc_filename_variable,
+        ),
+        s1_desc_filename=os.path.join(
+            dataset_dir,
+            s1_desc_filename_variable,
+        ),
+        srtm_filename=os.path.join(dataset_dir, srtm_filename_variable),
+        wc_filename=os.path.join(dataset_dir, wc_filename_variable),
+    )
+    # print GetTiffDataset
+    print(input_data)
+    # Check Data Existence
+    assert input_data.s1_asc_availability == True
+    assert input_data.s1_desc_availability == True
+    assert input_data.s2_availabitity == True
+
+
+# @pytest.mark.skip(reason="Check Time Serie generation")
+def test_generate_chunks():
+    """
+    test chunk functionality
+    """
+    chunks = generate_chunks(width=10980, height=10980, nlines=1024)
+
+    assert type(chunks) == list
+    assert chunks[0] == (0, 0, 10980, 1024)
+
+
+@pytest.mark.skip(reason="Check Time Serie generation")
+@pytest.mark.parametrize("sensors", sensors_test)
+def test_predict_mmdc_model(sensors):
+    """ """
+
+    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+    # checkpoints
+    checkpoint_path = "/home/uz/vinascj/scratch/MMDC/results/latent/logs/experiments/runs/mmdc_full/2023-03-16_14-00-10/checkpoints"
+    checkpoint_filename = "last.ckpt"
+
+    mmdc_full_model = get_mmdc_full_model(
+        os.path.join(checkpoint_path, checkpoint_filename)
+    )
+    (
+        s2_x,
+        s2_m,
+        s2_a,
+        s1_x,
+        s1_vm,
+        s1_a_asc,
+        s1_a_desc,
+        srtm_x,
+        wc_x,
+        device,
+    ) = setup_data()
+    scales = get_scales()
+    mmdc_full_model.set_scales(scales)
+    # move to device
+    mmdc_full_model.to(device)
+
+    pred = predict_mmdc_model(
+        mmdc_full_model,
+        sensors,
+        s2_x,
+        s2_m,
+        s2_a,
+        s1_x,
+        s1_vm,
+        s1_a_asc,
+        s1_a_desc,
+        wc_x,
+        srtm_x,
+    )
+    print(pred.shape)
+
+    assert type(pred) == torch.Tensor
+
+
+# TODO add cases with no data
+# @pytest.mark.skip(reason="Dissable for faster testing")
+@pytest.mark.parametrize("sensors", sensors_test)
+def test_predict_single_date_tile(sensors):
+    """ """
+
+    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+    export_path = f"/work/CESBIO/projects/MAESTRIA/test_onetile/total/export/test_latent_singledate_with_model_{'_'.join(sensors)}.tif"
+    input_data = GeoTiffDataset(
+        s2_filename=os.path.join(dataset_dir, s2_filename_variable),
+        s1_asc_filename=os.path.join(
+            dataset_dir,
+            s1_asc_filename_variable,
+        ),
+        s1_desc_filename=os.path.join(
+            dataset_dir,
+            s1_desc_filename_variable,
+        ),
+        srtm_filename=os.path.join(dataset_dir, srtm_filename_variable),
+        wc_filename=os.path.join(dataset_dir, wc_filename_variable),
+    )
+
+    nb_bands = len(sensors) * 8
+    print("nb_bands  :", nb_bands)
+
+    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+    # checkpoints
+    checkpoint_path = "/home/uz/vinascj/scratch/MMDC/results/latent/logs/experiments/runs/mmdc_full/2023-03-16_14-00-10/checkpoints"
+    checkpoint_filename = "last.ckpt"  # "last.ckpt"
+
+    mmdc_full_model = get_mmdc_full_model(
+        os.path.join(checkpoint_path, checkpoint_filename),
+        mmdc_full_config_path="/home/uz/vinascj/src/MMDC/mmdc-singledate/configs/model/mmdc_full.json",
+        inference_tile=True,
+    )
+    scales = get_scales()
+    mmdc_full_model.set_scales(scales)
+    # move to device
+    mmdc_full_model.to(device)
+    assert mmdc_full_model.training == False
+
+    # init the process
+    process = MMDCProcess(
+        count=nb_bands,  # 4,
+        nb_lines=256,
+        patch_size=128,
+        model=mmdc_full_model,
+        process=predict_mmdc_model,
+    )
+
+    # entry point pred function
+    predict_single_date_tile(
+        input_data=input_data,
+        export_path=export_path,
+        sensors=sensors,
+        process=process,
+    )
+
+    assert Path(export_path).exists() == True
+
+    with rio.open(export_path) as predited_raster:
+        predicted_metadata = predited_raster.meta.copy()
+        assert predicted_metadata["count"] == nb_bands
+
+
+@pytest.mark.parametrize(
+    "s2_filename_variable, s1_asc_filename_variable, s1_desc_filename_variable, srtm_filename_variable, wc_filename_variable",
+    datasets,
+)
+def test_predict_single_date_tile_no_data(
+    s2_filename_variable,
+    s1_asc_filename_variable,
+    s1_desc_filename_variable,
+    srtm_filename_variable,
+    wc_filename_variable,
+):
+    """ """
+    sensors = ["S2L2A", "S1ASC"]
+    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+    export_path = f"/work/CESBIO/projects/MAESTRIA/test_onetile/total/export/test_latent_singledate_with_model_{'_'.join(sensors)}.tif"
+    dataset_dir = "/work/CESBIO/projects/MAESTRIA/test_onetile/total/T31TCJ/"
+    input_data = GeoTiffDataset(
+        s2_filename=os.path.join(dataset_dir, s2_filename_variable),
+        s1_asc_filename=os.path.join(
+            dataset_dir,
+            s1_asc_filename_variable,
+        ),
+        s1_desc_filename=os.path.join(
+            dataset_dir,
+            s1_desc_filename_variable,
+        ),
+        srtm_filename=os.path.join(dataset_dir, srtm_filename_variable),
+        wc_filename=os.path.join(dataset_dir, wc_filename_variable),
+    )
+
+    nb_bands = len(sensors) * 8
+    print("nb_bands  :", nb_bands)
+
+    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+    # checkpoints
+    checkpoint_path = "/home/uz/vinascj/scratch/MMDC/results/latent/logs/experiments/runs/mmdc_full/2023-03-16_14-00-10/checkpoints"
+    checkpoint_filename = "last.ckpt"  # "last.ckpt"
+
+    mmdc_full_model = get_mmdc_full_model(
+        os.path.join(checkpoint_path, checkpoint_filename),
+        mmdc_full_config_path="/home/uz/vinascj/src/MMDC/mmdc-singledate/test/full_module_config.json",
+        inference_tile=True,
+    )
+    # move to device
+    scales = get_scales()
+    mmdc_full_model.set_scales(scales)
+    mmdc_full_model.to(device)
+    assert mmdc_full_model.training == False
+
+    # init the process
+    process = MMDCProcess(
+        count=nb_bands,  # 4,
+        nb_lines=256,
+        patch_size=128,
+        model=mmdc_full_model,
+        process=predict_mmdc_model,
+    )
+
+    # entry point pred function
+    predict_single_date_tile(
+        input_data=input_data,
+        export_path=export_path,
+        sensors=sensors,
+        process=process,
+    )
+
+    assert Path(export_path).exists() == True
+
+    with rio.open(export_path) as predited_raster:
+        predicted_metadata = predited_raster.meta.copy()
+        assert predicted_metadata["count"] == nb_bands
+
+
+# @pytest.mark.skip(reason="Check Time Serie generation")
+def test_mmdc_tile_inference():
+    """
+    Test the inference code in a tile
+    """
+    # feed parameters
+    input_path = Path("/work/CESBIO/projects/MAESTRIA/training_dataset2/")
+    export_path = Path("/work/CESBIO/projects/MAESTRIA/test_onetile/total/export/")
+    model_checkpoint_path = "/home/uz/vinascj/scratch/MMDC/results/latent/logs/experiments/runs/mmdc_full/2023-03-16_14-00-10/checkpoints/last.ckpt"
+    model_config_path = (
+        "/home/uz/vinascj/src/MMDC/mmdc-singledate/configs/model/mmdc_full.yaml"
+    )
+    input_tile_list = ["/work/CESBIO/projects/MAESTRIA/training_dataset2/T33TUL"]
+
+    # dataframe with input data
+    tile_df = pd.DataFrame(
+        {
+            "date": [
+                datetime.strptime("2018-01-11", "%Y-%m-%d"),
+                datetime.strptime("2018-01-14", "%Y-%m-%d"),
+                datetime.strptime("2018-01-19", "%Y-%m-%d"),
+                datetime.strptime("2018-01-21", "%Y-%m-%d"),
+                datetime.strptime("2018-01-24", "%Y-%m-%d"),
+            ],
+            "patch_s2": [
+                "/work/CESBIO/projects/MAESTRIA/training_dataset2/T33TUL/SENTINEL2B_20180111-100351-463_L2A_T33TUL_D_V1-4_roi_0.tif",
+                "/work/CESBIO/projects/MAESTRIA/training_dataset2/T33TUL/SENTINEL2B_20180114-101347-463_L2A_T33TUL_D_V1-4_roi_0.tif",
+                "/work/CESBIO/projects/MAESTRIA/training_dataset2/T33TUL/SENTINEL2A_20180119-101331-457_L2A_T33TUL_D_V1-4_roi_0.tif",
+                "/work/CESBIO/projects/MAESTRIA/training_dataset2/T33TUL/SENTINEL2B_20180121-100542-770_L2A_T33TUL_D_V1-4_roi_0.tif",
+                "/work/CESBIO/projects/MAESTRIA/training_dataset2/T33TUL/SENTINEL2B_20180124-101352-753_L2A_T33TUL_D_V1-4_roi_0.tif",
+            ],
+            "patchasc_s1": [
+                "/work/CESBIO/projects/MAESTRIA/training_dataset2/T33TUL/S1A_IW_GRDH_1SDV_20180108T165831_20180108T165856_020066_022322_3008_36.29396013100586_14.593286735562288_roi_0.tif",
+                "/work/CESBIO/projects/MAESTRIA/training_dataset2/T33TUL/S1B_IW_GRDH_1SDV_20180114T165747_20180114T165812_009170_0106A2_AFD8_36.27898252884488_14.533603849909781_roi_0.tif",
+                "/work/CESBIO/projects/MAESTRIA/training_dataset2/T33TUL/nan",
+                "/work/CESBIO/projects/MAESTRIA/training_dataset2/T33TUL/S1A_IW_GRDH_1SDV_20180120T165830_20180120T165855_020241_0228B0_95C7_36.29306272148375_14.591766979064303_roi_0.tif",
+                "/work/CESBIO/projects/MAESTRIA/training_dataset2/T33TUL/nan",
+            ],
+            "patchdesc_s1": [
+                "/work/CESBIO/projects/MAESTRIA/training_dataset2/T33TUL/nan",
+                "/work/CESBIO/projects/MAESTRIA/training_dataset2/T33TUL/S1B_IW_GRDH_1SDV_20180113T051001_20180113T051026_009148_0105F4_218B_45.70533932879138_164.47774901324672_roi_0.tif",
+                "/work/CESBIO/projects/MAESTRIA/training_dataset2/T33TUL/S1A_IW_GRDH_1SDV_20180119T051046_20180119T051111_020219_0227FF_F7E9_45.71131905091333_164.664087191203_roi_0.tif",
+                "/work/CESBIO/projects/MAESTRIA/training_dataset2/T33TUL/S1A_IW_GRDH_1SDV_20180119T051046_20180119T051111_020219_0227FF_F7E9_45.71131905091333_164.664087191203_roi_0.tif",
+                "/work/CESBIO/projects/MAESTRIA/training_dataset2/T33TUL/S1A_IW_GRDH_1SDV_20180124T051848_20180124T051913_020292_022A63_177C_35.81756575888757_165.03806323648553_roi_0.tif",
+            ],
+            "Tuiles": [
+                "T33TUL",
+                "T33TUL",
+                "T33TUL",
+                "T33TUL",
+                "T33TUL",
+            ],
+            "patch": [
+                0,
+                0,
+                0,
+                0,
+                0,
+            ],
+            "roi": [
+                0,
+                0,
+                0,
+                0,
+                0,
+            ],
+            "worldclim_filename": [
+                "/work/CESBIO/projects/MAESTRIA/training_dataset2/T33TUL/wc_clim_1_T33TUL_roi_0.tif",
+                "/work/CESBIO/projects/MAESTRIA/training_dataset2/T33TUL/wc_clim_1_T33TUL_roi_0.tif",
+                "/work/CESBIO/projects/MAESTRIA/training_dataset2/T33TUL/wc_clim_1_T33TUL_roi_0.tif",
+                "/work/CESBIO/projects/MAESTRIA/training_dataset2/T33TUL/wc_clim_1_T33TUL_roi_0.tif",
+                "/work/CESBIO/projects/MAESTRIA/training_dataset2/T33TUL/wc_clim_1_T33TUL_roi_0.tif",
+            ],
+            "srtm_filename": [
+                "/work/CESBIO/projects/MAESTRIA/training_dataset2/T33TUL/srtm_T33TUL_roi_0.tif",
+                "/work/CESBIO/projects/MAESTRIA/training_dataset2/T33TUL/srtm_T33TUL_roi_0.tif",
+                "/work/CESBIO/projects/MAESTRIA/training_dataset2/T33TUL/srtm_T33TUL_roi_0.tif",
+                "/work/CESBIO/projects/MAESTRIA/training_dataset2/T33TUL/srtm_T33TUL_roi_0.tif",
+                "/work/CESBIO/projects/MAESTRIA/training_dataset2/T33TUL/srtm_T33TUL_roi_0.tif",
+            ],
+            "patch_s2_availability": [
+                True,
+                True,
+                True,
+                True,
+                True,
+            ],
+            "patchasc_s1_availability": [
+                True,
+                True,
+                False,
+                True,
+                False,
+            ],
+            "patchdesc_s1_availability": [
+                True,
+                True,
+                False,
+                True,
+                False,
+            ],
+        }
+    )
+    # inference_dataframe(
+    #     samples_dir=input_path,
+    #     input_tile_list=input_tile_list,
+    #     days_gap=5,
+    # )
+    # # just test the df
+    print(
+        tile_df.loc[
+            :,
+            [
+                "patch_s2_availability",
+                "patchasc_s1_availability",
+                "patchdesc_s1_availability",
+            ],
+        ]
+    )
+
+    # apply prediction
+    mmdc_tile_inference(
+        inference_dataframe=tile_df,
+        export_path=export_path,
+        model_checkpoint_path=model_checkpoint_path,
+        model_config_path=model_config_path,
+        patch_size=128,
+        nb_lines=256,
+    )
+
+    assert Path(export_path / "latent_tile_infer_2018-01-11_S2L2A_S1FULL.tif").exists()
+    assert Path(export_path / "latent_tile_infer_2018-01-14_S2L2A_S1FULL.tif").exists()
+    assert Path(export_path / "latent_tile_infer_2018-01-19_S2L2A.tif").exists()
+    assert Path(export_path / "latent_tile_infer_2018-01-21_S2L2A_S1FULL.tif").exists()
+    assert Path(export_path / "latent_tile_infer_2018-01-24_S2L2A.tif").exists()
+
+    with rio.open(
+        os.path.join(export_path, "latent_tile_infer_2018-01-19_S2L2A.tif")
+    ) as raster:
+        metadata = raster.meta.copy()
+    assert metadata["count"] == 8
+
+    with rio.open(
+        os.path.join(export_path, "latent_tile_infer_2018-01-11_S2L2A_S1FULL.tif")
+    ) as raster:
+        metadata = raster.meta.copy()
+    assert metadata["count"] == 16