diff --git a/configs/callbacks/default.yaml b/configs/callbacks/default.yaml
index a2c4d6137f83fdb1b9d57a88d66a191aecf503c5..922a6a34164749365183cbabe91d52bb2ea72047 100644
--- a/configs/callbacks/default.yaml
+++ b/configs/callbacks/default.yaml
@@ -17,12 +17,12 @@ model_checkpoint_every_10:
   dirpath: ${output_dir}/checkpoints/${run_id}
   filename: "epoch_{epoch:03d}"
 
-early_stopping:
-  _target_: pytorch_lightning.callbacks.EarlyStopping
-  monitor: "val/loss" # name of the logged metric which determines when model is improving
-  mode: "min" # "max" means higher metric value is better, can be also "min"
-  patience: 10 # how many validation epochs of not improving until training stops
-  min_delta: 0 # minimum change in the monitored metric needed to qualify as an improvement
+#early_stopping:
+#  _target_: pytorch_lightning.callbacks.EarlyStopping
+#  monitor: "val/loss" # name of the logged metric which determines when model is improving
+#  mode: "min" # "max" means higher metric value is better, can be also "min"
+#  patience: 10 # how many validation epochs of not improving until training stops
+#  min_delta: 0 # minimum change in the monitored metric needed to qualify as an improvement
 
 model_summary:
   _target_: pytorch_lightning.callbacks.RichModelSummary
diff --git a/configs/callbacks/mmdc_full_callback.yaml b/configs/callbacks/mmdc_full_callback.yaml
index cc8921c2dd2b39ce91ad481952727af90bcef952..60019d3c1ba48e2607d3397cc06c4125262554ba 100644
--- a/configs/callbacks/mmdc_full_callback.yaml
+++ b/configs/callbacks/mmdc_full_callback.yaml
@@ -36,13 +36,13 @@ stochastic_weight_averaging:
   annealing_epochs: 5
   annealing_strategy: 'cos'
 
-
-early_stopping:
-  _target_: pytorch_lightning.callbacks.EarlyStopping
-  monitor: "val/loss" # name of the logged metric which determines when model is improving
-  mode: "min" # "max" means higher metric value is better, can be also "min"
-  patience: 100 # how many validation epochs of not improving until training stops
-  min_delta: 0 # minimum change in the monitored metric needed to qualify as an improvement
+#
+#early_stopping:
+#  _target_: pytorch_lightning.callbacks.EarlyStopping
+#  monitor: "val/loss" # name of the logged metric which determines when model is improving
+#  mode: "min" # "max" means higher metric value is better, can be also "min"
+#  patience: 100 # how many validation epochs of not improving until training stops
+#  min_delta: 0 # minimum change in the monitored metric needed to qualify as an improvement
 
 lr_monitor:
   _target_: pytorch_lightning.callbacks.LearningRateMonitor
@@ -54,10 +54,6 @@ model_summary:
 rich_progress_bar:
   _target_: pytorch_lightning.callbacks.RichProgressBar
 
-timer:
-  _target_: pytorch_lightning.callbacks.Timer
-  duration: 00:72:00:00
-
 evolution_plot:
   _target_: mmdc_singledate.callbacks.eval_callbacks.PlotEvolutionCallback
   save_dir: ${hydra:run.dir}
diff --git a/configs/callbacks/mmdc_full_callback_freeze.yaml b/configs/callbacks/mmdc_full_callback_freeze.yaml
index f759f7f90b0bd1ed5f8fae2abe2bd905f9b301ae..2ae4fb136bf50aca9e69b6fe238b27056dbc6881 100644
--- a/configs/callbacks/mmdc_full_callback_freeze.yaml
+++ b/configs/callbacks/mmdc_full_callback_freeze.yaml
@@ -37,12 +37,12 @@ stochastic_weight_averaging:
   annealing_strategy: 'cos'
 
 
