From 4cdc39392ebbb565b138051f1453d354c408eecc Mon Sep 17 00:00:00 2001
From: Juan Sebastian Vinasco-Salinas <js.vinasco.s@gmail.com>
Date: Thu, 16 Feb 2023 16:14:12 +0000
Subject: [PATCH 01/81] add README.rst for entry points setup

---
 README.rst                       | 1 +
 setup.cfg                        | 6 ++++++
 src/bin/generate_mmdc_dataset.py | 6 +++++-
 3 files changed, 12 insertions(+), 1 deletion(-)
 create mode 100644 README.rst

diff --git a/README.rst b/README.rst
new file mode 100644
index 00000000..5bc06818
--- /dev/null
+++ b/README.rst
@@ -0,0 +1 @@
+# MMDC-SingleDate
diff --git a/setup.cfg b/setup.cfg
index 9c5469ae..08f3a061 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -68,6 +68,12 @@ testing =
     pytest-cov
 
 [options.entry_points]
+console_scripts =
+     mmdc_data = bin.generate_mmdc_dataset:main
+     mmdc_visualize = bin.visualize_mmdc_ds:main
+     mmdc_split_dataset = bin.split_mmdc_dataset:main
+
+# mmdc_inference = bin.infer_mmdc_ds:main
 # Add here console scripts like:
 # console_scripts =
 #     script_name = mmdc-singledate.module:function
diff --git a/src/bin/generate_mmdc_dataset.py b/src/bin/generate_mmdc_dataset.py
index d801c566..ada7c11b 100644
--- a/src/bin/generate_mmdc_dataset.py
+++ b/src/bin/generate_mmdc_dataset.py
@@ -81,7 +81,7 @@ def get_parser() -> argparse.ArgumentParser:
     return arg_parser
 
 
-if __name__ == "__main__":
+def main():
     # Parser arguments
     parser = get_parser()
     args = parser.parse_args()
@@ -124,3 +124,7 @@ if __name__ == "__main__":
             args.threshold,
         ),
     )
+
+
+if __name__ == "__main__":
+    main()
-- 
GitLab


From 6a38a318699ec4d871d7d9f98e38fe700401d138 Mon Sep 17 00:00:00 2001
From: Juan Sebastian Vinasco-Salinas <js.vinasco.s@gmail.com>
Date: Thu, 16 Mar 2023 10:08:36 +0000
Subject: [PATCH 02/81] solve conflicts

---
 src/bin/split_mmdc_dataset.py | 16 +++++++---------
 1 file changed, 7 insertions(+), 9 deletions(-)

diff --git a/src/bin/split_mmdc_dataset.py b/src/bin/split_mmdc_dataset.py
index de429bf5..b205a249 100644
--- a/src/bin/split_mmdc_dataset.py
+++ b/src/bin/split_mmdc_dataset.py
@@ -141,16 +141,10 @@ def split_tile_rois(
             roi_intersect.writelines(sorted([f"{roi}\n" for roi in roi_list]))
 
 
-def main(
-    input_dir: Path, input_file: Path, test_percentage: int, random_state: int
-) -> None:
+def main():
     """
     Split the tiles/rois between train/val/test
     """
-    split_tile_rois(input_dir, input_file, test_percentage, random_state)
-
-
-if __name__ == "__main__":
     # Parser arguments
     parser = get_parser()
     args = parser.parse_args()
@@ -171,10 +165,14 @@ if __name__ == "__main__":
     logging.info("test percentage selected : %s", args.test_percentage)
     logging.info("random state value :  %s", args.random_state)
 
-    # Go to main
-    main(
+    # Go to entry point
+    split_tile_rois(
         Path(args.tensors_dir),
         Path(args.roi_intersections_file),
         args.test_percentage,
         args.random_state,
     )
+
+
+if __name__ == "__main__":
+    main()
-- 
GitLab


From 52f90b6df8091f9db3ba1021e3a906a6706b8ed2 Mon Sep 17 00:00:00 2001
From: Juan Sebastian Vinasco-Salinas <js.vinasco.s@gmail.com>
Date: Thu, 16 Mar 2023 10:10:15 +0000
Subject: [PATCH 03/81] ignore iota2 copy

---
 .gitignore | 1 +
 1 file changed, 1 insertion(+)

diff --git a/.gitignore b/.gitignore
index 2015cd5d..676c9c84 100644
--- a/.gitignore
+++ b/.gitignore
@@ -13,3 +13,4 @@ thirdparties
 .coverage
 src/MMDC_SingleDate.egg-info/
 .projectile
+iota2_thirdparties/
-- 
GitLab


From f5ee64c78921dec37fe73524f1ae6e7f0d949212 Mon Sep 17 00:00:00 2001
From: Juan Sebastian Vinasco-Salinas <js.vinasco.s@gmail.com>
Date: Thu, 16 Feb 2023 16:16:00 +0000
Subject: [PATCH 04/81] update for entry point

---
 src/bin/visualize_mmdc_ds.py | 55 +++++++++++++++---------------------
 1 file changed, 22 insertions(+), 33 deletions(-)

diff --git a/src/bin/visualize_mmdc_ds.py b/src/bin/visualize_mmdc_ds.py
index 01da86c6..a69bc924 100644
--- a/src/bin/visualize_mmdc_ds.py
+++ b/src/bin/visualize_mmdc_ds.py
@@ -71,34 +71,10 @@ def get_parser() -> argparse.ArgumentParser:
     return arg_parser
 
 
-def main(
-    input_directory: Path,
-    input_tiles: Path,
-    patch_index: list[int] | None,
-    nb_patches: int | None,
-    export_path: Path,
-) -> None:
+def main():
     """
     Entry point for the visualization
     """
-    try:
-        export_visualization_graphs(
-            input_directory=input_directory,
-            input_tiles=input_tiles,
-            patch_index=patch_index,
-            nb_patches=nb_patches,
-            export_path=export_path,
-        )
-
-    except FileNotFoundError:
-        if not input_directory.exists():
-            print(f"Folder {input_directory} does not exist!!")
-
-        if not export_path.exists():
-            print(f"Folder {export_path} does not exist!!")
-
-
-if __name__ == "__main__":
     # Parser arguments
     parser = get_parser()
     args = parser.parse_args()
@@ -126,13 +102,26 @@ if __name__ == "__main__":
     logging.info(" index patches : %s", args.patch_index)
     logging.info(" output directory : %s", args.export_path)
 
-    # Go to main
-    main(
-        input_directory=Path(args.input_path),
-        input_tiles=Path(args.input_tiles),
-        patch_index=args.patch_index,
-        nb_patches=args.nb_patches,
-        export_path=Path(args.export_path),
-    )
+    # Go to entry point
+
+    try:
+        export_visualization_graphs(
+            input_directory=Path(args.input_path),
+            input_tiles=Path(args.input_tiles),
+            patch_index=args.patch_index,
+            nb_patches=args.nb_patches,
+            export_path=Path(args.export_path),
+        )
+
+    except FileNotFoundError:
+        if not args.input_directory.exists():
+            print(f"Folder {args.input_directory} does not exist!!")
+
+        if not args.export_path.exists():
+            print(f"Folder {args.export_path} does not exist!!")
 
     logging.info("Visualization export finished")
+
+
+if __name__ == "__main__":
+    main()
-- 
GitLab


From 1db085675ab92d477b7e1a9b5e567806f95bf871 Mon Sep 17 00:00:00 2001
From: Juan Sebastian Vinasco-Salinas <js.vinasco.s@gmail.com>
Date: Fri, 17 Feb 2023 08:23:58 +0000
Subject: [PATCH 05/81] update convention

---
 setup.cfg | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/setup.cfg b/setup.cfg
index 08f3a061..eb6eb8f6 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -69,7 +69,7 @@ testing =
 
 [options.entry_points]
 console_scripts =
-     mmdc_data = bin.generate_mmdc_dataset:main
+     mmdc_generate_dataset = bin.generate_mmdc_dataset:main
      mmdc_visualize = bin.visualize_mmdc_ds:main
      mmdc_split_dataset = bin.split_mmdc_dataset:main
 
-- 
GitLab


From b1236f397e7b2e87717ada8b6f6bc465ce831fe4 Mon Sep 17 00:00:00 2001
From: Juan Sebastian Vinasco-Salinas <js.vinasco.s@gmail.com>
Date: Wed, 29 Mar 2023 11:46:23 +0000
Subject: [PATCH 06/81] merge makefile

---
 Makefile | 4 +++-
 1 file changed, 3 insertions(+), 1 deletion(-)

diff --git a/Makefile b/Makefile
index 5692cfc8..bc64a93b 100644
--- a/Makefile
+++ b/Makefile
@@ -74,8 +74,10 @@ test_no_PIL:
 test_mask_loss:
 	$(CONDA) && pytest -vv test/test_masked_losses.py
 
-
 PYLINT_IGNORED = "pix2pix_module.py,pix2pix_networks.py,mmdc_residual_module.py"
+test_inference:
+	$(CONDA) && pytest -vv test/test_mmdc_inference.py
+
 #.PHONY:
 pylint:
 	$(CONDA) && pylint --ignore=$(PYLINT_IGNORED) src/
-- 
GitLab


From bee5c3218cb7d34527bf919d13fac79aec765678 Mon Sep 17 00:00:00 2001
From: Juan Sebastian Vinasco-Salinas <js.vinasco.s@gmail.com>
Date: Fri, 17 Feb 2023 09:13:48 +0000
Subject: [PATCH 07/81] WIP add inference components utils

---
 .../components/inference_components.py        | 210 ++++++++++++
 .../inference/components/inference_utils.py   | 324 ++++++++++++++++++
 2 files changed, 534 insertions(+)
 create mode 100644 src/mmdc_singledate/inference/components/inference_components.py
 create mode 100644 src/mmdc_singledate/inference/components/inference_utils.py

diff --git a/src/mmdc_singledate/inference/components/inference_components.py b/src/mmdc_singledate/inference/components/inference_components.py
new file mode 100644
index 00000000..f480b032
--- /dev/null
+++ b/src/mmdc_singledate/inference/components/inference_components.py
@@ -0,0 +1,210 @@
+#!/usr/bin/env python3
+# copyright: (c) 2023 cesbio / centre national d'Etudes Spatiales
+
+"""
+Functions components for get the
+latent spaces for a given tile
+"""
+
+import logging
+
+# imports
+from pathlib import Path
+from typing import Literal
+
+import torch
+from torch import nn
+from torchutils import patches
+
+from mmdc_singledate.models.components.model_dataclass import VAELatentSpace
+from mmdc_singledate.models.mmdc_full_module import MMDCFullModule
+
+# define sensors variable for typing
+SENSORS = Literal["S2L2A", "S1FULL", "S1ASC", "S1DESC"]
+
+# Configure the logger
+NUMERIC_LEVEL = getattr(logging, "INFO", None)
+logging.basicConfig(
+    level=NUMERIC_LEVEL, format="%(asctime)-15s %(levelname)s: %(message)s"
+)
+logger = logging.getLogger(__name__)
+
+
+@torch.no_grad()
+def get_mmdc_full_model(
+    checkpoint: str,
+) -> nn.Module:
+
+    """
+    Instantiate the network
+
+    """
+
+    if not Path(checkpoint).exists:
+        logger.info("No checkpoint path was give. ")
+        return 1
+    else:
+        mmdc_full_model = MMDCFullModule(
+            s2_angles_conv_in_channels=6,
+            s2_angles_conv_out_channels=3,
+            s2_angles_conv_encoder_sizes=[16, 8],
+            s2_angles_conv_kernel_size=1,
+            #
+            srtm_s2_angles_conv_in_channels=10,  # 4 srtm + 6 angles
+            srtm_s2_angles_conv_out_channels=3,
+            srtm_s2_angles_unet_encoder_sizes=[32, 64, 128],
+            srtm_s2_angles_kernel_size=3,
+            #
+            s1_angles_mlp_in_size=6,
+            s1_angles_mlp_hidden_layers=[9, 12, 9],
+            s1_angles_mlp_out_size=3,
+            #
+            srtm_s1_encoder_sizes=[32, 64, 128],
+            srtm_kernel_size=3,
+            srtm_s1_in_channels=4,
+            srtm_s1_out_channels=3,
+            #
+            wc_enc_sizes=[64, 32, 16],
+            wc_kernel_size=1,
+            wc_in_channels=103,
+            wc_out_channels=4,
+            #
+            s1_input_size=6,
+            s1_encoder_sizes=[64, 128, 256, 512, 1024],
+            s1_enc_kernel_sizes=[3],
+            s1_decoder_sizes=[32, 16, 8],
+            s1_dec_kernel_sizes=[3, 3, 3, 3],
+            code_sizes=[0, 4, 0],
+            s2_input_size=10,
+            s2_encoder_sizes=[64, 128, 256, 512, 1024],
+            s2_enc_kernel_sizes=[3],
+            s2_decoder_sizes=[64, 32, 16],
+            s2_dec_kernel_sizes=[3, 3, 3, 3],
+            s1_ang_in_decoder=True,
+            s2_ang_in_decoder=True,
+            srtm_in_decoder=True,
+            wc_in_decoder=True,
+            w_d1_e1_s1=1,
+            w_d2_e2_s2=1,
+            w_d1_e2_s2=1,
+            w_d2_e1_s1=1,
+            w_code_s1s2=1,
+        )
+
+        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+        print(device)
+        # load state_dict
+        lightning_checkpoint = torch.load(checkpoint, map_location=device)
+
+        # delete "model" from the loaded checkpoint
+        checkpoint = {
+            key.split("model.")[-1]: item
+            for key, item in lightning_checkpoint["state_dict"].items()
+            if key.startswith("model.")
+        }
+
+        # load the state dict
+        mmdc_full_model.load_state_dict(checkpoint)
+
+        # disble randomness, dropout, etc...
+        mmdc_full_model.eval()
+
+    return mmdc_full_model
+
+
+@torch.no_grad()
+def predict_mmdc_model(
+    model: nn.Module,
+    sensors: SENSORS,
+    s2_ref,
+    s2_angles,
+    s1_asc,
+    s1_desc,
+    s1_asc_angles,
+    s1_desc_angles,
+    srtm,
+    worldclim,
+) -> VAELatentSpace:
+    """
+    This function apply the predict method to a
+    batch of mmdc data and get the latent space
+    by sensor
+
+        :args: model : mmdc_full model intance
+        :args: sensor : Sensor acquisition
+        :args: s2_ref :
+        :args: s2_angles :
+        :args: s1_back :
+        :args: s1_asc_angles :
+        :args: s1_desc_angles :
+        :args: worldclim :
+        :args: srtm :
+
+        :return: VAELatentSpace
+    """
+    s1_back = torch.cat(
+        (
+            s1_asc,
+            s1_desc,
+            (s1_asc / s1_desc),
+        ),
+        1,
+    )
+    prediction = model.predict(
+        s2_ref,
+        s2_angles,
+        s1_back,
+        s1_asc_angles,
+        s1_desc_angles,
+        worldclim,
+        srtm,
+    )
+
+    if "S2L2A" in sensors:
+        # get latent
+        latent_space = prediction[0].latent.latent_s2
+
+    if "S1FULL" in sensors:
+        latent_space = prediction[0].latent.latent_s1
+
+    if "S1ASC" in sensors:
+        latent_space = prediction[0].latent.latent_s1
+
+    if "S1DESC" in sensors:
+        latent_space = prediction[0].latent.latent_s1
+
+    return latent_space
+
+
+def patchify_batch(
+    tensor: torch.Tensor,
+    patch_size: int,
+) -> torch.tensor:
+    """
+    reshape the geotiff data readed to the
+    shape expected from the network
+
+    :param: tensor
+    :param: patch_size
+    """
+    patch = patches.patchify(tensor, patch_size)
+    flatten_patch = patches.flatten2d(patch)
+
+    return flatten_patch
+
+
+def unpatchify_batch(
+    flatten_patch: torch.tensor, patch_shape: torch.Size, tensor_shape: torch.Size
+) -> torch.tensor:
+    """
+    Inverse operation of patchify batch
+    :param:   #
+    :param:   #
+    :param:   #
+    :return:  #
+    """
+    unflatten_patch = patches.unflatten2d(flatten_patch, patch_shape[0], patch_shape[1])
+    unpatch = patches.unpatchify(unflatten_patch)
+    unpatch_crop = unpatch[:, : tensor_shape[1], : tensor_shape[2]]
+
+    return unpatch_crop
diff --git a/src/mmdc_singledate/inference/components/inference_utils.py b/src/mmdc_singledate/inference/components/inference_utils.py
new file mode 100644
index 00000000..3ec181c5
--- /dev/null
+++ b/src/mmdc_singledate/inference/components/inference_utils.py
@@ -0,0 +1,324 @@
+#!/usr/bin/env python3
+# copyright: (c) 2023 cesbio / centre national d'Etudes Spatiales
+
+"""
+Infereces utils functions
+
+"""
+from collections.abc import Callable, Generator
+from dataclasses import dataclass, field
+from pathlib import Path
+from typing import Any
+
+import rasterio as rio
+import torch
+from rasterio.windows import Window as rio_window
+
+from mmdc_singledate.datamodules.components.datamodule_components import (
+    get_worldclim_filenames,
+)
+from mmdc_singledate.datamodules.components.datamodule_utils import (
+    apply_log_to_s1,
+    get_s1_acquisition_angles,
+    join_even_odd_s2_angles,
+    srtm_height_aspect,
+)
+
+# dataclasses
+Chunk = tuple[int, int, int, int]
+
+# hold the data
+
+
+@dataclass
+class GeoTiffDataset:
+
+    """
+    hold the tif's filenames
+    stored in disk and his metadata
+    """
+
+    s2_filename: str
+    s1_asc_filename: str
+    s1_desc_filename: str
+    srtm_filename: str
+    wc_filename: str
+    metadata: dict = field(init=False)
+
+    def __post_init__(self) -> None:
+        # check if the data exists
+        # if not Path(self.s2_filename).exists():
+        #     raise Exception(f"{self.s2_filename} do not exists!")
+        # if not Path(self.s1_asc_filename).exists():
+        #     raise Exception(f"{self.s1_asc_filename} do not exists!")
+        # if not Path(self.s1_desc_filename).exists():
+        #     raise Exception(f"{self.s1_desc_filename} do not exists!")
+        if not Path(self.srtm_filename).exists():
+            raise Exception(f"{self.srtm_filename} do not exists!")
+        if not Path(self.wc_filename).exists():
+            raise Exception(f"{self.wc_filename} do not exists!")
+
+        # check if the filenames
+        # provide are in the same
+        # footprint and spatial resolution
+        # rio.open(self.s2_filename) as s2, rio.open(
+        #    self.s1_asc_filename
+        # ) as s1_asc, rio.open(self.s1_desc_filename) as s1_desc,
+
+        with rio.open(self.srtm_filename) as srtm, rio.open(self.wc_filename) as wc:
+            # recover the metadata
+            self.metadata = srtm.meta
+
+            # list to check
+            crss = {
+                # s2.meta["crs"],
+                # s1_asc.meta["crs"],
+                # s1_desc.meta["crs"],
+                srtm.meta["crs"],
+                wc.meta["crs"],
+            }
+
+            heights = {
+                # s2.meta["height"],
+                # s1_asc.meta["height"],
+                # s1_desc.meta["height"],
+                srtm.meta["height"],
+                wc.meta["height"],
+            }
+
+            widths = {
+                # s2.meta["width"],
+                # s1_asc.meta["width"],
+                # s1_desc.meta["width"],
+                srtm.meta["width"],
+                wc.meta["width"],
+            }
+
+            # check crs
+            if len(crss) > 1:
+                raise Exception("Data should be in the same CRS")
+
+            # check height
+            if len(heights) > 1:
+                raise Exception("Data should be have the same height")
+
+            # check width
+            if len(widths) > 1:
+                raise Exception("Data should be have the same width")
+
+
+@dataclass
+class S2Components:
+    """
+    dataclass for hold the s2 related data
+    """
+
+    s2_reflectances: torch.Tensor
+    s2_angles: torch.Tensor
+    s2_mask: torch.Tensor
+
+
+@dataclass
+class S1Components:
+    """
+    dataclass for hold the s1 related data
+    """
+
+    s1_backscatter: torch.Tensor
+    s1_valmask: torch.Tensor
+    s1_edgemask: torch.Tensor
+    s1_lia_angles: torch.Tensor
+
+
+@dataclass
+class SRTMComponents:
+    """
+    dataclass for hold srtm related data
+    """
+
+    srtm: torch.Tensor
+
+
+@dataclass
+class WorldClimComponents:
+    """
+    dataclass for hold worldclim related data
+    """
+
+    worldclim: torch.Tensor
+
+
+# chunks logic
+def generate_chunks(
+    width: int, height: int, nlines: int, test_area: tuple[int, int] | None = None
+) -> list[Chunk]:
+    """Given the width and height of an image and a number of lines per chunk,
+    generate the list of chunks for the image. The chunks span all the width. A
+    chunk is encoded as 4 values: (x0, y0, width, height)
+    :param width: image width
+    :param height: image height
+    :param nlines: number of lines per chunk
+    :param test_area: generate only the chunks between (ya, yb)
+    :returns: a list of chunks
+    """
+    chunks = []
+    for iter_height in range(height // nlines - 1):
+        y0 = iter_height * nlines
+        y1 = y0 + nlines
+        if y1 > height:
+            y1 = height
+        append_chunk = (test_area is None) or (y0 > test_area[0] and y1 < test_area[1])
+        if append_chunk:
+            chunks.append((0, y0, width, y1 - y0))
+    return chunks
+
+
+# read data
+def read_img_tile(
+    filename: str,
+    rois: list[rio_window],
+    sensor_func: Callable[[torch.Tensor, Any], Any],
+) -> Generator[Any, None, None]:
+    """
+    read a patch of a abstract sensor data
+    contruct the auxiliary data and yield the patch
+    of data
+    """
+    # check if the filename exists
+    # if exist proceed to yield the data
+    if Path(filename).exists:
+        with rio.open(filename) as raster:
+            for roi in rois:
+                # read the patch as tensor
+                tensor = torch.tensor(raster.read(window=roi), requires_grad=False)
+                # compute some transformation to the tensor and return dataclass
+                sensor_data = sensor_func(tensor, filename)
+                yield sensor_data
+    # if not exists create a zeros tensor and yield
+    else:
+        for roi in rois:
+            null_data = torch.zeros(roi.width, roi.height)
+            yield null_data
+
+
+def read_s2_img_tile(
+    s2_tensor: torch.Tensor,
+    *args: Any,
+    **kwargs: Any,
+) -> S2Components:  # [torch.Tensor, torch.Tensor, torch.Tensor]:
+    """
+    read a patch of sentinel 2 data
+    contruct the masks and yield the patch
+    of data
+    """
+    if s2_tensor.shape[0] == 20:
+        # extract masks
+        cloud_mask = s2_tensor[11, ...].to(torch.uint8)
+        cloud_mask[cloud_mask > 0] = 1
+        sat_mask = s2_tensor[10, ...].to(torch.uint8)
+        edge_mask = s2_tensor[12, ...].to(torch.uint8)
+        # create the validity mask
+        mask = torch.logical_or(
+            cloud_mask, torch.logical_or(edge_mask, sat_mask)
+        ).unsqueeze(0)
+        angles_s2 = join_even_odd_s2_angles(s2_tensor[14:, ...])
+        image_s2 = s2_tensor[:10, ...]
+
+    else:
+        image_s2 = torch.zeros(10, s2_tensor.shape[0], s2_tensor.shape[1])
+        angles_s2 = torch.zeros(6, s2_tensor.shape[0], s2_tensor.shape[1])
+        mask = torch.zeros(1, s2_tensor.shape[0], s2_tensor.shape[1])
+
+    return S2Components(s2_reflectances=image_s2, s2_angles=angles_s2, s2_mask=mask)
+
+
+def read_s1_img_tile(
+    s1_tensor: torch.Tensor,
+    s1_filename: str,
+    *args: Any,
+    **kwargs: Any,
+) -> S1Components:  # [torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+    """
+    read a patch of s1 construct the datasets associated
+    and yield the result
+    """
+    if s1_tensor.shape[0] == 2:
+        # get the vv,vh, vv/vh bands
+        img_s1 = torch.cat(
+            (s1_tensor, (s1_tensor[1, ...] / s1_tensor[0, ...]).unsqueeze(0))
+        )
+        # get the local incidencce angle (lia)
+        s1_lia = get_s1_acquisition_angles(s1_filename)
+        # compute validity mask
+        s1_valmask = torch.ones(img_s1.shape)
+        # compute edge mask
+        s1_edgemask = img_s1.to(int)
+        s1_backscatter = apply_log_to_s1(img_s1)
+    else:
+        s1_backscatter = torch.zeros(3, s1_tensor.shape[0], s1_tensor.shape[1])
+        s1_valmask = torch.zeros(1, s1_tensor.shape[0], s1_tensor.shape[1])
+        s1_edgemask = torch.zeros(1, s1_tensor.shape[0], s1_tensor.shape[1])
+        s1_lia = torch.zeros(3)
+
+    return S1Components(
+        s1_backscatter=s1_backscatter,
+        s1_valmask=s1_valmask,
+        s1_edgemask=s1_edgemask,
+        s1_lia_angles=s1_lia,
+    )
+
+
+def read_srtm_img_tile(
+    srtm_tensor: torch.Tensor,
+    *args: Any,
+    **kwargs: Any,
+) -> SRTMComponents:  # [torch.Tensor]:
+    """
+    read srtm patch
+    """
+    # read the patch as tensor
+    img_srtm = srtm_height_aspect(srtm_tensor)
+    return SRTMComponents(img_srtm)
+
+
+def read_worldclim_img_tile(
+    wc_tensor: torch.Tensor,
+    *args: Any,
+    **kwargs: Any,
+) -> WorldClimComponents:  # [torch.tensor]:
+    """
+    read worldclim subset
+    """
+    return WorldClimComponents(wc_tensor)
+
+
+def expand_worldclim_filenames(wc_filename: str) -> list[str]:
+    """
+    given the firts worldclim filename expand the filenames
+    to the others files
+    """
+    name_wc = [
+        get_worldclim_filenames(wc_filename, str(idx), True) for idx in range(1, 13, 1)
+    ]
+
+    name_wc_bio = get_worldclim_filenames(wc_filename, None, False)
+    name_wc.append(name_wc_bio)
+
+    return name_wc
+
+
+def concat_worldclim_components(
+    wc_filename: str, rois: list[rio_window]
+) -> Generator[Any, None, None]:
+    """
+    Compose function for apply the read_img_tile general function
+    to the different worldclim files
+    """
+    # get all filenames
+    wc_filenames = expand_worldclim_filenames(wc_filename)
+    # read the tensors
+    wc_tensor = [
+        next(read_img_tile(wc_filename, rois, read_worldclim_img_tile)).worldclim
+        for idx, wc_filename in enumerate(wc_filenames)
+    ]
+    yield WorldClimComponents(torch.cat(wc_tensor))
-- 
GitLab


From 8372a1537ddbcc2cff8e82dcbf86b00bb7d41ab9 Mon Sep 17 00:00:00 2001
From: Juan Sebastian Vinasco-Salinas <js.vinasco.s@gmail.com>
Date: Fri, 17 Feb 2023 09:14:15 +0000
Subject: [PATCH 08/81] WIP add inference components utils

---
 src/mmdc_singledate/inference/components/inference_components.py | 1 -
 1 file changed, 1 deletion(-)

diff --git a/src/mmdc_singledate/inference/components/inference_components.py b/src/mmdc_singledate/inference/components/inference_components.py
index f480b032..fe3212b2 100644
--- a/src/mmdc_singledate/inference/components/inference_components.py
+++ b/src/mmdc_singledate/inference/components/inference_components.py
@@ -34,7 +34,6 @@ logger = logging.getLogger(__name__)
 def get_mmdc_full_model(
     checkpoint: str,
 ) -> nn.Module:
-
     """
     Instantiate the network
 
-- 
GitLab


From 30e98cd1437d1a1a460d40163fdd85b3ad2f9cd3 Mon Sep 17 00:00:00 2001
From: Juan Sebastian Vinasco-Salinas <js.vinasco.s@gmail.com>
Date: Fri, 17 Feb 2023 09:15:50 +0000
Subject: [PATCH 09/81] add test inference

---
 test/test_mmdc_inference.py | 273 ++++++++++++++++++++++++++++++++++++
 1 file changed, 273 insertions(+)
 create mode 100644 test/test_mmdc_inference.py

diff --git a/test/test_mmdc_inference.py b/test/test_mmdc_inference.py
new file mode 100644
index 00000000..90d654be
--- /dev/null
+++ b/test/test_mmdc_inference.py
@@ -0,0 +1,273 @@
+#!/usr/bin/env python3
+# Copyright: (c) 2023 CESBIO / Centre National d'Etudes Spatiales
+
+import os
+
+import pytest
+import torch
+
+from mmdc_singledate.inference.components.inference_components import (  # predict_tile,
+    get_mmdc_full_model,
+)
+from mmdc_singledate.inference.components.inference_utils import (
+    GeoTiffDataset,
+    generate_chunks,
+)
+from mmdc_singledate.inference.mmdc_tile_inference import MMDCProcess, predict_tile
+
+# dir
+dataset_dir = "/work/CESBIO/projects/MAESTRIA/test_onetile/total/T31TCJ/"
+
+
+def test_GeoTiffDataset():
+    """ """
+    datapath = "/work/CESBIO/projects/MAESTRIA/test_onetile/total/T31TCJ"
+    input_data = GeoTiffDataset(
+        s2_filename=os.path.join(
+            datapath, "SENTINEL2A_20180302-105023-464_L2A_T31TCJ_C_V2-2_roi_0.tif"
+        ),
+        s1_asc_filename=os.path.join(
+            datapath,
+            "S1A_IW_GRDH_1SDV_20180303T174722_20180303T174747_020854_023C45_4525_36.09887223808782_15.429479982324139_roi_0.tif",
+        ),
+        s1_desc_filename=os.path.join(
+            datapath,
+            "S1A_IW_GRDH_1SDV_20180302T060027_20180302T060052_020832_023B8F_C485_40.83835645911643_165.05888005216622_roi_0.tif",
+        ),
+        srtm_filename=os.path.join(datapath, "srtm_T31TCJ_roi_0.tif"),
+        wc_filename=os.path.join(datapath, "wc_clim_1_T31TCJ_roi_0.tif"),
+    )
+    assert input_data
+    assert type(input_data.metadata) == dict
+
+
+def test_generate_chunks():
+    """
+    test chunk functionality
+    """
+    chunks = generate_chunks(width=10980, height=10980, nlines=1024)
+
+    assert type(chunks) == list
+    assert chunks[0] == (0, 0, 10980, 1024)
+
+
+def test_mmdc_full_model():
+    """
+    test instantiate network
+    """
+    checkpoint_path = "/home/uz/vinascj/scratch/MMDC/results/latent/logs/experiments/runs/mmdc_full/2023-01-06_13-00-34/checkpoints"
+    checkpoint_filename = "epoch_008.ckpt"  # "last.ckpt"
+
+    # lightning_dict = torch.load("epoch_008.ckpt", map_location=torch.device("cpu"))
+
+    mmdc_full_model = get_mmdc_full_model(
+        os.path.join(checkpoint_path, checkpoint_filename)
+    )
+
+    print(mmdc_full_model)
+
+    assert mmdc_full_model.training == False
+
+    s2_x = torch.rand(1, 10, 256, 256)
+    s2_m = torch.ones(1, 10, 256, 256)
+    s2_angles_x = torch.rand(1, 6, 256, 256)
+    s1_x = torch.rand(1, 6, 256, 256)
+    s1_vm = torch.ones(1, 6, 256, 256)
+    s1_asc_angles_x = torch.rand(1, 3)
+    s1_desc_angles_x = torch.rand(1, 3)
+    worldclim_x = torch.rand(1, 103, 256, 256)
+    srtm_x = torch.rand(1, 4, 256, 256)
+
+    prediction = mmdc_full_model.predict(
+        s2_x,
+        s2_m,
+        s2_angles_x,
+        s1_x,
+        s1_vm,
+        s1_asc_angles_x,
+        s1_desc_angles_x,
+        worldclim_x,
+        srtm_x,
+    )
+
+    latent_variable = prediction[0].latent.latent_s1.mu
+
+    assert type(latent_variable) == torch.Tensor
+
+
+def dummy_process(
+    s2_refl,
+    s2_ang,
+    s1_asc,
+    s1_desc,
+    s1_asc_lia,
+    s1_desc_lia,
+    srtm_patch,
+    wc_patch,
+):
+    """
+    Create a dummy function
+    """
+    prediction = (
+        s2_refl[:, 0, ...]
+        * s2_ang[:, 0, ...]
+        * s1_asc[:, 0, ...]
+        * s1_desc[:, 0, ...]
+        * srtm_patch[:, 0, ...]
+        * wc_patch[:, 0, ...]
+    )
+    return prediction
+
+
+@pytest.mark.skip(reason="Have to update to the new data format")
+def test_predict_tile():
+    """ """
+    export_path = (
+        "/work/CESBIO/projects/MAESTRIA/test_onetile/total/export/test_latent.tif"
+    )
+    process = MMDCProcess(
+        count=1,
+        nb_lines=1024,
+        process=dummy_process,
+    )
+
+    datapath = "/work/CESBIO/projects/MAESTRIA/test_onetile/total/T31TCJ/"
+    input_data = GeoTiffDataset(
+        s2_filename=os.path.join(
+            datapath, "SENTINEL2A_20180302-105023-464_L2A_T31TCJ_C_V2-2_roi_0.tif"
+        ),
+        s1_asc_filename=os.path.join(
+            datapath,
+            "S1A_IW_GRDH_1SDV_20180303T174722_20180303T174747_020854_023C45_4525_36.09887223808782_15.429479982324139_roi_0.tif",
+        ),
+        s1_desc_filename=os.path.join(
+            datapath,
+            "S1A_IW_GRDH_1SDV_20180302T060027_20180302T060052_020832_023B8F_C485_40.83835645911643_165.05888005216622_roi_0.tif",
+        ),
+        srtm_filename=os.path.join(datapath, "srtm_T31TDJ_roi_0.tif"),
+        wc_filename=os.path.join(datapath, "wc_clim_1_T31TDJ_roi_0.tif"),
+    )
+
+    predict_tile(
+        input_data=input_data,
+        export_path=export_path,
+        process=process,
+    )
+
+
+# def test_read_s2_img_tile():
+#     """
+#     test read a window of s2
+#     """
+#     dataset_sample = os.path.join(
+#         dataset_dir, "SENTINEL2A_20180302-105023-464_L2A_T31TDJ_C_V2-2_roi_0.tif"
+#     )
+
+#     chunks = generate_chunks(width=10980, height=10980, nlines=1024)
+#     roi = [rio.windows.Window(*chunk) for chunk in chunks]
+
+#     s2_data = read_s2_img_tile(s2_filename=dataset_sample, rois=roi)
+#     # get the firt element
+#     s2_data_window = next(s2_data)
+
+#     assert s2_data_window.s2_reflectances.shape == torch.Size([10, 1024, 10980])
+#     assert s2_data_window.s2_angles.shape == torch.Size([6, 1024, 10980])
+#     assert s2_data_window.s2_mask.shape == torch.Size([1, 1024, 10980])
+
+
+# def test_read_s1_img_tile():
+#     """
+#     test read a window of s1
+#     """
+#     dataset_sample = os.path.join(
+#         dataset_dir,
+#         "S1A_IW_GRDH_1SDV_20180303T174722_20180303T174747_020854_023C45_4525_41.98988334384891_15.429479982324139_roi_0.tif",
+#     )
+
+#     chunks = generate_chunks(width=10980, height=10980, nlines=1024)
+#     roi = [rio.windows.Window(*chunk) for chunk in chunks]
+
+#     s1_data = read_s1_img_tile(s1_filename=dataset_sample, rois=roi)
+#     # get the firt element
+#     s1_data_window = next(s1_data)
+
+#     assert s1_data_window.s1_backscatter.shape == torch.Size([3, 1024, 10980])
+#     assert s1_data_window.s1_valmask.shape == torch.Size([3, 1024, 10980])
+#     assert s1_data_window.s1_edgemask.shape == torch.Size([3, 1024, 10980])
+#     assert s1_data_window.s1_lia_angles.shape == torch.Size([3])
+
+
+# def test_read_srtm_img_tile():
+#     """
+#     test read a window of srtm
+#     """
+#     dataset_sample = os.path.join(dataset_dir, "srtm_T31TDJ_roi_0.tif")
+
+#     chunks = generate_chunks(width=10980, height=10980, nlines=1024)
+#     roi = [rio.windows.Window(*chunk) for chunk in chunks]
+
+#     srtm_data = read_srtm_img_tile(srtm_filename=dataset_sample, rois=roi)
+#     # get the firt element
+#     srtm_data_window = next(srtm_data)
+
+#     assert srtm_data_window.shape == torch.Size([4, 1024, 10980])
+
+
+# def test_read_worlclim_img_tile():
+#     """
+#     test read a window of srtm
+#     """
+#     dataset_sample = os.path.join(dataset_dir, "wc_clim_1_T31TDJ_roi_0.tif")
+
+#     chunks = generate_chunks(width=10980, height=10980, nlines=1024)
+#     roi = [rio.windows.Window(*chunk) for chunk in chunks]
+
+#     wc_data = read_worldclim_img_tile(wc_filename=dataset_sample, rois=roi)
+#     # get the firt element
+#     wc_data_window = next(wc_data)
+
+#     assert wc_data_window.shape == torch.Size([103, 1024, 10980])
+
+
+# def test_predict_tile():
+#     """ """
+
+#     input_path = "/work/CESBIO/projects/MAESTRIA/test_onetile/total/T31TDJ/"
+#     export_path = "/work/CESBIO/projects/MAESTRIA/test_onetile/total/export/"
+
+#     mmdc_filenames = GeoTiffDataset(
+#         s2_filename=os.path.join(
+#             input_path, "SENTINEL2A_20180302-105023-464_L2A_T31TDJ_C_V2-2_roi_0.tif"
+#         ),
+#         s1_filename=os.path.join(
+#             input_path,
+#             "S1B_IW_GRDH_1SDV_20180304T173831_20180304T173856_009885_011E2F_CEA4_32.528381567125045_15.330754606990766_roi_0.tif",
+#         ),
+#         srtm_filename=os.path.join(input_path, "srtm_T31TDJ_roi_0.tif"),
+#         wc_filename=os.path.join(input_path, "wc_clim_1_T31TDJ_roi_0.tif"),
+#     )
+
+#     # get grid mesh
+#     rois, meta = get_rois_from_prediction_mesh(
+#         input_filename=mmdc_filenames.s2_filename, nlines=1024
+#     )
+
+#     export_filename = "latent_variable.tif"
+#     exported_file = predict_tile(
+#         input_path=input_path,
+#         export_path=export_path,
+#         mmdc_filenames=mmdc_filenames,
+#         export_filename=export_filename,
+#         rois=rois,
+#         meta=meta,
+#     )
+#     # assert existence of exported file
+#     assert Path(exported_file).exists
+
+
+# #
+
+
+# # asc = "S1A_IW_GRDH_1SDV_20180303T174722_20180303T174747_020854_023C45_4525_39.6607434587337_15.429479982324139_roi_0.tif"
+# # desc = "S1B_IW_GRDH_1SDV_20180303T055142_20180303T055207_009863_011D75_B010_45.08497158404881_165.07686884776342_roi_4.tif
+# "
-- 
GitLab


From bab2b88046f139e6e925b862fe7e6f0107569f4a Mon Sep 17 00:00:00 2001
From: Juan Sebastian Vinasco-Salinas <js.vinasco.s@gmail.com>
Date: Fri, 17 Feb 2023 09:16:48 +0000
Subject: [PATCH 10/81] add inference tile

---
 .../inference/mmdc_tile_inference.py          | 273 ++++++++++++++++++
 1 file changed, 273 insertions(+)
 create mode 100644 src/mmdc_singledate/inference/mmdc_tile_inference.py

diff --git a/src/mmdc_singledate/inference/mmdc_tile_inference.py b/src/mmdc_singledate/inference/mmdc_tile_inference.py
new file mode 100644
index 00000000..b4aec447
--- /dev/null
+++ b/src/mmdc_singledate/inference/mmdc_tile_inference.py
@@ -0,0 +1,273 @@
+#!/usr/bin/env python3
+# copyright: (c) 2023 cesbio / centre national d'Etudes Spatiales
+
+"""
+Infereces API with Rasterio
+"""
+
+# imports
+import logging
+import os
+from collections.abc import Callable
+from dataclasses import dataclass
+from pathlib import Path
+
+import rasterio as rio
+import torch
+
+from mmdc_singledate.datamodules.components.datamodule_components import prepare_data_df
+
+from .components.inference_components import (
+    get_mmdc_full_model,
+    patchify_batch,
+    predict_mmdc_model,
+    unpatchify_batch,
+)
+from .components.inference_utils import (
+    GeoTiffDataset,
+    concat_worldclim_components,
+    generate_chunks,
+    read_img_tile,
+    read_s1_img_tile,
+    read_s2_img_tile,
+    read_srtm_img_tile,
+)
+
+# Configure the logger
+NUMERIC_LEVEL = getattr(logging, "INFO", None)
+logging.basicConfig(
+    level=NUMERIC_LEVEL, format="%(asctime)-15s %(levelname)s: %(message)s"
+)
+logger = logging.getLogger(__name__)
+
+
+def inference_dataframe(
+    samples_dir: str,
+    input_tile_list: list[str],
+    days_gap: int,
+    nb_tiles: int,
+    nb_rois: int,
+    nb_files: int = None,
+):
+    """
+    Read the input directory and create
+    a dataframe with the occurrences
+    for be infered
+    """
+    # read the metadata and contruct time serie
+    (
+        tile_df,
+        asc_orbit_without_acquisitions,
+        desc_orbit_without_acquisitions,
+    ) = prepare_data_df(
+        samples_dir=samples_dir,
+        input_tile_list=input_tile_list,  # or pass,
+        days_gap=days_gap,
+        nb_files=nb_files,
+        nb_tiles=nb_tiles,
+        nb_rois=1,
+    )
+
+    # manage the abscense of S1ASC or S1DESC
+    if asc_orbit_without_acquisitions:
+        tile_df["patchasc_s1"] = None
+
+    if desc_orbit_without_acquisitions:
+        tile_df["patchdesc_s1"] = None
+
+    return tile_df
+
+
+# functions and classes
+@dataclass
+class MMDCProcess:
+    """
+    Class to hold
+    """
+
+    count: int
+    nb_lines: int
+    process: Callable[
+        [
+            torch.tensor,
+            torch.tensor,
+            torch.tensor,
+            torch.tensor,
+            torch.tensor,
+            torch.tensor,
+            torch.tensor,
+        ],
+        torch.tensor,
+    ]
+
+
+def predict_tile(
+    input_data: GeoTiffDataset,
+    export_path: Path,
+    sensors: list[str],
+    process: MMDCProcess,
+):
+    """
+    Predict a tile of data
+    """
+    # retrieve the metadata
+    meta = input_data.metadata.copy()
+    # get nb outputs
+    meta.update({"count": process.count})
+    # calculate rois
+    chunks = generate_chunks(meta["width"], meta["height"], process.nb_lines)
+    # get the windows from the chunks
+    rois = [rio.windows.Window(*chunk) for chunk in chunks]
+    logger.info(f"chunk size : ({rois[0].width}, {rois[0].height}) ")
+
+    # init the dataset
+    logger.info("Reading S2 data")
+    s2_data = read_img_tile(input_data.s2_filename, rois, read_s2_img_tile)
+    logger.info("Reading S1 ASC data")
+    s1_asc_data = read_img_tile(input_data.s1_asc_filename, rois, read_s1_img_tile)
+    logger.info("Reading S1 DESC data")
+    s1_desc_data = read_img_tile(input_data.s1_desc_filename, rois, read_s1_img_tile)
+    logger.info("Reading SRTM data")
+    srtm_data = read_img_tile(input_data.srtm_filename, rois, read_srtm_img_tile)
+    logger.info("Reading WorldClim data")
+    worldclim_data = concat_worldclim_components(input_data.wc_filename, rois)
+
+    logger.info("Export Init")
+
+    # separate the latent spaces by sensor
+    sensors = [s for s in sensors if s in ["S2L2A", "S1FULL", "S1ASC", "S1DESC"]]
+    # built export_filename
+    export_names = ["mmdc_latent_" + s.casefold() + ".tif" for s in sensors]
+    for idx, export_name in enumerate(export_names):
+        # export latent spaces
+        with rio.open(
+            os.path.join(export_path, export_name), "w", **meta
+        ) as prediction:
+            # iterate over the windows
+            for roi, s2, s1_asc, s1_desc, srtm, wc in zip(
+                rois,
+                s2_data,
+                s1_asc_data,
+                s1_desc_data,
+                srtm_data,
+                worldclim_data,
+            ):
+                print(" original size : ", s2.s2_reflectances.shape)
+                # reshape the data
+                s2_refl_patch = patchify_batch(s2.s2_reflectances, 256)
+                s2_ang_patch = patchify_batch(s2.s2_angles, 256)
+                s1_asc_patch = patchify_batch(s1_asc.s1_backscatter, 256)
+                s1_asc_lia_patch = s1_asc.s1_lia_angles
+                s1_desc_patch = patchify_batch(s1_desc.s1_backscatter, 256)
+                s1_desc_lia_patch = s1_desc.s1_lia_angles
+                srtm_patch = patchify_batch(srtm.srtm, 256)
+                wc_patch = patchify_batch(wc.worldclim, 256)
+                print(
+                    s2_refl_patch.shape,
+                    s2_ang_patch.shape,
+                    s1_asc_patch.shape,
+                    s1_desc_patch.shape,
+                    s1_asc_lia_patch.shape,
+                    s1_desc_lia_patch.shape,
+                    srtm_patch.shape,
+                    wc_patch.shape,
+                )
+
+                # apply predict function
+                pred_vaelatentspace = process.process(
+                    s2_refl_patch,
+                    s2_ang_patch,
+                    s1_asc_patch,
+                    s1_desc_patch,
+                    s1_asc_lia_patch,
+                    s1_desc_lia_patch,
+                    srtm_patch,
+                    wc_patch,
+                )
+                pred_tensor = torch.cat(
+                    (
+                        pred_vaelatentspace.mu,
+                        pred_vaelatentspace.logvar,
+                    ),
+                    1,
+                )
+                print(pred_tensor.shape)
+                prediction.write(
+                    unpatchify_batch(
+                        flatten_patch=pred_tensor,
+                        patch_shape=s2_refl_patch.shape,
+                        tensor_shape=s2.s2_reflectances.shape,
+                    ),
+                    window=roi,
+                    indexes=1,
+                )
+                logger.info(("Export tile", f"filename :{export_path}"))
+
+
+def mmdc_tile_inference(
+    input_path: Path,
+    export_path: Path,
+    model_checkpoint_path: Path,
+    sensor: list[str],
+    tile_list: list[str],
+    days_gap: int,
+    nb_tiles: int,
+    nb_files: int = None,
+    latent_space_size: int = 2,
+    nb_lines: int = 1024,
+) -> None:
+    """
+    Entry point
+
+        :args: input_path : full path to the export
+        :args: export_path : full path to the export
+        :args: sensor: list of sensors data input
+        :args: tile_list: list of tiles,
+        :args: days_gap : days between s1 and s2 acquisitions,
+        :args: nb_tiles : number of tiles to process,
+        :args: nb_files : number files,
+        :args: latent_space_size : latent space output par sensor,
+        :args: nb_lines : number of lines to read every at time 1024,
+
+        :return : None
+
+    """
+    # dataframe with input data
+    tile_df = inference_dataframe(
+        samples_dir=input_path,
+        input_tile_list=tile_list,
+        days_gap=days_gap,
+        nb_tiles=nb_tiles,
+        nb_rois=1,
+        nb_files=nb_files,
+    )
+
+    # instance the model and get the pred func
+    model = get_mmdc_full_model(
+        checkpoint=model_checkpoint_path,
+    )
+    pred_func = predict_mmdc_model(model=model, sensor=sensor)
+    #
+    mmdc_process = MMDCProcess(
+        count=latent_space_size, nb_lines=nb_lines, process=pred_func
+    )
+
+    # iterate over the dates in the time serie
+    for tuile, df_row in tile_df.iterrows():
+        # get the input data
+        mmdc_input_data = GeoTiffDataset(
+            s2_filename=df_row.patch_s2,
+            s1_asc_filename=df_row.patchasc_s1,
+            s1_desc_filename=df_row.patchdesc_s1,
+            srtm_filename=df_row.srtm_filename,
+            wc_filename=df_row.worldclim_filename,
+        )
+        # predict tile
+        predict_tile(
+            input_data=mmdc_input_data,
+            export_path=export_path,
+            sensors=sensor,
+            process=mmdc_process,
+        )
+
+        logger.info("Export Finish !!!")
-- 
GitLab


From 9243fd75ae8c83014cc96d939857db38bbeeab6a Mon Sep 17 00:00:00 2001
From: Juan Sebastian Vinasco-Salinas <js.vinasco.s@gmail.com>
Date: Fri, 17 Feb 2023 09:24:43 +0000
Subject: [PATCH 11/81] add iota2 API

---
 .../inference/mmdc_tile_inference_iota2.py    | 98 +++++++++++++++++++
 1 file changed, 98 insertions(+)
 create mode 100644 src/mmdc_singledate/inference/mmdc_tile_inference_iota2.py 

diff --git a/src/mmdc_singledate/inference/mmdc_tile_inference_iota2.py  b/src/mmdc_singledate/inference/mmdc_tile_inference_iota2.py 
new file mode 100644
index 00000000..f434b108
--- /dev/null
+++ b/src/mmdc_singledate/inference/mmdc_tile_inference_iota2.py 	
@@ -0,0 +1,98 @@
+#!/usr/bin/env python3
+# copyright: (c) 2023 cesbio / centre national d'Etudes Spatiales
+
+"""
+Infereces API with Iota-2
+Inspired by:
+https://src.koda.cnrs.fr/mmdc/mmdc-singledate/-/blob/01ff5139a9eb22785930964d181e2a0b7b7af0d1/iota2/external_iota2_code.py
+"""
+
+from typing import Any
+
+import torch
+from torchutils import patches
+
+from mmdc_singledate.datamodules.components.datamodule_utils import (
+    apply_log_to_s1
+)
+
+
+def apply_mmdc_full_mode(
+    self,
+    checkpoint_path: str,
+    checkpoint_epoch: int = 100,
+    patch_size: int = 256,
+):
+    """
+    Apply MMDC with Iota-2.py
+    """
+    # TODO Get the data in the same order as
+    # sentinel2.Sentinel2.GROUP_10M +
+    # sensorsio.sentinel2.GROUP_20M
+    list_bands_s2 = [
+        self.get_interpolated_Sentinel2_B2(),
+        self.get_interpolated_Sentinel2_B3(),
+        self.get_interpolated_Sentinel2_B4(),
+        self.get_interpolated_Sentinel2_B8(),
+        self.get_interpolated_Sentinel2_B5(),
+        self.get_interpolated_Sentinel2_B6(),
+        self.get_interpolated_Sentinel2_B7(),
+        self.get_interpolated_Sentinel2_B8A(),
+        self.get_interpolated_Sentinel2_B11(),
+        self.get_interpolated_Sentinel2_B12(),
+    ]
+
+    # TODO Manage S1 ASC and S1 DESC ?
+    list_bands_s1 = [
+        self.get_interpolated_Sentinel1_ASC_vh(),
+        self.get_interpolated_Sentinel1_ASC_vv(),
+        self.get_interpolated_Sentinel1_ASC_vh()
+        / (self.get_interpolated_Sentinel1_ASC_vv() + 1e-4),
+        self.get_interpolated_Sentinel1_DES_vh(),
+        self.get_interpolated_Sentinel1_DES_vv(),
+        self.get_interpolated_Sentinel1_DES_vh()
+        / (self.get_interpolated_Sentinel1_DES_vv() + 1e-4),
+    ]
+
+    # TODO Masks contruction
+
+
+    with torch.no_grad():
+        # Permute dimensions to fit patchify
+        # Shape before permutation is C,H,W,D. D being the dates
+        # Shape after permutation is D,C,H,W
+        bands_s2 = torch.Tensor(list_bands_s2).permute(-1, 0, 1, 2)
+        bands_s1 = torch.Tensor(list_bands_s1).permute(-1, 0, 1, 2)
+        bands_s1 = apply_log_to_s1(bands_s1)
+
+
+    # TODO Apply patchify
+
+
+    # Get the model
+    mmdc_full_model = get_mmdc_full_model(
+        os.path.join(checkpoint_path, checkpoint_filename)
+    )
+
+    # apply model
+    latent_variable = mmdc_full_model.predict(
+        s2_x,
+        s2_m,
+        s2_angles_x,
+        s1_x,
+        s1_vm,
+        s1_asc_angles_x,
+        s1_desc_angles_x,
+        worldclim_x,
+        srtm_x,
+    )
+
+    # TODO unpatchify
+
+    # TODO crop padding
+
+    # TODO Depending of the configuration return a unique latent variable
+    # or a stack of laten variables
+    coef = pass
+    labels = pass
+    return coef, labels
-- 
GitLab


From 496a5f0479c6ed5c26fb7b438a295394eeb30522 Mon Sep 17 00:00:00 2001
From: Juan Sebastian Vinasco-Salinas <js.vinasco.s@gmail.com>
Date: Fri, 17 Feb 2023 09:25:32 +0000
Subject: [PATCH 12/81] ignore iota2 thirdparty

---
 .gitignore | 1 +
 1 file changed, 1 insertion(+)

diff --git a/.gitignore b/.gitignore
index 676c9c84..281e6999 100644
--- a/.gitignore
+++ b/.gitignore
@@ -10,6 +10,7 @@ src/models/*.ipynb
 *.swp
 *~
 thirdparties
+iota2_thirdparties
 .coverage
 src/MMDC_SingleDate.egg-info/
 .projectile
-- 
GitLab


From fab87779bdbd8d22a71430ad7e59920e5f0cd570 Mon Sep 17 00:00:00 2001
From: Juan Sebastian Vinasco-Salinas <js.vinasco.s@gmail.com>
Date: Fri, 17 Feb 2023 09:26:38 +0000
Subject: [PATCH 13/81] add script for configure iota2

---
 create-iota2-env.sh  | 104 +++++++++++++++++++++++++++++++++++++++++++
 1 file changed, 104 insertions(+)
 create mode 100644 create-iota2-env.sh 

diff --git a/create-iota2-env.sh  b/create-iota2-env.sh 
new file mode 100644
index 00000000..f7035081
--- /dev/null
+++ b/create-iota2-env.sh 	
@@ -0,0 +1,104 @@
+#!/usr/bin/env bash
+
+export python_version="3.9"
+export name="mmdc-iota2"
+if ! [ -z "$1" ]
+then
+    export name=$1
+fi
+
+source ~/set_proxy_iota2.sh
+if [ -z "$https_proxy" ]
+then
+    echo "Please set https_proxy environment variable before running this script"
+    exit 1
+fi
+
+export target=/work/scratch/$USER/virtualenv/$name
+
+if ! [ -z "$2" ]
+then
+    export target="$2/$name"
+fi
+
+echo "Installing $name in $target ..."
+
+if [ -d "$target" ]; then
+   echo "Cleaning previous conda env"
+   rm -rf $target
+fi
+
+# Create blank virtualenv
+module purge
+module load conda
+module load gcc
+conda activate
+conda create --yes --prefix $target python==${python_version} pip
+
+# Enter virtualenv
+conda deactivate
+conda activate $target
+
+# install mamba
+conda install mamba
+
+which python
+python --version
+
+# Install iota2
+#mamba install iota2_develop=257da617 -c iota2
+mamba install iota2 -c iota2
+
+# clone the lastest version of iota2
+rm -rf thirdparties/iota2
+git clone https://framagit.org/iota2-project/iota2.git thirdparties/iota2
+cd thirdparties/iota2
+git checkout issue#600/tile_exogenous_features
+git pull origin issue#600/tile_exogenous_features
+cd ../../
+pwd
+#which Iota2.py
+which Iota2.py
+
+# create a backup
+cd  /home/uz/$USER/scratch/virtualenv/mmdc-iota2/lib/python3.9/site-packages/iota2-0.0.0-py3.9.egg/
+mv iota2 iota2_backup
+
+# create a symbolic link
+ln -s ~/src/MMDC/mmdc-singledate/thirdparties/iota2/iota2/ iota2
+
+cd ~/src/MMDC/mmdc-singledate
+# install missing dependancies
+pip install -r requirements-mmdc-iota2.txt
+
+# test install
+Iota2.py -h
+
+# # install MMDC dependencies
+# conda install --yes pytorch=1.12.1 torchvision -c pytorch -c nvidia
+
+# conda deactivate
+# conda activate $target
+
+# # Requirements
+# pip install -r requirements-mmdc-sgld.txt
+
+module unload git
+# Install sensorsio
+rm -rf thirdparties/sensorsio
+git clone https://src.koda.cnrs.fr/mmdc/sensorsio.git thirdparties/sensorsio
+pip install thirdparties/sensorsio
+
+# Install torchutils
+rm -rf thirdparties/torchutils
+git clone https://src.koda.cnrs.fr/mmdc/torchutils.git thirdparties/torchutils
+pip install thirdparties/torchutils
+
+# Install the current project in edit mode
+pip install -e .[testing]
+
+# Activate pre-commit hooks
+pre-commit install
+
+# End
+conda deactivate
-- 
GitLab


From a914b8110c4cbcc6d01d6724e61a22278fd01513 Mon Sep 17 00:00:00 2001
From: Juan Sebastian Vinasco-Salinas <js.vinasco.s@gmail.com>
Date: Fri, 17 Feb 2023 14:08:57 +0000
Subject: [PATCH 14/81] Add iota conf

---
 ...ternalfeatures_with_userfeatures_50x50.cfg | 81 +++++++++++++++++++
 configs/iota2/i2_grid.cfg                     | 19 +++++
 jobs/iota2_aux_test.pbs                       | 20 +++++
 jobs/iota2_external_feature_test.pbs          | 18 +++++
 4 files changed, 138 insertions(+)
 create mode 100644 configs/iota2/externalfeatures_with_userfeatures_50x50.cfg
 create mode 100644 configs/iota2/i2_grid.cfg
 create mode 100644 jobs/iota2_aux_test.pbs
 create mode 100644 jobs/iota2_external_feature_test.pbs

diff --git a/configs/iota2/externalfeatures_with_userfeatures_50x50.cfg b/configs/iota2/externalfeatures_with_userfeatures_50x50.cfg
new file mode 100644
index 00000000..3d074b7b
--- /dev/null
+++ b/configs/iota2/externalfeatures_with_userfeatures_50x50.cfg
@@ -0,0 +1,81 @@
+chain :
+{
+  output_path : '/home/uz/vinascj/scratch/MMDC/iota2/juan_i2_data/results/external_user_features'
+  remove_output_path : False
+  check_inputs : False
+
+  nomenclature_path : '/home/uz/vinascj/scratch/MMDC/iota2/juan_i2_data/inputs/other/nomenclature_grosse.txt'
+  color_table : '/home/uz/vinascj/scratch/MMDC/iota2/juan_i2_data/inputs/other/colorFile_grosse.txt'
+
+  list_tile : 'T31TCJ'
+  ground_truth : '/home/uz/vinascj/scratch/MMDC/iota2/juan_i2_data/inputs/vector_data/S2_50x50.shp'
+  data_field : 'groscode'
+  s2_path : '/home/uz/vinascj/scratch/MMDC/iota2/juan_i2_data/inputs/raster_data/S2_2dates_50x50_symlink'
+  # s2_output_path : '/home/uz/vinascj/scratch/MMDC/iota2/juan_i2_data/inputs/raster_data/S2_2dates_50x50_output'
+  # user_feat_path: '/home/uz/vinascj/scratch/MMDC/iota2/juan_i2_data/inputs/raster_data/mnt_50x50'
+  # s1_path : '/home/uz/vinascj/scratch/MMDC/iota2/juan_i2_data/inputs/config/SAR_test_T31TCJ.cfg'
+
+  srtm_path:'/datalake/static_aux/MNT/SRTM_30_hgt'
+  worldclim_path:'/datalake/static_aux/worldclim-2.0'
+  grid : '/home/uz/vinascj/src/MMDC/mmdc-singledate/thirdparties/sensorsio/src/sensorsio/data/sentinel2/mgrs_tiles.shp' #'/XXXX/Features.shp' # MGRS file providing tiles grid
+  tile_field : 'Name'
+
+  spatial_resolution : 10
+  first_step : 'init'
+  last_step : 'validation'
+  proj : 'EPSG:2154'
+}
+userFeat:
+{
+  arbo:"/*"
+  patterns:"mnt2,slope"
+}
+python_data_managing :
+{
+    padding_size_x : 1
+    padding_size_y : 1
+    chunk_size_mode:"split_number"
+    number_of_chunks:2
+    data_mode_access: "both"
+}
+
+
+# builders:
+# {
+# builders_class_name : ["i2_features_to_grid"]
+# }
+
+external_features :
+{
+  module:"/home/uz/vinascj/scratch/MMDC/iota2/juan_i2_data/inputs/other/soi.py"
+  functions : [['test_mnt', {"arg1":1}]]
+  concat_mode : True
+  external_features_flag:True
+}
+arg_classification:
+{
+ enable_probability_map:True
+ generate_final_probability_map:True
+}
+arg_train :
+{
+
+  runs : 1
+  random_seed : 1
+  classifier:"sharkrf"
+
+  sample_selection :
+  {
+    sampler : 'random'
+    strategy : 'percent'
+    'strategy.percent.p' : 1.0
+    ram : 4000
+  }
+}
+
+task_retry_limits :
+{
+  allowed_retry : 0
+  maximum_ram : 180.0
+  maximum_cpu : 40
+}
diff --git a/configs/iota2/i2_grid.cfg b/configs/iota2/i2_grid.cfg
new file mode 100644
index 00000000..e13f8788
--- /dev/null
+++ b/configs/iota2/i2_grid.cfg
@@ -0,0 +1,19 @@
+chain :
+{
+  output_path : '/work/scratch/vinascj/MMDC/iota2/grid_test'
+  spatial_resolution : 100 # output target resolution
+  first_step : 'tiler' # do not change
+  last_step : 'tiler' # do not change
+
+  proj : 'EPSG:2154' # output target crs
+  grid : '/home/uz/vinascj/src/MMDC/mmdc-singledate/thirdparties/sensorsio/src/sensorsio/data/sentinel2/mgrs_tiles.shp' #'/XXXX/Features.shp' # MGRS file providing tiles grid
+  tile_field : 'Name'
+  list_tile : '31TCK 31TCJ' # list of tiles in grid to build
+  srtm_path:'/datalake/static_aux/MNT/SRTM_30_hgt'
+  worldclim_path:'/datalake/static_aux/worldclim-2.0'
+}
+
+builders:
+{
+builders_class_name : ["i2_features_to_grid"]
+}
diff --git a/jobs/iota2_aux_test.pbs b/jobs/iota2_aux_test.pbs
new file mode 100644
index 00000000..3b3eed36
--- /dev/null
+++ b/jobs/iota2_aux_test.pbs
@@ -0,0 +1,20 @@
+#!/bin/bash
+#PBS -N iota2-test
+#PBS -l select=1:ncpus=8:mem=12G
+#PBS -l walltime=1:00:00
+
+# be sure no modules loaded
+conda deactivate
+module purge
+
+# load modules
+module load conda
+conda activate /work/scratch/${USER}/virtualenv/mmdc-iota2
+
+
+Iota2.py -scheduler_type debug \
+    -config ${HOME}/src/MMDC/mmdc-singledate/configs/iota2/i2_grid.cfg
+    #-nb_parallel_tasks 1 \
+    #-only_summary
+
+# /home/uz/vinascj/src/MMDC/mmdc-singledate/configs/iota2/externalfeatures_with_userfeatures_50x50.cfg \
diff --git a/jobs/iota2_external_feature_test.pbs b/jobs/iota2_external_feature_test.pbs
new file mode 100644
index 00000000..4283144b
--- /dev/null
+++ b/jobs/iota2_external_feature_test.pbs
@@ -0,0 +1,18 @@
+#!/bin/bash
+#PBS -N iota2-test
+#PBS -l select=1:ncpus=8:mem=12G
+#PBS -l walltime=1:00:00
+
+# be sure no modules loaded
+conda deactivate
+module purge
+
+# load modules
+module load conda
+conda activate /work/scratch/${USER}/virtualenv/mmdc-iota2
+
+
+Iota2.py -scheduler_type debug \
+    -config ${HOME}/src/MMDC/mmdc-singledate/configs/iota2/externalfeatures_with_userfeatures_50x50.cfg \
+    -nb_parallel_tasks 1 \
+    # -only_summary
-- 
GitLab


From 4e51bd8c9d5524902715afd691c921d0c6e2b4ea Mon Sep 17 00:00:00 2001
From: Juan Sebastian Vinasco-Salinas <js.vinasco.s@gmail.com>
Date: Fri, 24 Feb 2023 13:17:32 +0000
Subject: [PATCH 15/81] update Iota2 API

---
 .../inference/mmdc_tile_inference_iota2.py    | 34 +++++++++++++++----
 1 file changed, 28 insertions(+), 6 deletions(-)

diff --git a/src/mmdc_singledate/inference/mmdc_tile_inference_iota2.py  b/src/mmdc_singledate/inference/mmdc_tile_inference_iota2.py 
index f434b108..1110eccf 100644
--- a/src/mmdc_singledate/inference/mmdc_tile_inference_iota2.py 	
+++ b/src/mmdc_singledate/inference/mmdc_tile_inference_iota2.py 	
@@ -13,7 +13,9 @@ import torch
 from torchutils import patches
 
 from mmdc_singledate.datamodules.components.datamodule_utils import (
-    apply_log_to_s1
+    apply_log_to_s1,
+    srtm_height_aspect,
+    join_even_odd_s2_angles
 )
 
 
@@ -26,9 +28,13 @@ def apply_mmdc_full_mode(
     """
     Apply MMDC with Iota-2.py
     """
-    # TODO Get the data in the same order as
+    # How manage the S1 acquisition dates for construct the mask validity
+    # in time
+
+    # DONE Get the data in the same order as
     # sentinel2.Sentinel2.GROUP_10M +
     # sensorsio.sentinel2.GROUP_20M
+    # This data correspond to (Chunk_Size, Img_Size, nb_dates)
     list_bands_s2 = [
         self.get_interpolated_Sentinel2_B2(),
         self.get_interpolated_Sentinel2_B3(),
@@ -42,6 +48,9 @@ def apply_mmdc_full_mode(
         self.get_interpolated_Sentinel2_B12(),
     ]
 
+    # TODO Masks contruction for S2
+    list_s2_mask = self.get_Sentinel2_binary_masks()
+
     # TODO Manage S1 ASC and S1 DESC ?
     list_bands_s1 = [
         self.get_interpolated_Sentinel1_ASC_vh(),
@@ -54,16 +63,29 @@ def apply_mmdc_full_mode(
         / (self.get_interpolated_Sentinel1_DES_vv() + 1e-4),
     ]
 
-    # TODO Masks contruction
-
-
     with torch.no_grad():
         # Permute dimensions to fit patchify
         # Shape before permutation is C,H,W,D. D being the dates
         # Shape after permutation is D,C,H,W
         bands_s2 = torch.Tensor(list_bands_s2).permute(-1, 0, 1, 2)
+        bands_s2_mask = torch.Tensor(list_s2_mask).permute(-1, 0, 1, 2)
         bands_s1 = torch.Tensor(list_bands_s1).permute(-1, 0, 1, 2)
-        bands_s1 = apply_log_to_s1(bands_s1)
+        bands_s1 = apply_log_to_s1(bands_s1).permute(-1, 0, 1, 2)
+
+
+        # TODO Masks contruction for S1
+        # build_s1_image_and_masks function for datamodules components datamodule components ?
+
+        # Replace nan by 0
+        bands_s1 = bands_s1.nan_to_num()
+        # These dimensions are useful for unpatchify
+        # Keep the dimensions of height and width found in chunk
+        band_h, band_w = bands_s2.shape[-2:]
+        # Keep the number of patches of patch_size
+        # in rows and cols found in chunk
+        h, w = patches.patchify(bands_s2[0, ...],
+                                patch_size=patch_size,
+                                margin=patch_margin).shape[:2]
 
 
     # TODO Apply patchify
-- 
GitLab


From 6a256126f5f6cbc8dc2506d590e58fb4b75ba197 Mon Sep 17 00:00:00 2001
From: Juan Sebastian Vinasco-Salinas <js.vinasco.s@gmail.com>
Date: Fri, 24 Feb 2023 13:18:02 +0000
Subject: [PATCH 16/81] test external features

---
 ...ternalfeatures_with_userfeatures_50x50.cfg |  2 +-
 jobs/iota2_external_feature_test.pbs          | 27 +++++++++++++++++++
 2 files changed, 28 insertions(+), 1 deletion(-)

diff --git a/configs/iota2/externalfeatures_with_userfeatures_50x50.cfg b/configs/iota2/externalfeatures_with_userfeatures_50x50.cfg
index 3d074b7b..08f1b1b9 100644
--- a/configs/iota2/externalfeatures_with_userfeatures_50x50.cfg
+++ b/configs/iota2/externalfeatures_with_userfeatures_50x50.cfg
@@ -39,7 +39,6 @@ python_data_managing :
     data_mode_access: "both"
 }
 
-
 # builders:
 # {
 # builders_class_name : ["i2_features_to_grid"]
@@ -49,6 +48,7 @@ external_features :
 {
   module:"/home/uz/vinascj/scratch/MMDC/iota2/juan_i2_data/inputs/other/soi.py"
   functions : [['test_mnt', {"arg1":1}]]
+  # functions : [['test_s2_concat', {"arg1":1}]]
   concat_mode : True
   external_features_flag:True
 }
diff --git a/jobs/iota2_external_feature_test.pbs b/jobs/iota2_external_feature_test.pbs
index 4283144b..f8b6953c 100644
--- a/jobs/iota2_external_feature_test.pbs
+++ b/jobs/iota2_external_feature_test.pbs
@@ -10,9 +10,36 @@ module purge
 # load modules
 module load conda
 conda activate /work/scratch/${USER}/virtualenv/mmdc-iota2
+# conda activate /work/scratch/${USER}/virtualenv/mmdc-iota2-develop
 
 
 Iota2.py -scheduler_type debug \
     -config ${HOME}/src/MMDC/mmdc-singledate/configs/iota2/externalfeatures_with_userfeatures_50x50.cfg \
     -nb_parallel_tasks 1 \
     # -only_summary
+
+
+
+# Functions availables
+# ['__class__', '__delattr__', '__dict__', '__dir__', '__doc__',
+# '__eq__', '__format__', '__ge__', '__getattribute__', '__gt__',
+# '__hash__', '__init__', '__init_subclass__', '__le__', '__lt__',
+# '__module__', '__ne__', '__new__', '__reduce__', '__reduce_ex__',
+# '__repr__', '__setattr__', '__sizeof__', '__str__', '__subclasshook__',
+# '__weakref__', 'all_dates', 'allow_nans', 'binary_masks', 'concat_mode',
+# 'dim_ts', 'enabled_gap', 'exogeneous_data_array', 'exogeneous_data_name',
+# 'external_functions', 'fill_missing_dates', 'func_param',
+# 'get_Sentinel2_binary_masks', 'get_filled_masks', 'get_filled_stack',
+# 'get_interpolated_Sentinel2_B11', 'get_interpolated_Sentinel2_B12',
+# 'get_interpolated_Sentinel2_B2', 'get_interpolated_Sentinel2_B3',
+# 'get_interpolated_Sentinel2_B4', 'get_interpolated_Sentinel2_B5',
+# 'get_interpolated_Sentinel2_B6', 'get_interpolated_Sentinel2_B7',
+# 'get_interpolated_Sentinel2_B8', 'get_interpolated_Sentinel2_B8A',
+# 'get_interpolated_Sentinel2_Brightness', 'get_interpolated_Sentinel2_NDVI',
+# 'get_interpolated_Sentinel2_NDWI', 'get_interpolated_dates', 'get_raw_Sentinel2_B11',
+# 'get_raw_Sentinel2_B12', 'get_raw_Sentinel2_B2', 'get_raw_Sentinel2_B3',
+# 'get_raw_Sentinel2_B4', 'get_raw_Sentinel2_B5', 'get_raw_Sentinel2_B6',
+# 'get_raw_Sentinel2_B7', 'get_raw_Sentinel2_B8', 'get_raw_Sentinel2_B8A',
+# 'get_raw_dates', 'interpolated_data', 'interpolated_dates', 'missing_masks_values',
+# 'missing_refl_values', 'out_data', 'process', 'raw_data', 'raw_dates',
+# 'test_user_feature_with_fake_data']
-- 
GitLab


From a63af60bdf4f8e8f0ca18d26741344823e4f95a4 Mon Sep 17 00:00:00 2001
From: Juan Sebastian Vinasco-Salinas <js.vinasco.s@gmail.com>
Date: Fri, 24 Feb 2023 13:18:44 +0000
Subject: [PATCH 17/81] test

---
 test/test_mmdc_inference.py | 7 +++++--
 1 file changed, 5 insertions(+), 2 deletions(-)

diff --git a/test/test_mmdc_inference.py b/test/test_mmdc_inference.py
index 90d654be..0e981f85 100644
--- a/test/test_mmdc_inference.py
+++ b/test/test_mmdc_inference.py
@@ -13,7 +13,10 @@ from mmdc_singledate.inference.components.inference_utils import (
     GeoTiffDataset,
     generate_chunks,
 )
-from mmdc_singledate.inference.mmdc_tile_inference import MMDCProcess, predict_tile
+from mmdc_singledate.inference.mmdc_tile_inference import (
+    MMDCProcess,
+    predict_single_date_tile,
+)
 
 # dir
 dataset_dir = "/work/CESBIO/projects/MAESTRIA/test_onetile/total/T31TCJ/"
@@ -148,7 +151,7 @@ def test_predict_tile():
         wc_filename=os.path.join(datapath, "wc_clim_1_T31TDJ_roi_0.tif"),
     )
 
-    predict_tile(
+    predict_single_date_tile(
         input_data=input_data,
         export_path=export_path,
         process=process,
-- 
GitLab


From a744a9dc3dd8a5c79850978ba0354faee71706ea Mon Sep 17 00:00:00 2001
From: Juan Sebastian Vinasco-Salinas <js.vinasco.s@gmail.com>
Date: Fri, 24 Feb 2023 13:19:39 +0000
Subject: [PATCH 18/81] iota env install

---
 create-iota2-env.sh  => create-iota2-env.sh | 39 +++++++++------------
 1 file changed, 17 insertions(+), 22 deletions(-)
 rename create-iota2-env.sh  => create-iota2-env.sh (68%)

diff --git a/create-iota2-env.sh  b/create-iota2-env.sh
similarity index 68%
rename from create-iota2-env.sh 
rename to create-iota2-env.sh
index f7035081..0af4a8c1 100644
--- a/create-iota2-env.sh 	
+++ b/create-iota2-env.sh
@@ -50,55 +50,50 @@ python --version
 mamba install iota2 -c iota2
 
 # clone the lastest version of iota2
-rm -rf thirdparties/iota2
-git clone https://framagit.org/iota2-project/iota2.git thirdparties/iota2
-cd thirdparties/iota2
+rm -rf iota2_thirdparties/iota2
+git clone -b issue#600/tile_exogenous_features https://framagit.org/iota2-project/iota2.git iota2_thirdparties/iota2
+cd iota2_thirdparties/iota2
 git checkout issue#600/tile_exogenous_features
 git pull origin issue#600/tile_exogenous_features
 cd ../../
 pwd
+
 #which Iota2.py
 which Iota2.py
-
 # create a backup
 cd  /home/uz/$USER/scratch/virtualenv/mmdc-iota2/lib/python3.9/site-packages/iota2-0.0.0-py3.9.egg/
 mv iota2 iota2_backup
 
 # create a symbolic link
-ln -s ~/src/MMDC/mmdc-singledate/thirdparties/iota2/iota2/ iota2
+ln -s ~/src/MMDC/mmdc-singledate/iota2_thirdparties/iota2/iota2/ iota2
 
 cd ~/src/MMDC/mmdc-singledate
+
+
+#which Iota2.py
+which Iota2.py
+
 # install missing dependancies
 pip install -r requirements-mmdc-iota2.txt
 
 # test install
 Iota2.py -h
 
-# # install MMDC dependencies
-# conda install --yes pytorch=1.12.1 torchvision -c pytorch -c nvidia
-
-# conda deactivate
-# conda activate $target
-
-# # Requirements
-# pip install -r requirements-mmdc-sgld.txt
-
-module unload git
 # Install sensorsio
-rm -rf thirdparties/sensorsio
-git clone https://src.koda.cnrs.fr/mmdc/sensorsio.git thirdparties/sensorsio
-pip install thirdparties/sensorsio
+rm -rf iota2_thirdparties/sensorsio
+git clone https://src.koda.cnrs.fr/mmdc/sensorsio.git iota2_thirdparties/sensorsio
+pip install iota2_thirdparties/sensorsio
 
 # Install torchutils
-rm -rf thirdparties/torchutils
-git clone https://src.koda.cnrs.fr/mmdc/torchutils.git thirdparties/torchutils
-pip install thirdparties/torchutils
+rm -rf iota2_thirdparties/torchutils
+git clone https://src.koda.cnrs.fr/mmdc/torchutils.git iota2_thirdparties/torchutils
+pip install iota2_thirdparties/torchutils
 
 # Install the current project in edit mode
 pip install -e .[testing]
 
 # Activate pre-commit hooks
-pre-commit install
+# pre-commit install
 
 # End
 conda deactivate
-- 
GitLab


From 53de3468d660317820efedede3d63c0db0ab5939 Mon Sep 17 00:00:00 2001
From: Juan Sebastian Vinasco-Salinas <js.vinasco.s@gmail.com>
Date: Fri, 24 Feb 2023 13:22:47 +0000
Subject: [PATCH 19/81] add iota2 suplementaty dependencies

---
 requirements-mmdc-iota2.txt | 3 +++
 1 file changed, 3 insertions(+)
 create mode 100644 requirements-mmdc-iota2.txt

diff --git a/requirements-mmdc-iota2.txt b/requirements-mmdc-iota2.txt
new file mode 100644
index 00000000..61906437
--- /dev/null
+++ b/requirements-mmdc-iota2.txt
@@ -0,0 +1,3 @@
+config==0.5.1
+itk
+pydantic
-- 
GitLab


From 88bf3886111448b78fe83085d63044bc520da6f4 Mon Sep 17 00:00:00 2001
From: Juan Sebastian Vinasco-Salinas <js.vinasco.s@gmail.com>
Date: Fri, 24 Feb 2023 13:23:31 +0000
Subject: [PATCH 20/81] lauch iota2

---
 jobs/iota2_aux_angles_test.pbs | 20 ++++++++++++++++++++
 jobs/iota2_aux_full.pbs        | 20 ++++++++++++++++++++
 2 files changed, 40 insertions(+)
 create mode 100644 jobs/iota2_aux_angles_test.pbs
 create mode 100644 jobs/iota2_aux_full.pbs

diff --git a/jobs/iota2_aux_angles_test.pbs b/jobs/iota2_aux_angles_test.pbs
new file mode 100644
index 00000000..6eaacfe2
--- /dev/null
+++ b/jobs/iota2_aux_angles_test.pbs
@@ -0,0 +1,20 @@
+#!/bin/bash
+#PBS -N iota2-test
+#PBS -l select=1:ncpus=8:mem=12G
+#PBS -l walltime=1:00:00
+
+# be sure no modules loaded
+conda deactivate
+module purge
+
+# load modules
+module load conda
+conda activate /work/scratch/${USER}/virtualenv/mmdc-iota2
+
+
+Iota2.py -scheduler_type debug \
+    -config ${HOME}/src/MMDC/mmdc-singledate/configs/iota2/iota2_grid_angles.cfg
+    #-nb_parallel_tasks 1 \
+    #-only_summary
+
+# /home/uz/vinascj/src/MMDC/mmdc-singledate/configs/iota2/externalfeatures_with_userfeatures_50x50.cfg \
diff --git a/jobs/iota2_aux_full.pbs b/jobs/iota2_aux_full.pbs
new file mode 100644
index 00000000..2e438ed2
--- /dev/null
+++ b/jobs/iota2_aux_full.pbs
@@ -0,0 +1,20 @@
+#!/bin/bash
+#PBS -N iota2-test
+#PBS -l select=1:ncpus=8:mem=12G
+#PBS -l walltime=1:00:00
+
+# be sure no modules loaded
+conda deactivate
+module purge
+
+# load modules
+module load conda
+conda activate /work/scratch/${USER}/virtualenv/mmdc-iota2
+
+
+Iota2.py -scheduler_type debug \
+    -config ${HOME}/src/MMDC/mmdc-singledate/configs/iota2/iota2_grid_full.cfg
+    #-nb_parallel_tasks 1 \
+    #-only_summary
+
+# /home/uz/vinascj/src/MMDC/mmdc-singledate/configs/iota2/externalfeatures_with_userfeatures_50x50.cfg \
-- 
GitLab


From 94df5c1957df27622fe99ffdd7ab10f6217da143 Mon Sep 17 00:00:00 2001
From: Juan Sebastian Vinasco-Salinas <js.vinasco.s@gmail.com>
Date: Fri, 24 Feb 2023 13:25:40 +0000
Subject: [PATCH 21/81] rename func

---
 src/mmdc_singledate/inference/mmdc_tile_inference.py | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/src/mmdc_singledate/inference/mmdc_tile_inference.py b/src/mmdc_singledate/inference/mmdc_tile_inference.py
index b4aec447..2511ae3d 100644
--- a/src/mmdc_singledate/inference/mmdc_tile_inference.py
+++ b/src/mmdc_singledate/inference/mmdc_tile_inference.py
@@ -101,7 +101,7 @@ class MMDCProcess:
     ]
 
 
-def predict_tile(
+def predict_single_date_tile(
     input_data: GeoTiffDataset,
     export_path: Path,
     sensors: list[str],
@@ -263,7 +263,7 @@ def mmdc_tile_inference(
             wc_filename=df_row.worldclim_filename,
         )
         # predict tile
-        predict_tile(
+        predict_single_date_tile(
             input_data=mmdc_input_data,
             export_path=export_path,
             sensors=sensor,
-- 
GitLab


From 3592a56b200a78bf9ddcfaecc7caf48c7704f39d Mon Sep 17 00:00:00 2001
From: Juan Sebastian Vinasco-Salinas <js.vinasco.s@gmail.com>
Date: Thu, 9 Mar 2023 17:18:05 +0000
Subject: [PATCH 22/81] update test inference

---
 test/test_mmdc_inference.py | 291 ++++++++++++++++++------------------
 1 file changed, 149 insertions(+), 142 deletions(-)

diff --git a/test/test_mmdc_inference.py b/test/test_mmdc_inference.py
index 0e981f85..53547ac6 100644
--- a/test/test_mmdc_inference.py
+++ b/test/test_mmdc_inference.py
@@ -2,12 +2,13 @@
 # Copyright: (c) 2023 CESBIO / Centre National d'Etudes Spatiales
 
 import os
-
+from pathlib import Path
 import pytest
 import torch
 
 from mmdc_singledate.inference.components.inference_components import (  # predict_tile,
     get_mmdc_full_model,
+    predict_mmdc_model,
 )
 from mmdc_singledate.inference.components.inference_utils import (
     GeoTiffDataset,
@@ -15,9 +16,15 @@ from mmdc_singledate.inference.components.inference_utils import (
 )
 from mmdc_singledate.inference.mmdc_tile_inference import (
     MMDCProcess,
+    inference_dataframe,
+    mmdc_tile_inference,
     predict_single_date_tile,
 )
 
+from mmdc_singledate.models.components.model_dataclass import (
+    VAELatentSpace,
+)
+
 # dir
 dataset_dir = "/work/CESBIO/projects/MAESTRIA/test_onetile/total/T31TCJ/"
 
@@ -25,6 +32,7 @@ dataset_dir = "/work/CESBIO/projects/MAESTRIA/test_onetile/total/T31TCJ/"
 def test_GeoTiffDataset():
     """ """
     datapath = "/work/CESBIO/projects/MAESTRIA/test_onetile/total/T31TCJ"
+
     input_data = GeoTiffDataset(
         s2_filename=os.path.join(
             datapath, "SENTINEL2A_20180302-105023-464_L2A_T31TCJ_C_V2-2_roi_0.tif"
@@ -40,8 +48,12 @@ def test_GeoTiffDataset():
         srtm_filename=os.path.join(datapath, "srtm_T31TCJ_roi_0.tif"),
         wc_filename=os.path.join(datapath, "wc_clim_1_T31TCJ_roi_0.tif"),
     )
-    assert input_data
-    assert type(input_data.metadata) == dict
+    # print GetTiffDataset
+    print(input_data)
+    # Check Data Existence
+    assert input_data.s1_asc_availability == True
+    assert input_data.s1_desc_availability == True
+    assert input_data.s2_availabitity == True
 
 
 def test_generate_chunks():
@@ -58,6 +70,11 @@ def test_mmdc_full_model():
     """
     test instantiate network
     """
+    # set device
+    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+    print(device)
+
+    # checkpoints
     checkpoint_path = "/home/uz/vinascj/scratch/MMDC/results/latent/logs/experiments/runs/mmdc_full/2023-01-06_13-00-34/checkpoints"
     checkpoint_filename = "epoch_008.ckpt"  # "last.ckpt"
 
@@ -66,20 +83,22 @@ def test_mmdc_full_model():
     mmdc_full_model = get_mmdc_full_model(
         os.path.join(checkpoint_path, checkpoint_filename)
     )
-
-    print(mmdc_full_model)
+    # move to device
+    mmdc_full_model.to(device)
 
     assert mmdc_full_model.training == False
 
-    s2_x = torch.rand(1, 10, 256, 256)
-    s2_m = torch.ones(1, 10, 256, 256)
-    s2_angles_x = torch.rand(1, 6, 256, 256)
-    s1_x = torch.rand(1, 6, 256, 256)
-    s1_vm = torch.ones(1, 6, 256, 256)
-    s1_asc_angles_x = torch.rand(1, 3)
-    s1_desc_angles_x = torch.rand(1, 3)
-    worldclim_x = torch.rand(1, 103, 256, 256)
-    srtm_x = torch.rand(1, 4, 256, 256)
+    s2_x = torch.rand(1, 10, 256, 256).to(device)
+    s2_m = torch.ones(1, 10, 256, 256).to(device)
+    s2_angles_x = torch.rand(1, 6, 256, 256).to(device)
+    s1_x = torch.rand(1, 6, 256, 256).to(device)
+    s1_vm = torch.ones(1, 1, 256, 256).to(device)
+    s1_asc_angles_x = torch.rand(1, 3).to(device)
+    s1_desc_angles_x = torch.rand(1, 3).to(device)
+    worldclim_x = torch.rand(1, 103, 256, 256).to(device)
+    srtm_x = torch.rand(1, 4, 256, 256).to(device)
+
+    print(srtm_x)
 
     prediction = mmdc_full_model.predict(
         s2_x,
@@ -94,43 +113,121 @@ def test_mmdc_full_model():
     )
 
     latent_variable = prediction[0].latent.latent_s1.mu
+    print(latent_variable.shape)
 
     assert type(latent_variable) == torch.Tensor
 
 
+sensors_test = [
+    (["S2L2A", "S1FULL"]),
+    (["S2L2A", "S1ASC"]),
+    (["S2L2A", "S1DESC"]),
+    (["S2L2A"]),
+    (["S1FULL"]),
+    (["S1ASC"]),
+    (["S1DESC"]),
+    # ([ "S1FULL","S2L2A"  ]),
+]
+
+
+@pytest.mark.parametrize("sensors", sensors_test)
+def test_predict_mmdc_model(sensors):
+    """ """
+
+    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+    # data input
+    s2_x = torch.rand(1, 10, 256, 256).to(device)
+    s2_m = torch.ones(1, 10, 256, 256).to(device)
+    s2_angles_x = torch.rand(1, 6, 256, 256).to(device)
+    s1_x = torch.rand(1, 6, 256, 256).to(device)
+    s1_vm = torch.ones(1, 1, 256, 256).to(device)
+    s1_asc_angles_x = torch.rand(1, 3).to(device)
+    s1_desc_angles_x = torch.rand(1, 3).to(device)
+    worldclim_x = torch.rand(1, 103, 256, 256).to(device)
+    srtm_x = torch.rand(1, 4, 256, 256).to(device)
+
+    # model
+    # checkpoints
+    checkpoint_path = "/home/uz/vinascj/scratch/MMDC/results/latent/logs/experiments/runs/mmdc_full/2023-01-06_13-00-34/checkpoints"
+    checkpoint_filename = "epoch_008.ckpt"  # "last.ckpt"
+
+    mmdc_full_model = get_mmdc_full_model(
+        os.path.join(checkpoint_path, checkpoint_filename)
+    )
+    # move to device
+    mmdc_full_model.to(device)
+
+    pred = predict_mmdc_model(
+        mmdc_full_model,
+        sensors,
+        s2_x,
+        s2_m,
+        s2_angles_x,
+        s1_x,
+        s1_vm,
+        s1_asc_angles_x,
+        s1_desc_angles_x,
+        worldclim_x,
+        srtm_x,
+    )
+    print(pred.shape)
+
+    assert type(pred) == torch.Tensor
+    # assert pred.shape[0] == 4 or 2
+
+
 def dummy_process(
     s2_refl,
+    s2_mask,
     s2_ang,
-    s1_asc,
-    s1_desc,
+    s1_back,
+    s1_vm,
     s1_asc_lia,
     s1_desc_lia,
-    srtm_patch,
     wc_patch,
+    srtm_patch,
 ):
     """
     Create a dummy function
     """
     prediction = (
         s2_refl[:, 0, ...]
+        * s2_mask[:, 0, ...]
         * s2_ang[:, 0, ...]
-        * s1_asc[:, 0, ...]
-        * s1_desc[:, 0, ...]
+        * s1_back[:, 0, ...]
+        * s1_vm[:, 0, ...]
         * srtm_patch[:, 0, ...]
         * wc_patch[:, 0, ...]
     )
-    return prediction
 
+    latent_example = VAELatentSpace(
+        mu=prediction,
+        logvar=prediction,
+    )
 
-@pytest.mark.skip(reason="Have to update to the new data format")
-def test_predict_tile():
-    """ """
-    export_path = (
-        "/work/CESBIO/projects/MAESTRIA/test_onetile/total/export/test_latent.tif"
+    latent_space_stack = torch.cat(
+        (
+            torch.unsqueeze(latent_example.mu, 1),
+            torch.unsqueeze(latent_example.logvar, 1),
+            torch.unsqueeze(latent_example.mu, 1),
+            torch.unsqueeze(latent_example.logvar, 1),
+        ),
+        1,
     )
+
+    return latent_space_stack
+
+
+# TODO add cases with no data
+@pytest.mark.parametrize("sensors", sensors_test)
+def test_predict_single_date_tile(sensors):
+    """ """
+    export_path = f"/work/CESBIO/projects/MAESTRIA/test_onetile/total/export/test_latent_singledate_{'_'.join(sensors)}.tif"
     process = MMDCProcess(
-        count=1,
+        count=4,
         nb_lines=1024,
+        patch_size=256,
         process=dummy_process,
     )
 
@@ -147,130 +244,40 @@ def test_predict_tile():
             datapath,
             "S1A_IW_GRDH_1SDV_20180302T060027_20180302T060052_020832_023B8F_C485_40.83835645911643_165.05888005216622_roi_0.tif",
         ),
-        srtm_filename=os.path.join(datapath, "srtm_T31TDJ_roi_0.tif"),
-        wc_filename=os.path.join(datapath, "wc_clim_1_T31TDJ_roi_0.tif"),
+        srtm_filename=os.path.join(datapath, "srtm_T31TCJ_roi_0.tif"),
+        wc_filename=os.path.join(datapath, "wc_clim_1_T31TCJ_roi_0.tif"),
     )
 
     predict_single_date_tile(
         input_data=input_data,
         export_path=export_path,
+        sensors=sensors,  # ["S2L2A"],
         process=process,
     )
 
-
-# def test_read_s2_img_tile():
-#     """
-#     test read a window of s2
-#     """
-#     dataset_sample = os.path.join(
-#         dataset_dir, "SENTINEL2A_20180302-105023-464_L2A_T31TDJ_C_V2-2_roi_0.tif"
-#     )
-
-#     chunks = generate_chunks(width=10980, height=10980, nlines=1024)
-#     roi = [rio.windows.Window(*chunk) for chunk in chunks]
-
-#     s2_data = read_s2_img_tile(s2_filename=dataset_sample, rois=roi)
-#     # get the firt element
-#     s2_data_window = next(s2_data)
-
-#     assert s2_data_window.s2_reflectances.shape == torch.Size([10, 1024, 10980])
-#     assert s2_data_window.s2_angles.shape == torch.Size([6, 1024, 10980])
-#     assert s2_data_window.s2_mask.shape == torch.Size([1, 1024, 10980])
-
-
-# def test_read_s1_img_tile():
-#     """
-#     test read a window of s1
-#     """
-#     dataset_sample = os.path.join(
-#         dataset_dir,
-#         "S1A_IW_GRDH_1SDV_20180303T174722_20180303T174747_020854_023C45_4525_41.98988334384891_15.429479982324139_roi_0.tif",
-#     )
-
-#     chunks = generate_chunks(width=10980, height=10980, nlines=1024)
-#     roi = [rio.windows.Window(*chunk) for chunk in chunks]
-
-#     s1_data = read_s1_img_tile(s1_filename=dataset_sample, rois=roi)
-#     # get the firt element
-#     s1_data_window = next(s1_data)
-
-#     assert s1_data_window.s1_backscatter.shape == torch.Size([3, 1024, 10980])
-#     assert s1_data_window.s1_valmask.shape == torch.Size([3, 1024, 10980])
-#     assert s1_data_window.s1_edgemask.shape == torch.Size([3, 1024, 10980])
-#     assert s1_data_window.s1_lia_angles.shape == torch.Size([3])
-
-
-# def test_read_srtm_img_tile():
-#     """
-#     test read a window of srtm
-#     """
-#     dataset_sample = os.path.join(dataset_dir, "srtm_T31TDJ_roi_0.tif")
-
-#     chunks = generate_chunks(width=10980, height=10980, nlines=1024)
-#     roi = [rio.windows.Window(*chunk) for chunk in chunks]
-
-#     srtm_data = read_srtm_img_tile(srtm_filename=dataset_sample, rois=roi)
-#     # get the firt element
-#     srtm_data_window = next(srtm_data)
-
-#     assert srtm_data_window.shape == torch.Size([4, 1024, 10980])
+    assert Path(export_path).exists() == True
 
 
-# def test_read_worlclim_img_tile():
-#     """
-#     test read a window of srtm
-#     """
-#     dataset_sample = os.path.join(dataset_dir, "wc_clim_1_T31TDJ_roi_0.tif")
-
-#     chunks = generate_chunks(width=10980, height=10980, nlines=1024)
-#     roi = [rio.windows.Window(*chunk) for chunk in chunks]
-
-#     wc_data = read_worldclim_img_tile(wc_filename=dataset_sample, rois=roi)
-#     # get the firt element
-#     wc_data_window = next(wc_data)
-
-#     assert wc_data_window.shape == torch.Size([103, 1024, 10980])
-
-
-# def test_predict_tile():
-#     """ """
-
-#     input_path = "/work/CESBIO/projects/MAESTRIA/test_onetile/total/T31TDJ/"
-#     export_path = "/work/CESBIO/projects/MAESTRIA/test_onetile/total/export/"
-
-#     mmdc_filenames = GeoTiffDataset(
-#         s2_filename=os.path.join(
-#             input_path, "SENTINEL2A_20180302-105023-464_L2A_T31TDJ_C_V2-2_roi_0.tif"
-#         ),
-#         s1_filename=os.path.join(
-#             input_path,
-#             "S1B_IW_GRDH_1SDV_20180304T173831_20180304T173856_009885_011E2F_CEA4_32.528381567125045_15.330754606990766_roi_0.tif",
-#         ),
-#         srtm_filename=os.path.join(input_path, "srtm_T31TDJ_roi_0.tif"),
-#         wc_filename=os.path.join(input_path, "wc_clim_1_T31TDJ_roi_0.tif"),
-#     )
-
-#     # get grid mesh
-#     rois, meta = get_rois_from_prediction_mesh(
-#         input_filename=mmdc_filenames.s2_filename, nlines=1024
-#     )
-
-#     export_filename = "latent_variable.tif"
-#     exported_file = predict_tile(
-#         input_path=input_path,
-#         export_path=export_path,
-#         mmdc_filenames=mmdc_filenames,
-#         export_filename=export_filename,
-#         rois=rois,
-#         meta=meta,
-#     )
-#     # assert existence of exported file
-#     assert Path(exported_file).exists
-
-
-# #
-
+@pytest.mark.parametrize("sensors", sensors_test)
+def test_mmdc_tile_inference(sensors):
+    """
+    Test the inference code in a tile
+    """
+    # feed parameters
+    input_path = Path("/work/CESBIO/projects/MAESTRIA/test_onetile/total/T31TCJ/")
+    export_path = Path(
+        "/work/CESBIO/projects/MAESTRIA/test_onetile/export/test_latent_tileinfer_{'_'.join(sensors)}.tif"
+    )
+    model_checkpoint_path = "/home/uz/vinascj/scratch/MMDC/results/latent/logs/experiments/runs/mmdc_full/2023-01-06_13-00-34/checkpoints/epoch_008.ckpt"
+    # apply prediction
 
-# # asc = "S1A_IW_GRDH_1SDV_20180303T174722_20180303T174747_020854_023C45_4525_39.6607434587337_15.429479982324139_roi_0.tif"
-# # desc = "S1B_IW_GRDH_1SDV_20180303T055142_20180303T055207_009863_011D75_B010_45.08497158404881_165.07686884776342_roi_4.tif
-# "
+    mmdc_tile_inference(
+        input_path=input_path,
+        export_path=export_path,
+        model_checkpoint_path=model_checkpoint_path,
+        sensor=sensors,
+        tile_list="T31TCJ",
+        days_gap=15,
+        latent_space_size=4,
+        nb_lines=1024,
+    )
-- 
GitLab


From 58ae3b86af6346f6fc3b72caff6d91e362f4b085 Mon Sep 17 00:00:00 2001
From: Juan Sebastian Vinasco-Salinas <js.vinasco.s@gmail.com>
Date: Thu, 9 Mar 2023 17:18:37 +0000
Subject: [PATCH 23/81] update tiler config

---
 .../externalfeatures_with_userfeatures_50x50.cfg      | 11 +++--------
 1 file changed, 3 insertions(+), 8 deletions(-)

diff --git a/configs/iota2/externalfeatures_with_userfeatures_50x50.cfg b/configs/iota2/externalfeatures_with_userfeatures_50x50.cfg
index 08f1b1b9..cdb197f1 100644
--- a/configs/iota2/externalfeatures_with_userfeatures_50x50.cfg
+++ b/configs/iota2/externalfeatures_with_userfeatures_50x50.cfg
@@ -11,9 +11,9 @@ chain :
   ground_truth : '/home/uz/vinascj/scratch/MMDC/iota2/juan_i2_data/inputs/vector_data/S2_50x50.shp'
   data_field : 'groscode'
   s2_path : '/home/uz/vinascj/scratch/MMDC/iota2/juan_i2_data/inputs/raster_data/S2_2dates_50x50_symlink'
+  user_feat_path: '/home/uz/vinascj/scratch/MMDC/iota2/juan_i2_data/inputs/raster_data/mnt_50x50'
   # s2_output_path : '/home/uz/vinascj/scratch/MMDC/iota2/juan_i2_data/inputs/raster_data/S2_2dates_50x50_output'
-  # user_feat_path: '/home/uz/vinascj/scratch/MMDC/iota2/juan_i2_data/inputs/raster_data/mnt_50x50'
-  # s1_path : '/home/uz/vinascj/scratch/MMDC/iota2/juan_i2_data/inputs/config/SAR_test_T31TCJ.cfg'
+  s1_path : '/home/uz/vinascj/scratch/MMDC/iota2/juan_i2_data/inputs/config/SAR_test_T31TCJ.cfg'
 
   srtm_path:'/datalake/static_aux/MNT/SRTM_30_hgt'
   worldclim_path:'/datalake/static_aux/worldclim-2.0'
@@ -28,7 +28,7 @@ chain :
 userFeat:
 {
   arbo:"/*"
-  patterns:"mnt2,slope"
+  patterns:"slope"
 }
 python_data_managing :
 {
@@ -39,11 +39,6 @@ python_data_managing :
     data_mode_access: "both"
 }
 
-# builders:
-# {
-# builders_class_name : ["i2_features_to_grid"]
-# }
-
 external_features :
 {
   module:"/home/uz/vinascj/scratch/MMDC/iota2/juan_i2_data/inputs/other/soi.py"
-- 
GitLab


From d220f2e66453a554a65cf0b6dbf7358bbd28f3dc Mon Sep 17 00:00:00 2001
From: Juan Sebastian Vinasco-Salinas <js.vinasco.s@gmail.com>
Date: Thu, 16 Mar 2023 10:17:11 +0000
Subject: [PATCH 24/81] merge

---
 create-conda-env.sh | 2 ++
 1 file changed, 2 insertions(+)

diff --git a/create-conda-env.sh b/create-conda-env.sh
index 36d9e095..61f3fc75 100755
--- a/create-conda-env.sh
+++ b/create-conda-env.sh
@@ -51,6 +51,8 @@ mamba install --yes "pytorch>=2.0.0.dev202302=py3.10_cuda11.7_cudnn8.5.0_0" "tor
 conda deactivate
 conda activate $target
 
+# proxy
+source ~/set_proxy_iota2.sh
 # Requirements
 pip install -r requirements-mmdc-sgld-test.txt
 
-- 
GitLab


From e1c27a0aca6793f3b05f9d71ac823b61509a5467 Mon Sep 17 00:00:00 2001
From: Juan Sebastian Vinasco-Salinas <js.vinasco.s@gmail.com>
Date: Thu, 9 Mar 2023 17:19:32 +0000
Subject: [PATCH 25/81] update iota conda env

---
 create-iota2-env.sh | 104 +++++++++++++++++++++++++++++++++++++++-----
 1 file changed, 92 insertions(+), 12 deletions(-)

diff --git a/create-iota2-env.sh b/create-iota2-env.sh
index 0af4a8c1..06e9ba1f 100644
--- a/create-iota2-env.sh
+++ b/create-iota2-env.sh
@@ -58,8 +58,6 @@ git pull origin issue#600/tile_exogenous_features
 cd ../../
 pwd
 
-#which Iota2.py
-which Iota2.py
 # create a backup
 cd  /home/uz/$USER/scratch/virtualenv/mmdc-iota2/lib/python3.9/site-packages/iota2-0.0.0-py3.9.egg/
 mv iota2 iota2_backup
@@ -69,19 +67,14 @@ ln -s ~/src/MMDC/mmdc-singledate/iota2_thirdparties/iota2/iota2/ iota2
 
 cd ~/src/MMDC/mmdc-singledate
 
-
-#which Iota2.py
-which Iota2.py
+conda install -c conda-forge pydantic
 
 # install missing dependancies
 pip install -r requirements-mmdc-iota2.txt
 
-# test install
-Iota2.py -h
-
 # Install sensorsio
 rm -rf iota2_thirdparties/sensorsio
-git clone https://src.koda.cnrs.fr/mmdc/sensorsio.git iota2_thirdparties/sensorsio
+git clone -b handle_local_srtm https://framagit.org/iota2-project/sensorsio.git iota2_thirdparties/sensorsio
 pip install iota2_thirdparties/sensorsio
 
 # Install torchutils
@@ -92,8 +85,95 @@ pip install iota2_thirdparties/torchutils
 # Install the current project in edit mode
 pip install -e .[testing]
 
-# Activate pre-commit hooks
-# pre-commit install
-
 # End
 conda deactivate
+
+# #!/usr/bin/env bash
+
+# export python_version="3.9"
+# export name="mmdc-iota2"
+# if ! [ -z "$1" ]
+# then
+#     export name=$1
+# fi
+
+# source ~/set_proxy_iota2.sh
+# if [ -z "$https_proxy" ]
+# then
+#     echo "Please set https_proxy environment variable before running this script"
+#     exit 1
+# fi
+
+# export target=/work/scratch/$USER/virtualenv/$name
+
+# if ! [ -z "$2" ]
+# then
+#     export target="$2/$name"
+# fi
+
+# echo "Installing $name in $target ..."
+
+# if [ -d "$target" ]; then
+#    echo "Cleaning previous conda env"
+#    rm -rf $target
+# fi
+
+# # Create blank virtualenv
+# module purge
+# module load conda
+# module load gcc
+# conda activate
+# conda create --yes --prefix $target python==${python_version} pip
+
+# # Enter virtualenv
+# conda deactivate
+# conda activate $target
+
+# # install mamba
+# conda install mamba
+
+# which python
+# python --version
+
+# # Install iota2
+# #mamba install iota2_develop=257da617 -c iota2
+# mamba install iota2 -c iota2
+
+# clone the lastest version of iota2
+# rm -rf iota2_thirdparties/iota2
+# git clone -b issue#600/tile_exogenous_features https://framagit.org/iota2-project/iota2.git iota2_thirdparties/iota2
+# cd iota2_thirdparties/iota2
+# git checkout issue#600/tile_exogenous_features
+# git pull origin issue#600/tile_exogenous_features
+# cd ../../
+# pwd
+
+# # create a backup
+# cd  /home/uz/$USER/scratch/virtualenv/mmdc-iota2/lib/python3.9/site-packages/iota2-0.0.0-py3.9.egg/
+# mv iota2 iota2_backup
+
+# # create a symbolic link
+# ln -s ~/src/MMDC/mmdc-singledate/iota2_thirdparties/iota2/iota2/ iota2
+
+# cd ~/src/MMDC/mmdc-singledate
+
+# conda install -c conda-forge pydantic
+
+# # install missing dependancies
+# pip install -r requirements-mmdc-iota2.txt #--upgrade --no-cache-dir
+
+# # Install sensorsio
+# rm -rf iota2_thirdparties/sensorsio
+# git clone https://src.koda.cnrs.fr/mmdc/sensorsio.git iota2_thirdparties/sensorsio
+# pip install iota2_thirdparties/sensorsio
+
+# # Install torchutils
+# rm -rf iota2_thirdparties/torchutils
+# git clone https://src.koda.cnrs.fr/mmdc/torchutils.git iota2_thirdparties/torchutils
+# pip install iota2_thirdparties/torchutils
+
+# # Install the current project in edit mode
+# pip install -e .[testing]
+
+# # End
+# conda deactivate
-- 
GitLab


From 29ec037ee7f861d2b500b63ff434cb8efe3ee46b Mon Sep 17 00:00:00 2001
From: Juan Sebastian Vinasco-Salinas <js.vinasco.s@gmail.com>
Date: Thu, 9 Mar 2023 17:20:34 +0000
Subject: [PATCH 26/81] linintg

---
 test/test_mmdc_inference.py | 6 ++----
 1 file changed, 2 insertions(+), 4 deletions(-)

diff --git a/test/test_mmdc_inference.py b/test/test_mmdc_inference.py
index 53547ac6..b7b869b1 100644
--- a/test/test_mmdc_inference.py
+++ b/test/test_mmdc_inference.py
@@ -3,6 +3,7 @@
 
 import os
 from pathlib import Path
+
 import pytest
 import torch
 
@@ -20,10 +21,7 @@ from mmdc_singledate.inference.mmdc_tile_inference import (
     mmdc_tile_inference,
     predict_single_date_tile,
 )
-
-from mmdc_singledate.models.components.model_dataclass import (
-    VAELatentSpace,
-)
+from mmdc_singledate.models.components.model_dataclass import VAELatentSpace
 
 # dir
 dataset_dir = "/work/CESBIO/projects/MAESTRIA/test_onetile/total/T31TCJ/"
-- 
GitLab


From 8c72b2f8e613d6a806e0bcf3c9a27e1c5d413a8e Mon Sep 17 00:00:00 2001
From: Juan Sebastian Vinasco-Salinas <js.vinasco.s@gmail.com>
Date: Thu, 9 Mar 2023 17:25:17 +0000
Subject: [PATCH 27/81] update utils

---
 .../inference/components/inference_utils.py   | 135 +++++++++++++-----
 1 file changed, 100 insertions(+), 35 deletions(-)

diff --git a/src/mmdc_singledate/inference/components/inference_utils.py b/src/mmdc_singledate/inference/components/inference_utils.py
index 3ec181c5..fe66b640 100644
--- a/src/mmdc_singledate/inference/components/inference_utils.py
+++ b/src/mmdc_singledate/inference/components/inference_utils.py
@@ -39,59 +39,67 @@ class GeoTiffDataset:
     """
 
     s2_filename: str
+    s2_availabitity: bool = field(init=False)
     s1_asc_filename: str
+    s1_asc_availability: bool = field(init=False)
     s1_desc_filename: str
+    s1_desc_availability: bool = field(init=False)
     srtm_filename: str
     wc_filename: str
     metadata: dict = field(init=False)
 
     def __post_init__(self) -> None:
         # check if the data exists
-        # if not Path(self.s2_filename).exists():
-        #     raise Exception(f"{self.s2_filename} do not exists!")
-        # if not Path(self.s1_asc_filename).exists():
-        #     raise Exception(f"{self.s1_asc_filename} do not exists!")
-        # if not Path(self.s1_desc_filename).exists():
-        #     raise Exception(f"{self.s1_desc_filename} do not exists!")
         if not Path(self.srtm_filename).exists():
             raise Exception(f"{self.srtm_filename} do not exists!")
         if not Path(self.wc_filename).exists():
             raise Exception(f"{self.wc_filename} do not exists!")
-
+        if Path(self.s1_asc_filename).exists():
+            self.s1_asc_availability = True
+        else:
+            self.s1_asc_availability = False
+        if Path(self.s1_desc_filename).exists():
+            self.s1_desc_availability = True
+        else:
+            self.s1_desc_availability = False
+        if Path(self.s2_filename).exists():
+            self.s2_availabitity = True
+        else:
+            self.s2_availabitity = False
         # check if the filenames
         # provide are in the same
         # footprint and spatial resolution
-        # rio.open(self.s2_filename) as s2, rio.open(
+        # rio.open(
         #    self.s1_asc_filename
         # ) as s1_asc, rio.open(self.s1_desc_filename) as s1_desc,
-
+        # rio.open(self.s2_filename) as s2,
         with rio.open(self.srtm_filename) as srtm, rio.open(self.wc_filename) as wc:
             # recover the metadata
             self.metadata = srtm.meta
 
-            # list to check
+            # create a python set to check
             crss = {
+                srtm.meta["crs"],
+                wc.meta["crs"],
                 # s2.meta["crs"],
                 # s1_asc.meta["crs"],
                 # s1_desc.meta["crs"],
-                srtm.meta["crs"],
-                wc.meta["crs"],
             }
 
             heights = {
+                srtm.meta["height"],
+                wc.meta["height"],
                 # s2.meta["height"],
                 # s1_asc.meta["height"],
                 # s1_desc.meta["height"],
-                srtm.meta["height"],
-                wc.meta["height"],
             }
 
             widths = {
+                srtm.meta["width"],
+                wc.meta["width"],
                 # s2.meta["width"],
                 # s1_asc.meta["width"],
                 # s1_desc.meta["width"],
-                srtm.meta["width"],
-                wc.meta["width"],
             }
 
             # check crs
@@ -126,7 +134,6 @@ class S1Components:
 
     s1_backscatter: torch.Tensor
     s1_valmask: torch.Tensor
-    s1_edgemask: torch.Tensor
     s1_lia_angles: torch.Tensor
 
 
@@ -177,6 +184,7 @@ def generate_chunks(
 def read_img_tile(
     filename: str,
     rois: list[rio_window],
+    availabitity: bool,
     sensor_func: Callable[[torch.Tensor, Any], Any],
 ) -> Generator[Any, None, None]:
     """
@@ -186,23 +194,24 @@ def read_img_tile(
     """
     # check if the filename exists
     # if exist proceed to yield the data
-    if Path(filename).exists:
+    if availabitity:
         with rio.open(filename) as raster:
             for roi in rois:
                 # read the patch as tensor
                 tensor = torch.tensor(raster.read(window=roi), requires_grad=False)
                 # compute some transformation to the tensor and return dataclass
-                sensor_data = sensor_func(tensor, filename)
+                sensor_data = sensor_func(tensor, availabitity, filename)
                 yield sensor_data
     # if not exists create a zeros tensor and yield
     else:
         for roi in rois:
-            null_data = torch.zeros(roi.width, roi.height)
+            null_data = torch.ones(roi.width, roi.height)
             yield null_data
 
 
 def read_s2_img_tile(
     s2_tensor: torch.Tensor,
+    availabitity: bool,
     *args: Any,
     **kwargs: Any,
 ) -> S2Components:  # [torch.Tensor, torch.Tensor, torch.Tensor]:
@@ -211,7 +220,7 @@ def read_s2_img_tile(
     contruct the masks and yield the patch
     of data
     """
-    if s2_tensor.shape[0] == 20:
+    if availabitity:
         # extract masks
         cloud_mask = s2_tensor[11, ...].to(torch.uint8)
         cloud_mask[cloud_mask > 0] = 1
@@ -223,17 +232,39 @@ def read_s2_img_tile(
         ).unsqueeze(0)
         angles_s2 = join_even_odd_s2_angles(s2_tensor[14:, ...])
         image_s2 = s2_tensor[:10, ...]
-
     else:
-        image_s2 = torch.zeros(10, s2_tensor.shape[0], s2_tensor.shape[1])
-        angles_s2 = torch.zeros(6, s2_tensor.shape[0], s2_tensor.shape[1])
-        mask = torch.zeros(1, s2_tensor.shape[0], s2_tensor.shape[1])
+        image_s2 = torch.ones(10, s2_tensor.shape[0], s2_tensor.shape[1])
+        angles_s2 = torch.ones(4, s2_tensor.shape[0], s2_tensor.shape[1])
+        mask = torch.zeros(2, s2_tensor.shape[0], s2_tensor.shape[1])
 
     return S2Components(s2_reflectances=image_s2, s2_angles=angles_s2, s2_mask=mask)
 
 
+# def read_s2_img_iota2(
+#     s2_tensor: torch.Tensor,
+#     *args: Any,
+#     **kwargs: Any,
+# ) -> S2Components:  # [torch.Tensor, torch.Tensor, torch.Tensor]:
+#     """
+#     read a patch of sentinel 2 data
+#     contruct the masks and yield the patch
+#     of data
+#     """
+#     # copy the masks
+#     # cloud_mask =
+#     # sat_mask =
+#     # edge_mask =
+#     mask = torch.concat((cloud_mask, sat_mask, edge_mask), axis=0)
+#     angles_s2 = join_even_odd_s2_angles(s2_tensor[14:, ...])
+#     image_s2 = s2_tensor[:10, ...]
+
+#     return S2Components(s2_reflectances=image_s2, s2_angles=angles_s2, s2_mask=mask)
+
+
+# TODO get out the s1_lia computation from this function
 def read_s1_img_tile(
     s1_tensor: torch.Tensor,
+    availability: bool,
     s1_filename: str,
     *args: Any,
     **kwargs: Any,
@@ -241,8 +272,11 @@ def read_s1_img_tile(
     """
     read a patch of s1 construct the datasets associated
     and yield the result
+
+    inputs : s1_tensor read from  read_img_tile
     """
-    if s1_tensor.shape[0] == 2:
+    # if s1_tensor.shape[0] == 2:
+    if availability:
         # get the vv,vh, vv/vh bands
         img_s1 = torch.cat(
             (s1_tensor, (s1_tensor[1, ...] / s1_tensor[0, ...]).unsqueeze(0))
@@ -250,24 +284,47 @@ def read_s1_img_tile(
         # get the local incidencce angle (lia)
         s1_lia = get_s1_acquisition_angles(s1_filename)
         # compute validity mask
-        s1_valmask = torch.ones(img_s1.shape)
+        s1_valmask = torch.ones(1, s1_tensor.shape[1], s1_tensor.shape[2])
         # compute edge mask
-        s1_edgemask = img_s1.to(int)
         s1_backscatter = apply_log_to_s1(img_s1)
     else:
-        s1_backscatter = torch.zeros(3, s1_tensor.shape[0], s1_tensor.shape[1])
+        s1_backscatter = torch.ones(3, s1_tensor.shape[0], s1_tensor.shape[1])
         s1_valmask = torch.zeros(1, s1_tensor.shape[0], s1_tensor.shape[1])
-        s1_edgemask = torch.zeros(1, s1_tensor.shape[0], s1_tensor.shape[1])
-        s1_lia = torch.zeros(3)
+        s1_lia = torch.ones(3)
 
     return S1Components(
         s1_backscatter=s1_backscatter,
         s1_valmask=s1_valmask,
-        s1_edgemask=s1_edgemask,
         s1_lia_angles=s1_lia,
     )
 
 
+# def read_s1_img_iota2(
+#     s1_tensor: torch.Tensor,
+#     *args: Any,
+#     **kwargs: Any,
+# ) -> S1Components:  # [torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+#     """
+#     read a patch of s1 construct the datasets associated
+#     and yield the result
+#     """
+#     # get the vv,vh, vv/vh bands
+#     img_s1 = torch.cat(
+#         (s1_tensor, (s1_tensor[1, ...] / s1_tensor[0, ...]).unsqueeze(0))
+#     )
+#     s1_lia = get_s1_acquisition_angles(s1_filename)
+#     # compute validity mask
+#     s1_valmask = torch.ones(img_s1.shape)
+#     # compute edge mask
+#     s1_backscatter = apply_log_to_s1(img_s1)
+
+#     return S1Components(
+#         s1_backscatter=s1_backscatter,
+#         s1_valmask=s1_valmask,
+#         s1_lia_angles=s1_lia,
+#     )
+
+
 def read_srtm_img_tile(
     srtm_tensor: torch.Tensor,
     *args: Any,
@@ -292,7 +349,9 @@ def read_worldclim_img_tile(
     return WorldClimComponents(wc_tensor)
 
 
-def expand_worldclim_filenames(wc_filename: str) -> list[str]:
+def expand_worldclim_filenames(
+    wc_filename: str,
+) -> list[str]:
     """
     given the firts worldclim filename expand the filenames
     to the others files
@@ -308,7 +367,11 @@ def expand_worldclim_filenames(wc_filename: str) -> list[str]:
 
 
 def concat_worldclim_components(
-    wc_filename: str, rois: list[rio_window]
+    wc_filename: str,
+    rois: list[rio_window],
+    availabitity,
+    *args: Any,
+    **kwargs: Any,
 ) -> Generator[Any, None, None]:
     """
     Compose function for apply the read_img_tile general function
@@ -318,7 +381,9 @@ def concat_worldclim_components(
     wc_filenames = expand_worldclim_filenames(wc_filename)
     # read the tensors
     wc_tensor = [
-        next(read_img_tile(wc_filename, rois, read_worldclim_img_tile)).worldclim
+        next(
+            read_img_tile(wc_filename, rois, availabitity, read_worldclim_img_tile)
+        ).worldclim
         for idx, wc_filename in enumerate(wc_filenames)
     ]
     yield WorldClimComponents(torch.cat(wc_tensor))
-- 
GitLab


From 4d68a05dbafcf55a1d3f4be0311269bf05e6a984 Mon Sep 17 00:00:00 2001
From: Juan Sebastian Vinasco-Salinas <js.vinasco.s@gmail.com>
Date: Thu, 9 Mar 2023 17:28:14 +0000
Subject: [PATCH 28/81] model related stuff

---
 .../components/inference_components.py        | 196 ++++++++++++------
 1 file changed, 131 insertions(+), 65 deletions(-)

diff --git a/src/mmdc_singledate/inference/components/inference_components.py b/src/mmdc_singledate/inference/components/inference_components.py
index fe3212b2..8841e203 100644
--- a/src/mmdc_singledate/inference/components/inference_components.py
+++ b/src/mmdc_singledate/inference/components/inference_components.py
@@ -6,17 +6,19 @@ Functions components for get the
 latent spaces for a given tile
 """
 
-import logging
 
 # imports
+import logging
 from pathlib import Path
 from typing import Literal
 
 import torch
 from torch import nn
-from torchutils import patches
 
-from mmdc_singledate.models.components.model_dataclass import VAELatentSpace
+from mmdc_singledate.models.components.model_dataclass import (
+    S1S2VAELatentSpace,
+    VAELatentSpace,
+)
 from mmdc_singledate.models.mmdc_full_module import MMDCFullModule
 
 # define sensors variable for typing
@@ -90,8 +92,11 @@ def get_mmdc_full_model(
             w_code_s1s2=1,
         )
 
+        # mmdc_full_lightning = MMDCFullLitModule(mmdc_full_model)
+
         device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
-        print(device)
+        print(f"device detected : {device}")
+        print(f"nb_gpu's detected : {torch.cuda.device_count()}")
         # load state_dict
         lightning_checkpoint = torch.load(checkpoint, map_location=device)
 
@@ -104,11 +109,14 @@ def get_mmdc_full_model(
 
         # load the state dict
         mmdc_full_model.load_state_dict(checkpoint)
+        # mmdc_full_model.load_state_dict(lightning_checkpoint)
+        # mmdc_full_lightning.load_state_dict(lightning_checkpoint)
 
         # disble randomness, dropout, etc...
         mmdc_full_model.eval()
+        # mmdc_full_lightning.eval()
 
-    return mmdc_full_model
+    return mmdc_full_model  # mmdc_full_lightning
 
 
 @torch.no_grad()
@@ -116,13 +124,14 @@ def predict_mmdc_model(
     model: nn.Module,
     sensors: SENSORS,
     s2_ref,
+    s2_mask,
     s2_angles,
-    s1_asc,
-    s1_desc,
+    s1_back,
+    s1_vm,
     s1_asc_angles,
     s1_desc_angles,
-    srtm,
     worldclim,
+    srtm,
 ) -> VAELatentSpace:
     """
     This function apply the predict method to a
@@ -130,7 +139,7 @@ def predict_mmdc_model(
     by sensor
 
         :args: model : mmdc_full model intance
-        :args: sensor : Sensor acquisition
+        :args: sensors : Sensor acquisition
         :args: s2_ref :
         :args: s2_angles :
         :args: s1_back :
@@ -141,69 +150,126 @@ def predict_mmdc_model(
 
         :return: VAELatentSpace
     """
-    s1_back = torch.cat(
-        (
-            s1_asc,
-            s1_desc,
-            (s1_asc / s1_desc),
-        ),
-        1,
-    )
+    # TODO check if the sensors are in the current supported list
+    # available_sensors = ["S2L2A", "S1FULL", "S1ASC", "S1DESC"]
+
     prediction = model.predict(
         s2_ref,
+        s2_mask,
         s2_angles,
         s1_back,
+        s1_vm,
         s1_asc_angles,
         s1_desc_angles,
         worldclim,
         srtm,
     )
 
-    if "S2L2A" in sensors:
-        # get latent
-        latent_space = prediction[0].latent.latent_s2
-
-    if "S1FULL" in sensors:
-        latent_space = prediction[0].latent.latent_s1
-
-    if "S1ASC" in sensors:
-        latent_space = prediction[0].latent.latent_s1
-
-    if "S1DESC" in sensors:
-        latent_space = prediction[0].latent.latent_s1
-
-    return latent_space
-
-
-def patchify_batch(
-    tensor: torch.Tensor,
-    patch_size: int,
-) -> torch.tensor:
-    """
-    reshape the geotiff data readed to the
-    shape expected from the network
-
-    :param: tensor
-    :param: patch_size
-    """
-    patch = patches.patchify(tensor, patch_size)
-    flatten_patch = patches.flatten2d(patch)
-
-    return flatten_patch
-
-
-def unpatchify_batch(
-    flatten_patch: torch.tensor, patch_shape: torch.Size, tensor_shape: torch.Size
-) -> torch.tensor:
-    """
-    Inverse operation of patchify batch
-    :param:   #
-    :param:   #
-    :param:   #
-    :return:  #
-    """
-    unflatten_patch = patches.unflatten2d(flatten_patch, patch_shape[0], patch_shape[1])
-    unpatch = patches.unpatchify(unflatten_patch)
-    unpatch_crop = unpatch[:, : tensor_shape[1], : tensor_shape[2]]
-
-    return unpatch_crop
+    print("prediction.shape :", prediction[0].latent.latent_s2.mu.shape)
+    print("prediction.shape :", prediction[0].latent.latent_s2.logvar.shape)
+    print("prediction.shape :", prediction[0].latent.latent_s1.mu.shape)
+    print("prediction.shape :", prediction[0].latent.latent_s1.logvar.shape)
+
+    # init the output latent spaces as
+    # empty dataclass
+    latent_space = S1S2VAELatentSpace(latent_s1=None, latent_s2=None)
+
+    # fullfit with a matchcase
+    # match
+    match sensors:
+        # Cases
+        case ["S2L2A", "S1FULL"]:
+            logger.info("S1 full captured")
+            logger.info("S2 captured")
+
+            latent_space.latent_s2 = prediction[0].latent.latent_s2
+            latent_space.latent_s1 = prediction[0].latent.latent_s1
+
+            latent_space_stack = torch.cat(
+                (
+                    latent_space.latent_s2.mu,
+                    latent_space.latent_s2.logvar,
+                    latent_space.latent_s1.mu,
+                    latent_space.latent_s1.logvar,
+                ),
+                0,
+            )
+
+        case ["S2L2A", "S1ASC"]:
+            logger.info("S1 asc captured")
+            logger.info("S2 captured")
+
+            latent_space.latent_s2 = prediction[0].latent.latent_s2
+            latent_space.latent_s1 = prediction[0].latent.latent_s1
+
+            latent_space_stack = torch.cat(
+                (
+                    latent_space.latent_s2.mu,
+                    latent_space.latent_s2.logvar,
+                    latent_space.latent_s1.mu,
+                    latent_space.latent_s1.logvar,
+                ),
+                0,
+            )
+
+        case ["S2L2A", "S1DESC"]:
+            logger.info("S1 desc captured")
+            logger.info("S2 captured")
+
+            latent_space.latent_s2 = prediction[0].latent.latent_s2
+            latent_space.latent_s1 = prediction[0].latent.latent_s1
+
+            latent_space_stack = torch.cat(
+                (
+                    latent_space.latent_s2.mu,
+                    latent_space.latent_s2.logvar,
+                    latent_space.latent_s1.mu,
+                    latent_space.latent_s1.logvar,
+                ),
+                0,
+            )
+
+        case ["S2L2A"]:
+            logger.info("Only S2 captured")
+
+            latent_space.latent_s2 = prediction[0].latent.latent_s2
+
+            latent_space_stack = torch.cat(
+                (
+                    latent_space.latent_s2.mu,
+                    latent_space.latent_s2.logvar,
+                ),
+                0,
+            )
+
+        case ["S1FULL"]:
+            logger.info("Only S1 full captured")
+
+            latent_space.latent_s1 = prediction[0].latent.latent_s1
+
+            latent_space_stack = torch.cat(
+                (latent_space.latent_s1.mu, latent_space.latent_s1.logvar),
+                0,
+            )
+
+        case ["S1ASC"]:
+            logger.info("Only S1ASC captured")
+
+            latent_space.latent_s1 = prediction[0].latent.latent_s1
+
+            latent_space_stack = torch.cat(
+                (latent_space.latent_s1.mu, latent_space.latent_s1.logvar),
+                0,
+            )
+
+        case ["S1DESC"]:
+            logger.info("Only S1DESC captured")
+
+            latent_space.latent_s1 = prediction[0].latent.latent_s1
+
+            latent_space_stack = torch.cat(
+                (latent_space.latent_s1.mu, latent_space.latent_s1.logvar),
+                0,
+            )
+
+    return latent_space_stack  # latent_space
-- 
GitLab


From 4dc55df9d189065899a23e4d5dfd6737e8a22858 Mon Sep 17 00:00:00 2001
From: Juan Sebastian Vinasco-Salinas <js.vinasco.s@gmail.com>
Date: Thu, 9 Mar 2023 17:32:02 +0000
Subject: [PATCH 29/81] update aux datasets

---
 .../inference/mmdc_tile_inference.py          | 263 +++++++++++-------
 1 file changed, 160 insertions(+), 103 deletions(-)

diff --git a/src/mmdc_singledate/inference/mmdc_tile_inference.py b/src/mmdc_singledate/inference/mmdc_tile_inference.py
index 2511ae3d..17939afc 100644
--- a/src/mmdc_singledate/inference/mmdc_tile_inference.py
+++ b/src/mmdc_singledate/inference/mmdc_tile_inference.py
@@ -7,21 +7,20 @@ Infereces API with Rasterio
 
 # imports
 import logging
-import os
 from collections.abc import Callable
 from dataclasses import dataclass
 from pathlib import Path
 
+import numpy as np
 import rasterio as rio
 import torch
+from torchutils import patches
 
 from mmdc_singledate.datamodules.components.datamodule_components import prepare_data_df
 
-from .components.inference_components import (
+from .components.inference_components import (  # patchify_batch,; unpatchify_batch,
     get_mmdc_full_model,
-    patchify_batch,
     predict_mmdc_model,
-    unpatchify_batch,
 )
 from .components.inference_utils import (
     GeoTiffDataset,
@@ -43,39 +42,34 @@ logger = logging.getLogger(__name__)
 
 def inference_dataframe(
     samples_dir: str,
-    input_tile_list: list[str],
+    input_tile_list: list[str] | None,
     days_gap: int,
-    nb_tiles: int,
-    nb_rois: int,
-    nb_files: int = None,
 ):
     """
     Read the input directory and create
     a dataframe with the occurrences
     for be infered
     """
-    # read the metadata and contruct time serie
-    (
-        tile_df,
-        asc_orbit_without_acquisitions,
-        desc_orbit_without_acquisitions,
-    ) = prepare_data_df(
-        samples_dir=samples_dir,
-        input_tile_list=input_tile_list,  # or pass,
-        days_gap=days_gap,
-        nb_files=nb_files,
-        nb_tiles=nb_tiles,
-        nb_rois=1,
+
+    # Create the Tile time serie
+
+    tile_time_serie_df = prepare_data_df(
+        samples_dir=samples_dir, input_tile_list=input_tile_list, days_gap=days_gap
+    )
+
+    tile_time_serie_df["patch_s2_availability"] = tile_time_serie_df["patch_s2"].apply(
+        lambda x: Path(x).exists()
     )
 
-    # manage the abscense of S1ASC or S1DESC
-    if asc_orbit_without_acquisitions:
-        tile_df["patchasc_s1"] = None
+    tile_time_serie_df["patchasc_s1_availability"] = tile_time_serie_df[
+        "patchasc_s1"
+    ].apply(lambda x: Path(x).exists())
 
-    if desc_orbit_without_acquisitions:
-        tile_df["patchdesc_s1"] = None
+    tile_time_serie_df["patchdesc_s1_availability"] = tile_time_serie_df[
+        "patchasc_s1"
+    ].apply(lambda x: Path(x).exists())
 
-    return tile_df
+    return tile_time_serie_df
 
 
 # functions and classes
@@ -87,6 +81,7 @@ class MMDCProcess:
 
     count: int
     nb_lines: int
+    patch_size: int
     process: Callable[
         [
             torch.tensor,
@@ -110,39 +105,59 @@ def predict_single_date_tile(
     """
     Predict a tile of data
     """
-    # retrieve the metadata
-    meta = input_data.metadata.copy()
-    # get nb outputs
-    meta.update({"count": process.count})
-    # calculate rois
-    chunks = generate_chunks(meta["width"], meta["height"], process.nb_lines)
-    # get the windows from the chunks
-    rois = [rio.windows.Window(*chunk) for chunk in chunks]
-    logger.info(f"chunk size : ({rois[0].width}, {rois[0].height}) ")
-
-    # init the dataset
-    logger.info("Reading S2 data")
-    s2_data = read_img_tile(input_data.s2_filename, rois, read_s2_img_tile)
-    logger.info("Reading S1 ASC data")
-    s1_asc_data = read_img_tile(input_data.s1_asc_filename, rois, read_s1_img_tile)
-    logger.info("Reading S1 DESC data")
-    s1_desc_data = read_img_tile(input_data.s1_desc_filename, rois, read_s1_img_tile)
-    logger.info("Reading SRTM data")
-    srtm_data = read_img_tile(input_data.srtm_filename, rois, read_srtm_img_tile)
-    logger.info("Reading WorldClim data")
-    worldclim_data = concat_worldclim_components(input_data.wc_filename, rois)
-
-    logger.info("Export Init")
-
-    # separate the latent spaces by sensor
-    sensors = [s for s in sensors if s in ["S2L2A", "S1FULL", "S1ASC", "S1DESC"]]
-    # built export_filename
-    export_names = ["mmdc_latent_" + s.casefold() + ".tif" for s in sensors]
-    for idx, export_name in enumerate(export_names):
-        # export latent spaces
-        with rio.open(
-            os.path.join(export_path, export_name), "w", **meta
-        ) as prediction:
+    # check that at least one dataset exists
+    if (
+        Path(input_data.s2_filename).exists
+        or Path(input_data.s1_asc_filename)
+        or Path(input_data.s1_desc_filename)
+    ):
+        # retrieve the metadata
+        meta = input_data.metadata.copy()
+        # get nb outputs
+        meta.update({"count": process.count})
+        # calculate rois
+        chunks = generate_chunks(meta["width"], meta["height"], process.nb_lines)
+        # get the windows from the chunks
+        rois = [rio.windows.Window(*chunk) for chunk in chunks]
+        logger.info(f"chunk size : ({rois[0].width}, {rois[0].height}) ")
+
+        # init the dataset
+        logger.info("Reading S2 data")
+        s2_data = read_img_tile(
+            filename=input_data.s2_filename,
+            rois=rois,
+            availabitity=input_data.s2_availabitity,
+            sensor_func=read_s2_img_tile,
+        )
+        logger.info("Reading S1 ASC data")
+        s1_asc_data = read_img_tile(
+            filename=input_data.s1_asc_filename,
+            rois=rois,
+            availabitity=input_data.s1_asc_availability,
+            sensor_func=read_s1_img_tile,
+        )
+        logger.info("Reading S1 DESC data")
+        s1_desc_data = read_img_tile(
+            filename=input_data.s1_desc_filename,
+            rois=rois,
+            availabitity=input_data.s1_desc_availability,
+            sensor_func=read_s1_img_tile,
+        )
+        logger.info("Reading SRTM data")
+        srtm_data = read_img_tile(
+            filename=input_data.srtm_filename,
+            rois=rois,
+            availabitity=True,
+            sensor_func=read_srtm_img_tile,
+        )
+        logger.info("Reading WorldClim data")
+        worldclim_data = concat_worldclim_components(
+            wc_filename=input_data.wc_filename, rois=rois, availabitity=True
+        )
+
+        logger.info("Export Init")
+
+        with rio.open(export_path, "w", **meta) as prediction:
             # iterate over the windows
             for roi, s2, s1_asc, s1_desc, srtm, wc in zip(
                 rois,
@@ -153,53 +168,100 @@ def predict_single_date_tile(
                 worldclim_data,
             ):
                 print(" original size : ", s2.s2_reflectances.shape)
+                print(" original size : ", s2.s2_angles.shape)
+                print(" original size : ", s2.s2_mask.shape)
+                print(" original size : ", s1_asc.s1_backscatter.shape)
+                print(" original size : ", s1_asc.s1_valmask.shape)
+                print(" original size : ", s1_asc.s1_lia_angles.shape)
+                print(" original size : ", s1_desc.s1_backscatter.shape)
+                print(" original size : ", s1_desc.s1_valmask.shape)
+                print(" original size : ", s1_desc.s1_lia_angles.shape)
+                print(" original size : ", srtm.srtm.shape)
+                print(" original size : ", wc.worldclim.shape)
+
+                # Concat S1 Data
+                s1_backscatter = torch.cat(
+                    (s1_asc.s1_backscatter, s1_desc.s1_backscatter), 0
+                )
+                # The validity mask of S1 is unique
+                # Multiply the masks
+                s1_valmask = torch.mul(s1_asc.s1_valmask, s1_desc.s1_valmask)
+
+                # keep the sizes for recover the original size
+                s2_mask_patch_size = patches.patchify(
+                    s2.s2_mask, process.patch_size
+                ).shape
+                s2_s2_mask_shape = s2.s2_mask.shape  # [1, 1024, 10980]
+
+                print("s2_mask_patch_size :", s2_mask_patch_size)
+                print("s2_s2_mask_shape:", s2_s2_mask_shape)
                 # reshape the data
-                s2_refl_patch = patchify_batch(s2.s2_reflectances, 256)
-                s2_ang_patch = patchify_batch(s2.s2_angles, 256)
-                s1_asc_patch = patchify_batch(s1_asc.s1_backscatter, 256)
-                s1_asc_lia_patch = s1_asc.s1_lia_angles
-                s1_desc_patch = patchify_batch(s1_desc.s1_backscatter, 256)
-                s1_desc_lia_patch = s1_desc.s1_lia_angles
-                srtm_patch = patchify_batch(srtm.srtm, 256)
-                wc_patch = patchify_batch(wc.worldclim, 256)
-                print(
-                    s2_refl_patch.shape,
-                    s2_ang_patch.shape,
-                    s1_asc_patch.shape,
-                    s1_desc_patch.shape,
-                    s1_asc_lia_patch.shape,
-                    s1_desc_lia_patch.shape,
-                    srtm_patch.shape,
-                    wc_patch.shape,
+                s2_refl_patch = patches.flatten2d(
+                    patches.patchify(s2.s2_reflectances, process.patch_size)
+                )
+                s2_ang_patch = patches.flatten2d(
+                    patches.patchify(s2.s2_angles, process.patch_size)
+                )
+                s2_mask_patch = patches.flatten2d(
+                    patches.patchify(s2.s2_mask, process.patch_size)
+                )
+                s1_patch = patches.flatten2d(
+                    patches.patchify(s1_backscatter, process.patch_size)
                 )
+                s1_valmask_patch = patches.flatten2d(
+                    patches.patchify(s1_valmask, process.patch_size)
+                )
+                srtm_patch = patches.flatten2d(
+                    patches.patchify(srtm.srtm, process.patch_size)
+                )
+                wc_patch = patches.flatten2d(
+                    patches.patchify(wc.worldclim, process.patch_size)
+                )
+                # patchify not necessary
+
+                print("s2_patches", s2_refl_patch.shape)
+                print("s2_angles_patches", s2_ang_patch.shape)
+                print("s2_mask_patch", s2_mask_patch.shape)
+                print("s1_patch", s1_patch.shape)
+                print("s1_valmask_patch", s1_valmask_patch.shape)
+                print("s1_lia_asc patches", s1_asc.s1_lia_angles.shape)
+                print("s1_lia_desc patches", s1_desc.s1_lia_angles.shape)
+                print("srtm patches", srtm_patch.shape)
+                print("wc patches", wc_patch.shape)
 
                 # apply predict function
+                # should return a s1s2vaelatentspace
+                # object
                 pred_vaelatentspace = process.process(
                     s2_refl_patch,
+                    s2_mask_patch,
                     s2_ang_patch,
-                    s1_asc_patch,
-                    s1_desc_patch,
-                    s1_asc_lia_patch,
-                    s1_desc_lia_patch,
-                    srtm_patch,
+                    s1_patch,
+                    s1_valmask_patch,
+                    s1_asc.s1_lia_angles,
+                    s1_desc.s1_lia_angles,
                     wc_patch,
+                    srtm_patch,
                 )
-                pred_tensor = torch.cat(
-                    (
-                        pred_vaelatentspace.mu,
-                        pred_vaelatentspace.logvar,
-                    ),
-                    1,
-                )
-                print(pred_tensor.shape)
+                print(type(pred_vaelatentspace))
+                print("latent space sizes : ", pred_vaelatentspace.shape)
+
+                # unpatchify
+
+                pred_vaelatentspace_unpatchify = patches.unpatchify(
+                    patches.unflatten2d(
+                        pred_vaelatentspace,
+                        s2_mask_patch_size[0],
+                        s2_mask_patch_size[1],
+                    )
+                )[:, : s2_s2_mask_shape[1], : s2_s2_mask_shape[2]]
+
+                print("pred_tensor :", pred_vaelatentspace_unpatchify.shape)
+
                 prediction.write(
-                    unpatchify_batch(
-                        flatten_patch=pred_tensor,
-                        patch_shape=s2_refl_patch.shape,
-                        tensor_shape=s2.s2_reflectances.shape,
-                    ),
+                    np.array(pred_vaelatentspace_unpatchify[0, ...]),
                     window=roi,
-                    indexes=1,
+                    indexes=4,
                 )
                 logger.info(("Export tile", f"filename :{export_path}"))
 
@@ -211,9 +273,7 @@ def mmdc_tile_inference(
     sensor: list[str],
     tile_list: list[str],
     days_gap: int,
-    nb_tiles: int,
-    nb_files: int = None,
-    latent_space_size: int = 2,
+    latent_space_size: int = 4,
     nb_lines: int = 1024,
 ) -> None:
     """
@@ -224,8 +284,6 @@ def mmdc_tile_inference(
         :args: sensor: list of sensors data input
         :args: tile_list: list of tiles,
         :args: days_gap : days between s1 and s2 acquisitions,
-        :args: nb_tiles : number of tiles to process,
-        :args: nb_files : number files,
         :args: latent_space_size : latent space output par sensor,
         :args: nb_lines : number of lines to read every at time 1024,
 
@@ -237,9 +295,6 @@ def mmdc_tile_inference(
         samples_dir=input_path,
         input_tile_list=tile_list,
         days_gap=days_gap,
-        nb_tiles=nb_tiles,
-        nb_rois=1,
-        nb_files=nb_files,
     )
 
     # instance the model and get the pred func
@@ -258,7 +313,9 @@ def mmdc_tile_inference(
         mmdc_input_data = GeoTiffDataset(
             s2_filename=df_row.patch_s2,
             s1_asc_filename=df_row.patchasc_s1,
+            s1_asc_availibility=df_row.s1_asc_availibility,
             s1_desc_filename=df_row.patchdesc_s1,
+            s1_desc_availibitity=df_row.s1_desc_availibitity,
             srtm_filename=df_row.srtm_filename,
             wc_filename=df_row.worldclim_filename,
         )
-- 
GitLab


From ec972af09fd3cab23b6d747960231de11f3a3993 Mon Sep 17 00:00:00 2001
From: Juan Sebastian Vinasco-Salinas <js.vinasco.s@gmail.com>
Date: Thu, 9 Mar 2023 17:33:35 +0000
Subject: [PATCH 30/81] add tiler test config

---
 configs/iota2/iota2_grid_full.cfg | 21 +++++++++++++++++++++
 1 file changed, 21 insertions(+)
 create mode 100644 configs/iota2/iota2_grid_full.cfg

diff --git a/configs/iota2/iota2_grid_full.cfg b/configs/iota2/iota2_grid_full.cfg
new file mode 100644
index 00000000..0a529fbb
--- /dev/null
+++ b/configs/iota2/iota2_grid_full.cfg
@@ -0,0 +1,21 @@
+chain :
+{
+  output_path : '/home/uz/vinascj/scratch/MMDC/iota2/grid_test_full'
+  spatial_resolution : 50
+  first_step : 'tiler'
+  last_step : 'tiler'
+
+  proj : 'EPSG:2154'
+  rasters_grid_path : '/work/OT/theia/oso/arthur/TMP/raster_grid'
+  tile_field : 'Name'
+  list_tile : 'T31TCJ'
+  srtm_path:'/datalake/static_aux/MNT/SRTM_30_hgt'
+  worldclim_path:'/datalake/static_aux/worldclim-2.0'
+  s2_path:'/work/OT/theia/oso/arthur/TMP/test_s2_angles'
+  s1_dir:'/work/OT/theia/oso/arthur/s1_emma/all_tiles'
+}
+
+builders:
+{
+builders_class_name : ["i2_features_to_grid"]
+}
-- 
GitLab


From 61172ee549b58edcf9b9b247f5a98f2aa173c32e Mon Sep 17 00:00:00 2001
From: Juan Sebastian Vinasco-Salinas <js.vinasco.s@gmail.com>
Date: Thu, 9 Mar 2023 17:40:30 +0000
Subject: [PATCH 31/81] add Iota2 API

---
 .../inference/mmdc_iota2_inference.py         | 118 ++++++++++++++++++
 1 file changed, 118 insertions(+)
 create mode 100644 src/mmdc_singledate/inference/mmdc_iota2_inference.py

diff --git a/src/mmdc_singledate/inference/mmdc_iota2_inference.py b/src/mmdc_singledate/inference/mmdc_iota2_inference.py
new file mode 100644
index 00000000..4a40ff74
--- /dev/null
+++ b/src/mmdc_singledate/inference/mmdc_iota2_inference.py
@@ -0,0 +1,118 @@
+#!/usr/bin/env python3
+# copyright: (c) 2023 cesbio / centre national d'Etudes Spatiales
+
+"""
+Infereces API with Iota-2
+Inspired by:
+https://src.koda.cnrs.fr/mmdc/mmdc-singledate/-/blob/01ff5139a9eb22785930964d181e2a0b7b7af0d1/iota2/external_iota2_code.py
+"""
+
+import torch
+from torchutils import patches
+
+# from mmdc_singledate.datamodules.components.datamodule_utils import (
+#     apply_log_to_s1,
+#     join_even_odd_s2_angles,
+#     srtm_height_aspect,
+# )
+
+
+def apply_mmdc_full_mode(
+    self,
+    checkpoint_path: str,
+    checkpoint_epoch: int = 100,
+    patch_size: int = 256,
+):
+    """
+    Apply MMDC with Iota-2.py
+    """
+    # How manage the S1 acquisition dates for construct the mask validity
+    # in time
+
+    # DONE Get the data in the same order as
+    # sentinel2.Sentinel2.GROUP_10M +
+    # sensorsio.sentinel2.GROUP_20M
+    # This data correspond to (Chunk_Size, Img_Size, nb_dates)
+    list_bands_s2 = [
+        self.get_interpolated_Sentinel2_B2(),
+        self.get_interpolated_Sentinel2_B3(),
+        self.get_interpolated_Sentinel2_B4(),
+        self.get_interpolated_Sentinel2_B8(),
+        self.get_interpolated_Sentinel2_B5(),
+        self.get_interpolated_Sentinel2_B6(),
+        self.get_interpolated_Sentinel2_B7(),
+        self.get_interpolated_Sentinel2_B8A(),
+        self.get_interpolated_Sentinel2_B11(),
+        self.get_interpolated_Sentinel2_B12(),
+    ]
+
+    # TODO Masks contruction for S2
+    list_s2_mask = self.get_Sentinel2_binary_masks()
+
+    # TODO Manage S1 ASC and S1 DESC ?
+    list_bands_s1 = [
+        self.get_interpolated_Sentinel1_ASC_vh(),
+        self.get_interpolated_Sentinel1_ASC_vv(),
+        self.get_interpolated_Sentinel1_ASC_vh()
+        / (self.get_interpolated_Sentinel1_ASC_vv() + 1e-4),
+        self.get_interpolated_Sentinel1_DES_vh(),
+        self.get_interpolated_Sentinel1_DES_vv(),
+        self.get_interpolated_Sentinel1_DES_vh()
+        / (self.get_interpolated_Sentinel1_DES_vv() + 1e-4),
+    ]
+
+    # TODO Read SRTM data
+    # TODO Read Worldclim Data
+
+    with torch.no_grad():
+        # Permute dimensions to fit patchify
+        # Shape before permutation is C,H,W,D. D being the dates
+        # Shape after permutation is D,C,H,W
+        bands_s2 = torch.Tensor(list_bands_s2).permute(-1, 0, 1, 2)
+        bands_s2_mask = torch.Tensor(list_s2_mask).permute(-1, 0, 1, 2)
+        bands_s1 = torch.Tensor(list_bands_s1).permute(-1, 0, 1, 2)
+        bands_s1 = apply_log_to_s1(bands_s1).permute(-1, 0, 1, 2)
+
+        # TODO Masks contruction for S1
+        # build_s1_image_and_masks function for datamodules components datamodule components ?
+
+        # Replace nan by 0
+        bands_s1 = bands_s1.nan_to_num()
+        # These dimensions are useful for unpatchify
+        # Keep the dimensions of height and width found in chunk
+        band_h, band_w = bands_s2.shape[-2:]
+        # Keep the number of patches of patch_size
+        # in rows and cols found in chunk
+        h, w = patches.patchify(
+            bands_s2[0, ...], patch_size=patch_size, margin=patch_margin
+        ).shape[:2]
+
+    # TODO Apply patchify
+
+    # Get the model
+    mmdc_full_model = get_mmdc_full_model(
+        os.path.join(checkpoint_path, checkpoint_filename)
+    )
+
+    # apply model
+    # latent_variable = mmdc_full_model.predict(
+    #     s2_x,
+    #     s2_m,
+    #     s2_angles_x,
+    #     s1_x,
+    #     s1_vm,
+    #     s1_asc_angles_x,
+    #     s1_desc_angles_x,
+    #     worldclim_x,
+    #     srtm_x,
+    # )
+
+    # TODO unpatchify
+
+    # TODO crop padding
+
+    # TODO Depending of the configuration return a unique latent variable
+    # or a stack of laten variables
+    # coef = pass
+    # labels = pass
+    # return coef, labels
-- 
GitLab


From 40025c415226024a44c27aa56a4289a87e9809e7 Mon Sep 17 00:00:00 2001
From: Juan Sebastian Vinasco-Salinas <js.vinasco.s@gmail.com>
Date: Fri, 10 Mar 2023 17:12:50 +0000
Subject: [PATCH 32/81] testing linting

---
 src/bin/inference_mmdc_singledate.py          | 155 ++++++++++++++++++
 src/bin/inference_mmdc_timeserie.py           | 133 +++++++++++++++
 .../components/inference_components.py        |  85 +++++++---
 .../inference/mmdc_tile_inference.py          |  25 ++-
 4 files changed, 372 insertions(+), 26 deletions(-)
 create mode 100644 src/bin/inference_mmdc_singledate.py
 create mode 100644 src/bin/inference_mmdc_timeserie.py

diff --git a/src/bin/inference_mmdc_singledate.py b/src/bin/inference_mmdc_singledate.py
new file mode 100644
index 00000000..37e19ad6
--- /dev/null
+++ b/src/bin/inference_mmdc_singledate.py
@@ -0,0 +1,155 @@
+#!/usr/bin/env python3
+# copyright: (c) 2023 cesbio / centre national d'Etudes Spatiales
+
+"""
+Inference code for MMDC
+"""
+
+
+# imports
+import argparse
+import logging
+import os
+
+from mmdc_singledate.inference.components.inference_components import (
+    GeoTiffDataset,
+    get_mmdc_full_model,
+    predict_mmdc_model,
+)
+from mmdc_singledate.inference.mmdc_tile_inference import (
+    MMDCProcess,
+    predict_single_date_tile,
+)
+
+
+def get_parser() -> argparse.ArgumentParser:
+    """
+    Generate argument parser for CLI
+    """
+    arg_parser = argparse.ArgumentParser(
+        os.path.basename(__file__),
+        description="Compute the inference a selected number of latent spaces "
+        " and exported as GTiff for a given tile",
+    )
+
+    arg_parser.add_argument(
+        "--loglevel",
+        default="INFO",
+        choices=("DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"),
+        help="Logger level (default: INFO. Should be one of "
+        "(DEBUG, INFO, WARNING, ERROR, CRITICAL)",
+    )
+
+    arg_parser.add_argument(
+        "--s2_filename", type=str, help="s2 input file", required=False
+    )
+
+    arg_parser.add_argument(
+        "--s1_asc_filename", type=str, help="s1 asc input file", required=False
+    )
+
+    arg_parser.add_argument(
+        "--s1_desc_filename", type=str, help="s1 desc input file", required=False
+    )
+
+    arg_parser.add_argument(
+        "--srtm_filename", type=str, help="srtm input file", required=True
+    )
+
+    arg_parser.add_argument(
+        "--worldclim_filename", type=str, help="worldclim input file", required=True
+    )
+
+    arg_parser.add_argument(
+        "--export_path", type=str, help="export path ", required=True
+    )
+
+    arg_parser.add_argument(
+        "--model_checkpoint_path",
+        type=str,
+        help="model checkpoint path ",
+        required=True,
+    )
+
+    arg_parser.add_argument(
+        "--latent_space_size",
+        type=int,
+        help="Number of channels in the output file",
+        required=False,
+    )
+
+    arg_parser.add_argument(
+        "--nb_lines",
+        type=str,
+        help="Number of Line to read in each window reading",
+        required=False,
+    )
+
+    arg_parser.add_argument(
+        "--sensors",
+        dest="sensors",
+        nargs="+",
+        help="List of sensors S2L2A | S1FULL | S1ASC | S1DESC (Is sensible to the order S2L2A always firts)",
+        required=True,
+    )
+
+    return arg_parser
+
+
+def main():
+    """
+    Entry Point
+    """
+    # Parser arguments
+    parser = get_parser()
+    args = parser.parse_args()
+
+    # Configure logging
+    numeric_level = getattr(logging, args.loglevel.upper(), None)
+    if not isinstance(numeric_level, int):
+        raise ValueError("Invalid log level: %s" % args.loglevel)
+
+    logging.basicConfig(
+        level=numeric_level,
+        datefmt="%y-%m-%d %H:%M:%S",
+        format="%(asctime)s :: %(levelname)s :: %(message)s",
+    )
+
+    # model
+    model = get_mmdc_full_model(
+        checkpoint=args.model_checkpoint_path,
+    )
+
+    # prediction function
+    pred_mmdc = predict_mmdc_model(model=model, sensor=args.sensor)
+
+    # process class
+    mmdc_process = MMDCProcess(
+        count=args.latent_space_size,
+        nb_lines=args.nb_lines,
+        patch_size=args.patch_size,
+        process=pred_mmdc,
+    )
+
+    # create export filename
+    export_path = f"/work/CESBIO/projects/MAESTRIA/test_onetile/total/export/latent_singledate_{'_'.join(args.sensors)}.tif"
+
+    # GeoTiff
+    geotiff_dataclass = GeoTiffDataset(
+        s2_filename=args.s2_filename,
+        s1_asc_filename=args.s1_asc_filename,
+        s1_desc_filename=args.s1_desc_filename,
+        srtm_filename=args.srtm_filename,
+        wc_filename=args.wc_filename,
+    )
+
+    predict_single_date_tile(
+        input_data=geotiff_dataclass,
+        export_path=export_path,
+        sensors=args.sensors,
+        process=mmdc_process,
+    )
+
+
+if __name__ == "__main__":
+    main()
diff --git a/src/bin/inference_mmdc_timeserie.py b/src/bin/inference_mmdc_timeserie.py
new file mode 100644
index 00000000..62bd41f9
--- /dev/null
+++ b/src/bin/inference_mmdc_timeserie.py
@@ -0,0 +1,133 @@
+#!/usr/bin/env python3
+# copyright: (c) 2023 cesbio / centre national d'Etudes Spatiales
+
+"""
+Inference code for MMDC
+"""
+
+
+# imports
+import argparse
+import logging
+import os
+from pathlib import Path
+
+from mmdc_singledate.inference.mmdc_tile_inference import mmdc_tile_inference
+
+
+def get_parser() -> argparse.ArgumentParser:
+    """
+    Generate argument parser for CLI
+    """
+    arg_parser = argparse.ArgumentParser(
+        os.path.basename(__file__),
+        description="Compute the inference a selected number of latent spaces "
+        " and exported as GTiff for a given tile",
+    )
+
+    arg_parser.add_argument(
+        "--loglevel",
+        default="INFO",
+        choices=("DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"),
+        help="Logger level (default: INFO. Should be one of "
+        "(DEBUG, INFO, WARNING, ERROR, CRITICAL)",
+    )
+
+    arg_parser.add_argument(
+        "--input_path", type=str, help="data input folder", required=True
+    )
+
+    arg_parser.add_argument(
+        "--export_path", type=str, help="export path ", required=True
+    )
+
+    arg_parser.add_argument(
+        "--days_gap",
+        type=str,
+        help="difference between S1 and S2 acquisitions in days",
+        required=True,
+    )
+
+    arg_parser.add_argument(
+        "--model_checkpoint_path",
+        type=str,
+        help="model checkpoint path ",
+        required=True,
+    )
+
+    arg_parser.add_argument(
+        "--tile_list", type=str, help="List of tiles to produce", required=False
+    )
+
+    arg_parser.add_argument(
+        "--patch_size",
+        type=str,
+        help="Sizes of the patches to make the inference",
+        required=False,
+    )
+
+    arg_parser.add_argument(
+        "--latent_space_size",
+        type=int,
+        help="Number of channels in the output file",
+        required=False,
+    )
+
+    arg_parser.add_argument(
+        "--nb_lines",
+        type=str,
+        help="Number of Line to read in each window reading",
+        required=False,
+    )
+
+    arg_parser.add_argument(
+        "-sensors",
+        dest="sensors",
+        nargs="+",
+        help="List of sensors S2L2A | S1FULL | S1ASC | S1DESC (Is sensible to the order S2L2A always firts)",
+        required=True,
+    )
+
+    return arg_parser
+
+
+def main():
+    """
+    Entry Point
+    """
+    # Parser arguments
+    parser = get_parser()
+    args = parser.parse_args()
+
+    # Configure logging
+    numeric_level = getattr(logging, args.loglevel.upper(), None)
+    if not isinstance(numeric_level, int):
+        raise ValueError("Invalid log level: %s" % args.loglevel)
+
+    logging.basicConfig(
+        level=numeric_level,
+        datefmt="%y-%m-%d %H:%M:%S",
+        format="%(asctime)s :: %(levelname)s :: %(message)s",
+    )
+
+    if not os.path.isdir(args.export_path):
+        logging.info(
+            f"The directory is not present. Creating a new one.. " f"{args.export_path}"
+        )
+        Path(args.export_path).mkdir()
+
+    # entry point
+    mmdc_tile_inference(
+        input_path=Path(args.input_path),
+        export_path=Path(args.export_path),
+        model_checkpoint_path=Path(args.model_checkpoint_path),
+        sensor=args.sensors,
+        tile_list=args.tile_list,
+        days_gap=args.days_gap,
+        latent_space_size=args.laten_space_size,
+        nb_lines=args.nb_lines,
+    )
+
+
+if __name__ == "__main__":
+    main()
diff --git a/src/mmdc_singledate/inference/components/inference_components.py b/src/mmdc_singledate/inference/components/inference_components.py
index 8841e203..934aa525 100644
--- a/src/mmdc_singledate/inference/components/inference_components.py
+++ b/src/mmdc_singledate/inference/components/inference_components.py
@@ -15,9 +15,8 @@ from typing import Literal
 import torch
 from torch import nn
 
-from mmdc_singledate.models.components.model_dataclass import (
+from mmdc_singledate.models.components.model_dataclass import (  # VAELatentSpace,
     S1S2VAELatentSpace,
-    VAELatentSpace,
 )
 from mmdc_singledate.models.mmdc_full_module import MMDCFullModule
 
@@ -132,7 +131,7 @@ def predict_mmdc_model(
     s1_desc_angles,
     worldclim,
     srtm,
-) -> VAELatentSpace:
+) -> torch.Tensor:
     """
     This function apply the predict method to a
     batch of mmdc data and get the latent space
@@ -153,16 +152,19 @@ def predict_mmdc_model(
     # TODO check if the sensors are in the current supported list
     # available_sensors = ["S2L2A", "S1FULL", "S1ASC", "S1DESC"]
 
+    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+    print(device)
+
     prediction = model.predict(
-        s2_ref,
-        s2_mask,
-        s2_angles,
-        s1_back,
-        s1_vm,
-        s1_asc_angles,
-        s1_desc_angles,
-        worldclim,
-        srtm,
+        s2_ref.to(device),
+        s2_mask.to(device),
+        s2_angles.to(device),
+        s1_back.to(device),
+        s1_vm.to(device),
+        s1_asc_angles.to(device),
+        s1_desc_angles.to(device),
+        worldclim.to(device),
+        srtm.to(device),
     )
 
     print("prediction.shape :", prediction[0].latent.latent_s2.mu.shape)
@@ -192,9 +194,20 @@ def predict_mmdc_model(
                     latent_space.latent_s1.mu,
                     latent_space.latent_s1.logvar,
                 ),
-                0,
+                1,
             )
 
+            print(
+                "latent_space.latent_s2.mu :",
+                latent_space.latent_s2.mu,
+            )
+            print(
+                "latent_space.latent_s2.logvar :",
+                latent_space.latent_s2.logvar,
+            )
+            print("latent_space.latent_s1.mu :", latent_space.latent_s1.mu)
+            print("latent_space.latent_s1.logvar :", latent_space.latent_s1.logvar)
+
         case ["S2L2A", "S1ASC"]:
             logger.info("S1 asc captured")
             logger.info("S2 captured")
@@ -209,8 +222,18 @@ def predict_mmdc_model(
                     latent_space.latent_s1.mu,
                     latent_space.latent_s1.logvar,
                 ),
-                0,
+                1,
+            )
+            print(
+                "latent_space.latent_s2.mu :",
+                latent_space.latent_s2.mu,
             )
+            print(
+                "latent_space.latent_s2.logvar :",
+                latent_space.latent_s2.logvar,
+            )
+            print("latent_space.latent_s1.mu :", latent_space.latent_s1.mu)
+            print("latent_space.latent_s1.logvar :", latent_space.latent_s1.logvar)
 
         case ["S2L2A", "S1DESC"]:
             logger.info("S1 desc captured")
@@ -226,8 +249,18 @@ def predict_mmdc_model(
                     latent_space.latent_s1.mu,
                     latent_space.latent_s1.logvar,
                 ),
-                0,
+                1,
+            )
+            print(
+                "latent_space.latent_s2.mu :",
+                latent_space.latent_s2.mu,
+            )
+            print(
+                "latent_space.latent_s2.logvar :",
+                latent_space.latent_s2.logvar,
             )
+            print("latent_space.latent_s1.mu :", latent_space.latent_s1.mu)
+            print("latent_space.latent_s1.logvar :", latent_space.latent_s1.logvar)
 
         case ["S2L2A"]:
             logger.info("Only S2 captured")
@@ -239,7 +272,15 @@ def predict_mmdc_model(
                     latent_space.latent_s2.mu,
                     latent_space.latent_s2.logvar,
                 ),
-                0,
+                1,
+            )
+            print(
+                "latent_space.latent_s2.mu :",
+                latent_space.latent_s2.mu,
+            )
+            print(
+                "latent_space.latent_s2.logvar :",
+                latent_space.latent_s2.logvar,
             )
 
         case ["S1FULL"]:
@@ -249,8 +290,10 @@ def predict_mmdc_model(
 
             latent_space_stack = torch.cat(
                 (latent_space.latent_s1.mu, latent_space.latent_s1.logvar),
-                0,
+                1,
             )
+            print("latent_space.latent_s1.mu :", latent_space.latent_s1.mu)
+            print("latent_space.latent_s1.logvar :", latent_space.latent_s1.logvar)
 
         case ["S1ASC"]:
             logger.info("Only S1ASC captured")
@@ -259,8 +302,10 @@ def predict_mmdc_model(
 
             latent_space_stack = torch.cat(
                 (latent_space.latent_s1.mu, latent_space.latent_s1.logvar),
-                0,
+                1,
             )
+            print("latent_space.latent_s1.mu :", latent_space.latent_s1.mu)
+            print("latent_space.latent_s1.logvar :", latent_space.latent_s1.logvar)
 
         case ["S1DESC"]:
             logger.info("Only S1DESC captured")
@@ -269,7 +314,9 @@ def predict_mmdc_model(
 
             latent_space_stack = torch.cat(
                 (latent_space.latent_s1.mu, latent_space.latent_s1.logvar),
-                0,
+                1,
             )
+            print("latent_space.latent_s1.mu :", latent_space.latent_s1.mu)
+            print("latent_space.latent_s1.logvar :", latent_space.latent_s1.logvar)
 
     return latent_space_stack  # latent_space
diff --git a/src/mmdc_singledate/inference/mmdc_tile_inference.py b/src/mmdc_singledate/inference/mmdc_tile_inference.py
index 17939afc..ad7408be 100644
--- a/src/mmdc_singledate/inference/mmdc_tile_inference.py
+++ b/src/mmdc_singledate/inference/mmdc_tile_inference.py
@@ -82,6 +82,7 @@ class MMDCProcess:
     count: int
     nb_lines: int
     patch_size: int
+    model : torch.nn
     process: Callable[
         [
             torch.tensor,
@@ -217,15 +218,21 @@ def predict_single_date_tile(
                 wc_patch = patches.flatten2d(
                     patches.patchify(wc.worldclim, process.patch_size)
                 )
-                # patchify not necessary
+                # Expand the angles to fit the sizes
+                s1asc_lia_patch = s1_asc.s1_lia_angles.unsqueeze(0).repeat(
+                    s2_mask_patch_size[0] * s2_mask_patch_size[1], 1)
+                # torch.flatten(start_dim=0, end_dim=1)
+                s1desc_lia_patch = s1_desc.s1_lia_angles.unsqueeze(0).repeat(
+                    s2_mask_patch_size[0] * s2_mask_patch_size[1], 1)
+                # torch.flatten(start_dim=0, end_dim=1)
 
                 print("s2_patches", s2_refl_patch.shape)
                 print("s2_angles_patches", s2_ang_patch.shape)
                 print("s2_mask_patch", s2_mask_patch.shape)
                 print("s1_patch", s1_patch.shape)
                 print("s1_valmask_patch", s1_valmask_patch.shape)
-                print("s1_lia_asc patches", s1_asc.s1_lia_angles.shape)
-                print("s1_lia_desc patches", s1_desc.s1_lia_angles.shape)
+                print("s1_lia_asc patches", s1_asc.s1_lia_angles.shape, s1asc_lia_patch.shape)
+                print("s1_lia_desc patches", s1_desc.s1_lia_angles.shape, s1desc_lia_patch.shape)
                 print("srtm patches", srtm_patch.shape)
                 print("wc patches", wc_patch.shape)
 
@@ -233,13 +240,15 @@ def predict_single_date_tile(
                 # should return a s1s2vaelatentspace
                 # object
                 pred_vaelatentspace = process.process(
+                    process.model,
+                    sensors,
                     s2_refl_patch,
                     s2_mask_patch,
                     s2_ang_patch,
                     s1_patch,
                     s1_valmask_patch,
-                    s1_asc.s1_lia_angles,
-                    s1_desc.s1_lia_angles,
+                    s1asc_lia_patch,
+                    s1desc_lia_patch,
                     wc_patch,
                     srtm_patch,
                 )
@@ -247,7 +256,6 @@ def predict_single_date_tile(
                 print("latent space sizes : ", pred_vaelatentspace.shape)
 
                 # unpatchify
-
                 pred_vaelatentspace_unpatchify = patches.unpatchify(
                     patches.unflatten2d(
                         pred_vaelatentspace,
@@ -257,11 +265,14 @@ def predict_single_date_tile(
                 )[:, : s2_s2_mask_shape[1], : s2_s2_mask_shape[2]]
 
                 print("pred_tensor :", pred_vaelatentspace_unpatchify.shape)
+                print("process.count : ", process.count)
+                # check the pred and the ask are the same
+                assert process.count == pred_vaelatentspace_unpatchify.shape[0]
 
                 prediction.write(
                     np.array(pred_vaelatentspace_unpatchify[0, ...]),
                     window=roi,
-                    indexes=4,
+                    indexes=process.count,
                 )
                 logger.info(("Export tile", f"filename :{export_path}"))
 
-- 
GitLab


From 4e8dba33a78ba4f1fac0e7a2979721b5906215e1 Mon Sep 17 00:00:00 2001
From: Juan Sebastian Vinasco-Salinas <js.vinasco.s@gmail.com>
Date: Sat, 11 Mar 2023 20:29:21 +0000
Subject: [PATCH 33/81] update for manage no data cases

---
 .../inference/components/inference_utils.py   | 23 +++++++++++--------
 1 file changed, 13 insertions(+), 10 deletions(-)

diff --git a/src/mmdc_singledate/inference/components/inference_utils.py b/src/mmdc_singledate/inference/components/inference_utils.py
index fe66b640..9d54bfe0 100644
--- a/src/mmdc_singledate/inference/components/inference_utils.py
+++ b/src/mmdc_singledate/inference/components/inference_utils.py
@@ -50,19 +50,19 @@ class GeoTiffDataset:
 
     def __post_init__(self) -> None:
         # check if the data exists
-        if not Path(self.srtm_filename).exists():
+        if not Path(self.srtm_filename).is_file():
             raise Exception(f"{self.srtm_filename} do not exists!")
-        if not Path(self.wc_filename).exists():
+        if not Path(self.wc_filename).is_file():
             raise Exception(f"{self.wc_filename} do not exists!")
-        if Path(self.s1_asc_filename).exists():
+        if Path(self.s1_asc_filename).is_file():
             self.s1_asc_availability = True
         else:
             self.s1_asc_availability = False
-        if Path(self.s1_desc_filename).exists():
+        if Path(self.s1_desc_filename).is_file():
             self.s1_desc_availability = True
         else:
             self.s1_desc_availability = False
-        if Path(self.s2_filename).exists():
+        if Path(self.s2_filename).is_file():
             self.s2_availabitity = True
         else:
             self.s2_availabitity = False
@@ -204,9 +204,11 @@ def read_img_tile(
                 yield sensor_data
     # if not exists create a zeros tensor and yield
     else:
+        print("Creating fake data")
         for roi in rois:
             null_data = torch.ones(roi.width, roi.height)
-            yield null_data
+            sensor_data = sensor_func(null_data, availabitity, filename)
+            yield sensor_data
 
 
 def read_s2_img_tile(
@@ -234,8 +236,9 @@ def read_s2_img_tile(
         image_s2 = s2_tensor[:10, ...]
     else:
         image_s2 = torch.ones(10, s2_tensor.shape[0], s2_tensor.shape[1])
-        angles_s2 = torch.ones(4, s2_tensor.shape[0], s2_tensor.shape[1])
-        mask = torch.zeros(2, s2_tensor.shape[0], s2_tensor.shape[1])
+        angles_s2 = torch.ones(6, s2_tensor.shape[0], s2_tensor.shape[1])
+        mask = torch.zeros(1, s2_tensor.shape[0], s2_tensor.shape[1])
+        print("passing zero data")
 
     return S2Components(s2_reflectances=image_s2, s2_angles=angles_s2, s2_mask=mask)
 
@@ -288,8 +291,8 @@ def read_s1_img_tile(
         # compute edge mask
         s1_backscatter = apply_log_to_s1(img_s1)
     else:
-        s1_backscatter = torch.ones(3, s1_tensor.shape[0], s1_tensor.shape[1])
-        s1_valmask = torch.zeros(1, s1_tensor.shape[0], s1_tensor.shape[1])
+        s1_backscatter = torch.ones(3, s1_tensor.shape[1], s1_tensor.shape[0])
+        s1_valmask = torch.zeros(1, s1_tensor.shape[1], s1_tensor.shape[0])
         s1_lia = torch.ones(3)
 
     return S1Components(
-- 
GitLab


From 6969fe2933c14bd7fda993c6c39633ed4a4162ee Mon Sep 17 00:00:00 2001
From: Juan Sebastian Vinasco-Salinas <js.vinasco.s@gmail.com>
Date: Sat, 11 Mar 2023 20:29:55 +0000
Subject: [PATCH 34/81] update for manage no data cases

---
 .../inference/mmdc_tile_inference.py          | 27 ++++++++++++++-----
 1 file changed, 21 insertions(+), 6 deletions(-)

diff --git a/src/mmdc_singledate/inference/mmdc_tile_inference.py b/src/mmdc_singledate/inference/mmdc_tile_inference.py
index ad7408be..88cb3bfb 100644
--- a/src/mmdc_singledate/inference/mmdc_tile_inference.py
+++ b/src/mmdc_singledate/inference/mmdc_tile_inference.py
@@ -82,7 +82,7 @@ class MMDCProcess:
     count: int
     nb_lines: int
     patch_size: int
-    model : torch.nn
+    model: torch.nn
     process: Callable[
         [
             torch.tensor,
@@ -155,7 +155,12 @@ def predict_single_date_tile(
         worldclim_data = concat_worldclim_components(
             wc_filename=input_data.wc_filename, rois=rois, availabitity=True
         )
-
+        print("input_data.s2_availabitity=", input_data.s2_availabitity)
+        print("s2_data=", s2_data)
+        print("s1_asc_data=", s1_asc_data)
+        print("s1_desc_data=", s1_desc_data)
+        print("srtm_data=", srtm_data)
+        print("worldclim_data=", worldclim_data)
         logger.info("Export Init")
 
         with rio.open(export_path, "w", **meta) as prediction:
@@ -220,10 +225,12 @@ def predict_single_date_tile(
                 )
                 # Expand the angles to fit the sizes
                 s1asc_lia_patch = s1_asc.s1_lia_angles.unsqueeze(0).repeat(
-                    s2_mask_patch_size[0] * s2_mask_patch_size[1], 1)
+                    s2_mask_patch_size[0] * s2_mask_patch_size[1], 1
+                )
                 # torch.flatten(start_dim=0, end_dim=1)
                 s1desc_lia_patch = s1_desc.s1_lia_angles.unsqueeze(0).repeat(
-                    s2_mask_patch_size[0] * s2_mask_patch_size[1], 1)
+                    s2_mask_patch_size[0] * s2_mask_patch_size[1], 1
+                )
                 # torch.flatten(start_dim=0, end_dim=1)
 
                 print("s2_patches", s2_refl_patch.shape)
@@ -231,8 +238,16 @@ def predict_single_date_tile(
                 print("s2_mask_patch", s2_mask_patch.shape)
                 print("s1_patch", s1_patch.shape)
                 print("s1_valmask_patch", s1_valmask_patch.shape)
-                print("s1_lia_asc patches", s1_asc.s1_lia_angles.shape, s1asc_lia_patch.shape)
-                print("s1_lia_desc patches", s1_desc.s1_lia_angles.shape, s1desc_lia_patch.shape)
+                print(
+                    "s1_lia_asc patches",
+                    s1_asc.s1_lia_angles.shape,
+                    s1asc_lia_patch.shape,
+                )
+                print(
+                    "s1_lia_desc patches",
+                    s1_desc.s1_lia_angles.shape,
+                    s1desc_lia_patch.shape,
+                )
                 print("srtm patches", srtm_patch.shape)
                 print("wc patches", wc_patch.shape)
 
-- 
GitLab


From 1a42b395c34be876a32454c5cd4228e01e5ba9fe Mon Sep 17 00:00:00 2001
From: Juan Sebastian Vinasco-Salinas <js.vinasco.s@gmail.com>
Date: Sat, 11 Mar 2023 20:30:43 +0000
Subject: [PATCH 35/81] update for manage no data cases

---
 test/test_mmdc_inference.py | 201 ++++++++++++++++++++++++++++++++++--
 1 file changed, 190 insertions(+), 11 deletions(-)

diff --git a/test/test_mmdc_inference.py b/test/test_mmdc_inference.py
index b7b869b1..3f4265ba 100644
--- a/test/test_mmdc_inference.py
+++ b/test/test_mmdc_inference.py
@@ -5,6 +5,7 @@ import os
 from pathlib import Path
 
 import pytest
+import rasterio as rio
 import torch
 
 from mmdc_singledate.inference.components.inference_components import (  # predict_tile,
@@ -76,8 +77,6 @@ def test_mmdc_full_model():
     checkpoint_path = "/home/uz/vinascj/scratch/MMDC/results/latent/logs/experiments/runs/mmdc_full/2023-01-06_13-00-34/checkpoints"
     checkpoint_filename = "epoch_008.ckpt"  # "last.ckpt"
 
-    # lightning_dict = torch.load("epoch_008.ckpt", map_location=torch.device("cpu"))
-
     mmdc_full_model = get_mmdc_full_model(
         os.path.join(checkpoint_path, checkpoint_filename)
     )
@@ -218,17 +217,61 @@ def dummy_process(
 
 
 # TODO add cases with no data
+# @pytest.mark.skip(reason="Test with dummy process ")
+# @pytest.mark.parametrize("sensors", sensors_test)
+# def test_predict_single_date_tile_with_dummy_process(sensors):
+#     """ """
+#     export_path = f"/work/CESBIO/projects/MAESTRIA/test_onetile/total/export/test_latent_singledate_{'_'.join(sensors)}.tif"
+#     nb_bands = 4 #  len(sensors) * 2
+#     print("nb_bands  :", nb_bands)
+#     process = MMDCProcess(
+#         count=nb_bands, #4,
+#         nb_lines=1024,
+#         patch_size=256,
+#         process=dummy_process,
+#     )
+
+#     datapath = "/work/CESBIO/projects/MAESTRIA/test_onetile/total/T31TCJ/"
+#     input_data = GeoTiffDataset(
+#         s2_filename=os.path.join(
+#             datapath, "SENTINEL2A_20180302-105023-464_L2A_T31TCJ_C_V2-2_roi_0.tif"
+#         ),
+#         s1_asc_filename=os.path.join(
+#             datapath,
+#             "S1A_IW_GRDH_1SDV_20180303T174722_20180303T174747_020854_023C45_4525_36.09887223808782_15.429479982324139_roi_0.tif",
+#         ),
+#         s1_desc_filename=os.path.join(
+#             datapath,
+#             "S1A_IW_GRDH_1SDV_20180302T060027_20180302T060052_020832_023B8F_C485_40.83835645911643_165.05888005216622_roi_0.tif",
+#         ),
+#         srtm_filename=os.path.join(datapath, "srtm_T31TCJ_roi_0.tif"),
+#         wc_filename=os.path.join(datapath, "wc_clim_1_T31TCJ_roi_0.tif"),
+#     )
+
+#     predict_single_date_tile(
+#         input_data=input_data,
+#         export_path=export_path,
+#         sensors=sensors,  # ["S2L2A"],
+#         process=process,
+#     )
+
+#     assert Path(export_path).exists() == True
+
+#     with rio.open(export_path) as predited_raster:
+#         predicted_metadata = predited_raster.meta.copy()
+
+#         assert predicted_metadata["count"] == nb_bands
+
+
+# TODO add cases with no data
+# @pytest.mark.skip(reason="Dissable for faster testing")
 @pytest.mark.parametrize("sensors", sensors_test)
 def test_predict_single_date_tile(sensors):
     """ """
-    export_path = f"/work/CESBIO/projects/MAESTRIA/test_onetile/total/export/test_latent_singledate_{'_'.join(sensors)}.tif"
-    process = MMDCProcess(
-        count=4,
-        nb_lines=1024,
-        patch_size=256,
-        process=dummy_process,
-    )
 
+    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+    export_path = f"/work/CESBIO/projects/MAESTRIA/test_onetile/total/export/test_latent_singledate_with_model_{'_'.join(sensors)}.tif"
     datapath = "/work/CESBIO/projects/MAESTRIA/test_onetile/total/T31TCJ/"
     input_data = GeoTiffDataset(
         s2_filename=os.path.join(
@@ -246,16 +289,152 @@ def test_predict_single_date_tile(sensors):
         wc_filename=os.path.join(datapath, "wc_clim_1_T31TCJ_roi_0.tif"),
     )
 
+    nb_bands = len(sensors) * 8
+    print("nb_bands  :", nb_bands)
+
+    # checkpoints
+    checkpoint_path = "/home/uz/vinascj/scratch/MMDC/results/latent/logs/experiments/runs/mmdc_full/2023-01-06_13-00-34/checkpoints"
+    checkpoint_filename = "epoch_008.ckpt"  # "last.ckpt"
+
+    mmdc_full_model = get_mmdc_full_model(
+        os.path.join(checkpoint_path, checkpoint_filename)
+    )
+    # move to device
+    mmdc_full_model.to(device)
+    assert mmdc_full_model.training == False
+
+    # init the process
+    process = MMDCProcess(
+        count=nb_bands,  # 4,
+        nb_lines=256,
+        patch_size=128,
+        model=mmdc_full_model,
+        process=predict_mmdc_model,
+    )
+
+    # entry point pred function
     predict_single_date_tile(
         input_data=input_data,
         export_path=export_path,
-        sensors=sensors,  # ["S2L2A"],
+        sensors=sensors,
         process=process,
     )
 
     assert Path(export_path).exists() == True
 
+    with rio.open(export_path) as predited_raster:
+        predicted_metadata = predited_raster.meta.copy()
+        assert predicted_metadata["count"] == nb_bands
+
+
+s2_filename_variable = "SENTINEL2A_20180302-105023-464_L2A_T31TCJ_C_V2-2_roi_0.tif"
+s1_asc_filename_variable = "S1A_IW_GRDH_1SDV_20180303T174722_20180303T174747_020854_023C45_4525_36.09887223808782_15.429479982324139_roi_0.tif"
+s1_desc_filename_variable = "S1A_IW_GRDH_1SDV_20180302T060027_20180302T060052_020832_023B8F_C485_40.83835645911643_165.05888005216622_roi_0.tif"
+srtm_filename_variable = "srtm_T31TCJ_roi_0.tif"
+wc_filename_variable = "wc_clim_1_T31TCJ_roi_0.tif"
+
+datasets = [
+    (
+        s2_filename_variable,
+        s1_asc_filename_variable,
+        s1_desc_filename_variable,
+        srtm_filename_variable,
+        wc_filename_variable,
+    ),
+    (
+        "",
+        s1_asc_filename_variable,
+        s1_desc_filename_variable,
+        srtm_filename_variable,
+        wc_filename_variable,
+    ),
+    (
+        s2_filename_variable,
+        "",
+        s1_desc_filename_variable,
+        srtm_filename_variable,
+        wc_filename_variable,
+    ),
+    (
+        s2_filename_variable,
+        s1_asc_filename_variable,
+        "",
+        srtm_filename_variable,
+        wc_filename_variable,
+    ),
+]
+
+
+@pytest.mark.parametrize(
+    "s2_filename_variable, s1_asc_filename_variable, s1_desc_filename_variable, srtm_filename_variable, wc_filename_variable",
+    datasets,
+)
+def test_predict_single_date_tile_no_data(
+    s2_filename_variable,
+    s1_asc_filename_variable,
+    s1_desc_filename_variable,
+    srtm_filename_variable,
+    wc_filename_variable,
+):
+    """ """
+    sensors = ["S2L2A", "S1ASC"]
+    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+    export_path = f"/work/CESBIO/projects/MAESTRIA/test_onetile/total/export/test_latent_singledate_with_model_{'_'.join(sensors)}.tif"
+    datapath = "/work/CESBIO/projects/MAESTRIA/test_onetile/total/T31TCJ/"
+    input_data = GeoTiffDataset(
+        s2_filename=os.path.join(datapath, s2_filename_variable),
+        s1_asc_filename=os.path.join(
+            datapath,
+            s1_asc_filename_variable,
+        ),
+        s1_desc_filename=os.path.join(
+            datapath,
+            s1_desc_filename_variable,
+        ),
+        srtm_filename=os.path.join(datapath, srtm_filename_variable),
+        wc_filename=os.path.join(datapath, wc_filename_variable),
+    )
+
+    nb_bands = len(sensors) * 8
+    print("nb_bands  :", nb_bands)
+
+    # checkpoints
+    checkpoint_path = "/home/uz/vinascj/scratch/MMDC/results/latent/logs/experiments/runs/mmdc_full/2023-01-06_13-00-34/checkpoints"
+    checkpoint_filename = "epoch_008.ckpt"  # "last.ckpt"
+
+    mmdc_full_model = get_mmdc_full_model(
+        os.path.join(checkpoint_path, checkpoint_filename)
+    )
+    # move to device
+    mmdc_full_model.to(device)
+    assert mmdc_full_model.training == False
+
+    # init the process
+    process = MMDCProcess(
+        count=nb_bands,  # 4,
+        nb_lines=256,
+        patch_size=128,
+        model=mmdc_full_model,
+        process=predict_mmdc_model,
+    )
+
+    # entry point pred function
+    predict_single_date_tile(
+        input_data=input_data,
+        export_path=export_path,
+        sensors=sensors,
+        process=process,
+    )
+
+    assert Path(export_path).exists() == True
+
+    with rio.open(export_path) as predited_raster:
+        predicted_metadata = predited_raster.meta.copy()
+        assert predicted_metadata["count"] == nb_bands
+
 
+@pytest.mark.skip(reason="Check Time Serie generation")
 @pytest.mark.parametrize("sensors", sensors_test)
 def test_mmdc_tile_inference(sensors):
     """
@@ -277,5 +456,5 @@ def test_mmdc_tile_inference(sensors):
         tile_list="T31TCJ",
         days_gap=15,
         latent_space_size=4,
-        nb_lines=1024,
+        nb_lines=256,
     )
-- 
GitLab


From a9a493db915be57d56735434550cb80a6323f75a Mon Sep 17 00:00:00 2001
From: Juan Sebastian Vinasco-Salinas <js.vinasco.s@gmail.com>
Date: Wed, 15 Mar 2023 16:06:46 +0000
Subject: [PATCH 36/81] test time serie

---
 src/mmdc_singledate/inference/components/inference_utils.py | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/src/mmdc_singledate/inference/components/inference_utils.py b/src/mmdc_singledate/inference/components/inference_utils.py
index 9d54bfe0..7963eb1a 100644
--- a/src/mmdc_singledate/inference/components/inference_utils.py
+++ b/src/mmdc_singledate/inference/components/inference_utils.py
@@ -39,14 +39,14 @@ class GeoTiffDataset:
     """
 
     s2_filename: str
-    s2_availabitity: bool = field(init=False)
     s1_asc_filename: str
-    s1_asc_availability: bool = field(init=False)
     s1_desc_filename: str
-    s1_desc_availability: bool = field(init=False)
     srtm_filename: str
     wc_filename: str
     metadata: dict = field(init=False)
+    s2_availabitity: bool | None = field(default=None, init=True)
+    s1_asc_availability: bool | None = field(default=None, init=True)
+    s1_desc_availability: bool | None = field(default=None, init=True)
 
     def __post_init__(self) -> None:
         # check if the data exists
-- 
GitLab


From 5d47d8437d9c45cdc81b65dfd5b15e9aea2a5247 Mon Sep 17 00:00:00 2001
From: Juan Sebastian Vinasco-Salinas <js.vinasco.s@gmail.com>
Date: Wed, 15 Mar 2023 16:08:23 +0000
Subject: [PATCH 37/81] test time serie

---
 .../inference/mmdc_tile_inference.py          | 129 ++++++++++++++----
 1 file changed, 104 insertions(+), 25 deletions(-)

diff --git a/src/mmdc_singledate/inference/mmdc_tile_inference.py b/src/mmdc_singledate/inference/mmdc_tile_inference.py
index 88cb3bfb..6780d5d8 100644
--- a/src/mmdc_singledate/inference/mmdc_tile_inference.py
+++ b/src/mmdc_singledate/inference/mmdc_tile_inference.py
@@ -12,6 +12,7 @@ from dataclasses import dataclass
 from pathlib import Path
 
 import numpy as np
+import pandas as pd
 import rasterio as rio
 import torch
 from torchutils import patches
@@ -292,14 +293,64 @@ def predict_single_date_tile(
                 logger.info(("Export tile", f"filename :{export_path}"))
 
 
+def estimate_nb_latent_spaces(
+    patch_s2_availability: bool,
+    patchasc_s1_availability: bool,
+    patchdesc_s1_availability: bool,
+) -> int:
+    """
+    Given the availability of the input data
+    estimate the number of latent variables
+    """
+
+    if (patch_s2_availability and patchasc_s1_availability) or (
+        patch_s2_availability and patchdesc_s1_availability
+    ):
+        nb_latent_spaces = 8 * 2
+
+    else:
+        nb_latent_spaces = 4 * 2
+
+    return nb_latent_spaces
+
+
+def determinate_sensor_list(
+    patch_s2_availability: bool,
+    patchasc_s1_availability: bool,
+    patchdesc_s1_availability: bool,
+) -> list:
+    """
+    Convert availabitiy dataframe to list of sensors
+    """
+    sensors_availability = [
+        patch_s2_availability,
+        patchasc_s1_availability,
+        patchdesc_s1_availability,
+    ]
+    match sensors_availability:
+        case [True, True, True]:
+            sensors = ["S2L2A", "S1FULL"]
+        case [True, True, False]:
+            sensors = ["S2L2A", "S1ASC"]
+        case [True, False, True]:
+            sensors = ["S2L2A", "S1DESC"]
+        case [False, True, True]:
+            sensors = ["S1FULL"]
+        case [True, False, False]:
+            sensors = ["S2L2A"]
+        case [False, True, False]:
+            sensors = ["S1ASC"]
+        case [False, False, True]:
+            sensors = ["S1DESC"]
+
+    return sensors
+
+
 def mmdc_tile_inference(
-    input_path: Path,
+    inference_dataframe: pd.DataFrame,
     export_path: Path,
     model_checkpoint_path: Path,
-    sensor: list[str],
-    tile_list: list[str],
-    days_gap: int,
-    latent_space_size: int = 4,
+    patch_size: int = 256,
     nb_lines: int = 1024,
 ) -> None:
     """
@@ -316,39 +367,67 @@ def mmdc_tile_inference(
         :return : None
 
     """
-    # dataframe with input data
-    tile_df = inference_dataframe(
-        samples_dir=input_path,
-        input_tile_list=tile_list,
-        days_gap=days_gap,
-    )
+    # device
+
+    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+    # # dataframe with input data
+    # tile_df = inference_dataframe(
+    #     samples_dir=input_path,
+    #     input_tile_list=tile_list,
+    #     days_gap=days_gap,
+    # )
+    # # just test the df
+    # tile_df = tile_df.head()
 
     # instance the model and get the pred func
     model = get_mmdc_full_model(
         checkpoint=model_checkpoint_path,
     )
-    pred_func = predict_mmdc_model(model=model, sensor=sensor)
-    #
-    mmdc_process = MMDCProcess(
-        count=latent_space_size, nb_lines=nb_lines, process=pred_func
-    )
+    model.to(device)
 
     # iterate over the dates in the time serie
-    for tuile, df_row in tile_df.iterrows():
+    for tuile, df_row in inference_dataframe.iterrows():
+        # estimate nb latent spaces
+        latent_space_size = estimate_nb_latent_spaces(
+            df_row["patch_s2_availability"],
+            df_row["patchasc_s1_availability"],
+            df_row["patchdesc_s1_availability"],
+        )
+        sensor = determinate_sensor_list(
+            df_row["patch_s2_availability"],
+            df_row["patchasc_s1_availability"],
+            df_row["patchdesc_s1_availability"],
+        )
+        # export path
+        date_ = df_row["date"].strftime("%Y-%m-%d")
+        export_file = export_path / f"latent_tile_infer_{date_}_{'_'.join(sensor)}.tif"
+        #
+        print(tuile, df_row)
+        print(latent_space_size)
+        # define process
+        mmdc_process = MMDCProcess(
+            count=latent_space_size,
+            nb_lines=nb_lines,
+            patch_size=patch_size,
+            model=model,
+            process=predict_mmdc_model,
+        )
         # get the input data
         mmdc_input_data = GeoTiffDataset(
-            s2_filename=df_row.patch_s2,
-            s1_asc_filename=df_row.patchasc_s1,
-            s1_asc_availibility=df_row.s1_asc_availibility,
-            s1_desc_filename=df_row.patchdesc_s1,
-            s1_desc_availibitity=df_row.s1_desc_availibitity,
-            srtm_filename=df_row.srtm_filename,
-            wc_filename=df_row.worldclim_filename,
+            s2_filename=df_row["patch_s2"],
+            s1_asc_filename=df_row["patchasc_s1"],
+            s1_desc_filename=df_row["patchdesc_s1"],
+            srtm_filename=df_row["srtm_filename"],
+            wc_filename=df_row["worldclim_filename"],
+            s2_availabitity=df_row["patch_s2_availability"],
+            s1_asc_availability=df_row["patchasc_s1_availability"],
+            s1_desc_availability=df_row["patchdesc_s1_availability"],
         )
         # predict tile
         predict_single_date_tile(
             input_data=mmdc_input_data,
-            export_path=export_path,
+            export_path=export_file,
             sensors=sensor,
             process=mmdc_process,
         )
-- 
GitLab


From d2ad4c82ed209e762861c2a746ca187827166067 Mon Sep 17 00:00:00 2001
From: Juan Sebastian Vinasco-Salinas <js.vinasco.s@gmail.com>
Date: Wed, 15 Mar 2023 16:12:38 +0000
Subject: [PATCH 38/81] update test time serie

---
 test/test_mmdc_inference.py | 154 +++++++++++++++++++++++++++++++++---
 1 file changed, 141 insertions(+), 13 deletions(-)

diff --git a/test/test_mmdc_inference.py b/test/test_mmdc_inference.py
index 3f4265ba..4f7b1f0a 100644
--- a/test/test_mmdc_inference.py
+++ b/test/test_mmdc_inference.py
@@ -2,8 +2,10 @@
 # Copyright: (c) 2023 CESBIO / Centre National d'Etudes Spatiales
 
 import os
+from datetime import datetime
 from pathlib import Path
 
+import pandas as pd
 import pytest
 import rasterio as rio
 import torch
@@ -28,6 +30,7 @@ from mmdc_singledate.models.components.model_dataclass import VAELatentSpace
 dataset_dir = "/work/CESBIO/projects/MAESTRIA/test_onetile/total/T31TCJ/"
 
 
+# @pytest.mark.skip(reason="Check Time Serie generation")
 def test_GeoTiffDataset():
     """ """
     datapath = "/work/CESBIO/projects/MAESTRIA/test_onetile/total/T31TCJ"
@@ -55,6 +58,7 @@ def test_GeoTiffDataset():
     assert input_data.s2_availabitity == True
 
 
+# @pytest.mark.skip(reason="Check Time Serie generation")
 def test_generate_chunks():
     """
     test chunk functionality
@@ -65,6 +69,7 @@ def test_generate_chunks():
     assert chunks[0] == (0, 0, 10980, 1024)
 
 
+# @pytest.mark.skip(reason="Check Time Serie generation")
 def test_mmdc_full_model():
     """
     test instantiate network
@@ -127,6 +132,7 @@ sensors_test = [
 ]
 
 
+# @pytest.mark.skip(reason="Check Time Serie generation")
 @pytest.mark.parametrize("sensors", sensors_test)
 def test_predict_mmdc_model(sensors):
     """ """
@@ -369,6 +375,7 @@ datasets = [
     "s2_filename_variable, s1_asc_filename_variable, s1_desc_filename_variable, srtm_filename_variable, wc_filename_variable",
     datasets,
 )
+# @pytest.mark.skip(reason="Check Time Serie generation")
 def test_predict_single_date_tile_no_data(
     s2_filename_variable,
     s1_asc_filename_variable,
@@ -434,27 +441,148 @@ def test_predict_single_date_tile_no_data(
         assert predicted_metadata["count"] == nb_bands
 
 
-@pytest.mark.skip(reason="Check Time Serie generation")
-@pytest.mark.parametrize("sensors", sensors_test)
-def test_mmdc_tile_inference(sensors):
+# @pytest.mark.skip(reason="Check Time Serie generation")
+def test_mmdc_tile_inference():
     """
     Test the inference code in a tile
     """
     # feed parameters
-    input_path = Path("/work/CESBIO/projects/MAESTRIA/test_onetile/total/T31TCJ/")
-    export_path = Path(
-        "/work/CESBIO/projects/MAESTRIA/test_onetile/export/test_latent_tileinfer_{'_'.join(sensors)}.tif"
-    )
+    input_path = Path("/work/CESBIO/projects/MAESTRIA/training_dataset2/")
+    export_path = Path("/work/CESBIO/projects/MAESTRIA/test_onetile/total/export/")
     model_checkpoint_path = "/home/uz/vinascj/scratch/MMDC/results/latent/logs/experiments/runs/mmdc_full/2023-01-06_13-00-34/checkpoints/epoch_008.ckpt"
-    # apply prediction
 
+    input_tile_list = ["/work/CESBIO/projects/MAESTRIA/training_dataset2/T33TUL"]
+
+    # dataframe with input data
+    tile_df = pd.DataFrame(
+        {
+            "date": [
+                datetime.strptime("2018-01-11", "%Y-%m-%d"),
+                datetime.strptime("2018-01-14", "%Y-%m-%d"),
+                datetime.strptime("2018-01-19", "%Y-%m-%d"),
+                datetime.strptime("2018-01-21", "%Y-%m-%d"),
+                datetime.strptime("2018-01-24", "%Y-%m-%d"),
+            ],
+            "patch_s2": [
+                "/work/CESBIO/projects/MAESTRIA/training_dataset2/T33TUL/SENTINEL2B_20180111-100351-463_L2A_T33TUL_D_V1-4_roi_0.tif",
+                "/work/CESBIO/projects/MAESTRIA/training_dataset2/T33TUL/SENTINEL2B_20180114-101347-463_L2A_T33TUL_D_V1-4_roi_0.tif",
+                "/work/CESBIO/projects/MAESTRIA/training_dataset2/T33TUL/SENTINEL2A_20180119-101331-457_L2A_T33TUL_D_V1-4_roi_0.tif",
+                "/work/CESBIO/projects/MAESTRIA/training_dataset2/T33TUL/SENTINEL2B_20180121-100542-770_L2A_T33TUL_D_V1-4_roi_0.tif",
+                "/work/CESBIO/projects/MAESTRIA/training_dataset2/T33TUL/SENTINEL2B_20180124-101352-753_L2A_T33TUL_D_V1-4_roi_0.tif",
+            ],
+            "patchasc_s1": [
+                "/work/CESBIO/projects/MAESTRIA/training_dataset2/T33TUL/S1A_IW_GRDH_1SDV_20180108T165831_20180108T165856_020066_022322_3008_36.29396013100586_14.593286735562288_roi_0.tif",
+                "/work/CESBIO/projects/MAESTRIA/training_dataset2/T33TUL/S1B_IW_GRDH_1SDV_20180114T165747_20180114T165812_009170_0106A2_AFD8_36.27898252884488_14.533603849909781_roi_0.tif",
+                "/work/CESBIO/projects/MAESTRIA/training_dataset2/T33TUL/nan",
+                "/work/CESBIO/projects/MAESTRIA/training_dataset2/T33TUL/S1A_IW_GRDH_1SDV_20180120T165830_20180120T165855_020241_0228B0_95C7_36.29306272148375_14.591766979064303_roi_0.tif",
+                "/work/CESBIO/projects/MAESTRIA/training_dataset2/T33TUL/nan",
+            ],
+            "patchdesc_s1": [
+                "/work/CESBIO/projects/MAESTRIA/training_dataset2/T33TUL/nan",
+                "/work/CESBIO/projects/MAESTRIA/training_dataset2/T33TUL/S1B_IW_GRDH_1SDV_20180113T051001_20180113T051026_009148_0105F4_218B_45.70533932879138_164.47774901324672_roi_0.tif",
+                "/work/CESBIO/projects/MAESTRIA/training_dataset2/T33TUL/S1A_IW_GRDH_1SDV_20180119T051046_20180119T051111_020219_0227FF_F7E9_45.71131905091333_164.664087191203_roi_0.tif",
+                "/work/CESBIO/projects/MAESTRIA/training_dataset2/T33TUL/S1A_IW_GRDH_1SDV_20180119T051046_20180119T051111_020219_0227FF_F7E9_45.71131905091333_164.664087191203_roi_0.tif",
+                "/work/CESBIO/projects/MAESTRIA/training_dataset2/T33TUL/S1A_IW_GRDH_1SDV_20180124T051848_20180124T051913_020292_022A63_177C_35.81756575888757_165.03806323648553_roi_0.tif",
+            ],
+            "Tuiles": [
+                "T33TUL",
+                "T33TUL",
+                "T33TUL",
+                "T33TUL",
+                "T33TUL",
+            ],
+            "patch": [
+                0,
+                0,
+                0,
+                0,
+                0,
+            ],
+            "roi": [
+                0,
+                0,
+                0,
+                0,
+                0,
+            ],
+            "worldclim_filename": [
+                "/work/CESBIO/projects/MAESTRIA/training_dataset2/T33TUL/wc_clim_1_T33TUL_roi_0.tif",
+                "/work/CESBIO/projects/MAESTRIA/training_dataset2/T33TUL/wc_clim_1_T33TUL_roi_0.tif",
+                "/work/CESBIO/projects/MAESTRIA/training_dataset2/T33TUL/wc_clim_1_T33TUL_roi_0.tif",
+                "/work/CESBIO/projects/MAESTRIA/training_dataset2/T33TUL/wc_clim_1_T33TUL_roi_0.tif",
+                "/work/CESBIO/projects/MAESTRIA/training_dataset2/T33TUL/wc_clim_1_T33TUL_roi_0.tif",
+            ],
+            "srtm_filename": [
+                "/work/CESBIO/projects/MAESTRIA/training_dataset2/T33TUL/srtm_T33TUL_roi_0.tif",
+                "/work/CESBIO/projects/MAESTRIA/training_dataset2/T33TUL/srtm_T33TUL_roi_0.tif",
+                "/work/CESBIO/projects/MAESTRIA/training_dataset2/T33TUL/srtm_T33TUL_roi_0.tif",
+                "/work/CESBIO/projects/MAESTRIA/training_dataset2/T33TUL/srtm_T33TUL_roi_0.tif",
+                "/work/CESBIO/projects/MAESTRIA/training_dataset2/T33TUL/srtm_T33TUL_roi_0.tif",
+            ],
+            "patch_s2_availability": [
+                True,
+                True,
+                True,
+                True,
+                True,
+            ],
+            "patchasc_s1_availability": [
+                True,
+                True,
+                False,
+                True,
+                False,
+            ],
+            "patchdesc_s1_availability": [
+                True,
+                True,
+                False,
+                True,
+                False,
+            ],
+        }
+    )
+
+    # inference_dataframe(
+    #     samples_dir=input_path,
+    #     input_tile_list=input_tile_list,
+    #     days_gap=5,
+    # )
+    # # just test the df
+    print(
+        tile_df.loc[
+            :,
+            [
+                "patch_s2_availability",
+                "patchasc_s1_availability",
+                "patchdesc_s1_availability",
+            ],
+        ]
+    )
+
+    # apply prediction
     mmdc_tile_inference(
-        input_path=input_path,
+        inference_dataframe=tile_df,
         export_path=export_path,
         model_checkpoint_path=model_checkpoint_path,
-        sensor=sensors,
-        tile_list="T31TCJ",
-        days_gap=15,
-        latent_space_size=4,
+        patch_size=128,
         nb_lines=256,
     )
+
+    assert Path(export_path / "latent_tile_infer_2018-01-11_S2L2A_S1FULL.tif").exists()
+    assert Path(export_path / "latent_tile_infer_2018-01-14_S2L2A_S1FULL.tif").exists()
+    assert Path(export_path / "latent_tile_infer_2018-01-19_S2L2A.tif").exists()
+    assert Path(export_path / "latent_tile_infer_2018-01-21_S2L2A_S1FULL.tif").exists()
+    assert Path(export_path / "latent_tile_infer_2018-01-24_S2L2A.tif").exists()
+
+    with rio.open(
+        os.path.join(export_path, "latent_tile_infer_2018-01-19_S2L2A.tif")
+    ) as raster:
+        metadata = raster.meta.copy()
+    assert metadata["count"] == 8
+
+    with rio.open(
+        os.path.join(export_path, "latent_tile_infer_2018-01-11_S2L2A_S1FULL.tif")
+    ) as raster:
+        metadata = raster.meta.copy()
+    assert metadata["count"] == 16
-- 
GitLab


From abcb4cfff5e51c05ea85057f711b159f349e01ec Mon Sep 17 00:00:00 2001
From: Juan Sebastian Vinasco-Salinas <js.vinasco.s@gmail.com>
Date: Wed, 15 Mar 2023 16:27:00 +0000
Subject: [PATCH 39/81] add jobs examples

---
 jobs/inference_single_date.pbs | 28 ++++++++++++++++++++++++++++
 jobs/inference_time_serie.pbs  | 25 +++++++++++++++++++++++++
 2 files changed, 53 insertions(+)
 create mode 100644 jobs/inference_single_date.pbs
 create mode 100644 jobs/inference_time_serie.pbs

diff --git a/jobs/inference_single_date.pbs b/jobs/inference_single_date.pbs
new file mode 100644
index 00000000..5a9ad983
--- /dev/null
+++ b/jobs/inference_single_date.pbs
@@ -0,0 +1,28 @@
+#!/bin/bash
+#PBS -N infer_sgdt
+#PBS -q qgpgpu
+#PBS -l select=1:ncpus=8:mem=92G:ngpus=1
+#PBS -l walltime=1:00:00
+
+# be sure no modules loaded
+module purge
+
+export SRCDIR=${HOME}/src/MMDC/mmdc-singledate
+export MMDC_INIT=${SRCDIR}/mmdc_init.sh
+export WORKING_DIR=/work/scratch/${USER}/MMDC/jobs
+export DATA_DIR=/work/CESBIO/projects/MAESTRIA/test_onetile/total/
+
+cd ${WORKING_DIR}
+source ${MMDC_INIT}
+
+python ${SRCDIR}/src/bin/inference_mmdc_singledate.py \
+    --s2_filename ${DATA_DIR}/T31TCJ/SENTINEL2A_20180302-105023-464_L2A_T31TCJ_C_V2-2_roi_0.tif \
+    --s1_asc_filename ${DATA_DIR}/T31TCJ/S1A_IW_GRDH_1SDV_20180303T174722_20180303T174747_020854_023C45_4525_36.09887223808782_15.429479982324139_roi_0.tif \
+    --s1_desc_filename ${DATA_DIR}/T31TCJ/S1A_IW_GRDH_1SDV_20180302T060027_20180302T060052_020832_023B8F_C485_40.83835645911643_165.05888005216622_roi_0.tif \
+    --srtm_filename ${DATA_DIR}/T31TCJ/srtm_T31TCJ_roi_0.tif \
+    --worldclim_filename ${DATA_DIR}/T31TCJ/wc_clim_1_T31TCJ_roi_0.tif \
+    --export_path ${DATA_DIR}/export/ \
+    --model_checkpoint_path /home/uz/vinascj/scratch/MMDC/results/latent/logs/experiments/runs/mmdc_full/2023-01-06_13-00-34/checkpoints/epoch_008.ckpt \
+    --latent_space_size \
+    --nb_lines 256 \
+    --sensors S2L2A S1FULL \
diff --git a/jobs/inference_time_serie.pbs b/jobs/inference_time_serie.pbs
new file mode 100644
index 00000000..c366fec5
--- /dev/null
+++ b/jobs/inference_time_serie.pbs
@@ -0,0 +1,25 @@
+#!/bin/bash
+#PBS -N infer_sgdt
+#PBS -q qgpgpu
+#PBS -l select=1:ncpus=8:mem=92G:ngpus=1
+#PBS -l walltime=2:00:00
+
+# be sure no modules loaded
+module purge
+
+export SRCDIR=${HOME}/src/MMDC/mmdc-singledate
+export MMDC_INIT=${SRCDIR}/mmdc_init.sh
+export WORKING_DIR=/work/scratch/${USER}/MMDC/jobs
+export DATA_DIR=/work/CESBIO/projects/MAESTRIA/training_dataset2/
+
+cd ${WORKING_DIR}
+source ${MMDC_INIT}
+
+python ${SRCDIR}/src/bin/inference_mmdc_timeserie.py \
+    --input_path ${DATA_DIR} \
+    --export_path ${DATA_DIR}/inference/ \
+    --tile_list T33TUL \
+    --model_checkpoint_path /home/uz/vinascj/scratch/MMDC/results/latent/logs/experiments/runs/mmdc_full/2023-01-06_13-00-34/checkpoints/epoch_008.ckpt \
+    --days_gap 1 \
+    --patch_size 128 \
+    --nb_lines 256 \
-- 
GitLab


From 26a0cc6a568ac245e3a739f4bd5c3df6599e2c8c Mon Sep 17 00:00:00 2001
From: Juan Sebastian Vinasco-Salinas <js.vinasco.s@gmail.com>
Date: Thu, 16 Mar 2023 08:23:15 +0000
Subject: [PATCH 40/81] minor changes

---
 src/bin/inference_mmdc_singledate.py | 30 ++++++++++++++++++----------
 1 file changed, 20 insertions(+), 10 deletions(-)

diff --git a/src/bin/inference_mmdc_singledate.py b/src/bin/inference_mmdc_singledate.py
index 37e19ad6..7b8db7d5 100644
--- a/src/bin/inference_mmdc_singledate.py
+++ b/src/bin/inference_mmdc_singledate.py
@@ -11,11 +11,13 @@ import argparse
 import logging
 import os
 
+import torch
+
 from mmdc_singledate.inference.components.inference_components import (
-    GeoTiffDataset,
     get_mmdc_full_model,
     predict_mmdc_model,
 )
+from mmdc_singledate.inference.components.inference_utils import GeoTiffDataset
 from mmdc_singledate.inference.mmdc_tile_inference import (
     MMDCProcess,
     predict_single_date_tile,
@@ -80,11 +82,18 @@ def get_parser() -> argparse.ArgumentParser:
 
     arg_parser.add_argument(
         "--nb_lines",
-        type=str,
+        type=int,
         help="Number of Line to read in each window reading",
         required=False,
     )
 
+    arg_parser.add_argument(
+        "--patch_size",
+        type=int,
+        help="Number of Line to read in each window reading",
+        required=True,
+    )
+
     arg_parser.add_argument(
         "--sensors",
         dest="sensors",
@@ -115,24 +124,25 @@ def main():
         format="%(asctime)s :: %(levelname)s :: %(message)s",
     )
 
+    # device
+    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
     # model
-    model = get_mmdc_full_model(
+    mmdc_full_model = get_mmdc_full_model(
         checkpoint=args.model_checkpoint_path,
     )
-
-    # prediction function
-    pred_mmdc = predict_mmdc_model(model=model, sensor=args.sensor)
-
+    mmdc_full_model.to(device)
     # process class
     mmdc_process = MMDCProcess(
         count=args.latent_space_size,
         nb_lines=args.nb_lines,
         patch_size=args.patch_size,
-        process=pred_mmdc,
+        process=predict_mmdc_model,
+        model=mmdc_full_model,
     )
 
     # create export filename
-    export_path = f"/work/CESBIO/projects/MAESTRIA/test_onetile/total/export/latent_singledate_{'_'.join(args.sensors)}.tif"
+    export_path = f"{args.export_path}/latent_singledate_{'_'.join(args.sensors)}.tif"
 
     # GeoTiff
     geotiff_dataclass = GeoTiffDataset(
@@ -140,7 +150,7 @@ def main():
         s1_asc_filename=args.s1_asc_filename,
         s1_desc_filename=args.s1_desc_filename,
         srtm_filename=args.srtm_filename,
-        wc_filename=args.wc_filename,
+        wc_filename=args.worldclim_filename,
     )
 
     predict_single_date_tile(
-- 
GitLab


From 18b6dbf0e17b2fc9c323223f9caeb41d88342942 Mon Sep 17 00:00:00 2001
From: Juan Sebastian Vinasco-Salinas <js.vinasco.s@gmail.com>
Date: Thu, 16 Mar 2023 08:24:09 +0000
Subject: [PATCH 41/81] minor changes

---
 src/bin/inference_mmdc_timeserie.py | 46 ++++++++++++++---------------
 1 file changed, 22 insertions(+), 24 deletions(-)

diff --git a/src/bin/inference_mmdc_timeserie.py b/src/bin/inference_mmdc_timeserie.py
index 62bd41f9..56626f63 100644
--- a/src/bin/inference_mmdc_timeserie.py
+++ b/src/bin/inference_mmdc_timeserie.py
@@ -12,7 +12,10 @@ import logging
 import os
 from pathlib import Path
 
-from mmdc_singledate.inference.mmdc_tile_inference import mmdc_tile_inference
+from mmdc_singledate.inference.mmdc_tile_inference import (
+    inference_dataframe,
+    mmdc_tile_inference,
+)
 
 
 def get_parser() -> argparse.ArgumentParser:
@@ -43,7 +46,7 @@ def get_parser() -> argparse.ArgumentParser:
 
     arg_parser.add_argument(
         "--days_gap",
-        type=str,
+        type=int,
         help="difference between S1 and S2 acquisitions in days",
         required=True,
     )
@@ -56,38 +59,27 @@ def get_parser() -> argparse.ArgumentParser:
     )
 
     arg_parser.add_argument(
-        "--tile_list", type=str, help="List of tiles to produce", required=False
-    )
-
-    arg_parser.add_argument(
-        "--patch_size",
+        "--tile_list",
+        nargs="+",
         type=str,
-        help="Sizes of the patches to make the inference",
+        help="List of tiles to produce",
         required=False,
     )
 
     arg_parser.add_argument(
-        "--latent_space_size",
+        "--patch_size",
         type=int,
-        help="Number of channels in the output file",
+        help="Sizes of the patches to make the inference",
         required=False,
     )
 
     arg_parser.add_argument(
         "--nb_lines",
-        type=str,
+        type=int,
         help="Number of Line to read in each window reading",
         required=False,
     )
 
-    arg_parser.add_argument(
-        "-sensors",
-        dest="sensors",
-        nargs="+",
-        help="List of sensors S2L2A | S1FULL | S1ASC | S1DESC (Is sensible to the order S2L2A always firts)",
-        required=True,
-    )
-
     return arg_parser
 
 
@@ -116,18 +108,24 @@ def main():
         )
         Path(args.export_path).mkdir()
 
+    # calculate time serie
+    tile_df = inference_dataframe(
+        samples_dir=args.input_path,
+        input_tile_list=args.tile_list,
+        days_gap=args.days_gap,
+    )
+    logging.info(f"nb of dates to infer : {tile_df.shape[0]}")
     # entry point
     mmdc_tile_inference(
-        input_path=Path(args.input_path),
+        inference_dataframe=tile_df,
         export_path=Path(args.export_path),
         model_checkpoint_path=Path(args.model_checkpoint_path),
-        sensor=args.sensors,
-        tile_list=args.tile_list,
-        days_gap=args.days_gap,
-        latent_space_size=args.laten_space_size,
+        patch_size=args.patch_size,
         nb_lines=args.nb_lines,
     )
 
+    logging.info("All the data has exported succesfully")
+
 
 if __name__ == "__main__":
     main()
-- 
GitLab


From 57e8378247ce6e1bd223cb0f8b0fb205806c5d06 Mon Sep 17 00:00:00 2001
From: Juan Sebastian Vinasco-Salinas <js.vinasco.s@gmail.com>
Date: Thu, 16 Mar 2023 08:25:02 +0000
Subject: [PATCH 42/81] minor changes

---
 .../components/inference_components.py        | 61 +++++++++++--------
 1 file changed, 35 insertions(+), 26 deletions(-)

diff --git a/src/mmdc_singledate/inference/components/inference_components.py b/src/mmdc_singledate/inference/components/inference_components.py
index 934aa525..c7130a33 100644
--- a/src/mmdc_singledate/inference/components/inference_components.py
+++ b/src/mmdc_singledate/inference/components/inference_components.py
@@ -181,8 +181,7 @@ def predict_mmdc_model(
     match sensors:
         # Cases
         case ["S2L2A", "S1FULL"]:
-            logger.info("S1 full captured")
-            logger.info("S2 captured")
+            logger.info("S2 captured & S1 full captured")
 
             latent_space.latent_s2 = prediction[0].latent.latent_s2
             latent_space.latent_s1 = prediction[0].latent.latent_s1
@@ -199,18 +198,19 @@ def predict_mmdc_model(
 
             print(
                 "latent_space.latent_s2.mu :",
-                latent_space.latent_s2.mu,
+                latent_space.latent_s2.mu.shape,
             )
             print(
                 "latent_space.latent_s2.logvar :",
-                latent_space.latent_s2.logvar,
+                latent_space.latent_s2.logvar.shape,
+            )
+            print("latent_space.latent_s1.mu :", latent_space.latent_s1.mu.shape)
+            print(
+                "latent_space.latent_s1.logvar :", latent_space.latent_s1.logvar.shape
             )
-            print("latent_space.latent_s1.mu :", latent_space.latent_s1.mu)
-            print("latent_space.latent_s1.logvar :", latent_space.latent_s1.logvar)
 
         case ["S2L2A", "S1ASC"]:
-            logger.info("S1 asc captured")
-            logger.info("S2 captured")
+            logger.info("S2 captured & S1 asc captured")
 
             latent_space.latent_s2 = prediction[0].latent.latent_s2
             latent_space.latent_s1 = prediction[0].latent.latent_s1
@@ -226,18 +226,19 @@ def predict_mmdc_model(
             )
             print(
                 "latent_space.latent_s2.mu :",
-                latent_space.latent_s2.mu,
+                latent_space.latent_s2.mu.shape,
             )
             print(
                 "latent_space.latent_s2.logvar :",
-                latent_space.latent_s2.logvar,
+                latent_space.latent_s2.logvar.shape,
+            )
+            print("latent_space.latent_s1.mu :", latent_space.latent_s1.mu.shape)
+            print(
+                "latent_space.latent_s1.logvar :", latent_space.latent_s1.logvar.shape
             )
-            print("latent_space.latent_s1.mu :", latent_space.latent_s1.mu)
-            print("latent_space.latent_s1.logvar :", latent_space.latent_s1.logvar)
 
         case ["S2L2A", "S1DESC"]:
-            logger.info("S1 desc captured")
-            logger.info("S2 captured")
+            logger.info("S2 captured & S1 desc captured")
 
             latent_space.latent_s2 = prediction[0].latent.latent_s2
             latent_space.latent_s1 = prediction[0].latent.latent_s1
@@ -253,14 +254,16 @@ def predict_mmdc_model(
             )
             print(
                 "latent_space.latent_s2.mu :",
-                latent_space.latent_s2.mu,
+                latent_space.latent_s2.mu.shape,
             )
             print(
                 "latent_space.latent_s2.logvar :",
-                latent_space.latent_s2.logvar,
+                latent_space.latent_s2.logvar.shape,
+            )
+            print("latent_space.latent_s1.mu :", latent_space.latent_s1.mu.shape)
+            print(
+                "latent_space.latent_s1.logvar :", latent_space.latent_s1.logvar.shape
             )
-            print("latent_space.latent_s1.mu :", latent_space.latent_s1.mu)
-            print("latent_space.latent_s1.logvar :", latent_space.latent_s1.logvar)
 
         case ["S2L2A"]:
             logger.info("Only S2 captured")
@@ -276,11 +279,11 @@ def predict_mmdc_model(
             )
             print(
                 "latent_space.latent_s2.mu :",
-                latent_space.latent_s2.mu,
+                latent_space.latent_s2.mu.shape,
             )
             print(
                 "latent_space.latent_s2.logvar :",
-                latent_space.latent_s2.logvar,
+                latent_space.latent_s2.logvar.shape,
             )
 
         case ["S1FULL"]:
@@ -292,8 +295,10 @@ def predict_mmdc_model(
                 (latent_space.latent_s1.mu, latent_space.latent_s1.logvar),
                 1,
             )
-            print("latent_space.latent_s1.mu :", latent_space.latent_s1.mu)
-            print("latent_space.latent_s1.logvar :", latent_space.latent_s1.logvar)
+            print("latent_space.latent_s1.mu :", latent_space.latent_s1.mu.shape)
+            print(
+                "latent_space.latent_s1.logvar :", latent_space.latent_s1.logvar.shape
+            )
 
         case ["S1ASC"]:
             logger.info("Only S1ASC captured")
@@ -304,8 +309,10 @@ def predict_mmdc_model(
                 (latent_space.latent_s1.mu, latent_space.latent_s1.logvar),
                 1,
             )
-            print("latent_space.latent_s1.mu :", latent_space.latent_s1.mu)
-            print("latent_space.latent_s1.logvar :", latent_space.latent_s1.logvar)
+            print("latent_space.latent_s1.mu :", latent_space.latent_s1.mu.shape)
+            print(
+                "latent_space.latent_s1.logvar :", latent_space.latent_s1.logvar.shape
+            )
 
         case ["S1DESC"]:
             logger.info("Only S1DESC captured")
@@ -316,7 +323,9 @@ def predict_mmdc_model(
                 (latent_space.latent_s1.mu, latent_space.latent_s1.logvar),
                 1,
             )
-            print("latent_space.latent_s1.mu :", latent_space.latent_s1.mu)
-            print("latent_space.latent_s1.logvar :", latent_space.latent_s1.logvar)
+            print("latent_space.latent_s1.mu :", latent_space.latent_s1.mu.shape)
+            print(
+                "latent_space.latent_s1.logvar :", latent_space.latent_s1.logvar.shape
+            )
 
     return latent_space_stack  # latent_space
-- 
GitLab


From 430297c8cb42835a25be2ab083ebbbf0cb05eed4 Mon Sep 17 00:00:00 2001
From: Juan Sebastian Vinasco-Salinas <js.vinasco.s@gmail.com>
Date: Thu, 16 Mar 2023 08:25:28 +0000
Subject: [PATCH 43/81] minor changes

---
 src/mmdc_singledate/inference/mmdc_tile_inference.py | 11 +++--------
 1 file changed, 3 insertions(+), 8 deletions(-)

diff --git a/src/mmdc_singledate/inference/mmdc_tile_inference.py b/src/mmdc_singledate/inference/mmdc_tile_inference.py
index 6780d5d8..123b4318 100644
--- a/src/mmdc_singledate/inference/mmdc_tile_inference.py
+++ b/src/mmdc_singledate/inference/mmdc_tile_inference.py
@@ -371,14 +371,9 @@ def mmdc_tile_inference(
 
     device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
-    # # dataframe with input data
-    # tile_df = inference_dataframe(
-    #     samples_dir=input_path,
-    #     input_tile_list=tile_list,
-    #     days_gap=days_gap,
-    # )
-    # # just test the df
-    # tile_df = tile_df.head()
+    # TODO Uncomment for test purpuse
+    # inference_dataframe = inference_dataframe.head()
+    # print(inference_dataframe.shape)
 
     # instance the model and get the pred func
     model = get_mmdc_full_model(
-- 
GitLab


From 477d7a83c55ef2aa5c0b2376b37dbddb782188c7 Mon Sep 17 00:00:00 2001
From: Juan Sebastian Vinasco-Salinas <js.vinasco.s@gmail.com>
Date: Thu, 16 Mar 2023 08:26:02 +0000
Subject: [PATCH 44/81] update

---
 jobs/inference_single_date.pbs | 10 ++++++++--
 1 file changed, 8 insertions(+), 2 deletions(-)

diff --git a/jobs/inference_single_date.pbs b/jobs/inference_single_date.pbs
index 5a9ad983..989f6ec6 100644
--- a/jobs/inference_single_date.pbs
+++ b/jobs/inference_single_date.pbs
@@ -11,8 +11,13 @@ export SRCDIR=${HOME}/src/MMDC/mmdc-singledate
 export MMDC_INIT=${SRCDIR}/mmdc_init.sh
 export WORKING_DIR=/work/scratch/${USER}/MMDC/jobs
 export DATA_DIR=/work/CESBIO/projects/MAESTRIA/test_onetile/total/
+export SCRATCH_DIR=/work/scratch/${USER}/
 
 cd ${WORKING_DIR}
+
+mkdir ${SCRATCH_DIR}/MMDC/inference/
+mkdir ${SCRATCH_DIR}/MMDC/inference/singledate
+
 source ${MMDC_INIT}
 
 python ${SRCDIR}/src/bin/inference_mmdc_singledate.py \
@@ -21,8 +26,9 @@ python ${SRCDIR}/src/bin/inference_mmdc_singledate.py \
     --s1_desc_filename ${DATA_DIR}/T31TCJ/S1A_IW_GRDH_1SDV_20180302T060027_20180302T060052_020832_023B8F_C485_40.83835645911643_165.05888005216622_roi_0.tif \
     --srtm_filename ${DATA_DIR}/T31TCJ/srtm_T31TCJ_roi_0.tif \
     --worldclim_filename ${DATA_DIR}/T31TCJ/wc_clim_1_T31TCJ_roi_0.tif \
-    --export_path ${DATA_DIR}/export/ \
+    --export_path ${SCRATCH_DIR}/MMDC/inference/singledate \
     --model_checkpoint_path /home/uz/vinascj/scratch/MMDC/results/latent/logs/experiments/runs/mmdc_full/2023-01-06_13-00-34/checkpoints/epoch_008.ckpt \
-    --latent_space_size \
+    --latent_space_size 16 \
+    --patch_size 128 \
     --nb_lines 256 \
     --sensors S2L2A S1FULL \
-- 
GitLab


From d9925c5facdd2b772423072996709c01a083adaf Mon Sep 17 00:00:00 2001
From: Juan Sebastian Vinasco-Salinas <js.vinasco.s@gmail.com>
Date: Thu, 16 Mar 2023 08:26:13 +0000
Subject: [PATCH 45/81] update

---
 jobs/inference_time_serie.pbs | 12 +++++++++---
 1 file changed, 9 insertions(+), 3 deletions(-)

diff --git a/jobs/inference_time_serie.pbs b/jobs/inference_time_serie.pbs
index c366fec5..582bf198 100644
--- a/jobs/inference_time_serie.pbs
+++ b/jobs/inference_time_serie.pbs
@@ -11,15 +11,21 @@ export SRCDIR=${HOME}/src/MMDC/mmdc-singledate
 export MMDC_INIT=${SRCDIR}/mmdc_init.sh
 export WORKING_DIR=/work/scratch/${USER}/MMDC/jobs
 export DATA_DIR=/work/CESBIO/projects/MAESTRIA/training_dataset2/
+export SCRATCH_DIR=/work/scratch/${USER}/
+
+
+mkdir ${SCRATCH_DIR}/MMDC/inference/
+mkdir ${SCRATCH_DIR}/MMDC/inference/time_serie
+
 
 cd ${WORKING_DIR}
 source ${MMDC_INIT}
 
 python ${SRCDIR}/src/bin/inference_mmdc_timeserie.py \
     --input_path ${DATA_DIR} \
-    --export_path ${DATA_DIR}/inference/ \
-    --tile_list T33TUL \
+    --export_path ${SCRATCH_DIR}/MMDC/inference/time_serie \
+    --tile_list ${DATA_DIR}/T33TUL \
     --model_checkpoint_path /home/uz/vinascj/scratch/MMDC/results/latent/logs/experiments/runs/mmdc_full/2023-01-06_13-00-34/checkpoints/epoch_008.ckpt \
-    --days_gap 1 \
+    --days_gap 0 \
     --patch_size 128 \
     --nb_lines 256 \
-- 
GitLab


From 5c5e9e3c3bde0961050a3bc2bd07a96f6f369089 Mon Sep 17 00:00:00 2001
From: Juan Sebastian Vinasco-Salinas <js.vinasco.s@gmail.com>
Date: Thu, 16 Mar 2023 08:39:13 +0000
Subject: [PATCH 46/81] add tiler config

---
 .../externalfeatures_with_userfeatures.cfg    | 24 +++++++++++++++++++
 1 file changed, 24 insertions(+)
 create mode 100644 configs/iota2/externalfeatures_with_userfeatures.cfg

diff --git a/configs/iota2/externalfeatures_with_userfeatures.cfg b/configs/iota2/externalfeatures_with_userfeatures.cfg
new file mode 100644
index 00000000..b6eb0fab
--- /dev/null
+++ b/configs/iota2/externalfeatures_with_userfeatures.cfg
@@ -0,0 +1,24 @@
+# iota2 configuration file to launch iota2 as a tiler of features.
+# paths with 'XXX' must be replaced by local user paths.
+
+
+chain :
+{
+  output_path : '/home/uz/vinascj/scratch/MMDC/iota2/grid_test_full'
+
+  first_step : 'tiler'
+  last_step : 'tiler'
+
+  proj : 'EPSG:2154'
+  # rasters_grid_path : '/XXX/grid_raster'
+  grid : '/home/uz/vinascj/src/MMDC/mmdc-singledate/thirdparties/sensorsio/src/sensorsio/data/sentinel2/mgrs_tiles.shp' #'/XXXX/Features.shp' # MGRS file providing tiles grid
+  list_tile : 'T31TCJ' # T31TCK T31TDJ'
+  features_path : '/XXX/non_tiled_features'
+
+  spatial_resolution : 10 # mandatory but not used in this case. The output spatial resolution will be the one found in 'rasters_grid_path' directory.
+}
+
+builders:
+{
+builders_class_name : ["i2_features_to_grid"]
+}
-- 
GitLab


From eaa797a0c35fe62102ce9324e651e78e5d1f7a7d Mon Sep 17 00:00:00 2001
From: Juan Sebastian Vinasco-Salinas <js.vinasco.s@gmail.com>
Date: Thu, 16 Mar 2023 08:39:24 +0000
Subject: [PATCH 47/81] add tiler config

---
 configs/iota2/i2_classification.cfg | 61 +++++++++++++++++++++++++++++
 1 file changed, 61 insertions(+)
 create mode 100644 configs/iota2/i2_classification.cfg

diff --git a/configs/iota2/i2_classification.cfg b/configs/iota2/i2_classification.cfg
new file mode 100644
index 00000000..71c1b728
--- /dev/null
+++ b/configs/iota2/i2_classification.cfg
@@ -0,0 +1,61 @@
+chain :
+{
+  output_path : '/work/scratch/sowm/IOTA2_TEST_S2/IOTA2_Outputs/Results_classif'
+  remove_output_path : True
+  nomenclature_path : '/work/scratch/sowm/IOTA2_TEST_S2/nomenclature23.txt'
+  list_tile : 'T31TCJ T31TDJ'
+  s2_path : '/work/scratch/sowm/IOTA2_TEST_S2/sensor_data'
+  s1_path : '/work/scratch/sowm/test-data-cube/iota2/s1_full.cfg'
+  ground_truth : '/work/scratch/sowm/IOTA2_TEST_S2/vector_data/reference_data.shp'
+  data_field : 'code'
+  spatial_resolution : 10
+  color_table : '/work/scratch/sowm/IOTA2_TEST_S2/colorFile.txt'
+  proj : 'EPSG:2154'
+  first_step: "init"
+  last_step: "validation"
+}
+sensors_data_interpolation:
+{
+  write_outputs: False
+}
+
+external_features:
+{
+
+  module: "/home/qt/sowm/scratch/test-data-cube/iota2/external_iota2_code.py"
+  functions: [["apply_convae",
+              { "checkpoint_path": "/home/qt/sowm/scratch/results/Convae_split_roi_psz_128_ep_150_bs_16_lr_1e-05_s1_[64, 128, 256, 512, 1024]_[8]_s2_[64, 128, 256, 512, 1024]_[8]_21916652.admin01",
+                "checkpoint_epoch": 100,
+                "patch_size": 74
+              }]]
+  concat_mode: True
+}
+
+python_data_managing:
+{
+ chunk_size_mode: "split_number"
+ number_of_chunks: 20
+ padding_size_y: 27
+
+}
+
+arg_train :
+{
+  classifier : 'sharkrf'
+  otb_classifier_options : {"classifier.sharkrf.ntrees" : 100 }
+  sample_selection: {"sampler": "random",
+  "strategy": "percent",
+  "strategy.percent.p": 0.1}
+  random_seed: 42
+
+}
+arg_classification :
+{
+  classif_mode : 'separate'
+}
+task_retry_limits :
+{
+  allowed_retry : 0
+  maximum_ram : 180.0
+  maximum_cpu : 40
+}
-- 
GitLab


From 20a0bc26aedd5b284bba058f7709d22d0c77ff38 Mon Sep 17 00:00:00 2001
From: Juan Sebastian Vinasco-Salinas <js.vinasco.s@gmail.com>
Date: Thu, 16 Mar 2023 08:39:55 +0000
Subject: [PATCH 48/81] add tiler config

---
 configs/iota2/externalfeatures_with_userfeatures_50x50.cfg | 5 +++--
 configs/iota2/iota2_grid_full.cfg                          | 2 +-
 2 files changed, 4 insertions(+), 3 deletions(-)

diff --git a/configs/iota2/externalfeatures_with_userfeatures_50x50.cfg b/configs/iota2/externalfeatures_with_userfeatures_50x50.cfg
index cdb197f1..71bc61d1 100644
--- a/configs/iota2/externalfeatures_with_userfeatures_50x50.cfg
+++ b/configs/iota2/externalfeatures_with_userfeatures_50x50.cfg
@@ -11,7 +11,8 @@ chain :
   ground_truth : '/home/uz/vinascj/scratch/MMDC/iota2/juan_i2_data/inputs/vector_data/S2_50x50.shp'
   data_field : 'groscode'
   s2_path : '/home/uz/vinascj/scratch/MMDC/iota2/juan_i2_data/inputs/raster_data/S2_2dates_50x50_symlink'
-  user_feat_path: '/home/uz/vinascj/scratch/MMDC/iota2/juan_i2_data/inputs/raster_data/mnt_50x50'
+  # user_feat_path: '/home/uz/vinascj/scratch/MMDC/iota2/juan_i2_data/inputs/raster_data/mnt_50x50'
+  user_feat_path: '/home/uz/vinascj/scratch/MMDC/iota2/grid_test_full'
   # s2_output_path : '/home/uz/vinascj/scratch/MMDC/iota2/juan_i2_data/inputs/raster_data/S2_2dates_50x50_output'
   s1_path : '/home/uz/vinascj/scratch/MMDC/iota2/juan_i2_data/inputs/config/SAR_test_T31TCJ.cfg'
 
@@ -28,7 +29,7 @@ chain :
 userFeat:
 {
   arbo:"/*"
-  patterns:"slope"
+  patterns:"aspect,EVEN_AZIMUTH,ODD_ZENITH,s1_incidence_ascending,worldclim,AZIMUTH,EVEN_ZENITH,s1_azimuth_ascending,s1_incidence_descending,ZENITH,elevation,ODD_AZIMUTH,s1_azimuth_descending,slope"
 }
 python_data_managing :
 {
diff --git a/configs/iota2/iota2_grid_full.cfg b/configs/iota2/iota2_grid_full.cfg
index 0a529fbb..d5379d49 100644
--- a/configs/iota2/iota2_grid_full.cfg
+++ b/configs/iota2/iota2_grid_full.cfg
@@ -1,7 +1,7 @@
 chain :
 {
   output_path : '/home/uz/vinascj/scratch/MMDC/iota2/grid_test_full'
-  spatial_resolution : 50
+  spatial_resolution : 10
   first_step : 'tiler'
   last_step : 'tiler'
 
-- 
GitLab


From 1f3b81a67deba19e10fed32c7f0398297f8eb56a Mon Sep 17 00:00:00 2001
From: Juan Sebastian Vinasco-Salinas <js.vinasco.s@gmail.com>
Date: Thu, 16 Mar 2023 14:32:01 +0000
Subject: [PATCH 49/81] add new full module init strategy

---
 .../components/inference_components.py        | 192 +++++++++++++-----
 1 file changed, 136 insertions(+), 56 deletions(-)

diff --git a/src/mmdc_singledate/inference/components/inference_components.py b/src/mmdc_singledate/inference/components/inference_components.py
index c7130a33..ae4a7498 100644
--- a/src/mmdc_singledate/inference/components/inference_components.py
+++ b/src/mmdc_singledate/inference/components/inference_components.py
@@ -5,20 +5,28 @@
 Functions components for get the
 latent spaces for a given tile
 """
-
-
 # imports
 import logging
 from pathlib import Path
 from typing import Literal
 
 import torch
+from hydra import compose, initialize
 from torch import nn
 
-from mmdc_singledate.models.components.model_dataclass import (  # VAELatentSpace,
+from mmdc_singledate.datamodules.types import MMDCDataChannels
+from mmdc_singledate.models.mmdc_full_module import MMDCFullModule
+from mmdc_singledate.models.types import (
+    ConvnetParams,
+    MLPConfig,
+    MMDCDataUse,
+    ModularEmbeddingConfig,
+    MultiVAEConfig,
+    MultiVAELossWeights,
     S1S2VAELatentSpace,
+    TranslationLossWeights,
+    UnetParams,
 )
-from mmdc_singledate.models.mmdc_full_module import MMDCFullModule
 
 # define sensors variable for typing
 SENSORS = Literal["S2L2A", "S1FULL", "S1ASC", "S1DESC"]
@@ -44,55 +52,130 @@ def get_mmdc_full_model(
         logger.info("No checkpoint path was give. ")
         return 1
     else:
+        # parse config
+        with initialize(config_path="../../../../configs/model/"):
+            cfg = compose(
+                config_name="mmdc_full.yaml", return_hydra_config=True, overrides=[]
+            )
+
+        mmdc_full_config = MultiVAEConfig(
+            data_sizes=MMDCDataChannels(
+                sen1=cfg.model.config.data_sizes.sen1,
+                s1_angles=cfg.model.config.data_sizes.s1_angles,
+                sen2=cfg.model.config.data_sizes.sen2,
+                s2_angles=cfg.model.config.data_sizes.s2_angles,
+                srtm=cfg.model.config.data_sizes.srtm,
+                w_c=cfg.model.config.data_sizes.w_c,
+            ),
+            embeddings=ModularEmbeddingConfig(
+                s1_angles=MLPConfig(
+                    hidden_layers=cfg.model.config.embeddings.s1_angles.hidden_layers,
+                    out_channels=cfg.model.config.embeddings.s1_angles.out_channels,
+                ),
+                s1_srtm=UnetParams(
+                    out_channels=cfg.model.config.embeddings.s1_srtm.out_channels,
+                    encoder_sizes=cfg.model.config.embeddings.s1_srtm.encoder_sizes,
+                    kernel_size=cfg.model.config.embeddings.s1_srtm.kernel_size,
+                    tail_layers=cfg.model.config.embeddings.s1_srtm.tail_layers,
+                ),
+                s2_angles=ConvnetParams(
+                    out_channels=cfg.model.config.embeddings.s2_angles.out_channels,
+                    sizes=cfg.model.config.embeddings.s2_angles.sizes,
+                    kernel_sizes=cfg.model.config.embeddings.s2_angles.kernel_sizes,
+                ),
+                s2_srtm=UnetParams(
+                    out_channels=cfg.model.config.embeddings.s2_srtm.out_channels,
+                    encoder_sizes=cfg.model.config.embeddings.s2_srtm.encoder_sizes,
+                    kernel_size=cfg.model.config.embeddings.s2_srtm.kernel_size,
+                    tail_layers=cfg.model.config.embeddings.s2_srtm.tail_layers,
+                ),
+                w_c=ConvnetParams(
+                    out_channels=cfg.model.config.embeddings.w_c.out_channels,
+                    sizes=cfg.model.config.embeddings.w_c.sizes,
+                    kernel_sizes=cfg.model.config.embeddings.w_c.kernel_sizes,
+                ),
+            ),
+            s1_encoder=UnetParams(
+                out_channels=cfg.model.config.s1_encoder.out_channels,
+                encoder_sizes=cfg.model.config.s1_encoder.encoder_sizes,
+                kernel_size=cfg.model.config.s1_encoder.kernel_size,
+                tail_layers=cfg.model.config.s1_encoder.tail_layers,
+            ),
+            s2_encoder=UnetParams(
+                out_channels=cfg.model.config.s2_encoder.out_channels,
+                encoder_sizes=cfg.model.config.s2_encoder.encoder_sizes,
+                kernel_size=cfg.model.config.s2_encoder.kernel_size,
+                tail_layers=cfg.model.config.s2_encoder.tail_layers,
+            ),
+            s1_decoder=ConvnetParams(
+                out_channels=cfg.model.config.s1_decoder.out_channels,
+                sizes=cfg.model.config.s1_decoder.sizes,
+                kernel_sizes=cfg.model.config.s1_decoder.kernel_sizes,
+            ),
+            s2_decoder=ConvnetParams(
+                out_channels=cfg.model.config.s2_decoder.out_channels,
+                sizes=cfg.model.config.s2_decoder.sizes,
+                kernel_sizes=cfg.model.config.s2_decoder.kernel_sizes,
+            ),
+            s1_enc_use=MMDCDataUse(
+                s1_angles=cfg.model.config.s1_enc_use.s1_angles,
+                s2_angles=cfg.model.config.s1_enc_use.s2_angles,
+                srtm=cfg.model.config.s1_enc_use.srtm,
+                w_c=cfg.model.config.s1_enc_use.w_c,
+            ),
+            s2_enc_use=MMDCDataUse(
+                s1_angles=cfg.model.config.s2_enc_use.s1_angles,
+                s2_angles=cfg.model.config.s2_enc_use.s2_angles,
+                srtm=cfg.model.config.s2_enc_use.srtm,
+                w_c=cfg.model.config.s2_enc_use.w_c,
+            ),
+            s1_dec_use=MMDCDataUse(
+                s1_angles=cfg.model.config.s1_dec_use.s1_angles,
+                s2_angles=cfg.model.config.s1_dec_use.s2_angles,
+                srtm=cfg.model.config.s1_dec_use.srtm,
+                w_c=cfg.model.config.s1_dec_use.w_c,
+            ),
+            s2_dec_use=MMDCDataUse(
+                s1_angles=cfg.model.config.s2_dec_use.s1_angles,
+                s2_angles=cfg.model.config.s2_dec_use.s2_angles,
+                srtm=cfg.model.config.s2_dec_use.srtm,
+                w_c=cfg.model.config.s2_dec_use.w_c,
+            ),
+            loss_weights=MultiVAELossWeights(
+                sen1=TranslationLossWeights(
+                    nll=cfg.model.config.loss_weights.sen1.nll,
+                    lpips=cfg.model.config.loss_weights.sen1.lpips,
+                    sam=cfg.model.config.loss_weights.sen1.sam,
+                    gradients=cfg.model.config.loss_weights.sen1.gradients,
+                ),
+                sen2=TranslationLossWeights(
+                    nll=cfg.model.config.loss_weights.sen2.nll,
+                    lpips=cfg.model.config.loss_weights.sen2.lpips,
+                    sam=cfg.model.config.loss_weights.sen2.sam,
+                    gradients=cfg.model.config.loss_weights.sen2.gradients,
+                ),
+                s1_s2=TranslationLossWeights(
+                    nll=cfg.model.config.loss_weights.s1_s2.nll,
+                    lpips=cfg.model.config.loss_weights.s1_s2.lpips,
+                    sam=cfg.model.config.loss_weights.s1_s2.sam,
+                    gradients=cfg.model.config.loss_weights.s1_s2.gradients,
+                ),
+                s2_s1=TranslationLossWeights(
+                    nll=cfg.model.config.loss_weights.s2_s1.nll,
+                    lpips=cfg.model.config.loss_weights.s2_s1.lpips,
+                    sam=cfg.model.config.loss_weights.s2_s1.sam,
+                    gradients=cfg.model.config.loss_weights.s2_s1.gradients,
+                ),
+                forward=cfg.model.config.loss_weights.forward,
+                cross=cfg.model.config.loss_weights.forward,
+                latent=cfg.model.config.loss_weights.latent,
+            ),
+        )
+        # Init the model
         mmdc_full_model = MMDCFullModule(
-            s2_angles_conv_in_channels=6,
-            s2_angles_conv_out_channels=3,
-            s2_angles_conv_encoder_sizes=[16, 8],
-            s2_angles_conv_kernel_size=1,
-            #
-            srtm_s2_angles_conv_in_channels=10,  # 4 srtm + 6 angles
-            srtm_s2_angles_conv_out_channels=3,
-            srtm_s2_angles_unet_encoder_sizes=[32, 64, 128],
-            srtm_s2_angles_kernel_size=3,
-            #
-            s1_angles_mlp_in_size=6,
-            s1_angles_mlp_hidden_layers=[9, 12, 9],
-            s1_angles_mlp_out_size=3,
-            #
-            srtm_s1_encoder_sizes=[32, 64, 128],
-            srtm_kernel_size=3,
-            srtm_s1_in_channels=4,
-            srtm_s1_out_channels=3,
-            #
-            wc_enc_sizes=[64, 32, 16],
-            wc_kernel_size=1,
-            wc_in_channels=103,
-            wc_out_channels=4,
-            #
-            s1_input_size=6,
-            s1_encoder_sizes=[64, 128, 256, 512, 1024],
-            s1_enc_kernel_sizes=[3],
-            s1_decoder_sizes=[32, 16, 8],
-            s1_dec_kernel_sizes=[3, 3, 3, 3],
-            code_sizes=[0, 4, 0],
-            s2_input_size=10,
-            s2_encoder_sizes=[64, 128, 256, 512, 1024],
-            s2_enc_kernel_sizes=[3],
-            s2_decoder_sizes=[64, 32, 16],
-            s2_dec_kernel_sizes=[3, 3, 3, 3],
-            s1_ang_in_decoder=True,
-            s2_ang_in_decoder=True,
-            srtm_in_decoder=True,
-            wc_in_decoder=True,
-            w_d1_e1_s1=1,
-            w_d2_e2_s2=1,
-            w_d1_e2_s2=1,
-            w_d2_e1_s1=1,
-            w_code_s1s2=1,
+            mmdc_full_config,
         )
 
-        # mmdc_full_lightning = MMDCFullLitModule(mmdc_full_model)
-
         device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
         print(f"device detected : {device}")
         print(f"nb_gpu's detected : {torch.cuda.device_count()}")
@@ -100,22 +183,19 @@ def get_mmdc_full_model(
         lightning_checkpoint = torch.load(checkpoint, map_location=device)
 
         # delete "model" from the loaded checkpoint
-        checkpoint = {
+        lightning_checkpoint_cleaning = {
             key.split("model.")[-1]: item
             for key, item in lightning_checkpoint["state_dict"].items()
             if key.startswith("model.")
         }
 
         # load the state dict
-        mmdc_full_model.load_state_dict(checkpoint)
-        # mmdc_full_model.load_state_dict(lightning_checkpoint)
-        # mmdc_full_lightning.load_state_dict(lightning_checkpoint)
+        mmdc_full_model.load_state_dict(lightning_checkpoint_cleaning)
 
         # disble randomness, dropout, etc...
         mmdc_full_model.eval()
-        # mmdc_full_lightning.eval()
 
-    return mmdc_full_model  # mmdc_full_lightning
+    return mmdc_full_model
 
 
 @torch.no_grad()
-- 
GitLab


From 7f54f3ae6091e9b44aec02ceebfabf7a97659ce9 Mon Sep 17 00:00:00 2001
From: Juan Sebastian Vinasco-Salinas <js.vinasco.s@gmail.com>
Date: Fri, 17 Mar 2023 16:29:08 +0000
Subject: [PATCH 50/81] update to main

---
 src/bin/inference_mmdc_singledate.py | 8 +++++++-
 1 file changed, 7 insertions(+), 1 deletion(-)

diff --git a/src/bin/inference_mmdc_singledate.py b/src/bin/inference_mmdc_singledate.py
index 7b8db7d5..6d8e179f 100644
--- a/src/bin/inference_mmdc_singledate.py
+++ b/src/bin/inference_mmdc_singledate.py
@@ -22,6 +22,7 @@ from mmdc_singledate.inference.mmdc_tile_inference import (
     MMDCProcess,
     predict_single_date_tile,
 )
+from mmdc_singledate.inference.utils import get_scales
 
 
 def get_parser() -> argparse.ArgumentParser:
@@ -98,7 +99,10 @@ def get_parser() -> argparse.ArgumentParser:
         "--sensors",
         dest="sensors",
         nargs="+",
-        help="List of sensors S2L2A | S1FULL | S1ASC | S1DESC (Is sensible to the order S2L2A always firts)",
+        help=(
+            "List of sensors S2L2A | S1FULL | S1ASC | S1DESC",
+            " (Is sensible to the order S2L2A always firts)",
+        ),
         required=True,
     )
 
@@ -131,6 +135,8 @@ def main():
     mmdc_full_model = get_mmdc_full_model(
         checkpoint=args.model_checkpoint_path,
     )
+    scales = get_scales()
+    mmdc_full_model.set_scales(scales)
     mmdc_full_model.to(device)
     # process class
     mmdc_process = MMDCProcess(
-- 
GitLab


From 8e7aa3466060a647a68ee2e3b796643e11541eb3 Mon Sep 17 00:00:00 2001
From: Juan Sebastian Vinasco-Salinas <js.vinasco.s@gmail.com>
Date: Fri, 17 Mar 2023 16:30:32 +0000
Subject: [PATCH 51/81] update model config

---
 .../components/inference_components.py        | 357 ++++++++++--------
 1 file changed, 200 insertions(+), 157 deletions(-)

diff --git a/src/mmdc_singledate/inference/components/inference_components.py b/src/mmdc_singledate/inference/components/inference_components.py
index ae4a7498..e3a9f52b 100644
--- a/src/mmdc_singledate/inference/components/inference_components.py
+++ b/src/mmdc_singledate/inference/components/inference_components.py
@@ -11,7 +11,7 @@ from pathlib import Path
 from typing import Literal
 
 import torch
-from hydra import compose, initialize
+import yaml
 from torch import nn
 
 from mmdc_singledate.datamodules.types import MMDCDataChannels
@@ -39,6 +39,17 @@ logging.basicConfig(
 logger = logging.getLogger(__name__)
 
 
+def parse_from_yaml(path_to_yaml):
+    """
+    Read a yaml file and return a dict
+    https://betterprogramming.pub/validating-yaml-configs-made-easy-with-pydantic-594522612db5
+    """
+    with open(path_to_yaml) as f:
+        config = yaml.safe_load(f)
+
+    return config
+
+
 @torch.no_grad()
 def get_mmdc_full_model(
     checkpoint: str,
@@ -52,123 +63,155 @@ def get_mmdc_full_model(
         logger.info("No checkpoint path was give. ")
         return 1
     else:
-        # parse config
-        with initialize(config_path="../../../../configs/model/"):
-            cfg = compose(
-                config_name="mmdc_full.yaml", return_hydra_config=True, overrides=[]
-            )
+        cfg = parse_from_yaml("configs/model/mmdc_full.yaml")
 
         mmdc_full_config = MultiVAEConfig(
             data_sizes=MMDCDataChannels(
-                sen1=cfg.model.config.data_sizes.sen1,
-                s1_angles=cfg.model.config.data_sizes.s1_angles,
-                sen2=cfg.model.config.data_sizes.sen2,
-                s2_angles=cfg.model.config.data_sizes.s2_angles,
-                srtm=cfg.model.config.data_sizes.srtm,
-                w_c=cfg.model.config.data_sizes.w_c,
+                sen1=cfg["model"]["config"]["data_sizes"]["sen1"],
+                s1_angles=cfg["model"]["config"]["data_sizes"]["s1_angles"],
+                sen2=cfg["model"]["config"]["data_sizes"]["sen2"],
+                s2_angles=cfg["model"]["config"]["data_sizes"]["s2_angles"],
+                srtm=cfg["model"]["config"]["data_sizes"]["srtm"],
+                w_c=cfg["model"]["config"]["data_sizes"]["w_c"],
             ),
             embeddings=ModularEmbeddingConfig(
                 s1_angles=MLPConfig(
-                    hidden_layers=cfg.model.config.embeddings.s1_angles.hidden_layers,
-                    out_channels=cfg.model.config.embeddings.s1_angles.out_channels,
+                    hidden_layers=cfg["model"]["config"]["embeddings"]["s1_angles"][
+                        "hidden_layers"
+                    ],
+                    out_channels=cfg["model"]["config"]["embeddings"]["s1_angles"][
+                        "out_channels"
+                    ],
                 ),
                 s1_srtm=UnetParams(
-                    out_channels=cfg.model.config.embeddings.s1_srtm.out_channels,
-                    encoder_sizes=cfg.model.config.embeddings.s1_srtm.encoder_sizes,
-                    kernel_size=cfg.model.config.embeddings.s1_srtm.kernel_size,
-                    tail_layers=cfg.model.config.embeddings.s1_srtm.tail_layers,
+                    out_channels=cfg["model"]["config"]["embeddings"]["s1_srtm"][
+                        "out_channels"
+                    ],
+                    encoder_sizes=cfg["model"]["config"]["embeddings"]["s1_srtm"][
+                        "encoder_sizes"
+                    ],
+                    kernel_size=cfg["model"]["config"]["embeddings"]["s1_srtm"][
+                        "kernel_size"
+                    ],
+                    tail_layers=cfg["model"]["config"]["embeddings"]["s1_srtm"][
+                        "tail_layers"
+                    ],
                 ),
                 s2_angles=ConvnetParams(
-                    out_channels=cfg.model.config.embeddings.s2_angles.out_channels,
-                    sizes=cfg.model.config.embeddings.s2_angles.sizes,
-                    kernel_sizes=cfg.model.config.embeddings.s2_angles.kernel_sizes,
+                    out_channels=cfg["model"]["config"]["embeddings"]["s2_angles"][
+                        "out_channels"
+                    ],
+                    sizes=cfg["model"]["config"]["embeddings"]["s2_angles"]["sizes"],
+                    kernel_sizes=cfg["model"]["config"]["embeddings"]["s2_angles"][
+                        "kernel_sizes"
+                    ],
                 ),
                 s2_srtm=UnetParams(
-                    out_channels=cfg.model.config.embeddings.s2_srtm.out_channels,
-                    encoder_sizes=cfg.model.config.embeddings.s2_srtm.encoder_sizes,
-                    kernel_size=cfg.model.config.embeddings.s2_srtm.kernel_size,
-                    tail_layers=cfg.model.config.embeddings.s2_srtm.tail_layers,
+                    out_channels=cfg["model"]["config"]["embeddings"]["s2_srtm"][
+                        "out_channels"
+                    ],
+                    encoder_sizes=cfg["model"]["config"]["embeddings"]["s2_srtm"][
+                        "encoder_sizes"
+                    ],
+                    kernel_size=cfg["model"]["config"]["embeddings"]["s2_srtm"][
+                        "kernel_size"
+                    ],
+                    tail_layers=cfg["model"]["config"]["embeddings"]["s2_srtm"][
+                        "tail_layers"
+                    ],
                 ),
                 w_c=ConvnetParams(
-                    out_channels=cfg.model.config.embeddings.w_c.out_channels,
-                    sizes=cfg.model.config.embeddings.w_c.sizes,
-                    kernel_sizes=cfg.model.config.embeddings.w_c.kernel_sizes,
+                    out_channels=cfg["model"]["config"]["embeddings"]["w_c"][
+                        "out_channels"
+                    ],
+                    sizes=cfg["model"]["config"]["embeddings"]["w_c"]["sizes"],
+                    kernel_sizes=cfg["model"]["config"]["embeddings"]["w_c"][
+                        "kernel_sizes"
+                    ],
                 ),
             ),
             s1_encoder=UnetParams(
-                out_channels=cfg.model.config.s1_encoder.out_channels,
-                encoder_sizes=cfg.model.config.s1_encoder.encoder_sizes,
-                kernel_size=cfg.model.config.s1_encoder.kernel_size,
-                tail_layers=cfg.model.config.s1_encoder.tail_layers,
+                out_channels=cfg["model"]["config"]["s1_encoder"]["out_channels"],
+                encoder_sizes=cfg["model"]["config"]["s1_encoder"]["encoder_sizes"],
+                kernel_size=cfg["model"]["config"]["s1_encoder"]["kernel_size"],
+                tail_layers=cfg["model"]["config"]["s1_encoder"]["tail_layers"],
             ),
             s2_encoder=UnetParams(
-                out_channels=cfg.model.config.s2_encoder.out_channels,
-                encoder_sizes=cfg.model.config.s2_encoder.encoder_sizes,
-                kernel_size=cfg.model.config.s2_encoder.kernel_size,
-                tail_layers=cfg.model.config.s2_encoder.tail_layers,
+                out_channels=cfg["model"]["config"]["s2_encoder"]["out_channels"],
+                encoder_sizes=cfg["model"]["config"]["s2_encoder"]["encoder_sizes"],
+                kernel_size=cfg["model"]["config"]["s2_encoder"]["kernel_size"],
+                tail_layers=cfg["model"]["config"]["s2_encoder"]["tail_layers"],
             ),
             s1_decoder=ConvnetParams(
-                out_channels=cfg.model.config.s1_decoder.out_channels,
-                sizes=cfg.model.config.s1_decoder.sizes,
-                kernel_sizes=cfg.model.config.s1_decoder.kernel_sizes,
+                out_channels=cfg["model"]["config"]["s1_decoder"]["out_channels"],
+                sizes=cfg["model"]["config"]["s1_decoder"]["sizes"],
+                kernel_sizes=cfg["model"]["config"]["s1_decoder"]["kernel_sizes"],
             ),
             s2_decoder=ConvnetParams(
-                out_channels=cfg.model.config.s2_decoder.out_channels,
-                sizes=cfg.model.config.s2_decoder.sizes,
-                kernel_sizes=cfg.model.config.s2_decoder.kernel_sizes,
+                out_channels=cfg["model"]["config"]["s2_decoder"]["out_channels"],
+                sizes=cfg["model"]["config"]["s2_decoder"]["sizes"],
+                kernel_sizes=cfg["model"]["config"]["s2_decoder"]["kernel_sizes"],
             ),
             s1_enc_use=MMDCDataUse(
-                s1_angles=cfg.model.config.s1_enc_use.s1_angles,
-                s2_angles=cfg.model.config.s1_enc_use.s2_angles,
-                srtm=cfg.model.config.s1_enc_use.srtm,
-                w_c=cfg.model.config.s1_enc_use.w_c,
+                s1_angles=cfg["model"]["config"]["s1_enc_use"]["s1_angles"],
+                s2_angles=cfg["model"]["config"]["s1_enc_use"]["s2_angles"],
+                srtm=cfg["model"]["config"]["s1_enc_use"]["srtm"],
+                w_c=cfg["model"]["config"]["s1_enc_use"]["w_c"],
             ),
             s2_enc_use=MMDCDataUse(
-                s1_angles=cfg.model.config.s2_enc_use.s1_angles,
-                s2_angles=cfg.model.config.s2_enc_use.s2_angles,
-                srtm=cfg.model.config.s2_enc_use.srtm,
-                w_c=cfg.model.config.s2_enc_use.w_c,
+                s1_angles=cfg["model"]["config"]["s2_enc_use"]["s1_angles"],
+                s2_angles=cfg["model"]["config"]["s2_enc_use"]["s2_angles"],
+                srtm=cfg["model"]["config"]["s2_enc_use"]["srtm"],
+                w_c=cfg["model"]["config"]["s2_enc_use"]["w_c"],
             ),
             s1_dec_use=MMDCDataUse(
-                s1_angles=cfg.model.config.s1_dec_use.s1_angles,
-                s2_angles=cfg.model.config.s1_dec_use.s2_angles,
-                srtm=cfg.model.config.s1_dec_use.srtm,
-                w_c=cfg.model.config.s1_dec_use.w_c,
+                s1_angles=cfg["model"]["config"]["s1_dec_use"]["s1_angles"],
+                s2_angles=cfg["model"]["config"]["s1_dec_use"]["s2_angles"],
+                srtm=cfg["model"]["config"]["s1_dec_use"]["srtm"],
+                w_c=cfg["model"]["config"]["s1_dec_use"]["w_c"],
             ),
             s2_dec_use=MMDCDataUse(
-                s1_angles=cfg.model.config.s2_dec_use.s1_angles,
-                s2_angles=cfg.model.config.s2_dec_use.s2_angles,
-                srtm=cfg.model.config.s2_dec_use.srtm,
-                w_c=cfg.model.config.s2_dec_use.w_c,
+                s1_angles=cfg["model"]["config"]["s2_dec_use"]["s1_angles"],
+                s2_angles=cfg["model"]["config"]["s2_dec_use"]["s2_angles"],
+                srtm=cfg["model"]["config"]["s2_dec_use"]["srtm"],
+                w_c=cfg["model"]["config"]["s2_dec_use"]["w_c"],
             ),
             loss_weights=MultiVAELossWeights(
                 sen1=TranslationLossWeights(
-                    nll=cfg.model.config.loss_weights.sen1.nll,
-                    lpips=cfg.model.config.loss_weights.sen1.lpips,
-                    sam=cfg.model.config.loss_weights.sen1.sam,
-                    gradients=cfg.model.config.loss_weights.sen1.gradients,
+                    nll=cfg["model"]["config"]["loss_weights"]["sen1"]["nll"],
+                    lpips=cfg["model"]["config"]["loss_weights"]["sen1"]["lpips"],
+                    sam=cfg["model"]["config"]["loss_weights"]["sen1"]["sam"],
+                    gradients=cfg["model"]["config"]["loss_weights"]["sen1"][
+                        "gradients"
+                    ],
                 ),
                 sen2=TranslationLossWeights(
-                    nll=cfg.model.config.loss_weights.sen2.nll,
-                    lpips=cfg.model.config.loss_weights.sen2.lpips,
-                    sam=cfg.model.config.loss_weights.sen2.sam,
-                    gradients=cfg.model.config.loss_weights.sen2.gradients,
+                    nll=cfg["model"]["config"]["loss_weights"]["sen2"]["nll"],
+                    lpips=cfg["model"]["config"]["loss_weights"]["sen2"]["lpips"],
+                    sam=cfg["model"]["config"]["loss_weights"]["sen2"]["sam"],
+                    gradients=cfg["model"]["config"]["loss_weights"]["sen2"][
+                        "gradients"
+                    ],
                 ),
                 s1_s2=TranslationLossWeights(
-                    nll=cfg.model.config.loss_weights.s1_s2.nll,
-                    lpips=cfg.model.config.loss_weights.s1_s2.lpips,
-                    sam=cfg.model.config.loss_weights.s1_s2.sam,
-                    gradients=cfg.model.config.loss_weights.s1_s2.gradients,
+                    nll=cfg["model"]["config"]["loss_weights"]["s1_s2"]["nll"],
+                    lpips=cfg["model"]["config"]["loss_weights"]["s1_s2"]["lpips"],
+                    sam=cfg["model"]["config"]["loss_weights"]["s1_s2"]["sam"],
+                    gradients=cfg["model"]["config"]["loss_weights"]["s1_s2"][
+                        "gradients"
+                    ],
                 ),
                 s2_s1=TranslationLossWeights(
-                    nll=cfg.model.config.loss_weights.s2_s1.nll,
-                    lpips=cfg.model.config.loss_weights.s2_s1.lpips,
-                    sam=cfg.model.config.loss_weights.s2_s1.sam,
-                    gradients=cfg.model.config.loss_weights.s2_s1.gradients,
+                    nll=cfg["model"]["config"]["loss_weights"]["s2_s1"]["nll"],
+                    lpips=cfg["model"]["config"]["loss_weights"]["s2_s1"]["lpips"],
+                    sam=cfg["model"]["config"]["loss_weights"]["s2_s1"]["sam"],
+                    gradients=cfg["model"]["config"]["loss_weights"]["s2_s1"][
+                        "gradients"
+                    ],
                 ),
-                forward=cfg.model.config.loss_weights.forward,
-                cross=cfg.model.config.loss_weights.forward,
-                latent=cfg.model.config.loss_weights.latent,
+                forward=cfg["model"]["config"]["loss_weights"]["forward"],
+                cross=cfg["model"]["config"]["loss_weights"]["forward"],
+                latent=cfg["model"]["config"]["loss_weights"]["latent"],
             ),
         )
         # Init the model
@@ -177,8 +220,8 @@ def get_mmdc_full_model(
         )
 
         device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
-        print(f"device detected : {device}")
-        print(f"nb_gpu's detected : {torch.cuda.device_count()}")
+        # print(f"device detected : {device}")
+        # print(f"nb_gpu's detected : {torch.cuda.device_count()}")
         # load state_dict
         lightning_checkpoint = torch.load(checkpoint, map_location=device)
 
@@ -233,7 +276,7 @@ def predict_mmdc_model(
     # available_sensors = ["S2L2A", "S1FULL", "S1ASC", "S1DESC"]
 
     device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
-    print(device)
+    # print(device)
 
     prediction = model.predict(
         s2_ref.to(device),
@@ -247,10 +290,10 @@ def predict_mmdc_model(
         srtm.to(device),
     )
 
-    print("prediction.shape :", prediction[0].latent.latent_s2.mu.shape)
-    print("prediction.shape :", prediction[0].latent.latent_s2.logvar.shape)
-    print("prediction.shape :", prediction[0].latent.latent_s1.mu.shape)
-    print("prediction.shape :", prediction[0].latent.latent_s1.logvar.shape)
+    # print("prediction.shape :", prediction[0].latent.latent_s2.mean.shape)
+    # print("prediction.shape :", prediction[0].latent.latent_s2.logvar.shape)
+    # print("prediction.shape :", prediction[0].latent.latent_s1.mean.shape)
+    # print("prediction.shape :", prediction[0].latent.latent_s1.logvar.shape)
 
     # init the output latent spaces as
     # empty dataclass
@@ -261,151 +304,151 @@ def predict_mmdc_model(
     match sensors:
         # Cases
         case ["S2L2A", "S1FULL"]:
-            logger.info("S2 captured & S1 full captured")
+            # logger.info("S2 captured & S1 full captured")
 
             latent_space.latent_s2 = prediction[0].latent.latent_s2
             latent_space.latent_s1 = prediction[0].latent.latent_s1
 
             latent_space_stack = torch.cat(
                 (
-                    latent_space.latent_s2.mu,
+                    latent_space.latent_s2.mean,
                     latent_space.latent_s2.logvar,
-                    latent_space.latent_s1.mu,
+                    latent_space.latent_s1.mean,
                     latent_space.latent_s1.logvar,
                 ),
                 1,
             )
 
-            print(
-                "latent_space.latent_s2.mu :",
-                latent_space.latent_s2.mu.shape,
-            )
-            print(
-                "latent_space.latent_s2.logvar :",
-                latent_space.latent_s2.logvar.shape,
-            )
-            print("latent_space.latent_s1.mu :", latent_space.latent_s1.mu.shape)
-            print(
-                "latent_space.latent_s1.logvar :", latent_space.latent_s1.logvar.shape
-            )
+            # print(
+            #     "latent_space.latent_s2.mean :",
+            #     latent_space.latent_s2.mean.shape,
+            # )
+            # print(
+            #     "latent_space.latent_s2.logvar :",
+            #     latent_space.latent_s2.logvar.shape,
+            # )
+            # print("latent_space.latent_s1.mean :", latent_space.latent_s1.mean.shape)
+            # print(
+            #     "latent_space.latent_s1.logvar :", latent_space.latent_s1.logvar.shape
+            # )
 
         case ["S2L2A", "S1ASC"]:
-            logger.info("S2 captured & S1 asc captured")
+            # logger.info("S2 captured & S1 asc captured")
 
             latent_space.latent_s2 = prediction[0].latent.latent_s2
             latent_space.latent_s1 = prediction[0].latent.latent_s1
 
             latent_space_stack = torch.cat(
                 (
-                    latent_space.latent_s2.mu,
+                    latent_space.latent_s2.mean,
                     latent_space.latent_s2.logvar,
-                    latent_space.latent_s1.mu,
+                    latent_space.latent_s1.mean,
                     latent_space.latent_s1.logvar,
                 ),
                 1,
             )
-            print(
-                "latent_space.latent_s2.mu :",
-                latent_space.latent_s2.mu.shape,
-            )
-            print(
-                "latent_space.latent_s2.logvar :",
-                latent_space.latent_s2.logvar.shape,
-            )
-            print("latent_space.latent_s1.mu :", latent_space.latent_s1.mu.shape)
-            print(
-                "latent_space.latent_s1.logvar :", latent_space.latent_s1.logvar.shape
-            )
+            # print(
+            #     "latent_space.latent_s2.mean :",
+            #     latent_space.latent_s2.mean.shape,
+            # )
+            # print(
+            #     "latent_space.latent_s2.logvar :",
+            #     latent_space.latent_s2.logvar.shape,
+            # )
+            # print("latent_space.latent_s1.mean :", latent_space.latent_s1.mean.shape)
+            # print(
+            #     "latent_space.latent_s1.logvar :", latent_space.latent_s1.logvar.shape
+            # )
 
         case ["S2L2A", "S1DESC"]:
-            logger.info("S2 captured & S1 desc captured")
+            # logger.info("S2 captured & S1 desc captured")
 
             latent_space.latent_s2 = prediction[0].latent.latent_s2
             latent_space.latent_s1 = prediction[0].latent.latent_s1
 
             latent_space_stack = torch.cat(
                 (
-                    latent_space.latent_s2.mu,
+                    latent_space.latent_s2.mean,
                     latent_space.latent_s2.logvar,
-                    latent_space.latent_s1.mu,
+                    latent_space.latent_s1.mean,
                     latent_space.latent_s1.logvar,
                 ),
                 1,
             )
-            print(
-                "latent_space.latent_s2.mu :",
-                latent_space.latent_s2.mu.shape,
-            )
-            print(
-                "latent_space.latent_s2.logvar :",
-                latent_space.latent_s2.logvar.shape,
-            )
-            print("latent_space.latent_s1.mu :", latent_space.latent_s1.mu.shape)
-            print(
-                "latent_space.latent_s1.logvar :", latent_space.latent_s1.logvar.shape
-            )
+            # print(
+            #     "latent_space.latent_s2.mean :",
+            #     latent_space.latent_s2.mean.shape,
+            # )
+            # print(
+            #     "latent_space.latent_s2.logvar :",
+            #     latent_space.latent_s2.logvar.shape,
+            # )
+            # print("latent_space.latent_s1.mean :", latent_space.latent_s1.mean.shape)
+            # print(
+            #     "latent_space.latent_s1.logvar :", latent_space.latent_s1.logvar.shape
+            # )
 
         case ["S2L2A"]:
-            logger.info("Only S2 captured")
+            # logger.info("Only S2 captured")
 
             latent_space.latent_s2 = prediction[0].latent.latent_s2
 
             latent_space_stack = torch.cat(
                 (
-                    latent_space.latent_s2.mu,
+                    latent_space.latent_s2.mean,
                     latent_space.latent_s2.logvar,
                 ),
                 1,
             )
-            print(
-                "latent_space.latent_s2.mu :",
-                latent_space.latent_s2.mu.shape,
-            )
-            print(
-                "latent_space.latent_s2.logvar :",
-                latent_space.latent_s2.logvar.shape,
-            )
+            # print(
+            #     "latent_space.latent_s2.mean :",
+            #     latent_space.latent_s2.mean.shape,
+            # )
+            # print(
+            #     "latent_space.latent_s2.logvar :",
+            #     latent_space.latent_s2.logvar.shape,
+            # )
 
         case ["S1FULL"]:
-            logger.info("Only S1 full captured")
+            # logger.info("Only S1 full captured")
 
             latent_space.latent_s1 = prediction[0].latent.latent_s1
 
             latent_space_stack = torch.cat(
-                (latent_space.latent_s1.mu, latent_space.latent_s1.logvar),
+                (latent_space.latent_s1.mean, latent_space.latent_s1.logvar),
                 1,
             )
-            print("latent_space.latent_s1.mu :", latent_space.latent_s1.mu.shape)
-            print(
-                "latent_space.latent_s1.logvar :", latent_space.latent_s1.logvar.shape
-            )
+            # print("latent_space.latent_s1.mean :", latent_space.latent_s1.mean.shape)
+            # print(
+            #     "latent_space.latent_s1.logvar :", latent_space.latent_s1.logvar.shape
+            # )
 
         case ["S1ASC"]:
-            logger.info("Only S1ASC captured")
+            # logger.info("Only S1ASC captured")
 
             latent_space.latent_s1 = prediction[0].latent.latent_s1
 
             latent_space_stack = torch.cat(
-                (latent_space.latent_s1.mu, latent_space.latent_s1.logvar),
+                (latent_space.latent_s1.mean, latent_space.latent_s1.logvar),
                 1,
             )
-            print("latent_space.latent_s1.mu :", latent_space.latent_s1.mu.shape)
-            print(
-                "latent_space.latent_s1.logvar :", latent_space.latent_s1.logvar.shape
-            )
+            # print("latent_space.latent_s1.mean :", latent_space.latent_s1.mean.shape)
+            # print(
+            # "latent_space.latent_s1.logvar :", latent_space.latent_s1.logvar.shape
+            # )
 
         case ["S1DESC"]:
-            logger.info("Only S1DESC captured")
+            # logger.info("Only S1DESC captured")
 
             latent_space.latent_s1 = prediction[0].latent.latent_s1
 
             latent_space_stack = torch.cat(
-                (latent_space.latent_s1.mu, latent_space.latent_s1.logvar),
+                (latent_space.latent_s1.mean, latent_space.latent_s1.logvar),
                 1,
             )
-            print("latent_space.latent_s1.mu :", latent_space.latent_s1.mu.shape)
-            print(
-                "latent_space.latent_s1.logvar :", latent_space.latent_s1.logvar.shape
-            )
+            # print("latent_space.latent_s1.mean :", latent_space.latent_s1.mean.shape)
+            # print(
+            # "latent_space.latent_s1.logvar :", latent_space.latent_s1.logvar.shape
+            # )
 
     return latent_space_stack  # latent_space
-- 
GitLab


From a44235ecafdb5025c695d3adb021b0ba99447747 Mon Sep 17 00:00:00 2001
From: Juan Sebastian Vinasco-Salinas <js.vinasco.s@gmail.com>
Date: Fri, 17 Mar 2023 16:32:11 +0000
Subject: [PATCH 52/81] update paths

---
 test/test_mmdc_inference.py | 351 +++++++++++-------------------------
 1 file changed, 104 insertions(+), 247 deletions(-)

diff --git a/test/test_mmdc_inference.py b/test/test_mmdc_inference.py
index 4f7b1f0a..51be142f 100644
--- a/test/test_mmdc_inference.py
+++ b/test/test_mmdc_inference.py
@@ -10,7 +10,7 @@ import pytest
 import rasterio as rio
 import torch
 
-from mmdc_singledate.inference.components.inference_components import (  # predict_tile,
+from mmdc_singledate.inference.components.inference_components import (
     get_mmdc_full_model,
     predict_mmdc_model,
 )
@@ -24,31 +24,77 @@ from mmdc_singledate.inference.mmdc_tile_inference import (
     mmdc_tile_inference,
     predict_single_date_tile,
 )
-from mmdc_singledate.models.components.model_dataclass import VAELatentSpace
+from mmdc_singledate.models.types import VAELatentSpace
 
-# dir
+from .utils import get_scales, setup_data
+
+# usefull variables
 dataset_dir = "/work/CESBIO/projects/MAESTRIA/test_onetile/total/T31TCJ/"
+s2_filename_variable = "SENTINEL2A_20180302-105023-464_L2A_T31TCJ_C_V2-2_roi_0.tif"
+s1_asc_filename_variable = "S1A_IW_GRDH_1SDV_20180303T174722_20180303T174747_020854_023C45_4525_36.09887223808782_15.429479982324139_roi_0.tif"
+s1_desc_filename_variable = "S1A_IW_GRDH_1SDV_20180302T060027_20180302T060052_020832_023B8F_C485_40.83835645911643_165.05888005216622_roi_0.tif"
+srtm_filename_variable = "srtm_T31TCJ_roi_0.tif"
+wc_filename_variable = "wc_clim_1_T31TCJ_roi_0.tif"
+
+
+sensors_test = [
+    (["S2L2A", "S1FULL"]),
+    (["S2L2A", "S1ASC"]),
+    (["S2L2A", "S1DESC"]),
+    (["S2L2A"]),
+    (["S1FULL"]),
+    (["S1ASC"]),
+    (["S1DESC"]),
+]
+
+
+datasets = [
+    (
+        s2_filename_variable,
+        s1_asc_filename_variable,
+        s1_desc_filename_variable,
+        srtm_filename_variable,
+        wc_filename_variable,
+    ),
+    (
+        "",
+        s1_asc_filename_variable,
+        s1_desc_filename_variable,
+        srtm_filename_variable,
+        wc_filename_variable,
+    ),
+    (
+        s2_filename_variable,
+        "",
+        s1_desc_filename_variable,
+        srtm_filename_variable,
+        wc_filename_variable,
+    ),
+    (
+        s2_filename_variable,
+        s1_asc_filename_variable,
+        "",
+        srtm_filename_variable,
+        wc_filename_variable,
+    ),
+]
 
 
 # @pytest.mark.skip(reason="Check Time Serie generation")
 def test_GeoTiffDataset():
     """ """
-    datapath = "/work/CESBIO/projects/MAESTRIA/test_onetile/total/T31TCJ"
-
     input_data = GeoTiffDataset(
-        s2_filename=os.path.join(
-            datapath, "SENTINEL2A_20180302-105023-464_L2A_T31TCJ_C_V2-2_roi_0.tif"
-        ),
+        s2_filename=os.path.join(dataset_dir, s2_filename_variable),
         s1_asc_filename=os.path.join(
-            datapath,
-            "S1A_IW_GRDH_1SDV_20180303T174722_20180303T174747_020854_023C45_4525_36.09887223808782_15.429479982324139_roi_0.tif",
+            dataset_dir,
+            s1_asc_filename_variable,
         ),
         s1_desc_filename=os.path.join(
-            datapath,
-            "S1A_IW_GRDH_1SDV_20180302T060027_20180302T060052_020832_023B8F_C485_40.83835645911643_165.05888005216622_roi_0.tif",
+            dataset_dir,
+            s1_desc_filename_variable,
         ),
-        srtm_filename=os.path.join(datapath, "srtm_T31TCJ_roi_0.tif"),
-        wc_filename=os.path.join(datapath, "wc_clim_1_T31TCJ_roi_0.tif"),
+        srtm_filename=os.path.join(dataset_dir, srtm_filename_variable),
+        wc_filename=os.path.join(dataset_dir, wc_filename_variable),
     )
     # print GetTiffDataset
     print(input_data)
@@ -69,95 +115,33 @@ def test_generate_chunks():
     assert chunks[0] == (0, 0, 10980, 1024)
 
 
-# @pytest.mark.skip(reason="Check Time Serie generation")
-def test_mmdc_full_model():
-    """
-    test instantiate network
-    """
-    # set device
-    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
-    print(device)
+@pytest.mark.skip(reason="Check Time Serie generation")
+@pytest.mark.parametrize("sensors", sensors_test)
+def test_predict_mmdc_model(sensors):
+    """ """
 
+    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
     # checkpoints
-    checkpoint_path = "/home/uz/vinascj/scratch/MMDC/results/latent/logs/experiments/runs/mmdc_full/2023-01-06_13-00-34/checkpoints"
-    checkpoint_filename = "epoch_008.ckpt"  # "last.ckpt"
+    checkpoint_path = "/home/uz/vinascj/scratch/MMDC/results/latent/logs/experiments/runs/mmdc_full/2023-03-16_14-00-10/checkpoints"
+    checkpoint_filename = "last.ckpt"
 
     mmdc_full_model = get_mmdc_full_model(
         os.path.join(checkpoint_path, checkpoint_filename)
     )
-    # move to device
-    mmdc_full_model.to(device)
-
-    assert mmdc_full_model.training == False
-
-    s2_x = torch.rand(1, 10, 256, 256).to(device)
-    s2_m = torch.ones(1, 10, 256, 256).to(device)
-    s2_angles_x = torch.rand(1, 6, 256, 256).to(device)
-    s1_x = torch.rand(1, 6, 256, 256).to(device)
-    s1_vm = torch.ones(1, 1, 256, 256).to(device)
-    s1_asc_angles_x = torch.rand(1, 3).to(device)
-    s1_desc_angles_x = torch.rand(1, 3).to(device)
-    worldclim_x = torch.rand(1, 103, 256, 256).to(device)
-    srtm_x = torch.rand(1, 4, 256, 256).to(device)
-
-    print(srtm_x)
-
-    prediction = mmdc_full_model.predict(
+    (
         s2_x,
         s2_m,
-        s2_angles_x,
+        s2_a,
         s1_x,
         s1_vm,
-        s1_asc_angles_x,
-        s1_desc_angles_x,
-        worldclim_x,
+        s1_a_asc,
+        s1_a_desc,
         srtm_x,
-    )
-
-    latent_variable = prediction[0].latent.latent_s1.mu
-    print(latent_variable.shape)
-
-    assert type(latent_variable) == torch.Tensor
-
-
-sensors_test = [
-    (["S2L2A", "S1FULL"]),
-    (["S2L2A", "S1ASC"]),
-    (["S2L2A", "S1DESC"]),
-    (["S2L2A"]),
-    (["S1FULL"]),
-    (["S1ASC"]),
-    (["S1DESC"]),
-    # ([ "S1FULL","S2L2A"  ]),
-]
-
-
-# @pytest.mark.skip(reason="Check Time Serie generation")
-@pytest.mark.parametrize("sensors", sensors_test)
-def test_predict_mmdc_model(sensors):
-    """ """
-
-    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
-
-    # data input
-    s2_x = torch.rand(1, 10, 256, 256).to(device)
-    s2_m = torch.ones(1, 10, 256, 256).to(device)
-    s2_angles_x = torch.rand(1, 6, 256, 256).to(device)
-    s1_x = torch.rand(1, 6, 256, 256).to(device)
-    s1_vm = torch.ones(1, 1, 256, 256).to(device)
-    s1_asc_angles_x = torch.rand(1, 3).to(device)
-    s1_desc_angles_x = torch.rand(1, 3).to(device)
-    worldclim_x = torch.rand(1, 103, 256, 256).to(device)
-    srtm_x = torch.rand(1, 4, 256, 256).to(device)
-
-    # model
-    # checkpoints
-    checkpoint_path = "/home/uz/vinascj/scratch/MMDC/results/latent/logs/experiments/runs/mmdc_full/2023-01-06_13-00-34/checkpoints"
-    checkpoint_filename = "epoch_008.ckpt"  # "last.ckpt"
-
-    mmdc_full_model = get_mmdc_full_model(
-        os.path.join(checkpoint_path, checkpoint_filename)
-    )
+        wc_x,
+        device,
+    ) = setup_data()
+    scales = get_scales()
+    mmdc_full_model.set_scales(scales)
     # move to device
     mmdc_full_model.to(device)
 
@@ -166,107 +150,17 @@ def test_predict_mmdc_model(sensors):
         sensors,
         s2_x,
         s2_m,
-        s2_angles_x,
+        s2_a,
         s1_x,
         s1_vm,
-        s1_asc_angles_x,
-        s1_desc_angles_x,
-        worldclim_x,
+        s1_a_asc,
+        s1_a_desc,
+        wc_x,
         srtm_x,
     )
     print(pred.shape)
 
     assert type(pred) == torch.Tensor
-    # assert pred.shape[0] == 4 or 2
-
-
-def dummy_process(
-    s2_refl,
-    s2_mask,
-    s2_ang,
-    s1_back,
-    s1_vm,
-    s1_asc_lia,
-    s1_desc_lia,
-    wc_patch,
-    srtm_patch,
-):
-    """
-    Create a dummy function
-    """
-    prediction = (
-        s2_refl[:, 0, ...]
-        * s2_mask[:, 0, ...]
-        * s2_ang[:, 0, ...]
-        * s1_back[:, 0, ...]
-        * s1_vm[:, 0, ...]
-        * srtm_patch[:, 0, ...]
-        * wc_patch[:, 0, ...]
-    )
-
-    latent_example = VAELatentSpace(
-        mu=prediction,
-        logvar=prediction,
-    )
-
-    latent_space_stack = torch.cat(
-        (
-            torch.unsqueeze(latent_example.mu, 1),
-            torch.unsqueeze(latent_example.logvar, 1),
-            torch.unsqueeze(latent_example.mu, 1),
-            torch.unsqueeze(latent_example.logvar, 1),
-        ),
-        1,
-    )
-
-    return latent_space_stack
-
-
-# TODO add cases with no data
-# @pytest.mark.skip(reason="Test with dummy process ")
-# @pytest.mark.parametrize("sensors", sensors_test)
-# def test_predict_single_date_tile_with_dummy_process(sensors):
-#     """ """
-#     export_path = f"/work/CESBIO/projects/MAESTRIA/test_onetile/total/export/test_latent_singledate_{'_'.join(sensors)}.tif"
-#     nb_bands = 4 #  len(sensors) * 2
-#     print("nb_bands  :", nb_bands)
-#     process = MMDCProcess(
-#         count=nb_bands, #4,
-#         nb_lines=1024,
-#         patch_size=256,
-#         process=dummy_process,
-#     )
-
-#     datapath = "/work/CESBIO/projects/MAESTRIA/test_onetile/total/T31TCJ/"
-#     input_data = GeoTiffDataset(
-#         s2_filename=os.path.join(
-#             datapath, "SENTINEL2A_20180302-105023-464_L2A_T31TCJ_C_V2-2_roi_0.tif"
-#         ),
-#         s1_asc_filename=os.path.join(
-#             datapath,
-#             "S1A_IW_GRDH_1SDV_20180303T174722_20180303T174747_020854_023C45_4525_36.09887223808782_15.429479982324139_roi_0.tif",
-#         ),
-#         s1_desc_filename=os.path.join(
-#             datapath,
-#             "S1A_IW_GRDH_1SDV_20180302T060027_20180302T060052_020832_023B8F_C485_40.83835645911643_165.05888005216622_roi_0.tif",
-#         ),
-#         srtm_filename=os.path.join(datapath, "srtm_T31TCJ_roi_0.tif"),
-#         wc_filename=os.path.join(datapath, "wc_clim_1_T31TCJ_roi_0.tif"),
-#     )
-
-#     predict_single_date_tile(
-#         input_data=input_data,
-#         export_path=export_path,
-#         sensors=sensors,  # ["S2L2A"],
-#         process=process,
-#     )
-
-#     assert Path(export_path).exists() == True
-
-#     with rio.open(export_path) as predited_raster:
-#         predicted_metadata = predited_raster.meta.copy()
-
-#         assert predicted_metadata["count"] == nb_bands
 
 
 # TODO add cases with no data
@@ -278,33 +172,33 @@ def test_predict_single_date_tile(sensors):
     device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
     export_path = f"/work/CESBIO/projects/MAESTRIA/test_onetile/total/export/test_latent_singledate_with_model_{'_'.join(sensors)}.tif"
-    datapath = "/work/CESBIO/projects/MAESTRIA/test_onetile/total/T31TCJ/"
     input_data = GeoTiffDataset(
-        s2_filename=os.path.join(
-            datapath, "SENTINEL2A_20180302-105023-464_L2A_T31TCJ_C_V2-2_roi_0.tif"
-        ),
+        s2_filename=os.path.join(dataset_dir, s2_filename_variable),
         s1_asc_filename=os.path.join(
-            datapath,
-            "S1A_IW_GRDH_1SDV_20180303T174722_20180303T174747_020854_023C45_4525_36.09887223808782_15.429479982324139_roi_0.tif",
+            dataset_dir,
+            s1_asc_filename_variable,
         ),
         s1_desc_filename=os.path.join(
-            datapath,
-            "S1A_IW_GRDH_1SDV_20180302T060027_20180302T060052_020832_023B8F_C485_40.83835645911643_165.05888005216622_roi_0.tif",
+            dataset_dir,
+            s1_desc_filename_variable,
         ),
-        srtm_filename=os.path.join(datapath, "srtm_T31TCJ_roi_0.tif"),
-        wc_filename=os.path.join(datapath, "wc_clim_1_T31TCJ_roi_0.tif"),
+        srtm_filename=os.path.join(dataset_dir, srtm_filename_variable),
+        wc_filename=os.path.join(dataset_dir, wc_filename_variable),
     )
 
     nb_bands = len(sensors) * 8
     print("nb_bands  :", nb_bands)
 
+    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
     # checkpoints
-    checkpoint_path = "/home/uz/vinascj/scratch/MMDC/results/latent/logs/experiments/runs/mmdc_full/2023-01-06_13-00-34/checkpoints"
-    checkpoint_filename = "epoch_008.ckpt"  # "last.ckpt"
+    checkpoint_path = "/home/uz/vinascj/scratch/MMDC/results/latent/logs/experiments/runs/mmdc_full/2023-03-16_14-00-10/checkpoints"
+    checkpoint_filename = "last.ckpt"  # "last.ckpt"
 
     mmdc_full_model = get_mmdc_full_model(
         os.path.join(checkpoint_path, checkpoint_filename)
     )
+    scales = get_scales()
+    mmdc_full_model.set_scales(scales)
     # move to device
     mmdc_full_model.to(device)
     assert mmdc_full_model.training == False
@@ -333,49 +227,10 @@ def test_predict_single_date_tile(sensors):
         assert predicted_metadata["count"] == nb_bands
 
 
-s2_filename_variable = "SENTINEL2A_20180302-105023-464_L2A_T31TCJ_C_V2-2_roi_0.tif"
-s1_asc_filename_variable = "S1A_IW_GRDH_1SDV_20180303T174722_20180303T174747_020854_023C45_4525_36.09887223808782_15.429479982324139_roi_0.tif"
-s1_desc_filename_variable = "S1A_IW_GRDH_1SDV_20180302T060027_20180302T060052_020832_023B8F_C485_40.83835645911643_165.05888005216622_roi_0.tif"
-srtm_filename_variable = "srtm_T31TCJ_roi_0.tif"
-wc_filename_variable = "wc_clim_1_T31TCJ_roi_0.tif"
-
-datasets = [
-    (
-        s2_filename_variable,
-        s1_asc_filename_variable,
-        s1_desc_filename_variable,
-        srtm_filename_variable,
-        wc_filename_variable,
-    ),
-    (
-        "",
-        s1_asc_filename_variable,
-        s1_desc_filename_variable,
-        srtm_filename_variable,
-        wc_filename_variable,
-    ),
-    (
-        s2_filename_variable,
-        "",
-        s1_desc_filename_variable,
-        srtm_filename_variable,
-        wc_filename_variable,
-    ),
-    (
-        s2_filename_variable,
-        s1_asc_filename_variable,
-        "",
-        srtm_filename_variable,
-        wc_filename_variable,
-    ),
-]
-
-
 @pytest.mark.parametrize(
     "s2_filename_variable, s1_asc_filename_variable, s1_desc_filename_variable, srtm_filename_variable, wc_filename_variable",
     datasets,
 )
-# @pytest.mark.skip(reason="Check Time Serie generation")
 def test_predict_single_date_tile_no_data(
     s2_filename_variable,
     s1_asc_filename_variable,
@@ -388,32 +243,35 @@ def test_predict_single_date_tile_no_data(
     device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
     export_path = f"/work/CESBIO/projects/MAESTRIA/test_onetile/total/export/test_latent_singledate_with_model_{'_'.join(sensors)}.tif"
-    datapath = "/work/CESBIO/projects/MAESTRIA/test_onetile/total/T31TCJ/"
+    dataset_dir = "/work/CESBIO/projects/MAESTRIA/test_onetile/total/T31TCJ/"
     input_data = GeoTiffDataset(
-        s2_filename=os.path.join(datapath, s2_filename_variable),
+        s2_filename=os.path.join(dataset_dir, s2_filename_variable),
         s1_asc_filename=os.path.join(
-            datapath,
+            dataset_dir,
             s1_asc_filename_variable,
         ),
         s1_desc_filename=os.path.join(
-            datapath,
+            dataset_dir,
             s1_desc_filename_variable,
         ),
-        srtm_filename=os.path.join(datapath, srtm_filename_variable),
-        wc_filename=os.path.join(datapath, wc_filename_variable),
+        srtm_filename=os.path.join(dataset_dir, srtm_filename_variable),
+        wc_filename=os.path.join(dataset_dir, wc_filename_variable),
     )
 
     nb_bands = len(sensors) * 8
     print("nb_bands  :", nb_bands)
 
+    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
     # checkpoints
-    checkpoint_path = "/home/uz/vinascj/scratch/MMDC/results/latent/logs/experiments/runs/mmdc_full/2023-01-06_13-00-34/checkpoints"
-    checkpoint_filename = "epoch_008.ckpt"  # "last.ckpt"
+    checkpoint_path = "/home/uz/vinascj/scratch/MMDC/results/latent/logs/experiments/runs/mmdc_full/2023-03-16_14-00-10/checkpoints"
+    checkpoint_filename = "last.ckpt"  # "last.ckpt"
 
     mmdc_full_model = get_mmdc_full_model(
         os.path.join(checkpoint_path, checkpoint_filename)
     )
     # move to device
+    scales = get_scales()
+    mmdc_full_model.set_scales(scales)
     mmdc_full_model.to(device)
     assert mmdc_full_model.training == False
 
@@ -449,7 +307,7 @@ def test_mmdc_tile_inference():
     # feed parameters
     input_path = Path("/work/CESBIO/projects/MAESTRIA/training_dataset2/")
     export_path = Path("/work/CESBIO/projects/MAESTRIA/test_onetile/total/export/")
-    model_checkpoint_path = "/home/uz/vinascj/scratch/MMDC/results/latent/logs/experiments/runs/mmdc_full/2023-01-06_13-00-34/checkpoints/epoch_008.ckpt"
+    model_checkpoint_path = "/home/uz/vinascj/scratch/MMDC/results/latent/logs/experiments/runs/mmdc_full/2023-03-16_14-00-10/checkpoints/last.ckpt"
 
     input_tile_list = ["/work/CESBIO/projects/MAESTRIA/training_dataset2/T33TUL"]
 
@@ -542,7 +400,6 @@ def test_mmdc_tile_inference():
             ],
         }
     )
-
     # inference_dataframe(
     #     samples_dir=input_path,
     #     input_tile_list=input_tile_list,
-- 
GitLab


From bebe29b944d582aa423b75425c5d5455dd26202a Mon Sep 17 00:00:00 2001
From: Juan Sebastian Vinasco-Salinas <js.vinasco.s@gmail.com>
Date: Fri, 17 Mar 2023 16:33:04 +0000
Subject: [PATCH 53/81] update to main

---
 .../inference/mmdc_tile_inference.py          | 116 +++++++++---------
 1 file changed, 60 insertions(+), 56 deletions(-)

diff --git a/src/mmdc_singledate/inference/mmdc_tile_inference.py b/src/mmdc_singledate/inference/mmdc_tile_inference.py
index 123b4318..ebc619f5 100644
--- a/src/mmdc_singledate/inference/mmdc_tile_inference.py
+++ b/src/mmdc_singledate/inference/mmdc_tile_inference.py
@@ -16,13 +16,11 @@ import pandas as pd
 import rasterio as rio
 import torch
 from torchutils import patches
+from tqdm import tqdm
 
 from mmdc_singledate.datamodules.components.datamodule_components import prepare_data_df
 
-from .components.inference_components import (  # patchify_batch,; unpatchify_batch,
-    get_mmdc_full_model,
-    predict_mmdc_model,
-)
+from .components.inference_components import get_mmdc_full_model, predict_mmdc_model
 from .components.inference_utils import (
     GeoTiffDataset,
     concat_worldclim_components,
@@ -32,6 +30,7 @@ from .components.inference_utils import (
     read_s2_img_tile,
     read_srtm_img_tile,
 )
+from .utils import get_scales
 
 # Configure the logger
 NUMERIC_LEVEL = getattr(logging, "INFO", None)
@@ -121,48 +120,48 @@ def predict_single_date_tile(
         chunks = generate_chunks(meta["width"], meta["height"], process.nb_lines)
         # get the windows from the chunks
         rois = [rio.windows.Window(*chunk) for chunk in chunks]
-        logger.info(f"chunk size : ({rois[0].width}, {rois[0].height}) ")
+        # logger.info(f"chunk size : ({rois[0].width}, {rois[0].height}) ")
 
         # init the dataset
-        logger.info("Reading S2 data")
+        # logger.info("Reading S2 data")
         s2_data = read_img_tile(
             filename=input_data.s2_filename,
             rois=rois,
             availabitity=input_data.s2_availabitity,
             sensor_func=read_s2_img_tile,
         )
-        logger.info("Reading S1 ASC data")
+        # logger.info("Reading S1 ASC data")
         s1_asc_data = read_img_tile(
             filename=input_data.s1_asc_filename,
             rois=rois,
             availabitity=input_data.s1_asc_availability,
             sensor_func=read_s1_img_tile,
         )
-        logger.info("Reading S1 DESC data")
+        # logger.info("Reading S1 DESC data")
         s1_desc_data = read_img_tile(
             filename=input_data.s1_desc_filename,
             rois=rois,
             availabitity=input_data.s1_desc_availability,
             sensor_func=read_s1_img_tile,
         )
-        logger.info("Reading SRTM data")
+        # logger.info("Reading SRTM data")
         srtm_data = read_img_tile(
             filename=input_data.srtm_filename,
             rois=rois,
             availabitity=True,
             sensor_func=read_srtm_img_tile,
         )
-        logger.info("Reading WorldClim data")
+        # logger.info("Reading WorldClim data")
         worldclim_data = concat_worldclim_components(
             wc_filename=input_data.wc_filename, rois=rois, availabitity=True
         )
-        print("input_data.s2_availabitity=", input_data.s2_availabitity)
-        print("s2_data=", s2_data)
-        print("s1_asc_data=", s1_asc_data)
-        print("s1_desc_data=", s1_desc_data)
-        print("srtm_data=", srtm_data)
-        print("worldclim_data=", worldclim_data)
-        logger.info("Export Init")
+        # print("input_data.s2_availabitity=", input_data.s2_availabitity)
+        # print("s2_data=", s2_data)
+        # print("s1_asc_data=", s1_asc_data)
+        # print("s1_desc_data=", s1_desc_data)
+        # print("srtm_data=", srtm_data)
+        # print("worldclim_data=", worldclim_data)
+        # logger.info("Export Init")
 
         with rio.open(export_path, "w", **meta) as prediction:
             # iterate over the windows
@@ -174,17 +173,17 @@ def predict_single_date_tile(
                 srtm_data,
                 worldclim_data,
             ):
-                print(" original size : ", s2.s2_reflectances.shape)
-                print(" original size : ", s2.s2_angles.shape)
-                print(" original size : ", s2.s2_mask.shape)
-                print(" original size : ", s1_asc.s1_backscatter.shape)
-                print(" original size : ", s1_asc.s1_valmask.shape)
-                print(" original size : ", s1_asc.s1_lia_angles.shape)
-                print(" original size : ", s1_desc.s1_backscatter.shape)
-                print(" original size : ", s1_desc.s1_valmask.shape)
-                print(" original size : ", s1_desc.s1_lia_angles.shape)
-                print(" original size : ", srtm.srtm.shape)
-                print(" original size : ", wc.worldclim.shape)
+                # print(" original size : ", s2.s2_reflectances.shape)
+                # print(" original size : ", s2.s2_angles.shape)
+                # print(" original size : ", s2.s2_mask.shape)
+                # print(" original size : ", s1_asc.s1_backscatter.shape)
+                # print(" original size : ", s1_asc.s1_valmask.shape)
+                # print(" original size : ", s1_asc.s1_lia_angles.shape)
+                # print(" original size : ", s1_desc.s1_backscatter.shape)
+                # print(" original size : ", s1_desc.s1_valmask.shape)
+                # print(" original size : ", s1_desc.s1_lia_angles.shape)
+                # print(" original size : ", srtm.srtm.shape)
+                # print(" original size : ", wc.worldclim.shape)
 
                 # Concat S1 Data
                 s1_backscatter = torch.cat(
@@ -200,8 +199,8 @@ def predict_single_date_tile(
                 ).shape
                 s2_s2_mask_shape = s2.s2_mask.shape  # [1, 1024, 10980]
 
-                print("s2_mask_patch_size :", s2_mask_patch_size)
-                print("s2_s2_mask_shape:", s2_s2_mask_shape)
+                # print("s2_mask_patch_size :", s2_mask_patch_size)
+                # print("s2_s2_mask_shape:", s2_s2_mask_shape)
                 # reshape the data
                 s2_refl_patch = patches.flatten2d(
                     patches.patchify(s2.s2_reflectances, process.patch_size)
@@ -234,23 +233,23 @@ def predict_single_date_tile(
                 )
                 # torch.flatten(start_dim=0, end_dim=1)
 
-                print("s2_patches", s2_refl_patch.shape)
-                print("s2_angles_patches", s2_ang_patch.shape)
-                print("s2_mask_patch", s2_mask_patch.shape)
-                print("s1_patch", s1_patch.shape)
-                print("s1_valmask_patch", s1_valmask_patch.shape)
-                print(
-                    "s1_lia_asc patches",
-                    s1_asc.s1_lia_angles.shape,
-                    s1asc_lia_patch.shape,
-                )
-                print(
-                    "s1_lia_desc patches",
-                    s1_desc.s1_lia_angles.shape,
-                    s1desc_lia_patch.shape,
-                )
-                print("srtm patches", srtm_patch.shape)
-                print("wc patches", wc_patch.shape)
+                # print("s2_patches", s2_refl_patch.shape)
+                # print("s2_angles_patches", s2_ang_patch.shape)
+                # print("s2_mask_patch", s2_mask_patch.shape)
+                # print("s1_patch", s1_patch.shape)
+                # print("s1_valmask_patch", s1_valmask_patch.shape)
+                # print(
+                #     "s1_lia_asc patches",
+                #     s1_asc.s1_lia_angles.shape,
+                #     s1asc_lia_patch.shape,
+                # )
+                # print(
+                #     "s1_lia_desc patches",
+                #     s1_desc.s1_lia_angles.shape,
+                #     s1desc_lia_patch.shape,
+                # )
+                # print("srtm patches", srtm_patch.shape)
+                # print("wc patches", wc_patch.shape)
 
                 # apply predict function
                 # should return a s1s2vaelatentspace
@@ -268,8 +267,8 @@ def predict_single_date_tile(
                     wc_patch,
                     srtm_patch,
                 )
-                print(type(pred_vaelatentspace))
-                print("latent space sizes : ", pred_vaelatentspace.shape)
+                # print(type(pred_vaelatentspace))
+                # print("latent space sizes : ", pred_vaelatentspace.shape)
 
                 # unpatchify
                 pred_vaelatentspace_unpatchify = patches.unpatchify(
@@ -280,8 +279,8 @@ def predict_single_date_tile(
                     )
                 )[:, : s2_s2_mask_shape[1], : s2_s2_mask_shape[2]]
 
-                print("pred_tensor :", pred_vaelatentspace_unpatchify.shape)
-                print("process.count : ", process.count)
+                # print("pred_tensor :", pred_vaelatentspace_unpatchify.shape)
+                # print("process.count : ", process.count)
                 # check the pred and the ask are the same
                 assert process.count == pred_vaelatentspace_unpatchify.shape[0]
 
@@ -290,7 +289,7 @@ def predict_single_date_tile(
                     window=roi,
                     indexes=process.count,
                 )
-                logger.info(("Export tile", f"filename :{export_path}"))
+                # logger.info(("Export tile", f"filename :{export_path}"))
 
 
 def estimate_nb_latent_spaces(
@@ -379,10 +378,15 @@ def mmdc_tile_inference(
     model = get_mmdc_full_model(
         checkpoint=model_checkpoint_path,
     )
+    scales = get_scales()
+    model.set_scales(scales)
+    # move to device
     model.to(device)
 
     # iterate over the dates in the time serie
-    for tuile, df_row in inference_dataframe.iterrows():
+    for tuile, df_row in tqdm(
+        inference_dataframe.iterrows(), total=inference_dataframe.shape[0]
+    ):
         # estimate nb latent spaces
         latent_space_size = estimate_nb_latent_spaces(
             df_row["patch_s2_availability"],
@@ -398,8 +402,8 @@ def mmdc_tile_inference(
         date_ = df_row["date"].strftime("%Y-%m-%d")
         export_file = export_path / f"latent_tile_infer_{date_}_{'_'.join(sensor)}.tif"
         #
-        print(tuile, df_row)
-        print(latent_space_size)
+        # print(tuile, df_row)
+        # print(latent_space_size)
         # define process
         mmdc_process = MMDCProcess(
             count=latent_space_size,
@@ -427,4 +431,4 @@ def mmdc_tile_inference(
             process=mmdc_process,
         )
 
-        logger.info("Export Finish !!!")
+        # logger.info("Export Finish !!!")
-- 
GitLab


From d6dd8459758931a14936f66aaad557553f9a47e4 Mon Sep 17 00:00:00 2001
From: Juan Sebastian Vinasco-Salinas <js.vinasco.s@gmail.com>
Date: Fri, 17 Mar 2023 16:34:00 +0000
Subject: [PATCH 54/81] make sure inference is a module

---
 src/mmdc_singledate/inference/__init__.py     |  1 +
 .../inference/components/__init__.py          |  1 +
 src/mmdc_singledate/inference/utils.py        | 45 +++++++++++++++++++
 3 files changed, 47 insertions(+)
 create mode 100644 src/mmdc_singledate/inference/__init__.py
 create mode 100644 src/mmdc_singledate/inference/components/__init__.py
 create mode 100644 src/mmdc_singledate/inference/utils.py

diff --git a/src/mmdc_singledate/inference/__init__.py b/src/mmdc_singledate/inference/__init__.py
new file mode 100644
index 00000000..e5a0d9b4
--- /dev/null
+++ b/src/mmdc_singledate/inference/__init__.py
@@ -0,0 +1 @@
+#!/usr/bin/env python3
diff --git a/src/mmdc_singledate/inference/components/__init__.py b/src/mmdc_singledate/inference/components/__init__.py
new file mode 100644
index 00000000..e5a0d9b4
--- /dev/null
+++ b/src/mmdc_singledate/inference/components/__init__.py
@@ -0,0 +1 @@
+#!/usr/bin/env python3
diff --git a/src/mmdc_singledate/inference/utils.py b/src/mmdc_singledate/inference/utils.py
new file mode 100644
index 00000000..53adf910
--- /dev/null
+++ b/src/mmdc_singledate/inference/utils.py
@@ -0,0 +1,45 @@
+#!/usr/bin/env python3
+
+import torch
+
+from mmdc_singledate.datamodules.types import MMDCDataStats, MMDCShiftScales, ShiftScale
+
+
+def get_scales():
+    """
+    Read Scales for Inference
+    """
+    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+    dataset_dir = "/work/CESBIO/projects/MAESTRIA/training_dataset2/tftwmwc3"
+    tile = "T31TEM"
+    patch_size = "256x256"
+    roi = 1
+    stats = MMDCDataStats(
+        torch.load(f"{dataset_dir}/{tile}/{tile}_stats_s2_{patch_size}_{roi}.pth"),
+        torch.load(f"{dataset_dir}/{tile}/{tile}_stats_s1_{patch_size}_{roi}.pth"),
+        torch.load(
+            f"{dataset_dir}/{tile}/{tile}_stats_worldclim_{patch_size}_{roi}.pth"
+        ),
+        torch.load(f"{dataset_dir}/{tile}/{tile}_stats_srtm_{patch_size}_{roi}.pth"),
+    )
+    scale_regul = torch.nn.Threshold(1e-10, 1.0)
+    shift_scale_s2 = ShiftScale(
+        stats.sen2.median.to(device),
+        scale_regul((stats.sen2.qmax - stats.sen2.qmin) / 2.0).to(device),
+    )
+    shift_scale_s1 = ShiftScale(
+        stats.sen1.median.to(device),
+        scale_regul((stats.sen1.qmax - stats.sen1.qmin) / 2.0).to(device),
+    )
+    shift_scale_wc = ShiftScale(
+        stats.worldclim.median.to(device),
+        scale_regul((stats.worldclim.qmax - stats.worldclim.qmin) / 2.0).to(device),
+    )
+    shift_scale_srtm = ShiftScale(
+        stats.srtm.median.to(device),
+        scale_regul((stats.srtm.qmax - stats.srtm.qmin) / 2.0).to(device),
+    )
+
+    return MMDCShiftScales(
+        shift_scale_s2, shift_scale_s1, shift_scale_wc, shift_scale_srtm
+    )
-- 
GitLab


From 028ce198d20f1aba7785dce1bb9b1f9a29b49a3a Mon Sep 17 00:00:00 2001
From: Juan Sebastian Vinasco-Salinas <js.vinasco.s@gmail.com>
Date: Fri, 17 Mar 2023 16:34:35 +0000
Subject: [PATCH 55/81] iota2 init

---
 iota2_init.sh | 3 +++
 1 file changed, 3 insertions(+)
 create mode 100644 iota2_init.sh

diff --git a/iota2_init.sh b/iota2_init.sh
new file mode 100644
index 00000000..4d1f386f
--- /dev/null
+++ b/iota2_init.sh
@@ -0,0 +1,3 @@
+module purge
+module load conda
+conda activate /work/scratch/${USER}/virtualenv/mmdc-iota2
-- 
GitLab


From e35a7c0fbe1ccaffb0a87fcdb610ae73bf46cd5c Mon Sep 17 00:00:00 2001
From: Juan Sebastian Vinasco-Salinas <js.vinasco.s@gmail.com>
Date: Fri, 17 Mar 2023 16:36:14 +0000
Subject: [PATCH 56/81] update to main

---
 src/mmdc_singledate/inference/components/inference_utils.py | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/src/mmdc_singledate/inference/components/inference_utils.py b/src/mmdc_singledate/inference/components/inference_utils.py
index 7963eb1a..1d79f1f0 100644
--- a/src/mmdc_singledate/inference/components/inference_utils.py
+++ b/src/mmdc_singledate/inference/components/inference_utils.py
@@ -204,7 +204,7 @@ def read_img_tile(
                 yield sensor_data
     # if not exists create a zeros tensor and yield
     else:
-        print("Creating fake data")
+        # print("Creating fake data")
         for roi in rois:
             null_data = torch.ones(roi.width, roi.height)
             sensor_data = sensor_func(null_data, availabitity, filename)
@@ -238,7 +238,7 @@ def read_s2_img_tile(
         image_s2 = torch.ones(10, s2_tensor.shape[0], s2_tensor.shape[1])
         angles_s2 = torch.ones(6, s2_tensor.shape[0], s2_tensor.shape[1])
         mask = torch.zeros(1, s2_tensor.shape[0], s2_tensor.shape[1])
-        print("passing zero data")
+        # print("passing zero data")
 
     return S2Components(s2_reflectances=image_s2, s2_angles=angles_s2, s2_mask=mask)
 
-- 
GitLab


From f999fd67b7f8a83dc01f89e7dc72707283d79a0a Mon Sep 17 00:00:00 2001
From: Juan Sebastian Vinasco-Salinas <js.vinasco.s@gmail.com>
Date: Mon, 20 Mar 2023 16:48:54 +0000
Subject: [PATCH 57/81] ignore iota2 thirdparty

---
 .gitignore | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/.gitignore b/.gitignore
index 281e6999..207ceac4 100644
--- a/.gitignore
+++ b/.gitignore
@@ -14,4 +14,4 @@ iota2_thirdparties
 .coverage
 src/MMDC_SingleDate.egg-info/
 .projectile
-iota2_thirdparties/
+iota2_thirdparties
-- 
GitLab


From bb0673f73fab4d11b810812fa45c9333b6921cc9 Mon Sep 17 00:00:00 2001
From: Juan Sebastian Vinasco-Salinas <js.vinasco.s@gmail.com>
Date: Mon, 20 Mar 2023 16:52:32 +0000
Subject: [PATCH 58/81] update thirdaparty folders

---
 .gitignore | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/.gitignore b/.gitignore
index 207ceac4..26aaaf09 100644
--- a/.gitignore
+++ b/.gitignore
@@ -9,9 +9,9 @@ _lightning_logs/*
 src/models/*.ipynb
 *.swp
 *~
-thirdparties
+thirdparties/*
 iota2_thirdparties
 .coverage
 src/MMDC_SingleDate.egg-info/
 .projectile
-iota2_thirdparties
+iota2_thirdparties/*
-- 
GitLab


From 52a70cbf1de8defc5a3e1bb08756f8ef5996bfaa Mon Sep 17 00:00:00 2001
From: Juan Sebastian Vinasco-Salinas <js.vinasco.s@gmail.com>
Date: Tue, 21 Mar 2023 15:59:13 +0000
Subject: [PATCH 59/81] update new parameters

---
 jobs/inference_single_date.pbs | 5 +++--
 jobs/inference_time_serie.pbs  | 7 ++++---
 2 files changed, 7 insertions(+), 5 deletions(-)

diff --git a/jobs/inference_single_date.pbs b/jobs/inference_single_date.pbs
index 989f6ec6..1374be03 100644
--- a/jobs/inference_single_date.pbs
+++ b/jobs/inference_single_date.pbs
@@ -13,7 +13,7 @@ export WORKING_DIR=/work/scratch/${USER}/MMDC/jobs
 export DATA_DIR=/work/CESBIO/projects/MAESTRIA/test_onetile/total/
 export SCRATCH_DIR=/work/scratch/${USER}/
 
-cd ${WORKING_DIR}
+cd ${SRCDIR}
 
 mkdir ${SCRATCH_DIR}/MMDC/inference/
 mkdir ${SCRATCH_DIR}/MMDC/inference/singledate
@@ -27,7 +27,8 @@ python ${SRCDIR}/src/bin/inference_mmdc_singledate.py \
     --srtm_filename ${DATA_DIR}/T31TCJ/srtm_T31TCJ_roi_0.tif \
     --worldclim_filename ${DATA_DIR}/T31TCJ/wc_clim_1_T31TCJ_roi_0.tif \
     --export_path ${SCRATCH_DIR}/MMDC/inference/singledate \
-    --model_checkpoint_path /home/uz/vinascj/scratch/MMDC/results/latent/logs/experiments/runs/mmdc_full/2023-01-06_13-00-34/checkpoints/epoch_008.ckpt \
+    --model_checkpoint_path /home/uz/vinascj/scratch/MMDC/results/latent/logs/experiments/runs/mmdc_full/2023-03-16_14-00-10/checkpoints/last.ckpt \
+    --model_config_path /home/uz/vinascj/src/MMDC/mmdc-singledate/configs/model/mmdc_full.yaml \
     --latent_space_size 16 \
     --patch_size 128 \
     --nb_lines 256 \
diff --git a/jobs/inference_time_serie.pbs b/jobs/inference_time_serie.pbs
index 582bf198..0aaf57e2 100644
--- a/jobs/inference_time_serie.pbs
+++ b/jobs/inference_time_serie.pbs
@@ -18,14 +18,15 @@ mkdir ${SCRATCH_DIR}/MMDC/inference/
 mkdir ${SCRATCH_DIR}/MMDC/inference/time_serie
 
 
-cd ${WORKING_DIR}
+cd ${SRCDIR}
 source ${MMDC_INIT}
 
 python ${SRCDIR}/src/bin/inference_mmdc_timeserie.py \
     --input_path ${DATA_DIR} \
     --export_path ${SCRATCH_DIR}/MMDC/inference/time_serie \
-    --tile_list ${DATA_DIR}/T33TUL \
-    --model_checkpoint_path /home/uz/vinascj/scratch/MMDC/results/latent/logs/experiments/runs/mmdc_full/2023-01-06_13-00-34/checkpoints/epoch_008.ckpt \
+    --tile_list T33TUL \
+    --model_checkpoint_path /home/uz/vinascj/scratch/MMDC/results/latent/logs/experiments/runs/mmdc_full/2023-03-16_14-00-10/checkpoints/last.ckpt \
+    --model_config_path /home/uz/vinascj/src/MMDC/mmdc-singledate/configs/model/mmdc_full.yaml \
     --days_gap 0 \
     --patch_size 128 \
     --nb_lines 256 \
-- 
GitLab


From ecea1d9fb9069ccb2484fa7b5fe3206b65729c1e Mon Sep 17 00:00:00 2001
From: Juan Sebastian Vinasco-Salinas <js.vinasco.s@gmail.com>
Date: Tue, 21 Mar 2023 16:01:28 +0000
Subject: [PATCH 60/81] update new parameters

---
 src/bin/inference_mmdc_singledate.py | 9 +++++++++
 1 file changed, 9 insertions(+)

diff --git a/src/bin/inference_mmdc_singledate.py b/src/bin/inference_mmdc_singledate.py
index 6d8e179f..5df740b3 100644
--- a/src/bin/inference_mmdc_singledate.py
+++ b/src/bin/inference_mmdc_singledate.py
@@ -74,6 +74,13 @@ def get_parser() -> argparse.ArgumentParser:
         required=True,
     )
 
+    arg_parser.add_argument(
+        "--model_config_path",
+        type=str,
+        help="model configuration path ",
+        required=True,
+    )
+
     arg_parser.add_argument(
         "--latent_space_size",
         type=int,
@@ -134,6 +141,8 @@ def main():
     # model
     mmdc_full_model = get_mmdc_full_model(
         checkpoint=args.model_checkpoint_path,
+        mmdc_full_config_path=args.model_config_path,
+        inference_tile=True,
     )
     scales = get_scales()
     mmdc_full_model.set_scales(scales)
-- 
GitLab


From faff5329a1b41f303d07d77ebaa1781aec9153cd Mon Sep 17 00:00:00 2001
From: Juan Sebastian Vinasco-Salinas <js.vinasco.s@gmail.com>
Date: Tue, 21 Mar 2023 16:09:46 +0000
Subject: [PATCH 61/81] add read config functions

---
 src/mmdc_singledate/inference/utils.py | 196 ++++++++++++++++++++++++-
 1 file changed, 195 insertions(+), 1 deletion(-)

diff --git a/src/mmdc_singledate/inference/utils.py b/src/mmdc_singledate/inference/utils.py
index 53adf910..6b6a3c45 100644
--- a/src/mmdc_singledate/inference/utils.py
+++ b/src/mmdc_singledate/inference/utils.py
@@ -1,8 +1,28 @@
 #!/usr/bin/env python3
 
+from dataclasses import fields
+
+import pydantic
 import torch
+import yaml
+from config import Config
 
-from mmdc_singledate.datamodules.types import MMDCDataStats, MMDCShiftScales, ShiftScale
+from mmdc_singledate.datamodules.types import (
+    MMDCDataChannels,
+    MMDCDataStats,
+    MMDCShiftScales,
+    ShiftScale,
+)
+from mmdc_singledate.models.types import (  # S1S2VAELatentSpace,
+    ConvnetParams,
+    MLPConfig,
+    MMDCDataUse,
+    ModularEmbeddingConfig,
+    MultiVAEConfig,
+    MultiVAELossWeights,
+    TranslationLossWeights,
+    UnetParams,
+)
 
 
 def get_scales():
@@ -43,3 +63,177 @@ def get_scales():
     return MMDCShiftScales(
         shift_scale_s2, shift_scale_s1, shift_scale_wc, shift_scale_srtm
     )
+
+
+# DONE add pydantic validation
+def parse_from_yaml(path_to_yaml: str):
+    """
+    Read a yaml file and return a dict
+    https://betterprogramming.pub/validating-yaml-configs-made-easy-with-pydantic-594522612db5
+    """
+    with open(path_to_yaml) as f:
+        config = yaml.safe_load(f)
+
+    return config
+
+
+def parse_multivaeconfig_to_pydantic(cfg: dict):
+    """
+    Convert a MultiVAEConfig to a pydantic version
+    for validate the values
+        :args: cfg : configuration as dict
+    """
+    # parse the field of the class to a dict
+    multivae_fields = {
+        _field.name: (_field.type, ...) for _field in fields(MultiVAEConfig)
+    }
+    # create pydantic model from a dataclass
+    pydantic_multivaefields = pydantic.create_model("MultiVAEConfig", **multivae_fields)
+    # apply the convertion
+    multivaeconfig_pydantic = pydantic_multivaefields(**cfg["model"]["config"])
+
+    return multivaeconfig_pydantic
+
+
+# DONE add pydantic validation
+def parse_from_cfg(path_to_cfg: str):
+    """
+    Read a cfg as config file validate the values
+    and return a dict
+    """
+    with open(path_to_cfg, encoding="UTF-8") as i2_config:
+        cfg = Config(i2_config)
+        cfg_dict = cfg.as_dict()
+
+    return cfg_dict
+
+
+def get_mmdc_full_config(path_to_cfg: str, inference_tile: bool):
+    """ "
+    Get the parameters for instantiate the mmdc_full module
+    based on the target inference mode
+    """
+
+    # read the yaml hydra config
+    if inference_tile:
+        cfg = parse_from_yaml(path_to_cfg)  # "configs/model/mmdc_full.yaml")
+        cfg = parse_multivaeconfig_to_pydantic(cfg)
+
+    else:
+        cfg = parse_from_cfg(path_to_cfg)
+        cfg = parse_multivaeconfig_to_pydantic(cfg)
+
+    mmdc_full_config = MultiVAEConfig(
+        data_sizes=MMDCDataChannels(
+            sen1=cfg.data_sizes.sen1,
+            s1_angles=cfg.data_sizes.s1_angles,
+            sen2=cfg.data_sizes.sen2,
+            s2_angles=cfg.data_sizes.s2_angles,
+            srtm=cfg.data_sizes.srtm,
+            w_c=cfg.data_sizes.w_c,
+        ),
+        embeddings=ModularEmbeddingConfig(
+            s1_angles=MLPConfig(
+                hidden_layers=cfg.embeddings.s1_angles.hidden_layers,
+                out_channels=cfg.embeddings.s1_angles.out_channels,
+            ),
+            s1_srtm=UnetParams(
+                out_channels=cfg.embeddings.s1_srtm.out_channels,
+                encoder_sizes=cfg.embeddings.s1_srtm.encoder_sizes,
+                kernel_size=cfg.embeddings.s1_srtm.kernel_size,
+                tail_layers=cfg.embeddings.s1_srtm.tail_layers,
+            ),
+            s2_angles=ConvnetParams(
+                out_channels=cfg.embeddings.s2_angles.out_channels,
+                sizes=cfg.embeddings.s2_angles.sizes,
+                kernel_sizes=cfg.embeddings.s2_angles.kernel_sizes,
+            ),
+            s2_srtm=UnetParams(
+                out_channels=cfg.embeddings.s2_srtm.out_channels,
+                encoder_sizes=cfg.embeddings.s2_srtm.encoder_sizes,
+                kernel_size=cfg.embeddings.s2_srtm.kernel_size,
+                tail_layers=cfg.embeddings.s2_srtm.tail_layers,
+            ),
+            w_c=ConvnetParams(
+                out_channels=cfg.embeddings.w_c.out_channels,
+                sizes=cfg.embeddings.w_c.sizes,
+                kernel_sizes=cfg.embeddings.w_c.kernel_sizes,
+            ),
+        ),
+        s1_encoder=UnetParams(
+            out_channels=cfg.s1_encoder.out_channels,
+            encoder_sizes=cfg.s1_encoder.encoder_sizes,
+            kernel_size=cfg.s1_encoder.kernel_size,
+            tail_layers=cfg.s1_encoder.tail_layers,
+        ),
+        s2_encoder=UnetParams(
+            out_channels=cfg.s2_encoder.out_channels,
+            encoder_sizes=cfg.s2_encoder.encoder_sizes,
+            kernel_size=cfg.s2_encoder.kernel_size,
+            tail_layers=cfg.s2_encoder.tail_layers,
+        ),
+        s1_decoder=ConvnetParams(
+            out_channels=cfg.s1_decoder.out_channels,
+            sizes=cfg.s1_decoder.sizes,
+            kernel_sizes=cfg.s1_decoder.kernel_sizes,
+        ),
+        s2_decoder=ConvnetParams(
+            out_channels=cfg.s2_decoder.out_channels,
+            sizes=cfg.s2_decoder.sizes,
+            kernel_sizes=cfg.s2_decoder.kernel_sizes,
+        ),
+        s1_enc_use=MMDCDataUse(
+            s1_angles=cfg.s1_enc_use.s1_angles,
+            s2_angles=cfg.s1_enc_use.s2_angles,
+            srtm=cfg.s1_enc_use.srtm,
+            w_c=cfg.s1_enc_use.w_c,
+        ),
+        s2_enc_use=MMDCDataUse(
+            s1_angles=cfg.s2_enc_use.s1_angles,
+            s2_angles=cfg.s2_enc_use.s2_angles,
+            srtm=cfg.s2_enc_use.srtm,
+            w_c=cfg.s2_enc_use.w_c,
+        ),
+        s1_dec_use=MMDCDataUse(
+            s1_angles=cfg.s1_dec_use.s1_angles,
+            s2_angles=cfg.s1_dec_use.s2_angles,
+            srtm=cfg.s1_dec_use.srtm,
+            w_c=cfg.s1_dec_use.w_c,
+        ),
+        s2_dec_use=MMDCDataUse(
+            s1_angles=cfg.s2_dec_use.s1_angles,
+            s2_angles=cfg.s2_dec_use.s2_angles,
+            srtm=cfg.s2_dec_use.srtm,
+            w_c=cfg.s2_dec_use.w_c,
+        ),
+        loss_weights=MultiVAELossWeights(
+            sen1=TranslationLossWeights(
+                nll=cfg.loss_weights.sen1.nll,
+                lpips=cfg.loss_weights.sen1.lpips,
+                sam=cfg.loss_weights.sen1.sam,
+                gradients=cfg.loss_weights.sen1.gradients,
+            ),
+            sen2=TranslationLossWeights(
+                nll=cfg.loss_weights.sen2.nll,
+                lpips=cfg.loss_weights.sen2.lpips,
+                sam=cfg.loss_weights.sen2.sam,
+                gradients=cfg.loss_weights.sen2.gradients,
+            ),
+            s1_s2=TranslationLossWeights(
+                nll=cfg.loss_weights.s1_s2.nll,
+                lpips=cfg.loss_weights.s1_s2.lpips,
+                sam=cfg.loss_weights.s1_s2.sam,
+                gradients=cfg.loss_weights.s1_s2.gradients,
+            ),
+            s2_s1=TranslationLossWeights(
+                nll=cfg.loss_weights.s2_s1.nll,
+                lpips=cfg.loss_weights.s2_s1.lpips,
+                sam=cfg.loss_weights.s2_s1.sam,
+                gradients=cfg.loss_weights.s2_s1.gradients,
+            ),
+            forward=cfg.loss_weights.forward,
+            cross=cfg.loss_weights.forward,
+            latent=cfg.loss_weights.latent,
+        ),
+    )
+    return mmdc_full_config
-- 
GitLab


From 06690df09ef0f6d70047da21d7b93e60baebab99 Mon Sep 17 00:00:00 2001
From: Juan Sebastian Vinasco-Salinas <js.vinasco.s@gmail.com>
Date: Tue, 21 Mar 2023 16:17:08 +0000
Subject: [PATCH 62/81] update config parameters

---
 src/bin/inference_mmdc_timeserie.py | 8 ++++++++
 1 file changed, 8 insertions(+)

diff --git a/src/bin/inference_mmdc_timeserie.py b/src/bin/inference_mmdc_timeserie.py
index 56626f63..4db518f2 100644
--- a/src/bin/inference_mmdc_timeserie.py
+++ b/src/bin/inference_mmdc_timeserie.py
@@ -58,6 +58,13 @@ def get_parser() -> argparse.ArgumentParser:
         required=True,
     )
 
+    arg_parser.add_argument(
+        "--model_config_path",
+        type=str,
+        help="model configuration path ",
+        required=True,
+    )
+
     arg_parser.add_argument(
         "--tile_list",
         nargs="+",
@@ -120,6 +127,7 @@ def main():
         inference_dataframe=tile_df,
         export_path=Path(args.export_path),
         model_checkpoint_path=Path(args.model_checkpoint_path),
+        model_config_path=Path(args.model_config_path),
         patch_size=args.patch_size,
         nb_lines=args.nb_lines,
     )
-- 
GitLab


From 6e323282384b0d6f992ac2474ed04ebd00523c63 Mon Sep 17 00:00:00 2001
From: Juan Sebastian Vinasco-Salinas <js.vinasco.s@gmail.com>
Date: Tue, 21 Mar 2023 16:21:05 +0000
Subject: [PATCH 63/81] update config parameters

---
 test/test_mmdc_inference.py | 13 ++++++++++---
 1 file changed, 10 insertions(+), 3 deletions(-)

diff --git a/test/test_mmdc_inference.py b/test/test_mmdc_inference.py
index 51be142f..fd15d093 100644
--- a/test/test_mmdc_inference.py
+++ b/test/test_mmdc_inference.py
@@ -195,7 +195,9 @@ def test_predict_single_date_tile(sensors):
     checkpoint_filename = "last.ckpt"  # "last.ckpt"
 
     mmdc_full_model = get_mmdc_full_model(
-        os.path.join(checkpoint_path, checkpoint_filename)
+        os.path.join(checkpoint_path, checkpoint_filename),
+        mmdc_full_config_path="/home/uz/vinascj/src/MMDC/mmdc-singledate/configs/model/mmdc_full.json",
+        inference_tile=True,
     )
     scales = get_scales()
     mmdc_full_model.set_scales(scales)
@@ -267,7 +269,9 @@ def test_predict_single_date_tile_no_data(
     checkpoint_filename = "last.ckpt"  # "last.ckpt"
 
     mmdc_full_model = get_mmdc_full_model(
-        os.path.join(checkpoint_path, checkpoint_filename)
+        os.path.join(checkpoint_path, checkpoint_filename),
+        mmdc_full_config_path="/home/uz/vinascj/src/MMDC/mmdc-singledate/configs/model/mmdc_full.json",
+        inference_tile=True,
     )
     # move to device
     scales = get_scales()
@@ -308,7 +312,9 @@ def test_mmdc_tile_inference():
     input_path = Path("/work/CESBIO/projects/MAESTRIA/training_dataset2/")
     export_path = Path("/work/CESBIO/projects/MAESTRIA/test_onetile/total/export/")
     model_checkpoint_path = "/home/uz/vinascj/scratch/MMDC/results/latent/logs/experiments/runs/mmdc_full/2023-03-16_14-00-10/checkpoints/last.ckpt"
-
+    model_config_path = (
+        "/home/uz/vinascj/src/MMDC/mmdc-singledate/configs/model/mmdc_full.yaml"
+    )
     input_tile_list = ["/work/CESBIO/projects/MAESTRIA/training_dataset2/T33TUL"]
 
     # dataframe with input data
@@ -422,6 +428,7 @@ def test_mmdc_tile_inference():
         inference_dataframe=tile_df,
         export_path=export_path,
         model_checkpoint_path=model_checkpoint_path,
+        model_config_path=model_config_path,
         patch_size=128,
         nb_lines=256,
     )
-- 
GitLab


From 2fb8131c18b3edeeecb4ff3c353bb3033329e735 Mon Sep 17 00:00:00 2001
From: Juan Sebastian Vinasco-Salinas <js.vinasco.s@gmail.com>
Date: Tue, 21 Mar 2023 16:22:12 +0000
Subject: [PATCH 64/81] add requirement

---
 requirements-mmdc-sgld.txt | 3 +++
 1 file changed, 3 insertions(+)

diff --git a/requirements-mmdc-sgld.txt b/requirements-mmdc-sgld.txt
index 02defbac..c4b218d1 100644
--- a/requirements-mmdc-sgld.txt
+++ b/requirements-mmdc-sgld.txt
@@ -1,6 +1,7 @@
 affine
 auto-walrus
 black           # code formatting
+config
 findpeaks
 flake8          # code analysis
 geopandas
@@ -18,6 +19,8 @@ numpy
 pandas
 pre-commit      # hooks for applying linters on commit
 pudb            # debugger
+pydantic        # validate configurations in iota2
+pydantic-yaml   # read yml configs for construct the networks
 pytest          # tests
 python-dotenv   # loading env variables from .env file
 python-lsp-server[all]
-- 
GitLab


From 1cfdcdf1e91de973e274aee6785ad48900545ae1 Mon Sep 17 00:00:00 2001
From: Juan Sebastian Vinasco-Salinas <js.vinasco.s@gmail.com>
Date: Tue, 21 Mar 2023 16:32:51 +0000
Subject: [PATCH 65/81] put config logic in other function

---
 .../components/inference_components.py        | 527 +++++++-----------
 1 file changed, 195 insertions(+), 332 deletions(-)

diff --git a/src/mmdc_singledate/inference/components/inference_components.py b/src/mmdc_singledate/inference/components/inference_components.py
index e3a9f52b..30fb377e 100644
--- a/src/mmdc_singledate/inference/components/inference_components.py
+++ b/src/mmdc_singledate/inference/components/inference_components.py
@@ -11,22 +11,12 @@ from pathlib import Path
 from typing import Literal
 
 import torch
-import yaml
 from torch import nn
 
-from mmdc_singledate.datamodules.types import MMDCDataChannels
+# from mmdc_singledate.datamodules.types import MMDCDataChannels
+from mmdc_singledate.inference.utils import get_mmdc_full_config
 from mmdc_singledate.models.mmdc_full_module import MMDCFullModule
-from mmdc_singledate.models.types import (
-    ConvnetParams,
-    MLPConfig,
-    MMDCDataUse,
-    ModularEmbeddingConfig,
-    MultiVAEConfig,
-    MultiVAELossWeights,
-    S1S2VAELatentSpace,
-    TranslationLossWeights,
-    UnetParams,
-)
+from mmdc_singledate.models.types import S1S2VAELatentSpace
 
 # define sensors variable for typing
 SENSORS = Literal["S2L2A", "S1FULL", "S1ASC", "S1DESC"]
@@ -39,189 +29,32 @@ logging.basicConfig(
 logger = logging.getLogger(__name__)
 
 
-def parse_from_yaml(path_to_yaml):
-    """
-    Read a yaml file and return a dict
-    https://betterprogramming.pub/validating-yaml-configs-made-easy-with-pydantic-594522612db5
-    """
-    with open(path_to_yaml) as f:
-        config = yaml.safe_load(f)
-
-    return config
-
-
 @torch.no_grad()
 def get_mmdc_full_model(
     checkpoint: str,
+    mmdc_full_config_path: str,
+    inference_tile: bool,
 ) -> nn.Module:
     """
     Instantiate the network
 
     """
-
     if not Path(checkpoint).exists:
         logger.info("No checkpoint path was give. ")
         return 1
     else:
-        cfg = parse_from_yaml("configs/model/mmdc_full.yaml")
-
-        mmdc_full_config = MultiVAEConfig(
-            data_sizes=MMDCDataChannels(
-                sen1=cfg["model"]["config"]["data_sizes"]["sen1"],
-                s1_angles=cfg["model"]["config"]["data_sizes"]["s1_angles"],
-                sen2=cfg["model"]["config"]["data_sizes"]["sen2"],
-                s2_angles=cfg["model"]["config"]["data_sizes"]["s2_angles"],
-                srtm=cfg["model"]["config"]["data_sizes"]["srtm"],
-                w_c=cfg["model"]["config"]["data_sizes"]["w_c"],
-            ),
-            embeddings=ModularEmbeddingConfig(
-                s1_angles=MLPConfig(
-                    hidden_layers=cfg["model"]["config"]["embeddings"]["s1_angles"][
-                        "hidden_layers"
-                    ],
-                    out_channels=cfg["model"]["config"]["embeddings"]["s1_angles"][
-                        "out_channels"
-                    ],
-                ),
-                s1_srtm=UnetParams(
-                    out_channels=cfg["model"]["config"]["embeddings"]["s1_srtm"][
-                        "out_channels"
-                    ],
-                    encoder_sizes=cfg["model"]["config"]["embeddings"]["s1_srtm"][
-                        "encoder_sizes"
-                    ],
-                    kernel_size=cfg["model"]["config"]["embeddings"]["s1_srtm"][
-                        "kernel_size"
-                    ],
-                    tail_layers=cfg["model"]["config"]["embeddings"]["s1_srtm"][
-                        "tail_layers"
-                    ],
-                ),
-                s2_angles=ConvnetParams(
-                    out_channels=cfg["model"]["config"]["embeddings"]["s2_angles"][
-                        "out_channels"
-                    ],
-                    sizes=cfg["model"]["config"]["embeddings"]["s2_angles"]["sizes"],
-                    kernel_sizes=cfg["model"]["config"]["embeddings"]["s2_angles"][
-                        "kernel_sizes"
-                    ],
-                ),
-                s2_srtm=UnetParams(
-                    out_channels=cfg["model"]["config"]["embeddings"]["s2_srtm"][
-                        "out_channels"
-                    ],
-                    encoder_sizes=cfg["model"]["config"]["embeddings"]["s2_srtm"][
-                        "encoder_sizes"
-                    ],
-                    kernel_size=cfg["model"]["config"]["embeddings"]["s2_srtm"][
-                        "kernel_size"
-                    ],
-                    tail_layers=cfg["model"]["config"]["embeddings"]["s2_srtm"][
-                        "tail_layers"
-                    ],
-                ),
-                w_c=ConvnetParams(
-                    out_channels=cfg["model"]["config"]["embeddings"]["w_c"][
-                        "out_channels"
-                    ],
-                    sizes=cfg["model"]["config"]["embeddings"]["w_c"]["sizes"],
-                    kernel_sizes=cfg["model"]["config"]["embeddings"]["w_c"][
-                        "kernel_sizes"
-                    ],
-                ),
-            ),
-            s1_encoder=UnetParams(
-                out_channels=cfg["model"]["config"]["s1_encoder"]["out_channels"],
-                encoder_sizes=cfg["model"]["config"]["s1_encoder"]["encoder_sizes"],
-                kernel_size=cfg["model"]["config"]["s1_encoder"]["kernel_size"],
-                tail_layers=cfg["model"]["config"]["s1_encoder"]["tail_layers"],
-            ),
-            s2_encoder=UnetParams(
-                out_channels=cfg["model"]["config"]["s2_encoder"]["out_channels"],
-                encoder_sizes=cfg["model"]["config"]["s2_encoder"]["encoder_sizes"],
-                kernel_size=cfg["model"]["config"]["s2_encoder"]["kernel_size"],
-                tail_layers=cfg["model"]["config"]["s2_encoder"]["tail_layers"],
-            ),
-            s1_decoder=ConvnetParams(
-                out_channels=cfg["model"]["config"]["s1_decoder"]["out_channels"],
-                sizes=cfg["model"]["config"]["s1_decoder"]["sizes"],
-                kernel_sizes=cfg["model"]["config"]["s1_decoder"]["kernel_sizes"],
-            ),
-            s2_decoder=ConvnetParams(
-                out_channels=cfg["model"]["config"]["s2_decoder"]["out_channels"],
-                sizes=cfg["model"]["config"]["s2_decoder"]["sizes"],
-                kernel_sizes=cfg["model"]["config"]["s2_decoder"]["kernel_sizes"],
-            ),
-            s1_enc_use=MMDCDataUse(
-                s1_angles=cfg["model"]["config"]["s1_enc_use"]["s1_angles"],
-                s2_angles=cfg["model"]["config"]["s1_enc_use"]["s2_angles"],
-                srtm=cfg["model"]["config"]["s1_enc_use"]["srtm"],
-                w_c=cfg["model"]["config"]["s1_enc_use"]["w_c"],
-            ),
-            s2_enc_use=MMDCDataUse(
-                s1_angles=cfg["model"]["config"]["s2_enc_use"]["s1_angles"],
-                s2_angles=cfg["model"]["config"]["s2_enc_use"]["s2_angles"],
-                srtm=cfg["model"]["config"]["s2_enc_use"]["srtm"],
-                w_c=cfg["model"]["config"]["s2_enc_use"]["w_c"],
-            ),
-            s1_dec_use=MMDCDataUse(
-                s1_angles=cfg["model"]["config"]["s1_dec_use"]["s1_angles"],
-                s2_angles=cfg["model"]["config"]["s1_dec_use"]["s2_angles"],
-                srtm=cfg["model"]["config"]["s1_dec_use"]["srtm"],
-                w_c=cfg["model"]["config"]["s1_dec_use"]["w_c"],
-            ),
-            s2_dec_use=MMDCDataUse(
-                s1_angles=cfg["model"]["config"]["s2_dec_use"]["s1_angles"],
-                s2_angles=cfg["model"]["config"]["s2_dec_use"]["s2_angles"],
-                srtm=cfg["model"]["config"]["s2_dec_use"]["srtm"],
-                w_c=cfg["model"]["config"]["s2_dec_use"]["w_c"],
-            ),
-            loss_weights=MultiVAELossWeights(
-                sen1=TranslationLossWeights(
-                    nll=cfg["model"]["config"]["loss_weights"]["sen1"]["nll"],
-                    lpips=cfg["model"]["config"]["loss_weights"]["sen1"]["lpips"],
-                    sam=cfg["model"]["config"]["loss_weights"]["sen1"]["sam"],
-                    gradients=cfg["model"]["config"]["loss_weights"]["sen1"][
-                        "gradients"
-                    ],
-                ),
-                sen2=TranslationLossWeights(
-                    nll=cfg["model"]["config"]["loss_weights"]["sen2"]["nll"],
-                    lpips=cfg["model"]["config"]["loss_weights"]["sen2"]["lpips"],
-                    sam=cfg["model"]["config"]["loss_weights"]["sen2"]["sam"],
-                    gradients=cfg["model"]["config"]["loss_weights"]["sen2"][
-                        "gradients"
-                    ],
-                ),
-                s1_s2=TranslationLossWeights(
-                    nll=cfg["model"]["config"]["loss_weights"]["s1_s2"]["nll"],
-                    lpips=cfg["model"]["config"]["loss_weights"]["s1_s2"]["lpips"],
-                    sam=cfg["model"]["config"]["loss_weights"]["s1_s2"]["sam"],
-                    gradients=cfg["model"]["config"]["loss_weights"]["s1_s2"][
-                        "gradients"
-                    ],
-                ),
-                s2_s1=TranslationLossWeights(
-                    nll=cfg["model"]["config"]["loss_weights"]["s2_s1"]["nll"],
-                    lpips=cfg["model"]["config"]["loss_weights"]["s2_s1"]["lpips"],
-                    sam=cfg["model"]["config"]["loss_weights"]["s2_s1"]["sam"],
-                    gradients=cfg["model"]["config"]["loss_weights"]["s2_s1"][
-                        "gradients"
-                    ],
-                ),
-                forward=cfg["model"]["config"]["loss_weights"]["forward"],
-                cross=cfg["model"]["config"]["loss_weights"]["forward"],
-                latent=cfg["model"]["config"]["loss_weights"]["latent"],
-            ),
+        # read the config file
+        mmdc_full_config = get_mmdc_full_config(
+            mmdc_full_config_path, inference_tile=inference_tile
         )
-        # Init the model
+        # Init the model passing the config
         mmdc_full_model = MMDCFullModule(
             mmdc_full_config,
         )
 
         device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
-        # print(f"device detected : {device}")
-        # print(f"nb_gpu's detected : {torch.cuda.device_count()}")
+        logger.debug(f"device detected : {device}")
+        logger.debug(f"nb_gpu's detected : {torch.cuda.device_count()}")
         # load state_dict
         lightning_checkpoint = torch.load(checkpoint, map_location=device)
 
@@ -276,7 +109,7 @@ def predict_mmdc_model(
     # available_sensors = ["S2L2A", "S1FULL", "S1ASC", "S1DESC"]
 
     device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
-    # print(device)
+    logger.debug(device)
 
     prediction = model.predict(
         s2_ref.to(device),
@@ -290,10 +123,10 @@ def predict_mmdc_model(
         srtm.to(device),
     )
 
-    # print("prediction.shape :", prediction[0].latent.latent_s2.mean.shape)
-    # print("prediction.shape :", prediction[0].latent.latent_s2.logvar.shape)
-    # print("prediction.shape :", prediction[0].latent.latent_s1.mean.shape)
-    # print("prediction.shape :", prediction[0].latent.latent_s1.logvar.shape)
+    logger.debug("prediction.shape :", prediction[0].latent.latent_s2.mean.shape)
+    logger.debug("prediction.shape :", prediction[0].latent.latent_s2.logvar.shape)
+    logger.debug("prediction.shape :", prediction[0].latent.latent_s1.mean.shape)
+    logger.debug("prediction.shape :", prediction[0].latent.latent_s1.logvar.shape)
 
     # init the output latent spaces as
     # empty dataclass
@@ -301,154 +134,184 @@ def predict_mmdc_model(
 
     # fullfit with a matchcase
     # match
-    match sensors:
-        # Cases
-        case ["S2L2A", "S1FULL"]:
-            # logger.info("S2 captured & S1 full captured")
-
-            latent_space.latent_s2 = prediction[0].latent.latent_s2
-            latent_space.latent_s1 = prediction[0].latent.latent_s1
-
-            latent_space_stack = torch.cat(
-                (
-                    latent_space.latent_s2.mean,
-                    latent_space.latent_s2.logvar,
-                    latent_space.latent_s1.mean,
-                    latent_space.latent_s1.logvar,
-                ),
-                1,
-            )
-
-            # print(
-            #     "latent_space.latent_s2.mean :",
-            #     latent_space.latent_s2.mean.shape,
-            # )
-            # print(
-            #     "latent_space.latent_s2.logvar :",
-            #     latent_space.latent_s2.logvar.shape,
-            # )
-            # print("latent_space.latent_s1.mean :", latent_space.latent_s1.mean.shape)
-            # print(
-            #     "latent_space.latent_s1.logvar :", latent_space.latent_s1.logvar.shape
-            # )
-
-        case ["S2L2A", "S1ASC"]:
-            # logger.info("S2 captured & S1 asc captured")
-
-            latent_space.latent_s2 = prediction[0].latent.latent_s2
-            latent_space.latent_s1 = prediction[0].latent.latent_s1
-
-            latent_space_stack = torch.cat(
-                (
-                    latent_space.latent_s2.mean,
-                    latent_space.latent_s2.logvar,
-                    latent_space.latent_s1.mean,
-                    latent_space.latent_s1.logvar,
-                ),
-                1,
-            )
-            # print(
-            #     "latent_space.latent_s2.mean :",
-            #     latent_space.latent_s2.mean.shape,
-            # )
-            # print(
-            #     "latent_space.latent_s2.logvar :",
-            #     latent_space.latent_s2.logvar.shape,
-            # )
-            # print("latent_space.latent_s1.mean :", latent_space.latent_s1.mean.shape)
-            # print(
-            #     "latent_space.latent_s1.logvar :", latent_space.latent_s1.logvar.shape
-            # )
-
-        case ["S2L2A", "S1DESC"]:
-            # logger.info("S2 captured & S1 desc captured")
-
-            latent_space.latent_s2 = prediction[0].latent.latent_s2
-            latent_space.latent_s1 = prediction[0].latent.latent_s1
-
-            latent_space_stack = torch.cat(
-                (
-                    latent_space.latent_s2.mean,
-                    latent_space.latent_s2.logvar,
-                    latent_space.latent_s1.mean,
-                    latent_space.latent_s1.logvar,
-                ),
-                1,
-            )
-            # print(
-            #     "latent_space.latent_s2.mean :",
-            #     latent_space.latent_s2.mean.shape,
-            # )
-            # print(
-            #     "latent_space.latent_s2.logvar :",
-            #     latent_space.latent_s2.logvar.shape,
-            # )
-            # print("latent_space.latent_s1.mean :", latent_space.latent_s1.mean.shape)
-            # print(
-            #     "latent_space.latent_s1.logvar :", latent_space.latent_s1.logvar.shape
-            # )
-
-        case ["S2L2A"]:
-            # logger.info("Only S2 captured")
-
-            latent_space.latent_s2 = prediction[0].latent.latent_s2
-
-            latent_space_stack = torch.cat(
-                (
-                    latent_space.latent_s2.mean,
-                    latent_space.latent_s2.logvar,
-                ),
-                1,
-            )
-            # print(
-            #     "latent_space.latent_s2.mean :",
-            #     latent_space.latent_s2.mean.shape,
-            # )
-            # print(
-            #     "latent_space.latent_s2.logvar :",
-            #     latent_space.latent_s2.logvar.shape,
-            # )
-
-        case ["S1FULL"]:
-            # logger.info("Only S1 full captured")
-
-            latent_space.latent_s1 = prediction[0].latent.latent_s1
-
-            latent_space_stack = torch.cat(
-                (latent_space.latent_s1.mean, latent_space.latent_s1.logvar),
-                1,
-            )
-            # print("latent_space.latent_s1.mean :", latent_space.latent_s1.mean.shape)
-            # print(
-            #     "latent_space.latent_s1.logvar :", latent_space.latent_s1.logvar.shape
-            # )
-
-        case ["S1ASC"]:
-            # logger.info("Only S1ASC captured")
-
-            latent_space.latent_s1 = prediction[0].latent.latent_s1
-
-            latent_space_stack = torch.cat(
-                (latent_space.latent_s1.mean, latent_space.latent_s1.logvar),
-                1,
-            )
-            # print("latent_space.latent_s1.mean :", latent_space.latent_s1.mean.shape)
-            # print(
-            # "latent_space.latent_s1.logvar :", latent_space.latent_s1.logvar.shape
-            # )
-
-        case ["S1DESC"]:
-            # logger.info("Only S1DESC captured")
-
-            latent_space.latent_s1 = prediction[0].latent.latent_s1
-
-            latent_space_stack = torch.cat(
-                (latent_space.latent_s1.mean, latent_space.latent_s1.logvar),
-                1,
-            )
-            # print("latent_space.latent_s1.mean :", latent_space.latent_s1.mean.shape)
-            # print(
-            # "latent_space.latent_s1.logvar :", latent_space.latent_s1.logvar.shape
-            # )
+    # match sensors:
+    # Cases
+    if sensors == ["S2L2A", "S1FULL"]:
+        logger.debug("S2 captured & S1 full captured")
+
+        latent_space.latent_s2 = prediction[0].latent.latent_s2
+        latent_space.latent_s1 = prediction[0].latent.latent_s1
+
+        latent_space_stack = torch.cat(
+            (
+                latent_space.latent_s2.mean,
+                latent_space.latent_s2.logvar,
+                latent_space.latent_s1.mean,
+                latent_space.latent_s1.logvar,
+            ),
+            1,
+        )
+
+        logger.debug(
+            "latent_space.latent_s2.mean :",
+            latent_space.latent_s2.mean.shape,
+        )
+        logger.debug(
+            "latent_space.latent_s2.logvar :",
+            latent_space.latent_s2.logvar.shape,
+        )
+        logger.debug("latent_space.latent_s1.mean :", latent_space.latent_s1.mean.shape)
+        logger.debug(
+            "latent_space.latent_s1.logvar :", latent_space.latent_s1.logvar.shape
+        )
+
+    elif sensors == ["S2L2A", "S1ASC"]:
+        # logger.debug("S2 captured & S1 asc captured")
+
+        latent_space.latent_s2 = prediction[0].latent.latent_s2
+        latent_space.latent_s1 = prediction[0].latent.latent_s1
+
+        latent_space_stack = torch.cat(
+            (
+                latent_space.latent_s2.mean,
+                latent_space.latent_s2.logvar,
+                latent_space.latent_s1.mean,
+                latent_space.latent_s1.logvar,
+            ),
+            1,
+        )
+        logger.debug(
+            "latent_space.latent_s2.mean :",
+            latent_space.latent_s2.mean.shape,
+        )
+        logger.debug(
+            "latent_space.latent_s2.logvar :",
+            latent_space.latent_s2.logvar.shape,
+        )
+        logger.debug("latent_space.latent_s1.mean :", latent_space.latent_s1.mean.shape)
+        logger.debug(
+            "latent_space.latent_s1.logvar :", latent_space.latent_s1.logvar.shape
+        )
+
+    elif sensors == ["S2L2A", "S1DESC"]:
+        logger.debug("S2 captured & S1 desc captured")
+
+        latent_space.latent_s2 = prediction[0].latent.latent_s2
+        latent_space.latent_s1 = prediction[0].latent.latent_s1
+
+        latent_space_stack = torch.cat(
+            (
+                latent_space.latent_s2.mean,
+                latent_space.latent_s2.logvar,
+                latent_space.latent_s1.mean,
+                latent_space.latent_s1.logvar,
+            ),
+            1,
+        )
+        logger.debug(
+            "latent_space.latent_s2.mean :",
+            latent_space.latent_s2.mean.shape,
+        )
+        logger.debug(
+            "latent_space.latent_s2.logvar :",
+            latent_space.latent_s2.logvar.shape,
+        )
+        logger.debug("latent_space.latent_s1.mean :", latent_space.latent_s1.mean.shape)
+        logger.debug(
+            "latent_space.latent_s1.logvar :", latent_space.latent_s1.logvar.shape
+        )
+
+    elif sensors == ["S2L2A"]:
+        # logger.debug("Only S2 captured")
+
+        latent_space.latent_s2 = prediction[0].latent.latent_s2
+
+        latent_space_stack = torch.cat(
+            (
+                latent_space.latent_s2.mean,
+                latent_space.latent_s2.logvar,
+            ),
+            1,
+        )
+        logger.debug(
+            "latent_space.latent_s2.mean :",
+            latent_space.latent_s2.mean.shape,
+        )
+        logger.debug(
+            "latent_space.latent_s2.logvar :",
+            latent_space.latent_s2.logvar.shape,
+        )
+
+    elif sensors == ["S1FULL"]:
+        # logger.debug("Only S1 full captured")
+
+        latent_space.latent_s1 = prediction[0].latent.latent_s1
+
+        latent_space_stack = torch.cat(
+            (latent_space.latent_s1.mean, latent_space.latent_s1.logvar),
+            1,
+        )
+        logger.debug("latent_space.latent_s1.mean :", latent_space.latent_s1.mean.shape)
+        logger.debug(
+            "latent_space.latent_s1.logvar :", latent_space.latent_s1.logvar.shape
+        )
+
+    elif sensors == ["S1ASC"]:
+        # logger.debug("Only S1ASC captured")
+
+        latent_space.latent_s1 = prediction[0].latent.latent_s1
+
+        latent_space_stack = torch.cat(
+            (latent_space.latent_s1.mean, latent_space.latent_s1.logvar),
+            1,
+        )
+        logger.debug("latent_space.latent_s1.mean :", latent_space.latent_s1.mean.shape)
+        logger.debug(
+            "latent_space.latent_s1.logvar :", latent_space.latent_s1.logvar.shape
+        )
+
+    elif sensors == ["S1DESC"]:
+        # logger.debug("Only S1DESC captured")
+
+        latent_space.latent_s1 = prediction[0].latent.latent_s1
+
+        latent_space_stack = torch.cat(
+            (latent_space.latent_s1.mean, latent_space.latent_s1.logvar),
+            1,
+        )
+        logger.debug("latent_space.latent_s1.mean :", latent_space.latent_s1.mean.shape)
+        logger.debug(
+            "latent_space.latent_s1.logvar :", latent_space.latent_s1.logvar.shape
+        )
+    else:
+        raise Exception(f"Not a valid sensors ({sensors}) are given")
 
     return latent_space_stack  # latent_space
+
+
+# Switch to match case in python 3.10 o superior
+# fullfit with a matchcase
+# for captor in sensors:
+#     # match
+#     match captor:
+#         # Cases
+#         case ["S2L2A", "S1FULL"]:
+#             print("S1 full captured & S2 captured")
+
+#         case ["S2L2A",  "S1ASC"]:
+#             print("S1 asc captured & S2 captured")
+
+#         case ["S2L2A", "S1DESC"]:
+#             print("S1 desc captured & S2 captured")
+
+#         case ["S2L2A"]:
+#             print("Only S2 captured")
+
+#         case ["S1FULL"]:
+#             print("Only S1 full captured")
+
+#         case ["S1ASC"]:
+#             print("Only S1ASC captured")
+
+#         case ["S1DESC"]:
+#             print("Only S1DESC captured")
-- 
GitLab


From cd5117f4fbaf60fe879d83a946bdc6ed7ec5ec0e Mon Sep 17 00:00:00 2001
From: Juan Sebastian Vinasco-Salinas <js.vinasco.s@gmail.com>
Date: Tue, 21 Mar 2023 16:33:48 +0000
Subject: [PATCH 66/81] put config logic in other function

---
 src/mmdc_singledate/inference/components/inference_components.py | 1 -
 1 file changed, 1 deletion(-)

diff --git a/src/mmdc_singledate/inference/components/inference_components.py b/src/mmdc_singledate/inference/components/inference_components.py
index 30fb377e..c514a8fc 100644
--- a/src/mmdc_singledate/inference/components/inference_components.py
+++ b/src/mmdc_singledate/inference/components/inference_components.py
@@ -13,7 +13,6 @@ from typing import Literal
 import torch
 from torch import nn
 
-# from mmdc_singledate.datamodules.types import MMDCDataChannels
 from mmdc_singledate.inference.utils import get_mmdc_full_config
 from mmdc_singledate.models.mmdc_full_module import MMDCFullModule
 from mmdc_singledate.models.types import S1S2VAELatentSpace
-- 
GitLab


From 854df0bac2dc3d0cc4d5fff7a4cba3c3a9fad3f2 Mon Sep 17 00:00:00 2001
From: Juan Sebastian Vinasco-Salinas <js.vinasco.s@gmail.com>
Date: Tue, 21 Mar 2023 16:37:42 +0000
Subject: [PATCH 67/81] update config quit match case

---
 .../inference/mmdc_tile_inference.py          | 164 +++++++++---------
 1 file changed, 86 insertions(+), 78 deletions(-)

diff --git a/src/mmdc_singledate/inference/mmdc_tile_inference.py b/src/mmdc_singledate/inference/mmdc_tile_inference.py
index ebc619f5..307e3007 100644
--- a/src/mmdc_singledate/inference/mmdc_tile_inference.py
+++ b/src/mmdc_singledate/inference/mmdc_tile_inference.py
@@ -120,70 +120,75 @@ def predict_single_date_tile(
         chunks = generate_chunks(meta["width"], meta["height"], process.nb_lines)
         # get the windows from the chunks
         rois = [rio.windows.Window(*chunk) for chunk in chunks]
-        # logger.info(f"chunk size : ({rois[0].width}, {rois[0].height}) ")
+        logger.debug(f"chunk size : ({rois[0].width}, {rois[0].height}) ")
 
         # init the dataset
-        # logger.info("Reading S2 data")
+        logger.debug("Reading S2 data")
         s2_data = read_img_tile(
             filename=input_data.s2_filename,
             rois=rois,
             availabitity=input_data.s2_availabitity,
             sensor_func=read_s2_img_tile,
         )
-        # logger.info("Reading S1 ASC data")
+        logger.debug("Reading S1 ASC data")
         s1_asc_data = read_img_tile(
             filename=input_data.s1_asc_filename,
             rois=rois,
             availabitity=input_data.s1_asc_availability,
             sensor_func=read_s1_img_tile,
         )
-        # logger.info("Reading S1 DESC data")
+        logger.debug("Reading S1 DESC data")
         s1_desc_data = read_img_tile(
             filename=input_data.s1_desc_filename,
             rois=rois,
             availabitity=input_data.s1_desc_availability,
             sensor_func=read_s1_img_tile,
         )
-        # logger.info("Reading SRTM data")
+        logger.debug("Reading SRTM data")
         srtm_data = read_img_tile(
             filename=input_data.srtm_filename,
             rois=rois,
             availabitity=True,
             sensor_func=read_srtm_img_tile,
         )
-        # logger.info("Reading WorldClim data")
+        logger.debug("Reading WorldClim data")
         worldclim_data = concat_worldclim_components(
             wc_filename=input_data.wc_filename, rois=rois, availabitity=True
         )
-        # print("input_data.s2_availabitity=", input_data.s2_availabitity)
-        # print("s2_data=", s2_data)
-        # print("s1_asc_data=", s1_asc_data)
-        # print("s1_desc_data=", s1_desc_data)
-        # print("srtm_data=", srtm_data)
-        # print("worldclim_data=", worldclim_data)
-        # logger.info("Export Init")
+        logger.debug("input_data.s2_availabitity=", input_data.s2_availabitity)
+        logger.debug("s2_data=", s2_data)
+        logger.debug("s1_asc_data=", s1_asc_data)
+        logger.debug("s1_desc_data=", s1_desc_data)
+        logger.debug("srtm_data=", srtm_data)
+        logger.debug("worldclim_data=", worldclim_data)
+        logger.debug("Export Init")
 
         with rio.open(export_path, "w", **meta) as prediction:
             # iterate over the windows
-            for roi, s2, s1_asc, s1_desc, srtm, wc in zip(
-                rois,
-                s2_data,
-                s1_asc_data,
-                s1_desc_data,
-                srtm_data,
-                worldclim_data,
+            for roi, s2, s1_asc, s1_desc, srtm, wc in tqdm(
+                zip(
+                    rois,
+                    s2_data,
+                    s1_asc_data,
+                    s1_desc_data,
+                    srtm_data,
+                    worldclim_data,
+                ),
+                total=len(rois),
+                position=1,
             ):
-                # print(" original size : ", s2.s2_reflectances.shape)
-                # print(" original size : ", s2.s2_angles.shape)
-                # print(" original size : ", s2.s2_mask.shape)
-                # print(" original size : ", s1_asc.s1_backscatter.shape)
-                # print(" original size : ", s1_asc.s1_valmask.shape)
-                # print(" original size : ", s1_asc.s1_lia_angles.shape)
-                # print(" original size : ", s1_desc.s1_backscatter.shape)
-                # print(" original size : ", s1_desc.s1_valmask.shape)
-                # print(" original size : ", s1_desc.s1_lia_angles.shape)
-                # print(" original size : ", srtm.srtm.shape)
-                # print(" original size : ", wc.worldclim.shape)
+                # export one tile
+                logger.debug(" original size : ", s2.s2_reflectances.shape)
+                logger.debug(" original size : ", s2.s2_angles.shape)
+                logger.debug(" original size : ", s2.s2_mask.shape)
+                logger.debug(" original size : ", s1_asc.s1_backscatter.shape)
+                logger.debug(" original size : ", s1_asc.s1_valmask.shape)
+                logger.debug(" original size : ", s1_asc.s1_lia_angles.shape)
+                logger.debug(" original size : ", s1_desc.s1_backscatter.shape)
+                logger.debug(" original size : ", s1_desc.s1_valmask.shape)
+                logger.debug(" original size : ", s1_desc.s1_lia_angles.shape)
+                logger.debug(" original size : ", srtm.srtm.shape)
+                logger.debug(" original size : ", wc.worldclim.shape)
 
                 # Concat S1 Data
                 s1_backscatter = torch.cat(
@@ -199,8 +204,8 @@ def predict_single_date_tile(
                 ).shape
                 s2_s2_mask_shape = s2.s2_mask.shape  # [1, 1024, 10980]
 
-                # print("s2_mask_patch_size :", s2_mask_patch_size)
-                # print("s2_s2_mask_shape:", s2_s2_mask_shape)
+                logger.debug("s2_mask_patch_size :", s2_mask_patch_size)
+                logger.debug("s2_s2_mask_shape:", s2_s2_mask_shape)
                 # reshape the data
                 s2_refl_patch = patches.flatten2d(
                     patches.patchify(s2.s2_reflectances, process.patch_size)
@@ -231,25 +236,24 @@ def predict_single_date_tile(
                 s1desc_lia_patch = s1_desc.s1_lia_angles.unsqueeze(0).repeat(
                     s2_mask_patch_size[0] * s2_mask_patch_size[1], 1
                 )
-                # torch.flatten(start_dim=0, end_dim=1)
 
-                # print("s2_patches", s2_refl_patch.shape)
-                # print("s2_angles_patches", s2_ang_patch.shape)
-                # print("s2_mask_patch", s2_mask_patch.shape)
-                # print("s1_patch", s1_patch.shape)
-                # print("s1_valmask_patch", s1_valmask_patch.shape)
-                # print(
-                #     "s1_lia_asc patches",
-                #     s1_asc.s1_lia_angles.shape,
-                #     s1asc_lia_patch.shape,
-                # )
-                # print(
-                #     "s1_lia_desc patches",
-                #     s1_desc.s1_lia_angles.shape,
-                #     s1desc_lia_patch.shape,
-                # )
-                # print("srtm patches", srtm_patch.shape)
-                # print("wc patches", wc_patch.shape)
+                logger.debug("s2_patches", s2_refl_patch.shape)
+                logger.debug("s2_angles_patches", s2_ang_patch.shape)
+                logger.debug("s2_mask_patch", s2_mask_patch.shape)
+                logger.debug("s1_patch", s1_patch.shape)
+                logger.debug("s1_valmask_patch", s1_valmask_patch.shape)
+                logger.debug(
+                    "s1_lia_asc patches",
+                    s1_asc.s1_lia_angles.shape,
+                    s1asc_lia_patch.shape,
+                )
+                logger.debug(
+                    "s1_lia_desc patches",
+                    s1_desc.s1_lia_angles.shape,
+                    s1desc_lia_patch.shape,
+                )
+                logger.debug("srtm patches", srtm_patch.shape)
+                logger.debug("wc patches", wc_patch.shape)
 
                 # apply predict function
                 # should return a s1s2vaelatentspace
@@ -267,8 +271,8 @@ def predict_single_date_tile(
                     wc_patch,
                     srtm_patch,
                 )
-                # print(type(pred_vaelatentspace))
-                # print("latent space sizes : ", pred_vaelatentspace.shape)
+                logger.debug(type(pred_vaelatentspace))
+                logger.debug("latent space sizes : ", pred_vaelatentspace.shape)
 
                 # unpatchify
                 pred_vaelatentspace_unpatchify = patches.unpatchify(
@@ -279,8 +283,8 @@ def predict_single_date_tile(
                     )
                 )[:, : s2_s2_mask_shape[1], : s2_s2_mask_shape[2]]
 
-                # print("pred_tensor :", pred_vaelatentspace_unpatchify.shape)
-                # print("process.count : ", process.count)
+                logger.debug("pred_tensor :", pred_vaelatentspace_unpatchify.shape)
+                logger.debug("process.count : ", process.count)
                 # check the pred and the ask are the same
                 assert process.count == pred_vaelatentspace_unpatchify.shape[0]
 
@@ -289,7 +293,7 @@ def predict_single_date_tile(
                     window=roi,
                     indexes=process.count,
                 )
-                # logger.info(("Export tile", f"filename :{export_path}"))
+                logger.debug(("Export tile", f"filename :{export_path}"))
 
 
 def estimate_nb_latent_spaces(
@@ -326,22 +330,24 @@ def determinate_sensor_list(
         patchasc_s1_availability,
         patchdesc_s1_availability,
     ]
-    match sensors_availability:
-        case [True, True, True]:
-            sensors = ["S2L2A", "S1FULL"]
-        case [True, True, False]:
-            sensors = ["S2L2A", "S1ASC"]
-        case [True, False, True]:
-            sensors = ["S2L2A", "S1DESC"]
-        case [False, True, True]:
-            sensors = ["S1FULL"]
-        case [True, False, False]:
-            sensors = ["S2L2A"]
-        case [False, True, False]:
-            sensors = ["S1ASC"]
-        case [False, False, True]:
-            sensors = ["S1DESC"]
-
+    if sensors_availability == [True, True, True]:
+        sensors = ["S2L2A", "S1FULL"]
+    elif sensors_availability == [True, True, False]:
+        sensors = ["S2L2A", "S1ASC"]
+    elif sensors_availability == [True, False, True]:
+        sensors = ["S2L2A", "S1DESC"]
+    elif sensors_availability == [False, True, True]:
+        sensors = ["S1FULL"]
+    elif sensors_availability == [True, False, False]:
+        sensors = ["S2L2A"]
+    elif sensors_availability == [False, True, False]:
+        sensors = ["S1ASC"]
+    elif sensors_availability == [False, False, True]:
+        sensors = ["S1DESC"]
+    else:
+        raise Exception(
+            (f"The sensor {sensors_availability}", " is not in the available list")
+        )
     return sensors
 
 
@@ -349,6 +355,7 @@ def mmdc_tile_inference(
     inference_dataframe: pd.DataFrame,
     export_path: Path,
     model_checkpoint_path: Path,
+    model_config_path: Path,
     patch_size: int = 256,
     nb_lines: int = 1024,
 ) -> None:
@@ -372,11 +379,13 @@ def mmdc_tile_inference(
 
     # TODO Uncomment for test purpuse
     # inference_dataframe = inference_dataframe.head()
-    # print(inference_dataframe.shape)
+    # logger.debug(inference_dataframe.shape)
 
     # instance the model and get the pred func
     model = get_mmdc_full_model(
         checkpoint=model_checkpoint_path,
+        mmdc_full_config_path=model_config_path,
+        inference_tile=True,
     )
     scales = get_scales()
     model.set_scales(scales)
@@ -385,7 +394,7 @@ def mmdc_tile_inference(
 
     # iterate over the dates in the time serie
     for tuile, df_row in tqdm(
-        inference_dataframe.iterrows(), total=inference_dataframe.shape[0]
+        inference_dataframe.iterrows(), total=inference_dataframe.shape[0], position=0
     ):
         # estimate nb latent spaces
         latent_space_size = estimate_nb_latent_spaces(
@@ -401,9 +410,8 @@ def mmdc_tile_inference(
         # export path
         date_ = df_row["date"].strftime("%Y-%m-%d")
         export_file = export_path / f"latent_tile_infer_{date_}_{'_'.join(sensor)}.tif"
-        #
-        # print(tuile, df_row)
-        # print(latent_space_size)
+        logger.debug(tuile, df_row)
+        logger.debug(latent_space_size)
         # define process
         mmdc_process = MMDCProcess(
             count=latent_space_size,
@@ -431,4 +439,4 @@ def mmdc_tile_inference(
             process=mmdc_process,
         )
 
-        # logger.info("Export Finish !!!")
+        logger.debug("Export Finish !!!")
-- 
GitLab


From 49ad2680225b503e14c3fd2f03402295c1f46a87 Mon Sep 17 00:00:00 2001
From: Juan Sebastian Vinasco-Salinas <js.vinasco.s@gmail.com>
Date: Wed, 22 Mar 2023 15:32:37 +0000
Subject: [PATCH 68/81] delete iota2 functions

---
 .../inference/components/inference_utils.py   | 47 -------------------
 1 file changed, 47 deletions(-)

diff --git a/src/mmdc_singledate/inference/components/inference_utils.py b/src/mmdc_singledate/inference/components/inference_utils.py
index 1d79f1f0..ef85dbdb 100644
--- a/src/mmdc_singledate/inference/components/inference_utils.py
+++ b/src/mmdc_singledate/inference/components/inference_utils.py
@@ -243,27 +243,6 @@ def read_s2_img_tile(
     return S2Components(s2_reflectances=image_s2, s2_angles=angles_s2, s2_mask=mask)
 
 
-# def read_s2_img_iota2(
-#     s2_tensor: torch.Tensor,
-#     *args: Any,
-#     **kwargs: Any,
-# ) -> S2Components:  # [torch.Tensor, torch.Tensor, torch.Tensor]:
-#     """
-#     read a patch of sentinel 2 data
-#     contruct the masks and yield the patch
-#     of data
-#     """
-#     # copy the masks
-#     # cloud_mask =
-#     # sat_mask =
-#     # edge_mask =
-#     mask = torch.concat((cloud_mask, sat_mask, edge_mask), axis=0)
-#     angles_s2 = join_even_odd_s2_angles(s2_tensor[14:, ...])
-#     image_s2 = s2_tensor[:10, ...]
-
-#     return S2Components(s2_reflectances=image_s2, s2_angles=angles_s2, s2_mask=mask)
-
-
 # TODO get out the s1_lia computation from this function
 def read_s1_img_tile(
     s1_tensor: torch.Tensor,
@@ -302,32 +281,6 @@ def read_s1_img_tile(
     )
 
 
-# def read_s1_img_iota2(
-#     s1_tensor: torch.Tensor,
-#     *args: Any,
-#     **kwargs: Any,
-# ) -> S1Components:  # [torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
-#     """
-#     read a patch of s1 construct the datasets associated
-#     and yield the result
-#     """
-#     # get the vv,vh, vv/vh bands
-#     img_s1 = torch.cat(
-#         (s1_tensor, (s1_tensor[1, ...] / s1_tensor[0, ...]).unsqueeze(0))
-#     )
-#     s1_lia = get_s1_acquisition_angles(s1_filename)
-#     # compute validity mask
-#     s1_valmask = torch.ones(img_s1.shape)
-#     # compute edge mask
-#     s1_backscatter = apply_log_to_s1(img_s1)
-
-#     return S1Components(
-#         s1_backscatter=s1_backscatter,
-#         s1_valmask=s1_valmask,
-#         s1_lia_angles=s1_lia,
-#     )
-
-
 def read_srtm_img_tile(
     srtm_tensor: torch.Tensor,
     *args: Any,
-- 
GitLab


From 0c6d60660a596d903a343cc692760546efb76373 Mon Sep 17 00:00:00 2001
From: Juan Sebastian Vinasco-Salinas <js.vinasco.s@gmail.com>
Date: Wed, 22 Mar 2023 18:24:27 +0000
Subject: [PATCH 69/81] WIP parsing data and dates

---
 .../inference/mmdc_iota2_inference.py         | 475 +++++++++++++++---
 1 file changed, 409 insertions(+), 66 deletions(-)

diff --git a/src/mmdc_singledate/inference/mmdc_iota2_inference.py b/src/mmdc_singledate/inference/mmdc_iota2_inference.py
index 4a40ff74..5bd385e4 100644
--- a/src/mmdc_singledate/inference/mmdc_iota2_inference.py
+++ b/src/mmdc_singledate/inference/mmdc_iota2_inference.py
@@ -6,29 +6,58 @@ Infereces API with Iota-2
 Inspired by:
 https://src.koda.cnrs.fr/mmdc/mmdc-singledate/-/blob/01ff5139a9eb22785930964d181e2a0b7b7af0d1/iota2/external_iota2_code.py
 """
+# imports
+import logging
+import os
 
+import numpy as np
 import torch
+from dateutil import parser as date_parser
+from iota2.configuration_files import read_config_file as rcf
 from torchutils import patches
 
-# from mmdc_singledate.datamodules.components.datamodule_utils import (
-#     apply_log_to_s1,
-#     join_even_odd_s2_angles,
-#     srtm_height_aspect,
-# )
+# from mmdc_singledate.models.mmdc_full_module import MMDCFullModule
+# from mmdc_singledate.models.types import S1S2VAELatentSpace
+from mmdc_singledate.datamodules.components.datamodule_utils import (
+    apply_log_to_s1,
+    srtm_height_aspect,
+)
 
+from .components.inference_components import get_mmdc_full_model, predict_mmdc_model
+from .components.inference_utils import (
+    GeoTiffDataset,
+    S1Components,
+    S2Components,
+    SRTMComponents,
+    WorldClimComponents,
+    get_mmdc_full_config,
+    read_s1_img_tile,
+    read_s2_img_tile,
+    read_srtm_img_tile,
+    read_worldclim_img_tile,
+)
+from .utils import get_scales
 
-def apply_mmdc_full_mode(
+# from tqdm import tqdm
+# from pathlib import Path
+
+
+# Configure the logger
+NUMERIC_LEVEL = getattr(logging, "INFO", None)
+logging.basicConfig(
+    level=NUMERIC_LEVEL, format="%(asctime)-15s %(levelname)s: %(message)s"
+)
+logger = logging.getLogger(__name__)
+
+
+def read_s2_img_iota2(
     self,
-    checkpoint_path: str,
-    checkpoint_epoch: int = 100,
-    patch_size: int = 256,
-):
+) -> S2Components:  # [torch.Tensor, torch.Tensor, torch.Tensor]:
     """
-    Apply MMDC with Iota-2.py
+    read a patch of sentinel 2 data
+    contruct the masks and yield the patch
+    of data
     """
-    # How manage the S1 acquisition dates for construct the mask validity
-    # in time
-
     # DONE Get the data in the same order as
     # sentinel2.Sentinel2.GROUP_10M +
     # sensorsio.sentinel2.GROUP_20M
@@ -46,9 +75,37 @@ def apply_mmdc_full_mode(
         self.get_interpolated_Sentinel2_B12(),
     ]
 
-    # TODO Masks contruction for S2
-    list_s2_mask = self.get_Sentinel2_binary_masks()
+    # DONE Masks contruction for S2
+    list_s2_mask = [
+        self.get_Sentinel2_binary_masks(),
+        self.get_Sentinel2_binary_masks(),
+        self.get_Sentinel2_binary_masks(),
+        self.get_Sentinel2_binary_masks(),
+    ]
+
+    # TODO Add S2 Angles
+    list_bands_s2_angles = [
+        self.get_Sentinel2_angles(),
+    ]
+
+    # extend the list
+    list_bands_s2.extend(list_s2_mask)
+    list_bands_s2.extend(list_s2_mask)
+
+    bands_s2 = torch.Tensor(
+        list_bands_s2, dtype=torch.torch.float32, device=device
+    ).permute(-1, 0, 1, 2)
 
+    return read_s2_img_tile(bands_s2)
+
+
+def read_s1_img_iota2(
+    self,
+) -> S1Components:  # [torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+    """
+    read a patch of s1 construct the datasets associated
+    and yield the result
+    """
     # TODO Manage S1 ASC and S1 DESC ?
     list_bands_s1 = [
         self.get_interpolated_Sentinel1_ASC_vh(),
@@ -61,58 +118,344 @@ def apply_mmdc_full_mode(
         / (self.get_interpolated_Sentinel1_DES_vv() + 1e-4),
     ]
 
-    # TODO Read SRTM data
-    # TODO Read Worldclim Data
-
-    with torch.no_grad():
-        # Permute dimensions to fit patchify
-        # Shape before permutation is C,H,W,D. D being the dates
-        # Shape after permutation is D,C,H,W
-        bands_s2 = torch.Tensor(list_bands_s2).permute(-1, 0, 1, 2)
-        bands_s2_mask = torch.Tensor(list_s2_mask).permute(-1, 0, 1, 2)
-        bands_s1 = torch.Tensor(list_bands_s1).permute(-1, 0, 1, 2)
-        bands_s1 = apply_log_to_s1(bands_s1).permute(-1, 0, 1, 2)
-
-        # TODO Masks contruction for S1
-        # build_s1_image_and_masks function for datamodules components datamodule components ?
-
-        # Replace nan by 0
-        bands_s1 = bands_s1.nan_to_num()
-        # These dimensions are useful for unpatchify
-        # Keep the dimensions of height and width found in chunk
-        band_h, band_w = bands_s2.shape[-2:]
-        # Keep the number of patches of patch_size
-        # in rows and cols found in chunk
-        h, w = patches.patchify(
-            bands_s2[0, ...], patch_size=patch_size, margin=patch_margin
-        ).shape[:2]
-
-    # TODO Apply patchify
-
-    # Get the model
+    bands_s1 = torch.Tensor(list_bands_s1).permute(-1, 0, 1, 2)
+    bands_s1 = apply_log_to_s1(bands_s1).permute(-1, 0, 1, 2)
+
+    # get the vv,vh, vv/vh bands
+    # img_s1 = torch.cat(
+    # (s1_tensor, (s1_tensor[1, ...] / s1_tensor[0, ...]).unsqueeze(0))
+    # )
+
+    list_bands_s1_lia = [
+        self.get_interpolated_,
+        self.get_interpolated_,
+        self.get_interpolated_,
+    ]
+
+    # s1_lia = get_s1_acquisition_angles(s1_filename)
+    # compute validity mask
+    s1_valmask = torch.ones(img_s1.shape)
+    # compute edge mask
+    s1_backscatter = apply_log_to_s1(img_s1)
+
+    return read_s1_img_tile()
+
+
+def read_srtm_img_iota2(
+    self,
+) -> SRTMComponents:  # [torch.Tensor]:
+    """
+    read srtm patch
+    """
+    # read the patch as tensor
+
+    list_bands_srtm = [
+        self.get_interpolated_userFeatures_aspect(),
+        self.get_interpolated_,
+        self.get_interpolated_,
+    ]
+
+    srtm_tensor = torch.Tensor(list_bands_srtm).permute(-1, 0, 1, 2)
+
+    return read_srtm_img_tile(srtm_height_aspect(srtm_tensor))
+
+
+def read_worldclim_img_iota2(
+    self,
+) -> WorldClimComponents:  # [torch.tensor]:
+    """
+    read worldclim subset
+    """
+    wc_tensor = torch.Tensor(self.get_interpolated_).permute(-1, 0, 1, 2)
+    return read_worldclim_img_tile(wc_tensor)
+
+
+def predictc_iota2(
+    self,
+):
+    """
+    Apply MMDC with Iota-2.py
+    """
+
+    # retrieve the metadata
+    # meta = input_data.metadata.copy()
+    # get nb outputs
+    # meta.update({"count": process.count})
+    # calculate rois
+    # chunks = generate_chunks(meta["width"], meta["height"], process.nb_lines)
+    # get the windows from the chunks
+    # rois = [rio.windows.Window(*chunk) for chunk in chunks]
+    # logger.debug(f"chunk size : ({rois[0].width}, {rois[0].height}) ")
+
+    # init the dataset
+    logger.debug("Reading S2 data")
+    s2_data = read_s2_img_iota2()
+    logger.debug("Reading S1 ASC data")
+    s1_asc_data = read_s1_img_iota2()
+    logger.debug("Reading S1 DESC data")
+    s1_desc_data = read_s1_img_iota2()
+    logger.debug("Reading SRTM data")
+    srtm_data = read_srtm_img_iota2()
+    logger.debug("Reading WorldClim data")
+    worldclim_data = read_worldclim_img_iota2()
+
+    # Concat S1 Data
+    s1_backscatter = torch.cat((s1_asc.s1_backscatter, s1_desc.s1_backscatter), 0)
+    # Multiply the masks
+    s1_valmask = torch.mul(s1_asc.s1_valmask, s1_desc.s1_valmask)
+    # keep the sizes for recover the original size
+    s2_mask_patch_size = patches.patchify(s2.s2_mask, process.patch_size).shape
+    s2_s2_mask_shape = s2.s2_mask.shape  # [1, 1024, 10980]
+
+    # reshape the data
+    s2_refl_patch = patches.flatten2d(
+        patches.patchify(s2_data.s2_reflectances, process.patch_size)
+    )
+    s2_ang_patch = patches.flatten2d(
+        patches.patchify(s2_data.s2_angles, process.patch_size)
+    )
+    s2_mask_patch = patches.flatten2d(
+        patches.patchify(s2_data.s2_mask, process.patch_size)
+    )
+    s1_patch = patches.flatten2d(patches.patchify(s1_backscatter, process.patch_size))
+    s1_valmask_patch = patches.flatten2d(
+        patches.patchify(s1_valmask, process.patch_size)
+    )
+    srtm_patch = patches.flatten2d(patches.patchify(srtm_data.srtm, process.patch_size))
+    wc_patch = patches.flatten2d(
+        patches.patchify(worldclim_data.worldclim, process.patch_size)
+    )
+    # Expand the angles to fit the sizes
+    s1asc_lia_patch = s1_asc.s1_lia_angles.unsqueeze(0).repeat(
+        s2_mask_patch_size[0] * s2_mask_patch_size[1], 1
+    )
+    s1desc_lia_patch = s1_desc.s1_lia_angles.unsqueeze(0).repeat(
+        s2_mask_patch_size[0] * s2_mask_patch_size[1], 1
+    )
+
+    logger.debug("s2_patches", s2_refl_patch.shape)
+    logger.debug("s2_angles_patches", s2_ang_patch.shape)
+    logger.debug("s2_mask_patch", s2_mask_patch.shape)
+    logger.debug("s1_patch", s1_patch.shape)
+    logger.debug("s1_valmask_patch", s1_valmask_patch.shape)
+    logger.debug(
+        "s1_lia_asc patches",
+        s1_asc.s1_lia_angles.shape,
+        s1asc_lia_patch.shape,
+    )
+    logger.debug(
+        "s1_lia_desc patches",
+        s1_desc.s1_lia_angles.shape,
+        s1desc_lia_patch.shape,
+    )
+    logger.debug("srtm patches", srtm_patch.shape)
+    logger.debug("wc patches", wc_patch.shape)
+
+    # apply predict function
+    # should return a s1s2vaelatentspace
+    # object
+    pred_vaelatentspace = process.process(
+        process.model,
+        sensors,
+        s2_refl_patch,
+        s2_mask_patch,
+        s2_ang_patch,
+        s1_patch,
+        s1_valmask_patch,
+        s1asc_lia_patch,
+        s1desc_lia_patch,
+        wc_patch,
+        srtm_patch,
+    )
+    logger.debug(type(pred_vaelatentspace))
+    logger.debug("latent space sizes : ", pred_vaelatentspace.shape)
+
+    # unpatchify
+    pred_vaelatentspace_unpatchify = patches.unpatchify(
+        patches.unflatten2d(
+            pred_vaelatentspace,
+            s2_mask_patch_size[0],
+            s2_mask_patch_size[1],
+        )
+    )[:, : s2_s2_mask_shape[1], : s2_s2_mask_shape[2]]
+
+    logger.debug("pred_tensor :", pred_vaelatentspace_unpatchify.shape)
+    logger.debug("process.count : ", process.count)
+    # check the pred and the ask are the same
+    assert process.count == pred_vaelatentspace_unpatchify.shape[0]
+
+    predition_array = np.array(pred_vaelatentspace_unpatchify[0, ...])
+
+    logger.debug(("Export tile", f"filename :{export_path}"))
+
+    return predition_array, labels
+
+
+def acquisition_dataframe(
+    self,
+) -> pd.DataFrame:
+    """
+    Construct a DataFrame with the acquisition dates
+    for the predictions
+    """
+    # use Iota2 API
+    # acquisition_dates = self.get_raw_dates()
+
+    # Get the acquisition dates
+    acquisition_dates = {
+        "Sentinel1_DES_vv": ["20151231"],
+        "Sentinel1_DES_vh": ["20151231"],
+        "Sentinel1_ASC_vv": ["20170518", "20170519"],
+        "Sentinel1_ASC_vh": ["20170518", "20170519"],
+        "Sentinel2": ["20151130", "20151203"],
+    }
+
+    # get the acquisition dates for each sensor-mode
+    sentinel2_acquisition_dates = pd.DataFrame(
+        {"Sentinel2": acquisition_dates["Sentinel2"]}, dtype=str
+    )
+    sentinel1_asc_acquisition_dates = pd.DataFrame(
+        {
+            "Sentinel1_ASC_vv": acquisition_dates["Sentinel1_ASC_vv"],
+            "Sentinel1_ASC_vh": acquisition_dates["Sentinel1_ASC_vh"],
+        },
+        dtype=str,
+    )
+    sentinel1_des_acquisition_dates = pd.DataFrame(
+        {
+            "Sentinel1_DES_vv": acquisition_dates["Sentinel1_DES_vv"],
+            "Sentinel1_DES_vh": acquisition_dates["Sentinel1_DES_vh"],
+        },
+        dtype=str,
+    )
+
+    # convert acquistion dates from str to datetime object
+    sentinel2_acquisition_dates = sentinel2_acquisition_dates["Sentinel2"].apply(
+        lambda x: date_parser.parse(x)
+    )
+    sentinel1_asc_acquisition_dates = sentinel1_asc_acquisition_dates[
+        "Sentinel1_ASC_vv"
+    ].apply(lambda x: date_parser.parse(x))
+    sentinel1_asc_acquisition_dates = sentinel1_asc_acquisition_dates[
+        "Sentinel1_ASC_vh"
+    ].apply(lambda x: date_parser.parse(x))
+    sentinel1_desc_acquisition_dates = sentinel1_des_acquisition_dates[
+        "Sentinel1_DES_vv"
+    ].apply(lambda x: date_parser.parse(x))
+    sentinel1_desc_acquisition_dates = sentinel1_des_acquisition_dates[
+        "Sentinel1_DES_vh"
+    ].apply(lambda x: date_parser.parse(x))
+
+    # merge the dataframes
+
+    # concat_acquisitions = pd.concat()
+
+    # return
+
+
+def apply_mmdc_full_mode(
+    # checkpoint_path: str, # read from the cfg file
+    # checkpoint_epoch: int = 100, # read from the cfg file
+    # patch_size: int = 256, # read from the cfg file
+):
+    """
+    Entry Point
+    """
+    # Set device
+    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+    # read the configuration file
+    config = rcf("configs/iota2/externalfeatures_with_userfeatures.cfg")
+    # return a list with a tuple of functions and arguments
+    mmdc_parametrs = config.getParam("external_features", "functions")
+    padding_size_x = config.getParam("python_data_managing", "padding_size_x")
+    padding_size_y = config.getParam("python_data_managing", "padding_size_y")
+
+    checkpoint_path = parametrs[0][1]  # TODO fix this parameters
+
+    # Get the acquisition dates
+    # dict_dates  = {
+    # 'Sentinel1_DES_vv': ['20151231'],
+    # 'Sentinel1_DES_vh': ['20151231'],
+    # 'Sentinel1_ASC_vv': ['20170518', '20170519'],
+    # 'Sentinel1_ASC_vh': ['20170518', '20170519'],
+    # 'Sentinel2': ['20151130', '20151203']}
+
+    # model
     mmdc_full_model = get_mmdc_full_model(
-        os.path.join(checkpoint_path, checkpoint_filename)
-    )
-
-    # apply model
-    # latent_variable = mmdc_full_model.predict(
-    #     s2_x,
-    #     s2_m,
-    #     s2_angles_x,
-    #     s1_x,
-    #     s1_vm,
-    #     s1_asc_angles_x,
-    #     s1_desc_angles_x,
-    #     worldclim_x,
-    #     srtm_x,
+        checkpoint=model_checkpoint_path,
+        mmdc_full_config_path=model_config_path,
+        inference_tile=False,
+    )
+    scales = get_scales()
+    mmdc_full_model.set_scales(scales)
+    mmdc_full_model.to(device)
+
+    # process class
+    mmdc_process = MMDCProcess(
+        count=latent_space_size,
+        nb_lines=nb_lines,
+        patch_size=patch_size,
+        process=predict_mmdc_model,
+        model=mmdc_full_model,
+    )
+    # create export filename
+    export_path = f"{export_path}/latent_singledate_{'_'.join(sensors)}.tif"
+
+    # # How manage the S1 acquisition dates for construct the mask validity
+    # # in time
+
+    # # TODO Read SRTM data
+    # # TODO Read Worldclim Data
+
+    # # with torch.no_grad():
+    # # Permute dimensions to fit patchify
+    # # Shape before permutation is C,H,W,D. D being the dates
+    # # Shape after permutation is D,C,H,W
+    # bands_s2_mask = torch.Tensor(list_s2_mask).permute(-1, 0, 1, 2)
+    # bands_s1 = torch.Tensor(list_bands_s1).permute(-1, 0, 1, 2)
+    # bands_s1 = apply_log_to_s1(bands_s1).permute(-1, 0, 1, 2)
+
+    # # TODO Masks contruction for S1
+    # # build_s1_image_and_masks function for datamodules components datamodule components ?
+
+    # # Replace nan by 0
+    # bands_s1 = bands_s1.nan_to_num()
+    # # These dimensions are useful for unpatchify
+    # # Keep the dimensions of height and width found in chunk
+    # band_h, band_w = bands_s2.shape[-2:]
+    # # Keep the number of patches of patch_size
+    # # in rows and cols found in chunk
+    # h, w = patches.patchify(
+    #     bands_s2[0, ...], patch_size=patch_size, margin=patch_margin
+    # ).shape[:2]
+
+    # # TODO Apply patchify
+
+    # # Get the model
+    # mmdc_full_model = get_mmdc_full_model(
+    #     os.path.join(checkpoint_path, checkpoint_filename)
     # )
 
-    # TODO unpatchify
+    # # apply model
+    # # latent_variable = mmdc_full_model.predict(
+    # #     s2_x,
+    # #     s2_m,
+    # #     s2_angles_x,
+    # #     s1_x,
+    # #     s1_vm,
+    # #     s1_asc_angles_x,
+    # #     s1_desc_angles_x,
+    # #     worldclim_x,
+    # #     srtm_x,
+    # # )
+
+    # # TODO unpatchify
 
-    # TODO crop padding
+    # # TODO crop padding
 
-    # TODO Depending of the configuration return a unique latent variable
-    # or a stack of laten variables
-    # coef = pass
-    # labels = pass
-    # return coef, labels
+    # # TODO Depending of the configuration return a unique latent variable
+    # # or a stack of laten variables
+    # # coef = pass
+    # # labels = pass
+    # # return coef, labels
-- 
GitLab


From 4ae4317bf4be86d5be3e74147fc8f707c8e35899 Mon Sep 17 00:00:00 2001
From: Juan Sebastian Vinasco-Salinas <js.vinasco.s@gmail.com>
Date: Tue, 28 Mar 2023 15:20:38 +0000
Subject: [PATCH 70/81] add jobs

---
 jobs/iota2_clasif_mmdc_full.pbs  | 120 +++++++++++++++++++++++++++++++
 jobs/iota2_tile_aux_features.pbs |  18 +++++
 2 files changed, 138 insertions(+)
 create mode 100644 jobs/iota2_clasif_mmdc_full.pbs
 create mode 100644 jobs/iota2_tile_aux_features.pbs

diff --git a/jobs/iota2_clasif_mmdc_full.pbs b/jobs/iota2_clasif_mmdc_full.pbs
new file mode 100644
index 00000000..e3ec1e97
--- /dev/null
+++ b/jobs/iota2_clasif_mmdc_full.pbs
@@ -0,0 +1,120 @@
+#!/bin/bash
+#PBS -N iota2-class
+#PBS -l select=1:ncpus=8:mem=12G
+#PBS -l walltime=10:00:00
+
+# be sure no modules loaded
+conda deactivate
+module purge
+
+
+ulimit -u 5000
+
+# export ROOT_PATH="$(dirname $( cd "$(dirname "${BASH_SOURCE[0]}")"; pwd -P ))"
+# export OTB_APPLICATION_PATH=$OTB_APPLICATION_PATH:/work/scratch/$USER/virtualenv/mmdc-iota2/lib/otb/applications//
+# export PATH=$PATH:$ROOT_PATH/bin:$OTB_APPLICATION_PATH
+# export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$ROOT_PATH/lib:$OTB_APPLICATION_PATH #
+#
+# load modules
+module load conda/4.12.0
+conda activate /work/scratch/${USER}/virtualenv/mmdc-iota2
+
+# Apply Iota2
+Iota2.py \
+    -config ${HOME}/src/MMDC/mmdc-singledate/configs/iota2/iota2_mmdc_full.cfg \
+    -scheduler_type PBS \
+    -config_ressources /home/uz/vinascj/src/MMDC/mmdc-singledate/configs/iota2/config_ressources.cfg \
+    # -scheduler_type debug \
+    # -only_summary
+    # -nb_parallel_tasks 1 \
+
+
+# New
+# Functions availables
+#['__class__',
+#'__delattr__',
+#'__dict__',
+#'__dir__',
+#'__doc__',
+#'__eq__',
+#'__format__',
+#'__ge__',
+#'__getattribute__',
+#'__gt__',
+#'__hash__',
+#'__init__',
+#'__init_subclass__',
+#'__le__',
+#'__lt__',
+#'__module__',
+#'__ne__',
+#'__new__',
+#'__reduce__',
+#'__reduce_ex__',
+#'__repr__',
+#'__setattr__',
+#'__sizeof__',
+#'__str__',
+#'__subclasshook__',
+#'__weakref__',
+#'all_dates',
+#'allow_nans',
+#'binary_masks',
+#'concat_mode',
+#'dim_ts',
+#'enabled_gap',
+#'exogeneous_data_array',
+#'exogeneous_data_name',
+#'external_functions',
+#'fill_missing_dates',
+#'func_param',
+#'get_Sentinel1_ASC_vh_binary_masks',
+#'get_Sentinel1_ASC_vv_binary_masks',
+#'get_Sentinel1_DES_vh_binary_masks',
+#'get_Sentinel1_DES_vv_binary_masks',
+#'get_Sentinel2_binary_masks',
+#'get_filled_masks',
+#'get_filled_stack',
+#'get_interpolated_Sentinel1_ASC_vh',
+#'get_interpolated_Sentinel1_ASC_vv',
+#'get_interpolated_Sentinel1_DES_vh',
+#'get_interpolated_Sentinel1_DES_vv',
+#'get_interpolated_Sentinel2_B11',
+#'get_interpolated_Sentinel2_B12',
+#'get_interpolated_Sentinel2_B2',
+#'get_interpolated_Sentinel2_B3',
+#'get_interpolated_Sentinel2_B4',
+#'get_interpolated_Sentinel2_B5',
+#'get_interpolated_Sentinel2_B6',
+#'get_interpolated_Sentinel2_B7',
+#'get_interpolated_Sentinel2_B8',
+#'get_interpolated_Sentinel2_B8A',
+#'get_interpolated_Sentinel2_Brightness',
+#'get_interpolated_Sentinel2_NDVI',
+#'get_interpolated_Sentinel2_NDWI',
+#'get_interpolated_dates',
+#'get_interpolated_userFeatures_aspect',
+#'get_raw_Sentinel1_ASC_vh',
+#'get_raw_Sentinel1_ASC_vv',
+#'get_raw_Sentinel1_DES_vh', 'get_raw_Sentinel1_DES_vv',
+#'get_raw_Sentinel2_B11',
+#'get_raw_Sentinel2_B12',
+#'get_raw_Sentinel2_B2',
+#'get_raw_Sentinel2_B3',
+#'get_raw_Sentinel2_B4',
+#'get_raw_Sentinel2_B5',
+#'get_raw_Sentinel2_B6',
+#'get_raw_Sentinel2_B7',
+#'get_raw_Sentinel2_B8',
+#'get_raw_Sentinel2_B8A',
+#'get_raw_dates',
+#'get_raw_userFeatures_aspect',
+#'interpolated_data',
+#'interpolated_dates',
+#'missing_masks_values',
+#'missing_refl_values',
+#'out_data',
+#'process',
+#'raw_data',
+#'raw_dates',
+#'test_user_feature_with_fake_data']
diff --git a/jobs/iota2_tile_aux_features.pbs b/jobs/iota2_tile_aux_features.pbs
new file mode 100644
index 00000000..2507bf36
--- /dev/null
+++ b/jobs/iota2_tile_aux_features.pbs
@@ -0,0 +1,18 @@
+#!/bin/bash
+#PBS -N iota2-tiler
+#PBS -l select=1:ncpus=8:mem=120G
+#PBS -l walltime=4:00:00
+
+# be sure no modules loaded
+conda deactivate
+module purge
+
+# load modules
+module load conda
+conda activate /work/scratch/${USER}/virtualenv/mmdc-iota2
+
+
+Iota2.py -scheduler_type debug \
+    -config ${HOME}/src/MMDC/mmdc-singledate/configs/iota2/i2_tiler.cfg
+    -nb_parallel_tasks 1 \
+    -only_summary
-- 
GitLab


From d9f063ccabfbc409d7e9f96c397754d09f45da42 Mon Sep 17 00:00:00 2001
From: Juan Sebastian Vinasco-Salinas <js.vinasco.s@gmail.com>
Date: Tue, 28 Mar 2023 15:21:26 +0000
Subject: [PATCH 71/81] downgrade libsvm

---
 create-iota2-env.sh | 95 ++-------------------------------------------
 1 file changed, 4 insertions(+), 91 deletions(-)

diff --git a/create-iota2-env.sh b/create-iota2-env.sh
index 06e9ba1f..5f3b99ee 100644
--- a/create-iota2-env.sh
+++ b/create-iota2-env.sh
@@ -45,6 +45,9 @@ conda install mamba
 which python
 python --version
 
+# install libsvm
+mamba install -c conda-forge libsvm=325
+
 # Install iota2
 #mamba install iota2_develop=257da617 -c iota2
 mamba install iota2 -c iota2
@@ -67,7 +70,7 @@ ln -s ~/src/MMDC/mmdc-singledate/iota2_thirdparties/iota2/iota2/ iota2
 
 cd ~/src/MMDC/mmdc-singledate
 
-conda install -c conda-forge pydantic
+mamba install -c conda-forge pydantic
 
 # install missing dependancies
 pip install -r requirements-mmdc-iota2.txt
@@ -87,93 +90,3 @@ pip install -e .[testing]
 
 # End
 conda deactivate
-
-# #!/usr/bin/env bash
-
-# export python_version="3.9"
-# export name="mmdc-iota2"
-# if ! [ -z "$1" ]
-# then
-#     export name=$1
-# fi
-
-# source ~/set_proxy_iota2.sh
-# if [ -z "$https_proxy" ]
-# then
-#     echo "Please set https_proxy environment variable before running this script"
-#     exit 1
-# fi
-
-# export target=/work/scratch/$USER/virtualenv/$name
-
-# if ! [ -z "$2" ]
-# then
-#     export target="$2/$name"
-# fi
-
-# echo "Installing $name in $target ..."
-
-# if [ -d "$target" ]; then
-#    echo "Cleaning previous conda env"
-#    rm -rf $target
-# fi
-
-# # Create blank virtualenv
-# module purge
-# module load conda
-# module load gcc
-# conda activate
-# conda create --yes --prefix $target python==${python_version} pip
-
-# # Enter virtualenv
-# conda deactivate
-# conda activate $target
-
-# # install mamba
-# conda install mamba
-
-# which python
-# python --version
-
-# # Install iota2
-# #mamba install iota2_develop=257da617 -c iota2
-# mamba install iota2 -c iota2
-
-# clone the lastest version of iota2
-# rm -rf iota2_thirdparties/iota2
-# git clone -b issue#600/tile_exogenous_features https://framagit.org/iota2-project/iota2.git iota2_thirdparties/iota2
-# cd iota2_thirdparties/iota2
-# git checkout issue#600/tile_exogenous_features
-# git pull origin issue#600/tile_exogenous_features
-# cd ../../
-# pwd
-
-# # create a backup
-# cd  /home/uz/$USER/scratch/virtualenv/mmdc-iota2/lib/python3.9/site-packages/iota2-0.0.0-py3.9.egg/
-# mv iota2 iota2_backup
-
-# # create a symbolic link
-# ln -s ~/src/MMDC/mmdc-singledate/iota2_thirdparties/iota2/iota2/ iota2
-
-# cd ~/src/MMDC/mmdc-singledate
-
-# conda install -c conda-forge pydantic
-
-# # install missing dependancies
-# pip install -r requirements-mmdc-iota2.txt #--upgrade --no-cache-dir
-
-# # Install sensorsio
-# rm -rf iota2_thirdparties/sensorsio
-# git clone https://src.koda.cnrs.fr/mmdc/sensorsio.git iota2_thirdparties/sensorsio
-# pip install iota2_thirdparties/sensorsio
-
-# # Install torchutils
-# rm -rf iota2_thirdparties/torchutils
-# git clone https://src.koda.cnrs.fr/mmdc/torchutils.git iota2_thirdparties/torchutils
-# pip install iota2_thirdparties/torchutils
-
-# # Install the current project in edit mode
-# pip install -e .[testing]
-
-# # End
-# conda deactivate
-- 
GitLab


From 56f9992dcb6b1603cfcd89aacbe00c88cdb51245 Mon Sep 17 00:00:00 2001
From: Juan Sebastian Vinasco-Salinas <js.vinasco.s@gmail.com>
Date: Tue, 28 Mar 2023 15:21:58 +0000
Subject: [PATCH 72/81] add time series availability

---
 .../inference/mmdc_iota2_inference.py         | 143 ++++++++++++------
 1 file changed, 97 insertions(+), 46 deletions(-)

diff --git a/src/mmdc_singledate/inference/mmdc_iota2_inference.py b/src/mmdc_singledate/inference/mmdc_iota2_inference.py
index 5bd385e4..742f2892 100644
--- a/src/mmdc_singledate/inference/mmdc_iota2_inference.py
+++ b/src/mmdc_singledate/inference/mmdc_iota2_inference.py
@@ -18,25 +18,25 @@ from torchutils import patches
 
 # from mmdc_singledate.models.mmdc_full_module import MMDCFullModule
 # from mmdc_singledate.models.types import S1S2VAELatentSpace
-from mmdc_singledate.datamodules.components.datamodule_utils import (
-    apply_log_to_s1,
-    srtm_height_aspect,
-)
-
-from .components.inference_components import get_mmdc_full_model, predict_mmdc_model
-from .components.inference_utils import (
-    GeoTiffDataset,
-    S1Components,
-    S2Components,
-    SRTMComponents,
-    WorldClimComponents,
-    get_mmdc_full_config,
-    read_s1_img_tile,
-    read_s2_img_tile,
-    read_srtm_img_tile,
-    read_worldclim_img_tile,
-)
-from .utils import get_scales
+# from mmdc_singledate.datamodules.components.datamodule_utils import (
+#     apply_log_to_s1,
+#     srtm_height_aspect,
+# )
+
+# from .components.inference_components import get_mmdc_full_model, predict_mmdc_model
+# from .components.inference_utils import (
+#     GeoTiffDataset,
+#     S1Components,
+#     S2Components,
+#     SRTMComponents,
+#     WorldClimComponents,
+#     get_mmdc_full_config,
+#     read_s1_img_tile,
+#     read_s2_img_tile,
+#     read_srtm_img_tile,
+#     read_worldclim_img_tile,
+# )
+# from .utils import get_scales
 
 # from tqdm import tqdm
 # from pathlib import Path
@@ -51,8 +51,9 @@ logger = logging.getLogger(__name__)
 
 
 def read_s2_img_iota2(
-    self,
-) -> S2Components:  # [torch.Tensor, torch.Tensor, torch.Tensor]:
+        self,
+):
+# ) -> S2Components:  # [torch.Tensor, torch.Tensor, torch.Tensor]:
     """
     read a patch of sentinel 2 data
     contruct the masks and yield the patch
@@ -83,14 +84,19 @@ def read_s2_img_iota2(
         self.get_Sentinel2_binary_masks(),
     ]
 
-    # TODO Add S2 Angles
+    # DONE Add S2 Angles
     list_bands_s2_angles = [
-        self.get_Sentinel2_angles(),
+        self.get_interpolated_userFeatures_ZENITH(),
+        self.get_interpolated_userFeatures_AZIMUTH(),
+        self.get_interpolated_userFeatures_EVEN_ZENITH(),
+        self.get_interpolated_userFeatures_ODD_ZENITH(),
+        self.get_interpolated_userFeatures_EVEN_AZIMUTH(),
+        self.get_interpolated_userFeatures_ODD_AZIMUTH(),
     ]
 
     # extend the list
     list_bands_s2.extend(list_s2_mask)
-    list_bands_s2.extend(list_s2_mask)
+    list_bands_s2.extend(list_bands_s2_angles)
 
     bands_s2 = torch.Tensor(
         list_bands_s2, dtype=torch.torch.float32, device=device
@@ -101,7 +107,8 @@ def read_s2_img_iota2(
 
 def read_s1_img_iota2(
     self,
-) -> S1Components:  # [torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+):
+# ) -> S1Components:  # [torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
     """
     read a patch of s1 construct the datasets associated
     and yield the result
@@ -143,30 +150,39 @@ def read_s1_img_iota2(
 
 def read_srtm_img_iota2(
     self,
-) -> SRTMComponents:  # [torch.Tensor]:
+):
+# ) -> SRTMComponents:
     """
-    read srtm patch
+    read srtm patch from Iota2 API
     """
-    # read the patch as tensor
+    # read the patch as numpy array
 
     list_bands_srtm = [
+        self.get_interpolated_userFeatures_elevation(),
+        self.get_interpolated_userFeatures_slope(),
         self.get_interpolated_userFeatures_aspect(),
-        self.get_interpolated_,
-        self.get_interpolated_,
     ]
-
+    # convert to tensor and permute the bands
     srtm_tensor = torch.Tensor(list_bands_srtm).permute(-1, 0, 1, 2)
 
-    return read_srtm_img_tile(srtm_height_aspect(srtm_tensor))
+    return read_srtm_img_tile(
+        srtm_height_aspect(
+            srtm_tensor
+        )
+    )
+
 
 
 def read_worldclim_img_iota2(
     self,
-) -> WorldClimComponents:  # [torch.tensor]:
+):
+# ) -> WorldClimComponents:  # [torch.tensor]:
     """
-    read worldclim subset
+    read worldclim patch from Iota2 API
     """
-    wc_tensor = torch.Tensor(self.get_interpolated_).permute(-1, 0, 1, 2)
+    wc_tensor = torch.Tensor(
+        self.get_interpolated_userFeatures_worldclim()
+    ).permute(-1, 0, 1, 2)
     return read_worldclim_img_tile(wc_tensor)
 
 
@@ -330,26 +346,34 @@ def acquisition_dataframe(
     )
 
     # convert acquistion dates from str to datetime object
-    sentinel2_acquisition_dates = sentinel2_acquisition_dates["Sentinel2"].apply(
+    sentinel2_acquisition_dates["Sentinel2"] = sentinel2_acquisition_dates["Sentinel2"].apply(
         lambda x: date_parser.parse(x)
     )
-    sentinel1_asc_acquisition_dates = sentinel1_asc_acquisition_dates[
+    sentinel1_asc_acquisition_dates["Sentinel1_ASC_vv"] = sentinel1_asc_acquisition_dates[
         "Sentinel1_ASC_vv"
-    ].apply(lambda x: date_parser.parse(x))
-    sentinel1_asc_acquisition_dates = sentinel1_asc_acquisition_dates[
+        ].apply(lambda x: date_parser.parse(x))
+    sentinel1_asc_acquisition_dates["Sentinel1_ASC_vh"] = sentinel1_asc_acquisition_dates[
         "Sentinel1_ASC_vh"
-    ].apply(lambda x: date_parser.parse(x))
-    sentinel1_desc_acquisition_dates = sentinel1_des_acquisition_dates[
+        ].apply(lambda x: date_parser.parse(x))
+    sentinel1_des_acquisition_dates["Sentinel1_DES_vv"] = sentinel1_des_acquisition_dates[
         "Sentinel1_DES_vv"
-    ].apply(lambda x: date_parser.parse(x))
-    sentinel1_desc_acquisition_dates = sentinel1_des_acquisition_dates[
+        ].apply(lambda x: date_parser.parse(x))
+    sentinel1_des_acquisition_dates["Sentinel1_DES_vh"] = sentinel1_des_acquisition_dates[
         "Sentinel1_DES_vh"
-    ].apply(lambda x: date_parser.parse(x))
+        ].apply(lambda x: date_parser.parse(x))
 
     # merge the dataframes
+    sensors_join = pd.concat([
+        sentinel2_acquisition_dates,
+        sentinel1_asc_acquisition_dates,
+        sentinel1_des_acquisition_dates,
+        ]
+    )
 
-    # concat_acquisitions = pd.concat()
-
+    # get acquisition availability
+    sensors_join["s2_availability"] = pd.isna(sensors_join["Sentinel2"])
+    sensors_join["s1_availability_asc"] = pd.isna(sensors_join["Sentinel1_ASC_vv"])
+    sensors_join["s1_availability_desc"] = pd.isna(sensors_join["Sentinel1_DES_vv"])
     # return
 
 
@@ -459,3 +483,30 @@ def apply_mmdc_full_mode(
     # # coef = pass
     # # labels = pass
     # # return coef, labels
+
+
+def test_mnt(self, arg1):
+    """
+    compute the Soil Composition Index
+    """
+    print(dir(self))
+    print("raw dates :", self.get_raw_dates())
+    print("interpolated dates :",self.get_interpolated_dates())
+    print("interpolated dates :",self.interpolated_dates)
+    print("exogeneous_data_name:",self.exogeneous_data_name)
+    print("fill_missing_dates:",self.fill_missing_dates)
+    print(":",None)
+
+    mnt =  self.get_Sentinel2_binary_masks() #self.get_raw_userFeatures_mnt_band_0()#self.get_raw_userFeatures_mnt2()
+    # mnt =  self.get_raw_userFeatures_mnt2_band_0()
+    # mnt =  self.get_raw_userFeatures_slope()
+    print(f"type mnt : {type(mnt)}")
+    print(f"size mnt : {mnt.shape}")
+    labels = [f"mnt_{cpt}" for cpt in range(mnt.shape[-1])]
+    print(f"labels mnt : {labels}")
+
+    # s2_b5 =  self.get_interpolated_Sentinel2_B5()
+    # print(f"type s2_b5 : {type(s2_b5)}") # 4, 3, 2  # 26,50,2
+    # print(f"size s2_b5 : {s2_b5.shape}") # 4, 3, 1  # 26,50,2
+
+    return mnt, labels
-- 
GitLab


From a160400bb6dc0618780b2024b22bfde7b1745294 Mon Sep 17 00:00:00 2001
From: Juan Sebastian Vinasco-Salinas <js.vinasco.s@gmail.com>
Date: Wed, 29 Mar 2023 11:39:14 +0000
Subject: [PATCH 73/81] config ressources

---
 configs/iota2/config_ressources.cfg | 434 ++++++++++++++++++++++++++++
 1 file changed, 434 insertions(+)
 create mode 100644 configs/iota2/config_ressources.cfg

diff --git a/configs/iota2/config_ressources.cfg b/configs/iota2/config_ressources.cfg
new file mode 100644
index 00000000..9ef37b3e
--- /dev/null
+++ b/configs/iota2/config_ressources.cfg
@@ -0,0 +1,434 @@
+################################################################################
+#               Configuration file use to set HPC ressources request
+
+# All steps are mandatory
+#
+#                  TEMPLATE
+# step_name:{
+#           name : "IOTA2_dir" #no space characters
+#           nb_cpu : 1
+#           ram : "5gb"
+#           walltime : "00:10:00"#HH:MM:SS
+#           process_min : 1 #not mandatory (default = 1)
+#           process_max : 9 #not mandatory (default = number of tasks)
+#           }
+#
+################################################################################
+
+iota2_chain:{
+             name : "IOTA2"
+             nb_cpu : 1
+             ram : "5gb"
+             walltime : "10:00:00"
+            }
+
+iota2_dir:{
+           name : "IOTA2_dir"
+           nb_cpu : 1
+           ram : "5gb"
+           walltime : "00:10:00"
+           process_min : 1
+          }
+
+preprocess_data : {
+                   name:"preprocess_data"
+                   nb_cpu:40
+                   ram:"180gb"
+                   walltime:"10:00:00"
+                   process_min : 1
+                  }
+
+get_common_mask : {
+                   name:"CommonMasks"
+                   nb_cpu:3
+                   ram:"15gb"
+                   walltime:"00:10:00"
+                   process_min : 1
+                  }
+
+coregistration : {
+                   name:"Coregistration"
+                   nb_cpu:4
+                   ram:"20gb"
+                   walltime:"05:00:00"
+                   process_min : 1
+                  }
+
+get_pixValidity : {
+                   name:"NbView"
+                   nb_cpu:3
+                   ram:"15gb"
+                   walltime:"00:10:00"
+                   process_min : 1
+                   }
+
+envelope : {
+            name:"Envelope"
+            nb_cpu:3
+            ram:"15gb"
+            walltime:"00:10:00"
+            process_min : 1
+            }
+
+regionShape : {
+               name:"regionShape"
+               nb_cpu:3
+               ram:"15gb"
+               walltime:"00:10:00"
+               process_min : 1
+               }
+
+splitRegions : {
+                name:"splitRegions"
+                nb_cpu:3
+                ram:"15gb"
+                walltime:"00:10:00"
+                process_min : 1
+                }
+
+extract_data_region_tiles : {
+                             name:"extract_data_region_tiles"
+                             nb_cpu:3
+                             ram:"15gb"
+                             walltime:"00:10:00"
+                             process_min : 1
+                             }
+
+split_learning_val : {
+                      name:"split_learning_val"
+                      nb_cpu:3
+                      ram:"15gb"
+                      walltime:"00:10:00"
+                      process_min : 1
+                      }
+
+split_learning_val_sub : {
+                          name:"split_learning_val_sub"
+                          nb_cpu:3
+                          ram:"15gb"
+                          walltime:"00:10:00"
+                          process_min : 1
+                          }
+split_samples : {
+                 name:"split_samples"
+                 nb_cpu:3
+                 ram:"15gb"
+                 walltime:"00:10:00"
+                 process_min : 1
+                 }
+
+samplesFormatting : {
+                     name:"samples_formatting"
+                     nb_cpu:1
+                     ram:"5gb"
+                     walltime:"00:10:00"
+                     process_min : 1
+                     }
+
+samplesMerge: {
+                name:"samples_merges"
+                nb_cpu:2
+                ram:"10gb"
+                walltime:"00:10:00"
+                process_min : 1
+               }
+
+samplesStatistics : {
+                   name:"samples_stats"
+                   nb_cpu:3
+                   ram:"15gb"
+                   walltime:"00:10:00"
+                   process_min : 1
+                  }
+
+samplesSelection : {
+                   name:"samples_selection"
+                   nb_cpu:3
+                   ram:"15gb"
+                   walltime:"00:10:00"
+                   process_min : 1
+                  }
+samplesselection_tiles : {
+                   name:"samplesSelection_tiles"
+                   nb_cpu:3
+                   ram:"15gb"
+                   walltime:"00:10:00"
+                   process_min : 1
+                  }
+samplesManagement : {
+                   name:"samples_management"
+                   nb_cpu:3
+                   ram:"15gb"
+                   walltime:"00:10:00"
+                   process_min : 1
+                  }
+
+vectorSampler : {
+                 name:"vectorSampler"
+                 nb_cpu:40
+                 ram:"180gb"
+                 walltime:"10:00:00"
+                 process_min : 1
+                 }
+
+dimensionalityReduction : {
+               name:"dimensionalityReduction"
+               nb_cpu:3
+               nb_MPI_process:2
+               ram:"15gb"
+               nb_chunk:1
+               walltime:"00:20:00"}
+
+samplesAugmentation : {
+                       name:"samplesAugmentation"
+                       nb_cpu:3
+                       ram:"15gb"
+                       walltime:"00:10:00"
+                       }
+
+mergeSample : {
+               name:"mergeSample"
+               nb_cpu:3
+               ram:"15gb"
+               walltime:"00:10:00"
+               process_min : 1
+               }
+
+stats_by_models : {
+                   name:"stats_by_models"
+                   nb_cpu:3
+                   ram:"15gb"
+                   walltime:"00:10:00"
+                   process_min : 1
+                   }
+
+training : {
+            name:"training"
+            nb_cpu:40
+            ram:"180gb"
+            walltime:"10:00:00"
+            process_min : 1
+            }
+
+#training regression
+TrainRegressionPytorch : {
+                 name:"training"
+                 nb_cpu:40
+                 ram:"180gb"
+                 walltime:"30:00:00"
+                 }
+
+PredictRegressionScikit: {
+                 name:"prediction"
+                 nb_cpu:40
+                 ram:"180gb"
+                 walltime:"30:00:00"
+                 }
+
+#training regression
+PredictRegressionPytorch : {
+                 name:"choice"
+                 nb_cpu:40
+                 ram:"180gb"
+                 walltime:"30:00:00"
+                 }
+
+
+cmdClassifications : {
+                      name:"cmdClassifications"
+                      nb_cpu:10
+                      ram:"60gb"
+                      walltime:"10:00:00"
+                      process_min : 1
+                      }
+
+classifications : {
+                   name:"classifications"
+                   nb_cpu:10
+                   ram:"60gb"
+                   walltime:"10:00:00"
+                   process_min : 1
+                   }
+
+classifShaping : {
+                  name:"classifShaping"
+                  nb_cpu:10
+                  ram:"60gb"
+                  walltime:"10:00:00"
+                  process_min : 1
+                  }
+
+gen_confusionMatrix : {
+                       name:"genCmdconfusionMatrix"
+                       nb_cpu:3
+                       ram:"15gb"
+                       walltime:"00:10:00"
+                       process_min : 1
+                       }
+
+confusionMatrix : {
+                   name:"confusionMatrix"
+                   nb_cpu:3
+                   ram:"15gb"
+                   walltime:"00:10:00"
+                   process_min : 1
+                   }
+
+#regression
+GenerateRegressionMetrics : {
+                 name:"classification"
+                 nb_cpu:10
+                 ram:"50gb"
+                 walltime:"40:00:00"
+                 }
+
+
+fusion : {
+          name:"fusion"
+          nb_cpu:3
+          ram:"15gb"
+          walltime:"00:10:00"
+          process_min : 1
+          }
+
+noData : {
+          name:"noData"
+          nb_cpu:3
+          ram:"15gb"
+          walltime:"00:10:00"
+          process_min : 1
+          }
+
+statsReport : {
+               name:"statsReport"
+               nb_cpu:3
+               ram:"15gb"
+               walltime:"00:10:00"
+               process_min : 1
+               }
+
+confusionMatrixFusion : {
+                         name:"confusionMatrixFusion"
+                         nb_cpu:3
+                         ram:"15gb"
+                         walltime:"00:10:00"
+                         process_min : 1
+                         }
+
+merge_final_classifications : {
+                           name:"merge_final_classifications"
+                           nb_cpu:1
+                           ram:"5gb"
+                           walltime:"00:10:00"
+                          }
+
+reportGen : {
+             name:"reportGeneration"
+             nb_cpu:3
+             ram:"15gb"
+             walltime:"00:10:00"
+             process_min : 1
+             }
+
+mergeOutStats : {
+                 name:"mergeOutStats"
+                 nb_cpu:3
+                 ram:"15gb"
+                 walltime:"00:10:00"
+                 process_min : 1
+                 }
+
+SAROptConfusionMatrix : {
+                 name:"SAROptConfusionMatrix"
+                 nb_cpu:3
+                 ram:"15gb"
+                 walltime:"00:10:00"
+                 process_min : 1
+                 }
+
+SAROptConfusionMatrixFusion : {
+                 name:"SAROptConfusionMatrixFusion"
+                 nb_cpu:3
+                 ram:"15gb"
+                 walltime:"00:10:00"
+                 process_min : 1
+                 }
+
+SAROptFusion : {
+                 name:"SAROptFusion"
+                 nb_cpu:3
+                 ram:"15gb"
+                 walltime:"00:10:00"
+                 process_min : 1
+                 }
+clump : {
+         name:"clump"
+         nb_cpu:1
+         ram:"5gb"
+         walltime:"00:10:00"
+         process_min:1
+         }
+
+grid : {
+         name:"grid"
+         nb_cpu:1
+         ram:"5gb"
+         walltime:"00:10:00"
+         process_min:1
+         }
+
+vectorisation : {
+                 name:"vectorisation"
+                 nb_cpu:1
+                 ram:"5gb"
+                 walltime:"00:10:00"
+                 process_min:1
+                 }
+
+crownsearch : {
+                name:"crownsearch"
+                nb_cpu:1
+                ram:"5gb"
+                walltime:"00:10:00"
+                process_min:1
+                }
+
+crownbuild : {
+                name:"crownbuild"
+                nb_cpu:1
+                ram:"5gb"
+                walltime:"00:10:00"
+                process_min:1}
+
+statistics : {
+                name:"statistics"
+                nb_cpu:1
+                ram:"5gb"
+                walltime:"00:10:00"
+                process_min:1
+                }
+
+join : {
+         name:"join"
+         nb_cpu:1
+         ram:"5gb"
+         walltime:"00:10:00"
+         process_min:1
+         }
+
+tiler : {
+         name:"tiler"
+         nb_cpu:8
+         ram:"120gb"
+         walltime:"01:00:00"
+         process_min:1
+
+}
+features_tiler : {
+         name:"features_tiler"
+         nb_cpu:8
+         ram:"120gb"
+         walltime:"01:00:00"
+         process_min:1
+
+}
-- 
GitLab


From 4f39a09b7c8e3f2f626959a3f9f0356032ff4fcc Mon Sep 17 00:00:00 2001
From: Juan Sebastian Vinasco-Salinas <js.vinasco.s@gmail.com>
Date: Wed, 29 Mar 2023 11:40:41 +0000
Subject: [PATCH 74/81] add configs

---
 configs/iota2/config_sar.cfg      | 19 ++++++++
 configs/iota2/i2_tiler.cfg        | 22 +++++++++
 configs/iota2/iota2_mmdc_full.cfg | 76 +++++++++++++++++++++++++++++++
 3 files changed, 117 insertions(+)
 create mode 100644 configs/iota2/config_sar.cfg
 create mode 100644 configs/iota2/i2_tiler.cfg
 create mode 100644 configs/iota2/iota2_mmdc_full.cfg

diff --git a/configs/iota2/config_sar.cfg b/configs/iota2/config_sar.cfg
new file mode 100644
index 00000000..07512be0
--- /dev/null
+++ b/configs/iota2/config_sar.cfg
@@ -0,0 +1,19 @@
+[Paths]
+output = /work/CESBIO/projects/MAESTRIA/iota2_mmdc/s1_preprocess
+s1images = /work/CESBIO/projects/MAESTRIA/iota2_mmdc/s1_data_datalake/
+srtm = /datalake/static_aux/MNT/SRTM_30_hgt
+geoidfile = /home/uz/vinascj/src/MMDC/mmdc-data/Geoid/egm96.grd
+
+[Processing]
+tiles = 31TCJ
+tilesshapefile = /work/CESBIO/projects/MAESTRIA/iota2_mmdc/vector_db/mgrs_tiles.shp
+referencesfolder = /work/CESBIO/projects/MAESTRIA/iota2_mmdc/tiled_reference/
+srtmshapefile = /work/CESBIO/projects/MAESTRIA/iota2_mmdc/vector_db/srtm.shp
+rasterpattern = T31TCJ.tif
+gapFilling_interpolation = spline
+temporalresolution = 10
+borderthreshold = 1e-3
+ramperprocess = 5000
+
+[Filtering]
+window_radius = 2
diff --git a/configs/iota2/i2_tiler.cfg b/configs/iota2/i2_tiler.cfg
new file mode 100644
index 00000000..193d757d
--- /dev/null
+++ b/configs/iota2/i2_tiler.cfg
@@ -0,0 +1,22 @@
+chain :
+{
+  output_path : '/work/CESBIO/projects/MAESTRIA/iota2_mmdc/tiled_auxilar_data'
+  spatial_resolution : 10
+  first_step : 'tiler'
+  last_step : 'tiler'
+
+  proj : 'EPSG:2154'
+  rasters_grid_path : '/work/CESBIO/projects/MAESTRIA/iota2_mmdc/tiled_reference/T31TCJ'
+  tile_field : 'Name'
+  list_tile : 'T31TCJ'
+  srtm_path:'/datalake/static_aux/MNT/SRTM_30_hgt'
+  worldclim_path:'/datalake/static_aux/worldclim-2.0'
+  s2_path : '/work/CESBIO/projects/MAESTRIA/iota2_mmdc/s2_data_datalake'
+  # add S1 data
+  s1_dir:"/work/CESBIO/projects/MAESTRIA/iota2_mmdc/s1_data_datalake"
+}
+
+builders:
+{
+builders_class_name : ["i2_features_to_grid"]
+}
diff --git a/configs/iota2/iota2_mmdc_full.cfg b/configs/iota2/iota2_mmdc_full.cfg
new file mode 100644
index 00000000..fb651a84
--- /dev/null
+++ b/configs/iota2/iota2_mmdc_full.cfg
@@ -0,0 +1,76 @@
+# Entry Point Iota2 Config
+chain :
+{
+  # Input and Output data
+  output_path : '/work/CESBIO/projects/MAESTRIA/iota2_mmdc/iota2_results'
+  remove_output_path : False
+  check_inputs : False
+  list_tile : 'T31TCJ'
+  data_field : 'code'
+  s2_path : '/work/scratch/vinascj/MMDC/iota2/i2_training_data/s2_data_full'
+  ground_truth : '/work/CESBIO/projects/MAESTRIA/iota2_mmdc/vector_db/reference_data.shp'
+
+  # Add S1 data
+  s1_path : "/home/uz/vinascj/src/MMDC/mmdc-singledate/configs/iota2/config_sar.cfg"
+
+  # Classification related parameters
+  spatial_resolution : 10
+  color_table : '/work/CESBIO/projects/MAESTRIA/iota2_mmdc/colorFile.txt'
+  nomenclature_path : '/work/CESBIO/projects/MAESTRIA/iota2_mmdc/nomenclature.txt'
+  first_step : 'init'
+  last_step : 'validation'
+  proj : 'EPSG:2154'
+
+  # Path to the User data
+  user_feat_path: '/work/CESBIO/projects/MAESTRIA/iota2_mmdc/tiled_auxilar_data' #/work/CESBIO/projects/MAESTRIA/iota2_mmdc/tiled_auxilar_data/T31TCJ
+}
+
+# Parametrize the classifier
+arg_train :
+{
+  random_seed : 0 # Set the random seed to split the ground truth
+  runs : 1
+  classifier : 'sharkrf'
+  otb_classifier_options : {'classifier.sharkrf.nodesize': 25}
+  sample_selection :
+  {
+    sampler : 'random'
+    strategy : 'percent'
+    'strategy.percent.p' : 0.1
+  }
+}
+
+# Use User Provided data
+userFeat:
+{
+  arbo:"/*"
+  patterns:"elevation,aspect,slope,worldclim,EVEN_AZIMUTH,ODD_ZENITH,s1_incidence_ascending,worldclim,AZIMUTH,EVEN_ZENITH,s1_azimuth_ascending,s1_incidence_descending,ZENITH,elevation,ODD_AZIMUTH,s1_azimuth_descending,"
+}
+
+# Parameters for the user provided data
+python_data_managing :
+{
+    padding_size_x : 1
+    padding_size_y : 1
+    chunk_size_mode:"split_number"
+    number_of_chunks:4
+    data_mode_access: "both"
+}
+
+# Process the User and Iota2 data
+external_features :
+{
+  # module:"/home/uz/vinascj/scratch/MMDC/iota2/juan_i2_data/inputs/other/soi.py"
+  module:"/home/uz/vinascj/src/MMDC/mmdc-singledate/src/mmdc_singledate/inference/mmdc_iota2_inference.py"
+  functions : [['test_mnt', {"arg1":1}]]
+  concat_mode : True
+  external_features_flag:True
+}
+
+# Parametrize nodes
+task_retry_limits :
+{
+  allowed_retry : 0
+  maximum_ram : 180.0
+  maximum_cpu : 40
+}
-- 
GitLab


From 3d75244afe891beb613a30e3547bfd55547c9095 Mon Sep 17 00:00:00 2001
From: Juan Sebastian Vinasco-Salinas <js.vinasco.s@gmail.com>
Date: Wed, 29 Mar 2023 12:32:36 +0000
Subject: [PATCH 75/81] fix import

---
 .../inference/components/inference_components.py                | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/src/mmdc_singledate/inference/components/inference_components.py b/src/mmdc_singledate/inference/components/inference_components.py
index c514a8fc..6331e26f 100644
--- a/src/mmdc_singledate/inference/components/inference_components.py
+++ b/src/mmdc_singledate/inference/components/inference_components.py
@@ -14,7 +14,7 @@ import torch
 from torch import nn
 
 from mmdc_singledate.inference.utils import get_mmdc_full_config
-from mmdc_singledate.models.mmdc_full_module import MMDCFullModule
+from mmdc_singledate.models.torch.full import MMDCFullModule
 from mmdc_singledate.models.types import S1S2VAELatentSpace
 
 # define sensors variable for typing
-- 
GitLab


From 8b7aa4065072748c94fc3b250372968b1b74e855 Mon Sep 17 00:00:00 2001
From: Juan Sebastian Vinasco-Salinas <js.vinasco.s@gmail.com>
Date: Wed, 29 Mar 2023 12:47:36 +0000
Subject: [PATCH 76/81] update config file

---
 test/test_mmdc_inference.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/test/test_mmdc_inference.py b/test/test_mmdc_inference.py
index fd15d093..0f5085a9 100644
--- a/test/test_mmdc_inference.py
+++ b/test/test_mmdc_inference.py
@@ -270,7 +270,7 @@ def test_predict_single_date_tile_no_data(
 
     mmdc_full_model = get_mmdc_full_model(
         os.path.join(checkpoint_path, checkpoint_filename),
-        mmdc_full_config_path="/home/uz/vinascj/src/MMDC/mmdc-singledate/configs/model/mmdc_full.json",
+        mmdc_full_config_path="/home/uz/vinascj/src/MMDC/mmdc-singledate/test/full_module_config.json",
         inference_tile=True,
     )
     # move to device
-- 
GitLab


From 5671620dacf2a476efcf09ce13267591dc3cc0dc Mon Sep 17 00:00:00 2001
From: Juan Sebastian Vinasco-Salinas <js.vinasco.s@gmail.com>
Date: Wed, 29 Mar 2023 12:48:43 +0000
Subject: [PATCH 77/81] update config file

---
 test/test_mmdc_inference.py | 7 ++++---
 1 file changed, 4 insertions(+), 3 deletions(-)

diff --git a/test/test_mmdc_inference.py b/test/test_mmdc_inference.py
index 0f5085a9..6d8a5f4f 100644
--- a/test/test_mmdc_inference.py
+++ b/test/test_mmdc_inference.py
@@ -18,16 +18,17 @@ from mmdc_singledate.inference.components.inference_utils import (
     GeoTiffDataset,
     generate_chunks,
 )
-from mmdc_singledate.inference.mmdc_tile_inference import (
+from mmdc_singledate.inference.mmdc_tile_inference import (  # inference_dataframe,
     MMDCProcess,
-    inference_dataframe,
     mmdc_tile_inference,
     predict_single_date_tile,
 )
-from mmdc_singledate.models.types import VAELatentSpace
 
 from .utils import get_scales, setup_data
 
+# from mmdc_singledate.models.types import VAELatentSpace
+
+
 # usefull variables
 dataset_dir = "/work/CESBIO/projects/MAESTRIA/test_onetile/total/T31TCJ/"
 s2_filename_variable = "SENTINEL2A_20180302-105023-464_L2A_T31TCJ_C_V2-2_roi_0.tif"
-- 
GitLab


From 7ee2a504503af9e97cda31b89ae60da511819efc Mon Sep 17 00:00:00 2001
From: Juan Sebastian Vinasco-Salinas <js.vinasco.s@gmail.com>
Date: Wed, 29 Mar 2023 13:42:51 +0000
Subject: [PATCH 78/81] WIP minor fix

---
 Makefile                     |   3 +-
 test/full_module_config.json | 197 +++++++++++++++++++++++++++++++++++
 2 files changed, 199 insertions(+), 1 deletion(-)
 create mode 100644 test/full_module_config.json

diff --git a/Makefile b/Makefile
index bc64a93b..d112a0ae 100644
--- a/Makefile
+++ b/Makefile
@@ -74,10 +74,11 @@ test_no_PIL:
 test_mask_loss:
 	$(CONDA) && pytest -vv test/test_masked_losses.py
 
-PYLINT_IGNORED = "pix2pix_module.py,pix2pix_networks.py,mmdc_residual_module.py"
 test_inference:
 	$(CONDA) && pytest -vv test/test_mmdc_inference.py
 
+PYLINT_IGNORED = "pix2pix_module.py,pix2pix_networks.py,mmdc_residual_module.py"
+
 #.PHONY:
 pylint:
 	$(CONDA) && pylint --ignore=$(PYLINT_IGNORED) src/
diff --git a/test/full_module_config.json b/test/full_module_config.json
new file mode 100644
index 00000000..db4ed172
--- /dev/null
+++ b/test/full_module_config.json
@@ -0,0 +1,197 @@
+{
+  "_target_": "mmdc_singledate.models.mmdc_full_module.MMDCFullLitModule",
+  "lr": 0.0001,
+  "model": {
+    "_target_": "mmdc_singledate.models.mmdc_full_module.MMDCFullModule",
+    "config": {
+      "_target_": "mmdc_singledate.models.types.MultiVAEConfig",
+      "data_sizes": {
+        "_target_": "mmdc_singledate.models.types.MMDCDataChannels",
+        "sen1": 6,
+        "s1_angles": 6,
+        "sen2": 10,
+        "s2_angles": 6,
+        "srtm": 4,
+        "w_c": 103
+      },
+      "embeddings": {
+        "_target_": "mmdc_singledate.models.types.ModularEmbeddingConfig",
+        "s1_angles": {
+          "_target_": "mmdc_singledate.models.types.MLPConfig",
+          "hidden_layers": [
+            9,
+            12,
+            9
+          ],
+          "out_channels": 3
+        },
+        "s1_srtm": {
+          "_target_": "mmdc_singledate.models.types.UnetParams",
+          "out_channels": 4,
+          "encoder_sizes": [
+            32,
+            64,
+            128
+          ],
+          "kernel_size": 3,
+          "tail_layers": 3
+        },
+        "s2_angles": {
+          "_target_": "mmdc_singledate.models.types.ConvnetParams",
+          "out_channels": 3,
+          "sizes": [
+            16,
+            8
+          ],
+          "kernel_sizes": [
+            1,
+            1,
+            1
+          ]
+        },
+        "s2_srtm": {
+          "_target_": "mmdc_singledate.models.types.UnetParams",
+          "out_channels": 4,
+          "encoder_sizes": [
+            32,
+            64,
+            128
+          ],
+          "kernel_size": 3,
+          "tail_layers": 3
+        },
+        "w_c": {
+          "_target_": "mmdc_singledate.models.types.ConvnetParams",
+          "out_channels": 3,
+          "sizes": [
+            64,
+            32,
+            16
+          ],
+          "kernel_sizes": [
+            1,
+            1,
+            1,
+            1
+          ]
+        }
+      },
+      "s1_encoder": {
+        "_target_": "mmdc_singledate.models.types.UnetParams",
+        "out_channels": 4,
+        "encoder_sizes": [
+          64,
+          128,
+          256,
+          512
+        ],
+        "kernel_size": 3,
+        "tail_layers": 3
+      },
+      "s2_encoder": {
+        "_target_": "mmdc_singledate.models.types.UnetParams",
+        "out_channels": 4,
+        "encoder_sizes": [
+          64,
+          128,
+          256,
+          512
+        ],
+        "kernel_size": 3,
+        "tail_layers": 3
+      },
+      "s1_decoder": {
+        "_target_": "mmdc_singledate.models.types.ConvnetParams",
+        "out_channels": 0,
+        "sizes": [
+          64,
+          32,
+          16
+        ],
+        "kernel_sizes": [
+          3,
+          3,
+          3,
+          3
+        ]
+      },
+      "s2_decoder": {
+        "_target_": "mmdc_singledate.models.types.ConvnetParams",
+        "out_channels": 0,
+        "sizes": [
+          64,
+          32,
+          16
+        ],
+        "kernel_sizes": [
+          3,
+          3,
+          3,
+          3
+        ]
+      },
+      "s1_enc_use": {
+        "_target_": "mmdc_singledate.models.types.MMDCDataUse",
+        "s1_angles": true,
+        "s2_angles": true,
+        "srtm": true,
+        "w_c": true
+      },
+      "s2_enc_use": {
+        "_target_": "mmdc_singledate.models.types.MMDCDataUse",
+        "s1_angles": true,
+        "s2_angles": true,
+        "srtm": true,
+        "w_c": true
+      },
+      "s1_dec_use": {
+        "_target_": "mmdc_singledate.models.types.MMDCDataUse",
+        "s1_angles": true,
+        "s2_angles": true,
+        "srtm": true,
+        "w_c": true
+      },
+      "s2_dec_use": {
+        "_target_": "mmdc_singledate.models.types.MMDCDataUse",
+        "s1_angles": true,
+        "s2_angles": true,
+        "srtm": true,
+        "w_c": true
+      },
+      "loss_weights": {
+        "_target_": "mmdc_singledate.models.types.MultiVAELossWeights",
+        "sen1": {
+          "_target_": "mmdc_singledate.models.types.TranslationLossWeights",
+          "nll": 1,
+          "lpips": 1,
+          "sam": 1,
+          "gradients": 1
+        },
+        "sen2": {
+          "_target_": "mmdc_singledate.models.types.TranslationLossWeights",
+          "nll": 1,
+          "lpips": 1,
+          "sam": 1,
+          "gradients": 1
+        },
+        "s1_s2": {
+          "_target_": "mmdc_singledate.models.types.TranslationLossWeights",
+          "nll": 1,
+          "lpips": 1,
+          "sam": 1,
+          "gradients": 1
+        },
+        "s2_s1": {
+          "_target_": "mmdc_singledate.models.types.TranslationLossWeights",
+          "nll": 1,
+          "lpips": 1,
+          "sam": 1,
+          "gradients": 1
+        },
+        "forward": 1,
+        "cross": 1,
+        "latent": 1
+      }
+    }
+  }
+}
-- 
GitLab


From d412ff10112cce56c020f949e69639b166c3ded3 Mon Sep 17 00:00:00 2001
From: Juan Sebastian Vinasco-Salinas <js.vinasco.s@gmail.com>
Date: Wed, 29 Mar 2023 13:47:42 +0000
Subject: [PATCH 79/81] WIP minor fix

---
 src/mmdc_singledate/inference/utils.py | 3 ++-
 1 file changed, 2 insertions(+), 1 deletion(-)

diff --git a/src/mmdc_singledate/inference/utils.py b/src/mmdc_singledate/inference/utils.py
index 6b6a3c45..4c24cb63 100644
--- a/src/mmdc_singledate/inference/utils.py
+++ b/src/mmdc_singledate/inference/utils.py
@@ -25,6 +25,7 @@ from mmdc_singledate.models.types import (  # S1S2VAELatentSpace,
 )
 
 
+# TODO Complexify the scales used for the inference
 def get_scales():
     """
     Read Scales for Inference
@@ -108,7 +109,7 @@ def parse_from_cfg(path_to_cfg: str):
     return cfg_dict
 
 
-def get_mmdc_full_config(path_to_cfg: str, inference_tile: bool):
+def get_mmdc_full_config(path_to_cfg: str, inference_tile: bool) -> MultiVAEConfig:
     """ "
     Get the parameters for instantiate the mmdc_full module
     based on the target inference mode
-- 
GitLab


From 93b63ae2ba60d2c6f2abf7fb41420831ae297327 Mon Sep 17 00:00:00 2001
From: Juan Sebastian Vinasco-Salinas <js.vinasco.s@gmail.com>
Date: Wed, 29 Mar 2023 13:50:48 +0000
Subject: [PATCH 80/81] WIP minor fix

---
 .../components/inference_components.py        | 132 ++++++++----------
 1 file changed, 60 insertions(+), 72 deletions(-)

diff --git a/src/mmdc_singledate/inference/components/inference_components.py b/src/mmdc_singledate/inference/components/inference_components.py
index 6331e26f..42cdc604 100644
--- a/src/mmdc_singledate/inference/components/inference_components.py
+++ b/src/mmdc_singledate/inference/components/inference_components.py
@@ -122,14 +122,14 @@ def predict_mmdc_model(
         srtm.to(device),
     )
 
-    logger.debug("prediction.shape :", prediction[0].latent.latent_s2.mean.shape)
-    logger.debug("prediction.shape :", prediction[0].latent.latent_s2.logvar.shape)
-    logger.debug("prediction.shape :", prediction[0].latent.latent_s1.mean.shape)
-    logger.debug("prediction.shape :", prediction[0].latent.latent_s1.logvar.shape)
+    logger.debug("prediction.shape :", prediction[0].latent.sen2.mean.shape)
+    logger.debug("prediction.shape :", prediction[0].latent.sen2.logvar.shape)
+    logger.debug("prediction.shape :", prediction[0].latent.sen1.mean.shape)
+    logger.debug("prediction.shape :", prediction[0].latent.sen1.logvar.shape)
 
     # init the output latent spaces as
     # empty dataclass
-    latent_space = S1S2VAELatentSpace(latent_s1=None, latent_s2=None)
+    latent_space = S1S2VAELatentSpace(sen1=None, sen2=None)
 
     # fullfit with a matchcase
     # match
@@ -138,150 +138,138 @@ def predict_mmdc_model(
     if sensors == ["S2L2A", "S1FULL"]:
         logger.debug("S2 captured & S1 full captured")
 
-        latent_space.latent_s2 = prediction[0].latent.latent_s2
-        latent_space.latent_s1 = prediction[0].latent.latent_s1
+        latent_space.sen2 = prediction[0].latent.sen2
+        latent_space.sen1 = prediction[0].latent.sen1
 
         latent_space_stack = torch.cat(
             (
-                latent_space.latent_s2.mean,
-                latent_space.latent_s2.logvar,
-                latent_space.latent_s1.mean,
-                latent_space.latent_s1.logvar,
+                latent_space.sen2.mean,
+                latent_space.sen2.logvar,
+                latent_space.sen1.mean,
+                latent_space.sen1.logvar,
             ),
             1,
         )
 
         logger.debug(
-            "latent_space.latent_s2.mean :",
-            latent_space.latent_s2.mean.shape,
+            "latent_space.sen2.mean :",
+            latent_space.sen2.mean.shape,
         )
         logger.debug(
-            "latent_space.latent_s2.logvar :",
-            latent_space.latent_s2.logvar.shape,
-        )
-        logger.debug("latent_space.latent_s1.mean :", latent_space.latent_s1.mean.shape)
-        logger.debug(
-            "latent_space.latent_s1.logvar :", latent_space.latent_s1.logvar.shape
+            "latent_space.sen2.logvar :",
+            latent_space.sen2.logvar.shape,
         )
+        logger.debug("latent_space.sen1.mean :", latent_space.sen1.mean.shape)
+        logger.debug("latent_space.sen1.logvar :", latent_space.sen1.logvar.shape)
 
     elif sensors == ["S2L2A", "S1ASC"]:
         # logger.debug("S2 captured & S1 asc captured")
 
-        latent_space.latent_s2 = prediction[0].latent.latent_s2
-        latent_space.latent_s1 = prediction[0].latent.latent_s1
+        latent_space.sen2 = prediction[0].latent.sen2
+        latent_space.sen1 = prediction[0].latent.sen1
 
         latent_space_stack = torch.cat(
             (
-                latent_space.latent_s2.mean,
-                latent_space.latent_s2.logvar,
-                latent_space.latent_s1.mean,
-                latent_space.latent_s1.logvar,
+                latent_space.sen2.mean,
+                latent_space.sen2.logvar,
+                latent_space.sen1.mean,
+                latent_space.sen1.logvar,
             ),
             1,
         )
         logger.debug(
-            "latent_space.latent_s2.mean :",
-            latent_space.latent_s2.mean.shape,
-        )
-        logger.debug(
-            "latent_space.latent_s2.logvar :",
-            latent_space.latent_s2.logvar.shape,
+            "latent_space.sen2.mean :",
+            latent_space.sen2.mean.shape,
         )
-        logger.debug("latent_space.latent_s1.mean :", latent_space.latent_s1.mean.shape)
         logger.debug(
-            "latent_space.latent_s1.logvar :", latent_space.latent_s1.logvar.shape
+            "latent_space.sen2.logvar :",
+            latent_space.sen2.logvar.shape,
         )
+        logger.debug("latent_space.sen1.mean :", latent_space.sen1.mean.shape)
+        logger.debug("latent_space.sen1.logvar :", latent_space.sen1.logvar.shape)
 
     elif sensors == ["S2L2A", "S1DESC"]:
         logger.debug("S2 captured & S1 desc captured")
 
-        latent_space.latent_s2 = prediction[0].latent.latent_s2
-        latent_space.latent_s1 = prediction[0].latent.latent_s1
+        latent_space.sen2 = prediction[0].latent.sen2
+        latent_space.sen1 = prediction[0].latent.sen1
 
         latent_space_stack = torch.cat(
             (
-                latent_space.latent_s2.mean,
-                latent_space.latent_s2.logvar,
-                latent_space.latent_s1.mean,
-                latent_space.latent_s1.logvar,
+                latent_space.sen2.mean,
+                latent_space.sen2.logvar,
+                latent_space.sen1.mean,
+                latent_space.sen1.logvar,
             ),
             1,
         )
         logger.debug(
-            "latent_space.latent_s2.mean :",
-            latent_space.latent_s2.mean.shape,
-        )
-        logger.debug(
-            "latent_space.latent_s2.logvar :",
-            latent_space.latent_s2.logvar.shape,
+            "latent_space.sen2.mean :",
+            latent_space.sen2.mean.shape,
         )
-        logger.debug("latent_space.latent_s1.mean :", latent_space.latent_s1.mean.shape)
         logger.debug(
-            "latent_space.latent_s1.logvar :", latent_space.latent_s1.logvar.shape
+            "latent_space.sen2.logvar :",
+            latent_space.sen2.logvar.shape,
         )
+        logger.debug("latent_space.sen1.mean :", latent_space.sen1.mean.shape)
+        logger.debug("latent_space.sen1.logvar :", latent_space.sen1.logvar.shape)
 
     elif sensors == ["S2L2A"]:
         # logger.debug("Only S2 captured")
 
-        latent_space.latent_s2 = prediction[0].latent.latent_s2
+        latent_space.sen2 = prediction[0].latent.sen2
 
         latent_space_stack = torch.cat(
             (
-                latent_space.latent_s2.mean,
-                latent_space.latent_s2.logvar,
+                latent_space.sen2.mean,
+                latent_space.sen2.logvar,
             ),
             1,
         )
         logger.debug(
-            "latent_space.latent_s2.mean :",
-            latent_space.latent_s2.mean.shape,
+            "latent_space.sen2.mean :",
+            latent_space.sen2.mean.shape,
         )
         logger.debug(
-            "latent_space.latent_s2.logvar :",
-            latent_space.latent_s2.logvar.shape,
+            "latent_space.sen2.logvar :",
+            latent_space.sen2.logvar.shape,
         )
 
     elif sensors == ["S1FULL"]:
         # logger.debug("Only S1 full captured")
 
-        latent_space.latent_s1 = prediction[0].latent.latent_s1
+        latent_space.sen1 = prediction[0].latent.sen1
 
         latent_space_stack = torch.cat(
-            (latent_space.latent_s1.mean, latent_space.latent_s1.logvar),
+            (latent_space.sen1.mean, latent_space.sen1.logvar),
             1,
         )
-        logger.debug("latent_space.latent_s1.mean :", latent_space.latent_s1.mean.shape)
-        logger.debug(
-            "latent_space.latent_s1.logvar :", latent_space.latent_s1.logvar.shape
-        )
+        logger.debug("latent_space.sen1.mean :", latent_space.sen1.mean.shape)
+        logger.debug("latent_space.sen1.logvar :", latent_space.sen1.logvar.shape)
 
     elif sensors == ["S1ASC"]:
         # logger.debug("Only S1ASC captured")
 
-        latent_space.latent_s1 = prediction[0].latent.latent_s1
+        latent_space.sen1 = prediction[0].latent.sen1
 
         latent_space_stack = torch.cat(
-            (latent_space.latent_s1.mean, latent_space.latent_s1.logvar),
+            (latent_space.sen1.mean, latent_space.sen1.logvar),
             1,
         )
-        logger.debug("latent_space.latent_s1.mean :", latent_space.latent_s1.mean.shape)
-        logger.debug(
-            "latent_space.latent_s1.logvar :", latent_space.latent_s1.logvar.shape
-        )
+        logger.debug("latent_space.sen1.mean :", latent_space.sen1.mean.shape)
+        logger.debug("latent_space.sen1.logvar :", latent_space.sen1.logvar.shape)
 
     elif sensors == ["S1DESC"]:
         # logger.debug("Only S1DESC captured")
 
-        latent_space.latent_s1 = prediction[0].latent.latent_s1
+        latent_space.sen1 = prediction[0].latent.sen1
 
         latent_space_stack = torch.cat(
-            (latent_space.latent_s1.mean, latent_space.latent_s1.logvar),
+            (latent_space.sen1.mean, latent_space.sen1.logvar),
             1,
         )
-        logger.debug("latent_space.latent_s1.mean :", latent_space.latent_s1.mean.shape)
-        logger.debug(
-            "latent_space.latent_s1.logvar :", latent_space.latent_s1.logvar.shape
-        )
+        logger.debug("latent_space.sen1.mean :", latent_space.sen1.mean.shape)
+        logger.debug("latent_space.sen1.logvar :", latent_space.sen1.logvar.shape)
     else:
         raise Exception(f"Not a valid sensors ({sensors}) are given")
 
-- 
GitLab


From 7ab7d9e0dc900e2d3dea37014b6d3f08f9e3c58d Mon Sep 17 00:00:00 2001
From: Juan Sebastian Vinasco-Salinas <js.vinasco.s@gmail.com>
Date: Wed, 29 Mar 2023 13:59:11 +0000
Subject: [PATCH 81/81] WIP adding entry points

---
 setup.cfg | 2 ++
 1 file changed, 2 insertions(+)

diff --git a/setup.cfg b/setup.cfg
index eb6eb8f6..3cce8fa6 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -72,6 +72,8 @@ console_scripts =
      mmdc_generate_dataset = bin.generate_mmdc_dataset:main
      mmdc_visualize = bin.visualize_mmdc_ds:main
      mmdc_split_dataset = bin.split_mmdc_dataset:main
+     mmdc_inference_singledate = bin.inference_mmdc_singledate:main
+     mmdc_inference_time_serie = bin.inference_mmdc_timeserie:main
 
 # mmdc_inference = bin.infer_mmdc_ds:main
 # Add here console scripts like:
-- 
GitLab