diff --git a/configs/callbacks/default.yaml b/configs/callbacks/default.yaml index a2c4d6137f83fdb1b9d57a88d66a191aecf503c5..922a6a34164749365183cbabe91d52bb2ea72047 100644 --- a/configs/callbacks/default.yaml +++ b/configs/callbacks/default.yaml @@ -17,12 +17,12 @@ model_checkpoint_every_10: dirpath: ${output_dir}/checkpoints/${run_id} filename: "epoch_{epoch:03d}" -early_stopping: - _target_: pytorch_lightning.callbacks.EarlyStopping - monitor: "val/loss" # name of the logged metric which determines when model is improving - mode: "min" # "max" means higher metric value is better, can be also "min" - patience: 10 # how many validation epochs of not improving until training stops - min_delta: 0 # minimum change in the monitored metric needed to qualify as an improvement +#early_stopping: +# _target_: pytorch_lightning.callbacks.EarlyStopping +# monitor: "val/loss" # name of the logged metric which determines when model is improving +# mode: "min" # "max" means higher metric value is better, can be also "min" +# patience: 10 # how many validation epochs of not improving until training stops +# min_delta: 0 # minimum change in the monitored metric needed to qualify as an improvement model_summary: _target_: pytorch_lightning.callbacks.RichModelSummary diff --git a/configs/callbacks/mmdc_full_callback.yaml b/configs/callbacks/mmdc_full_callback.yaml index cc8921c2dd2b39ce91ad481952727af90bcef952..60019d3c1ba48e2607d3397cc06c4125262554ba 100644 --- a/configs/callbacks/mmdc_full_callback.yaml +++ b/configs/callbacks/mmdc_full_callback.yaml @@ -36,13 +36,13 @@ stochastic_weight_averaging: annealing_epochs: 5 annealing_strategy: 'cos' - -early_stopping: - _target_: pytorch_lightning.callbacks.EarlyStopping - monitor: "val/loss" # name of the logged metric which determines when model is improving - mode: "min" # "max" means higher metric value is better, can be also "min" - patience: 100 # how many validation epochs of not improving until training stops - min_delta: 0 # minimum change in the monitored metric needed to qualify as an improvement +# +#early_stopping: +# _target_: pytorch_lightning.callbacks.EarlyStopping +# monitor: "val/loss" # name of the logged metric which determines when model is improving +# mode: "min" # "max" means higher metric value is better, can be also "min" +# patience: 100 # how many validation epochs of not improving until training stops +# min_delta: 0 # minimum change in the monitored metric needed to qualify as an improvement lr_monitor: _target_: pytorch_lightning.callbacks.LearningRateMonitor @@ -54,10 +54,6 @@ model_summary: rich_progress_bar: _target_: pytorch_lightning.callbacks.RichProgressBar -timer: - _target_: pytorch_lightning.callbacks.Timer - duration: 00:72:00:00 - evolution_plot: _target_: mmdc_singledate.callbacks.eval_callbacks.PlotEvolutionCallback save_dir: ${hydra:run.dir} diff --git a/configs/callbacks/mmdc_full_callback_freeze.yaml b/configs/callbacks/mmdc_full_callback_freeze.yaml index f759f7f90b0bd1ed5f8fae2abe2bd905f9b301ae..2ae4fb136bf50aca9e69b6fe238b27056dbc6881 100644 --- a/configs/callbacks/mmdc_full_callback_freeze.yaml +++ b/configs/callbacks/mmdc_full_callback_freeze.yaml @@ -37,12 +37,12 @@ stochastic_weight_averaging: annealing_strategy: 'cos' -early_stopping: - _target_: pytorch_lightning.callbacks.EarlyStopping - monitor: "val/loss" # name of the logged metric which determines when model is improving - mode: "min" # "max" means higher metric value is better, can be also "min" - patience: 100 # how many validation epochs of not improving until training stops - min_delta: 0 # minimum change in the monitored metric needed to qualify as an improvement +#early_stopping: +# _target_: pytorch_lightning.callbacks.EarlyStopping +# monitor: "val/loss" # name of the logged metric which determines when model is improving +# mode: "min" # "max" means higher metric value is better, can be also "min" +# patience: 100 # how many validation epochs of not improving until training stops +# min_delta: 0 # minimum change in the monitored metric needed to qualify as an improvement lr_monitor: _target_: pytorch_lightning.callbacks.LearningRateMonitor @@ -54,10 +54,6 @@ model_summary: rich_progress_bar: _target_: pytorch_lightning.callbacks.RichProgressBar -timer: - _target_: pytorch_lightning.callbacks.Timer - duration: 00:72:00:00 - evolution_plot: _target_: mmdc_singledate.callbacks.eval_callbacks.PlotEvolutionCallback save_dir: ${hydra:run.dir} diff --git a/configs/callbacks/mmdc_full_experts_callback.yaml b/configs/callbacks/mmdc_full_experts_callback.yaml index 4e9d61951322e703501415f48630361785a52d84..f99965e132bc3bb7fbf3526f90e628a200c7dacc 100644 --- a/configs/callbacks/mmdc_full_experts_callback.yaml +++ b/configs/callbacks/mmdc_full_experts_callback.yaml @@ -37,12 +37,12 @@ stochastic_weight_averaging: annealing_strategy: 'cos' -early_stopping: - _target_: pytorch_lightning.callbacks.EarlyStopping - monitor: "val/loss" # name of the logged metric which determines when model is improving - mode: "min" # "max" means higher metric value is better, can be also "min" - patience: 100 # how many validation epochs of not improving until training stops - min_delta: 0 # minimum change in the monitored metric needed to qualify as an improvement +#early_stopping: +# _target_: pytorch_lightning.callbacks.EarlyStopping +# monitor: "val/loss" # name of the logged metric which determines when model is improving +# mode: "min" # "max" means higher metric value is better, can be also "min" +# patience: 100 # how many validation epochs of not improving until training stops +# min_delta: 0 # minimum change in the monitored metric needed to qualify as an improvement lr_monitor: _target_: pytorch_lightning.callbacks.LearningRateMonitor @@ -54,10 +54,6 @@ model_summary: rich_progress_bar: _target_: pytorch_lightning.callbacks.RichProgressBar -timer: - _target_: pytorch_lightning.callbacks.Timer - duration: 00:72:00:00 - evolution_plot: _target_: mmdc_singledate.callbacks.eval_callbacks.PlotEvolutionCallback save_dir: ${hydra:run.dir} diff --git a/configs/experiment/mmdc_full_poe_no_mask.yaml b/configs/experiment/mmdc_full_poe_no_mask.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b85964735ec7fa0ee42a89605e7ff13711eac4e2 --- /dev/null +++ b/configs/experiment/mmdc_full_poe_no_mask.yaml @@ -0,0 +1,18 @@ +# @package _global_ + +# to execute this experiment run: +# python train.py experiment=mmdc_full_poe_mask + +defaults: + - override /datamodule: mmdc_datamodule.yaml + - override /model: mmdc_full_poe_no_mask.yaml + - override /callbacks: mmdc_full_experts_callback.yaml + - override /trainer: mmdc_full_trainer.yaml + +# all parameters bellow will be merged with parameters from default configurations set above +# this allows you overwrite only specified parameters + +tags: ["mmdc_poe_no_mask"] + + +name: "mmdc_poe_no_mask" diff --git a/configs/model/mmdc_full.yaml b/configs/model/mmdc_full.yaml index 3da4b2b2b5421e84abba2a3e7b8da466c64eb304..54469f70cceaeae894d38e48449df1dd21337113 100644 --- a/configs/model/mmdc_full.yaml +++ b/configs/model/mmdc_full.yaml @@ -1,7 +1,7 @@ _target_: mmdc_singledate.models.lightning.full.MMDCFullLitModule lr: 0.0001 -resume_from_checkpoint: /work/scratch/data/kalinie/MMDC/results/latent/checkpoints/mmdc_full/2024-01-25_08-25-45/last.ckpt +resume_from_checkpoint: /work/scratch/data/kalinie/MMDC/results/latent/checkpoints/mmdc_full/2024-02-29_14-05-22/last.ckpt #load_weights_from_checkpoint: # path: /work/scratch/data/kalinie/MMDC/results/latent/checkpoints/mmdc_full/2024-01-23_11-39-06/last.ckpt @@ -53,13 +53,13 @@ model: _target_: mmdc_singledate.models.datatypes.MultiVAEEncDecConfig s1_encoder: _target_: mmdc_singledate.models.datatypes.UnetParams - out_channels: 6 + out_channels: 12 encoder_sizes: [64, 128, 256, 512] kernel_size: 3 tail_layers: 3 s2_encoder: _target_: mmdc_singledate.models.datatypes.UnetParams - out_channels: 6 + out_channels: 12 encoder_sizes: [64, 128, 256, 512] kernel_size: 3 tail_layers: 3 diff --git a/configs/model/mmdc_full_mask.yaml b/configs/model/mmdc_full_mask.yaml index d3b02e016c522ee58f00ce63df9197ef16cf9c09..6148b5bad8249a0c12296e7dde42a3ae8065c4b4 100644 --- a/configs/model/mmdc_full_mask.yaml +++ b/configs/model/mmdc_full_mask.yaml @@ -1,18 +1,20 @@ _target_: mmdc_singledate.models.lightning.full.MMDCFullLitModule lr: 0.0001 -resume_from_checkpoint: /work/scratch/data/kalinie/MMDC/results/latent/checkpoints/mmdc_mask/2024-01-26_16-15-20/last.ckpt -#load_weights_from_checkpoint: -# path: /work/scratch/data/kalinie/MMDC/good_checkpoints/pretrained_baseline_50epochs/epoch_050.ckpt -# selected_layers: all -load_weights_from_checkpoint: - path: null +resume_from_checkpoint: /work/scratch/data/kalinie/MMDC/results/latent/checkpoints/mmdc_mask/2024-03-08_11-09-43/last.ckpt +# load_weights_from_checkpoint: +# path: /work/scratch/data/kalinie/MMDC/results/latent/checkpoints/mmdc_full/2024-02-27_14-47-18/last.ckpt +# selected_layers: "all" +# freeze_epochs: null masking: _target_: mmdc_singledate.models.components.masking.DataMasking - epochs_thr_list: [25, 50, 75] - masked_perc_list: [0, 50, 100] - strategy_list: ["random", "one_satellite", "one_satellite"] + # epochs_thr_list: [25, 50, 75] + # masked_perc_list: [0, 50, 100] + epochs_thr_list: [0, 25, 50] + masked_perc_list: [0, 25, 50] + + strategy_list: ["random", "random", "random"] model: _target_: mmdc_singledate.models.torch.full.MMDCFullModule diff --git a/configs/model/mmdc_full_moe_no_mask.yaml b/configs/model/mmdc_full_moe_no_mask.yaml index fc3a13242bdf8afe009e2b7190a290e3a7b9451d..2208ff1a7fb4da81831a4fd45b81bd7118eeff9b 100644 --- a/configs/model/mmdc_full_moe_no_mask.yaml +++ b/configs/model/mmdc_full_moe_no_mask.yaml @@ -2,6 +2,10 @@ _target_: mmdc_singledate.models.lightning.full_experts.MMDCFullExpertsLitModule lr: 0.0001 resume_from_checkpoint: null +load_weights_from_checkpoint: + path: /work/scratch/data/kalinie/MMDC/results/latent/checkpoints/mmdc_full/2024-02-27_14-47-18/last.ckpt + selected_layers: "all" + freeze_epochs: null masking: null @@ -98,17 +102,9 @@ model: dem: False meteo: False loss_weights: - _target_: mmdc_singledate.models.datatypes.MultiVAELossWeights - sen1: - _target_: mmdc_singledate.models.datatypes.TranslationLossWeights - nll: 1.0 - sen2: - _target_: mmdc_singledate.models.datatypes.TranslationLossWeights - nll: 1.0 forward: 1.0 - cross: 1.0 latent: _target_: mmdc_singledate.models.datatypes.LatentLossWeights latent: 1.0 - latent_cov: 0.1 + latent_cov: 0 latent_max_var: False diff --git a/configs/model/mmdc_full_poe_mask.yaml b/configs/model/mmdc_full_poe_mask.yaml index d7fe3d0641963a39197671c4efa8d0fba65459d9..0cee823846bd2ca03b3a12184c50ce896f65c612 100644 --- a/configs/model/mmdc_full_poe_mask.yaml +++ b/configs/model/mmdc_full_poe_mask.yaml @@ -2,7 +2,10 @@ _target_: mmdc_singledate.models.lightning.full_experts.MMDCFullExpertsLitModule lr: 0.0001 resume_from_checkpoint: null -#load_weights_from_checkpoint: /work/scratch/data/kalinie/MMDC/good_checkpoints/pretrained_no_latent/last.ckpt +load_weights_from_checkpoint: + path: /work/scratch/data/kalinie/MMDC/results/latent/checkpoints/mmdc_mask/2024-02-27_14-47-18/last.ckpt + selected_layers: "all" + freeze_epochs: null masking: diff --git a/configs/model/mmdc_full_poe_no_mask.yaml b/configs/model/mmdc_full_poe_no_mask.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e8894772a5416872bc95226d53f77cebdd2463b3 --- /dev/null +++ b/configs/model/mmdc_full_poe_no_mask.yaml @@ -0,0 +1,111 @@ + +_target_: mmdc_singledate.models.lightning.full_experts.MMDCFullExpertsLitModule +lr: 0.0001 +resume_from_checkpoint: null +load_weights_from_checkpoint: + path: /work/scratch/data/kalinie/MMDC/results/latent/checkpoints/mmdc_full/2024-02-27_14-47-18/last.ckpt + selected_layers: "all" + freeze_epochs: null + + +masking: null + +model: + _target_: mmdc_singledate.models.torch.full_experts.MMDCFullExpertsModule + config: + _target_: mmdc_singledate.models.datatypes.MultiVAEConfig + experts_strategy: "PoE" # Experts to choose from "average", "MoE", "PoE" + data_sizes: + _target_: mmdc_singledate.models.datatypes.MMDCDataChannels + sen1: 6 + s1_angles: 2 + sen2: 10 + s2_angles: 6 + dem: 4 + meteo: 48 + embeddings: + _target_: mmdc_singledate.models.datatypes.ModularEmbeddingConfig + s1_angles: + _target_: mmdc_singledate.models.datatypes.ConvnetParams + out_channels: 3 + sizes: [16, 8] + kernel_sizes: [1, 1, 1] + s1_dem: + _target_: mmdc_singledate.models.datatypes.UnetParams + out_channels: 4 + encoder_sizes: [32, 64, 128] + kernel_size: 3 + tail_layers: 3 + s2_angles: + _target_: mmdc_singledate.models.datatypes.ConvnetParams + out_channels: 3 + sizes: [16, 8] + kernel_sizes: [1, 1, 1] + s2_dem: + _target_: mmdc_singledate.models.datatypes.UnetParams + out_channels: 4 + encoder_sizes: [32, 64, 128] + kernel_size: 3 + tail_layers: 3 + meteo: + _target_: mmdc_singledate.models.datatypes.ConvnetParams + out_channels: 3 + sizes: [32, 32, 16] + kernel_sizes: [1, 1, 1, 1] + auto_enc: + _target_: mmdc_singledate.models.datatypes.MultiVAEEncDecConfig + s1_encoder: + _target_: mmdc_singledate.models.datatypes.UnetParams + out_channels: 6 + encoder_sizes: [64, 128, 256, 512] + kernel_size: 3 + tail_layers: 3 + s2_encoder: + _target_: mmdc_singledate.models.datatypes.UnetParams + out_channels: 6 + encoder_sizes: [64, 128, 256, 512] + kernel_size: 3 + tail_layers: 3 + s1_decoder: + _target_: mmdc_singledate.models.datatypes.ConvnetParams + out_channels: 0 + sizes: [64, 32, 16] + kernel_sizes: [3, 3, 3, 3] + s2_decoder: + _target_: mmdc_singledate.models.datatypes.ConvnetParams + out_channels: 0 + sizes: [64, 32, 16] + kernel_sizes: [3, 3, 3, 3] + ae_use: + _target_: mmdc_singledate.models.datatypes.MultiVAEAuxUseConfig + s1_enc_use: + _target_: mmdc_singledate.models.datatypes.MMDCDataUse + s1_angles: True + s2_angles: True + dem: True + meteo: True + s2_enc_use: + _target_: mmdc_singledate.models.datatypes.MMDCDataUse + s1_angles: True + s2_angles: True + dem: True + meteo: True + s1_dec_use: + _target_: mmdc_singledate.models.datatypes.MMDCDataUse + s1_angles: True + s2_angles: True + dem: False + meteo: False + s2_dec_use: + _target_: mmdc_singledate.models.datatypes.MMDCDataUse + s1_angles: True + s2_angles: True + dem: False + meteo: False + loss_weights: + forward: 1.0 + latent: + _target_: mmdc_singledate.models.datatypes.LatentLossWeights + latent: 1.0 + latent_cov: 0 + latent_max_var: False diff --git a/configs/train.yaml b/configs/train.yaml index c33b4f50569eb0dd908a38c77476b9750b9b2e77..73ae5d7992d8fbd223716e054db70ae2d3f12c3b 100644 --- a/configs/train.yaml +++ b/configs/train.yaml @@ -3,7 +3,7 @@ # specify here default training configuration defaults: - _self_ - - datamodule: mmdc_full_datamodule.yaml + - datamodule: mmdc_datamodule.yaml - model: mmdc_full.yaml - callbacks: default.yaml - logger: default.yaml diff --git a/configs/trainer/mmdc_full_quick.yaml b/configs/trainer/mmdc_full_quick.yaml index 09f391ed53dfaaeeaa4050e4e3d3a790fee8bda0..497a592d8cbc77368e268cefb7f9f96da6afa033 100644 --- a/configs/trainer/mmdc_full_quick.yaml +++ b/configs/trainer/mmdc_full_quick.yaml @@ -4,7 +4,7 @@ accelerator: gpu devices: 1 # gradient_clip_val: 0.5 min_epochs: 0 -max_epochs: 10 +max_epochs: 200 precision: bf16 diff --git a/jobs/mmdc_full_moe_no_mask.slurm b/jobs/mmdc_full_moe_no_mask.slurm index c2dffd28c7c24e9af51488cd713139cb8ece2d78..92f3866b1eb287129609b47fecdcba7301b097e2 100644 --- a/jobs/mmdc_full_moe_no_mask.slurm +++ b/jobs/mmdc_full_moe_no_mask.slurm @@ -3,9 +3,11 @@ #SBATCH --output=outputfile-%j.out #SBATCH --error=errorfile-%j.err #SBATCH -N 1 # number of nodes ( or --nodes=1) -#SBATCH --ntasks-per-node=8 # number of tasks ( or --tesks=8) +#SBATCH --ntasks-per-node=16 # number of tasks ( or --tesks=8) #SBATCH --gres=gpu:1 # number of gpus -#SBATCH --time=12:00:00 # Walltime +#SBATCH --partition=gpu_a100 # partition +#SBATCH --qos=gpu_all # QoS +#SBATCH --time=24:00:00 # Walltime #SBATCH --mem-per-cpu=12G # memory per core #SBATCH --account=cesbio # MANDATORY : account (launch myaccounts to list your accounts) #SBATCH --export=none # to start the job with a clean environnement and source of ~/.bashrc diff --git a/jobs/mmdc_full_poe_no_mask.slurm b/jobs/mmdc_full_poe_no_mask.slurm new file mode 100644 index 0000000000000000000000000000000000000000..5144ef55cd5ae0bd4b3dc5c3fe8feea5273f79c1 --- /dev/null +++ b/jobs/mmdc_full_poe_no_mask.slurm @@ -0,0 +1,26 @@ +#!/bin/bash +#SBATCH --job-name=PoE_no_mask +#SBATCH --output=outputfile-%j.out +#SBATCH --error=errorfile-%j.err +#SBATCH -N 1 # number of nodes ( or --nodes=1) +#SBATCH --ntasks-per-node=16 # number of tasks ( or --tesks=8) +#SBATCH --gres=gpu:1 # number of gpus +#SBATCH --partition=gpu_a100 # partition +#SBATCH --qos=gpu_all # QoS +#SBATCH --time=24:00:00 # Walltime +#SBATCH --mem-per-cpu=12G # memory per core +#SBATCH --account=cesbio # MANDATORY : account (launch myaccounts to list your accounts) +#SBATCH --export=none # to start the job with a clean environnement and source of ~/.bashrc + +# be sure no modules loaded +module purge + +export SCRATCH=/work/scratch/data/${USER} +export SRCDIR=${SCRATCH}/src/MMDC/mmdc-singledate +export MMDC_INIT=${SRCDIR}/mmdc_init.sh +export WORKING_DIR=${SCRATCH}/MMDC/jobs + +cd ${WORKING_DIR} +source ${MMDC_INIT} + +HYDRA_FULL_ERROR=1 python ${SRCDIR}/train.py experiment=mmdc_full_poe_no_mask >> output_$SLURM_JOBID.log diff --git a/setup.cfg b/setup.cfg index e4e529306cc9163b984a89cc8b7b2f071a6a37bb..aacbd28ab404bc5f1ebcba5ac25399d7f00aaec9 100644 --- a/setup.cfg +++ b/setup.cfg @@ -85,7 +85,7 @@ testing = # CAUTION: --cov flags may prohibit setting breakpoints while debugging. # Comment those flags to avoid this pytest issue. addopts = - --typeguard-packages=mtan_s1s2_classif +# --typeguard-packages=mtan_s1s2_classif --cov mmdc_singledate --cov-report term-missing --verbose norecursedirs = diff --git a/src/mmdc_singledate/models/components/building_components.py b/src/mmdc_singledate/models/components/building_components.py index 53ebedfc86bee44adfab63e312bf122d89b1e55f..ba4e423ed8a2d1e3aecc1f10bce2d23b6d21357c 100644 --- a/src/mmdc_singledate/models/components/building_components.py +++ b/src/mmdc_singledate/models/components/building_components.py @@ -391,8 +391,10 @@ class Unet(nn.Module): """ enc_ftrs = self.encoder(data) out = self.decoder(enc_ftrs[::-1][0], enc_ftrs[::-1][1:]) - mu_out = torch.tanh(self.mu_tail(F.dropout(out, 0.2))) - logvar_out = torch.tanh(self.logvar_tail(F.dropout(out, 0.2))) + mu_out = torch.tanh(self.mu_tail(F.dropout(out, 0.2, training=self.training))) + logvar_out = torch.tanh( + self.logvar_tail(F.dropout(out, 0.2, training=self.training)) + ) return VAELatentSpace(mu_out, logvar_out) diff --git a/src/mmdc_singledate/models/datatypes.py b/src/mmdc_singledate/models/datatypes.py index 5cfb625c9c0b22eaa9abfddab72e514602bb30ab..24995e491626a629d2d8df783511393526d8b689 100644 --- a/src/mmdc_singledate/models/datatypes.py +++ b/src/mmdc_singledate/models/datatypes.py @@ -259,11 +259,11 @@ class AuxData: class S1S2VAEAuxiliaryEmbeddings: """Class to hold embeddings of the auxiliary data""" - s1_angles_emb: torch.Tensor - s2_angles_emb: torch.Tensor - s1_dem_emb: torch.Tensor - s2_dem_emb: torch.Tensor - meteo_emb: torch.Tensor + s1_angles_emb: torch.Tensor | None + s2_angles_emb: torch.Tensor | None + s1_dem_emb: torch.Tensor | None + s2_dem_emb: torch.Tensor | None + meteo_emb: torch.Tensor | None @dataclass diff --git a/src/mmdc_singledate/models/lightning/full.py b/src/mmdc_singledate/models/lightning/full.py index b3352bf69d526866851243113605f84040d89d35..a7341fbbd4b9acbe53334ede64f9f66ddf727071 100644 --- a/src/mmdc_singledate/models/lightning/full.py +++ b/src/mmdc_singledate/models/lightning/full.py @@ -158,7 +158,6 @@ class MMDCFullLitModule(MMDCBaseLitModule): # pylint: disable=too-many-ancestor self, batch: Any ) -> tuple[MMDCFullDataForLoss, MMDCFullDataForLoss, S1S2VAELosses]: """Helper method for computing losses""" - batch = destructure_batch(batch) forward_batch = batch.copy() if self.masking is not None: @@ -274,6 +273,7 @@ class MMDCFullLitModule(MMDCBaseLitModule): # pylint: disable=too-many-ancestor masking_loss: bool = False, ) -> S1S2VAELosses | tuple[S1S2VAELosses, S1S2VAEMaskingLosses]: """Optimization step""" + batch = destructure_batch(batch) s1_data_for_loss, s2_data_for_loss, vae_losses = self.compute_losses(batch) diff --git a/src/mmdc_singledate/models/torch/full.py b/src/mmdc_singledate/models/torch/full.py index 938fdeae5db3d4ea528a38d54b17b26cbb68d4d0..da22684b126621dd4d29a0fff2a0af1beadb5d78 100644 --- a/src/mmdc_singledate/models/torch/full.py +++ b/src/mmdc_singledate/models/torch/full.py @@ -205,7 +205,7 @@ class MMDCFullModule(MMDCBaseModule): for aux_tensor, use_it in aux_data: if use_it: res = torch.cat( - [res, F.dropout(aux_tensor, 0.2)], + [res, F.dropout(aux_tensor, 0.2, training=self.training)], dim=1, ) return res