-early_stopping:
-  _target_: pytorch_lightning.callbacks.EarlyStopping
-  monitor: "val/loss" # name of the logged metric which determines when model is improving
-  mode: "min" # "max" means higher metric value is better, can be also "min"
-  patience: 100 # how many validation epochs of not improving until training stops
-  min_delta: 0 # minimum change in the monitored metric needed to qualify as an improvement
+#early_stopping:
+#  _target_: pytorch_lightning.callbacks.EarlyStopping
+#  monitor: "val/loss" # name of the logged metric which determines when model is improving
+#  mode: "min" # "max" means higher metric value is better, can be also "min"
+#  patience: 100 # how many validation epochs of not improving until training stops
+#  min_delta: 0 # minimum change in the monitored metric needed to qualify as an improvement
 
 lr_monitor:
   _target_: pytorch_lightning.callbacks.LearningRateMonitor
@@ -54,10 +54,6 @@ model_summary:
 rich_progress_bar:
   _target_: pytorch_lightning.callbacks.RichProgressBar
 
-timer:
-  _target_: pytorch_lightning.callbacks.Timer
-  duration: 00:72:00:00
-
 evolution_plot:
   _target_: mmdc_singledate.callbacks.eval_callbacks.PlotEvolutionCallback
   save_dir: ${hydra:run.dir}
diff --git a/configs/callbacks/mmdc_full_experts_callback.yaml b/configs/callbacks/mmdc_full_experts_callback.yaml
index 4e9d61951322e703501415f48630361785a52d84..f99965e132bc3bb7fbf3526f90e628a200c7dacc 100644
--- a/configs/callbacks/mmdc_full_experts_callback.yaml
+++ b/configs/callbacks/mmdc_full_experts_callback.yaml
@@ -37,12 +37,12 @@ stochastic_weight_averaging:
   annealing_strategy: 'cos'
 
 
-early_stopping:
-  _target_: pytorch_lightning.callbacks.EarlyStopping
-  monitor: "val/loss" # name of the logged metric which determines when model is improving
-  mode: "min" # "max" means higher metric value is better, can be also "min"
-  patience: 100 # how many validation epochs of not improving until training stops
-  min_delta: 0 # minimum change in the monitored metric needed to qualify as an improvement
+#early_stopping:
+#  _target_: pytorch_lightning.callbacks.EarlyStopping
+#  monitor: "val/loss" # name of the logged metric which determines when model is improving
+#  mode: "min" # "max" means higher metric value is better, can be also "min"
+#  patience: 100 # how many validation epochs of not improving until training stops
+#  min_delta: 0 # minimum change in the monitored metric needed to qualify as an improvement
 
 lr_monitor:
   _target_: pytorch_lightning.callbacks.LearningRateMonitor
@@ -54,10 +54,6 @@ model_summary:
 rich_progress_bar:
   _target_: pytorch_lightning.callbacks.RichProgressBar
 
-timer:
-  _target_: pytorch_lightning.callbacks.Timer
-  duration: 00:72:00:00
-
 evolution_plot:
   _target_: mmdc_singledate.callbacks.eval_callbacks.PlotEvolutionCallback
   save_dir: ${hydra:run.dir}
diff --git a/configs/experiment/mmdc_full_poe_no_mask.yaml b/configs/experiment/mmdc_full_poe_no_mask.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..b85964735ec7fa0ee42a89605e7ff13711eac4e2
--- /dev/null
+++ b/configs/experiment/mmdc_full_poe_no_mask.yaml
@@ -0,0 +1,18 @@
+# @package _global_
+
+# to execute this experiment run:
+# python train.py experiment=mmdc_full_poe_mask
+
+defaults:
+  - override /datamodule: mmdc_datamodule.yaml
+  - override /model: mmdc_full_poe_no_mask.yaml
+  - override /callbacks: mmdc_full_experts_callback.yaml
+  - override /trainer: mmdc_full_trainer.yaml
+
+# all parameters bellow will be merged with parameters from default configurations set above
+# this allows you overwrite only specified parameters
+
+tags: ["mmdc_poe_no_mask"]
+
+
+name: "mmdc_poe_no_mask"
diff --git a/configs/model/mmdc_full.yaml b/configs/model/mmdc_full.yaml
index 3da4b2b2b5421e84abba2a3e7b8da466c64eb304..54469f70cceaeae894d38e48449df1dd21337113 100644
--- a/configs/model/mmdc_full.yaml
+++ b/configs/model/mmdc_full.yaml
@@ -1,7 +1,7 @@
 
 _target_: mmdc_singledate.models.lightning.full.MMDCFullLitModule
 lr: 0.0001
-resume_from_checkpoint: /work/scratch/data/kalinie/MMDC/results/latent/checkpoints/mmdc_full/2024-01-25_08-25-45/last.ckpt
+resume_from_checkpoint: /work/scratch/data/kalinie/MMDC/results/latent/checkpoints/mmdc_full/2024-02-29_14-05-22/last.ckpt
 
 #load_weights_from_checkpoint:
 #  path: /work/scratch/data/kalinie/MMDC/results/latent/checkpoints/mmdc_full/2024-01-23_11-39-06/last.ckpt
@@ -53,13 +53,13 @@ model:
       _target_: mmdc_singledate.models.datatypes.MultiVAEEncDecConfig
       s1_encoder:
         _target_: mmdc_singledate.models.datatypes.UnetParams
-        out_channels: 6
+        out_channels: 12
         encoder_sizes: [64, 128, 256, 512]
         kernel_size: 3
         tail_layers: 3
       s2_encoder:
         _target_: mmdc_singledate.models.datatypes.UnetParams
-        out_channels: 6
+        out_channels: 12
         encoder_sizes: [64, 128, 256, 512]
         kernel_size: 3
         tail_layers: 3
diff --git a/configs/model/mmdc_full_mask.yaml b/configs/model/mmdc_full_mask.yaml
index d3b02e016c522ee58f00ce63df9197ef16cf9c09..6148b5bad8249a0c12296e7dde42a3ae8065c4b4 100644
--- a/configs/model/mmdc_full_mask.yaml
+++ b/configs/model/mmdc_full_mask.yaml
@@ -1,18 +1,20 @@
 
 _target_: mmdc_singledate.models.lightning.full.MMDCFullLitModule
 lr: 0.0001
-resume_from_checkpoint: /work/scratch/data/kalinie/MMDC/results/latent/checkpoints/mmdc_mask/2024-01-26_16-15-20/last.ckpt
-#load_weights_from_checkpoint:
-#  path: /work/scratch/data/kalinie/MMDC/good_checkpoints/pretrained_baseline_50epochs/epoch_050.ckpt
-#  selected_layers: all
-load_weights_from_checkpoint:
-  path: null
+resume_from_checkpoint: /work/scratch/data/kalinie/MMDC/results/latent/checkpoints/mmdc_mask/2024-03-08_11-09-43/last.ckpt
+# load_weights_from_checkpoint:
+#   path: /work/scratch/data/kalinie/MMDC/results/latent/checkpoints/mmdc_full/2024-02-27_14-47-18/last.ckpt
+#   selected_layers: "all"
+#   freeze_epochs: null
 
 masking:
   _target_: mmdc_singledate.models.components.masking.DataMasking
-  epochs_thr_list: [25, 50, 75]
-  masked_perc_list: [0, 50, 100]
-  strategy_list: ["random", "one_satellite", "one_satellite"]
+  # epochs_thr_list: [25, 50, 75]
+  # masked_perc_list: [0, 50, 100]
+  epochs_thr_list: [0, 25, 50]
+  masked_perc_list: [0, 25, 50]
+
+  strategy_list: ["random", "random", "random"]
 
 model:
   _target_: mmdc_singledate.models.torch.full.MMDCFullModule
diff --git a/configs/model/mmdc_full_moe_no_mask.yaml b/configs/model/mmdc_full_moe_no_mask.yaml
index fc3a13242bdf8afe009e2b7190a290e3a7b9451d..2208ff1a7fb4da81831a4fd45b81bd7118eeff9b 100644
--- a/configs/model/mmdc_full_moe_no_mask.yaml
+++ b/configs/model/mmdc_full_moe_no_mask.yaml
@@ -2,6 +2,10 @@
 _target_: mmdc_singledate.models.lightning.full_experts.MMDCFullExpertsLitModule
 lr: 0.0001
 resume_from_checkpoint: null
+load_weights_from_checkpoint:
+ path: /work/scratch/data/kalinie/MMDC/results/latent/checkpoints/mmdc_full/2024-02-27_14-47-18/last.ckpt
+ selected_layers: "all"
+ freeze_epochs: null
 
 masking: null
 
@@ -98,17 +102,9 @@ model:
           dem: False
           meteo: False
     loss_weights:
-      _target_: mmdc_singledate.models.datatypes.MultiVAELossWeights
-      sen1:
-        _target_: mmdc_singledate.models.datatypes.TranslationLossWeights
-        nll: 1.0
-      sen2:
-        _target_: mmdc_singledate.models.datatypes.TranslationLossWeights
-        nll: 1.0
       forward: 1.0
-      cross: 1.0
       latent:
         _target_: mmdc_singledate.models.datatypes.LatentLossWeights
         latent: 1.0
-        latent_cov: 0.1
+        latent_cov: 0
         latent_max_var: False
diff --git a/configs/model/mmdc_full_poe_mask.yaml b/configs/model/mmdc_full_poe_mask.yaml
index d7fe3d0641963a39197671c4efa8d0fba65459d9..0cee823846bd2ca03b3a12184c50ce896f65c612 100644
--- a/configs/model/mmdc_full_poe_mask.yaml
+++ b/configs/model/mmdc_full_poe_mask.yaml
@@ -2,7 +2,10 @@
 _target_: mmdc_singledate.models.lightning.full_experts.MMDCFullExpertsLitModule
 lr: 0.0001
 resume_from_checkpoint: null
-#load_weights_from_checkpoint: /work/scratch/data/kalinie/MMDC/good_checkpoints/pretrained_no_latent/last.ckpt
+load_weights_from_checkpoint:
+ path: /work/scratch/data/kalinie/MMDC/results/latent/checkpoints/mmdc_mask/2024-02-27_14-47-18/last.ckpt
+ selected_layers: "all"
+ freeze_epochs: null
 
 
 masking:
diff --git a/configs/model/mmdc_full_poe_no_mask.yaml b/configs/model/mmdc_full_poe_no_mask.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..e8894772a5416872bc95226d53f77cebdd2463b3
--- /dev/null
+++ b/configs/model/mmdc_full_poe_no_mask.yaml
@@ -0,0 +1,111 @@
+
+_target_: mmdc_singledate.models.lightning.full_experts.MMDCFullExpertsLitModule
+lr: 0.0001
+resume_from_checkpoint: null
+load_weights_from_checkpoint:
+ path: /work/scratch/data/kalinie/MMDC/results/latent/checkpoints/mmdc_full/2024-02-27_14-47-18/last.ckpt
+ selected_layers: "all"
+ freeze_epochs: null
+
+
+masking: null
+
+model:
+  _target_: mmdc_singledate.models.torch.full_experts.MMDCFullExpertsModule
+  config:
+    _target_: mmdc_singledate.models.datatypes.MultiVAEConfig
+    experts_strategy: "PoE"  # Experts to choose from "average", "MoE", "PoE"
+    data_sizes:
+      _target_: mmdc_singledate.models.datatypes.MMDCDataChannels
+      sen1: 6
+      s1_angles: 2
+      sen2: 10
+      s2_angles: 6
+      dem: 4
+      meteo: 48
+    embeddings:
+      _target_: mmdc_singledate.models.datatypes.ModularEmbeddingConfig
+      s1_angles:
+        _target_: mmdc_singledate.models.datatypes.ConvnetParams
+        out_channels: 3
+        sizes: [16, 8]
+        kernel_sizes: [1, 1, 1]
+      s1_dem:
+        _target_: mmdc_singledate.models.datatypes.UnetParams
+        out_channels: 4
+        encoder_sizes: [32, 64, 128]
+        kernel_size: 3
+        tail_layers: 3
+      s2_angles:
+        _target_: mmdc_singledate.models.datatypes.ConvnetParams
+        out_channels: 3
+        sizes: [16, 8]
+        kernel_sizes: [1, 1, 1]
+      s2_dem:
+        _target_: mmdc_singledate.models.datatypes.UnetParams
+        out_channels: 4
+        encoder_sizes: [32, 64, 128]
+        kernel_size: 3
+        tail_layers: 3
+      meteo:
+        _target_: mmdc_singledate.models.datatypes.ConvnetParams
+        out_channels: 3
+        sizes: [32, 32, 16]
+        kernel_sizes: [1, 1, 1, 1]
+    auto_enc:
+      _target_: mmdc_singledate.models.datatypes.MultiVAEEncDecConfig
+      s1_encoder:
+        _target_: mmdc_singledate.models.datatypes.UnetParams
+        out_channels: 6
+        encoder_sizes: [64, 128, 256, 512]
+        kernel_size: 3
+        tail_layers: 3
+      s2_encoder:
+        _target_: mmdc_singledate.models.datatypes.UnetParams
+        out_channels: 6
+        encoder_sizes: [64, 128, 256, 512]
+        kernel_size: 3
+        tail_layers: 3
+      s1_decoder:
+        _target_: mmdc_singledate.models.datatypes.ConvnetParams
+        out_channels: 0
+        sizes: [64, 32, 16]
+        kernel_sizes: [3, 3, 3, 3]
+      s2_decoder:
+        _target_: mmdc_singledate.models.datatypes.ConvnetParams
+        out_channels: 0
+        sizes: [64, 32, 16]
+        kernel_sizes: [3, 3, 3, 3]
+      ae_use:
+        _target_: mmdc_singledate.models.datatypes.MultiVAEAuxUseConfig
+        s1_enc_use:
+          _target_: mmdc_singledate.models.datatypes.MMDCDataUse
+          s1_angles: True
+          s2_angles: True
+          dem: True
+          meteo: True
+        s2_enc_use:
+          _target_: mmdc_singledate.models.datatypes.MMDCDataUse
+          s1_angles: True
+          s2_angles: True
+          dem: True
+          meteo: True
+        s1_dec_use:
+          _target_: mmdc_singledate.models.datatypes.MMDCDataUse
+          s1_angles: True
+          s2_angles: True
+          dem: False
+          meteo: False
+        s2_dec_use:
+          _target_: mmdc_singledate.models.datatypes.MMDCDataUse
+          s1_angles: True
+          s2_angles: True
+          dem: False
+          meteo: False
+    loss_weights:
+      forward: 1.0
+      latent:
+        _target_: mmdc_singledate.models.datatypes.LatentLossWeights
+        latent: 1.0
+        latent_cov: 0
+        latent_max_var: False
diff --git a/configs/train.yaml b/configs/train.yaml
index c33b4f50569eb0dd908a38c77476b9750b9b2e77..73ae5d7992d8fbd223716e054db70ae2d3f12c3b 100644
--- a/configs/train.yaml
+++ b/configs/train.yaml
@@ -3,7 +3,7 @@
 # specify here default training configuration
 defaults:
   - _self_
-  - datamodule: mmdc_full_datamodule.yaml
+  - datamodule: mmdc_datamodule.yaml
   - model: mmdc_full.yaml
   - callbacks: default.yaml
   - logger: default.yaml
diff --git a/configs/trainer/mmdc_full_quick.yaml b/configs/trainer/mmdc_full_quick.yaml
index 09f391ed53dfaaeeaa4050e4e3d3a790fee8bda0..497a592d8cbc77368e268cefb7f9f96da6afa033 100644
--- a/configs/trainer/mmdc_full_quick.yaml
+++ b/configs/trainer/mmdc_full_quick.yaml
@@ -4,7 +4,7 @@ accelerator: gpu
 devices: 1
 # gradient_clip_val: 0.5
 min_epochs: 0
-max_epochs: 10
+max_epochs: 200
 
 precision: bf16
 
diff --git a/jobs/mmdc_full_moe_no_mask.slurm b/jobs/mmdc_full_moe_no_mask.slurm
index c2dffd28c7c24e9af51488cd713139cb8ece2d78..92f3866b1eb287129609b47fecdcba7301b097e2 100644
--- a/jobs/mmdc_full_moe_no_mask.slurm
+++ b/jobs/mmdc_full_moe_no_mask.slurm
@@ -3,9 +3,11 @@
 #SBATCH --output=outputfile-%j.out
 #SBATCH --error=errorfile-%j.err
 #SBATCH -N 1                       # number of nodes ( or --nodes=1)
-#SBATCH --ntasks-per-node=8                       # number of tasks ( or --tesks=8)
+#SBATCH --ntasks-per-node=16       # number of tasks ( or --tesks=8)
 #SBATCH --gres=gpu:1               # number of gpus
-#SBATCH --time=12:00:00            # Walltime
+#SBATCH --partition=gpu_a100       # partition
+#SBATCH --qos=gpu_all              # QoS
+#SBATCH --time=24:00:00            # Walltime
 #SBATCH --mem-per-cpu=12G          # memory per core
 #SBATCH --account=cesbio           # MANDATORY : account (launch myaccounts to list your accounts)
 #SBATCH --export=none              #  to start the job with a clean environnement and source of ~/.bashrc
diff --git a/jobs/mmdc_full_poe_no_mask.slurm b/jobs/mmdc_full_poe_no_mask.slurm
new file mode 100644
index 0000000000000000000000000000000000000000..5144ef55cd5ae0bd4b3dc5c3fe8feea5273f79c1
--- /dev/null
+++ b/jobs/mmdc_full_poe_no_mask.slurm
@@ -0,0 +1,26 @@
+#!/bin/bash
+#SBATCH --job-name=PoE_no_mask
+#SBATCH --output=outputfile-%j.out
+#SBATCH --error=errorfile-%j.err
+#SBATCH -N 1                       # number of nodes ( or --nodes=1)
+#SBATCH --ntasks-per-node=16       # number of tasks ( or --tesks=8)
+#SBATCH --gres=gpu:1               # number of gpus
+#SBATCH --partition=gpu_a100       # partition
+#SBATCH --qos=gpu_all              # QoS
+#SBATCH --time=24:00:00            # Walltime
+#SBATCH --mem-per-cpu=12G          # memory per core
+#SBATCH --account=cesbio           # MANDATORY : account (launch myaccounts to list your accounts)
+#SBATCH --export=none              #  to start the job with a clean environnement and source of ~/.bashrc
+
+# be sure no modules loaded
+module purge
+
+export SCRATCH=/work/scratch/data/${USER}
+export SRCDIR=${SCRATCH}/src/MMDC/mmdc-singledate
+export MMDC_INIT=${SRCDIR}/mmdc_init.sh
+export WORKING_DIR=${SCRATCH}/MMDC/jobs
+
+cd ${WORKING_DIR}
+source ${MMDC_INIT}
+
+HYDRA_FULL_ERROR=1 python   ${SRCDIR}/train.py experiment=mmdc_full_poe_no_mask  >> output_$SLURM_JOBID.log
diff --git a/setup.cfg b/setup.cfg
index e4e529306cc9163b984a89cc8b7b2f071a6a37bb..aacbd28ab404bc5f1ebcba5ac25399d7f00aaec9 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -85,7 +85,7 @@ testing =
 # CAUTION: --cov flags may prohibit setting breakpoints while debugging.
 #          Comment those flags to avoid this pytest issue.
 addopts =
-    --typeguard-packages=mtan_s1s2_classif
+#    --typeguard-packages=mtan_s1s2_classif
     --cov mmdc_singledate --cov-report term-missing
     --verbose
 norecursedirs =
diff --git a/src/mmdc_singledate/models/components/building_components.py b/src/mmdc_singledate/models/components/building_components.py
index 53ebedfc86bee44adfab63e312bf122d89b1e55f..ba4e423ed8a2d1e3aecc1f10bce2d23b6d21357c 100644
--- a/src/mmdc_singledate/models/components/building_components.py
+++ b/src/mmdc_singledate/models/components/building_components.py
@@ -391,8 +391,10 @@ class Unet(nn.Module):
         """
         enc_ftrs = self.encoder(data)
         out = self.decoder(enc_ftrs[::-1][0], enc_ftrs[::-1][1:])
-        mu_out = torch.tanh(self.mu_tail(F.dropout(out, 0.2)))
-        logvar_out = torch.tanh(self.logvar_tail(F.dropout(out, 0.2)))
+        mu_out = torch.tanh(self.mu_tail(F.dropout(out, 0.2, training=self.training)))
+        logvar_out = torch.tanh(
+            self.logvar_tail(F.dropout(out, 0.2, training=self.training))
+        )
 
         return VAELatentSpace(mu_out, logvar_out)
 
diff --git a/src/mmdc_singledate/models/datatypes.py b/src/mmdc_singledate/models/datatypes.py
index 5cfb625c9c0b22eaa9abfddab72e514602bb30ab..24995e491626a629d2d8df783511393526d8b689 100644
--- a/src/mmdc_singledate/models/datatypes.py
+++ b/src/mmdc_singledate/models/datatypes.py
@@ -259,11 +259,11 @@ class AuxData:
 class S1S2VAEAuxiliaryEmbeddings:
     """Class to hold embeddings of the auxiliary data"""
 
-    s1_angles_emb: torch.Tensor
-    s2_angles_emb: torch.Tensor
-    s1_dem_emb: torch.Tensor
-    s2_dem_emb: torch.Tensor
-    meteo_emb: torch.Tensor
+    s1_angles_emb: torch.Tensor | None
+    s2_angles_emb: torch.Tensor | None
+    s1_dem_emb: torch.Tensor | None
+    s2_dem_emb: torch.Tensor | None
+    meteo_emb: torch.Tensor | None
 
 
 @dataclass
diff --git a/src/mmdc_singledate/models/lightning/full.py b/src/mmdc_singledate/models/lightning/full.py
index b3352bf69d526866851243113605f84040d89d35..a7341fbbd4b9acbe53334ede64f9f66ddf727071 100644
--- a/src/mmdc_singledate/models/lightning/full.py
+++ b/src/mmdc_singledate/models/lightning/full.py
@@ -158,7 +158,6 @@ class MMDCFullLitModule(MMDCBaseLitModule):  # pylint: disable=too-many-ancestor
         self, batch: Any
     ) -> tuple[MMDCFullDataForLoss, MMDCFullDataForLoss, S1S2VAELosses]:
         """Helper method for computing losses"""
-        batch = destructure_batch(batch)
 
         forward_batch = batch.copy()
         if self.masking is not None:
@@ -274,6 +273,7 @@ class MMDCFullLitModule(MMDCBaseLitModule):  # pylint: disable=too-many-ancestor
         masking_loss: bool = False,
     ) -> S1S2VAELosses | tuple[S1S2VAELosses, S1S2VAEMaskingLosses]:
         """Optimization step"""
+        batch = destructure_batch(batch)
 
         s1_data_for_loss, s2_data_for_loss, vae_losses = self.compute_losses(batch)
 
diff --git a/src/mmdc_singledate/models/torch/full.py b/src/mmdc_singledate/models/torch/full.py
index 938fdeae5db3d4ea528a38d54b17b26cbb68d4d0..da22684b126621dd4d29a0fff2a0af1beadb5d78 100644
--- a/src/mmdc_singledate/models/torch/full.py
+++ b/src/mmdc_singledate/models/torch/full.py
@@ -205,7 +205,7 @@ class MMDCFullModule(MMDCBaseModule):
         for aux_tensor, use_it in aux_data:
             if use_it:
                 res = torch.cat(
-                    [res, F.dropout(aux_tensor, 0.2)],
+                    [res, F.dropout(aux_tensor, 0.2, training=self.training)],
                     dim=1,
                 )
         return res