From 4cdc39392ebbb565b138051f1453d354c408eecc Mon Sep 17 00:00:00 2001 From: Juan Sebastian Vinasco-Salinas <js.vinasco.s@gmail.com> Date: Thu, 16 Feb 2023 16:14:12 +0000 Subject: [PATCH 01/81] add README.rst for entry points setup --- README.rst | 1 + setup.cfg | 6 ++++++ src/bin/generate_mmdc_dataset.py | 6 +++++- 3 files changed, 12 insertions(+), 1 deletion(-) create mode 100644 README.rst diff --git a/README.rst b/README.rst new file mode 100644 index 00000000..5bc06818 --- /dev/null +++ b/README.rst @@ -0,0 +1 @@ +# MMDC-SingleDate diff --git a/setup.cfg b/setup.cfg index 9c5469ae..08f3a061 100644 --- a/setup.cfg +++ b/setup.cfg @@ -68,6 +68,12 @@ testing = pytest-cov [options.entry_points] +console_scripts = + mmdc_data = bin.generate_mmdc_dataset:main + mmdc_visualize = bin.visualize_mmdc_ds:main + mmdc_split_dataset = bin.split_mmdc_dataset:main + +# mmdc_inference = bin.infer_mmdc_ds:main # Add here console scripts like: # console_scripts = # script_name = mmdc-singledate.module:function diff --git a/src/bin/generate_mmdc_dataset.py b/src/bin/generate_mmdc_dataset.py index d801c566..ada7c11b 100644 --- a/src/bin/generate_mmdc_dataset.py +++ b/src/bin/generate_mmdc_dataset.py @@ -81,7 +81,7 @@ def get_parser() -> argparse.ArgumentParser: return arg_parser -if __name__ == "__main__": +def main(): # Parser arguments parser = get_parser() args = parser.parse_args() @@ -124,3 +124,7 @@ if __name__ == "__main__": args.threshold, ), ) + + +if __name__ == "__main__": + main() -- GitLab From 6a38a318699ec4d871d7d9f98e38fe700401d138 Mon Sep 17 00:00:00 2001 From: Juan Sebastian Vinasco-Salinas <js.vinasco.s@gmail.com> Date: Thu, 16 Mar 2023 10:08:36 +0000 Subject: [PATCH 02/81] solve conflicts --- src/bin/split_mmdc_dataset.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/src/bin/split_mmdc_dataset.py b/src/bin/split_mmdc_dataset.py index de429bf5..b205a249 100644 --- a/src/bin/split_mmdc_dataset.py +++ b/src/bin/split_mmdc_dataset.py @@ -141,16 +141,10 @@ def split_tile_rois( roi_intersect.writelines(sorted([f"{roi}\n" for roi in roi_list])) -def main( - input_dir: Path, input_file: Path, test_percentage: int, random_state: int -) -> None: +def main(): """ Split the tiles/rois between train/val/test """ - split_tile_rois(input_dir, input_file, test_percentage, random_state) - - -if __name__ == "__main__": # Parser arguments parser = get_parser() args = parser.parse_args() @@ -171,10 +165,14 @@ if __name__ == "__main__": logging.info("test percentage selected : %s", args.test_percentage) logging.info("random state value : %s", args.random_state) - # Go to main - main( + # Go to entry point + split_tile_rois( Path(args.tensors_dir), Path(args.roi_intersections_file), args.test_percentage, args.random_state, ) + + +if __name__ == "__main__": + main() -- GitLab From 52f90b6df8091f9db3ba1021e3a906a6706b8ed2 Mon Sep 17 00:00:00 2001 From: Juan Sebastian Vinasco-Salinas <js.vinasco.s@gmail.com> Date: Thu, 16 Mar 2023 10:10:15 +0000 Subject: [PATCH 03/81] ignore iota2 copy --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index 2015cd5d..676c9c84 100644 --- a/.gitignore +++ b/.gitignore @@ -13,3 +13,4 @@ thirdparties .coverage src/MMDC_SingleDate.egg-info/ .projectile +iota2_thirdparties/ -- GitLab From f5ee64c78921dec37fe73524f1ae6e7f0d949212 Mon Sep 17 00:00:00 2001 From: Juan Sebastian Vinasco-Salinas <js.vinasco.s@gmail.com> Date: Thu, 16 Feb 2023 16:16:00 +0000 Subject: [PATCH 04/81] update for entry point --- src/bin/visualize_mmdc_ds.py | 55 +++++++++++++++--------------------- 1 file changed, 22 insertions(+), 33 deletions(-) diff --git a/src/bin/visualize_mmdc_ds.py b/src/bin/visualize_mmdc_ds.py index 01da86c6..a69bc924 100644 --- a/src/bin/visualize_mmdc_ds.py +++ b/src/bin/visualize_mmdc_ds.py @@ -71,34 +71,10 @@ def get_parser() -> argparse.ArgumentParser: return arg_parser -def main( - input_directory: Path, - input_tiles: Path, - patch_index: list[int] | None, - nb_patches: int | None, - export_path: Path, -) -> None: +def main(): """ Entry point for the visualization """ - try: - export_visualization_graphs( - input_directory=input_directory, - input_tiles=input_tiles, - patch_index=patch_index, - nb_patches=nb_patches, - export_path=export_path, - ) - - except FileNotFoundError: - if not input_directory.exists(): - print(f"Folder {input_directory} does not exist!!") - - if not export_path.exists(): - print(f"Folder {export_path} does not exist!!") - - -if __name__ == "__main__": # Parser arguments parser = get_parser() args = parser.parse_args() @@ -126,13 +102,26 @@ if __name__ == "__main__": logging.info(" index patches : %s", args.patch_index) logging.info(" output directory : %s", args.export_path) - # Go to main - main( - input_directory=Path(args.input_path), - input_tiles=Path(args.input_tiles), - patch_index=args.patch_index, - nb_patches=args.nb_patches, - export_path=Path(args.export_path), - ) + # Go to entry point + + try: + export_visualization_graphs( + input_directory=Path(args.input_path), + input_tiles=Path(args.input_tiles), + patch_index=args.patch_index, + nb_patches=args.nb_patches, + export_path=Path(args.export_path), + ) + + except FileNotFoundError: + if not args.input_directory.exists(): + print(f"Folder {args.input_directory} does not exist!!") + + if not args.export_path.exists(): + print(f"Folder {args.export_path} does not exist!!") logging.info("Visualization export finished") + + +if __name__ == "__main__": + main() -- GitLab From 1db085675ab92d477b7e1a9b5e567806f95bf871 Mon Sep 17 00:00:00 2001 From: Juan Sebastian Vinasco-Salinas <js.vinasco.s@gmail.com> Date: Fri, 17 Feb 2023 08:23:58 +0000 Subject: [PATCH 05/81] update convention --- setup.cfg | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.cfg b/setup.cfg index 08f3a061..eb6eb8f6 100644 --- a/setup.cfg +++ b/setup.cfg @@ -69,7 +69,7 @@ testing = [options.entry_points] console_scripts = - mmdc_data = bin.generate_mmdc_dataset:main + mmdc_generate_dataset = bin.generate_mmdc_dataset:main mmdc_visualize = bin.visualize_mmdc_ds:main mmdc_split_dataset = bin.split_mmdc_dataset:main -- GitLab From b1236f397e7b2e87717ada8b6f6bc465ce831fe4 Mon Sep 17 00:00:00 2001 From: Juan Sebastian Vinasco-Salinas <js.vinasco.s@gmail.com> Date: Wed, 29 Mar 2023 11:46:23 +0000 Subject: [PATCH 06/81] merge makefile --- Makefile | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/Makefile b/Makefile index 5692cfc8..bc64a93b 100644 --- a/Makefile +++ b/Makefile @@ -74,8 +74,10 @@ test_no_PIL: test_mask_loss: $(CONDA) && pytest -vv test/test_masked_losses.py - PYLINT_IGNORED = "pix2pix_module.py,pix2pix_networks.py,mmdc_residual_module.py" +test_inference: + $(CONDA) && pytest -vv test/test_mmdc_inference.py + #.PHONY: pylint: $(CONDA) && pylint --ignore=$(PYLINT_IGNORED) src/ -- GitLab From bee5c3218cb7d34527bf919d13fac79aec765678 Mon Sep 17 00:00:00 2001 From: Juan Sebastian Vinasco-Salinas <js.vinasco.s@gmail.com> Date: Fri, 17 Feb 2023 09:13:48 +0000 Subject: [PATCH 07/81] WIP add inference components utils --- .../components/inference_components.py | 210 ++++++++++++ .../inference/components/inference_utils.py | 324 ++++++++++++++++++ 2 files changed, 534 insertions(+) create mode 100644 src/mmdc_singledate/inference/components/inference_components.py create mode 100644 src/mmdc_singledate/inference/components/inference_utils.py diff --git a/src/mmdc_singledate/inference/components/inference_components.py b/src/mmdc_singledate/inference/components/inference_components.py new file mode 100644 index 00000000..f480b032 --- /dev/null +++ b/src/mmdc_singledate/inference/components/inference_components.py @@ -0,0 +1,210 @@ +#!/usr/bin/env python3 +# copyright: (c) 2023 cesbio / centre national d'Etudes Spatiales + +""" +Functions components for get the +latent spaces for a given tile +""" + +import logging + +# imports +from pathlib import Path +from typing import Literal + +import torch +from torch import nn +from torchutils import patches + +from mmdc_singledate.models.components.model_dataclass import VAELatentSpace +from mmdc_singledate.models.mmdc_full_module import MMDCFullModule + +# define sensors variable for typing +SENSORS = Literal["S2L2A", "S1FULL", "S1ASC", "S1DESC"] + +# Configure the logger +NUMERIC_LEVEL = getattr(logging, "INFO", None) +logging.basicConfig( + level=NUMERIC_LEVEL, format="%(asctime)-15s %(levelname)s: %(message)s" +) +logger = logging.getLogger(__name__) + + +@torch.no_grad() +def get_mmdc_full_model( + checkpoint: str, +) -> nn.Module: + + """ + Instantiate the network + + """ + + if not Path(checkpoint).exists: + logger.info("No checkpoint path was give. ") + return 1 + else: + mmdc_full_model = MMDCFullModule( + s2_angles_conv_in_channels=6, + s2_angles_conv_out_channels=3, + s2_angles_conv_encoder_sizes=[16, 8], + s2_angles_conv_kernel_size=1, + # + srtm_s2_angles_conv_in_channels=10, # 4 srtm + 6 angles + srtm_s2_angles_conv_out_channels=3, + srtm_s2_angles_unet_encoder_sizes=[32, 64, 128], + srtm_s2_angles_kernel_size=3, + # + s1_angles_mlp_in_size=6, + s1_angles_mlp_hidden_layers=[9, 12, 9], + s1_angles_mlp_out_size=3, + # + srtm_s1_encoder_sizes=[32, 64, 128], + srtm_kernel_size=3, + srtm_s1_in_channels=4, + srtm_s1_out_channels=3, + # + wc_enc_sizes=[64, 32, 16], + wc_kernel_size=1, + wc_in_channels=103, + wc_out_channels=4, + # + s1_input_size=6, + s1_encoder_sizes=[64, 128, 256, 512, 1024], + s1_enc_kernel_sizes=[3], + s1_decoder_sizes=[32, 16, 8], + s1_dec_kernel_sizes=[3, 3, 3, 3], + code_sizes=[0, 4, 0], + s2_input_size=10, + s2_encoder_sizes=[64, 128, 256, 512, 1024], + s2_enc_kernel_sizes=[3], + s2_decoder_sizes=[64, 32, 16], + s2_dec_kernel_sizes=[3, 3, 3, 3], + s1_ang_in_decoder=True, + s2_ang_in_decoder=True, + srtm_in_decoder=True, + wc_in_decoder=True, + w_d1_e1_s1=1, + w_d2_e2_s2=1, + w_d1_e2_s2=1, + w_d2_e1_s1=1, + w_code_s1s2=1, + ) + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + print(device) + # load state_dict + lightning_checkpoint = torch.load(checkpoint, map_location=device) + + # delete "model" from the loaded checkpoint + checkpoint = { + key.split("model.")[-1]: item + for key, item in lightning_checkpoint["state_dict"].items() + if key.startswith("model.") + } + + # load the state dict + mmdc_full_model.load_state_dict(checkpoint) + + # disble randomness, dropout, etc... + mmdc_full_model.eval() + + return mmdc_full_model + + +@torch.no_grad() +def predict_mmdc_model( + model: nn.Module, + sensors: SENSORS, + s2_ref, + s2_angles, + s1_asc, + s1_desc, + s1_asc_angles, + s1_desc_angles, + srtm, + worldclim, +) -> VAELatentSpace: + """ + This function apply the predict method to a + batch of mmdc data and get the latent space + by sensor + + :args: model : mmdc_full model intance + :args: sensor : Sensor acquisition + :args: s2_ref : + :args: s2_angles : + :args: s1_back : + :args: s1_asc_angles : + :args: s1_desc_angles : + :args: worldclim : + :args: srtm : + + :return: VAELatentSpace + """ + s1_back = torch.cat( + ( + s1_asc, + s1_desc, + (s1_asc / s1_desc), + ), + 1, + ) + prediction = model.predict( + s2_ref, + s2_angles, + s1_back, + s1_asc_angles, + s1_desc_angles, + worldclim, + srtm, + ) + + if "S2L2A" in sensors: + # get latent + latent_space = prediction[0].latent.latent_s2 + + if "S1FULL" in sensors: + latent_space = prediction[0].latent.latent_s1 + + if "S1ASC" in sensors: + latent_space = prediction[0].latent.latent_s1 + + if "S1DESC" in sensors: + latent_space = prediction[0].latent.latent_s1 + + return latent_space + + +def patchify_batch( + tensor: torch.Tensor, + patch_size: int, +) -> torch.tensor: + """ + reshape the geotiff data readed to the + shape expected from the network + + :param: tensor + :param: patch_size + """ + patch = patches.patchify(tensor, patch_size) + flatten_patch = patches.flatten2d(patch) + + return flatten_patch + + +def unpatchify_batch( + flatten_patch: torch.tensor, patch_shape: torch.Size, tensor_shape: torch.Size +) -> torch.tensor: + """ + Inverse operation of patchify batch + :param: # + :param: # + :param: # + :return: # + """ + unflatten_patch = patches.unflatten2d(flatten_patch, patch_shape[0], patch_shape[1]) + unpatch = patches.unpatchify(unflatten_patch) + unpatch_crop = unpatch[:, : tensor_shape[1], : tensor_shape[2]] + + return unpatch_crop diff --git a/src/mmdc_singledate/inference/components/inference_utils.py b/src/mmdc_singledate/inference/components/inference_utils.py new file mode 100644 index 00000000..3ec181c5 --- /dev/null +++ b/src/mmdc_singledate/inference/components/inference_utils.py @@ -0,0 +1,324 @@ +#!/usr/bin/env python3 +# copyright: (c) 2023 cesbio / centre national d'Etudes Spatiales + +""" +Infereces utils functions + +""" +from collections.abc import Callable, Generator +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any + +import rasterio as rio +import torch +from rasterio.windows import Window as rio_window + +from mmdc_singledate.datamodules.components.datamodule_components import ( + get_worldclim_filenames, +) +from mmdc_singledate.datamodules.components.datamodule_utils import ( + apply_log_to_s1, + get_s1_acquisition_angles, + join_even_odd_s2_angles, + srtm_height_aspect, +) + +# dataclasses +Chunk = tuple[int, int, int, int] + +# hold the data + + +@dataclass +class GeoTiffDataset: + + """ + hold the tif's filenames + stored in disk and his metadata + """ + + s2_filename: str + s1_asc_filename: str + s1_desc_filename: str + srtm_filename: str + wc_filename: str + metadata: dict = field(init=False) + + def __post_init__(self) -> None: + # check if the data exists + # if not Path(self.s2_filename).exists(): + # raise Exception(f"{self.s2_filename} do not exists!") + # if not Path(self.s1_asc_filename).exists(): + # raise Exception(f"{self.s1_asc_filename} do not exists!") + # if not Path(self.s1_desc_filename).exists(): + # raise Exception(f"{self.s1_desc_filename} do not exists!") + if not Path(self.srtm_filename).exists(): + raise Exception(f"{self.srtm_filename} do not exists!") + if not Path(self.wc_filename).exists(): + raise Exception(f"{self.wc_filename} do not exists!") + + # check if the filenames + # provide are in the same + # footprint and spatial resolution + # rio.open(self.s2_filename) as s2, rio.open( + # self.s1_asc_filename + # ) as s1_asc, rio.open(self.s1_desc_filename) as s1_desc, + + with rio.open(self.srtm_filename) as srtm, rio.open(self.wc_filename) as wc: + # recover the metadata + self.metadata = srtm.meta + + # list to check + crss = { + # s2.meta["crs"], + # s1_asc.meta["crs"], + # s1_desc.meta["crs"], + srtm.meta["crs"], + wc.meta["crs"], + } + + heights = { + # s2.meta["height"], + # s1_asc.meta["height"], + # s1_desc.meta["height"], + srtm.meta["height"], + wc.meta["height"], + } + + widths = { + # s2.meta["width"], + # s1_asc.meta["width"], + # s1_desc.meta["width"], + srtm.meta["width"], + wc.meta["width"], + } + + # check crs + if len(crss) > 1: + raise Exception("Data should be in the same CRS") + + # check height + if len(heights) > 1: + raise Exception("Data should be have the same height") + + # check width + if len(widths) > 1: + raise Exception("Data should be have the same width") + + +@dataclass +class S2Components: + """ + dataclass for hold the s2 related data + """ + + s2_reflectances: torch.Tensor + s2_angles: torch.Tensor + s2_mask: torch.Tensor + + +@dataclass +class S1Components: + """ + dataclass for hold the s1 related data + """ + + s1_backscatter: torch.Tensor + s1_valmask: torch.Tensor + s1_edgemask: torch.Tensor + s1_lia_angles: torch.Tensor + + +@dataclass +class SRTMComponents: + """ + dataclass for hold srtm related data + """ + + srtm: torch.Tensor + + +@dataclass +class WorldClimComponents: + """ + dataclass for hold worldclim related data + """ + + worldclim: torch.Tensor + + +# chunks logic +def generate_chunks( + width: int, height: int, nlines: int, test_area: tuple[int, int] | None = None +) -> list[Chunk]: + """Given the width and height of an image and a number of lines per chunk, + generate the list of chunks for the image. The chunks span all the width. A + chunk is encoded as 4 values: (x0, y0, width, height) + :param width: image width + :param height: image height + :param nlines: number of lines per chunk + :param test_area: generate only the chunks between (ya, yb) + :returns: a list of chunks + """ + chunks = [] + for iter_height in range(height // nlines - 1): + y0 = iter_height * nlines + y1 = y0 + nlines + if y1 > height: + y1 = height + append_chunk = (test_area is None) or (y0 > test_area[0] and y1 < test_area[1]) + if append_chunk: + chunks.append((0, y0, width, y1 - y0)) + return chunks + + +# read data +def read_img_tile( + filename: str, + rois: list[rio_window], + sensor_func: Callable[[torch.Tensor, Any], Any], +) -> Generator[Any, None, None]: + """ + read a patch of a abstract sensor data + contruct the auxiliary data and yield the patch + of data + """ + # check if the filename exists + # if exist proceed to yield the data + if Path(filename).exists: + with rio.open(filename) as raster: + for roi in rois: + # read the patch as tensor + tensor = torch.tensor(raster.read(window=roi), requires_grad=False) + # compute some transformation to the tensor and return dataclass + sensor_data = sensor_func(tensor, filename) + yield sensor_data + # if not exists create a zeros tensor and yield + else: + for roi in rois: + null_data = torch.zeros(roi.width, roi.height) + yield null_data + + +def read_s2_img_tile( + s2_tensor: torch.Tensor, + *args: Any, + **kwargs: Any, +) -> S2Components: # [torch.Tensor, torch.Tensor, torch.Tensor]: + """ + read a patch of sentinel 2 data + contruct the masks and yield the patch + of data + """ + if s2_tensor.shape[0] == 20: + # extract masks + cloud_mask = s2_tensor[11, ...].to(torch.uint8) + cloud_mask[cloud_mask > 0] = 1 + sat_mask = s2_tensor[10, ...].to(torch.uint8) + edge_mask = s2_tensor[12, ...].to(torch.uint8) + # create the validity mask + mask = torch.logical_or( + cloud_mask, torch.logical_or(edge_mask, sat_mask) + ).unsqueeze(0) + angles_s2 = join_even_odd_s2_angles(s2_tensor[14:, ...]) + image_s2 = s2_tensor[:10, ...] + + else: + image_s2 = torch.zeros(10, s2_tensor.shape[0], s2_tensor.shape[1]) + angles_s2 = torch.zeros(6, s2_tensor.shape[0], s2_tensor.shape[1]) + mask = torch.zeros(1, s2_tensor.shape[0], s2_tensor.shape[1]) + + return S2Components(s2_reflectances=image_s2, s2_angles=angles_s2, s2_mask=mask) + + +def read_s1_img_tile( + s1_tensor: torch.Tensor, + s1_filename: str, + *args: Any, + **kwargs: Any, +) -> S1Components: # [torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """ + read a patch of s1 construct the datasets associated + and yield the result + """ + if s1_tensor.shape[0] == 2: + # get the vv,vh, vv/vh bands + img_s1 = torch.cat( + (s1_tensor, (s1_tensor[1, ...] / s1_tensor[0, ...]).unsqueeze(0)) + ) + # get the local incidencce angle (lia) + s1_lia = get_s1_acquisition_angles(s1_filename) + # compute validity mask + s1_valmask = torch.ones(img_s1.shape) + # compute edge mask + s1_edgemask = img_s1.to(int) + s1_backscatter = apply_log_to_s1(img_s1) + else: + s1_backscatter = torch.zeros(3, s1_tensor.shape[0], s1_tensor.shape[1]) + s1_valmask = torch.zeros(1, s1_tensor.shape[0], s1_tensor.shape[1]) + s1_edgemask = torch.zeros(1, s1_tensor.shape[0], s1_tensor.shape[1]) + s1_lia = torch.zeros(3) + + return S1Components( + s1_backscatter=s1_backscatter, + s1_valmask=s1_valmask, + s1_edgemask=s1_edgemask, + s1_lia_angles=s1_lia, + ) + + +def read_srtm_img_tile( + srtm_tensor: torch.Tensor, + *args: Any, + **kwargs: Any, +) -> SRTMComponents: # [torch.Tensor]: + """ + read srtm patch + """ + # read the patch as tensor + img_srtm = srtm_height_aspect(srtm_tensor) + return SRTMComponents(img_srtm) + + +def read_worldclim_img_tile( + wc_tensor: torch.Tensor, + *args: Any, + **kwargs: Any, +) -> WorldClimComponents: # [torch.tensor]: + """ + read worldclim subset + """ + return WorldClimComponents(wc_tensor) + + +def expand_worldclim_filenames(wc_filename: str) -> list[str]: + """ + given the firts worldclim filename expand the filenames + to the others files + """ + name_wc = [ + get_worldclim_filenames(wc_filename, str(idx), True) for idx in range(1, 13, 1) + ] + + name_wc_bio = get_worldclim_filenames(wc_filename, None, False) + name_wc.append(name_wc_bio) + + return name_wc + + +def concat_worldclim_components( + wc_filename: str, rois: list[rio_window] +) -> Generator[Any, None, None]: + """ + Compose function for apply the read_img_tile general function + to the different worldclim files + """ + # get all filenames + wc_filenames = expand_worldclim_filenames(wc_filename) + # read the tensors + wc_tensor = [ + next(read_img_tile(wc_filename, rois, read_worldclim_img_tile)).worldclim + for idx, wc_filename in enumerate(wc_filenames) + ] + yield WorldClimComponents(torch.cat(wc_tensor)) -- GitLab From 8372a1537ddbcc2cff8e82dcbf86b00bb7d41ab9 Mon Sep 17 00:00:00 2001 From: Juan Sebastian Vinasco-Salinas <js.vinasco.s@gmail.com> Date: Fri, 17 Feb 2023 09:14:15 +0000 Subject: [PATCH 08/81] WIP add inference components utils --- src/mmdc_singledate/inference/components/inference_components.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/mmdc_singledate/inference/components/inference_components.py b/src/mmdc_singledate/inference/components/inference_components.py index f480b032..fe3212b2 100644 --- a/src/mmdc_singledate/inference/components/inference_components.py +++ b/src/mmdc_singledate/inference/components/inference_components.py @@ -34,7 +34,6 @@ logger = logging.getLogger(__name__) def get_mmdc_full_model( checkpoint: str, ) -> nn.Module: - """ Instantiate the network -- GitLab From 30e98cd1437d1a1a460d40163fdd85b3ad2f9cd3 Mon Sep 17 00:00:00 2001 From: Juan Sebastian Vinasco-Salinas <js.vinasco.s@gmail.com> Date: Fri, 17 Feb 2023 09:15:50 +0000 Subject: [PATCH 09/81] add test inference --- test/test_mmdc_inference.py | 273 ++++++++++++++++++++++++++++++++++++ 1 file changed, 273 insertions(+) create mode 100644 test/test_mmdc_inference.py diff --git a/test/test_mmdc_inference.py b/test/test_mmdc_inference.py new file mode 100644 index 00000000..90d654be --- /dev/null +++ b/test/test_mmdc_inference.py @@ -0,0 +1,273 @@ +#!/usr/bin/env python3 +# Copyright: (c) 2023 CESBIO / Centre National d'Etudes Spatiales + +import os + +import pytest +import torch + +from mmdc_singledate.inference.components.inference_components import ( # predict_tile, + get_mmdc_full_model, +) +from mmdc_singledate.inference.components.inference_utils import ( + GeoTiffDataset, + generate_chunks, +) +from mmdc_singledate.inference.mmdc_tile_inference import MMDCProcess, predict_tile + +# dir +dataset_dir = "/work/CESBIO/projects/MAESTRIA/test_onetile/total/T31TCJ/" + + +def test_GeoTiffDataset(): + """ """ + datapath = "/work/CESBIO/projects/MAESTRIA/test_onetile/total/T31TCJ" + input_data = GeoTiffDataset( + s2_filename=os.path.join( + datapath, "SENTINEL2A_20180302-105023-464_L2A_T31TCJ_C_V2-2_roi_0.tif" + ), + s1_asc_filename=os.path.join( + datapath, + "S1A_IW_GRDH_1SDV_20180303T174722_20180303T174747_020854_023C45_4525_36.09887223808782_15.429479982324139_roi_0.tif", + ), + s1_desc_filename=os.path.join( + datapath, + "S1A_IW_GRDH_1SDV_20180302T060027_20180302T060052_020832_023B8F_C485_40.83835645911643_165.05888005216622_roi_0.tif", + ), + srtm_filename=os.path.join(datapath, "srtm_T31TCJ_roi_0.tif"), + wc_filename=os.path.join(datapath, "wc_clim_1_T31TCJ_roi_0.tif"), + ) + assert input_data + assert type(input_data.metadata) == dict + + +def test_generate_chunks(): + """ + test chunk functionality + """ + chunks = generate_chunks(width=10980, height=10980, nlines=1024) + + assert type(chunks) == list + assert chunks[0] == (0, 0, 10980, 1024) + + +def test_mmdc_full_model(): + """ + test instantiate network + """ + checkpoint_path = "/home/uz/vinascj/scratch/MMDC/results/latent/logs/experiments/runs/mmdc_full/2023-01-06_13-00-34/checkpoints" + checkpoint_filename = "epoch_008.ckpt" # "last.ckpt" + + # lightning_dict = torch.load("epoch_008.ckpt", map_location=torch.device("cpu")) + + mmdc_full_model = get_mmdc_full_model( + os.path.join(checkpoint_path, checkpoint_filename) + ) + + print(mmdc_full_model) + + assert mmdc_full_model.training == False + + s2_x = torch.rand(1, 10, 256, 256) + s2_m = torch.ones(1, 10, 256, 256) + s2_angles_x = torch.rand(1, 6, 256, 256) + s1_x = torch.rand(1, 6, 256, 256) + s1_vm = torch.ones(1, 6, 256, 256) + s1_asc_angles_x = torch.rand(1, 3) + s1_desc_angles_x = torch.rand(1, 3) + worldclim_x = torch.rand(1, 103, 256, 256) + srtm_x = torch.rand(1, 4, 256, 256) + + prediction = mmdc_full_model.predict( + s2_x, + s2_m, + s2_angles_x, + s1_x, + s1_vm, + s1_asc_angles_x, + s1_desc_angles_x, + worldclim_x, + srtm_x, + ) + + latent_variable = prediction[0].latent.latent_s1.mu + + assert type(latent_variable) == torch.Tensor + + +def dummy_process( + s2_refl, + s2_ang, + s1_asc, + s1_desc, + s1_asc_lia, + s1_desc_lia, + srtm_patch, + wc_patch, +): + """ + Create a dummy function + """ + prediction = ( + s2_refl[:, 0, ...] + * s2_ang[:, 0, ...] + * s1_asc[:, 0, ...] + * s1_desc[:, 0, ...] + * srtm_patch[:, 0, ...] + * wc_patch[:, 0, ...] + ) + return prediction + + +@pytest.mark.skip(reason="Have to update to the new data format") +def test_predict_tile(): + """ """ + export_path = ( + "/work/CESBIO/projects/MAESTRIA/test_onetile/total/export/test_latent.tif" + ) + process = MMDCProcess( + count=1, + nb_lines=1024, + process=dummy_process, + ) + + datapath = "/work/CESBIO/projects/MAESTRIA/test_onetile/total/T31TCJ/" + input_data = GeoTiffDataset( + s2_filename=os.path.join( + datapath, "SENTINEL2A_20180302-105023-464_L2A_T31TCJ_C_V2-2_roi_0.tif" + ), + s1_asc_filename=os.path.join( + datapath, + "S1A_IW_GRDH_1SDV_20180303T174722_20180303T174747_020854_023C45_4525_36.09887223808782_15.429479982324139_roi_0.tif", + ), + s1_desc_filename=os.path.join( + datapath, + "S1A_IW_GRDH_1SDV_20180302T060027_20180302T060052_020832_023B8F_C485_40.83835645911643_165.05888005216622_roi_0.tif", + ), + srtm_filename=os.path.join(datapath, "srtm_T31TDJ_roi_0.tif"), + wc_filename=os.path.join(datapath, "wc_clim_1_T31TDJ_roi_0.tif"), + ) + + predict_tile( + input_data=input_data, + export_path=export_path, + process=process, + ) + + +# def test_read_s2_img_tile(): +# """ +# test read a window of s2 +# """ +# dataset_sample = os.path.join( +# dataset_dir, "SENTINEL2A_20180302-105023-464_L2A_T31TDJ_C_V2-2_roi_0.tif" +# ) + +# chunks = generate_chunks(width=10980, height=10980, nlines=1024) +# roi = [rio.windows.Window(*chunk) for chunk in chunks] + +# s2_data = read_s2_img_tile(s2_filename=dataset_sample, rois=roi) +# # get the firt element +# s2_data_window = next(s2_data) + +# assert s2_data_window.s2_reflectances.shape == torch.Size([10, 1024, 10980]) +# assert s2_data_window.s2_angles.shape == torch.Size([6, 1024, 10980]) +# assert s2_data_window.s2_mask.shape == torch.Size([1, 1024, 10980]) + + +# def test_read_s1_img_tile(): +# """ +# test read a window of s1 +# """ +# dataset_sample = os.path.join( +# dataset_dir, +# "S1A_IW_GRDH_1SDV_20180303T174722_20180303T174747_020854_023C45_4525_41.98988334384891_15.429479982324139_roi_0.tif", +# ) + +# chunks = generate_chunks(width=10980, height=10980, nlines=1024) +# roi = [rio.windows.Window(*chunk) for chunk in chunks] + +# s1_data = read_s1_img_tile(s1_filename=dataset_sample, rois=roi) +# # get the firt element +# s1_data_window = next(s1_data) + +# assert s1_data_window.s1_backscatter.shape == torch.Size([3, 1024, 10980]) +# assert s1_data_window.s1_valmask.shape == torch.Size([3, 1024, 10980]) +# assert s1_data_window.s1_edgemask.shape == torch.Size([3, 1024, 10980]) +# assert s1_data_window.s1_lia_angles.shape == torch.Size([3]) + + +# def test_read_srtm_img_tile(): +# """ +# test read a window of srtm +# """ +# dataset_sample = os.path.join(dataset_dir, "srtm_T31TDJ_roi_0.tif") + +# chunks = generate_chunks(width=10980, height=10980, nlines=1024) +# roi = [rio.windows.Window(*chunk) for chunk in chunks] + +# srtm_data = read_srtm_img_tile(srtm_filename=dataset_sample, rois=roi) +# # get the firt element +# srtm_data_window = next(srtm_data) + +# assert srtm_data_window.shape == torch.Size([4, 1024, 10980]) + + +# def test_read_worlclim_img_tile(): +# """ +# test read a window of srtm +# """ +# dataset_sample = os.path.join(dataset_dir, "wc_clim_1_T31TDJ_roi_0.tif") + +# chunks = generate_chunks(width=10980, height=10980, nlines=1024) +# roi = [rio.windows.Window(*chunk) for chunk in chunks] + +# wc_data = read_worldclim_img_tile(wc_filename=dataset_sample, rois=roi) +# # get the firt element +# wc_data_window = next(wc_data) + +# assert wc_data_window.shape == torch.Size([103, 1024, 10980]) + + +# def test_predict_tile(): +# """ """ + +# input_path = "/work/CESBIO/projects/MAESTRIA/test_onetile/total/T31TDJ/" +# export_path = "/work/CESBIO/projects/MAESTRIA/test_onetile/total/export/" + +# mmdc_filenames = GeoTiffDataset( +# s2_filename=os.path.join( +# input_path, "SENTINEL2A_20180302-105023-464_L2A_T31TDJ_C_V2-2_roi_0.tif" +# ), +# s1_filename=os.path.join( +# input_path, +# "S1B_IW_GRDH_1SDV_20180304T173831_20180304T173856_009885_011E2F_CEA4_32.528381567125045_15.330754606990766_roi_0.tif", +# ), +# srtm_filename=os.path.join(input_path, "srtm_T31TDJ_roi_0.tif"), +# wc_filename=os.path.join(input_path, "wc_clim_1_T31TDJ_roi_0.tif"), +# ) + +# # get grid mesh +# rois, meta = get_rois_from_prediction_mesh( +# input_filename=mmdc_filenames.s2_filename, nlines=1024 +# ) + +# export_filename = "latent_variable.tif" +# exported_file = predict_tile( +# input_path=input_path, +# export_path=export_path, +# mmdc_filenames=mmdc_filenames, +# export_filename=export_filename, +# rois=rois, +# meta=meta, +# ) +# # assert existence of exported file +# assert Path(exported_file).exists + + +# # + + +# # asc = "S1A_IW_GRDH_1SDV_20180303T174722_20180303T174747_020854_023C45_4525_39.6607434587337_15.429479982324139_roi_0.tif" +# # desc = "S1B_IW_GRDH_1SDV_20180303T055142_20180303T055207_009863_011D75_B010_45.08497158404881_165.07686884776342_roi_4.tif +# " -- GitLab From bab2b88046f139e6e925b862fe7e6f0107569f4a Mon Sep 17 00:00:00 2001 From: Juan Sebastian Vinasco-Salinas <js.vinasco.s@gmail.com> Date: Fri, 17 Feb 2023 09:16:48 +0000 Subject: [PATCH 10/81] add inference tile --- .../inference/mmdc_tile_inference.py | 273 ++++++++++++++++++ 1 file changed, 273 insertions(+) create mode 100644 src/mmdc_singledate/inference/mmdc_tile_inference.py diff --git a/src/mmdc_singledate/inference/mmdc_tile_inference.py b/src/mmdc_singledate/inference/mmdc_tile_inference.py new file mode 100644 index 00000000..b4aec447 --- /dev/null +++ b/src/mmdc_singledate/inference/mmdc_tile_inference.py @@ -0,0 +1,273 @@ +#!/usr/bin/env python3 +# copyright: (c) 2023 cesbio / centre national d'Etudes Spatiales + +""" +Infereces API with Rasterio +""" + +# imports +import logging +import os +from collections.abc import Callable +from dataclasses import dataclass +from pathlib import Path + +import rasterio as rio +import torch + +from mmdc_singledate.datamodules.components.datamodule_components import prepare_data_df + +from .components.inference_components import ( + get_mmdc_full_model, + patchify_batch, + predict_mmdc_model, + unpatchify_batch, +) +from .components.inference_utils import ( + GeoTiffDataset, + concat_worldclim_components, + generate_chunks, + read_img_tile, + read_s1_img_tile, + read_s2_img_tile, + read_srtm_img_tile, +) + +# Configure the logger +NUMERIC_LEVEL = getattr(logging, "INFO", None) +logging.basicConfig( + level=NUMERIC_LEVEL, format="%(asctime)-15s %(levelname)s: %(message)s" +) +logger = logging.getLogger(__name__) + + +def inference_dataframe( + samples_dir: str, + input_tile_list: list[str], + days_gap: int, + nb_tiles: int, + nb_rois: int, + nb_files: int = None, +): + """ + Read the input directory and create + a dataframe with the occurrences + for be infered + """ + # read the metadata and contruct time serie + ( + tile_df, + asc_orbit_without_acquisitions, + desc_orbit_without_acquisitions, + ) = prepare_data_df( + samples_dir=samples_dir, + input_tile_list=input_tile_list, # or pass, + days_gap=days_gap, + nb_files=nb_files, + nb_tiles=nb_tiles, + nb_rois=1, + ) + + # manage the abscense of S1ASC or S1DESC + if asc_orbit_without_acquisitions: + tile_df["patchasc_s1"] = None + + if desc_orbit_without_acquisitions: + tile_df["patchdesc_s1"] = None + + return tile_df + + +# functions and classes +@dataclass +class MMDCProcess: + """ + Class to hold + """ + + count: int + nb_lines: int + process: Callable[ + [ + torch.tensor, + torch.tensor, + torch.tensor, + torch.tensor, + torch.tensor, + torch.tensor, + torch.tensor, + ], + torch.tensor, + ] + + +def predict_tile( + input_data: GeoTiffDataset, + export_path: Path, + sensors: list[str], + process: MMDCProcess, +): + """ + Predict a tile of data + """ + # retrieve the metadata + meta = input_data.metadata.copy() + # get nb outputs + meta.update({"count": process.count}) + # calculate rois + chunks = generate_chunks(meta["width"], meta["height"], process.nb_lines) + # get the windows from the chunks + rois = [rio.windows.Window(*chunk) for chunk in chunks] + logger.info(f"chunk size : ({rois[0].width}, {rois[0].height}) ") + + # init the dataset + logger.info("Reading S2 data") + s2_data = read_img_tile(input_data.s2_filename, rois, read_s2_img_tile) + logger.info("Reading S1 ASC data") + s1_asc_data = read_img_tile(input_data.s1_asc_filename, rois, read_s1_img_tile) + logger.info("Reading S1 DESC data") + s1_desc_data = read_img_tile(input_data.s1_desc_filename, rois, read_s1_img_tile) + logger.info("Reading SRTM data") + srtm_data = read_img_tile(input_data.srtm_filename, rois, read_srtm_img_tile) + logger.info("Reading WorldClim data") + worldclim_data = concat_worldclim_components(input_data.wc_filename, rois) + + logger.info("Export Init") + + # separate the latent spaces by sensor + sensors = [s for s in sensors if s in ["S2L2A", "S1FULL", "S1ASC", "S1DESC"]] + # built export_filename + export_names = ["mmdc_latent_" + s.casefold() + ".tif" for s in sensors] + for idx, export_name in enumerate(export_names): + # export latent spaces + with rio.open( + os.path.join(export_path, export_name), "w", **meta + ) as prediction: + # iterate over the windows + for roi, s2, s1_asc, s1_desc, srtm, wc in zip( + rois, + s2_data, + s1_asc_data, + s1_desc_data, + srtm_data, + worldclim_data, + ): + print(" original size : ", s2.s2_reflectances.shape) + # reshape the data + s2_refl_patch = patchify_batch(s2.s2_reflectances, 256) + s2_ang_patch = patchify_batch(s2.s2_angles, 256) + s1_asc_patch = patchify_batch(s1_asc.s1_backscatter, 256) + s1_asc_lia_patch = s1_asc.s1_lia_angles + s1_desc_patch = patchify_batch(s1_desc.s1_backscatter, 256) + s1_desc_lia_patch = s1_desc.s1_lia_angles + srtm_patch = patchify_batch(srtm.srtm, 256) + wc_patch = patchify_batch(wc.worldclim, 256) + print( + s2_refl_patch.shape, + s2_ang_patch.shape, + s1_asc_patch.shape, + s1_desc_patch.shape, + s1_asc_lia_patch.shape, + s1_desc_lia_patch.shape, + srtm_patch.shape, + wc_patch.shape, + ) + + # apply predict function + pred_vaelatentspace = process.process( + s2_refl_patch, + s2_ang_patch, + s1_asc_patch, + s1_desc_patch, + s1_asc_lia_patch, + s1_desc_lia_patch, + srtm_patch, + wc_patch, + ) + pred_tensor = torch.cat( + ( + pred_vaelatentspace.mu, + pred_vaelatentspace.logvar, + ), + 1, + ) + print(pred_tensor.shape) + prediction.write( + unpatchify_batch( + flatten_patch=pred_tensor, + patch_shape=s2_refl_patch.shape, + tensor_shape=s2.s2_reflectances.shape, + ), + window=roi, + indexes=1, + ) + logger.info(("Export tile", f"filename :{export_path}")) + + +def mmdc_tile_inference( + input_path: Path, + export_path: Path, + model_checkpoint_path: Path, + sensor: list[str], + tile_list: list[str], + days_gap: int, + nb_tiles: int, + nb_files: int = None, + latent_space_size: int = 2, + nb_lines: int = 1024, +) -> None: + """ + Entry point + + :args: input_path : full path to the export + :args: export_path : full path to the export + :args: sensor: list of sensors data input + :args: tile_list: list of tiles, + :args: days_gap : days between s1 and s2 acquisitions, + :args: nb_tiles : number of tiles to process, + :args: nb_files : number files, + :args: latent_space_size : latent space output par sensor, + :args: nb_lines : number of lines to read every at time 1024, + + :return : None + + """ + # dataframe with input data + tile_df = inference_dataframe( + samples_dir=input_path, + input_tile_list=tile_list, + days_gap=days_gap, + nb_tiles=nb_tiles, + nb_rois=1, + nb_files=nb_files, + ) + + # instance the model and get the pred func + model = get_mmdc_full_model( + checkpoint=model_checkpoint_path, + ) + pred_func = predict_mmdc_model(model=model, sensor=sensor) + # + mmdc_process = MMDCProcess( + count=latent_space_size, nb_lines=nb_lines, process=pred_func + ) + + # iterate over the dates in the time serie + for tuile, df_row in tile_df.iterrows(): + # get the input data + mmdc_input_data = GeoTiffDataset( + s2_filename=df_row.patch_s2, + s1_asc_filename=df_row.patchasc_s1, + s1_desc_filename=df_row.patchdesc_s1, + srtm_filename=df_row.srtm_filename, + wc_filename=df_row.worldclim_filename, + ) + # predict tile + predict_tile( + input_data=mmdc_input_data, + export_path=export_path, + sensors=sensor, + process=mmdc_process, + ) + + logger.info("Export Finish !!!") -- GitLab From 9243fd75ae8c83014cc96d939857db38bbeeab6a Mon Sep 17 00:00:00 2001 From: Juan Sebastian Vinasco-Salinas <js.vinasco.s@gmail.com> Date: Fri, 17 Feb 2023 09:24:43 +0000 Subject: [PATCH 11/81] add iota2 API --- .../inference/mmdc_tile_inference_iota2.py | 98 +++++++++++++++++++ 1 file changed, 98 insertions(+) create mode 100644 src/mmdc_singledate/inference/mmdc_tile_inference_iota2.py diff --git a/src/mmdc_singledate/inference/mmdc_tile_inference_iota2.py b/src/mmdc_singledate/inference/mmdc_tile_inference_iota2.py new file mode 100644 index 00000000..f434b108 --- /dev/null +++ b/src/mmdc_singledate/inference/mmdc_tile_inference_iota2.py @@ -0,0 +1,98 @@ +#!/usr/bin/env python3 +# copyright: (c) 2023 cesbio / centre national d'Etudes Spatiales + +""" +Infereces API with Iota-2 +Inspired by: +https://src.koda.cnrs.fr/mmdc/mmdc-singledate/-/blob/01ff5139a9eb22785930964d181e2a0b7b7af0d1/iota2/external_iota2_code.py +""" + +from typing import Any + +import torch +from torchutils import patches + +from mmdc_singledate.datamodules.components.datamodule_utils import ( + apply_log_to_s1 +) + + +def apply_mmdc_full_mode( + self, + checkpoint_path: str, + checkpoint_epoch: int = 100, + patch_size: int = 256, +): + """ + Apply MMDC with Iota-2.py + """ + # TODO Get the data in the same order as + # sentinel2.Sentinel2.GROUP_10M + + # sensorsio.sentinel2.GROUP_20M + list_bands_s2 = [ + self.get_interpolated_Sentinel2_B2(), + self.get_interpolated_Sentinel2_B3(), + self.get_interpolated_Sentinel2_B4(), + self.get_interpolated_Sentinel2_B8(), + self.get_interpolated_Sentinel2_B5(), + self.get_interpolated_Sentinel2_B6(), + self.get_interpolated_Sentinel2_B7(), + self.get_interpolated_Sentinel2_B8A(), + self.get_interpolated_Sentinel2_B11(), + self.get_interpolated_Sentinel2_B12(), + ] + + # TODO Manage S1 ASC and S1 DESC ? + list_bands_s1 = [ + self.get_interpolated_Sentinel1_ASC_vh(), + self.get_interpolated_Sentinel1_ASC_vv(), + self.get_interpolated_Sentinel1_ASC_vh() + / (self.get_interpolated_Sentinel1_ASC_vv() + 1e-4), + self.get_interpolated_Sentinel1_DES_vh(), + self.get_interpolated_Sentinel1_DES_vv(), + self.get_interpolated_Sentinel1_DES_vh() + / (self.get_interpolated_Sentinel1_DES_vv() + 1e-4), + ] + + # TODO Masks contruction + + + with torch.no_grad(): + # Permute dimensions to fit patchify + # Shape before permutation is C,H,W,D. D being the dates + # Shape after permutation is D,C,H,W + bands_s2 = torch.Tensor(list_bands_s2).permute(-1, 0, 1, 2) + bands_s1 = torch.Tensor(list_bands_s1).permute(-1, 0, 1, 2) + bands_s1 = apply_log_to_s1(bands_s1) + + + # TODO Apply patchify + + + # Get the model + mmdc_full_model = get_mmdc_full_model( + os.path.join(checkpoint_path, checkpoint_filename) + ) + + # apply model + latent_variable = mmdc_full_model.predict( + s2_x, + s2_m, + s2_angles_x, + s1_x, + s1_vm, + s1_asc_angles_x, + s1_desc_angles_x, + worldclim_x, + srtm_x, + ) + + # TODO unpatchify + + # TODO crop padding + + # TODO Depending of the configuration return a unique latent variable + # or a stack of laten variables + coef = pass + labels = pass + return coef, labels -- GitLab From 496a5f0479c6ed5c26fb7b438a295394eeb30522 Mon Sep 17 00:00:00 2001 From: Juan Sebastian Vinasco-Salinas <js.vinasco.s@gmail.com> Date: Fri, 17 Feb 2023 09:25:32 +0000 Subject: [PATCH 12/81] ignore iota2 thirdparty --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index 676c9c84..281e6999 100644 --- a/.gitignore +++ b/.gitignore @@ -10,6 +10,7 @@ src/models/*.ipynb *.swp *~ thirdparties +iota2_thirdparties .coverage src/MMDC_SingleDate.egg-info/ .projectile -- GitLab From fab87779bdbd8d22a71430ad7e59920e5f0cd570 Mon Sep 17 00:00:00 2001 From: Juan Sebastian Vinasco-Salinas <js.vinasco.s@gmail.com> Date: Fri, 17 Feb 2023 09:26:38 +0000 Subject: [PATCH 13/81] add script for configure iota2 --- create-iota2-env.sh | 104 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 104 insertions(+) create mode 100644 create-iota2-env.sh diff --git a/create-iota2-env.sh b/create-iota2-env.sh new file mode 100644 index 00000000..f7035081 --- /dev/null +++ b/create-iota2-env.sh @@ -0,0 +1,104 @@ +#!/usr/bin/env bash + +export python_version="3.9" +export name="mmdc-iota2" +if ! [ -z "$1" ] +then + export name=$1 +fi + +source ~/set_proxy_iota2.sh +if [ -z "$https_proxy" ] +then + echo "Please set https_proxy environment variable before running this script" + exit 1 +fi + +export target=/work/scratch/$USER/virtualenv/$name + +if ! [ -z "$2" ] +then + export target="$2/$name" +fi + +echo "Installing $name in $target ..." + +if [ -d "$target" ]; then + echo "Cleaning previous conda env" + rm -rf $target +fi + +# Create blank virtualenv +module purge +module load conda +module load gcc +conda activate +conda create --yes --prefix $target python==${python_version} pip + +# Enter virtualenv +conda deactivate +conda activate $target + +# install mamba +conda install mamba + +which python +python --version + +# Install iota2 +#mamba install iota2_develop=257da617 -c iota2 +mamba install iota2 -c iota2 + +# clone the lastest version of iota2 +rm -rf thirdparties/iota2 +git clone https://framagit.org/iota2-project/iota2.git thirdparties/iota2 +cd thirdparties/iota2 +git checkout issue#600/tile_exogenous_features +git pull origin issue#600/tile_exogenous_features +cd ../../ +pwd +#which Iota2.py +which Iota2.py + +# create a backup +cd /home/uz/$USER/scratch/virtualenv/mmdc-iota2/lib/python3.9/site-packages/iota2-0.0.0-py3.9.egg/ +mv iota2 iota2_backup + +# create a symbolic link +ln -s ~/src/MMDC/mmdc-singledate/thirdparties/iota2/iota2/ iota2 + +cd ~/src/MMDC/mmdc-singledate +# install missing dependancies +pip install -r requirements-mmdc-iota2.txt + +# test install +Iota2.py -h + +# # install MMDC dependencies +# conda install --yes pytorch=1.12.1 torchvision -c pytorch -c nvidia + +# conda deactivate +# conda activate $target + +# # Requirements +# pip install -r requirements-mmdc-sgld.txt + +module unload git +# Install sensorsio +rm -rf thirdparties/sensorsio +git clone https://src.koda.cnrs.fr/mmdc/sensorsio.git thirdparties/sensorsio +pip install thirdparties/sensorsio + +# Install torchutils +rm -rf thirdparties/torchutils +git clone https://src.koda.cnrs.fr/mmdc/torchutils.git thirdparties/torchutils +pip install thirdparties/torchutils + +# Install the current project in edit mode +pip install -e .[testing] + +# Activate pre-commit hooks +pre-commit install + +# End +conda deactivate -- GitLab From a914b8110c4cbcc6d01d6724e61a22278fd01513 Mon Sep 17 00:00:00 2001 From: Juan Sebastian Vinasco-Salinas <js.vinasco.s@gmail.com> Date: Fri, 17 Feb 2023 14:08:57 +0000 Subject: [PATCH 14/81] Add iota conf --- ...ternalfeatures_with_userfeatures_50x50.cfg | 81 +++++++++++++++++++ configs/iota2/i2_grid.cfg | 19 +++++ jobs/iota2_aux_test.pbs | 20 +++++ jobs/iota2_external_feature_test.pbs | 18 +++++ 4 files changed, 138 insertions(+) create mode 100644 configs/iota2/externalfeatures_with_userfeatures_50x50.cfg create mode 100644 configs/iota2/i2_grid.cfg create mode 100644 jobs/iota2_aux_test.pbs create mode 100644 jobs/iota2_external_feature_test.pbs diff --git a/configs/iota2/externalfeatures_with_userfeatures_50x50.cfg b/configs/iota2/externalfeatures_with_userfeatures_50x50.cfg new file mode 100644 index 00000000..3d074b7b --- /dev/null +++ b/configs/iota2/externalfeatures_with_userfeatures_50x50.cfg @@ -0,0 +1,81 @@ +chain : +{ + output_path : '/home/uz/vinascj/scratch/MMDC/iota2/juan_i2_data/results/external_user_features' + remove_output_path : False + check_inputs : False + + nomenclature_path : '/home/uz/vinascj/scratch/MMDC/iota2/juan_i2_data/inputs/other/nomenclature_grosse.txt' + color_table : '/home/uz/vinascj/scratch/MMDC/iota2/juan_i2_data/inputs/other/colorFile_grosse.txt' + + list_tile : 'T31TCJ' + ground_truth : '/home/uz/vinascj/scratch/MMDC/iota2/juan_i2_data/inputs/vector_data/S2_50x50.shp' + data_field : 'groscode' + s2_path : '/home/uz/vinascj/scratch/MMDC/iota2/juan_i2_data/inputs/raster_data/S2_2dates_50x50_symlink' + # s2_output_path : '/home/uz/vinascj/scratch/MMDC/iota2/juan_i2_data/inputs/raster_data/S2_2dates_50x50_output' + # user_feat_path: '/home/uz/vinascj/scratch/MMDC/iota2/juan_i2_data/inputs/raster_data/mnt_50x50' + # s1_path : '/home/uz/vinascj/scratch/MMDC/iota2/juan_i2_data/inputs/config/SAR_test_T31TCJ.cfg' + + srtm_path:'/datalake/static_aux/MNT/SRTM_30_hgt' + worldclim_path:'/datalake/static_aux/worldclim-2.0' + grid : '/home/uz/vinascj/src/MMDC/mmdc-singledate/thirdparties/sensorsio/src/sensorsio/data/sentinel2/mgrs_tiles.shp' #'/XXXX/Features.shp' # MGRS file providing tiles grid + tile_field : 'Name' + + spatial_resolution : 10 + first_step : 'init' + last_step : 'validation' + proj : 'EPSG:2154' +} +userFeat: +{ + arbo:"/*" + patterns:"mnt2,slope" +} +python_data_managing : +{ + padding_size_x : 1 + padding_size_y : 1 + chunk_size_mode:"split_number" + number_of_chunks:2 + data_mode_access: "both" +} + + +# builders: +# { +# builders_class_name : ["i2_features_to_grid"] +# } + +external_features : +{ + module:"/home/uz/vinascj/scratch/MMDC/iota2/juan_i2_data/inputs/other/soi.py" + functions : [['test_mnt', {"arg1":1}]] + concat_mode : True + external_features_flag:True +} +arg_classification: +{ + enable_probability_map:True + generate_final_probability_map:True +} +arg_train : +{ + + runs : 1 + random_seed : 1 + classifier:"sharkrf" + + sample_selection : + { + sampler : 'random' + strategy : 'percent' + 'strategy.percent.p' : 1.0 + ram : 4000 + } +} + +task_retry_limits : +{ + allowed_retry : 0 + maximum_ram : 180.0 + maximum_cpu : 40 +} diff --git a/configs/iota2/i2_grid.cfg b/configs/iota2/i2_grid.cfg new file mode 100644 index 00000000..e13f8788 --- /dev/null +++ b/configs/iota2/i2_grid.cfg @@ -0,0 +1,19 @@ +chain : +{ + output_path : '/work/scratch/vinascj/MMDC/iota2/grid_test' + spatial_resolution : 100 # output target resolution + first_step : 'tiler' # do not change + last_step : 'tiler' # do not change + + proj : 'EPSG:2154' # output target crs + grid : '/home/uz/vinascj/src/MMDC/mmdc-singledate/thirdparties/sensorsio/src/sensorsio/data/sentinel2/mgrs_tiles.shp' #'/XXXX/Features.shp' # MGRS file providing tiles grid + tile_field : 'Name' + list_tile : '31TCK 31TCJ' # list of tiles in grid to build + srtm_path:'/datalake/static_aux/MNT/SRTM_30_hgt' + worldclim_path:'/datalake/static_aux/worldclim-2.0' +} + +builders: +{ +builders_class_name : ["i2_features_to_grid"] +} diff --git a/jobs/iota2_aux_test.pbs b/jobs/iota2_aux_test.pbs new file mode 100644 index 00000000..3b3eed36 --- /dev/null +++ b/jobs/iota2_aux_test.pbs @@ -0,0 +1,20 @@ +#!/bin/bash +#PBS -N iota2-test +#PBS -l select=1:ncpus=8:mem=12G +#PBS -l walltime=1:00:00 + +# be sure no modules loaded +conda deactivate +module purge + +# load modules +module load conda +conda activate /work/scratch/${USER}/virtualenv/mmdc-iota2 + + +Iota2.py -scheduler_type debug \ + -config ${HOME}/src/MMDC/mmdc-singledate/configs/iota2/i2_grid.cfg + #-nb_parallel_tasks 1 \ + #-only_summary + +# /home/uz/vinascj/src/MMDC/mmdc-singledate/configs/iota2/externalfeatures_with_userfeatures_50x50.cfg \ diff --git a/jobs/iota2_external_feature_test.pbs b/jobs/iota2_external_feature_test.pbs new file mode 100644 index 00000000..4283144b --- /dev/null +++ b/jobs/iota2_external_feature_test.pbs @@ -0,0 +1,18 @@ +#!/bin/bash +#PBS -N iota2-test +#PBS -l select=1:ncpus=8:mem=12G +#PBS -l walltime=1:00:00 + +# be sure no modules loaded +conda deactivate +module purge + +# load modules +module load conda +conda activate /work/scratch/${USER}/virtualenv/mmdc-iota2 + + +Iota2.py -scheduler_type debug \ + -config ${HOME}/src/MMDC/mmdc-singledate/configs/iota2/externalfeatures_with_userfeatures_50x50.cfg \ + -nb_parallel_tasks 1 \ + # -only_summary -- GitLab From 4e51bd8c9d5524902715afd691c921d0c6e2b4ea Mon Sep 17 00:00:00 2001 From: Juan Sebastian Vinasco-Salinas <js.vinasco.s@gmail.com> Date: Fri, 24 Feb 2023 13:17:32 +0000 Subject: [PATCH 15/81] update Iota2 API --- .../inference/mmdc_tile_inference_iota2.py | 34 +++++++++++++++---- 1 file changed, 28 insertions(+), 6 deletions(-) diff --git a/src/mmdc_singledate/inference/mmdc_tile_inference_iota2.py b/src/mmdc_singledate/inference/mmdc_tile_inference_iota2.py index f434b108..1110eccf 100644 --- a/src/mmdc_singledate/inference/mmdc_tile_inference_iota2.py +++ b/src/mmdc_singledate/inference/mmdc_tile_inference_iota2.py @@ -13,7 +13,9 @@ import torch from torchutils import patches from mmdc_singledate.datamodules.components.datamodule_utils import ( - apply_log_to_s1 + apply_log_to_s1, + srtm_height_aspect, + join_even_odd_s2_angles ) @@ -26,9 +28,13 @@ def apply_mmdc_full_mode( """ Apply MMDC with Iota-2.py """ - # TODO Get the data in the same order as + # How manage the S1 acquisition dates for construct the mask validity + # in time + + # DONE Get the data in the same order as # sentinel2.Sentinel2.GROUP_10M + # sensorsio.sentinel2.GROUP_20M + # This data correspond to (Chunk_Size, Img_Size, nb_dates) list_bands_s2 = [ self.get_interpolated_Sentinel2_B2(), self.get_interpolated_Sentinel2_B3(), @@ -42,6 +48,9 @@ def apply_mmdc_full_mode( self.get_interpolated_Sentinel2_B12(), ] + # TODO Masks contruction for S2 + list_s2_mask = self.get_Sentinel2_binary_masks() + # TODO Manage S1 ASC and S1 DESC ? list_bands_s1 = [ self.get_interpolated_Sentinel1_ASC_vh(), @@ -54,16 +63,29 @@ def apply_mmdc_full_mode( / (self.get_interpolated_Sentinel1_DES_vv() + 1e-4), ] - # TODO Masks contruction - - with torch.no_grad(): # Permute dimensions to fit patchify # Shape before permutation is C,H,W,D. D being the dates # Shape after permutation is D,C,H,W bands_s2 = torch.Tensor(list_bands_s2).permute(-1, 0, 1, 2) + bands_s2_mask = torch.Tensor(list_s2_mask).permute(-1, 0, 1, 2) bands_s1 = torch.Tensor(list_bands_s1).permute(-1, 0, 1, 2) - bands_s1 = apply_log_to_s1(bands_s1) + bands_s1 = apply_log_to_s1(bands_s1).permute(-1, 0, 1, 2) + + + # TODO Masks contruction for S1 + # build_s1_image_and_masks function for datamodules components datamodule components ? + + # Replace nan by 0 + bands_s1 = bands_s1.nan_to_num() + # These dimensions are useful for unpatchify + # Keep the dimensions of height and width found in chunk + band_h, band_w = bands_s2.shape[-2:] + # Keep the number of patches of patch_size + # in rows and cols found in chunk + h, w = patches.patchify(bands_s2[0, ...], + patch_size=patch_size, + margin=patch_margin).shape[:2] # TODO Apply patchify -- GitLab From 6a256126f5f6cbc8dc2506d590e58fb4b75ba197 Mon Sep 17 00:00:00 2001 From: Juan Sebastian Vinasco-Salinas <js.vinasco.s@gmail.com> Date: Fri, 24 Feb 2023 13:18:02 +0000 Subject: [PATCH 16/81] test external features --- ...ternalfeatures_with_userfeatures_50x50.cfg | 2 +- jobs/iota2_external_feature_test.pbs | 27 +++++++++++++++++++ 2 files changed, 28 insertions(+), 1 deletion(-) diff --git a/configs/iota2/externalfeatures_with_userfeatures_50x50.cfg b/configs/iota2/externalfeatures_with_userfeatures_50x50.cfg index 3d074b7b..08f1b1b9 100644 --- a/configs/iota2/externalfeatures_with_userfeatures_50x50.cfg +++ b/configs/iota2/externalfeatures_with_userfeatures_50x50.cfg @@ -39,7 +39,6 @@ python_data_managing : data_mode_access: "both" } - # builders: # { # builders_class_name : ["i2_features_to_grid"] @@ -49,6 +48,7 @@ external_features : { module:"/home/uz/vinascj/scratch/MMDC/iota2/juan_i2_data/inputs/other/soi.py" functions : [['test_mnt', {"arg1":1}]] + # functions : [['test_s2_concat', {"arg1":1}]] concat_mode : True external_features_flag:True } diff --git a/jobs/iota2_external_feature_test.pbs b/jobs/iota2_external_feature_test.pbs index 4283144b..f8b6953c 100644 --- a/jobs/iota2_external_feature_test.pbs +++ b/jobs/iota2_external_feature_test.pbs @@ -10,9 +10,36 @@ module purge # load modules module load conda conda activate /work/scratch/${USER}/virtualenv/mmdc-iota2 +# conda activate /work/scratch/${USER}/virtualenv/mmdc-iota2-develop Iota2.py -scheduler_type debug \ -config ${HOME}/src/MMDC/mmdc-singledate/configs/iota2/externalfeatures_with_userfeatures_50x50.cfg \ -nb_parallel_tasks 1 \ # -only_summary + + + +# Functions availables +# ['__class__', '__delattr__', '__dict__', '__dir__', '__doc__', +# '__eq__', '__format__', '__ge__', '__getattribute__', '__gt__', +# '__hash__', '__init__', '__init_subclass__', '__le__', '__lt__', +# '__module__', '__ne__', '__new__', '__reduce__', '__reduce_ex__', +# '__repr__', '__setattr__', '__sizeof__', '__str__', '__subclasshook__', +# '__weakref__', 'all_dates', 'allow_nans', 'binary_masks', 'concat_mode', +# 'dim_ts', 'enabled_gap', 'exogeneous_data_array', 'exogeneous_data_name', +# 'external_functions', 'fill_missing_dates', 'func_param', +# 'get_Sentinel2_binary_masks', 'get_filled_masks', 'get_filled_stack', +# 'get_interpolated_Sentinel2_B11', 'get_interpolated_Sentinel2_B12', +# 'get_interpolated_Sentinel2_B2', 'get_interpolated_Sentinel2_B3', +# 'get_interpolated_Sentinel2_B4', 'get_interpolated_Sentinel2_B5', +# 'get_interpolated_Sentinel2_B6', 'get_interpolated_Sentinel2_B7', +# 'get_interpolated_Sentinel2_B8', 'get_interpolated_Sentinel2_B8A', +# 'get_interpolated_Sentinel2_Brightness', 'get_interpolated_Sentinel2_NDVI', +# 'get_interpolated_Sentinel2_NDWI', 'get_interpolated_dates', 'get_raw_Sentinel2_B11', +# 'get_raw_Sentinel2_B12', 'get_raw_Sentinel2_B2', 'get_raw_Sentinel2_B3', +# 'get_raw_Sentinel2_B4', 'get_raw_Sentinel2_B5', 'get_raw_Sentinel2_B6', +# 'get_raw_Sentinel2_B7', 'get_raw_Sentinel2_B8', 'get_raw_Sentinel2_B8A', +# 'get_raw_dates', 'interpolated_data', 'interpolated_dates', 'missing_masks_values', +# 'missing_refl_values', 'out_data', 'process', 'raw_data', 'raw_dates', +# 'test_user_feature_with_fake_data'] -- GitLab From a63af60bdf4f8e8f0ca18d26741344823e4f95a4 Mon Sep 17 00:00:00 2001 From: Juan Sebastian Vinasco-Salinas <js.vinasco.s@gmail.com> Date: Fri, 24 Feb 2023 13:18:44 +0000 Subject: [PATCH 17/81] test --- test/test_mmdc_inference.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/test/test_mmdc_inference.py b/test/test_mmdc_inference.py index 90d654be..0e981f85 100644 --- a/test/test_mmdc_inference.py +++ b/test/test_mmdc_inference.py @@ -13,7 +13,10 @@ from mmdc_singledate.inference.components.inference_utils import ( GeoTiffDataset, generate_chunks, ) -from mmdc_singledate.inference.mmdc_tile_inference import MMDCProcess, predict_tile +from mmdc_singledate.inference.mmdc_tile_inference import ( + MMDCProcess, + predict_single_date_tile, +) # dir dataset_dir = "/work/CESBIO/projects/MAESTRIA/test_onetile/total/T31TCJ/" @@ -148,7 +151,7 @@ def test_predict_tile(): wc_filename=os.path.join(datapath, "wc_clim_1_T31TDJ_roi_0.tif"), ) - predict_tile( + predict_single_date_tile( input_data=input_data, export_path=export_path, process=process, -- GitLab From a744a9dc3dd8a5c79850978ba0354faee71706ea Mon Sep 17 00:00:00 2001 From: Juan Sebastian Vinasco-Salinas <js.vinasco.s@gmail.com> Date: Fri, 24 Feb 2023 13:19:39 +0000 Subject: [PATCH 18/81] iota env install --- create-iota2-env.sh => create-iota2-env.sh | 39 +++++++++------------ 1 file changed, 17 insertions(+), 22 deletions(-) rename create-iota2-env.sh => create-iota2-env.sh (68%) diff --git a/create-iota2-env.sh b/create-iota2-env.sh similarity index 68% rename from create-iota2-env.sh rename to create-iota2-env.sh index f7035081..0af4a8c1 100644 --- a/create-iota2-env.sh +++ b/create-iota2-env.sh @@ -50,55 +50,50 @@ python --version mamba install iota2 -c iota2 # clone the lastest version of iota2 -rm -rf thirdparties/iota2 -git clone https://framagit.org/iota2-project/iota2.git thirdparties/iota2 -cd thirdparties/iota2 +rm -rf iota2_thirdparties/iota2 +git clone -b issue#600/tile_exogenous_features https://framagit.org/iota2-project/iota2.git iota2_thirdparties/iota2 +cd iota2_thirdparties/iota2 git checkout issue#600/tile_exogenous_features git pull origin issue#600/tile_exogenous_features cd ../../ pwd + #which Iota2.py which Iota2.py - # create a backup cd /home/uz/$USER/scratch/virtualenv/mmdc-iota2/lib/python3.9/site-packages/iota2-0.0.0-py3.9.egg/ mv iota2 iota2_backup # create a symbolic link -ln -s ~/src/MMDC/mmdc-singledate/thirdparties/iota2/iota2/ iota2 +ln -s ~/src/MMDC/mmdc-singledate/iota2_thirdparties/iota2/iota2/ iota2 cd ~/src/MMDC/mmdc-singledate + + +#which Iota2.py +which Iota2.py + # install missing dependancies pip install -r requirements-mmdc-iota2.txt # test install Iota2.py -h -# # install MMDC dependencies -# conda install --yes pytorch=1.12.1 torchvision -c pytorch -c nvidia - -# conda deactivate -# conda activate $target - -# # Requirements -# pip install -r requirements-mmdc-sgld.txt - -module unload git # Install sensorsio -rm -rf thirdparties/sensorsio -git clone https://src.koda.cnrs.fr/mmdc/sensorsio.git thirdparties/sensorsio -pip install thirdparties/sensorsio +rm -rf iota2_thirdparties/sensorsio +git clone https://src.koda.cnrs.fr/mmdc/sensorsio.git iota2_thirdparties/sensorsio +pip install iota2_thirdparties/sensorsio # Install torchutils -rm -rf thirdparties/torchutils -git clone https://src.koda.cnrs.fr/mmdc/torchutils.git thirdparties/torchutils -pip install thirdparties/torchutils +rm -rf iota2_thirdparties/torchutils +git clone https://src.koda.cnrs.fr/mmdc/torchutils.git iota2_thirdparties/torchutils +pip install iota2_thirdparties/torchutils # Install the current project in edit mode pip install -e .[testing] # Activate pre-commit hooks -pre-commit install +# pre-commit install # End conda deactivate -- GitLab From 53de3468d660317820efedede3d63c0db0ab5939 Mon Sep 17 00:00:00 2001 From: Juan Sebastian Vinasco-Salinas <js.vinasco.s@gmail.com> Date: Fri, 24 Feb 2023 13:22:47 +0000 Subject: [PATCH 19/81] add iota2 suplementaty dependencies --- requirements-mmdc-iota2.txt | 3 +++ 1 file changed, 3 insertions(+) create mode 100644 requirements-mmdc-iota2.txt diff --git a/requirements-mmdc-iota2.txt b/requirements-mmdc-iota2.txt new file mode 100644 index 00000000..61906437 --- /dev/null +++ b/requirements-mmdc-iota2.txt @@ -0,0 +1,3 @@ +config==0.5.1 +itk +pydantic -- GitLab From 88bf3886111448b78fe83085d63044bc520da6f4 Mon Sep 17 00:00:00 2001 From: Juan Sebastian Vinasco-Salinas <js.vinasco.s@gmail.com> Date: Fri, 24 Feb 2023 13:23:31 +0000 Subject: [PATCH 20/81] lauch iota2 --- jobs/iota2_aux_angles_test.pbs | 20 ++++++++++++++++++++ jobs/iota2_aux_full.pbs | 20 ++++++++++++++++++++ 2 files changed, 40 insertions(+) create mode 100644 jobs/iota2_aux_angles_test.pbs create mode 100644 jobs/iota2_aux_full.pbs diff --git a/jobs/iota2_aux_angles_test.pbs b/jobs/iota2_aux_angles_test.pbs new file mode 100644 index 00000000..6eaacfe2 --- /dev/null +++ b/jobs/iota2_aux_angles_test.pbs @@ -0,0 +1,20 @@ +#!/bin/bash +#PBS -N iota2-test +#PBS -l select=1:ncpus=8:mem=12G +#PBS -l walltime=1:00:00 + +# be sure no modules loaded +conda deactivate +module purge + +# load modules +module load conda +conda activate /work/scratch/${USER}/virtualenv/mmdc-iota2 + + +Iota2.py -scheduler_type debug \ + -config ${HOME}/src/MMDC/mmdc-singledate/configs/iota2/iota2_grid_angles.cfg + #-nb_parallel_tasks 1 \ + #-only_summary + +# /home/uz/vinascj/src/MMDC/mmdc-singledate/configs/iota2/externalfeatures_with_userfeatures_50x50.cfg \ diff --git a/jobs/iota2_aux_full.pbs b/jobs/iota2_aux_full.pbs new file mode 100644 index 00000000..2e438ed2 --- /dev/null +++ b/jobs/iota2_aux_full.pbs @@ -0,0 +1,20 @@ +#!/bin/bash +#PBS -N iota2-test +#PBS -l select=1:ncpus=8:mem=12G +#PBS -l walltime=1:00:00 + +# be sure no modules loaded +conda deactivate +module purge + +# load modules +module load conda +conda activate /work/scratch/${USER}/virtualenv/mmdc-iota2 + + +Iota2.py -scheduler_type debug \ + -config ${HOME}/src/MMDC/mmdc-singledate/configs/iota2/iota2_grid_full.cfg + #-nb_parallel_tasks 1 \ + #-only_summary + +# /home/uz/vinascj/src/MMDC/mmdc-singledate/configs/iota2/externalfeatures_with_userfeatures_50x50.cfg \ -- GitLab From 94df5c1957df27622fe99ffdd7ab10f6217da143 Mon Sep 17 00:00:00 2001 From: Juan Sebastian Vinasco-Salinas <js.vinasco.s@gmail.com> Date: Fri, 24 Feb 2023 13:25:40 +0000 Subject: [PATCH 21/81] rename func --- src/mmdc_singledate/inference/mmdc_tile_inference.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/mmdc_singledate/inference/mmdc_tile_inference.py b/src/mmdc_singledate/inference/mmdc_tile_inference.py index b4aec447..2511ae3d 100644 --- a/src/mmdc_singledate/inference/mmdc_tile_inference.py +++ b/src/mmdc_singledate/inference/mmdc_tile_inference.py @@ -101,7 +101,7 @@ class MMDCProcess: ] -def predict_tile( +def predict_single_date_tile( input_data: GeoTiffDataset, export_path: Path, sensors: list[str], @@ -263,7 +263,7 @@ def mmdc_tile_inference( wc_filename=df_row.worldclim_filename, ) # predict tile - predict_tile( + predict_single_date_tile( input_data=mmdc_input_data, export_path=export_path, sensors=sensor, -- GitLab From 3592a56b200a78bf9ddcfaecc7caf48c7704f39d Mon Sep 17 00:00:00 2001 From: Juan Sebastian Vinasco-Salinas <js.vinasco.s@gmail.com> Date: Thu, 9 Mar 2023 17:18:05 +0000 Subject: [PATCH 22/81] update test inference --- test/test_mmdc_inference.py | 291 ++++++++++++++++++------------------ 1 file changed, 149 insertions(+), 142 deletions(-) diff --git a/test/test_mmdc_inference.py b/test/test_mmdc_inference.py index 0e981f85..53547ac6 100644 --- a/test/test_mmdc_inference.py +++ b/test/test_mmdc_inference.py @@ -2,12 +2,13 @@ # Copyright: (c) 2023 CESBIO / Centre National d'Etudes Spatiales import os - +from pathlib import Path import pytest import torch from mmdc_singledate.inference.components.inference_components import ( # predict_tile, get_mmdc_full_model, + predict_mmdc_model, ) from mmdc_singledate.inference.components.inference_utils import ( GeoTiffDataset, @@ -15,9 +16,15 @@ from mmdc_singledate.inference.components.inference_utils import ( ) from mmdc_singledate.inference.mmdc_tile_inference import ( MMDCProcess, + inference_dataframe, + mmdc_tile_inference, predict_single_date_tile, ) +from mmdc_singledate.models.components.model_dataclass import ( + VAELatentSpace, +) + # dir dataset_dir = "/work/CESBIO/projects/MAESTRIA/test_onetile/total/T31TCJ/" @@ -25,6 +32,7 @@ dataset_dir = "/work/CESBIO/projects/MAESTRIA/test_onetile/total/T31TCJ/" def test_GeoTiffDataset(): """ """ datapath = "/work/CESBIO/projects/MAESTRIA/test_onetile/total/T31TCJ" + input_data = GeoTiffDataset( s2_filename=os.path.join( datapath, "SENTINEL2A_20180302-105023-464_L2A_T31TCJ_C_V2-2_roi_0.tif" @@ -40,8 +48,12 @@ def test_GeoTiffDataset(): srtm_filename=os.path.join(datapath, "srtm_T31TCJ_roi_0.tif"), wc_filename=os.path.join(datapath, "wc_clim_1_T31TCJ_roi_0.tif"), ) - assert input_data - assert type(input_data.metadata) == dict + # print GetTiffDataset + print(input_data) + # Check Data Existence + assert input_data.s1_asc_availability == True + assert input_data.s1_desc_availability == True + assert input_data.s2_availabitity == True def test_generate_chunks(): @@ -58,6 +70,11 @@ def test_mmdc_full_model(): """ test instantiate network """ + # set device + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + print(device) + + # checkpoints checkpoint_path = "/home/uz/vinascj/scratch/MMDC/results/latent/logs/experiments/runs/mmdc_full/2023-01-06_13-00-34/checkpoints" checkpoint_filename = "epoch_008.ckpt" # "last.ckpt" @@ -66,20 +83,22 @@ def test_mmdc_full_model(): mmdc_full_model = get_mmdc_full_model( os.path.join(checkpoint_path, checkpoint_filename) ) - - print(mmdc_full_model) + # move to device + mmdc_full_model.to(device) assert mmdc_full_model.training == False - s2_x = torch.rand(1, 10, 256, 256) - s2_m = torch.ones(1, 10, 256, 256) - s2_angles_x = torch.rand(1, 6, 256, 256) - s1_x = torch.rand(1, 6, 256, 256) - s1_vm = torch.ones(1, 6, 256, 256) - s1_asc_angles_x = torch.rand(1, 3) - s1_desc_angles_x = torch.rand(1, 3) - worldclim_x = torch.rand(1, 103, 256, 256) - srtm_x = torch.rand(1, 4, 256, 256) + s2_x = torch.rand(1, 10, 256, 256).to(device) + s2_m = torch.ones(1, 10, 256, 256).to(device) + s2_angles_x = torch.rand(1, 6, 256, 256).to(device) + s1_x = torch.rand(1, 6, 256, 256).to(device) + s1_vm = torch.ones(1, 1, 256, 256).to(device) + s1_asc_angles_x = torch.rand(1, 3).to(device) + s1_desc_angles_x = torch.rand(1, 3).to(device) + worldclim_x = torch.rand(1, 103, 256, 256).to(device) + srtm_x = torch.rand(1, 4, 256, 256).to(device) + + print(srtm_x) prediction = mmdc_full_model.predict( s2_x, @@ -94,43 +113,121 @@ def test_mmdc_full_model(): ) latent_variable = prediction[0].latent.latent_s1.mu + print(latent_variable.shape) assert type(latent_variable) == torch.Tensor +sensors_test = [ + (["S2L2A", "S1FULL"]), + (["S2L2A", "S1ASC"]), + (["S2L2A", "S1DESC"]), + (["S2L2A"]), + (["S1FULL"]), + (["S1ASC"]), + (["S1DESC"]), + # ([ "S1FULL","S2L2A" ]), +] + + +@pytest.mark.parametrize("sensors", sensors_test) +def test_predict_mmdc_model(sensors): + """ """ + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + # data input + s2_x = torch.rand(1, 10, 256, 256).to(device) + s2_m = torch.ones(1, 10, 256, 256).to(device) + s2_angles_x = torch.rand(1, 6, 256, 256).to(device) + s1_x = torch.rand(1, 6, 256, 256).to(device) + s1_vm = torch.ones(1, 1, 256, 256).to(device) + s1_asc_angles_x = torch.rand(1, 3).to(device) + s1_desc_angles_x = torch.rand(1, 3).to(device) + worldclim_x = torch.rand(1, 103, 256, 256).to(device) + srtm_x = torch.rand(1, 4, 256, 256).to(device) + + # model + # checkpoints + checkpoint_path = "/home/uz/vinascj/scratch/MMDC/results/latent/logs/experiments/runs/mmdc_full/2023-01-06_13-00-34/checkpoints" + checkpoint_filename = "epoch_008.ckpt" # "last.ckpt" + + mmdc_full_model = get_mmdc_full_model( + os.path.join(checkpoint_path, checkpoint_filename) + ) + # move to device + mmdc_full_model.to(device) + + pred = predict_mmdc_model( + mmdc_full_model, + sensors, + s2_x, + s2_m, + s2_angles_x, + s1_x, + s1_vm, + s1_asc_angles_x, + s1_desc_angles_x, + worldclim_x, + srtm_x, + ) + print(pred.shape) + + assert type(pred) == torch.Tensor + # assert pred.shape[0] == 4 or 2 + + def dummy_process( s2_refl, + s2_mask, s2_ang, - s1_asc, - s1_desc, + s1_back, + s1_vm, s1_asc_lia, s1_desc_lia, - srtm_patch, wc_patch, + srtm_patch, ): """ Create a dummy function """ prediction = ( s2_refl[:, 0, ...] + * s2_mask[:, 0, ...] * s2_ang[:, 0, ...] - * s1_asc[:, 0, ...] - * s1_desc[:, 0, ...] + * s1_back[:, 0, ...] + * s1_vm[:, 0, ...] * srtm_patch[:, 0, ...] * wc_patch[:, 0, ...] ) - return prediction + latent_example = VAELatentSpace( + mu=prediction, + logvar=prediction, + ) -@pytest.mark.skip(reason="Have to update to the new data format") -def test_predict_tile(): - """ """ - export_path = ( - "/work/CESBIO/projects/MAESTRIA/test_onetile/total/export/test_latent.tif" + latent_space_stack = torch.cat( + ( + torch.unsqueeze(latent_example.mu, 1), + torch.unsqueeze(latent_example.logvar, 1), + torch.unsqueeze(latent_example.mu, 1), + torch.unsqueeze(latent_example.logvar, 1), + ), + 1, ) + + return latent_space_stack + + +# TODO add cases with no data +@pytest.mark.parametrize("sensors", sensors_test) +def test_predict_single_date_tile(sensors): + """ """ + export_path = f"/work/CESBIO/projects/MAESTRIA/test_onetile/total/export/test_latent_singledate_{'_'.join(sensors)}.tif" process = MMDCProcess( - count=1, + count=4, nb_lines=1024, + patch_size=256, process=dummy_process, ) @@ -147,130 +244,40 @@ def test_predict_tile(): datapath, "S1A_IW_GRDH_1SDV_20180302T060027_20180302T060052_020832_023B8F_C485_40.83835645911643_165.05888005216622_roi_0.tif", ), - srtm_filename=os.path.join(datapath, "srtm_T31TDJ_roi_0.tif"), - wc_filename=os.path.join(datapath, "wc_clim_1_T31TDJ_roi_0.tif"), + srtm_filename=os.path.join(datapath, "srtm_T31TCJ_roi_0.tif"), + wc_filename=os.path.join(datapath, "wc_clim_1_T31TCJ_roi_0.tif"), ) predict_single_date_tile( input_data=input_data, export_path=export_path, + sensors=sensors, # ["S2L2A"], process=process, ) - -# def test_read_s2_img_tile(): -# """ -# test read a window of s2 -# """ -# dataset_sample = os.path.join( -# dataset_dir, "SENTINEL2A_20180302-105023-464_L2A_T31TDJ_C_V2-2_roi_0.tif" -# ) - -# chunks = generate_chunks(width=10980, height=10980, nlines=1024) -# roi = [rio.windows.Window(*chunk) for chunk in chunks] - -# s2_data = read_s2_img_tile(s2_filename=dataset_sample, rois=roi) -# # get the firt element -# s2_data_window = next(s2_data) - -# assert s2_data_window.s2_reflectances.shape == torch.Size([10, 1024, 10980]) -# assert s2_data_window.s2_angles.shape == torch.Size([6, 1024, 10980]) -# assert s2_data_window.s2_mask.shape == torch.Size([1, 1024, 10980]) - - -# def test_read_s1_img_tile(): -# """ -# test read a window of s1 -# """ -# dataset_sample = os.path.join( -# dataset_dir, -# "S1A_IW_GRDH_1SDV_20180303T174722_20180303T174747_020854_023C45_4525_41.98988334384891_15.429479982324139_roi_0.tif", -# ) - -# chunks = generate_chunks(width=10980, height=10980, nlines=1024) -# roi = [rio.windows.Window(*chunk) for chunk in chunks] - -# s1_data = read_s1_img_tile(s1_filename=dataset_sample, rois=roi) -# # get the firt element -# s1_data_window = next(s1_data) - -# assert s1_data_window.s1_backscatter.shape == torch.Size([3, 1024, 10980]) -# assert s1_data_window.s1_valmask.shape == torch.Size([3, 1024, 10980]) -# assert s1_data_window.s1_edgemask.shape == torch.Size([3, 1024, 10980]) -# assert s1_data_window.s1_lia_angles.shape == torch.Size([3]) - - -# def test_read_srtm_img_tile(): -# """ -# test read a window of srtm -# """ -# dataset_sample = os.path.join(dataset_dir, "srtm_T31TDJ_roi_0.tif") - -# chunks = generate_chunks(width=10980, height=10980, nlines=1024) -# roi = [rio.windows.Window(*chunk) for chunk in chunks] - -# srtm_data = read_srtm_img_tile(srtm_filename=dataset_sample, rois=roi) -# # get the firt element -# srtm_data_window = next(srtm_data) - -# assert srtm_data_window.shape == torch.Size([4, 1024, 10980]) + assert Path(export_path).exists() == True -# def test_read_worlclim_img_tile(): -# """ -# test read a window of srtm -# """ -# dataset_sample = os.path.join(dataset_dir, "wc_clim_1_T31TDJ_roi_0.tif") - -# chunks = generate_chunks(width=10980, height=10980, nlines=1024) -# roi = [rio.windows.Window(*chunk) for chunk in chunks] - -# wc_data = read_worldclim_img_tile(wc_filename=dataset_sample, rois=roi) -# # get the firt element -# wc_data_window = next(wc_data) - -# assert wc_data_window.shape == torch.Size([103, 1024, 10980]) - - -# def test_predict_tile(): -# """ """ - -# input_path = "/work/CESBIO/projects/MAESTRIA/test_onetile/total/T31TDJ/" -# export_path = "/work/CESBIO/projects/MAESTRIA/test_onetile/total/export/" - -# mmdc_filenames = GeoTiffDataset( -# s2_filename=os.path.join( -# input_path, "SENTINEL2A_20180302-105023-464_L2A_T31TDJ_C_V2-2_roi_0.tif" -# ), -# s1_filename=os.path.join( -# input_path, -# "S1B_IW_GRDH_1SDV_20180304T173831_20180304T173856_009885_011E2F_CEA4_32.528381567125045_15.330754606990766_roi_0.tif", -# ), -# srtm_filename=os.path.join(input_path, "srtm_T31TDJ_roi_0.tif"), -# wc_filename=os.path.join(input_path, "wc_clim_1_T31TDJ_roi_0.tif"), -# ) - -# # get grid mesh -# rois, meta = get_rois_from_prediction_mesh( -# input_filename=mmdc_filenames.s2_filename, nlines=1024 -# ) - -# export_filename = "latent_variable.tif" -# exported_file = predict_tile( -# input_path=input_path, -# export_path=export_path, -# mmdc_filenames=mmdc_filenames, -# export_filename=export_filename, -# rois=rois, -# meta=meta, -# ) -# # assert existence of exported file -# assert Path(exported_file).exists - - -# # - +@pytest.mark.parametrize("sensors", sensors_test) +def test_mmdc_tile_inference(sensors): + """ + Test the inference code in a tile + """ + # feed parameters + input_path = Path("/work/CESBIO/projects/MAESTRIA/test_onetile/total/T31TCJ/") + export_path = Path( + "/work/CESBIO/projects/MAESTRIA/test_onetile/export/test_latent_tileinfer_{'_'.join(sensors)}.tif" + ) + model_checkpoint_path = "/home/uz/vinascj/scratch/MMDC/results/latent/logs/experiments/runs/mmdc_full/2023-01-06_13-00-34/checkpoints/epoch_008.ckpt" + # apply prediction -# # asc = "S1A_IW_GRDH_1SDV_20180303T174722_20180303T174747_020854_023C45_4525_39.6607434587337_15.429479982324139_roi_0.tif" -# # desc = "S1B_IW_GRDH_1SDV_20180303T055142_20180303T055207_009863_011D75_B010_45.08497158404881_165.07686884776342_roi_4.tif -# " + mmdc_tile_inference( + input_path=input_path, + export_path=export_path, + model_checkpoint_path=model_checkpoint_path, + sensor=sensors, + tile_list="T31TCJ", + days_gap=15, + latent_space_size=4, + nb_lines=1024, + ) -- GitLab From 58ae3b86af6346f6fc3b72caff6d91e362f4b085 Mon Sep 17 00:00:00 2001 From: Juan Sebastian Vinasco-Salinas <js.vinasco.s@gmail.com> Date: Thu, 9 Mar 2023 17:18:37 +0000 Subject: [PATCH 23/81] update tiler config --- .../externalfeatures_with_userfeatures_50x50.cfg | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/configs/iota2/externalfeatures_with_userfeatures_50x50.cfg b/configs/iota2/externalfeatures_with_userfeatures_50x50.cfg index 08f1b1b9..cdb197f1 100644 --- a/configs/iota2/externalfeatures_with_userfeatures_50x50.cfg +++ b/configs/iota2/externalfeatures_with_userfeatures_50x50.cfg @@ -11,9 +11,9 @@ chain : ground_truth : '/home/uz/vinascj/scratch/MMDC/iota2/juan_i2_data/inputs/vector_data/S2_50x50.shp' data_field : 'groscode' s2_path : '/home/uz/vinascj/scratch/MMDC/iota2/juan_i2_data/inputs/raster_data/S2_2dates_50x50_symlink' + user_feat_path: '/home/uz/vinascj/scratch/MMDC/iota2/juan_i2_data/inputs/raster_data/mnt_50x50' # s2_output_path : '/home/uz/vinascj/scratch/MMDC/iota2/juan_i2_data/inputs/raster_data/S2_2dates_50x50_output' - # user_feat_path: '/home/uz/vinascj/scratch/MMDC/iota2/juan_i2_data/inputs/raster_data/mnt_50x50' - # s1_path : '/home/uz/vinascj/scratch/MMDC/iota2/juan_i2_data/inputs/config/SAR_test_T31TCJ.cfg' + s1_path : '/home/uz/vinascj/scratch/MMDC/iota2/juan_i2_data/inputs/config/SAR_test_T31TCJ.cfg' srtm_path:'/datalake/static_aux/MNT/SRTM_30_hgt' worldclim_path:'/datalake/static_aux/worldclim-2.0' @@ -28,7 +28,7 @@ chain : userFeat: { arbo:"/*" - patterns:"mnt2,slope" + patterns:"slope" } python_data_managing : { @@ -39,11 +39,6 @@ python_data_managing : data_mode_access: "both" } -# builders: -# { -# builders_class_name : ["i2_features_to_grid"] -# } - external_features : { module:"/home/uz/vinascj/scratch/MMDC/iota2/juan_i2_data/inputs/other/soi.py" -- GitLab From d220f2e66453a554a65cf0b6dbf7358bbd28f3dc Mon Sep 17 00:00:00 2001 From: Juan Sebastian Vinasco-Salinas <js.vinasco.s@gmail.com> Date: Thu, 16 Mar 2023 10:17:11 +0000 Subject: [PATCH 24/81] merge --- create-conda-env.sh | 2 ++ 1 file changed, 2 insertions(+) diff --git a/create-conda-env.sh b/create-conda-env.sh index 36d9e095..61f3fc75 100755 --- a/create-conda-env.sh +++ b/create-conda-env.sh @@ -51,6 +51,8 @@ mamba install --yes "pytorch>=2.0.0.dev202302=py3.10_cuda11.7_cudnn8.5.0_0" "tor conda deactivate conda activate $target +# proxy +source ~/set_proxy_iota2.sh # Requirements pip install -r requirements-mmdc-sgld-test.txt -- GitLab From e1c27a0aca6793f3b05f9d71ac823b61509a5467 Mon Sep 17 00:00:00 2001 From: Juan Sebastian Vinasco-Salinas <js.vinasco.s@gmail.com> Date: Thu, 9 Mar 2023 17:19:32 +0000 Subject: [PATCH 25/81] update iota conda env --- create-iota2-env.sh | 104 +++++++++++++++++++++++++++++++++++++++----- 1 file changed, 92 insertions(+), 12 deletions(-) diff --git a/create-iota2-env.sh b/create-iota2-env.sh index 0af4a8c1..06e9ba1f 100644 --- a/create-iota2-env.sh +++ b/create-iota2-env.sh @@ -58,8 +58,6 @@ git pull origin issue#600/tile_exogenous_features cd ../../ pwd -#which Iota2.py -which Iota2.py # create a backup cd /home/uz/$USER/scratch/virtualenv/mmdc-iota2/lib/python3.9/site-packages/iota2-0.0.0-py3.9.egg/ mv iota2 iota2_backup @@ -69,19 +67,14 @@ ln -s ~/src/MMDC/mmdc-singledate/iota2_thirdparties/iota2/iota2/ iota2 cd ~/src/MMDC/mmdc-singledate - -#which Iota2.py -which Iota2.py +conda install -c conda-forge pydantic # install missing dependancies pip install -r requirements-mmdc-iota2.txt -# test install -Iota2.py -h - # Install sensorsio rm -rf iota2_thirdparties/sensorsio -git clone https://src.koda.cnrs.fr/mmdc/sensorsio.git iota2_thirdparties/sensorsio +git clone -b handle_local_srtm https://framagit.org/iota2-project/sensorsio.git iota2_thirdparties/sensorsio pip install iota2_thirdparties/sensorsio # Install torchutils @@ -92,8 +85,95 @@ pip install iota2_thirdparties/torchutils # Install the current project in edit mode pip install -e .[testing] -# Activate pre-commit hooks -# pre-commit install - # End conda deactivate + +# #!/usr/bin/env bash + +# export python_version="3.9" +# export name="mmdc-iota2" +# if ! [ -z "$1" ] +# then +# export name=$1 +# fi + +# source ~/set_proxy_iota2.sh +# if [ -z "$https_proxy" ] +# then +# echo "Please set https_proxy environment variable before running this script" +# exit 1 +# fi + +# export target=/work/scratch/$USER/virtualenv/$name + +# if ! [ -z "$2" ] +# then +# export target="$2/$name" +# fi + +# echo "Installing $name in $target ..." + +# if [ -d "$target" ]; then +# echo "Cleaning previous conda env" +# rm -rf $target +# fi + +# # Create blank virtualenv +# module purge +# module load conda +# module load gcc +# conda activate +# conda create --yes --prefix $target python==${python_version} pip + +# # Enter virtualenv +# conda deactivate +# conda activate $target + +# # install mamba +# conda install mamba + +# which python +# python --version + +# # Install iota2 +# #mamba install iota2_develop=257da617 -c iota2 +# mamba install iota2 -c iota2 + +# clone the lastest version of iota2 +# rm -rf iota2_thirdparties/iota2 +# git clone -b issue#600/tile_exogenous_features https://framagit.org/iota2-project/iota2.git iota2_thirdparties/iota2 +# cd iota2_thirdparties/iota2 +# git checkout issue#600/tile_exogenous_features +# git pull origin issue#600/tile_exogenous_features +# cd ../../ +# pwd + +# # create a backup +# cd /home/uz/$USER/scratch/virtualenv/mmdc-iota2/lib/python3.9/site-packages/iota2-0.0.0-py3.9.egg/ +# mv iota2 iota2_backup + +# # create a symbolic link +# ln -s ~/src/MMDC/mmdc-singledate/iota2_thirdparties/iota2/iota2/ iota2 + +# cd ~/src/MMDC/mmdc-singledate + +# conda install -c conda-forge pydantic + +# # install missing dependancies +# pip install -r requirements-mmdc-iota2.txt #--upgrade --no-cache-dir + +# # Install sensorsio +# rm -rf iota2_thirdparties/sensorsio +# git clone https://src.koda.cnrs.fr/mmdc/sensorsio.git iota2_thirdparties/sensorsio +# pip install iota2_thirdparties/sensorsio + +# # Install torchutils +# rm -rf iota2_thirdparties/torchutils +# git clone https://src.koda.cnrs.fr/mmdc/torchutils.git iota2_thirdparties/torchutils +# pip install iota2_thirdparties/torchutils + +# # Install the current project in edit mode +# pip install -e .[testing] + +# # End +# conda deactivate -- GitLab From 29ec037ee7f861d2b500b63ff434cb8efe3ee46b Mon Sep 17 00:00:00 2001 From: Juan Sebastian Vinasco-Salinas <js.vinasco.s@gmail.com> Date: Thu, 9 Mar 2023 17:20:34 +0000 Subject: [PATCH 26/81] linintg --- test/test_mmdc_inference.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/test/test_mmdc_inference.py b/test/test_mmdc_inference.py index 53547ac6..b7b869b1 100644 --- a/test/test_mmdc_inference.py +++ b/test/test_mmdc_inference.py @@ -3,6 +3,7 @@ import os from pathlib import Path + import pytest import torch @@ -20,10 +21,7 @@ from mmdc_singledate.inference.mmdc_tile_inference import ( mmdc_tile_inference, predict_single_date_tile, ) - -from mmdc_singledate.models.components.model_dataclass import ( - VAELatentSpace, -) +from mmdc_singledate.models.components.model_dataclass import VAELatentSpace # dir dataset_dir = "/work/CESBIO/projects/MAESTRIA/test_onetile/total/T31TCJ/" -- GitLab From 8c72b2f8e613d6a806e0bcf3c9a27e1c5d413a8e Mon Sep 17 00:00:00 2001 From: Juan Sebastian Vinasco-Salinas <js.vinasco.s@gmail.com> Date: Thu, 9 Mar 2023 17:25:17 +0000 Subject: [PATCH 27/81] update utils --- .../inference/components/inference_utils.py | 135 +++++++++++++----- 1 file changed, 100 insertions(+), 35 deletions(-) diff --git a/src/mmdc_singledate/inference/components/inference_utils.py b/src/mmdc_singledate/inference/components/inference_utils.py index 3ec181c5..fe66b640 100644 --- a/src/mmdc_singledate/inference/components/inference_utils.py +++ b/src/mmdc_singledate/inference/components/inference_utils.py @@ -39,59 +39,67 @@ class GeoTiffDataset: """ s2_filename: str + s2_availabitity: bool = field(init=False) s1_asc_filename: str + s1_asc_availability: bool = field(init=False) s1_desc_filename: str + s1_desc_availability: bool = field(init=False) srtm_filename: str wc_filename: str metadata: dict = field(init=False) def __post_init__(self) -> None: # check if the data exists - # if not Path(self.s2_filename).exists(): - # raise Exception(f"{self.s2_filename} do not exists!") - # if not Path(self.s1_asc_filename).exists(): - # raise Exception(f"{self.s1_asc_filename} do not exists!") - # if not Path(self.s1_desc_filename).exists(): - # raise Exception(f"{self.s1_desc_filename} do not exists!") if not Path(self.srtm_filename).exists(): raise Exception(f"{self.srtm_filename} do not exists!") if not Path(self.wc_filename).exists(): raise Exception(f"{self.wc_filename} do not exists!") - + if Path(self.s1_asc_filename).exists(): + self.s1_asc_availability = True + else: + self.s1_asc_availability = False + if Path(self.s1_desc_filename).exists(): + self.s1_desc_availability = True + else: + self.s1_desc_availability = False + if Path(self.s2_filename).exists(): + self.s2_availabitity = True + else: + self.s2_availabitity = False # check if the filenames # provide are in the same # footprint and spatial resolution - # rio.open(self.s2_filename) as s2, rio.open( + # rio.open( # self.s1_asc_filename # ) as s1_asc, rio.open(self.s1_desc_filename) as s1_desc, - + # rio.open(self.s2_filename) as s2, with rio.open(self.srtm_filename) as srtm, rio.open(self.wc_filename) as wc: # recover the metadata self.metadata = srtm.meta - # list to check + # create a python set to check crss = { + srtm.meta["crs"], + wc.meta["crs"], # s2.meta["crs"], # s1_asc.meta["crs"], # s1_desc.meta["crs"], - srtm.meta["crs"], - wc.meta["crs"], } heights = { + srtm.meta["height"], + wc.meta["height"], # s2.meta["height"], # s1_asc.meta["height"], # s1_desc.meta["height"], - srtm.meta["height"], - wc.meta["height"], } widths = { + srtm.meta["width"], + wc.meta["width"], # s2.meta["width"], # s1_asc.meta["width"], # s1_desc.meta["width"], - srtm.meta["width"], - wc.meta["width"], } # check crs @@ -126,7 +134,6 @@ class S1Components: s1_backscatter: torch.Tensor s1_valmask: torch.Tensor - s1_edgemask: torch.Tensor s1_lia_angles: torch.Tensor @@ -177,6 +184,7 @@ def generate_chunks( def read_img_tile( filename: str, rois: list[rio_window], + availabitity: bool, sensor_func: Callable[[torch.Tensor, Any], Any], ) -> Generator[Any, None, None]: """ @@ -186,23 +194,24 @@ def read_img_tile( """ # check if the filename exists # if exist proceed to yield the data - if Path(filename).exists: + if availabitity: with rio.open(filename) as raster: for roi in rois: # read the patch as tensor tensor = torch.tensor(raster.read(window=roi), requires_grad=False) # compute some transformation to the tensor and return dataclass - sensor_data = sensor_func(tensor, filename) + sensor_data = sensor_func(tensor, availabitity, filename) yield sensor_data # if not exists create a zeros tensor and yield else: for roi in rois: - null_data = torch.zeros(roi.width, roi.height) + null_data = torch.ones(roi.width, roi.height) yield null_data def read_s2_img_tile( s2_tensor: torch.Tensor, + availabitity: bool, *args: Any, **kwargs: Any, ) -> S2Components: # [torch.Tensor, torch.Tensor, torch.Tensor]: @@ -211,7 +220,7 @@ def read_s2_img_tile( contruct the masks and yield the patch of data """ - if s2_tensor.shape[0] == 20: + if availabitity: # extract masks cloud_mask = s2_tensor[11, ...].to(torch.uint8) cloud_mask[cloud_mask > 0] = 1 @@ -223,17 +232,39 @@ def read_s2_img_tile( ).unsqueeze(0) angles_s2 = join_even_odd_s2_angles(s2_tensor[14:, ...]) image_s2 = s2_tensor[:10, ...] - else: - image_s2 = torch.zeros(10, s2_tensor.shape[0], s2_tensor.shape[1]) - angles_s2 = torch.zeros(6, s2_tensor.shape[0], s2_tensor.shape[1]) - mask = torch.zeros(1, s2_tensor.shape[0], s2_tensor.shape[1]) + image_s2 = torch.ones(10, s2_tensor.shape[0], s2_tensor.shape[1]) + angles_s2 = torch.ones(4, s2_tensor.shape[0], s2_tensor.shape[1]) + mask = torch.zeros(2, s2_tensor.shape[0], s2_tensor.shape[1]) return S2Components(s2_reflectances=image_s2, s2_angles=angles_s2, s2_mask=mask) +# def read_s2_img_iota2( +# s2_tensor: torch.Tensor, +# *args: Any, +# **kwargs: Any, +# ) -> S2Components: # [torch.Tensor, torch.Tensor, torch.Tensor]: +# """ +# read a patch of sentinel 2 data +# contruct the masks and yield the patch +# of data +# """ +# # copy the masks +# # cloud_mask = +# # sat_mask = +# # edge_mask = +# mask = torch.concat((cloud_mask, sat_mask, edge_mask), axis=0) +# angles_s2 = join_even_odd_s2_angles(s2_tensor[14:, ...]) +# image_s2 = s2_tensor[:10, ...] + +# return S2Components(s2_reflectances=image_s2, s2_angles=angles_s2, s2_mask=mask) + + +# TODO get out the s1_lia computation from this function def read_s1_img_tile( s1_tensor: torch.Tensor, + availability: bool, s1_filename: str, *args: Any, **kwargs: Any, @@ -241,8 +272,11 @@ def read_s1_img_tile( """ read a patch of s1 construct the datasets associated and yield the result + + inputs : s1_tensor read from read_img_tile """ - if s1_tensor.shape[0] == 2: + # if s1_tensor.shape[0] == 2: + if availability: # get the vv,vh, vv/vh bands img_s1 = torch.cat( (s1_tensor, (s1_tensor[1, ...] / s1_tensor[0, ...]).unsqueeze(0)) @@ -250,24 +284,47 @@ def read_s1_img_tile( # get the local incidencce angle (lia) s1_lia = get_s1_acquisition_angles(s1_filename) # compute validity mask - s1_valmask = torch.ones(img_s1.shape) + s1_valmask = torch.ones(1, s1_tensor.shape[1], s1_tensor.shape[2]) # compute edge mask - s1_edgemask = img_s1.to(int) s1_backscatter = apply_log_to_s1(img_s1) else: - s1_backscatter = torch.zeros(3, s1_tensor.shape[0], s1_tensor.shape[1]) + s1_backscatter = torch.ones(3, s1_tensor.shape[0], s1_tensor.shape[1]) s1_valmask = torch.zeros(1, s1_tensor.shape[0], s1_tensor.shape[1]) - s1_edgemask = torch.zeros(1, s1_tensor.shape[0], s1_tensor.shape[1]) - s1_lia = torch.zeros(3) + s1_lia = torch.ones(3) return S1Components( s1_backscatter=s1_backscatter, s1_valmask=s1_valmask, - s1_edgemask=s1_edgemask, s1_lia_angles=s1_lia, ) +# def read_s1_img_iota2( +# s1_tensor: torch.Tensor, +# *args: Any, +# **kwargs: Any, +# ) -> S1Components: # [torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: +# """ +# read a patch of s1 construct the datasets associated +# and yield the result +# """ +# # get the vv,vh, vv/vh bands +# img_s1 = torch.cat( +# (s1_tensor, (s1_tensor[1, ...] / s1_tensor[0, ...]).unsqueeze(0)) +# ) +# s1_lia = get_s1_acquisition_angles(s1_filename) +# # compute validity mask +# s1_valmask = torch.ones(img_s1.shape) +# # compute edge mask +# s1_backscatter = apply_log_to_s1(img_s1) + +# return S1Components( +# s1_backscatter=s1_backscatter, +# s1_valmask=s1_valmask, +# s1_lia_angles=s1_lia, +# ) + + def read_srtm_img_tile( srtm_tensor: torch.Tensor, *args: Any, @@ -292,7 +349,9 @@ def read_worldclim_img_tile( return WorldClimComponents(wc_tensor) -def expand_worldclim_filenames(wc_filename: str) -> list[str]: +def expand_worldclim_filenames( + wc_filename: str, +) -> list[str]: """ given the firts worldclim filename expand the filenames to the others files @@ -308,7 +367,11 @@ def expand_worldclim_filenames(wc_filename: str) -> list[str]: def concat_worldclim_components( - wc_filename: str, rois: list[rio_window] + wc_filename: str, + rois: list[rio_window], + availabitity, + *args: Any, + **kwargs: Any, ) -> Generator[Any, None, None]: """ Compose function for apply the read_img_tile general function @@ -318,7 +381,9 @@ def concat_worldclim_components( wc_filenames = expand_worldclim_filenames(wc_filename) # read the tensors wc_tensor = [ - next(read_img_tile(wc_filename, rois, read_worldclim_img_tile)).worldclim + next( + read_img_tile(wc_filename, rois, availabitity, read_worldclim_img_tile) + ).worldclim for idx, wc_filename in enumerate(wc_filenames) ] yield WorldClimComponents(torch.cat(wc_tensor)) -- GitLab From 4d68a05dbafcf55a1d3f4be0311269bf05e6a984 Mon Sep 17 00:00:00 2001 From: Juan Sebastian Vinasco-Salinas <js.vinasco.s@gmail.com> Date: Thu, 9 Mar 2023 17:28:14 +0000 Subject: [PATCH 28/81] model related stuff --- .../components/inference_components.py | 196 ++++++++++++------ 1 file changed, 131 insertions(+), 65 deletions(-) diff --git a/src/mmdc_singledate/inference/components/inference_components.py b/src/mmdc_singledate/inference/components/inference_components.py index fe3212b2..8841e203 100644 --- a/src/mmdc_singledate/inference/components/inference_components.py +++ b/src/mmdc_singledate/inference/components/inference_components.py @@ -6,17 +6,19 @@ Functions components for get the latent spaces for a given tile """ -import logging # imports +import logging from pathlib import Path from typing import Literal import torch from torch import nn -from torchutils import patches -from mmdc_singledate.models.components.model_dataclass import VAELatentSpace +from mmdc_singledate.models.components.model_dataclass import ( + S1S2VAELatentSpace, + VAELatentSpace, +) from mmdc_singledate.models.mmdc_full_module import MMDCFullModule # define sensors variable for typing @@ -90,8 +92,11 @@ def get_mmdc_full_model( w_code_s1s2=1, ) + # mmdc_full_lightning = MMDCFullLitModule(mmdc_full_model) + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - print(device) + print(f"device detected : {device}") + print(f"nb_gpu's detected : {torch.cuda.device_count()}") # load state_dict lightning_checkpoint = torch.load(checkpoint, map_location=device) @@ -104,11 +109,14 @@ def get_mmdc_full_model( # load the state dict mmdc_full_model.load_state_dict(checkpoint) + # mmdc_full_model.load_state_dict(lightning_checkpoint) + # mmdc_full_lightning.load_state_dict(lightning_checkpoint) # disble randomness, dropout, etc... mmdc_full_model.eval() + # mmdc_full_lightning.eval() - return mmdc_full_model + return mmdc_full_model # mmdc_full_lightning @torch.no_grad() @@ -116,13 +124,14 @@ def predict_mmdc_model( model: nn.Module, sensors: SENSORS, s2_ref, + s2_mask, s2_angles, - s1_asc, - s1_desc, + s1_back, + s1_vm, s1_asc_angles, s1_desc_angles, - srtm, worldclim, + srtm, ) -> VAELatentSpace: """ This function apply the predict method to a @@ -130,7 +139,7 @@ def predict_mmdc_model( by sensor :args: model : mmdc_full model intance - :args: sensor : Sensor acquisition + :args: sensors : Sensor acquisition :args: s2_ref : :args: s2_angles : :args: s1_back : @@ -141,69 +150,126 @@ def predict_mmdc_model( :return: VAELatentSpace """ - s1_back = torch.cat( - ( - s1_asc, - s1_desc, - (s1_asc / s1_desc), - ), - 1, - ) + # TODO check if the sensors are in the current supported list + # available_sensors = ["S2L2A", "S1FULL", "S1ASC", "S1DESC"] + prediction = model.predict( s2_ref, + s2_mask, s2_angles, s1_back, + s1_vm, s1_asc_angles, s1_desc_angles, worldclim, srtm, ) - if "S2L2A" in sensors: - # get latent - latent_space = prediction[0].latent.latent_s2 - - if "S1FULL" in sensors: - latent_space = prediction[0].latent.latent_s1 - - if "S1ASC" in sensors: - latent_space = prediction[0].latent.latent_s1 - - if "S1DESC" in sensors: - latent_space = prediction[0].latent.latent_s1 - - return latent_space - - -def patchify_batch( - tensor: torch.Tensor, - patch_size: int, -) -> torch.tensor: - """ - reshape the geotiff data readed to the - shape expected from the network - - :param: tensor - :param: patch_size - """ - patch = patches.patchify(tensor, patch_size) - flatten_patch = patches.flatten2d(patch) - - return flatten_patch - - -def unpatchify_batch( - flatten_patch: torch.tensor, patch_shape: torch.Size, tensor_shape: torch.Size -) -> torch.tensor: - """ - Inverse operation of patchify batch - :param: # - :param: # - :param: # - :return: # - """ - unflatten_patch = patches.unflatten2d(flatten_patch, patch_shape[0], patch_shape[1]) - unpatch = patches.unpatchify(unflatten_patch) - unpatch_crop = unpatch[:, : tensor_shape[1], : tensor_shape[2]] - - return unpatch_crop + print("prediction.shape :", prediction[0].latent.latent_s2.mu.shape) + print("prediction.shape :", prediction[0].latent.latent_s2.logvar.shape) + print("prediction.shape :", prediction[0].latent.latent_s1.mu.shape) + print("prediction.shape :", prediction[0].latent.latent_s1.logvar.shape) + + # init the output latent spaces as + # empty dataclass + latent_space = S1S2VAELatentSpace(latent_s1=None, latent_s2=None) + + # fullfit with a matchcase + # match + match sensors: + # Cases + case ["S2L2A", "S1FULL"]: + logger.info("S1 full captured") + logger.info("S2 captured") + + latent_space.latent_s2 = prediction[0].latent.latent_s2 + latent_space.latent_s1 = prediction[0].latent.latent_s1 + + latent_space_stack = torch.cat( + ( + latent_space.latent_s2.mu, + latent_space.latent_s2.logvar, + latent_space.latent_s1.mu, + latent_space.latent_s1.logvar, + ), + 0, + ) + + case ["S2L2A", "S1ASC"]: + logger.info("S1 asc captured") + logger.info("S2 captured") + + latent_space.latent_s2 = prediction[0].latent.latent_s2 + latent_space.latent_s1 = prediction[0].latent.latent_s1 + + latent_space_stack = torch.cat( + ( + latent_space.latent_s2.mu, + latent_space.latent_s2.logvar, + latent_space.latent_s1.mu, + latent_space.latent_s1.logvar, + ), + 0, + ) + + case ["S2L2A", "S1DESC"]: + logger.info("S1 desc captured") + logger.info("S2 captured") + + latent_space.latent_s2 = prediction[0].latent.latent_s2 + latent_space.latent_s1 = prediction[0].latent.latent_s1 + + latent_space_stack = torch.cat( + ( + latent_space.latent_s2.mu, + latent_space.latent_s2.logvar, + latent_space.latent_s1.mu, + latent_space.latent_s1.logvar, + ), + 0, + ) + + case ["S2L2A"]: + logger.info("Only S2 captured") + + latent_space.latent_s2 = prediction[0].latent.latent_s2 + + latent_space_stack = torch.cat( + ( + latent_space.latent_s2.mu, + latent_space.latent_s2.logvar, + ), + 0, + ) + + case ["S1FULL"]: + logger.info("Only S1 full captured") + + latent_space.latent_s1 = prediction[0].latent.latent_s1 + + latent_space_stack = torch.cat( + (latent_space.latent_s1.mu, latent_space.latent_s1.logvar), + 0, + ) + + case ["S1ASC"]: + logger.info("Only S1ASC captured") + + latent_space.latent_s1 = prediction[0].latent.latent_s1 + + latent_space_stack = torch.cat( + (latent_space.latent_s1.mu, latent_space.latent_s1.logvar), + 0, + ) + + case ["S1DESC"]: + logger.info("Only S1DESC captured") + + latent_space.latent_s1 = prediction[0].latent.latent_s1 + + latent_space_stack = torch.cat( + (latent_space.latent_s1.mu, latent_space.latent_s1.logvar), + 0, + ) + + return latent_space_stack # latent_space -- GitLab From 4dc55df9d189065899a23e4d5dfd6737e8a22858 Mon Sep 17 00:00:00 2001 From: Juan Sebastian Vinasco-Salinas <js.vinasco.s@gmail.com> Date: Thu, 9 Mar 2023 17:32:02 +0000 Subject: [PATCH 29/81] update aux datasets --- .../inference/mmdc_tile_inference.py | 263 +++++++++++------- 1 file changed, 160 insertions(+), 103 deletions(-) diff --git a/src/mmdc_singledate/inference/mmdc_tile_inference.py b/src/mmdc_singledate/inference/mmdc_tile_inference.py index 2511ae3d..17939afc 100644 --- a/src/mmdc_singledate/inference/mmdc_tile_inference.py +++ b/src/mmdc_singledate/inference/mmdc_tile_inference.py @@ -7,21 +7,20 @@ Infereces API with Rasterio # imports import logging -import os from collections.abc import Callable from dataclasses import dataclass from pathlib import Path +import numpy as np import rasterio as rio import torch +from torchutils import patches from mmdc_singledate.datamodules.components.datamodule_components import prepare_data_df -from .components.inference_components import ( +from .components.inference_components import ( # patchify_batch,; unpatchify_batch, get_mmdc_full_model, - patchify_batch, predict_mmdc_model, - unpatchify_batch, ) from .components.inference_utils import ( GeoTiffDataset, @@ -43,39 +42,34 @@ logger = logging.getLogger(__name__) def inference_dataframe( samples_dir: str, - input_tile_list: list[str], + input_tile_list: list[str] | None, days_gap: int, - nb_tiles: int, - nb_rois: int, - nb_files: int = None, ): """ Read the input directory and create a dataframe with the occurrences for be infered """ - # read the metadata and contruct time serie - ( - tile_df, - asc_orbit_without_acquisitions, - desc_orbit_without_acquisitions, - ) = prepare_data_df( - samples_dir=samples_dir, - input_tile_list=input_tile_list, # or pass, - days_gap=days_gap, - nb_files=nb_files, - nb_tiles=nb_tiles, - nb_rois=1, + + # Create the Tile time serie + + tile_time_serie_df = prepare_data_df( + samples_dir=samples_dir, input_tile_list=input_tile_list, days_gap=days_gap + ) + + tile_time_serie_df["patch_s2_availability"] = tile_time_serie_df["patch_s2"].apply( + lambda x: Path(x).exists() ) - # manage the abscense of S1ASC or S1DESC - if asc_orbit_without_acquisitions: - tile_df["patchasc_s1"] = None + tile_time_serie_df["patchasc_s1_availability"] = tile_time_serie_df[ + "patchasc_s1" + ].apply(lambda x: Path(x).exists()) - if desc_orbit_without_acquisitions: - tile_df["patchdesc_s1"] = None + tile_time_serie_df["patchdesc_s1_availability"] = tile_time_serie_df[ + "patchasc_s1" + ].apply(lambda x: Path(x).exists()) - return tile_df + return tile_time_serie_df # functions and classes @@ -87,6 +81,7 @@ class MMDCProcess: count: int nb_lines: int + patch_size: int process: Callable[ [ torch.tensor, @@ -110,39 +105,59 @@ def predict_single_date_tile( """ Predict a tile of data """ - # retrieve the metadata - meta = input_data.metadata.copy() - # get nb outputs - meta.update({"count": process.count}) - # calculate rois - chunks = generate_chunks(meta["width"], meta["height"], process.nb_lines) - # get the windows from the chunks - rois = [rio.windows.Window(*chunk) for chunk in chunks] - logger.info(f"chunk size : ({rois[0].width}, {rois[0].height}) ") - - # init the dataset - logger.info("Reading S2 data") - s2_data = read_img_tile(input_data.s2_filename, rois, read_s2_img_tile) - logger.info("Reading S1 ASC data") - s1_asc_data = read_img_tile(input_data.s1_asc_filename, rois, read_s1_img_tile) - logger.info("Reading S1 DESC data") - s1_desc_data = read_img_tile(input_data.s1_desc_filename, rois, read_s1_img_tile) - logger.info("Reading SRTM data") - srtm_data = read_img_tile(input_data.srtm_filename, rois, read_srtm_img_tile) - logger.info("Reading WorldClim data") - worldclim_data = concat_worldclim_components(input_data.wc_filename, rois) - - logger.info("Export Init") - - # separate the latent spaces by sensor - sensors = [s for s in sensors if s in ["S2L2A", "S1FULL", "S1ASC", "S1DESC"]] - # built export_filename - export_names = ["mmdc_latent_" + s.casefold() + ".tif" for s in sensors] - for idx, export_name in enumerate(export_names): - # export latent spaces - with rio.open( - os.path.join(export_path, export_name), "w", **meta - ) as prediction: + # check that at least one dataset exists + if ( + Path(input_data.s2_filename).exists + or Path(input_data.s1_asc_filename) + or Path(input_data.s1_desc_filename) + ): + # retrieve the metadata + meta = input_data.metadata.copy() + # get nb outputs + meta.update({"count": process.count}) + # calculate rois + chunks = generate_chunks(meta["width"], meta["height"], process.nb_lines) + # get the windows from the chunks + rois = [rio.windows.Window(*chunk) for chunk in chunks] + logger.info(f"chunk size : ({rois[0].width}, {rois[0].height}) ") + + # init the dataset + logger.info("Reading S2 data") + s2_data = read_img_tile( + filename=input_data.s2_filename, + rois=rois, + availabitity=input_data.s2_availabitity, + sensor_func=read_s2_img_tile, + ) + logger.info("Reading S1 ASC data") + s1_asc_data = read_img_tile( + filename=input_data.s1_asc_filename, + rois=rois, + availabitity=input_data.s1_asc_availability, + sensor_func=read_s1_img_tile, + ) + logger.info("Reading S1 DESC data") + s1_desc_data = read_img_tile( + filename=input_data.s1_desc_filename, + rois=rois, + availabitity=input_data.s1_desc_availability, + sensor_func=read_s1_img_tile, + ) + logger.info("Reading SRTM data") + srtm_data = read_img_tile( + filename=input_data.srtm_filename, + rois=rois, + availabitity=True, + sensor_func=read_srtm_img_tile, + ) + logger.info("Reading WorldClim data") + worldclim_data = concat_worldclim_components( + wc_filename=input_data.wc_filename, rois=rois, availabitity=True + ) + + logger.info("Export Init") + + with rio.open(export_path, "w", **meta) as prediction: # iterate over the windows for roi, s2, s1_asc, s1_desc, srtm, wc in zip( rois, @@ -153,53 +168,100 @@ def predict_single_date_tile( worldclim_data, ): print(" original size : ", s2.s2_reflectances.shape) + print(" original size : ", s2.s2_angles.shape) + print(" original size : ", s2.s2_mask.shape) + print(" original size : ", s1_asc.s1_backscatter.shape) + print(" original size : ", s1_asc.s1_valmask.shape) + print(" original size : ", s1_asc.s1_lia_angles.shape) + print(" original size : ", s1_desc.s1_backscatter.shape) + print(" original size : ", s1_desc.s1_valmask.shape) + print(" original size : ", s1_desc.s1_lia_angles.shape) + print(" original size : ", srtm.srtm.shape) + print(" original size : ", wc.worldclim.shape) + + # Concat S1 Data + s1_backscatter = torch.cat( + (s1_asc.s1_backscatter, s1_desc.s1_backscatter), 0 + ) + # The validity mask of S1 is unique + # Multiply the masks + s1_valmask = torch.mul(s1_asc.s1_valmask, s1_desc.s1_valmask) + + # keep the sizes for recover the original size + s2_mask_patch_size = patches.patchify( + s2.s2_mask, process.patch_size + ).shape + s2_s2_mask_shape = s2.s2_mask.shape # [1, 1024, 10980] + + print("s2_mask_patch_size :", s2_mask_patch_size) + print("s2_s2_mask_shape:", s2_s2_mask_shape) # reshape the data - s2_refl_patch = patchify_batch(s2.s2_reflectances, 256) - s2_ang_patch = patchify_batch(s2.s2_angles, 256) - s1_asc_patch = patchify_batch(s1_asc.s1_backscatter, 256) - s1_asc_lia_patch = s1_asc.s1_lia_angles - s1_desc_patch = patchify_batch(s1_desc.s1_backscatter, 256) - s1_desc_lia_patch = s1_desc.s1_lia_angles - srtm_patch = patchify_batch(srtm.srtm, 256) - wc_patch = patchify_batch(wc.worldclim, 256) - print( - s2_refl_patch.shape, - s2_ang_patch.shape, - s1_asc_patch.shape, - s1_desc_patch.shape, - s1_asc_lia_patch.shape, - s1_desc_lia_patch.shape, - srtm_patch.shape, - wc_patch.shape, + s2_refl_patch = patches.flatten2d( + patches.patchify(s2.s2_reflectances, process.patch_size) + ) + s2_ang_patch = patches.flatten2d( + patches.patchify(s2.s2_angles, process.patch_size) + ) + s2_mask_patch = patches.flatten2d( + patches.patchify(s2.s2_mask, process.patch_size) + ) + s1_patch = patches.flatten2d( + patches.patchify(s1_backscatter, process.patch_size) ) + s1_valmask_patch = patches.flatten2d( + patches.patchify(s1_valmask, process.patch_size) + ) + srtm_patch = patches.flatten2d( + patches.patchify(srtm.srtm, process.patch_size) + ) + wc_patch = patches.flatten2d( + patches.patchify(wc.worldclim, process.patch_size) + ) + # patchify not necessary + + print("s2_patches", s2_refl_patch.shape) + print("s2_angles_patches", s2_ang_patch.shape) + print("s2_mask_patch", s2_mask_patch.shape) + print("s1_patch", s1_patch.shape) + print("s1_valmask_patch", s1_valmask_patch.shape) + print("s1_lia_asc patches", s1_asc.s1_lia_angles.shape) + print("s1_lia_desc patches", s1_desc.s1_lia_angles.shape) + print("srtm patches", srtm_patch.shape) + print("wc patches", wc_patch.shape) # apply predict function + # should return a s1s2vaelatentspace + # object pred_vaelatentspace = process.process( s2_refl_patch, + s2_mask_patch, s2_ang_patch, - s1_asc_patch, - s1_desc_patch, - s1_asc_lia_patch, - s1_desc_lia_patch, - srtm_patch, + s1_patch, + s1_valmask_patch, + s1_asc.s1_lia_angles, + s1_desc.s1_lia_angles, wc_patch, + srtm_patch, ) - pred_tensor = torch.cat( - ( - pred_vaelatentspace.mu, - pred_vaelatentspace.logvar, - ), - 1, - ) - print(pred_tensor.shape) + print(type(pred_vaelatentspace)) + print("latent space sizes : ", pred_vaelatentspace.shape) + + # unpatchify + + pred_vaelatentspace_unpatchify = patches.unpatchify( + patches.unflatten2d( + pred_vaelatentspace, + s2_mask_patch_size[0], + s2_mask_patch_size[1], + ) + )[:, : s2_s2_mask_shape[1], : s2_s2_mask_shape[2]] + + print("pred_tensor :", pred_vaelatentspace_unpatchify.shape) + prediction.write( - unpatchify_batch( - flatten_patch=pred_tensor, - patch_shape=s2_refl_patch.shape, - tensor_shape=s2.s2_reflectances.shape, - ), + np.array(pred_vaelatentspace_unpatchify[0, ...]), window=roi, - indexes=1, + indexes=4, ) logger.info(("Export tile", f"filename :{export_path}")) @@ -211,9 +273,7 @@ def mmdc_tile_inference( sensor: list[str], tile_list: list[str], days_gap: int, - nb_tiles: int, - nb_files: int = None, - latent_space_size: int = 2, + latent_space_size: int = 4, nb_lines: int = 1024, ) -> None: """ @@ -224,8 +284,6 @@ def mmdc_tile_inference( :args: sensor: list of sensors data input :args: tile_list: list of tiles, :args: days_gap : days between s1 and s2 acquisitions, - :args: nb_tiles : number of tiles to process, - :args: nb_files : number files, :args: latent_space_size : latent space output par sensor, :args: nb_lines : number of lines to read every at time 1024, @@ -237,9 +295,6 @@ def mmdc_tile_inference( samples_dir=input_path, input_tile_list=tile_list, days_gap=days_gap, - nb_tiles=nb_tiles, - nb_rois=1, - nb_files=nb_files, ) # instance the model and get the pred func @@ -258,7 +313,9 @@ def mmdc_tile_inference( mmdc_input_data = GeoTiffDataset( s2_filename=df_row.patch_s2, s1_asc_filename=df_row.patchasc_s1, + s1_asc_availibility=df_row.s1_asc_availibility, s1_desc_filename=df_row.patchdesc_s1, + s1_desc_availibitity=df_row.s1_desc_availibitity, srtm_filename=df_row.srtm_filename, wc_filename=df_row.worldclim_filename, ) -- GitLab From ec972af09fd3cab23b6d747960231de11f3a3993 Mon Sep 17 00:00:00 2001 From: Juan Sebastian Vinasco-Salinas <js.vinasco.s@gmail.com> Date: Thu, 9 Mar 2023 17:33:35 +0000 Subject: [PATCH 30/81] add tiler test config --- configs/iota2/iota2_grid_full.cfg | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) create mode 100644 configs/iota2/iota2_grid_full.cfg diff --git a/configs/iota2/iota2_grid_full.cfg b/configs/iota2/iota2_grid_full.cfg new file mode 100644 index 00000000..0a529fbb --- /dev/null +++ b/configs/iota2/iota2_grid_full.cfg @@ -0,0 +1,21 @@ +chain : +{ + output_path : '/home/uz/vinascj/scratch/MMDC/iota2/grid_test_full' + spatial_resolution : 50 + first_step : 'tiler' + last_step : 'tiler' + + proj : 'EPSG:2154' + rasters_grid_path : '/work/OT/theia/oso/arthur/TMP/raster_grid' + tile_field : 'Name' + list_tile : 'T31TCJ' + srtm_path:'/datalake/static_aux/MNT/SRTM_30_hgt' + worldclim_path:'/datalake/static_aux/worldclim-2.0' + s2_path:'/work/OT/theia/oso/arthur/TMP/test_s2_angles' + s1_dir:'/work/OT/theia/oso/arthur/s1_emma/all_tiles' +} + +builders: +{ +builders_class_name : ["i2_features_to_grid"] +} -- GitLab From 61172ee549b58edcf9b9b247f5a98f2aa173c32e Mon Sep 17 00:00:00 2001 From: Juan Sebastian Vinasco-Salinas <js.vinasco.s@gmail.com> Date: Thu, 9 Mar 2023 17:40:30 +0000 Subject: [PATCH 31/81] add Iota2 API --- .../inference/mmdc_iota2_inference.py | 118 ++++++++++++++++++ 1 file changed, 118 insertions(+) create mode 100644 src/mmdc_singledate/inference/mmdc_iota2_inference.py diff --git a/src/mmdc_singledate/inference/mmdc_iota2_inference.py b/src/mmdc_singledate/inference/mmdc_iota2_inference.py new file mode 100644 index 00000000..4a40ff74 --- /dev/null +++ b/src/mmdc_singledate/inference/mmdc_iota2_inference.py @@ -0,0 +1,118 @@ +#!/usr/bin/env python3 +# copyright: (c) 2023 cesbio / centre national d'Etudes Spatiales + +""" +Infereces API with Iota-2 +Inspired by: +https://src.koda.cnrs.fr/mmdc/mmdc-singledate/-/blob/01ff5139a9eb22785930964d181e2a0b7b7af0d1/iota2/external_iota2_code.py +""" + +import torch +from torchutils import patches + +# from mmdc_singledate.datamodules.components.datamodule_utils import ( +# apply_log_to_s1, +# join_even_odd_s2_angles, +# srtm_height_aspect, +# ) + + +def apply_mmdc_full_mode( + self, + checkpoint_path: str, + checkpoint_epoch: int = 100, + patch_size: int = 256, +): + """ + Apply MMDC with Iota-2.py + """ + # How manage the S1 acquisition dates for construct the mask validity + # in time + + # DONE Get the data in the same order as + # sentinel2.Sentinel2.GROUP_10M + + # sensorsio.sentinel2.GROUP_20M + # This data correspond to (Chunk_Size, Img_Size, nb_dates) + list_bands_s2 = [ + self.get_interpolated_Sentinel2_B2(), + self.get_interpolated_Sentinel2_B3(), + self.get_interpolated_Sentinel2_B4(), + self.get_interpolated_Sentinel2_B8(), + self.get_interpolated_Sentinel2_B5(), + self.get_interpolated_Sentinel2_B6(), + self.get_interpolated_Sentinel2_B7(), + self.get_interpolated_Sentinel2_B8A(), + self.get_interpolated_Sentinel2_B11(), + self.get_interpolated_Sentinel2_B12(), + ] + + # TODO Masks contruction for S2 + list_s2_mask = self.get_Sentinel2_binary_masks() + + # TODO Manage S1 ASC and S1 DESC ? + list_bands_s1 = [ + self.get_interpolated_Sentinel1_ASC_vh(), + self.get_interpolated_Sentinel1_ASC_vv(), + self.get_interpolated_Sentinel1_ASC_vh() + / (self.get_interpolated_Sentinel1_ASC_vv() + 1e-4), + self.get_interpolated_Sentinel1_DES_vh(), + self.get_interpolated_Sentinel1_DES_vv(), + self.get_interpolated_Sentinel1_DES_vh() + / (self.get_interpolated_Sentinel1_DES_vv() + 1e-4), + ] + + # TODO Read SRTM data + # TODO Read Worldclim Data + + with torch.no_grad(): + # Permute dimensions to fit patchify + # Shape before permutation is C,H,W,D. D being the dates + # Shape after permutation is D,C,H,W + bands_s2 = torch.Tensor(list_bands_s2).permute(-1, 0, 1, 2) + bands_s2_mask = torch.Tensor(list_s2_mask).permute(-1, 0, 1, 2) + bands_s1 = torch.Tensor(list_bands_s1).permute(-1, 0, 1, 2) + bands_s1 = apply_log_to_s1(bands_s1).permute(-1, 0, 1, 2) + + # TODO Masks contruction for S1 + # build_s1_image_and_masks function for datamodules components datamodule components ? + + # Replace nan by 0 + bands_s1 = bands_s1.nan_to_num() + # These dimensions are useful for unpatchify + # Keep the dimensions of height and width found in chunk + band_h, band_w = bands_s2.shape[-2:] + # Keep the number of patches of patch_size + # in rows and cols found in chunk + h, w = patches.patchify( + bands_s2[0, ...], patch_size=patch_size, margin=patch_margin + ).shape[:2] + + # TODO Apply patchify + + # Get the model + mmdc_full_model = get_mmdc_full_model( + os.path.join(checkpoint_path, checkpoint_filename) + ) + + # apply model + # latent_variable = mmdc_full_model.predict( + # s2_x, + # s2_m, + # s2_angles_x, + # s1_x, + # s1_vm, + # s1_asc_angles_x, + # s1_desc_angles_x, + # worldclim_x, + # srtm_x, + # ) + + # TODO unpatchify + + # TODO crop padding + + # TODO Depending of the configuration return a unique latent variable + # or a stack of laten variables + # coef = pass + # labels = pass + # return coef, labels -- GitLab From 40025c415226024a44c27aa56a4289a87e9809e7 Mon Sep 17 00:00:00 2001 From: Juan Sebastian Vinasco-Salinas <js.vinasco.s@gmail.com> Date: Fri, 10 Mar 2023 17:12:50 +0000 Subject: [PATCH 32/81] testing linting --- src/bin/inference_mmdc_singledate.py | 155 ++++++++++++++++++ src/bin/inference_mmdc_timeserie.py | 133 +++++++++++++++ .../components/inference_components.py | 85 +++++++--- .../inference/mmdc_tile_inference.py | 25 ++- 4 files changed, 372 insertions(+), 26 deletions(-) create mode 100644 src/bin/inference_mmdc_singledate.py create mode 100644 src/bin/inference_mmdc_timeserie.py diff --git a/src/bin/inference_mmdc_singledate.py b/src/bin/inference_mmdc_singledate.py new file mode 100644 index 00000000..37e19ad6 --- /dev/null +++ b/src/bin/inference_mmdc_singledate.py @@ -0,0 +1,155 @@ +#!/usr/bin/env python3 +# copyright: (c) 2023 cesbio / centre national d'Etudes Spatiales + +""" +Inference code for MMDC +""" + + +# imports +import argparse +import logging +import os + +from mmdc_singledate.inference.components.inference_components import ( + GeoTiffDataset, + get_mmdc_full_model, + predict_mmdc_model, +) +from mmdc_singledate.inference.mmdc_tile_inference import ( + MMDCProcess, + predict_single_date_tile, +) + + +def get_parser() -> argparse.ArgumentParser: + """ + Generate argument parser for CLI + """ + arg_parser = argparse.ArgumentParser( + os.path.basename(__file__), + description="Compute the inference a selected number of latent spaces " + " and exported as GTiff for a given tile", + ) + + arg_parser.add_argument( + "--loglevel", + default="INFO", + choices=("DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"), + help="Logger level (default: INFO. Should be one of " + "(DEBUG, INFO, WARNING, ERROR, CRITICAL)", + ) + + arg_parser.add_argument( + "--s2_filename", type=str, help="s2 input file", required=False + ) + + arg_parser.add_argument( + "--s1_asc_filename", type=str, help="s1 asc input file", required=False + ) + + arg_parser.add_argument( + "--s1_desc_filename", type=str, help="s1 desc input file", required=False + ) + + arg_parser.add_argument( + "--srtm_filename", type=str, help="srtm input file", required=True + ) + + arg_parser.add_argument( + "--worldclim_filename", type=str, help="worldclim input file", required=True + ) + + arg_parser.add_argument( + "--export_path", type=str, help="export path ", required=True + ) + + arg_parser.add_argument( + "--model_checkpoint_path", + type=str, + help="model checkpoint path ", + required=True, + ) + + arg_parser.add_argument( + "--latent_space_size", + type=int, + help="Number of channels in the output file", + required=False, + ) + + arg_parser.add_argument( + "--nb_lines", + type=str, + help="Number of Line to read in each window reading", + required=False, + ) + + arg_parser.add_argument( + "--sensors", + dest="sensors", + nargs="+", + help="List of sensors S2L2A | S1FULL | S1ASC | S1DESC (Is sensible to the order S2L2A always firts)", + required=True, + ) + + return arg_parser + + +def main(): + """ + Entry Point + """ + # Parser arguments + parser = get_parser() + args = parser.parse_args() + + # Configure logging + numeric_level = getattr(logging, args.loglevel.upper(), None) + if not isinstance(numeric_level, int): + raise ValueError("Invalid log level: %s" % args.loglevel) + + logging.basicConfig( + level=numeric_level, + datefmt="%y-%m-%d %H:%M:%S", + format="%(asctime)s :: %(levelname)s :: %(message)s", + ) + + # model + model = get_mmdc_full_model( + checkpoint=args.model_checkpoint_path, + ) + + # prediction function + pred_mmdc = predict_mmdc_model(model=model, sensor=args.sensor) + + # process class + mmdc_process = MMDCProcess( + count=args.latent_space_size, + nb_lines=args.nb_lines, + patch_size=args.patch_size, + process=pred_mmdc, + ) + + # create export filename + export_path = f"/work/CESBIO/projects/MAESTRIA/test_onetile/total/export/latent_singledate_{'_'.join(args.sensors)}.tif" + + # GeoTiff + geotiff_dataclass = GeoTiffDataset( + s2_filename=args.s2_filename, + s1_asc_filename=args.s1_asc_filename, + s1_desc_filename=args.s1_desc_filename, + srtm_filename=args.srtm_filename, + wc_filename=args.wc_filename, + ) + + predict_single_date_tile( + input_data=geotiff_dataclass, + export_path=export_path, + sensors=args.sensors, + process=mmdc_process, + ) + + +if __name__ == "__main__": + main() diff --git a/src/bin/inference_mmdc_timeserie.py b/src/bin/inference_mmdc_timeserie.py new file mode 100644 index 00000000..62bd41f9 --- /dev/null +++ b/src/bin/inference_mmdc_timeserie.py @@ -0,0 +1,133 @@ +#!/usr/bin/env python3 +# copyright: (c) 2023 cesbio / centre national d'Etudes Spatiales + +""" +Inference code for MMDC +""" + + +# imports +import argparse +import logging +import os +from pathlib import Path + +from mmdc_singledate.inference.mmdc_tile_inference import mmdc_tile_inference + + +def get_parser() -> argparse.ArgumentParser: + """ + Generate argument parser for CLI + """ + arg_parser = argparse.ArgumentParser( + os.path.basename(__file__), + description="Compute the inference a selected number of latent spaces " + " and exported as GTiff for a given tile", + ) + + arg_parser.add_argument( + "--loglevel", + default="INFO", + choices=("DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"), + help="Logger level (default: INFO. Should be one of " + "(DEBUG, INFO, WARNING, ERROR, CRITICAL)", + ) + + arg_parser.add_argument( + "--input_path", type=str, help="data input folder", required=True + ) + + arg_parser.add_argument( + "--export_path", type=str, help="export path ", required=True + ) + + arg_parser.add_argument( + "--days_gap", + type=str, + help="difference between S1 and S2 acquisitions in days", + required=True, + ) + + arg_parser.add_argument( + "--model_checkpoint_path", + type=str, + help="model checkpoint path ", + required=True, + ) + + arg_parser.add_argument( + "--tile_list", type=str, help="List of tiles to produce", required=False + ) + + arg_parser.add_argument( + "--patch_size", + type=str, + help="Sizes of the patches to make the inference", + required=False, + ) + + arg_parser.add_argument( + "--latent_space_size", + type=int, + help="Number of channels in the output file", + required=False, + ) + + arg_parser.add_argument( + "--nb_lines", + type=str, + help="Number of Line to read in each window reading", + required=False, + ) + + arg_parser.add_argument( + "-sensors", + dest="sensors", + nargs="+", + help="List of sensors S2L2A | S1FULL | S1ASC | S1DESC (Is sensible to the order S2L2A always firts)", + required=True, + ) + + return arg_parser + + +def main(): + """ + Entry Point + """ + # Parser arguments + parser = get_parser() + args = parser.parse_args() + + # Configure logging + numeric_level = getattr(logging, args.loglevel.upper(), None) + if not isinstance(numeric_level, int): + raise ValueError("Invalid log level: %s" % args.loglevel) + + logging.basicConfig( + level=numeric_level, + datefmt="%y-%m-%d %H:%M:%S", + format="%(asctime)s :: %(levelname)s :: %(message)s", + ) + + if not os.path.isdir(args.export_path): + logging.info( + f"The directory is not present. Creating a new one.. " f"{args.export_path}" + ) + Path(args.export_path).mkdir() + + # entry point + mmdc_tile_inference( + input_path=Path(args.input_path), + export_path=Path(args.export_path), + model_checkpoint_path=Path(args.model_checkpoint_path), + sensor=args.sensors, + tile_list=args.tile_list, + days_gap=args.days_gap, + latent_space_size=args.laten_space_size, + nb_lines=args.nb_lines, + ) + + +if __name__ == "__main__": + main() diff --git a/src/mmdc_singledate/inference/components/inference_components.py b/src/mmdc_singledate/inference/components/inference_components.py index 8841e203..934aa525 100644 --- a/src/mmdc_singledate/inference/components/inference_components.py +++ b/src/mmdc_singledate/inference/components/inference_components.py @@ -15,9 +15,8 @@ from typing import Literal import torch from torch import nn -from mmdc_singledate.models.components.model_dataclass import ( +from mmdc_singledate.models.components.model_dataclass import ( # VAELatentSpace, S1S2VAELatentSpace, - VAELatentSpace, ) from mmdc_singledate.models.mmdc_full_module import MMDCFullModule @@ -132,7 +131,7 @@ def predict_mmdc_model( s1_desc_angles, worldclim, srtm, -) -> VAELatentSpace: +) -> torch.Tensor: """ This function apply the predict method to a batch of mmdc data and get the latent space @@ -153,16 +152,19 @@ def predict_mmdc_model( # TODO check if the sensors are in the current supported list # available_sensors = ["S2L2A", "S1FULL", "S1ASC", "S1DESC"] + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + print(device) + prediction = model.predict( - s2_ref, - s2_mask, - s2_angles, - s1_back, - s1_vm, - s1_asc_angles, - s1_desc_angles, - worldclim, - srtm, + s2_ref.to(device), + s2_mask.to(device), + s2_angles.to(device), + s1_back.to(device), + s1_vm.to(device), + s1_asc_angles.to(device), + s1_desc_angles.to(device), + worldclim.to(device), + srtm.to(device), ) print("prediction.shape :", prediction[0].latent.latent_s2.mu.shape) @@ -192,9 +194,20 @@ def predict_mmdc_model( latent_space.latent_s1.mu, latent_space.latent_s1.logvar, ), - 0, + 1, ) + print( + "latent_space.latent_s2.mu :", + latent_space.latent_s2.mu, + ) + print( + "latent_space.latent_s2.logvar :", + latent_space.latent_s2.logvar, + ) + print("latent_space.latent_s1.mu :", latent_space.latent_s1.mu) + print("latent_space.latent_s1.logvar :", latent_space.latent_s1.logvar) + case ["S2L2A", "S1ASC"]: logger.info("S1 asc captured") logger.info("S2 captured") @@ -209,8 +222,18 @@ def predict_mmdc_model( latent_space.latent_s1.mu, latent_space.latent_s1.logvar, ), - 0, + 1, + ) + print( + "latent_space.latent_s2.mu :", + latent_space.latent_s2.mu, ) + print( + "latent_space.latent_s2.logvar :", + latent_space.latent_s2.logvar, + ) + print("latent_space.latent_s1.mu :", latent_space.latent_s1.mu) + print("latent_space.latent_s1.logvar :", latent_space.latent_s1.logvar) case ["S2L2A", "S1DESC"]: logger.info("S1 desc captured") @@ -226,8 +249,18 @@ def predict_mmdc_model( latent_space.latent_s1.mu, latent_space.latent_s1.logvar, ), - 0, + 1, + ) + print( + "latent_space.latent_s2.mu :", + latent_space.latent_s2.mu, + ) + print( + "latent_space.latent_s2.logvar :", + latent_space.latent_s2.logvar, ) + print("latent_space.latent_s1.mu :", latent_space.latent_s1.mu) + print("latent_space.latent_s1.logvar :", latent_space.latent_s1.logvar) case ["S2L2A"]: logger.info("Only S2 captured") @@ -239,7 +272,15 @@ def predict_mmdc_model( latent_space.latent_s2.mu, latent_space.latent_s2.logvar, ), - 0, + 1, + ) + print( + "latent_space.latent_s2.mu :", + latent_space.latent_s2.mu, + ) + print( + "latent_space.latent_s2.logvar :", + latent_space.latent_s2.logvar, ) case ["S1FULL"]: @@ -249,8 +290,10 @@ def predict_mmdc_model( latent_space_stack = torch.cat( (latent_space.latent_s1.mu, latent_space.latent_s1.logvar), - 0, + 1, ) + print("latent_space.latent_s1.mu :", latent_space.latent_s1.mu) + print("latent_space.latent_s1.logvar :", latent_space.latent_s1.logvar) case ["S1ASC"]: logger.info("Only S1ASC captured") @@ -259,8 +302,10 @@ def predict_mmdc_model( latent_space_stack = torch.cat( (latent_space.latent_s1.mu, latent_space.latent_s1.logvar), - 0, + 1, ) + print("latent_space.latent_s1.mu :", latent_space.latent_s1.mu) + print("latent_space.latent_s1.logvar :", latent_space.latent_s1.logvar) case ["S1DESC"]: logger.info("Only S1DESC captured") @@ -269,7 +314,9 @@ def predict_mmdc_model( latent_space_stack = torch.cat( (latent_space.latent_s1.mu, latent_space.latent_s1.logvar), - 0, + 1, ) + print("latent_space.latent_s1.mu :", latent_space.latent_s1.mu) + print("latent_space.latent_s1.logvar :", latent_space.latent_s1.logvar) return latent_space_stack # latent_space diff --git a/src/mmdc_singledate/inference/mmdc_tile_inference.py b/src/mmdc_singledate/inference/mmdc_tile_inference.py index 17939afc..ad7408be 100644 --- a/src/mmdc_singledate/inference/mmdc_tile_inference.py +++ b/src/mmdc_singledate/inference/mmdc_tile_inference.py @@ -82,6 +82,7 @@ class MMDCProcess: count: int nb_lines: int patch_size: int + model : torch.nn process: Callable[ [ torch.tensor, @@ -217,15 +218,21 @@ def predict_single_date_tile( wc_patch = patches.flatten2d( patches.patchify(wc.worldclim, process.patch_size) ) - # patchify not necessary + # Expand the angles to fit the sizes + s1asc_lia_patch = s1_asc.s1_lia_angles.unsqueeze(0).repeat( + s2_mask_patch_size[0] * s2_mask_patch_size[1], 1) + # torch.flatten(start_dim=0, end_dim=1) + s1desc_lia_patch = s1_desc.s1_lia_angles.unsqueeze(0).repeat( + s2_mask_patch_size[0] * s2_mask_patch_size[1], 1) + # torch.flatten(start_dim=0, end_dim=1) print("s2_patches", s2_refl_patch.shape) print("s2_angles_patches", s2_ang_patch.shape) print("s2_mask_patch", s2_mask_patch.shape) print("s1_patch", s1_patch.shape) print("s1_valmask_patch", s1_valmask_patch.shape) - print("s1_lia_asc patches", s1_asc.s1_lia_angles.shape) - print("s1_lia_desc patches", s1_desc.s1_lia_angles.shape) + print("s1_lia_asc patches", s1_asc.s1_lia_angles.shape, s1asc_lia_patch.shape) + print("s1_lia_desc patches", s1_desc.s1_lia_angles.shape, s1desc_lia_patch.shape) print("srtm patches", srtm_patch.shape) print("wc patches", wc_patch.shape) @@ -233,13 +240,15 @@ def predict_single_date_tile( # should return a s1s2vaelatentspace # object pred_vaelatentspace = process.process( + process.model, + sensors, s2_refl_patch, s2_mask_patch, s2_ang_patch, s1_patch, s1_valmask_patch, - s1_asc.s1_lia_angles, - s1_desc.s1_lia_angles, + s1asc_lia_patch, + s1desc_lia_patch, wc_patch, srtm_patch, ) @@ -247,7 +256,6 @@ def predict_single_date_tile( print("latent space sizes : ", pred_vaelatentspace.shape) # unpatchify - pred_vaelatentspace_unpatchify = patches.unpatchify( patches.unflatten2d( pred_vaelatentspace, @@ -257,11 +265,14 @@ def predict_single_date_tile( )[:, : s2_s2_mask_shape[1], : s2_s2_mask_shape[2]] print("pred_tensor :", pred_vaelatentspace_unpatchify.shape) + print("process.count : ", process.count) + # check the pred and the ask are the same + assert process.count == pred_vaelatentspace_unpatchify.shape[0] prediction.write( np.array(pred_vaelatentspace_unpatchify[0, ...]), window=roi, - indexes=4, + indexes=process.count, ) logger.info(("Export tile", f"filename :{export_path}")) -- GitLab From 4e8dba33a78ba4f1fac0e7a2979721b5906215e1 Mon Sep 17 00:00:00 2001 From: Juan Sebastian Vinasco-Salinas <js.vinasco.s@gmail.com> Date: Sat, 11 Mar 2023 20:29:21 +0000 Subject: [PATCH 33/81] update for manage no data cases --- .../inference/components/inference_utils.py | 23 +++++++++++-------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/src/mmdc_singledate/inference/components/inference_utils.py b/src/mmdc_singledate/inference/components/inference_utils.py index fe66b640..9d54bfe0 100644 --- a/src/mmdc_singledate/inference/components/inference_utils.py +++ b/src/mmdc_singledate/inference/components/inference_utils.py @@ -50,19 +50,19 @@ class GeoTiffDataset: def __post_init__(self) -> None: # check if the data exists - if not Path(self.srtm_filename).exists(): + if not Path(self.srtm_filename).is_file(): raise Exception(f"{self.srtm_filename} do not exists!") - if not Path(self.wc_filename).exists(): + if not Path(self.wc_filename).is_file(): raise Exception(f"{self.wc_filename} do not exists!") - if Path(self.s1_asc_filename).exists(): + if Path(self.s1_asc_filename).is_file(): self.s1_asc_availability = True else: self.s1_asc_availability = False - if Path(self.s1_desc_filename).exists(): + if Path(self.s1_desc_filename).is_file(): self.s1_desc_availability = True else: self.s1_desc_availability = False - if Path(self.s2_filename).exists(): + if Path(self.s2_filename).is_file(): self.s2_availabitity = True else: self.s2_availabitity = False @@ -204,9 +204,11 @@ def read_img_tile( yield sensor_data # if not exists create a zeros tensor and yield else: + print("Creating fake data") for roi in rois: null_data = torch.ones(roi.width, roi.height) - yield null_data + sensor_data = sensor_func(null_data, availabitity, filename) + yield sensor_data def read_s2_img_tile( @@ -234,8 +236,9 @@ def read_s2_img_tile( image_s2 = s2_tensor[:10, ...] else: image_s2 = torch.ones(10, s2_tensor.shape[0], s2_tensor.shape[1]) - angles_s2 = torch.ones(4, s2_tensor.shape[0], s2_tensor.shape[1]) - mask = torch.zeros(2, s2_tensor.shape[0], s2_tensor.shape[1]) + angles_s2 = torch.ones(6, s2_tensor.shape[0], s2_tensor.shape[1]) + mask = torch.zeros(1, s2_tensor.shape[0], s2_tensor.shape[1]) + print("passing zero data") return S2Components(s2_reflectances=image_s2, s2_angles=angles_s2, s2_mask=mask) @@ -288,8 +291,8 @@ def read_s1_img_tile( # compute edge mask s1_backscatter = apply_log_to_s1(img_s1) else: - s1_backscatter = torch.ones(3, s1_tensor.shape[0], s1_tensor.shape[1]) - s1_valmask = torch.zeros(1, s1_tensor.shape[0], s1_tensor.shape[1]) + s1_backscatter = torch.ones(3, s1_tensor.shape[1], s1_tensor.shape[0]) + s1_valmask = torch.zeros(1, s1_tensor.shape[1], s1_tensor.shape[0]) s1_lia = torch.ones(3) return S1Components( -- GitLab From 6969fe2933c14bd7fda993c6c39633ed4a4162ee Mon Sep 17 00:00:00 2001 From: Juan Sebastian Vinasco-Salinas <js.vinasco.s@gmail.com> Date: Sat, 11 Mar 2023 20:29:55 +0000 Subject: [PATCH 34/81] update for manage no data cases --- .../inference/mmdc_tile_inference.py | 27 ++++++++++++++----- 1 file changed, 21 insertions(+), 6 deletions(-) diff --git a/src/mmdc_singledate/inference/mmdc_tile_inference.py b/src/mmdc_singledate/inference/mmdc_tile_inference.py index ad7408be..88cb3bfb 100644 --- a/src/mmdc_singledate/inference/mmdc_tile_inference.py +++ b/src/mmdc_singledate/inference/mmdc_tile_inference.py @@ -82,7 +82,7 @@ class MMDCProcess: count: int nb_lines: int patch_size: int - model : torch.nn + model: torch.nn process: Callable[ [ torch.tensor, @@ -155,7 +155,12 @@ def predict_single_date_tile( worldclim_data = concat_worldclim_components( wc_filename=input_data.wc_filename, rois=rois, availabitity=True ) - + print("input_data.s2_availabitity=", input_data.s2_availabitity) + print("s2_data=", s2_data) + print("s1_asc_data=", s1_asc_data) + print("s1_desc_data=", s1_desc_data) + print("srtm_data=", srtm_data) + print("worldclim_data=", worldclim_data) logger.info("Export Init") with rio.open(export_path, "w", **meta) as prediction: @@ -220,10 +225,12 @@ def predict_single_date_tile( ) # Expand the angles to fit the sizes s1asc_lia_patch = s1_asc.s1_lia_angles.unsqueeze(0).repeat( - s2_mask_patch_size[0] * s2_mask_patch_size[1], 1) + s2_mask_patch_size[0] * s2_mask_patch_size[1], 1 + ) # torch.flatten(start_dim=0, end_dim=1) s1desc_lia_patch = s1_desc.s1_lia_angles.unsqueeze(0).repeat( - s2_mask_patch_size[0] * s2_mask_patch_size[1], 1) + s2_mask_patch_size[0] * s2_mask_patch_size[1], 1 + ) # torch.flatten(start_dim=0, end_dim=1) print("s2_patches", s2_refl_patch.shape) @@ -231,8 +238,16 @@ def predict_single_date_tile( print("s2_mask_patch", s2_mask_patch.shape) print("s1_patch", s1_patch.shape) print("s1_valmask_patch", s1_valmask_patch.shape) - print("s1_lia_asc patches", s1_asc.s1_lia_angles.shape, s1asc_lia_patch.shape) - print("s1_lia_desc patches", s1_desc.s1_lia_angles.shape, s1desc_lia_patch.shape) + print( + "s1_lia_asc patches", + s1_asc.s1_lia_angles.shape, + s1asc_lia_patch.shape, + ) + print( + "s1_lia_desc patches", + s1_desc.s1_lia_angles.shape, + s1desc_lia_patch.shape, + ) print("srtm patches", srtm_patch.shape) print("wc patches", wc_patch.shape) -- GitLab From 1a42b395c34be876a32454c5cd4228e01e5ba9fe Mon Sep 17 00:00:00 2001 From: Juan Sebastian Vinasco-Salinas <js.vinasco.s@gmail.com> Date: Sat, 11 Mar 2023 20:30:43 +0000 Subject: [PATCH 35/81] update for manage no data cases --- test/test_mmdc_inference.py | 201 ++++++++++++++++++++++++++++++++++-- 1 file changed, 190 insertions(+), 11 deletions(-) diff --git a/test/test_mmdc_inference.py b/test/test_mmdc_inference.py index b7b869b1..3f4265ba 100644 --- a/test/test_mmdc_inference.py +++ b/test/test_mmdc_inference.py @@ -5,6 +5,7 @@ import os from pathlib import Path import pytest +import rasterio as rio import torch from mmdc_singledate.inference.components.inference_components import ( # predict_tile, @@ -76,8 +77,6 @@ def test_mmdc_full_model(): checkpoint_path = "/home/uz/vinascj/scratch/MMDC/results/latent/logs/experiments/runs/mmdc_full/2023-01-06_13-00-34/checkpoints" checkpoint_filename = "epoch_008.ckpt" # "last.ckpt" - # lightning_dict = torch.load("epoch_008.ckpt", map_location=torch.device("cpu")) - mmdc_full_model = get_mmdc_full_model( os.path.join(checkpoint_path, checkpoint_filename) ) @@ -218,17 +217,61 @@ def dummy_process( # TODO add cases with no data +# @pytest.mark.skip(reason="Test with dummy process ") +# @pytest.mark.parametrize("sensors", sensors_test) +# def test_predict_single_date_tile_with_dummy_process(sensors): +# """ """ +# export_path = f"/work/CESBIO/projects/MAESTRIA/test_onetile/total/export/test_latent_singledate_{'_'.join(sensors)}.tif" +# nb_bands = 4 # len(sensors) * 2 +# print("nb_bands :", nb_bands) +# process = MMDCProcess( +# count=nb_bands, #4, +# nb_lines=1024, +# patch_size=256, +# process=dummy_process, +# ) + +# datapath = "/work/CESBIO/projects/MAESTRIA/test_onetile/total/T31TCJ/" +# input_data = GeoTiffDataset( +# s2_filename=os.path.join( +# datapath, "SENTINEL2A_20180302-105023-464_L2A_T31TCJ_C_V2-2_roi_0.tif" +# ), +# s1_asc_filename=os.path.join( +# datapath, +# "S1A_IW_GRDH_1SDV_20180303T174722_20180303T174747_020854_023C45_4525_36.09887223808782_15.429479982324139_roi_0.tif", +# ), +# s1_desc_filename=os.path.join( +# datapath, +# "S1A_IW_GRDH_1SDV_20180302T060027_20180302T060052_020832_023B8F_C485_40.83835645911643_165.05888005216622_roi_0.tif", +# ), +# srtm_filename=os.path.join(datapath, "srtm_T31TCJ_roi_0.tif"), +# wc_filename=os.path.join(datapath, "wc_clim_1_T31TCJ_roi_0.tif"), +# ) + +# predict_single_date_tile( +# input_data=input_data, +# export_path=export_path, +# sensors=sensors, # ["S2L2A"], +# process=process, +# ) + +# assert Path(export_path).exists() == True + +# with rio.open(export_path) as predited_raster: +# predicted_metadata = predited_raster.meta.copy() + +# assert predicted_metadata["count"] == nb_bands + + +# TODO add cases with no data +# @pytest.mark.skip(reason="Dissable for faster testing") @pytest.mark.parametrize("sensors", sensors_test) def test_predict_single_date_tile(sensors): """ """ - export_path = f"/work/CESBIO/projects/MAESTRIA/test_onetile/total/export/test_latent_singledate_{'_'.join(sensors)}.tif" - process = MMDCProcess( - count=4, - nb_lines=1024, - patch_size=256, - process=dummy_process, - ) + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + export_path = f"/work/CESBIO/projects/MAESTRIA/test_onetile/total/export/test_latent_singledate_with_model_{'_'.join(sensors)}.tif" datapath = "/work/CESBIO/projects/MAESTRIA/test_onetile/total/T31TCJ/" input_data = GeoTiffDataset( s2_filename=os.path.join( @@ -246,16 +289,152 @@ def test_predict_single_date_tile(sensors): wc_filename=os.path.join(datapath, "wc_clim_1_T31TCJ_roi_0.tif"), ) + nb_bands = len(sensors) * 8 + print("nb_bands :", nb_bands) + + # checkpoints + checkpoint_path = "/home/uz/vinascj/scratch/MMDC/results/latent/logs/experiments/runs/mmdc_full/2023-01-06_13-00-34/checkpoints" + checkpoint_filename = "epoch_008.ckpt" # "last.ckpt" + + mmdc_full_model = get_mmdc_full_model( + os.path.join(checkpoint_path, checkpoint_filename) + ) + # move to device + mmdc_full_model.to(device) + assert mmdc_full_model.training == False + + # init the process + process = MMDCProcess( + count=nb_bands, # 4, + nb_lines=256, + patch_size=128, + model=mmdc_full_model, + process=predict_mmdc_model, + ) + + # entry point pred function predict_single_date_tile( input_data=input_data, export_path=export_path, - sensors=sensors, # ["S2L2A"], + sensors=sensors, process=process, ) assert Path(export_path).exists() == True + with rio.open(export_path) as predited_raster: + predicted_metadata = predited_raster.meta.copy() + assert predicted_metadata["count"] == nb_bands + + +s2_filename_variable = "SENTINEL2A_20180302-105023-464_L2A_T31TCJ_C_V2-2_roi_0.tif" +s1_asc_filename_variable = "S1A_IW_GRDH_1SDV_20180303T174722_20180303T174747_020854_023C45_4525_36.09887223808782_15.429479982324139_roi_0.tif" +s1_desc_filename_variable = "S1A_IW_GRDH_1SDV_20180302T060027_20180302T060052_020832_023B8F_C485_40.83835645911643_165.05888005216622_roi_0.tif" +srtm_filename_variable = "srtm_T31TCJ_roi_0.tif" +wc_filename_variable = "wc_clim_1_T31TCJ_roi_0.tif" + +datasets = [ + ( + s2_filename_variable, + s1_asc_filename_variable, + s1_desc_filename_variable, + srtm_filename_variable, + wc_filename_variable, + ), + ( + "", + s1_asc_filename_variable, + s1_desc_filename_variable, + srtm_filename_variable, + wc_filename_variable, + ), + ( + s2_filename_variable, + "", + s1_desc_filename_variable, + srtm_filename_variable, + wc_filename_variable, + ), + ( + s2_filename_variable, + s1_asc_filename_variable, + "", + srtm_filename_variable, + wc_filename_variable, + ), +] + + +@pytest.mark.parametrize( + "s2_filename_variable, s1_asc_filename_variable, s1_desc_filename_variable, srtm_filename_variable, wc_filename_variable", + datasets, +) +def test_predict_single_date_tile_no_data( + s2_filename_variable, + s1_asc_filename_variable, + s1_desc_filename_variable, + srtm_filename_variable, + wc_filename_variable, +): + """ """ + sensors = ["S2L2A", "S1ASC"] + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + export_path = f"/work/CESBIO/projects/MAESTRIA/test_onetile/total/export/test_latent_singledate_with_model_{'_'.join(sensors)}.tif" + datapath = "/work/CESBIO/projects/MAESTRIA/test_onetile/total/T31TCJ/" + input_data = GeoTiffDataset( + s2_filename=os.path.join(datapath, s2_filename_variable), + s1_asc_filename=os.path.join( + datapath, + s1_asc_filename_variable, + ), + s1_desc_filename=os.path.join( + datapath, + s1_desc_filename_variable, + ), + srtm_filename=os.path.join(datapath, srtm_filename_variable), + wc_filename=os.path.join(datapath, wc_filename_variable), + ) + + nb_bands = len(sensors) * 8 + print("nb_bands :", nb_bands) + + # checkpoints + checkpoint_path = "/home/uz/vinascj/scratch/MMDC/results/latent/logs/experiments/runs/mmdc_full/2023-01-06_13-00-34/checkpoints" + checkpoint_filename = "epoch_008.ckpt" # "last.ckpt" + + mmdc_full_model = get_mmdc_full_model( + os.path.join(checkpoint_path, checkpoint_filename) + ) + # move to device + mmdc_full_model.to(device) + assert mmdc_full_model.training == False + + # init the process + process = MMDCProcess( + count=nb_bands, # 4, + nb_lines=256, + patch_size=128, + model=mmdc_full_model, + process=predict_mmdc_model, + ) + + # entry point pred function + predict_single_date_tile( + input_data=input_data, + export_path=export_path, + sensors=sensors, + process=process, + ) + + assert Path(export_path).exists() == True + + with rio.open(export_path) as predited_raster: + predicted_metadata = predited_raster.meta.copy() + assert predicted_metadata["count"] == nb_bands + +@pytest.mark.skip(reason="Check Time Serie generation") @pytest.mark.parametrize("sensors", sensors_test) def test_mmdc_tile_inference(sensors): """ @@ -277,5 +456,5 @@ def test_mmdc_tile_inference(sensors): tile_list="T31TCJ", days_gap=15, latent_space_size=4, - nb_lines=1024, + nb_lines=256, ) -- GitLab From a9a493db915be57d56735434550cb80a6323f75a Mon Sep 17 00:00:00 2001 From: Juan Sebastian Vinasco-Salinas <js.vinasco.s@gmail.com> Date: Wed, 15 Mar 2023 16:06:46 +0000 Subject: [PATCH 36/81] test time serie --- src/mmdc_singledate/inference/components/inference_utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/mmdc_singledate/inference/components/inference_utils.py b/src/mmdc_singledate/inference/components/inference_utils.py index 9d54bfe0..7963eb1a 100644 --- a/src/mmdc_singledate/inference/components/inference_utils.py +++ b/src/mmdc_singledate/inference/components/inference_utils.py @@ -39,14 +39,14 @@ class GeoTiffDataset: """ s2_filename: str - s2_availabitity: bool = field(init=False) s1_asc_filename: str - s1_asc_availability: bool = field(init=False) s1_desc_filename: str - s1_desc_availability: bool = field(init=False) srtm_filename: str wc_filename: str metadata: dict = field(init=False) + s2_availabitity: bool | None = field(default=None, init=True) + s1_asc_availability: bool | None = field(default=None, init=True) + s1_desc_availability: bool | None = field(default=None, init=True) def __post_init__(self) -> None: # check if the data exists -- GitLab From 5d47d8437d9c45cdc81b65dfd5b15e9aea2a5247 Mon Sep 17 00:00:00 2001 From: Juan Sebastian Vinasco-Salinas <js.vinasco.s@gmail.com> Date: Wed, 15 Mar 2023 16:08:23 +0000 Subject: [PATCH 37/81] test time serie --- .../inference/mmdc_tile_inference.py | 129 ++++++++++++++---- 1 file changed, 104 insertions(+), 25 deletions(-) diff --git a/src/mmdc_singledate/inference/mmdc_tile_inference.py b/src/mmdc_singledate/inference/mmdc_tile_inference.py index 88cb3bfb..6780d5d8 100644 --- a/src/mmdc_singledate/inference/mmdc_tile_inference.py +++ b/src/mmdc_singledate/inference/mmdc_tile_inference.py @@ -12,6 +12,7 @@ from dataclasses import dataclass from pathlib import Path import numpy as np +import pandas as pd import rasterio as rio import torch from torchutils import patches @@ -292,14 +293,64 @@ def predict_single_date_tile( logger.info(("Export tile", f"filename :{export_path}")) +def estimate_nb_latent_spaces( + patch_s2_availability: bool, + patchasc_s1_availability: bool, + patchdesc_s1_availability: bool, +) -> int: + """ + Given the availability of the input data + estimate the number of latent variables + """ + + if (patch_s2_availability and patchasc_s1_availability) or ( + patch_s2_availability and patchdesc_s1_availability + ): + nb_latent_spaces = 8 * 2 + + else: + nb_latent_spaces = 4 * 2 + + return nb_latent_spaces + + +def determinate_sensor_list( + patch_s2_availability: bool, + patchasc_s1_availability: bool, + patchdesc_s1_availability: bool, +) -> list: + """ + Convert availabitiy dataframe to list of sensors + """ + sensors_availability = [ + patch_s2_availability, + patchasc_s1_availability, + patchdesc_s1_availability, + ] + match sensors_availability: + case [True, True, True]: + sensors = ["S2L2A", "S1FULL"] + case [True, True, False]: + sensors = ["S2L2A", "S1ASC"] + case [True, False, True]: + sensors = ["S2L2A", "S1DESC"] + case [False, True, True]: + sensors = ["S1FULL"] + case [True, False, False]: + sensors = ["S2L2A"] + case [False, True, False]: + sensors = ["S1ASC"] + case [False, False, True]: + sensors = ["S1DESC"] + + return sensors + + def mmdc_tile_inference( - input_path: Path, + inference_dataframe: pd.DataFrame, export_path: Path, model_checkpoint_path: Path, - sensor: list[str], - tile_list: list[str], - days_gap: int, - latent_space_size: int = 4, + patch_size: int = 256, nb_lines: int = 1024, ) -> None: """ @@ -316,39 +367,67 @@ def mmdc_tile_inference( :return : None """ - # dataframe with input data - tile_df = inference_dataframe( - samples_dir=input_path, - input_tile_list=tile_list, - days_gap=days_gap, - ) + # device + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + # # dataframe with input data + # tile_df = inference_dataframe( + # samples_dir=input_path, + # input_tile_list=tile_list, + # days_gap=days_gap, + # ) + # # just test the df + # tile_df = tile_df.head() # instance the model and get the pred func model = get_mmdc_full_model( checkpoint=model_checkpoint_path, ) - pred_func = predict_mmdc_model(model=model, sensor=sensor) - # - mmdc_process = MMDCProcess( - count=latent_space_size, nb_lines=nb_lines, process=pred_func - ) + model.to(device) # iterate over the dates in the time serie - for tuile, df_row in tile_df.iterrows(): + for tuile, df_row in inference_dataframe.iterrows(): + # estimate nb latent spaces + latent_space_size = estimate_nb_latent_spaces( + df_row["patch_s2_availability"], + df_row["patchasc_s1_availability"], + df_row["patchdesc_s1_availability"], + ) + sensor = determinate_sensor_list( + df_row["patch_s2_availability"], + df_row["patchasc_s1_availability"], + df_row["patchdesc_s1_availability"], + ) + # export path + date_ = df_row["date"].strftime("%Y-%m-%d") + export_file = export_path / f"latent_tile_infer_{date_}_{'_'.join(sensor)}.tif" + # + print(tuile, df_row) + print(latent_space_size) + # define process + mmdc_process = MMDCProcess( + count=latent_space_size, + nb_lines=nb_lines, + patch_size=patch_size, + model=model, + process=predict_mmdc_model, + ) # get the input data mmdc_input_data = GeoTiffDataset( - s2_filename=df_row.patch_s2, - s1_asc_filename=df_row.patchasc_s1, - s1_asc_availibility=df_row.s1_asc_availibility, - s1_desc_filename=df_row.patchdesc_s1, - s1_desc_availibitity=df_row.s1_desc_availibitity, - srtm_filename=df_row.srtm_filename, - wc_filename=df_row.worldclim_filename, + s2_filename=df_row["patch_s2"], + s1_asc_filename=df_row["patchasc_s1"], + s1_desc_filename=df_row["patchdesc_s1"], + srtm_filename=df_row["srtm_filename"], + wc_filename=df_row["worldclim_filename"], + s2_availabitity=df_row["patch_s2_availability"], + s1_asc_availability=df_row["patchasc_s1_availability"], + s1_desc_availability=df_row["patchdesc_s1_availability"], ) # predict tile predict_single_date_tile( input_data=mmdc_input_data, - export_path=export_path, + export_path=export_file, sensors=sensor, process=mmdc_process, ) -- GitLab From d2ad4c82ed209e762861c2a746ca187827166067 Mon Sep 17 00:00:00 2001 From: Juan Sebastian Vinasco-Salinas <js.vinasco.s@gmail.com> Date: Wed, 15 Mar 2023 16:12:38 +0000 Subject: [PATCH 38/81] update test time serie --- test/test_mmdc_inference.py | 154 +++++++++++++++++++++++++++++++++--- 1 file changed, 141 insertions(+), 13 deletions(-) diff --git a/test/test_mmdc_inference.py b/test/test_mmdc_inference.py index 3f4265ba..4f7b1f0a 100644 --- a/test/test_mmdc_inference.py +++ b/test/test_mmdc_inference.py @@ -2,8 +2,10 @@ # Copyright: (c) 2023 CESBIO / Centre National d'Etudes Spatiales import os +from datetime import datetime from pathlib import Path +import pandas as pd import pytest import rasterio as rio import torch @@ -28,6 +30,7 @@ from mmdc_singledate.models.components.model_dataclass import VAELatentSpace dataset_dir = "/work/CESBIO/projects/MAESTRIA/test_onetile/total/T31TCJ/" +# @pytest.mark.skip(reason="Check Time Serie generation") def test_GeoTiffDataset(): """ """ datapath = "/work/CESBIO/projects/MAESTRIA/test_onetile/total/T31TCJ" @@ -55,6 +58,7 @@ def test_GeoTiffDataset(): assert input_data.s2_availabitity == True +# @pytest.mark.skip(reason="Check Time Serie generation") def test_generate_chunks(): """ test chunk functionality @@ -65,6 +69,7 @@ def test_generate_chunks(): assert chunks[0] == (0, 0, 10980, 1024) +# @pytest.mark.skip(reason="Check Time Serie generation") def test_mmdc_full_model(): """ test instantiate network @@ -127,6 +132,7 @@ sensors_test = [ ] +# @pytest.mark.skip(reason="Check Time Serie generation") @pytest.mark.parametrize("sensors", sensors_test) def test_predict_mmdc_model(sensors): """ """ @@ -369,6 +375,7 @@ datasets = [ "s2_filename_variable, s1_asc_filename_variable, s1_desc_filename_variable, srtm_filename_variable, wc_filename_variable", datasets, ) +# @pytest.mark.skip(reason="Check Time Serie generation") def test_predict_single_date_tile_no_data( s2_filename_variable, s1_asc_filename_variable, @@ -434,27 +441,148 @@ def test_predict_single_date_tile_no_data( assert predicted_metadata["count"] == nb_bands -@pytest.mark.skip(reason="Check Time Serie generation") -@pytest.mark.parametrize("sensors", sensors_test) -def test_mmdc_tile_inference(sensors): +# @pytest.mark.skip(reason="Check Time Serie generation") +def test_mmdc_tile_inference(): """ Test the inference code in a tile """ # feed parameters - input_path = Path("/work/CESBIO/projects/MAESTRIA/test_onetile/total/T31TCJ/") - export_path = Path( - "/work/CESBIO/projects/MAESTRIA/test_onetile/export/test_latent_tileinfer_{'_'.join(sensors)}.tif" - ) + input_path = Path("/work/CESBIO/projects/MAESTRIA/training_dataset2/") + export_path = Path("/work/CESBIO/projects/MAESTRIA/test_onetile/total/export/") model_checkpoint_path = "/home/uz/vinascj/scratch/MMDC/results/latent/logs/experiments/runs/mmdc_full/2023-01-06_13-00-34/checkpoints/epoch_008.ckpt" - # apply prediction + input_tile_list = ["/work/CESBIO/projects/MAESTRIA/training_dataset2/T33TUL"] + + # dataframe with input data + tile_df = pd.DataFrame( + { + "date": [ + datetime.strptime("2018-01-11", "%Y-%m-%d"), + datetime.strptime("2018-01-14", "%Y-%m-%d"), + datetime.strptime("2018-01-19", "%Y-%m-%d"), + datetime.strptime("2018-01-21", "%Y-%m-%d"), + datetime.strptime("2018-01-24", "%Y-%m-%d"), + ], + "patch_s2": [ + "/work/CESBIO/projects/MAESTRIA/training_dataset2/T33TUL/SENTINEL2B_20180111-100351-463_L2A_T33TUL_D_V1-4_roi_0.tif", + "/work/CESBIO/projects/MAESTRIA/training_dataset2/T33TUL/SENTINEL2B_20180114-101347-463_L2A_T33TUL_D_V1-4_roi_0.tif", + "/work/CESBIO/projects/MAESTRIA/training_dataset2/T33TUL/SENTINEL2A_20180119-101331-457_L2A_T33TUL_D_V1-4_roi_0.tif", + "/work/CESBIO/projects/MAESTRIA/training_dataset2/T33TUL/SENTINEL2B_20180121-100542-770_L2A_T33TUL_D_V1-4_roi_0.tif", + "/work/CESBIO/projects/MAESTRIA/training_dataset2/T33TUL/SENTINEL2B_20180124-101352-753_L2A_T33TUL_D_V1-4_roi_0.tif", + ], + "patchasc_s1": [ + "/work/CESBIO/projects/MAESTRIA/training_dataset2/T33TUL/S1A_IW_GRDH_1SDV_20180108T165831_20180108T165856_020066_022322_3008_36.29396013100586_14.593286735562288_roi_0.tif", + "/work/CESBIO/projects/MAESTRIA/training_dataset2/T33TUL/S1B_IW_GRDH_1SDV_20180114T165747_20180114T165812_009170_0106A2_AFD8_36.27898252884488_14.533603849909781_roi_0.tif", + "/work/CESBIO/projects/MAESTRIA/training_dataset2/T33TUL/nan", + "/work/CESBIO/projects/MAESTRIA/training_dataset2/T33TUL/S1A_IW_GRDH_1SDV_20180120T165830_20180120T165855_020241_0228B0_95C7_36.29306272148375_14.591766979064303_roi_0.tif", + "/work/CESBIO/projects/MAESTRIA/training_dataset2/T33TUL/nan", + ], + "patchdesc_s1": [ + "/work/CESBIO/projects/MAESTRIA/training_dataset2/T33TUL/nan", + "/work/CESBIO/projects/MAESTRIA/training_dataset2/T33TUL/S1B_IW_GRDH_1SDV_20180113T051001_20180113T051026_009148_0105F4_218B_45.70533932879138_164.47774901324672_roi_0.tif", + "/work/CESBIO/projects/MAESTRIA/training_dataset2/T33TUL/S1A_IW_GRDH_1SDV_20180119T051046_20180119T051111_020219_0227FF_F7E9_45.71131905091333_164.664087191203_roi_0.tif", + "/work/CESBIO/projects/MAESTRIA/training_dataset2/T33TUL/S1A_IW_GRDH_1SDV_20180119T051046_20180119T051111_020219_0227FF_F7E9_45.71131905091333_164.664087191203_roi_0.tif", + "/work/CESBIO/projects/MAESTRIA/training_dataset2/T33TUL/S1A_IW_GRDH_1SDV_20180124T051848_20180124T051913_020292_022A63_177C_35.81756575888757_165.03806323648553_roi_0.tif", + ], + "Tuiles": [ + "T33TUL", + "T33TUL", + "T33TUL", + "T33TUL", + "T33TUL", + ], + "patch": [ + 0, + 0, + 0, + 0, + 0, + ], + "roi": [ + 0, + 0, + 0, + 0, + 0, + ], + "worldclim_filename": [ + "/work/CESBIO/projects/MAESTRIA/training_dataset2/T33TUL/wc_clim_1_T33TUL_roi_0.tif", + "/work/CESBIO/projects/MAESTRIA/training_dataset2/T33TUL/wc_clim_1_T33TUL_roi_0.tif", + "/work/CESBIO/projects/MAESTRIA/training_dataset2/T33TUL/wc_clim_1_T33TUL_roi_0.tif", + "/work/CESBIO/projects/MAESTRIA/training_dataset2/T33TUL/wc_clim_1_T33TUL_roi_0.tif", + "/work/CESBIO/projects/MAESTRIA/training_dataset2/T33TUL/wc_clim_1_T33TUL_roi_0.tif", + ], + "srtm_filename": [ + "/work/CESBIO/projects/MAESTRIA/training_dataset2/T33TUL/srtm_T33TUL_roi_0.tif", + "/work/CESBIO/projects/MAESTRIA/training_dataset2/T33TUL/srtm_T33TUL_roi_0.tif", + "/work/CESBIO/projects/MAESTRIA/training_dataset2/T33TUL/srtm_T33TUL_roi_0.tif", + "/work/CESBIO/projects/MAESTRIA/training_dataset2/T33TUL/srtm_T33TUL_roi_0.tif", + "/work/CESBIO/projects/MAESTRIA/training_dataset2/T33TUL/srtm_T33TUL_roi_0.tif", + ], + "patch_s2_availability": [ + True, + True, + True, + True, + True, + ], + "patchasc_s1_availability": [ + True, + True, + False, + True, + False, + ], + "patchdesc_s1_availability": [ + True, + True, + False, + True, + False, + ], + } + ) + + # inference_dataframe( + # samples_dir=input_path, + # input_tile_list=input_tile_list, + # days_gap=5, + # ) + # # just test the df + print( + tile_df.loc[ + :, + [ + "patch_s2_availability", + "patchasc_s1_availability", + "patchdesc_s1_availability", + ], + ] + ) + + # apply prediction mmdc_tile_inference( - input_path=input_path, + inference_dataframe=tile_df, export_path=export_path, model_checkpoint_path=model_checkpoint_path, - sensor=sensors, - tile_list="T31TCJ", - days_gap=15, - latent_space_size=4, + patch_size=128, nb_lines=256, ) + + assert Path(export_path / "latent_tile_infer_2018-01-11_S2L2A_S1FULL.tif").exists() + assert Path(export_path / "latent_tile_infer_2018-01-14_S2L2A_S1FULL.tif").exists() + assert Path(export_path / "latent_tile_infer_2018-01-19_S2L2A.tif").exists() + assert Path(export_path / "latent_tile_infer_2018-01-21_S2L2A_S1FULL.tif").exists() + assert Path(export_path / "latent_tile_infer_2018-01-24_S2L2A.tif").exists() + + with rio.open( + os.path.join(export_path, "latent_tile_infer_2018-01-19_S2L2A.tif") + ) as raster: + metadata = raster.meta.copy() + assert metadata["count"] == 8 + + with rio.open( + os.path.join(export_path, "latent_tile_infer_2018-01-11_S2L2A_S1FULL.tif") + ) as raster: + metadata = raster.meta.copy() + assert metadata["count"] == 16 -- GitLab From abcb4cfff5e51c05ea85057f711b159f349e01ec Mon Sep 17 00:00:00 2001 From: Juan Sebastian Vinasco-Salinas <js.vinasco.s@gmail.com> Date: Wed, 15 Mar 2023 16:27:00 +0000 Subject: [PATCH 39/81] add jobs examples --- jobs/inference_single_date.pbs | 28 ++++++++++++++++++++++++++++ jobs/inference_time_serie.pbs | 25 +++++++++++++++++++++++++ 2 files changed, 53 insertions(+) create mode 100644 jobs/inference_single_date.pbs create mode 100644 jobs/inference_time_serie.pbs diff --git a/jobs/inference_single_date.pbs b/jobs/inference_single_date.pbs new file mode 100644 index 00000000..5a9ad983 --- /dev/null +++ b/jobs/inference_single_date.pbs @@ -0,0 +1,28 @@ +#!/bin/bash +#PBS -N infer_sgdt +#PBS -q qgpgpu +#PBS -l select=1:ncpus=8:mem=92G:ngpus=1 +#PBS -l walltime=1:00:00 + +# be sure no modules loaded +module purge + +export SRCDIR=${HOME}/src/MMDC/mmdc-singledate +export MMDC_INIT=${SRCDIR}/mmdc_init.sh +export WORKING_DIR=/work/scratch/${USER}/MMDC/jobs +export DATA_DIR=/work/CESBIO/projects/MAESTRIA/test_onetile/total/ + +cd ${WORKING_DIR} +source ${MMDC_INIT} + +python ${SRCDIR}/src/bin/inference_mmdc_singledate.py \ + --s2_filename ${DATA_DIR}/T31TCJ/SENTINEL2A_20180302-105023-464_L2A_T31TCJ_C_V2-2_roi_0.tif \ + --s1_asc_filename ${DATA_DIR}/T31TCJ/S1A_IW_GRDH_1SDV_20180303T174722_20180303T174747_020854_023C45_4525_36.09887223808782_15.429479982324139_roi_0.tif \ + --s1_desc_filename ${DATA_DIR}/T31TCJ/S1A_IW_GRDH_1SDV_20180302T060027_20180302T060052_020832_023B8F_C485_40.83835645911643_165.05888005216622_roi_0.tif \ + --srtm_filename ${DATA_DIR}/T31TCJ/srtm_T31TCJ_roi_0.tif \ + --worldclim_filename ${DATA_DIR}/T31TCJ/wc_clim_1_T31TCJ_roi_0.tif \ + --export_path ${DATA_DIR}/export/ \ + --model_checkpoint_path /home/uz/vinascj/scratch/MMDC/results/latent/logs/experiments/runs/mmdc_full/2023-01-06_13-00-34/checkpoints/epoch_008.ckpt \ + --latent_space_size \ + --nb_lines 256 \ + --sensors S2L2A S1FULL \ diff --git a/jobs/inference_time_serie.pbs b/jobs/inference_time_serie.pbs new file mode 100644 index 00000000..c366fec5 --- /dev/null +++ b/jobs/inference_time_serie.pbs @@ -0,0 +1,25 @@ +#!/bin/bash +#PBS -N infer_sgdt +#PBS -q qgpgpu +#PBS -l select=1:ncpus=8:mem=92G:ngpus=1 +#PBS -l walltime=2:00:00 + +# be sure no modules loaded +module purge + +export SRCDIR=${HOME}/src/MMDC/mmdc-singledate +export MMDC_INIT=${SRCDIR}/mmdc_init.sh +export WORKING_DIR=/work/scratch/${USER}/MMDC/jobs +export DATA_DIR=/work/CESBIO/projects/MAESTRIA/training_dataset2/ + +cd ${WORKING_DIR} +source ${MMDC_INIT} + +python ${SRCDIR}/src/bin/inference_mmdc_timeserie.py \ + --input_path ${DATA_DIR} \ + --export_path ${DATA_DIR}/inference/ \ + --tile_list T33TUL \ + --model_checkpoint_path /home/uz/vinascj/scratch/MMDC/results/latent/logs/experiments/runs/mmdc_full/2023-01-06_13-00-34/checkpoints/epoch_008.ckpt \ + --days_gap 1 \ + --patch_size 128 \ + --nb_lines 256 \ -- GitLab From 26a0cc6a568ac245e3a739f4bd5c3df6599e2c8c Mon Sep 17 00:00:00 2001 From: Juan Sebastian Vinasco-Salinas <js.vinasco.s@gmail.com> Date: Thu, 16 Mar 2023 08:23:15 +0000 Subject: [PATCH 40/81] minor changes --- src/bin/inference_mmdc_singledate.py | 30 ++++++++++++++++++---------- 1 file changed, 20 insertions(+), 10 deletions(-) diff --git a/src/bin/inference_mmdc_singledate.py b/src/bin/inference_mmdc_singledate.py index 37e19ad6..7b8db7d5 100644 --- a/src/bin/inference_mmdc_singledate.py +++ b/src/bin/inference_mmdc_singledate.py @@ -11,11 +11,13 @@ import argparse import logging import os +import torch + from mmdc_singledate.inference.components.inference_components import ( - GeoTiffDataset, get_mmdc_full_model, predict_mmdc_model, ) +from mmdc_singledate.inference.components.inference_utils import GeoTiffDataset from mmdc_singledate.inference.mmdc_tile_inference import ( MMDCProcess, predict_single_date_tile, @@ -80,11 +82,18 @@ def get_parser() -> argparse.ArgumentParser: arg_parser.add_argument( "--nb_lines", - type=str, + type=int, help="Number of Line to read in each window reading", required=False, ) + arg_parser.add_argument( + "--patch_size", + type=int, + help="Number of Line to read in each window reading", + required=True, + ) + arg_parser.add_argument( "--sensors", dest="sensors", @@ -115,24 +124,25 @@ def main(): format="%(asctime)s :: %(levelname)s :: %(message)s", ) + # device + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + # model - model = get_mmdc_full_model( + mmdc_full_model = get_mmdc_full_model( checkpoint=args.model_checkpoint_path, ) - - # prediction function - pred_mmdc = predict_mmdc_model(model=model, sensor=args.sensor) - + mmdc_full_model.to(device) # process class mmdc_process = MMDCProcess( count=args.latent_space_size, nb_lines=args.nb_lines, patch_size=args.patch_size, - process=pred_mmdc, + process=predict_mmdc_model, + model=mmdc_full_model, ) # create export filename - export_path = f"/work/CESBIO/projects/MAESTRIA/test_onetile/total/export/latent_singledate_{'_'.join(args.sensors)}.tif" + export_path = f"{args.export_path}/latent_singledate_{'_'.join(args.sensors)}.tif" # GeoTiff geotiff_dataclass = GeoTiffDataset( @@ -140,7 +150,7 @@ def main(): s1_asc_filename=args.s1_asc_filename, s1_desc_filename=args.s1_desc_filename, srtm_filename=args.srtm_filename, - wc_filename=args.wc_filename, + wc_filename=args.worldclim_filename, ) predict_single_date_tile( -- GitLab From 18b6dbf0e17b2fc9c323223f9caeb41d88342942 Mon Sep 17 00:00:00 2001 From: Juan Sebastian Vinasco-Salinas <js.vinasco.s@gmail.com> Date: Thu, 16 Mar 2023 08:24:09 +0000 Subject: [PATCH 41/81] minor changes --- src/bin/inference_mmdc_timeserie.py | 46 ++++++++++++++--------------- 1 file changed, 22 insertions(+), 24 deletions(-) diff --git a/src/bin/inference_mmdc_timeserie.py b/src/bin/inference_mmdc_timeserie.py index 62bd41f9..56626f63 100644 --- a/src/bin/inference_mmdc_timeserie.py +++ b/src/bin/inference_mmdc_timeserie.py @@ -12,7 +12,10 @@ import logging import os from pathlib import Path -from mmdc_singledate.inference.mmdc_tile_inference import mmdc_tile_inference +from mmdc_singledate.inference.mmdc_tile_inference import ( + inference_dataframe, + mmdc_tile_inference, +) def get_parser() -> argparse.ArgumentParser: @@ -43,7 +46,7 @@ def get_parser() -> argparse.ArgumentParser: arg_parser.add_argument( "--days_gap", - type=str, + type=int, help="difference between S1 and S2 acquisitions in days", required=True, ) @@ -56,38 +59,27 @@ def get_parser() -> argparse.ArgumentParser: ) arg_parser.add_argument( - "--tile_list", type=str, help="List of tiles to produce", required=False - ) - - arg_parser.add_argument( - "--patch_size", + "--tile_list", + nargs="+", type=str, - help="Sizes of the patches to make the inference", + help="List of tiles to produce", required=False, ) arg_parser.add_argument( - "--latent_space_size", + "--patch_size", type=int, - help="Number of channels in the output file", + help="Sizes of the patches to make the inference", required=False, ) arg_parser.add_argument( "--nb_lines", - type=str, + type=int, help="Number of Line to read in each window reading", required=False, ) - arg_parser.add_argument( - "-sensors", - dest="sensors", - nargs="+", - help="List of sensors S2L2A | S1FULL | S1ASC | S1DESC (Is sensible to the order S2L2A always firts)", - required=True, - ) - return arg_parser @@ -116,18 +108,24 @@ def main(): ) Path(args.export_path).mkdir() + # calculate time serie + tile_df = inference_dataframe( + samples_dir=args.input_path, + input_tile_list=args.tile_list, + days_gap=args.days_gap, + ) + logging.info(f"nb of dates to infer : {tile_df.shape[0]}") # entry point mmdc_tile_inference( - input_path=Path(args.input_path), + inference_dataframe=tile_df, export_path=Path(args.export_path), model_checkpoint_path=Path(args.model_checkpoint_path), - sensor=args.sensors, - tile_list=args.tile_list, - days_gap=args.days_gap, - latent_space_size=args.laten_space_size, + patch_size=args.patch_size, nb_lines=args.nb_lines, ) + logging.info("All the data has exported succesfully") + if __name__ == "__main__": main() -- GitLab From 57e8378247ce6e1bd223cb0f8b0fb205806c5d06 Mon Sep 17 00:00:00 2001 From: Juan Sebastian Vinasco-Salinas <js.vinasco.s@gmail.com> Date: Thu, 16 Mar 2023 08:25:02 +0000 Subject: [PATCH 42/81] minor changes --- .../components/inference_components.py | 61 +++++++++++-------- 1 file changed, 35 insertions(+), 26 deletions(-) diff --git a/src/mmdc_singledate/inference/components/inference_components.py b/src/mmdc_singledate/inference/components/inference_components.py index 934aa525..c7130a33 100644 --- a/src/mmdc_singledate/inference/components/inference_components.py +++ b/src/mmdc_singledate/inference/components/inference_components.py @@ -181,8 +181,7 @@ def predict_mmdc_model( match sensors: # Cases case ["S2L2A", "S1FULL"]: - logger.info("S1 full captured") - logger.info("S2 captured") + logger.info("S2 captured & S1 full captured") latent_space.latent_s2 = prediction[0].latent.latent_s2 latent_space.latent_s1 = prediction[0].latent.latent_s1 @@ -199,18 +198,19 @@ def predict_mmdc_model( print( "latent_space.latent_s2.mu :", - latent_space.latent_s2.mu, + latent_space.latent_s2.mu.shape, ) print( "latent_space.latent_s2.logvar :", - latent_space.latent_s2.logvar, + latent_space.latent_s2.logvar.shape, + ) + print("latent_space.latent_s1.mu :", latent_space.latent_s1.mu.shape) + print( + "latent_space.latent_s1.logvar :", latent_space.latent_s1.logvar.shape ) - print("latent_space.latent_s1.mu :", latent_space.latent_s1.mu) - print("latent_space.latent_s1.logvar :", latent_space.latent_s1.logvar) case ["S2L2A", "S1ASC"]: - logger.info("S1 asc captured") - logger.info("S2 captured") + logger.info("S2 captured & S1 asc captured") latent_space.latent_s2 = prediction[0].latent.latent_s2 latent_space.latent_s1 = prediction[0].latent.latent_s1 @@ -226,18 +226,19 @@ def predict_mmdc_model( ) print( "latent_space.latent_s2.mu :", - latent_space.latent_s2.mu, + latent_space.latent_s2.mu.shape, ) print( "latent_space.latent_s2.logvar :", - latent_space.latent_s2.logvar, + latent_space.latent_s2.logvar.shape, + ) + print("latent_space.latent_s1.mu :", latent_space.latent_s1.mu.shape) + print( + "latent_space.latent_s1.logvar :", latent_space.latent_s1.logvar.shape ) - print("latent_space.latent_s1.mu :", latent_space.latent_s1.mu) - print("latent_space.latent_s1.logvar :", latent_space.latent_s1.logvar) case ["S2L2A", "S1DESC"]: - logger.info("S1 desc captured") - logger.info("S2 captured") + logger.info("S2 captured & S1 desc captured") latent_space.latent_s2 = prediction[0].latent.latent_s2 latent_space.latent_s1 = prediction[0].latent.latent_s1 @@ -253,14 +254,16 @@ def predict_mmdc_model( ) print( "latent_space.latent_s2.mu :", - latent_space.latent_s2.mu, + latent_space.latent_s2.mu.shape, ) print( "latent_space.latent_s2.logvar :", - latent_space.latent_s2.logvar, + latent_space.latent_s2.logvar.shape, + ) + print("latent_space.latent_s1.mu :", latent_space.latent_s1.mu.shape) + print( + "latent_space.latent_s1.logvar :", latent_space.latent_s1.logvar.shape ) - print("latent_space.latent_s1.mu :", latent_space.latent_s1.mu) - print("latent_space.latent_s1.logvar :", latent_space.latent_s1.logvar) case ["S2L2A"]: logger.info("Only S2 captured") @@ -276,11 +279,11 @@ def predict_mmdc_model( ) print( "latent_space.latent_s2.mu :", - latent_space.latent_s2.mu, + latent_space.latent_s2.mu.shape, ) print( "latent_space.latent_s2.logvar :", - latent_space.latent_s2.logvar, + latent_space.latent_s2.logvar.shape, ) case ["S1FULL"]: @@ -292,8 +295,10 @@ def predict_mmdc_model( (latent_space.latent_s1.mu, latent_space.latent_s1.logvar), 1, ) - print("latent_space.latent_s1.mu :", latent_space.latent_s1.mu) - print("latent_space.latent_s1.logvar :", latent_space.latent_s1.logvar) + print("latent_space.latent_s1.mu :", latent_space.latent_s1.mu.shape) + print( + "latent_space.latent_s1.logvar :", latent_space.latent_s1.logvar.shape + ) case ["S1ASC"]: logger.info("Only S1ASC captured") @@ -304,8 +309,10 @@ def predict_mmdc_model( (latent_space.latent_s1.mu, latent_space.latent_s1.logvar), 1, ) - print("latent_space.latent_s1.mu :", latent_space.latent_s1.mu) - print("latent_space.latent_s1.logvar :", latent_space.latent_s1.logvar) + print("latent_space.latent_s1.mu :", latent_space.latent_s1.mu.shape) + print( + "latent_space.latent_s1.logvar :", latent_space.latent_s1.logvar.shape + ) case ["S1DESC"]: logger.info("Only S1DESC captured") @@ -316,7 +323,9 @@ def predict_mmdc_model( (latent_space.latent_s1.mu, latent_space.latent_s1.logvar), 1, ) - print("latent_space.latent_s1.mu :", latent_space.latent_s1.mu) - print("latent_space.latent_s1.logvar :", latent_space.latent_s1.logvar) + print("latent_space.latent_s1.mu :", latent_space.latent_s1.mu.shape) + print( + "latent_space.latent_s1.logvar :", latent_space.latent_s1.logvar.shape + ) return latent_space_stack # latent_space -- GitLab From 430297c8cb42835a25be2ab083ebbbf0cb05eed4 Mon Sep 17 00:00:00 2001 From: Juan Sebastian Vinasco-Salinas <js.vinasco.s@gmail.com> Date: Thu, 16 Mar 2023 08:25:28 +0000 Subject: [PATCH 43/81] minor changes --- src/mmdc_singledate/inference/mmdc_tile_inference.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/src/mmdc_singledate/inference/mmdc_tile_inference.py b/src/mmdc_singledate/inference/mmdc_tile_inference.py index 6780d5d8..123b4318 100644 --- a/src/mmdc_singledate/inference/mmdc_tile_inference.py +++ b/src/mmdc_singledate/inference/mmdc_tile_inference.py @@ -371,14 +371,9 @@ def mmdc_tile_inference( device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - # # dataframe with input data - # tile_df = inference_dataframe( - # samples_dir=input_path, - # input_tile_list=tile_list, - # days_gap=days_gap, - # ) - # # just test the df - # tile_df = tile_df.head() + # TODO Uncomment for test purpuse + # inference_dataframe = inference_dataframe.head() + # print(inference_dataframe.shape) # instance the model and get the pred func model = get_mmdc_full_model( -- GitLab From 477d7a83c55ef2aa5c0b2376b37dbddb782188c7 Mon Sep 17 00:00:00 2001 From: Juan Sebastian Vinasco-Salinas <js.vinasco.s@gmail.com> Date: Thu, 16 Mar 2023 08:26:02 +0000 Subject: [PATCH 44/81] update --- jobs/inference_single_date.pbs | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/jobs/inference_single_date.pbs b/jobs/inference_single_date.pbs index 5a9ad983..989f6ec6 100644 --- a/jobs/inference_single_date.pbs +++ b/jobs/inference_single_date.pbs @@ -11,8 +11,13 @@ export SRCDIR=${HOME}/src/MMDC/mmdc-singledate export MMDC_INIT=${SRCDIR}/mmdc_init.sh export WORKING_DIR=/work/scratch/${USER}/MMDC/jobs export DATA_DIR=/work/CESBIO/projects/MAESTRIA/test_onetile/total/ +export SCRATCH_DIR=/work/scratch/${USER}/ cd ${WORKING_DIR} + +mkdir ${SCRATCH_DIR}/MMDC/inference/ +mkdir ${SCRATCH_DIR}/MMDC/inference/singledate + source ${MMDC_INIT} python ${SRCDIR}/src/bin/inference_mmdc_singledate.py \ @@ -21,8 +26,9 @@ python ${SRCDIR}/src/bin/inference_mmdc_singledate.py \ --s1_desc_filename ${DATA_DIR}/T31TCJ/S1A_IW_GRDH_1SDV_20180302T060027_20180302T060052_020832_023B8F_C485_40.83835645911643_165.05888005216622_roi_0.tif \ --srtm_filename ${DATA_DIR}/T31TCJ/srtm_T31TCJ_roi_0.tif \ --worldclim_filename ${DATA_DIR}/T31TCJ/wc_clim_1_T31TCJ_roi_0.tif \ - --export_path ${DATA_DIR}/export/ \ + --export_path ${SCRATCH_DIR}/MMDC/inference/singledate \ --model_checkpoint_path /home/uz/vinascj/scratch/MMDC/results/latent/logs/experiments/runs/mmdc_full/2023-01-06_13-00-34/checkpoints/epoch_008.ckpt \ - --latent_space_size \ + --latent_space_size 16 \ + --patch_size 128 \ --nb_lines 256 \ --sensors S2L2A S1FULL \ -- GitLab From d9925c5facdd2b772423072996709c01a083adaf Mon Sep 17 00:00:00 2001 From: Juan Sebastian Vinasco-Salinas <js.vinasco.s@gmail.com> Date: Thu, 16 Mar 2023 08:26:13 +0000 Subject: [PATCH 45/81] update --- jobs/inference_time_serie.pbs | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/jobs/inference_time_serie.pbs b/jobs/inference_time_serie.pbs index c366fec5..582bf198 100644 --- a/jobs/inference_time_serie.pbs +++ b/jobs/inference_time_serie.pbs @@ -11,15 +11,21 @@ export SRCDIR=${HOME}/src/MMDC/mmdc-singledate export MMDC_INIT=${SRCDIR}/mmdc_init.sh export WORKING_DIR=/work/scratch/${USER}/MMDC/jobs export DATA_DIR=/work/CESBIO/projects/MAESTRIA/training_dataset2/ +export SCRATCH_DIR=/work/scratch/${USER}/ + + +mkdir ${SCRATCH_DIR}/MMDC/inference/ +mkdir ${SCRATCH_DIR}/MMDC/inference/time_serie + cd ${WORKING_DIR} source ${MMDC_INIT} python ${SRCDIR}/src/bin/inference_mmdc_timeserie.py \ --input_path ${DATA_DIR} \ - --export_path ${DATA_DIR}/inference/ \ - --tile_list T33TUL \ + --export_path ${SCRATCH_DIR}/MMDC/inference/time_serie \ + --tile_list ${DATA_DIR}/T33TUL \ --model_checkpoint_path /home/uz/vinascj/scratch/MMDC/results/latent/logs/experiments/runs/mmdc_full/2023-01-06_13-00-34/checkpoints/epoch_008.ckpt \ - --days_gap 1 \ + --days_gap 0 \ --patch_size 128 \ --nb_lines 256 \ -- GitLab From 5c5e9e3c3bde0961050a3bc2bd07a96f6f369089 Mon Sep 17 00:00:00 2001 From: Juan Sebastian Vinasco-Salinas <js.vinasco.s@gmail.com> Date: Thu, 16 Mar 2023 08:39:13 +0000 Subject: [PATCH 46/81] add tiler config --- .../externalfeatures_with_userfeatures.cfg | 24 +++++++++++++++++++ 1 file changed, 24 insertions(+) create mode 100644 configs/iota2/externalfeatures_with_userfeatures.cfg diff --git a/configs/iota2/externalfeatures_with_userfeatures.cfg b/configs/iota2/externalfeatures_with_userfeatures.cfg new file mode 100644 index 00000000..b6eb0fab --- /dev/null +++ b/configs/iota2/externalfeatures_with_userfeatures.cfg @@ -0,0 +1,24 @@ +# iota2 configuration file to launch iota2 as a tiler of features. +# paths with 'XXX' must be replaced by local user paths. + + +chain : +{ + output_path : '/home/uz/vinascj/scratch/MMDC/iota2/grid_test_full' + + first_step : 'tiler' + last_step : 'tiler' + + proj : 'EPSG:2154' + # rasters_grid_path : '/XXX/grid_raster' + grid : '/home/uz/vinascj/src/MMDC/mmdc-singledate/thirdparties/sensorsio/src/sensorsio/data/sentinel2/mgrs_tiles.shp' #'/XXXX/Features.shp' # MGRS file providing tiles grid + list_tile : 'T31TCJ' # T31TCK T31TDJ' + features_path : '/XXX/non_tiled_features' + + spatial_resolution : 10 # mandatory but not used in this case. The output spatial resolution will be the one found in 'rasters_grid_path' directory. +} + +builders: +{ +builders_class_name : ["i2_features_to_grid"] +} -- GitLab From eaa797a0c35fe62102ce9324e651e78e5d1f7a7d Mon Sep 17 00:00:00 2001 From: Juan Sebastian Vinasco-Salinas <js.vinasco.s@gmail.com> Date: Thu, 16 Mar 2023 08:39:24 +0000 Subject: [PATCH 47/81] add tiler config --- configs/iota2/i2_classification.cfg | 61 +++++++++++++++++++++++++++++ 1 file changed, 61 insertions(+) create mode 100644 configs/iota2/i2_classification.cfg diff --git a/configs/iota2/i2_classification.cfg b/configs/iota2/i2_classification.cfg new file mode 100644 index 00000000..71c1b728 --- /dev/null +++ b/configs/iota2/i2_classification.cfg @@ -0,0 +1,61 @@ +chain : +{ + output_path : '/work/scratch/sowm/IOTA2_TEST_S2/IOTA2_Outputs/Results_classif' + remove_output_path : True + nomenclature_path : '/work/scratch/sowm/IOTA2_TEST_S2/nomenclature23.txt' + list_tile : 'T31TCJ T31TDJ' + s2_path : '/work/scratch/sowm/IOTA2_TEST_S2/sensor_data' + s1_path : '/work/scratch/sowm/test-data-cube/iota2/s1_full.cfg' + ground_truth : '/work/scratch/sowm/IOTA2_TEST_S2/vector_data/reference_data.shp' + data_field : 'code' + spatial_resolution : 10 + color_table : '/work/scratch/sowm/IOTA2_TEST_S2/colorFile.txt' + proj : 'EPSG:2154' + first_step: "init" + last_step: "validation" +} +sensors_data_interpolation: +{ + write_outputs: False +} + +external_features: +{ + + module: "/home/qt/sowm/scratch/test-data-cube/iota2/external_iota2_code.py" + functions: [["apply_convae", + { "checkpoint_path": "/home/qt/sowm/scratch/results/Convae_split_roi_psz_128_ep_150_bs_16_lr_1e-05_s1_[64, 128, 256, 512, 1024]_[8]_s2_[64, 128, 256, 512, 1024]_[8]_21916652.admin01", + "checkpoint_epoch": 100, + "patch_size": 74 + }]] + concat_mode: True +} + +python_data_managing: +{ + chunk_size_mode: "split_number" + number_of_chunks: 20 + padding_size_y: 27 + +} + +arg_train : +{ + classifier : 'sharkrf' + otb_classifier_options : {"classifier.sharkrf.ntrees" : 100 } + sample_selection: {"sampler": "random", + "strategy": "percent", + "strategy.percent.p": 0.1} + random_seed: 42 + +} +arg_classification : +{ + classif_mode : 'separate' +} +task_retry_limits : +{ + allowed_retry : 0 + maximum_ram : 180.0 + maximum_cpu : 40 +} -- GitLab From 20a0bc26aedd5b284bba058f7709d22d0c77ff38 Mon Sep 17 00:00:00 2001 From: Juan Sebastian Vinasco-Salinas <js.vinasco.s@gmail.com> Date: Thu, 16 Mar 2023 08:39:55 +0000 Subject: [PATCH 48/81] add tiler config --- configs/iota2/externalfeatures_with_userfeatures_50x50.cfg | 5 +++-- configs/iota2/iota2_grid_full.cfg | 2 +- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/configs/iota2/externalfeatures_with_userfeatures_50x50.cfg b/configs/iota2/externalfeatures_with_userfeatures_50x50.cfg index cdb197f1..71bc61d1 100644 --- a/configs/iota2/externalfeatures_with_userfeatures_50x50.cfg +++ b/configs/iota2/externalfeatures_with_userfeatures_50x50.cfg @@ -11,7 +11,8 @@ chain : ground_truth : '/home/uz/vinascj/scratch/MMDC/iota2/juan_i2_data/inputs/vector_data/S2_50x50.shp' data_field : 'groscode' s2_path : '/home/uz/vinascj/scratch/MMDC/iota2/juan_i2_data/inputs/raster_data/S2_2dates_50x50_symlink' - user_feat_path: '/home/uz/vinascj/scratch/MMDC/iota2/juan_i2_data/inputs/raster_data/mnt_50x50' + # user_feat_path: '/home/uz/vinascj/scratch/MMDC/iota2/juan_i2_data/inputs/raster_data/mnt_50x50' + user_feat_path: '/home/uz/vinascj/scratch/MMDC/iota2/grid_test_full' # s2_output_path : '/home/uz/vinascj/scratch/MMDC/iota2/juan_i2_data/inputs/raster_data/S2_2dates_50x50_output' s1_path : '/home/uz/vinascj/scratch/MMDC/iota2/juan_i2_data/inputs/config/SAR_test_T31TCJ.cfg' @@ -28,7 +29,7 @@ chain : userFeat: { arbo:"/*" - patterns:"slope" + patterns:"aspect,EVEN_AZIMUTH,ODD_ZENITH,s1_incidence_ascending,worldclim,AZIMUTH,EVEN_ZENITH,s1_azimuth_ascending,s1_incidence_descending,ZENITH,elevation,ODD_AZIMUTH,s1_azimuth_descending,slope" } python_data_managing : { diff --git a/configs/iota2/iota2_grid_full.cfg b/configs/iota2/iota2_grid_full.cfg index 0a529fbb..d5379d49 100644 --- a/configs/iota2/iota2_grid_full.cfg +++ b/configs/iota2/iota2_grid_full.cfg @@ -1,7 +1,7 @@ chain : { output_path : '/home/uz/vinascj/scratch/MMDC/iota2/grid_test_full' - spatial_resolution : 50 + spatial_resolution : 10 first_step : 'tiler' last_step : 'tiler' -- GitLab From 1f3b81a67deba19e10fed32c7f0398297f8eb56a Mon Sep 17 00:00:00 2001 From: Juan Sebastian Vinasco-Salinas <js.vinasco.s@gmail.com> Date: Thu, 16 Mar 2023 14:32:01 +0000 Subject: [PATCH 49/81] add new full module init strategy --- .../components/inference_components.py | 192 +++++++++++++----- 1 file changed, 136 insertions(+), 56 deletions(-) diff --git a/src/mmdc_singledate/inference/components/inference_components.py b/src/mmdc_singledate/inference/components/inference_components.py index c7130a33..ae4a7498 100644 --- a/src/mmdc_singledate/inference/components/inference_components.py +++ b/src/mmdc_singledate/inference/components/inference_components.py @@ -5,20 +5,28 @@ Functions components for get the latent spaces for a given tile """ - - # imports import logging from pathlib import Path from typing import Literal import torch +from hydra import compose, initialize from torch import nn -from mmdc_singledate.models.components.model_dataclass import ( # VAELatentSpace, +from mmdc_singledate.datamodules.types import MMDCDataChannels +from mmdc_singledate.models.mmdc_full_module import MMDCFullModule +from mmdc_singledate.models.types import ( + ConvnetParams, + MLPConfig, + MMDCDataUse, + ModularEmbeddingConfig, + MultiVAEConfig, + MultiVAELossWeights, S1S2VAELatentSpace, + TranslationLossWeights, + UnetParams, ) -from mmdc_singledate.models.mmdc_full_module import MMDCFullModule # define sensors variable for typing SENSORS = Literal["S2L2A", "S1FULL", "S1ASC", "S1DESC"] @@ -44,55 +52,130 @@ def get_mmdc_full_model( logger.info("No checkpoint path was give. ") return 1 else: + # parse config + with initialize(config_path="../../../../configs/model/"): + cfg = compose( + config_name="mmdc_full.yaml", return_hydra_config=True, overrides=[] + ) + + mmdc_full_config = MultiVAEConfig( + data_sizes=MMDCDataChannels( + sen1=cfg.model.config.data_sizes.sen1, + s1_angles=cfg.model.config.data_sizes.s1_angles, + sen2=cfg.model.config.data_sizes.sen2, + s2_angles=cfg.model.config.data_sizes.s2_angles, + srtm=cfg.model.config.data_sizes.srtm, + w_c=cfg.model.config.data_sizes.w_c, + ), + embeddings=ModularEmbeddingConfig( + s1_angles=MLPConfig( + hidden_layers=cfg.model.config.embeddings.s1_angles.hidden_layers, + out_channels=cfg.model.config.embeddings.s1_angles.out_channels, + ), + s1_srtm=UnetParams( + out_channels=cfg.model.config.embeddings.s1_srtm.out_channels, + encoder_sizes=cfg.model.config.embeddings.s1_srtm.encoder_sizes, + kernel_size=cfg.model.config.embeddings.s1_srtm.kernel_size, + tail_layers=cfg.model.config.embeddings.s1_srtm.tail_layers, + ), + s2_angles=ConvnetParams( + out_channels=cfg.model.config.embeddings.s2_angles.out_channels, + sizes=cfg.model.config.embeddings.s2_angles.sizes, + kernel_sizes=cfg.model.config.embeddings.s2_angles.kernel_sizes, + ), + s2_srtm=UnetParams( + out_channels=cfg.model.config.embeddings.s2_srtm.out_channels, + encoder_sizes=cfg.model.config.embeddings.s2_srtm.encoder_sizes, + kernel_size=cfg.model.config.embeddings.s2_srtm.kernel_size, + tail_layers=cfg.model.config.embeddings.s2_srtm.tail_layers, + ), + w_c=ConvnetParams( + out_channels=cfg.model.config.embeddings.w_c.out_channels, + sizes=cfg.model.config.embeddings.w_c.sizes, + kernel_sizes=cfg.model.config.embeddings.w_c.kernel_sizes, + ), + ), + s1_encoder=UnetParams( + out_channels=cfg.model.config.s1_encoder.out_channels, + encoder_sizes=cfg.model.config.s1_encoder.encoder_sizes, + kernel_size=cfg.model.config.s1_encoder.kernel_size, + tail_layers=cfg.model.config.s1_encoder.tail_layers, + ), + s2_encoder=UnetParams( + out_channels=cfg.model.config.s2_encoder.out_channels, + encoder_sizes=cfg.model.config.s2_encoder.encoder_sizes, + kernel_size=cfg.model.config.s2_encoder.kernel_size, + tail_layers=cfg.model.config.s2_encoder.tail_layers, + ), + s1_decoder=ConvnetParams( + out_channels=cfg.model.config.s1_decoder.out_channels, + sizes=cfg.model.config.s1_decoder.sizes, + kernel_sizes=cfg.model.config.s1_decoder.kernel_sizes, + ), + s2_decoder=ConvnetParams( + out_channels=cfg.model.config.s2_decoder.out_channels, + sizes=cfg.model.config.s2_decoder.sizes, + kernel_sizes=cfg.model.config.s2_decoder.kernel_sizes, + ), + s1_enc_use=MMDCDataUse( + s1_angles=cfg.model.config.s1_enc_use.s1_angles, + s2_angles=cfg.model.config.s1_enc_use.s2_angles, + srtm=cfg.model.config.s1_enc_use.srtm, + w_c=cfg.model.config.s1_enc_use.w_c, + ), + s2_enc_use=MMDCDataUse( + s1_angles=cfg.model.config.s2_enc_use.s1_angles, + s2_angles=cfg.model.config.s2_enc_use.s2_angles, + srtm=cfg.model.config.s2_enc_use.srtm, + w_c=cfg.model.config.s2_enc_use.w_c, + ), + s1_dec_use=MMDCDataUse( + s1_angles=cfg.model.config.s1_dec_use.s1_angles, + s2_angles=cfg.model.config.s1_dec_use.s2_angles, + srtm=cfg.model.config.s1_dec_use.srtm, + w_c=cfg.model.config.s1_dec_use.w_c, + ), + s2_dec_use=MMDCDataUse( + s1_angles=cfg.model.config.s2_dec_use.s1_angles, + s2_angles=cfg.model.config.s2_dec_use.s2_angles, + srtm=cfg.model.config.s2_dec_use.srtm, + w_c=cfg.model.config.s2_dec_use.w_c, + ), + loss_weights=MultiVAELossWeights( + sen1=TranslationLossWeights( + nll=cfg.model.config.loss_weights.sen1.nll, + lpips=cfg.model.config.loss_weights.sen1.lpips, + sam=cfg.model.config.loss_weights.sen1.sam, + gradients=cfg.model.config.loss_weights.sen1.gradients, + ), + sen2=TranslationLossWeights( + nll=cfg.model.config.loss_weights.sen2.nll, + lpips=cfg.model.config.loss_weights.sen2.lpips, + sam=cfg.model.config.loss_weights.sen2.sam, + gradients=cfg.model.config.loss_weights.sen2.gradients, + ), + s1_s2=TranslationLossWeights( + nll=cfg.model.config.loss_weights.s1_s2.nll, + lpips=cfg.model.config.loss_weights.s1_s2.lpips, + sam=cfg.model.config.loss_weights.s1_s2.sam, + gradients=cfg.model.config.loss_weights.s1_s2.gradients, + ), + s2_s1=TranslationLossWeights( + nll=cfg.model.config.loss_weights.s2_s1.nll, + lpips=cfg.model.config.loss_weights.s2_s1.lpips, + sam=cfg.model.config.loss_weights.s2_s1.sam, + gradients=cfg.model.config.loss_weights.s2_s1.gradients, + ), + forward=cfg.model.config.loss_weights.forward, + cross=cfg.model.config.loss_weights.forward, + latent=cfg.model.config.loss_weights.latent, + ), + ) + # Init the model mmdc_full_model = MMDCFullModule( - s2_angles_conv_in_channels=6, - s2_angles_conv_out_channels=3, - s2_angles_conv_encoder_sizes=[16, 8], - s2_angles_conv_kernel_size=1, - # - srtm_s2_angles_conv_in_channels=10, # 4 srtm + 6 angles - srtm_s2_angles_conv_out_channels=3, - srtm_s2_angles_unet_encoder_sizes=[32, 64, 128], - srtm_s2_angles_kernel_size=3, - # - s1_angles_mlp_in_size=6, - s1_angles_mlp_hidden_layers=[9, 12, 9], - s1_angles_mlp_out_size=3, - # - srtm_s1_encoder_sizes=[32, 64, 128], - srtm_kernel_size=3, - srtm_s1_in_channels=4, - srtm_s1_out_channels=3, - # - wc_enc_sizes=[64, 32, 16], - wc_kernel_size=1, - wc_in_channels=103, - wc_out_channels=4, - # - s1_input_size=6, - s1_encoder_sizes=[64, 128, 256, 512, 1024], - s1_enc_kernel_sizes=[3], - s1_decoder_sizes=[32, 16, 8], - s1_dec_kernel_sizes=[3, 3, 3, 3], - code_sizes=[0, 4, 0], - s2_input_size=10, - s2_encoder_sizes=[64, 128, 256, 512, 1024], - s2_enc_kernel_sizes=[3], - s2_decoder_sizes=[64, 32, 16], - s2_dec_kernel_sizes=[3, 3, 3, 3], - s1_ang_in_decoder=True, - s2_ang_in_decoder=True, - srtm_in_decoder=True, - wc_in_decoder=True, - w_d1_e1_s1=1, - w_d2_e2_s2=1, - w_d1_e2_s2=1, - w_d2_e1_s1=1, - w_code_s1s2=1, + mmdc_full_config, ) - # mmdc_full_lightning = MMDCFullLitModule(mmdc_full_model) - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"device detected : {device}") print(f"nb_gpu's detected : {torch.cuda.device_count()}") @@ -100,22 +183,19 @@ def get_mmdc_full_model( lightning_checkpoint = torch.load(checkpoint, map_location=device) # delete "model" from the loaded checkpoint - checkpoint = { + lightning_checkpoint_cleaning = { key.split("model.")[-1]: item for key, item in lightning_checkpoint["state_dict"].items() if key.startswith("model.") } # load the state dict - mmdc_full_model.load_state_dict(checkpoint) - # mmdc_full_model.load_state_dict(lightning_checkpoint) - # mmdc_full_lightning.load_state_dict(lightning_checkpoint) + mmdc_full_model.load_state_dict(lightning_checkpoint_cleaning) # disble randomness, dropout, etc... mmdc_full_model.eval() - # mmdc_full_lightning.eval() - return mmdc_full_model # mmdc_full_lightning + return mmdc_full_model @torch.no_grad() -- GitLab From 7f54f3ae6091e9b44aec02ceebfabf7a97659ce9 Mon Sep 17 00:00:00 2001 From: Juan Sebastian Vinasco-Salinas <js.vinasco.s@gmail.com> Date: Fri, 17 Mar 2023 16:29:08 +0000 Subject: [PATCH 50/81] update to main --- src/bin/inference_mmdc_singledate.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/bin/inference_mmdc_singledate.py b/src/bin/inference_mmdc_singledate.py index 7b8db7d5..6d8e179f 100644 --- a/src/bin/inference_mmdc_singledate.py +++ b/src/bin/inference_mmdc_singledate.py @@ -22,6 +22,7 @@ from mmdc_singledate.inference.mmdc_tile_inference import ( MMDCProcess, predict_single_date_tile, ) +from mmdc_singledate.inference.utils import get_scales def get_parser() -> argparse.ArgumentParser: @@ -98,7 +99,10 @@ def get_parser() -> argparse.ArgumentParser: "--sensors", dest="sensors", nargs="+", - help="List of sensors S2L2A | S1FULL | S1ASC | S1DESC (Is sensible to the order S2L2A always firts)", + help=( + "List of sensors S2L2A | S1FULL | S1ASC | S1DESC", + " (Is sensible to the order S2L2A always firts)", + ), required=True, ) @@ -131,6 +135,8 @@ def main(): mmdc_full_model = get_mmdc_full_model( checkpoint=args.model_checkpoint_path, ) + scales = get_scales() + mmdc_full_model.set_scales(scales) mmdc_full_model.to(device) # process class mmdc_process = MMDCProcess( -- GitLab From 8e7aa3466060a647a68ee2e3b796643e11541eb3 Mon Sep 17 00:00:00 2001 From: Juan Sebastian Vinasco-Salinas <js.vinasco.s@gmail.com> Date: Fri, 17 Mar 2023 16:30:32 +0000 Subject: [PATCH 51/81] update model config --- .../components/inference_components.py | 357 ++++++++++-------- 1 file changed, 200 insertions(+), 157 deletions(-) diff --git a/src/mmdc_singledate/inference/components/inference_components.py b/src/mmdc_singledate/inference/components/inference_components.py index ae4a7498..e3a9f52b 100644 --- a/src/mmdc_singledate/inference/components/inference_components.py +++ b/src/mmdc_singledate/inference/components/inference_components.py @@ -11,7 +11,7 @@ from pathlib import Path from typing import Literal import torch -from hydra import compose, initialize +import yaml from torch import nn from mmdc_singledate.datamodules.types import MMDCDataChannels @@ -39,6 +39,17 @@ logging.basicConfig( logger = logging.getLogger(__name__) +def parse_from_yaml(path_to_yaml): + """ + Read a yaml file and return a dict + https://betterprogramming.pub/validating-yaml-configs-made-easy-with-pydantic-594522612db5 + """ + with open(path_to_yaml) as f: + config = yaml.safe_load(f) + + return config + + @torch.no_grad() def get_mmdc_full_model( checkpoint: str, @@ -52,123 +63,155 @@ def get_mmdc_full_model( logger.info("No checkpoint path was give. ") return 1 else: - # parse config - with initialize(config_path="../../../../configs/model/"): - cfg = compose( - config_name="mmdc_full.yaml", return_hydra_config=True, overrides=[] - ) + cfg = parse_from_yaml("configs/model/mmdc_full.yaml") mmdc_full_config = MultiVAEConfig( data_sizes=MMDCDataChannels( - sen1=cfg.model.config.data_sizes.sen1, - s1_angles=cfg.model.config.data_sizes.s1_angles, - sen2=cfg.model.config.data_sizes.sen2, - s2_angles=cfg.model.config.data_sizes.s2_angles, - srtm=cfg.model.config.data_sizes.srtm, - w_c=cfg.model.config.data_sizes.w_c, + sen1=cfg["model"]["config"]["data_sizes"]["sen1"], + s1_angles=cfg["model"]["config"]["data_sizes"]["s1_angles"], + sen2=cfg["model"]["config"]["data_sizes"]["sen2"], + s2_angles=cfg["model"]["config"]["data_sizes"]["s2_angles"], + srtm=cfg["model"]["config"]["data_sizes"]["srtm"], + w_c=cfg["model"]["config"]["data_sizes"]["w_c"], ), embeddings=ModularEmbeddingConfig( s1_angles=MLPConfig( - hidden_layers=cfg.model.config.embeddings.s1_angles.hidden_layers, - out_channels=cfg.model.config.embeddings.s1_angles.out_channels, + hidden_layers=cfg["model"]["config"]["embeddings"]["s1_angles"][ + "hidden_layers" + ], + out_channels=cfg["model"]["config"]["embeddings"]["s1_angles"][ + "out_channels" + ], ), s1_srtm=UnetParams( - out_channels=cfg.model.config.embeddings.s1_srtm.out_channels, - encoder_sizes=cfg.model.config.embeddings.s1_srtm.encoder_sizes, - kernel_size=cfg.model.config.embeddings.s1_srtm.kernel_size, - tail_layers=cfg.model.config.embeddings.s1_srtm.tail_layers, + out_channels=cfg["model"]["config"]["embeddings"]["s1_srtm"][ + "out_channels" + ], + encoder_sizes=cfg["model"]["config"]["embeddings"]["s1_srtm"][ + "encoder_sizes" + ], + kernel_size=cfg["model"]["config"]["embeddings"]["s1_srtm"][ + "kernel_size" + ], + tail_layers=cfg["model"]["config"]["embeddings"]["s1_srtm"][ + "tail_layers" + ], ), s2_angles=ConvnetParams( - out_channels=cfg.model.config.embeddings.s2_angles.out_channels, - sizes=cfg.model.config.embeddings.s2_angles.sizes, - kernel_sizes=cfg.model.config.embeddings.s2_angles.kernel_sizes, + out_channels=cfg["model"]["config"]["embeddings"]["s2_angles"][ + "out_channels" + ], + sizes=cfg["model"]["config"]["embeddings"]["s2_angles"]["sizes"], + kernel_sizes=cfg["model"]["config"]["embeddings"]["s2_angles"][ + "kernel_sizes" + ], ), s2_srtm=UnetParams( - out_channels=cfg.model.config.embeddings.s2_srtm.out_channels, - encoder_sizes=cfg.model.config.embeddings.s2_srtm.encoder_sizes, - kernel_size=cfg.model.config.embeddings.s2_srtm.kernel_size, - tail_layers=cfg.model.config.embeddings.s2_srtm.tail_layers, + out_channels=cfg["model"]["config"]["embeddings"]["s2_srtm"][ + "out_channels" + ], + encoder_sizes=cfg["model"]["config"]["embeddings"]["s2_srtm"][ + "encoder_sizes" + ], + kernel_size=cfg["model"]["config"]["embeddings"]["s2_srtm"][ + "kernel_size" + ], + tail_layers=cfg["model"]["config"]["embeddings"]["s2_srtm"][ + "tail_layers" + ], ), w_c=ConvnetParams( - out_channels=cfg.model.config.embeddings.w_c.out_channels, - sizes=cfg.model.config.embeddings.w_c.sizes, - kernel_sizes=cfg.model.config.embeddings.w_c.kernel_sizes, + out_channels=cfg["model"]["config"]["embeddings"]["w_c"][ + "out_channels" + ], + sizes=cfg["model"]["config"]["embeddings"]["w_c"]["sizes"], + kernel_sizes=cfg["model"]["config"]["embeddings"]["w_c"][ + "kernel_sizes" + ], ), ), s1_encoder=UnetParams( - out_channels=cfg.model.config.s1_encoder.out_channels, - encoder_sizes=cfg.model.config.s1_encoder.encoder_sizes, - kernel_size=cfg.model.config.s1_encoder.kernel_size, - tail_layers=cfg.model.config.s1_encoder.tail_layers, + out_channels=cfg["model"]["config"]["s1_encoder"]["out_channels"], + encoder_sizes=cfg["model"]["config"]["s1_encoder"]["encoder_sizes"], + kernel_size=cfg["model"]["config"]["s1_encoder"]["kernel_size"], + tail_layers=cfg["model"]["config"]["s1_encoder"]["tail_layers"], ), s2_encoder=UnetParams( - out_channels=cfg.model.config.s2_encoder.out_channels, - encoder_sizes=cfg.model.config.s2_encoder.encoder_sizes, - kernel_size=cfg.model.config.s2_encoder.kernel_size, - tail_layers=cfg.model.config.s2_encoder.tail_layers, + out_channels=cfg["model"]["config"]["s2_encoder"]["out_channels"], + encoder_sizes=cfg["model"]["config"]["s2_encoder"]["encoder_sizes"], + kernel_size=cfg["model"]["config"]["s2_encoder"]["kernel_size"], + tail_layers=cfg["model"]["config"]["s2_encoder"]["tail_layers"], ), s1_decoder=ConvnetParams( - out_channels=cfg.model.config.s1_decoder.out_channels, - sizes=cfg.model.config.s1_decoder.sizes, - kernel_sizes=cfg.model.config.s1_decoder.kernel_sizes, + out_channels=cfg["model"]["config"]["s1_decoder"]["out_channels"], + sizes=cfg["model"]["config"]["s1_decoder"]["sizes"], + kernel_sizes=cfg["model"]["config"]["s1_decoder"]["kernel_sizes"], ), s2_decoder=ConvnetParams( - out_channels=cfg.model.config.s2_decoder.out_channels, - sizes=cfg.model.config.s2_decoder.sizes, - kernel_sizes=cfg.model.config.s2_decoder.kernel_sizes, + out_channels=cfg["model"]["config"]["s2_decoder"]["out_channels"], + sizes=cfg["model"]["config"]["s2_decoder"]["sizes"], + kernel_sizes=cfg["model"]["config"]["s2_decoder"]["kernel_sizes"], ), s1_enc_use=MMDCDataUse( - s1_angles=cfg.model.config.s1_enc_use.s1_angles, - s2_angles=cfg.model.config.s1_enc_use.s2_angles, - srtm=cfg.model.config.s1_enc_use.srtm, - w_c=cfg.model.config.s1_enc_use.w_c, + s1_angles=cfg["model"]["config"]["s1_enc_use"]["s1_angles"], + s2_angles=cfg["model"]["config"]["s1_enc_use"]["s2_angles"], + srtm=cfg["model"]["config"]["s1_enc_use"]["srtm"], + w_c=cfg["model"]["config"]["s1_enc_use"]["w_c"], ), s2_enc_use=MMDCDataUse( - s1_angles=cfg.model.config.s2_enc_use.s1_angles, - s2_angles=cfg.model.config.s2_enc_use.s2_angles, - srtm=cfg.model.config.s2_enc_use.srtm, - w_c=cfg.model.config.s2_enc_use.w_c, + s1_angles=cfg["model"]["config"]["s2_enc_use"]["s1_angles"], + s2_angles=cfg["model"]["config"]["s2_enc_use"]["s2_angles"], + srtm=cfg["model"]["config"]["s2_enc_use"]["srtm"], + w_c=cfg["model"]["config"]["s2_enc_use"]["w_c"], ), s1_dec_use=MMDCDataUse( - s1_angles=cfg.model.config.s1_dec_use.s1_angles, - s2_angles=cfg.model.config.s1_dec_use.s2_angles, - srtm=cfg.model.config.s1_dec_use.srtm, - w_c=cfg.model.config.s1_dec_use.w_c, + s1_angles=cfg["model"]["config"]["s1_dec_use"]["s1_angles"], + s2_angles=cfg["model"]["config"]["s1_dec_use"]["s2_angles"], + srtm=cfg["model"]["config"]["s1_dec_use"]["srtm"], + w_c=cfg["model"]["config"]["s1_dec_use"]["w_c"], ), s2_dec_use=MMDCDataUse( - s1_angles=cfg.model.config.s2_dec_use.s1_angles, - s2_angles=cfg.model.config.s2_dec_use.s2_angles, - srtm=cfg.model.config.s2_dec_use.srtm, - w_c=cfg.model.config.s2_dec_use.w_c, + s1_angles=cfg["model"]["config"]["s2_dec_use"]["s1_angles"], + s2_angles=cfg["model"]["config"]["s2_dec_use"]["s2_angles"], + srtm=cfg["model"]["config"]["s2_dec_use"]["srtm"], + w_c=cfg["model"]["config"]["s2_dec_use"]["w_c"], ), loss_weights=MultiVAELossWeights( sen1=TranslationLossWeights( - nll=cfg.model.config.loss_weights.sen1.nll, - lpips=cfg.model.config.loss_weights.sen1.lpips, - sam=cfg.model.config.loss_weights.sen1.sam, - gradients=cfg.model.config.loss_weights.sen1.gradients, + nll=cfg["model"]["config"]["loss_weights"]["sen1"]["nll"], + lpips=cfg["model"]["config"]["loss_weights"]["sen1"]["lpips"], + sam=cfg["model"]["config"]["loss_weights"]["sen1"]["sam"], + gradients=cfg["model"]["config"]["loss_weights"]["sen1"][ + "gradients" + ], ), sen2=TranslationLossWeights( - nll=cfg.model.config.loss_weights.sen2.nll, - lpips=cfg.model.config.loss_weights.sen2.lpips, - sam=cfg.model.config.loss_weights.sen2.sam, - gradients=cfg.model.config.loss_weights.sen2.gradients, + nll=cfg["model"]["config"]["loss_weights"]["sen2"]["nll"], + lpips=cfg["model"]["config"]["loss_weights"]["sen2"]["lpips"], + sam=cfg["model"]["config"]["loss_weights"]["sen2"]["sam"], + gradients=cfg["model"]["config"]["loss_weights"]["sen2"][ + "gradients" + ], ), s1_s2=TranslationLossWeights( - nll=cfg.model.config.loss_weights.s1_s2.nll, - lpips=cfg.model.config.loss_weights.s1_s2.lpips, - sam=cfg.model.config.loss_weights.s1_s2.sam, - gradients=cfg.model.config.loss_weights.s1_s2.gradients, + nll=cfg["model"]["config"]["loss_weights"]["s1_s2"]["nll"], + lpips=cfg["model"]["config"]["loss_weights"]["s1_s2"]["lpips"], + sam=cfg["model"]["config"]["loss_weights"]["s1_s2"]["sam"], + gradients=cfg["model"]["config"]["loss_weights"]["s1_s2"][ + "gradients" + ], ), s2_s1=TranslationLossWeights( - nll=cfg.model.config.loss_weights.s2_s1.nll, - lpips=cfg.model.config.loss_weights.s2_s1.lpips, - sam=cfg.model.config.loss_weights.s2_s1.sam, - gradients=cfg.model.config.loss_weights.s2_s1.gradients, + nll=cfg["model"]["config"]["loss_weights"]["s2_s1"]["nll"], + lpips=cfg["model"]["config"]["loss_weights"]["s2_s1"]["lpips"], + sam=cfg["model"]["config"]["loss_weights"]["s2_s1"]["sam"], + gradients=cfg["model"]["config"]["loss_weights"]["s2_s1"][ + "gradients" + ], ), - forward=cfg.model.config.loss_weights.forward, - cross=cfg.model.config.loss_weights.forward, - latent=cfg.model.config.loss_weights.latent, + forward=cfg["model"]["config"]["loss_weights"]["forward"], + cross=cfg["model"]["config"]["loss_weights"]["forward"], + latent=cfg["model"]["config"]["loss_weights"]["latent"], ), ) # Init the model @@ -177,8 +220,8 @@ def get_mmdc_full_model( ) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - print(f"device detected : {device}") - print(f"nb_gpu's detected : {torch.cuda.device_count()}") + # print(f"device detected : {device}") + # print(f"nb_gpu's detected : {torch.cuda.device_count()}") # load state_dict lightning_checkpoint = torch.load(checkpoint, map_location=device) @@ -233,7 +276,7 @@ def predict_mmdc_model( # available_sensors = ["S2L2A", "S1FULL", "S1ASC", "S1DESC"] device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - print(device) + # print(device) prediction = model.predict( s2_ref.to(device), @@ -247,10 +290,10 @@ def predict_mmdc_model( srtm.to(device), ) - print("prediction.shape :", prediction[0].latent.latent_s2.mu.shape) - print("prediction.shape :", prediction[0].latent.latent_s2.logvar.shape) - print("prediction.shape :", prediction[0].latent.latent_s1.mu.shape) - print("prediction.shape :", prediction[0].latent.latent_s1.logvar.shape) + # print("prediction.shape :", prediction[0].latent.latent_s2.mean.shape) + # print("prediction.shape :", prediction[0].latent.latent_s2.logvar.shape) + # print("prediction.shape :", prediction[0].latent.latent_s1.mean.shape) + # print("prediction.shape :", prediction[0].latent.latent_s1.logvar.shape) # init the output latent spaces as # empty dataclass @@ -261,151 +304,151 @@ def predict_mmdc_model( match sensors: # Cases case ["S2L2A", "S1FULL"]: - logger.info("S2 captured & S1 full captured") + # logger.info("S2 captured & S1 full captured") latent_space.latent_s2 = prediction[0].latent.latent_s2 latent_space.latent_s1 = prediction[0].latent.latent_s1 latent_space_stack = torch.cat( ( - latent_space.latent_s2.mu, + latent_space.latent_s2.mean, latent_space.latent_s2.logvar, - latent_space.latent_s1.mu, + latent_space.latent_s1.mean, latent_space.latent_s1.logvar, ), 1, ) - print( - "latent_space.latent_s2.mu :", - latent_space.latent_s2.mu.shape, - ) - print( - "latent_space.latent_s2.logvar :", - latent_space.latent_s2.logvar.shape, - ) - print("latent_space.latent_s1.mu :", latent_space.latent_s1.mu.shape) - print( - "latent_space.latent_s1.logvar :", latent_space.latent_s1.logvar.shape - ) + # print( + # "latent_space.latent_s2.mean :", + # latent_space.latent_s2.mean.shape, + # ) + # print( + # "latent_space.latent_s2.logvar :", + # latent_space.latent_s2.logvar.shape, + # ) + # print("latent_space.latent_s1.mean :", latent_space.latent_s1.mean.shape) + # print( + # "latent_space.latent_s1.logvar :", latent_space.latent_s1.logvar.shape + # ) case ["S2L2A", "S1ASC"]: - logger.info("S2 captured & S1 asc captured") + # logger.info("S2 captured & S1 asc captured") latent_space.latent_s2 = prediction[0].latent.latent_s2 latent_space.latent_s1 = prediction[0].latent.latent_s1 latent_space_stack = torch.cat( ( - latent_space.latent_s2.mu, + latent_space.latent_s2.mean, latent_space.latent_s2.logvar, - latent_space.latent_s1.mu, + latent_space.latent_s1.mean, latent_space.latent_s1.logvar, ), 1, ) - print( - "latent_space.latent_s2.mu :", - latent_space.latent_s2.mu.shape, - ) - print( - "latent_space.latent_s2.logvar :", - latent_space.latent_s2.logvar.shape, - ) - print("latent_space.latent_s1.mu :", latent_space.latent_s1.mu.shape) - print( - "latent_space.latent_s1.logvar :", latent_space.latent_s1.logvar.shape - ) + # print( + # "latent_space.latent_s2.mean :", + # latent_space.latent_s2.mean.shape, + # ) + # print( + # "latent_space.latent_s2.logvar :", + # latent_space.latent_s2.logvar.shape, + # ) + # print("latent_space.latent_s1.mean :", latent_space.latent_s1.mean.shape) + # print( + # "latent_space.latent_s1.logvar :", latent_space.latent_s1.logvar.shape + # ) case ["S2L2A", "S1DESC"]: - logger.info("S2 captured & S1 desc captured") + # logger.info("S2 captured & S1 desc captured") latent_space.latent_s2 = prediction[0].latent.latent_s2 latent_space.latent_s1 = prediction[0].latent.latent_s1 latent_space_stack = torch.cat( ( - latent_space.latent_s2.mu, + latent_space.latent_s2.mean, latent_space.latent_s2.logvar, - latent_space.latent_s1.mu, + latent_space.latent_s1.mean, latent_space.latent_s1.logvar, ), 1, ) - print( - "latent_space.latent_s2.mu :", - latent_space.latent_s2.mu.shape, - ) - print( - "latent_space.latent_s2.logvar :", - latent_space.latent_s2.logvar.shape, - ) - print("latent_space.latent_s1.mu :", latent_space.latent_s1.mu.shape) - print( - "latent_space.latent_s1.logvar :", latent_space.latent_s1.logvar.shape - ) + # print( + # "latent_space.latent_s2.mean :", + # latent_space.latent_s2.mean.shape, + # ) + # print( + # "latent_space.latent_s2.logvar :", + # latent_space.latent_s2.logvar.shape, + # ) + # print("latent_space.latent_s1.mean :", latent_space.latent_s1.mean.shape) + # print( + # "latent_space.latent_s1.logvar :", latent_space.latent_s1.logvar.shape + # ) case ["S2L2A"]: - logger.info("Only S2 captured") + # logger.info("Only S2 captured") latent_space.latent_s2 = prediction[0].latent.latent_s2 latent_space_stack = torch.cat( ( - latent_space.latent_s2.mu, + latent_space.latent_s2.mean, latent_space.latent_s2.logvar, ), 1, ) - print( - "latent_space.latent_s2.mu :", - latent_space.latent_s2.mu.shape, - ) - print( - "latent_space.latent_s2.logvar :", - latent_space.latent_s2.logvar.shape, - ) + # print( + # "latent_space.latent_s2.mean :", + # latent_space.latent_s2.mean.shape, + # ) + # print( + # "latent_space.latent_s2.logvar :", + # latent_space.latent_s2.logvar.shape, + # ) case ["S1FULL"]: - logger.info("Only S1 full captured") + # logger.info("Only S1 full captured") latent_space.latent_s1 = prediction[0].latent.latent_s1 latent_space_stack = torch.cat( - (latent_space.latent_s1.mu, latent_space.latent_s1.logvar), + (latent_space.latent_s1.mean, latent_space.latent_s1.logvar), 1, ) - print("latent_space.latent_s1.mu :", latent_space.latent_s1.mu.shape) - print( - "latent_space.latent_s1.logvar :", latent_space.latent_s1.logvar.shape - ) + # print("latent_space.latent_s1.mean :", latent_space.latent_s1.mean.shape) + # print( + # "latent_space.latent_s1.logvar :", latent_space.latent_s1.logvar.shape + # ) case ["S1ASC"]: - logger.info("Only S1ASC captured") + # logger.info("Only S1ASC captured") latent_space.latent_s1 = prediction[0].latent.latent_s1 latent_space_stack = torch.cat( - (latent_space.latent_s1.mu, latent_space.latent_s1.logvar), + (latent_space.latent_s1.mean, latent_space.latent_s1.logvar), 1, ) - print("latent_space.latent_s1.mu :", latent_space.latent_s1.mu.shape) - print( - "latent_space.latent_s1.logvar :", latent_space.latent_s1.logvar.shape - ) + # print("latent_space.latent_s1.mean :", latent_space.latent_s1.mean.shape) + # print( + # "latent_space.latent_s1.logvar :", latent_space.latent_s1.logvar.shape + # ) case ["S1DESC"]: - logger.info("Only S1DESC captured") + # logger.info("Only S1DESC captured") latent_space.latent_s1 = prediction[0].latent.latent_s1 latent_space_stack = torch.cat( - (latent_space.latent_s1.mu, latent_space.latent_s1.logvar), + (latent_space.latent_s1.mean, latent_space.latent_s1.logvar), 1, ) - print("latent_space.latent_s1.mu :", latent_space.latent_s1.mu.shape) - print( - "latent_space.latent_s1.logvar :", latent_space.latent_s1.logvar.shape - ) + # print("latent_space.latent_s1.mean :", latent_space.latent_s1.mean.shape) + # print( + # "latent_space.latent_s1.logvar :", latent_space.latent_s1.logvar.shape + # ) return latent_space_stack # latent_space -- GitLab From a44235ecafdb5025c695d3adb021b0ba99447747 Mon Sep 17 00:00:00 2001 From: Juan Sebastian Vinasco-Salinas <js.vinasco.s@gmail.com> Date: Fri, 17 Mar 2023 16:32:11 +0000 Subject: [PATCH 52/81] update paths --- test/test_mmdc_inference.py | 351 +++++++++++------------------------- 1 file changed, 104 insertions(+), 247 deletions(-) diff --git a/test/test_mmdc_inference.py b/test/test_mmdc_inference.py index 4f7b1f0a..51be142f 100644 --- a/test/test_mmdc_inference.py +++ b/test/test_mmdc_inference.py @@ -10,7 +10,7 @@ import pytest import rasterio as rio import torch -from mmdc_singledate.inference.components.inference_components import ( # predict_tile, +from mmdc_singledate.inference.components.inference_components import ( get_mmdc_full_model, predict_mmdc_model, ) @@ -24,31 +24,77 @@ from mmdc_singledate.inference.mmdc_tile_inference import ( mmdc_tile_inference, predict_single_date_tile, ) -from mmdc_singledate.models.components.model_dataclass import VAELatentSpace +from mmdc_singledate.models.types import VAELatentSpace -# dir +from .utils import get_scales, setup_data + +# usefull variables dataset_dir = "/work/CESBIO/projects/MAESTRIA/test_onetile/total/T31TCJ/" +s2_filename_variable = "SENTINEL2A_20180302-105023-464_L2A_T31TCJ_C_V2-2_roi_0.tif" +s1_asc_filename_variable = "S1A_IW_GRDH_1SDV_20180303T174722_20180303T174747_020854_023C45_4525_36.09887223808782_15.429479982324139_roi_0.tif" +s1_desc_filename_variable = "S1A_IW_GRDH_1SDV_20180302T060027_20180302T060052_020832_023B8F_C485_40.83835645911643_165.05888005216622_roi_0.tif" +srtm_filename_variable = "srtm_T31TCJ_roi_0.tif" +wc_filename_variable = "wc_clim_1_T31TCJ_roi_0.tif" + + +sensors_test = [ + (["S2L2A", "S1FULL"]), + (["S2L2A", "S1ASC"]), + (["S2L2A", "S1DESC"]), + (["S2L2A"]), + (["S1FULL"]), + (["S1ASC"]), + (["S1DESC"]), +] + + +datasets = [ + ( + s2_filename_variable, + s1_asc_filename_variable, + s1_desc_filename_variable, + srtm_filename_variable, + wc_filename_variable, + ), + ( + "", + s1_asc_filename_variable, + s1_desc_filename_variable, + srtm_filename_variable, + wc_filename_variable, + ), + ( + s2_filename_variable, + "", + s1_desc_filename_variable, + srtm_filename_variable, + wc_filename_variable, + ), + ( + s2_filename_variable, + s1_asc_filename_variable, + "", + srtm_filename_variable, + wc_filename_variable, + ), +] # @pytest.mark.skip(reason="Check Time Serie generation") def test_GeoTiffDataset(): """ """ - datapath = "/work/CESBIO/projects/MAESTRIA/test_onetile/total/T31TCJ" - input_data = GeoTiffDataset( - s2_filename=os.path.join( - datapath, "SENTINEL2A_20180302-105023-464_L2A_T31TCJ_C_V2-2_roi_0.tif" - ), + s2_filename=os.path.join(dataset_dir, s2_filename_variable), s1_asc_filename=os.path.join( - datapath, - "S1A_IW_GRDH_1SDV_20180303T174722_20180303T174747_020854_023C45_4525_36.09887223808782_15.429479982324139_roi_0.tif", + dataset_dir, + s1_asc_filename_variable, ), s1_desc_filename=os.path.join( - datapath, - "S1A_IW_GRDH_1SDV_20180302T060027_20180302T060052_020832_023B8F_C485_40.83835645911643_165.05888005216622_roi_0.tif", + dataset_dir, + s1_desc_filename_variable, ), - srtm_filename=os.path.join(datapath, "srtm_T31TCJ_roi_0.tif"), - wc_filename=os.path.join(datapath, "wc_clim_1_T31TCJ_roi_0.tif"), + srtm_filename=os.path.join(dataset_dir, srtm_filename_variable), + wc_filename=os.path.join(dataset_dir, wc_filename_variable), ) # print GetTiffDataset print(input_data) @@ -69,95 +115,33 @@ def test_generate_chunks(): assert chunks[0] == (0, 0, 10980, 1024) -# @pytest.mark.skip(reason="Check Time Serie generation") -def test_mmdc_full_model(): - """ - test instantiate network - """ - # set device - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - print(device) +@pytest.mark.skip(reason="Check Time Serie generation") +@pytest.mark.parametrize("sensors", sensors_test) +def test_predict_mmdc_model(sensors): + """ """ + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # checkpoints - checkpoint_path = "/home/uz/vinascj/scratch/MMDC/results/latent/logs/experiments/runs/mmdc_full/2023-01-06_13-00-34/checkpoints" - checkpoint_filename = "epoch_008.ckpt" # "last.ckpt" + checkpoint_path = "/home/uz/vinascj/scratch/MMDC/results/latent/logs/experiments/runs/mmdc_full/2023-03-16_14-00-10/checkpoints" + checkpoint_filename = "last.ckpt" mmdc_full_model = get_mmdc_full_model( os.path.join(checkpoint_path, checkpoint_filename) ) - # move to device - mmdc_full_model.to(device) - - assert mmdc_full_model.training == False - - s2_x = torch.rand(1, 10, 256, 256).to(device) - s2_m = torch.ones(1, 10, 256, 256).to(device) - s2_angles_x = torch.rand(1, 6, 256, 256).to(device) - s1_x = torch.rand(1, 6, 256, 256).to(device) - s1_vm = torch.ones(1, 1, 256, 256).to(device) - s1_asc_angles_x = torch.rand(1, 3).to(device) - s1_desc_angles_x = torch.rand(1, 3).to(device) - worldclim_x = torch.rand(1, 103, 256, 256).to(device) - srtm_x = torch.rand(1, 4, 256, 256).to(device) - - print(srtm_x) - - prediction = mmdc_full_model.predict( + ( s2_x, s2_m, - s2_angles_x, + s2_a, s1_x, s1_vm, - s1_asc_angles_x, - s1_desc_angles_x, - worldclim_x, + s1_a_asc, + s1_a_desc, srtm_x, - ) - - latent_variable = prediction[0].latent.latent_s1.mu - print(latent_variable.shape) - - assert type(latent_variable) == torch.Tensor - - -sensors_test = [ - (["S2L2A", "S1FULL"]), - (["S2L2A", "S1ASC"]), - (["S2L2A", "S1DESC"]), - (["S2L2A"]), - (["S1FULL"]), - (["S1ASC"]), - (["S1DESC"]), - # ([ "S1FULL","S2L2A" ]), -] - - -# @pytest.mark.skip(reason="Check Time Serie generation") -@pytest.mark.parametrize("sensors", sensors_test) -def test_predict_mmdc_model(sensors): - """ """ - - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - - # data input - s2_x = torch.rand(1, 10, 256, 256).to(device) - s2_m = torch.ones(1, 10, 256, 256).to(device) - s2_angles_x = torch.rand(1, 6, 256, 256).to(device) - s1_x = torch.rand(1, 6, 256, 256).to(device) - s1_vm = torch.ones(1, 1, 256, 256).to(device) - s1_asc_angles_x = torch.rand(1, 3).to(device) - s1_desc_angles_x = torch.rand(1, 3).to(device) - worldclim_x = torch.rand(1, 103, 256, 256).to(device) - srtm_x = torch.rand(1, 4, 256, 256).to(device) - - # model - # checkpoints - checkpoint_path = "/home/uz/vinascj/scratch/MMDC/results/latent/logs/experiments/runs/mmdc_full/2023-01-06_13-00-34/checkpoints" - checkpoint_filename = "epoch_008.ckpt" # "last.ckpt" - - mmdc_full_model = get_mmdc_full_model( - os.path.join(checkpoint_path, checkpoint_filename) - ) + wc_x, + device, + ) = setup_data() + scales = get_scales() + mmdc_full_model.set_scales(scales) # move to device mmdc_full_model.to(device) @@ -166,107 +150,17 @@ def test_predict_mmdc_model(sensors): sensors, s2_x, s2_m, - s2_angles_x, + s2_a, s1_x, s1_vm, - s1_asc_angles_x, - s1_desc_angles_x, - worldclim_x, + s1_a_asc, + s1_a_desc, + wc_x, srtm_x, ) print(pred.shape) assert type(pred) == torch.Tensor - # assert pred.shape[0] == 4 or 2 - - -def dummy_process( - s2_refl, - s2_mask, - s2_ang, - s1_back, - s1_vm, - s1_asc_lia, - s1_desc_lia, - wc_patch, - srtm_patch, -): - """ - Create a dummy function - """ - prediction = ( - s2_refl[:, 0, ...] - * s2_mask[:, 0, ...] - * s2_ang[:, 0, ...] - * s1_back[:, 0, ...] - * s1_vm[:, 0, ...] - * srtm_patch[:, 0, ...] - * wc_patch[:, 0, ...] - ) - - latent_example = VAELatentSpace( - mu=prediction, - logvar=prediction, - ) - - latent_space_stack = torch.cat( - ( - torch.unsqueeze(latent_example.mu, 1), - torch.unsqueeze(latent_example.logvar, 1), - torch.unsqueeze(latent_example.mu, 1), - torch.unsqueeze(latent_example.logvar, 1), - ), - 1, - ) - - return latent_space_stack - - -# TODO add cases with no data -# @pytest.mark.skip(reason="Test with dummy process ") -# @pytest.mark.parametrize("sensors", sensors_test) -# def test_predict_single_date_tile_with_dummy_process(sensors): -# """ """ -# export_path = f"/work/CESBIO/projects/MAESTRIA/test_onetile/total/export/test_latent_singledate_{'_'.join(sensors)}.tif" -# nb_bands = 4 # len(sensors) * 2 -# print("nb_bands :", nb_bands) -# process = MMDCProcess( -# count=nb_bands, #4, -# nb_lines=1024, -# patch_size=256, -# process=dummy_process, -# ) - -# datapath = "/work/CESBIO/projects/MAESTRIA/test_onetile/total/T31TCJ/" -# input_data = GeoTiffDataset( -# s2_filename=os.path.join( -# datapath, "SENTINEL2A_20180302-105023-464_L2A_T31TCJ_C_V2-2_roi_0.tif" -# ), -# s1_asc_filename=os.path.join( -# datapath, -# "S1A_IW_GRDH_1SDV_20180303T174722_20180303T174747_020854_023C45_4525_36.09887223808782_15.429479982324139_roi_0.tif", -# ), -# s1_desc_filename=os.path.join( -# datapath, -# "S1A_IW_GRDH_1SDV_20180302T060027_20180302T060052_020832_023B8F_C485_40.83835645911643_165.05888005216622_roi_0.tif", -# ), -# srtm_filename=os.path.join(datapath, "srtm_T31TCJ_roi_0.tif"), -# wc_filename=os.path.join(datapath, "wc_clim_1_T31TCJ_roi_0.tif"), -# ) - -# predict_single_date_tile( -# input_data=input_data, -# export_path=export_path, -# sensors=sensors, # ["S2L2A"], -# process=process, -# ) - -# assert Path(export_path).exists() == True - -# with rio.open(export_path) as predited_raster: -# predicted_metadata = predited_raster.meta.copy() - -# assert predicted_metadata["count"] == nb_bands # TODO add cases with no data @@ -278,33 +172,33 @@ def test_predict_single_date_tile(sensors): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") export_path = f"/work/CESBIO/projects/MAESTRIA/test_onetile/total/export/test_latent_singledate_with_model_{'_'.join(sensors)}.tif" - datapath = "/work/CESBIO/projects/MAESTRIA/test_onetile/total/T31TCJ/" input_data = GeoTiffDataset( - s2_filename=os.path.join( - datapath, "SENTINEL2A_20180302-105023-464_L2A_T31TCJ_C_V2-2_roi_0.tif" - ), + s2_filename=os.path.join(dataset_dir, s2_filename_variable), s1_asc_filename=os.path.join( - datapath, - "S1A_IW_GRDH_1SDV_20180303T174722_20180303T174747_020854_023C45_4525_36.09887223808782_15.429479982324139_roi_0.tif", + dataset_dir, + s1_asc_filename_variable, ), s1_desc_filename=os.path.join( - datapath, - "S1A_IW_GRDH_1SDV_20180302T060027_20180302T060052_020832_023B8F_C485_40.83835645911643_165.05888005216622_roi_0.tif", + dataset_dir, + s1_desc_filename_variable, ), - srtm_filename=os.path.join(datapath, "srtm_T31TCJ_roi_0.tif"), - wc_filename=os.path.join(datapath, "wc_clim_1_T31TCJ_roi_0.tif"), + srtm_filename=os.path.join(dataset_dir, srtm_filename_variable), + wc_filename=os.path.join(dataset_dir, wc_filename_variable), ) nb_bands = len(sensors) * 8 print("nb_bands :", nb_bands) + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # checkpoints - checkpoint_path = "/home/uz/vinascj/scratch/MMDC/results/latent/logs/experiments/runs/mmdc_full/2023-01-06_13-00-34/checkpoints" - checkpoint_filename = "epoch_008.ckpt" # "last.ckpt" + checkpoint_path = "/home/uz/vinascj/scratch/MMDC/results/latent/logs/experiments/runs/mmdc_full/2023-03-16_14-00-10/checkpoints" + checkpoint_filename = "last.ckpt" # "last.ckpt" mmdc_full_model = get_mmdc_full_model( os.path.join(checkpoint_path, checkpoint_filename) ) + scales = get_scales() + mmdc_full_model.set_scales(scales) # move to device mmdc_full_model.to(device) assert mmdc_full_model.training == False @@ -333,49 +227,10 @@ def test_predict_single_date_tile(sensors): assert predicted_metadata["count"] == nb_bands -s2_filename_variable = "SENTINEL2A_20180302-105023-464_L2A_T31TCJ_C_V2-2_roi_0.tif" -s1_asc_filename_variable = "S1A_IW_GRDH_1SDV_20180303T174722_20180303T174747_020854_023C45_4525_36.09887223808782_15.429479982324139_roi_0.tif" -s1_desc_filename_variable = "S1A_IW_GRDH_1SDV_20180302T060027_20180302T060052_020832_023B8F_C485_40.83835645911643_165.05888005216622_roi_0.tif" -srtm_filename_variable = "srtm_T31TCJ_roi_0.tif" -wc_filename_variable = "wc_clim_1_T31TCJ_roi_0.tif" - -datasets = [ - ( - s2_filename_variable, - s1_asc_filename_variable, - s1_desc_filename_variable, - srtm_filename_variable, - wc_filename_variable, - ), - ( - "", - s1_asc_filename_variable, - s1_desc_filename_variable, - srtm_filename_variable, - wc_filename_variable, - ), - ( - s2_filename_variable, - "", - s1_desc_filename_variable, - srtm_filename_variable, - wc_filename_variable, - ), - ( - s2_filename_variable, - s1_asc_filename_variable, - "", - srtm_filename_variable, - wc_filename_variable, - ), -] - - @pytest.mark.parametrize( "s2_filename_variable, s1_asc_filename_variable, s1_desc_filename_variable, srtm_filename_variable, wc_filename_variable", datasets, ) -# @pytest.mark.skip(reason="Check Time Serie generation") def test_predict_single_date_tile_no_data( s2_filename_variable, s1_asc_filename_variable, @@ -388,32 +243,35 @@ def test_predict_single_date_tile_no_data( device = torch.device("cuda" if torch.cuda.is_available() else "cpu") export_path = f"/work/CESBIO/projects/MAESTRIA/test_onetile/total/export/test_latent_singledate_with_model_{'_'.join(sensors)}.tif" - datapath = "/work/CESBIO/projects/MAESTRIA/test_onetile/total/T31TCJ/" + dataset_dir = "/work/CESBIO/projects/MAESTRIA/test_onetile/total/T31TCJ/" input_data = GeoTiffDataset( - s2_filename=os.path.join(datapath, s2_filename_variable), + s2_filename=os.path.join(dataset_dir, s2_filename_variable), s1_asc_filename=os.path.join( - datapath, + dataset_dir, s1_asc_filename_variable, ), s1_desc_filename=os.path.join( - datapath, + dataset_dir, s1_desc_filename_variable, ), - srtm_filename=os.path.join(datapath, srtm_filename_variable), - wc_filename=os.path.join(datapath, wc_filename_variable), + srtm_filename=os.path.join(dataset_dir, srtm_filename_variable), + wc_filename=os.path.join(dataset_dir, wc_filename_variable), ) nb_bands = len(sensors) * 8 print("nb_bands :", nb_bands) + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # checkpoints - checkpoint_path = "/home/uz/vinascj/scratch/MMDC/results/latent/logs/experiments/runs/mmdc_full/2023-01-06_13-00-34/checkpoints" - checkpoint_filename = "epoch_008.ckpt" # "last.ckpt" + checkpoint_path = "/home/uz/vinascj/scratch/MMDC/results/latent/logs/experiments/runs/mmdc_full/2023-03-16_14-00-10/checkpoints" + checkpoint_filename = "last.ckpt" # "last.ckpt" mmdc_full_model = get_mmdc_full_model( os.path.join(checkpoint_path, checkpoint_filename) ) # move to device + scales = get_scales() + mmdc_full_model.set_scales(scales) mmdc_full_model.to(device) assert mmdc_full_model.training == False @@ -449,7 +307,7 @@ def test_mmdc_tile_inference(): # feed parameters input_path = Path("/work/CESBIO/projects/MAESTRIA/training_dataset2/") export_path = Path("/work/CESBIO/projects/MAESTRIA/test_onetile/total/export/") - model_checkpoint_path = "/home/uz/vinascj/scratch/MMDC/results/latent/logs/experiments/runs/mmdc_full/2023-01-06_13-00-34/checkpoints/epoch_008.ckpt" + model_checkpoint_path = "/home/uz/vinascj/scratch/MMDC/results/latent/logs/experiments/runs/mmdc_full/2023-03-16_14-00-10/checkpoints/last.ckpt" input_tile_list = ["/work/CESBIO/projects/MAESTRIA/training_dataset2/T33TUL"] @@ -542,7 +400,6 @@ def test_mmdc_tile_inference(): ], } ) - # inference_dataframe( # samples_dir=input_path, # input_tile_list=input_tile_list, -- GitLab From bebe29b944d582aa423b75425c5d5455dd26202a Mon Sep 17 00:00:00 2001 From: Juan Sebastian Vinasco-Salinas <js.vinasco.s@gmail.com> Date: Fri, 17 Mar 2023 16:33:04 +0000 Subject: [PATCH 53/81] update to main --- .../inference/mmdc_tile_inference.py | 116 +++++++++--------- 1 file changed, 60 insertions(+), 56 deletions(-) diff --git a/src/mmdc_singledate/inference/mmdc_tile_inference.py b/src/mmdc_singledate/inference/mmdc_tile_inference.py index 123b4318..ebc619f5 100644 --- a/src/mmdc_singledate/inference/mmdc_tile_inference.py +++ b/src/mmdc_singledate/inference/mmdc_tile_inference.py @@ -16,13 +16,11 @@ import pandas as pd import rasterio as rio import torch from torchutils import patches +from tqdm import tqdm from mmdc_singledate.datamodules.components.datamodule_components import prepare_data_df -from .components.inference_components import ( # patchify_batch,; unpatchify_batch, - get_mmdc_full_model, - predict_mmdc_model, -) +from .components.inference_components import get_mmdc_full_model, predict_mmdc_model from .components.inference_utils import ( GeoTiffDataset, concat_worldclim_components, @@ -32,6 +30,7 @@ from .components.inference_utils import ( read_s2_img_tile, read_srtm_img_tile, ) +from .utils import get_scales # Configure the logger NUMERIC_LEVEL = getattr(logging, "INFO", None) @@ -121,48 +120,48 @@ def predict_single_date_tile( chunks = generate_chunks(meta["width"], meta["height"], process.nb_lines) # get the windows from the chunks rois = [rio.windows.Window(*chunk) for chunk in chunks] - logger.info(f"chunk size : ({rois[0].width}, {rois[0].height}) ") + # logger.info(f"chunk size : ({rois[0].width}, {rois[0].height}) ") # init the dataset - logger.info("Reading S2 data") + # logger.info("Reading S2 data") s2_data = read_img_tile( filename=input_data.s2_filename, rois=rois, availabitity=input_data.s2_availabitity, sensor_func=read_s2_img_tile, ) - logger.info("Reading S1 ASC data") + # logger.info("Reading S1 ASC data") s1_asc_data = read_img_tile( filename=input_data.s1_asc_filename, rois=rois, availabitity=input_data.s1_asc_availability, sensor_func=read_s1_img_tile, ) - logger.info("Reading S1 DESC data") + # logger.info("Reading S1 DESC data") s1_desc_data = read_img_tile( filename=input_data.s1_desc_filename, rois=rois, availabitity=input_data.s1_desc_availability, sensor_func=read_s1_img_tile, ) - logger.info("Reading SRTM data") + # logger.info("Reading SRTM data") srtm_data = read_img_tile( filename=input_data.srtm_filename, rois=rois, availabitity=True, sensor_func=read_srtm_img_tile, ) - logger.info("Reading WorldClim data") + # logger.info("Reading WorldClim data") worldclim_data = concat_worldclim_components( wc_filename=input_data.wc_filename, rois=rois, availabitity=True ) - print("input_data.s2_availabitity=", input_data.s2_availabitity) - print("s2_data=", s2_data) - print("s1_asc_data=", s1_asc_data) - print("s1_desc_data=", s1_desc_data) - print("srtm_data=", srtm_data) - print("worldclim_data=", worldclim_data) - logger.info("Export Init") + # print("input_data.s2_availabitity=", input_data.s2_availabitity) + # print("s2_data=", s2_data) + # print("s1_asc_data=", s1_asc_data) + # print("s1_desc_data=", s1_desc_data) + # print("srtm_data=", srtm_data) + # print("worldclim_data=", worldclim_data) + # logger.info("Export Init") with rio.open(export_path, "w", **meta) as prediction: # iterate over the windows @@ -174,17 +173,17 @@ def predict_single_date_tile( srtm_data, worldclim_data, ): - print(" original size : ", s2.s2_reflectances.shape) - print(" original size : ", s2.s2_angles.shape) - print(" original size : ", s2.s2_mask.shape) - print(" original size : ", s1_asc.s1_backscatter.shape) - print(" original size : ", s1_asc.s1_valmask.shape) - print(" original size : ", s1_asc.s1_lia_angles.shape) - print(" original size : ", s1_desc.s1_backscatter.shape) - print(" original size : ", s1_desc.s1_valmask.shape) - print(" original size : ", s1_desc.s1_lia_angles.shape) - print(" original size : ", srtm.srtm.shape) - print(" original size : ", wc.worldclim.shape) + # print(" original size : ", s2.s2_reflectances.shape) + # print(" original size : ", s2.s2_angles.shape) + # print(" original size : ", s2.s2_mask.shape) + # print(" original size : ", s1_asc.s1_backscatter.shape) + # print(" original size : ", s1_asc.s1_valmask.shape) + # print(" original size : ", s1_asc.s1_lia_angles.shape) + # print(" original size : ", s1_desc.s1_backscatter.shape) + # print(" original size : ", s1_desc.s1_valmask.shape) + # print(" original size : ", s1_desc.s1_lia_angles.shape) + # print(" original size : ", srtm.srtm.shape) + # print(" original size : ", wc.worldclim.shape) # Concat S1 Data s1_backscatter = torch.cat( @@ -200,8 +199,8 @@ def predict_single_date_tile( ).shape s2_s2_mask_shape = s2.s2_mask.shape # [1, 1024, 10980] - print("s2_mask_patch_size :", s2_mask_patch_size) - print("s2_s2_mask_shape:", s2_s2_mask_shape) + # print("s2_mask_patch_size :", s2_mask_patch_size) + # print("s2_s2_mask_shape:", s2_s2_mask_shape) # reshape the data s2_refl_patch = patches.flatten2d( patches.patchify(s2.s2_reflectances, process.patch_size) @@ -234,23 +233,23 @@ def predict_single_date_tile( ) # torch.flatten(start_dim=0, end_dim=1) - print("s2_patches", s2_refl_patch.shape) - print("s2_angles_patches", s2_ang_patch.shape) - print("s2_mask_patch", s2_mask_patch.shape) - print("s1_patch", s1_patch.shape) - print("s1_valmask_patch", s1_valmask_patch.shape) - print( - "s1_lia_asc patches", - s1_asc.s1_lia_angles.shape, - s1asc_lia_patch.shape, - ) - print( - "s1_lia_desc patches", - s1_desc.s1_lia_angles.shape, - s1desc_lia_patch.shape, - ) - print("srtm patches", srtm_patch.shape) - print("wc patches", wc_patch.shape) + # print("s2_patches", s2_refl_patch.shape) + # print("s2_angles_patches", s2_ang_patch.shape) + # print("s2_mask_patch", s2_mask_patch.shape) + # print("s1_patch", s1_patch.shape) + # print("s1_valmask_patch", s1_valmask_patch.shape) + # print( + # "s1_lia_asc patches", + # s1_asc.s1_lia_angles.shape, + # s1asc_lia_patch.shape, + # ) + # print( + # "s1_lia_desc patches", + # s1_desc.s1_lia_angles.shape, + # s1desc_lia_patch.shape, + # ) + # print("srtm patches", srtm_patch.shape) + # print("wc patches", wc_patch.shape) # apply predict function # should return a s1s2vaelatentspace @@ -268,8 +267,8 @@ def predict_single_date_tile( wc_patch, srtm_patch, ) - print(type(pred_vaelatentspace)) - print("latent space sizes : ", pred_vaelatentspace.shape) + # print(type(pred_vaelatentspace)) + # print("latent space sizes : ", pred_vaelatentspace.shape) # unpatchify pred_vaelatentspace_unpatchify = patches.unpatchify( @@ -280,8 +279,8 @@ def predict_single_date_tile( ) )[:, : s2_s2_mask_shape[1], : s2_s2_mask_shape[2]] - print("pred_tensor :", pred_vaelatentspace_unpatchify.shape) - print("process.count : ", process.count) + # print("pred_tensor :", pred_vaelatentspace_unpatchify.shape) + # print("process.count : ", process.count) # check the pred and the ask are the same assert process.count == pred_vaelatentspace_unpatchify.shape[0] @@ -290,7 +289,7 @@ def predict_single_date_tile( window=roi, indexes=process.count, ) - logger.info(("Export tile", f"filename :{export_path}")) + # logger.info(("Export tile", f"filename :{export_path}")) def estimate_nb_latent_spaces( @@ -379,10 +378,15 @@ def mmdc_tile_inference( model = get_mmdc_full_model( checkpoint=model_checkpoint_path, ) + scales = get_scales() + model.set_scales(scales) + # move to device model.to(device) # iterate over the dates in the time serie - for tuile, df_row in inference_dataframe.iterrows(): + for tuile, df_row in tqdm( + inference_dataframe.iterrows(), total=inference_dataframe.shape[0] + ): # estimate nb latent spaces latent_space_size = estimate_nb_latent_spaces( df_row["patch_s2_availability"], @@ -398,8 +402,8 @@ def mmdc_tile_inference( date_ = df_row["date"].strftime("%Y-%m-%d") export_file = export_path / f"latent_tile_infer_{date_}_{'_'.join(sensor)}.tif" # - print(tuile, df_row) - print(latent_space_size) + # print(tuile, df_row) + # print(latent_space_size) # define process mmdc_process = MMDCProcess( count=latent_space_size, @@ -427,4 +431,4 @@ def mmdc_tile_inference( process=mmdc_process, ) - logger.info("Export Finish !!!") + # logger.info("Export Finish !!!") -- GitLab From d6dd8459758931a14936f66aaad557553f9a47e4 Mon Sep 17 00:00:00 2001 From: Juan Sebastian Vinasco-Salinas <js.vinasco.s@gmail.com> Date: Fri, 17 Mar 2023 16:34:00 +0000 Subject: [PATCH 54/81] make sure inference is a module --- src/mmdc_singledate/inference/__init__.py | 1 + .../inference/components/__init__.py | 1 + src/mmdc_singledate/inference/utils.py | 45 +++++++++++++++++++ 3 files changed, 47 insertions(+) create mode 100644 src/mmdc_singledate/inference/__init__.py create mode 100644 src/mmdc_singledate/inference/components/__init__.py create mode 100644 src/mmdc_singledate/inference/utils.py diff --git a/src/mmdc_singledate/inference/__init__.py b/src/mmdc_singledate/inference/__init__.py new file mode 100644 index 00000000..e5a0d9b4 --- /dev/null +++ b/src/mmdc_singledate/inference/__init__.py @@ -0,0 +1 @@ +#!/usr/bin/env python3 diff --git a/src/mmdc_singledate/inference/components/__init__.py b/src/mmdc_singledate/inference/components/__init__.py new file mode 100644 index 00000000..e5a0d9b4 --- /dev/null +++ b/src/mmdc_singledate/inference/components/__init__.py @@ -0,0 +1 @@ +#!/usr/bin/env python3 diff --git a/src/mmdc_singledate/inference/utils.py b/src/mmdc_singledate/inference/utils.py new file mode 100644 index 00000000..53adf910 --- /dev/null +++ b/src/mmdc_singledate/inference/utils.py @@ -0,0 +1,45 @@ +#!/usr/bin/env python3 + +import torch + +from mmdc_singledate.datamodules.types import MMDCDataStats, MMDCShiftScales, ShiftScale + + +def get_scales(): + """ + Read Scales for Inference + """ + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + dataset_dir = "/work/CESBIO/projects/MAESTRIA/training_dataset2/tftwmwc3" + tile = "T31TEM" + patch_size = "256x256" + roi = 1 + stats = MMDCDataStats( + torch.load(f"{dataset_dir}/{tile}/{tile}_stats_s2_{patch_size}_{roi}.pth"), + torch.load(f"{dataset_dir}/{tile}/{tile}_stats_s1_{patch_size}_{roi}.pth"), + torch.load( + f"{dataset_dir}/{tile}/{tile}_stats_worldclim_{patch_size}_{roi}.pth" + ), + torch.load(f"{dataset_dir}/{tile}/{tile}_stats_srtm_{patch_size}_{roi}.pth"), + ) + scale_regul = torch.nn.Threshold(1e-10, 1.0) + shift_scale_s2 = ShiftScale( + stats.sen2.median.to(device), + scale_regul((stats.sen2.qmax - stats.sen2.qmin) / 2.0).to(device), + ) + shift_scale_s1 = ShiftScale( + stats.sen1.median.to(device), + scale_regul((stats.sen1.qmax - stats.sen1.qmin) / 2.0).to(device), + ) + shift_scale_wc = ShiftScale( + stats.worldclim.median.to(device), + scale_regul((stats.worldclim.qmax - stats.worldclim.qmin) / 2.0).to(device), + ) + shift_scale_srtm = ShiftScale( + stats.srtm.median.to(device), + scale_regul((stats.srtm.qmax - stats.srtm.qmin) / 2.0).to(device), + ) + + return MMDCShiftScales( + shift_scale_s2, shift_scale_s1, shift_scale_wc, shift_scale_srtm + ) -- GitLab From 028ce198d20f1aba7785dce1bb9b1f9a29b49a3a Mon Sep 17 00:00:00 2001 From: Juan Sebastian Vinasco-Salinas <js.vinasco.s@gmail.com> Date: Fri, 17 Mar 2023 16:34:35 +0000 Subject: [PATCH 55/81] iota2 init --- iota2_init.sh | 3 +++ 1 file changed, 3 insertions(+) create mode 100644 iota2_init.sh diff --git a/iota2_init.sh b/iota2_init.sh new file mode 100644 index 00000000..4d1f386f --- /dev/null +++ b/iota2_init.sh @@ -0,0 +1,3 @@ +module purge +module load conda +conda activate /work/scratch/${USER}/virtualenv/mmdc-iota2 -- GitLab From e35a7c0fbe1ccaffb0a87fcdb610ae73bf46cd5c Mon Sep 17 00:00:00 2001 From: Juan Sebastian Vinasco-Salinas <js.vinasco.s@gmail.com> Date: Fri, 17 Mar 2023 16:36:14 +0000 Subject: [PATCH 56/81] update to main --- src/mmdc_singledate/inference/components/inference_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/mmdc_singledate/inference/components/inference_utils.py b/src/mmdc_singledate/inference/components/inference_utils.py index 7963eb1a..1d79f1f0 100644 --- a/src/mmdc_singledate/inference/components/inference_utils.py +++ b/src/mmdc_singledate/inference/components/inference_utils.py @@ -204,7 +204,7 @@ def read_img_tile( yield sensor_data # if not exists create a zeros tensor and yield else: - print("Creating fake data") + # print("Creating fake data") for roi in rois: null_data = torch.ones(roi.width, roi.height) sensor_data = sensor_func(null_data, availabitity, filename) @@ -238,7 +238,7 @@ def read_s2_img_tile( image_s2 = torch.ones(10, s2_tensor.shape[0], s2_tensor.shape[1]) angles_s2 = torch.ones(6, s2_tensor.shape[0], s2_tensor.shape[1]) mask = torch.zeros(1, s2_tensor.shape[0], s2_tensor.shape[1]) - print("passing zero data") + # print("passing zero data") return S2Components(s2_reflectances=image_s2, s2_angles=angles_s2, s2_mask=mask) -- GitLab From f999fd67b7f8a83dc01f89e7dc72707283d79a0a Mon Sep 17 00:00:00 2001 From: Juan Sebastian Vinasco-Salinas <js.vinasco.s@gmail.com> Date: Mon, 20 Mar 2023 16:48:54 +0000 Subject: [PATCH 57/81] ignore iota2 thirdparty --- .gitignore | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index 281e6999..207ceac4 100644 --- a/.gitignore +++ b/.gitignore @@ -14,4 +14,4 @@ iota2_thirdparties .coverage src/MMDC_SingleDate.egg-info/ .projectile -iota2_thirdparties/ +iota2_thirdparties -- GitLab From bb0673f73fab4d11b810812fa45c9333b6921cc9 Mon Sep 17 00:00:00 2001 From: Juan Sebastian Vinasco-Salinas <js.vinasco.s@gmail.com> Date: Mon, 20 Mar 2023 16:52:32 +0000 Subject: [PATCH 58/81] update thirdaparty folders --- .gitignore | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.gitignore b/.gitignore index 207ceac4..26aaaf09 100644 --- a/.gitignore +++ b/.gitignore @@ -9,9 +9,9 @@ _lightning_logs/* src/models/*.ipynb *.swp *~ -thirdparties +thirdparties/* iota2_thirdparties .coverage src/MMDC_SingleDate.egg-info/ .projectile -iota2_thirdparties +iota2_thirdparties/* -- GitLab From 52a70cbf1de8defc5a3e1bb08756f8ef5996bfaa Mon Sep 17 00:00:00 2001 From: Juan Sebastian Vinasco-Salinas <js.vinasco.s@gmail.com> Date: Tue, 21 Mar 2023 15:59:13 +0000 Subject: [PATCH 59/81] update new parameters --- jobs/inference_single_date.pbs | 5 +++-- jobs/inference_time_serie.pbs | 7 ++++--- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/jobs/inference_single_date.pbs b/jobs/inference_single_date.pbs index 989f6ec6..1374be03 100644 --- a/jobs/inference_single_date.pbs +++ b/jobs/inference_single_date.pbs @@ -13,7 +13,7 @@ export WORKING_DIR=/work/scratch/${USER}/MMDC/jobs export DATA_DIR=/work/CESBIO/projects/MAESTRIA/test_onetile/total/ export SCRATCH_DIR=/work/scratch/${USER}/ -cd ${WORKING_DIR} +cd ${SRCDIR} mkdir ${SCRATCH_DIR}/MMDC/inference/ mkdir ${SCRATCH_DIR}/MMDC/inference/singledate @@ -27,7 +27,8 @@ python ${SRCDIR}/src/bin/inference_mmdc_singledate.py \ --srtm_filename ${DATA_DIR}/T31TCJ/srtm_T31TCJ_roi_0.tif \ --worldclim_filename ${DATA_DIR}/T31TCJ/wc_clim_1_T31TCJ_roi_0.tif \ --export_path ${SCRATCH_DIR}/MMDC/inference/singledate \ - --model_checkpoint_path /home/uz/vinascj/scratch/MMDC/results/latent/logs/experiments/runs/mmdc_full/2023-01-06_13-00-34/checkpoints/epoch_008.ckpt \ + --model_checkpoint_path /home/uz/vinascj/scratch/MMDC/results/latent/logs/experiments/runs/mmdc_full/2023-03-16_14-00-10/checkpoints/last.ckpt \ + --model_config_path /home/uz/vinascj/src/MMDC/mmdc-singledate/configs/model/mmdc_full.yaml \ --latent_space_size 16 \ --patch_size 128 \ --nb_lines 256 \ diff --git a/jobs/inference_time_serie.pbs b/jobs/inference_time_serie.pbs index 582bf198..0aaf57e2 100644 --- a/jobs/inference_time_serie.pbs +++ b/jobs/inference_time_serie.pbs @@ -18,14 +18,15 @@ mkdir ${SCRATCH_DIR}/MMDC/inference/ mkdir ${SCRATCH_DIR}/MMDC/inference/time_serie -cd ${WORKING_DIR} +cd ${SRCDIR} source ${MMDC_INIT} python ${SRCDIR}/src/bin/inference_mmdc_timeserie.py \ --input_path ${DATA_DIR} \ --export_path ${SCRATCH_DIR}/MMDC/inference/time_serie \ - --tile_list ${DATA_DIR}/T33TUL \ - --model_checkpoint_path /home/uz/vinascj/scratch/MMDC/results/latent/logs/experiments/runs/mmdc_full/2023-01-06_13-00-34/checkpoints/epoch_008.ckpt \ + --tile_list T33TUL \ + --model_checkpoint_path /home/uz/vinascj/scratch/MMDC/results/latent/logs/experiments/runs/mmdc_full/2023-03-16_14-00-10/checkpoints/last.ckpt \ + --model_config_path /home/uz/vinascj/src/MMDC/mmdc-singledate/configs/model/mmdc_full.yaml \ --days_gap 0 \ --patch_size 128 \ --nb_lines 256 \ -- GitLab From ecea1d9fb9069ccb2484fa7b5fe3206b65729c1e Mon Sep 17 00:00:00 2001 From: Juan Sebastian Vinasco-Salinas <js.vinasco.s@gmail.com> Date: Tue, 21 Mar 2023 16:01:28 +0000 Subject: [PATCH 60/81] update new parameters --- src/bin/inference_mmdc_singledate.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/src/bin/inference_mmdc_singledate.py b/src/bin/inference_mmdc_singledate.py index 6d8e179f..5df740b3 100644 --- a/src/bin/inference_mmdc_singledate.py +++ b/src/bin/inference_mmdc_singledate.py @@ -74,6 +74,13 @@ def get_parser() -> argparse.ArgumentParser: required=True, ) + arg_parser.add_argument( + "--model_config_path", + type=str, + help="model configuration path ", + required=True, + ) + arg_parser.add_argument( "--latent_space_size", type=int, @@ -134,6 +141,8 @@ def main(): # model mmdc_full_model = get_mmdc_full_model( checkpoint=args.model_checkpoint_path, + mmdc_full_config_path=args.model_config_path, + inference_tile=True, ) scales = get_scales() mmdc_full_model.set_scales(scales) -- GitLab From faff5329a1b41f303d07d77ebaa1781aec9153cd Mon Sep 17 00:00:00 2001 From: Juan Sebastian Vinasco-Salinas <js.vinasco.s@gmail.com> Date: Tue, 21 Mar 2023 16:09:46 +0000 Subject: [PATCH 61/81] add read config functions --- src/mmdc_singledate/inference/utils.py | 196 ++++++++++++++++++++++++- 1 file changed, 195 insertions(+), 1 deletion(-) diff --git a/src/mmdc_singledate/inference/utils.py b/src/mmdc_singledate/inference/utils.py index 53adf910..6b6a3c45 100644 --- a/src/mmdc_singledate/inference/utils.py +++ b/src/mmdc_singledate/inference/utils.py @@ -1,8 +1,28 @@ #!/usr/bin/env python3 +from dataclasses import fields + +import pydantic import torch +import yaml +from config import Config -from mmdc_singledate.datamodules.types import MMDCDataStats, MMDCShiftScales, ShiftScale +from mmdc_singledate.datamodules.types import ( + MMDCDataChannels, + MMDCDataStats, + MMDCShiftScales, + ShiftScale, +) +from mmdc_singledate.models.types import ( # S1S2VAELatentSpace, + ConvnetParams, + MLPConfig, + MMDCDataUse, + ModularEmbeddingConfig, + MultiVAEConfig, + MultiVAELossWeights, + TranslationLossWeights, + UnetParams, +) def get_scales(): @@ -43,3 +63,177 @@ def get_scales(): return MMDCShiftScales( shift_scale_s2, shift_scale_s1, shift_scale_wc, shift_scale_srtm ) + + +# DONE add pydantic validation +def parse_from_yaml(path_to_yaml: str): + """ + Read a yaml file and return a dict + https://betterprogramming.pub/validating-yaml-configs-made-easy-with-pydantic-594522612db5 + """ + with open(path_to_yaml) as f: + config = yaml.safe_load(f) + + return config + + +def parse_multivaeconfig_to_pydantic(cfg: dict): + """ + Convert a MultiVAEConfig to a pydantic version + for validate the values + :args: cfg : configuration as dict + """ + # parse the field of the class to a dict + multivae_fields = { + _field.name: (_field.type, ...) for _field in fields(MultiVAEConfig) + } + # create pydantic model from a dataclass + pydantic_multivaefields = pydantic.create_model("MultiVAEConfig", **multivae_fields) + # apply the convertion + multivaeconfig_pydantic = pydantic_multivaefields(**cfg["model"]["config"]) + + return multivaeconfig_pydantic + + +# DONE add pydantic validation +def parse_from_cfg(path_to_cfg: str): + """ + Read a cfg as config file validate the values + and return a dict + """ + with open(path_to_cfg, encoding="UTF-8") as i2_config: + cfg = Config(i2_config) + cfg_dict = cfg.as_dict() + + return cfg_dict + + +def get_mmdc_full_config(path_to_cfg: str, inference_tile: bool): + """ " + Get the parameters for instantiate the mmdc_full module + based on the target inference mode + """ + + # read the yaml hydra config + if inference_tile: + cfg = parse_from_yaml(path_to_cfg) # "configs/model/mmdc_full.yaml") + cfg = parse_multivaeconfig_to_pydantic(cfg) + + else: + cfg = parse_from_cfg(path_to_cfg) + cfg = parse_multivaeconfig_to_pydantic(cfg) + + mmdc_full_config = MultiVAEConfig( + data_sizes=MMDCDataChannels( + sen1=cfg.data_sizes.sen1, + s1_angles=cfg.data_sizes.s1_angles, + sen2=cfg.data_sizes.sen2, + s2_angles=cfg.data_sizes.s2_angles, + srtm=cfg.data_sizes.srtm, + w_c=cfg.data_sizes.w_c, + ), + embeddings=ModularEmbeddingConfig( + s1_angles=MLPConfig( + hidden_layers=cfg.embeddings.s1_angles.hidden_layers, + out_channels=cfg.embeddings.s1_angles.out_channels, + ), + s1_srtm=UnetParams( + out_channels=cfg.embeddings.s1_srtm.out_channels, + encoder_sizes=cfg.embeddings.s1_srtm.encoder_sizes, + kernel_size=cfg.embeddings.s1_srtm.kernel_size, + tail_layers=cfg.embeddings.s1_srtm.tail_layers, + ), + s2_angles=ConvnetParams( + out_channels=cfg.embeddings.s2_angles.out_channels, + sizes=cfg.embeddings.s2_angles.sizes, + kernel_sizes=cfg.embeddings.s2_angles.kernel_sizes, + ), + s2_srtm=UnetParams( + out_channels=cfg.embeddings.s2_srtm.out_channels, + encoder_sizes=cfg.embeddings.s2_srtm.encoder_sizes, + kernel_size=cfg.embeddings.s2_srtm.kernel_size, + tail_layers=cfg.embeddings.s2_srtm.tail_layers, + ), + w_c=ConvnetParams( + out_channels=cfg.embeddings.w_c.out_channels, + sizes=cfg.embeddings.w_c.sizes, + kernel_sizes=cfg.embeddings.w_c.kernel_sizes, + ), + ), + s1_encoder=UnetParams( + out_channels=cfg.s1_encoder.out_channels, + encoder_sizes=cfg.s1_encoder.encoder_sizes, + kernel_size=cfg.s1_encoder.kernel_size, + tail_layers=cfg.s1_encoder.tail_layers, + ), + s2_encoder=UnetParams( + out_channels=cfg.s2_encoder.out_channels, + encoder_sizes=cfg.s2_encoder.encoder_sizes, + kernel_size=cfg.s2_encoder.kernel_size, + tail_layers=cfg.s2_encoder.tail_layers, + ), + s1_decoder=ConvnetParams( + out_channels=cfg.s1_decoder.out_channels, + sizes=cfg.s1_decoder.sizes, + kernel_sizes=cfg.s1_decoder.kernel_sizes, + ), + s2_decoder=ConvnetParams( + out_channels=cfg.s2_decoder.out_channels, + sizes=cfg.s2_decoder.sizes, + kernel_sizes=cfg.s2_decoder.kernel_sizes, + ), + s1_enc_use=MMDCDataUse( + s1_angles=cfg.s1_enc_use.s1_angles, + s2_angles=cfg.s1_enc_use.s2_angles, + srtm=cfg.s1_enc_use.srtm, + w_c=cfg.s1_enc_use.w_c, + ), + s2_enc_use=MMDCDataUse( + s1_angles=cfg.s2_enc_use.s1_angles, + s2_angles=cfg.s2_enc_use.s2_angles, + srtm=cfg.s2_enc_use.srtm, + w_c=cfg.s2_enc_use.w_c, + ), + s1_dec_use=MMDCDataUse( + s1_angles=cfg.s1_dec_use.s1_angles, + s2_angles=cfg.s1_dec_use.s2_angles, + srtm=cfg.s1_dec_use.srtm, + w_c=cfg.s1_dec_use.w_c, + ), + s2_dec_use=MMDCDataUse( + s1_angles=cfg.s2_dec_use.s1_angles, + s2_angles=cfg.s2_dec_use.s2_angles, + srtm=cfg.s2_dec_use.srtm, + w_c=cfg.s2_dec_use.w_c, + ), + loss_weights=MultiVAELossWeights( + sen1=TranslationLossWeights( + nll=cfg.loss_weights.sen1.nll, + lpips=cfg.loss_weights.sen1.lpips, + sam=cfg.loss_weights.sen1.sam, + gradients=cfg.loss_weights.sen1.gradients, + ), + sen2=TranslationLossWeights( + nll=cfg.loss_weights.sen2.nll, + lpips=cfg.loss_weights.sen2.lpips, + sam=cfg.loss_weights.sen2.sam, + gradients=cfg.loss_weights.sen2.gradients, + ), + s1_s2=TranslationLossWeights( + nll=cfg.loss_weights.s1_s2.nll, + lpips=cfg.loss_weights.s1_s2.lpips, + sam=cfg.loss_weights.s1_s2.sam, + gradients=cfg.loss_weights.s1_s2.gradients, + ), + s2_s1=TranslationLossWeights( + nll=cfg.loss_weights.s2_s1.nll, + lpips=cfg.loss_weights.s2_s1.lpips, + sam=cfg.loss_weights.s2_s1.sam, + gradients=cfg.loss_weights.s2_s1.gradients, + ), + forward=cfg.loss_weights.forward, + cross=cfg.loss_weights.forward, + latent=cfg.loss_weights.latent, + ), + ) + return mmdc_full_config -- GitLab From 06690df09ef0f6d70047da21d7b93e60baebab99 Mon Sep 17 00:00:00 2001 From: Juan Sebastian Vinasco-Salinas <js.vinasco.s@gmail.com> Date: Tue, 21 Mar 2023 16:17:08 +0000 Subject: [PATCH 62/81] update config parameters --- src/bin/inference_mmdc_timeserie.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/bin/inference_mmdc_timeserie.py b/src/bin/inference_mmdc_timeserie.py index 56626f63..4db518f2 100644 --- a/src/bin/inference_mmdc_timeserie.py +++ b/src/bin/inference_mmdc_timeserie.py @@ -58,6 +58,13 @@ def get_parser() -> argparse.ArgumentParser: required=True, ) + arg_parser.add_argument( + "--model_config_path", + type=str, + help="model configuration path ", + required=True, + ) + arg_parser.add_argument( "--tile_list", nargs="+", @@ -120,6 +127,7 @@ def main(): inference_dataframe=tile_df, export_path=Path(args.export_path), model_checkpoint_path=Path(args.model_checkpoint_path), + model_config_path=Path(args.model_config_path), patch_size=args.patch_size, nb_lines=args.nb_lines, ) -- GitLab From 6e323282384b0d6f992ac2474ed04ebd00523c63 Mon Sep 17 00:00:00 2001 From: Juan Sebastian Vinasco-Salinas <js.vinasco.s@gmail.com> Date: Tue, 21 Mar 2023 16:21:05 +0000 Subject: [PATCH 63/81] update config parameters --- test/test_mmdc_inference.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/test/test_mmdc_inference.py b/test/test_mmdc_inference.py index 51be142f..fd15d093 100644 --- a/test/test_mmdc_inference.py +++ b/test/test_mmdc_inference.py @@ -195,7 +195,9 @@ def test_predict_single_date_tile(sensors): checkpoint_filename = "last.ckpt" # "last.ckpt" mmdc_full_model = get_mmdc_full_model( - os.path.join(checkpoint_path, checkpoint_filename) + os.path.join(checkpoint_path, checkpoint_filename), + mmdc_full_config_path="/home/uz/vinascj/src/MMDC/mmdc-singledate/configs/model/mmdc_full.json", + inference_tile=True, ) scales = get_scales() mmdc_full_model.set_scales(scales) @@ -267,7 +269,9 @@ def test_predict_single_date_tile_no_data( checkpoint_filename = "last.ckpt" # "last.ckpt" mmdc_full_model = get_mmdc_full_model( - os.path.join(checkpoint_path, checkpoint_filename) + os.path.join(checkpoint_path, checkpoint_filename), + mmdc_full_config_path="/home/uz/vinascj/src/MMDC/mmdc-singledate/configs/model/mmdc_full.json", + inference_tile=True, ) # move to device scales = get_scales() @@ -308,7 +312,9 @@ def test_mmdc_tile_inference(): input_path = Path("/work/CESBIO/projects/MAESTRIA/training_dataset2/") export_path = Path("/work/CESBIO/projects/MAESTRIA/test_onetile/total/export/") model_checkpoint_path = "/home/uz/vinascj/scratch/MMDC/results/latent/logs/experiments/runs/mmdc_full/2023-03-16_14-00-10/checkpoints/last.ckpt" - + model_config_path = ( + "/home/uz/vinascj/src/MMDC/mmdc-singledate/configs/model/mmdc_full.yaml" + ) input_tile_list = ["/work/CESBIO/projects/MAESTRIA/training_dataset2/T33TUL"] # dataframe with input data @@ -422,6 +428,7 @@ def test_mmdc_tile_inference(): inference_dataframe=tile_df, export_path=export_path, model_checkpoint_path=model_checkpoint_path, + model_config_path=model_config_path, patch_size=128, nb_lines=256, ) -- GitLab From 2fb8131c18b3edeeecb4ff3c353bb3033329e735 Mon Sep 17 00:00:00 2001 From: Juan Sebastian Vinasco-Salinas <js.vinasco.s@gmail.com> Date: Tue, 21 Mar 2023 16:22:12 +0000 Subject: [PATCH 64/81] add requirement --- requirements-mmdc-sgld.txt | 3 +++ 1 file changed, 3 insertions(+) diff --git a/requirements-mmdc-sgld.txt b/requirements-mmdc-sgld.txt index 02defbac..c4b218d1 100644 --- a/requirements-mmdc-sgld.txt +++ b/requirements-mmdc-sgld.txt @@ -1,6 +1,7 @@ affine auto-walrus black # code formatting +config findpeaks flake8 # code analysis geopandas @@ -18,6 +19,8 @@ numpy pandas pre-commit # hooks for applying linters on commit pudb # debugger +pydantic # validate configurations in iota2 +pydantic-yaml # read yml configs for construct the networks pytest # tests python-dotenv # loading env variables from .env file python-lsp-server[all] -- GitLab From 1cfdcdf1e91de973e274aee6785ad48900545ae1 Mon Sep 17 00:00:00 2001 From: Juan Sebastian Vinasco-Salinas <js.vinasco.s@gmail.com> Date: Tue, 21 Mar 2023 16:32:51 +0000 Subject: [PATCH 65/81] put config logic in other function --- .../components/inference_components.py | 527 +++++++----------- 1 file changed, 195 insertions(+), 332 deletions(-) diff --git a/src/mmdc_singledate/inference/components/inference_components.py b/src/mmdc_singledate/inference/components/inference_components.py index e3a9f52b..30fb377e 100644 --- a/src/mmdc_singledate/inference/components/inference_components.py +++ b/src/mmdc_singledate/inference/components/inference_components.py @@ -11,22 +11,12 @@ from pathlib import Path from typing import Literal import torch -import yaml from torch import nn -from mmdc_singledate.datamodules.types import MMDCDataChannels +# from mmdc_singledate.datamodules.types import MMDCDataChannels +from mmdc_singledate.inference.utils import get_mmdc_full_config from mmdc_singledate.models.mmdc_full_module import MMDCFullModule -from mmdc_singledate.models.types import ( - ConvnetParams, - MLPConfig, - MMDCDataUse, - ModularEmbeddingConfig, - MultiVAEConfig, - MultiVAELossWeights, - S1S2VAELatentSpace, - TranslationLossWeights, - UnetParams, -) +from mmdc_singledate.models.types import S1S2VAELatentSpace # define sensors variable for typing SENSORS = Literal["S2L2A", "S1FULL", "S1ASC", "S1DESC"] @@ -39,189 +29,32 @@ logging.basicConfig( logger = logging.getLogger(__name__) -def parse_from_yaml(path_to_yaml): - """ - Read a yaml file and return a dict - https://betterprogramming.pub/validating-yaml-configs-made-easy-with-pydantic-594522612db5 - """ - with open(path_to_yaml) as f: - config = yaml.safe_load(f) - - return config - - @torch.no_grad() def get_mmdc_full_model( checkpoint: str, + mmdc_full_config_path: str, + inference_tile: bool, ) -> nn.Module: """ Instantiate the network """ - if not Path(checkpoint).exists: logger.info("No checkpoint path was give. ") return 1 else: - cfg = parse_from_yaml("configs/model/mmdc_full.yaml") - - mmdc_full_config = MultiVAEConfig( - data_sizes=MMDCDataChannels( - sen1=cfg["model"]["config"]["data_sizes"]["sen1"], - s1_angles=cfg["model"]["config"]["data_sizes"]["s1_angles"], - sen2=cfg["model"]["config"]["data_sizes"]["sen2"], - s2_angles=cfg["model"]["config"]["data_sizes"]["s2_angles"], - srtm=cfg["model"]["config"]["data_sizes"]["srtm"], - w_c=cfg["model"]["config"]["data_sizes"]["w_c"], - ), - embeddings=ModularEmbeddingConfig( - s1_angles=MLPConfig( - hidden_layers=cfg["model"]["config"]["embeddings"]["s1_angles"][ - "hidden_layers" - ], - out_channels=cfg["model"]["config"]["embeddings"]["s1_angles"][ - "out_channels" - ], - ), - s1_srtm=UnetParams( - out_channels=cfg["model"]["config"]["embeddings"]["s1_srtm"][ - "out_channels" - ], - encoder_sizes=cfg["model"]["config"]["embeddings"]["s1_srtm"][ - "encoder_sizes" - ], - kernel_size=cfg["model"]["config"]["embeddings"]["s1_srtm"][ - "kernel_size" - ], - tail_layers=cfg["model"]["config"]["embeddings"]["s1_srtm"][ - "tail_layers" - ], - ), - s2_angles=ConvnetParams( - out_channels=cfg["model"]["config"]["embeddings"]["s2_angles"][ - "out_channels" - ], - sizes=cfg["model"]["config"]["embeddings"]["s2_angles"]["sizes"], - kernel_sizes=cfg["model"]["config"]["embeddings"]["s2_angles"][ - "kernel_sizes" - ], - ), - s2_srtm=UnetParams( - out_channels=cfg["model"]["config"]["embeddings"]["s2_srtm"][ - "out_channels" - ], - encoder_sizes=cfg["model"]["config"]["embeddings"]["s2_srtm"][ - "encoder_sizes" - ], - kernel_size=cfg["model"]["config"]["embeddings"]["s2_srtm"][ - "kernel_size" - ], - tail_layers=cfg["model"]["config"]["embeddings"]["s2_srtm"][ - "tail_layers" - ], - ), - w_c=ConvnetParams( - out_channels=cfg["model"]["config"]["embeddings"]["w_c"][ - "out_channels" - ], - sizes=cfg["model"]["config"]["embeddings"]["w_c"]["sizes"], - kernel_sizes=cfg["model"]["config"]["embeddings"]["w_c"][ - "kernel_sizes" - ], - ), - ), - s1_encoder=UnetParams( - out_channels=cfg["model"]["config"]["s1_encoder"]["out_channels"], - encoder_sizes=cfg["model"]["config"]["s1_encoder"]["encoder_sizes"], - kernel_size=cfg["model"]["config"]["s1_encoder"]["kernel_size"], - tail_layers=cfg["model"]["config"]["s1_encoder"]["tail_layers"], - ), - s2_encoder=UnetParams( - out_channels=cfg["model"]["config"]["s2_encoder"]["out_channels"], - encoder_sizes=cfg["model"]["config"]["s2_encoder"]["encoder_sizes"], - kernel_size=cfg["model"]["config"]["s2_encoder"]["kernel_size"], - tail_layers=cfg["model"]["config"]["s2_encoder"]["tail_layers"], - ), - s1_decoder=ConvnetParams( - out_channels=cfg["model"]["config"]["s1_decoder"]["out_channels"], - sizes=cfg["model"]["config"]["s1_decoder"]["sizes"], - kernel_sizes=cfg["model"]["config"]["s1_decoder"]["kernel_sizes"], - ), - s2_decoder=ConvnetParams( - out_channels=cfg["model"]["config"]["s2_decoder"]["out_channels"], - sizes=cfg["model"]["config"]["s2_decoder"]["sizes"], - kernel_sizes=cfg["model"]["config"]["s2_decoder"]["kernel_sizes"], - ), - s1_enc_use=MMDCDataUse( - s1_angles=cfg["model"]["config"]["s1_enc_use"]["s1_angles"], - s2_angles=cfg["model"]["config"]["s1_enc_use"]["s2_angles"], - srtm=cfg["model"]["config"]["s1_enc_use"]["srtm"], - w_c=cfg["model"]["config"]["s1_enc_use"]["w_c"], - ), - s2_enc_use=MMDCDataUse( - s1_angles=cfg["model"]["config"]["s2_enc_use"]["s1_angles"], - s2_angles=cfg["model"]["config"]["s2_enc_use"]["s2_angles"], - srtm=cfg["model"]["config"]["s2_enc_use"]["srtm"], - w_c=cfg["model"]["config"]["s2_enc_use"]["w_c"], - ), - s1_dec_use=MMDCDataUse( - s1_angles=cfg["model"]["config"]["s1_dec_use"]["s1_angles"], - s2_angles=cfg["model"]["config"]["s1_dec_use"]["s2_angles"], - srtm=cfg["model"]["config"]["s1_dec_use"]["srtm"], - w_c=cfg["model"]["config"]["s1_dec_use"]["w_c"], - ), - s2_dec_use=MMDCDataUse( - s1_angles=cfg["model"]["config"]["s2_dec_use"]["s1_angles"], - s2_angles=cfg["model"]["config"]["s2_dec_use"]["s2_angles"], - srtm=cfg["model"]["config"]["s2_dec_use"]["srtm"], - w_c=cfg["model"]["config"]["s2_dec_use"]["w_c"], - ), - loss_weights=MultiVAELossWeights( - sen1=TranslationLossWeights( - nll=cfg["model"]["config"]["loss_weights"]["sen1"]["nll"], - lpips=cfg["model"]["config"]["loss_weights"]["sen1"]["lpips"], - sam=cfg["model"]["config"]["loss_weights"]["sen1"]["sam"], - gradients=cfg["model"]["config"]["loss_weights"]["sen1"][ - "gradients" - ], - ), - sen2=TranslationLossWeights( - nll=cfg["model"]["config"]["loss_weights"]["sen2"]["nll"], - lpips=cfg["model"]["config"]["loss_weights"]["sen2"]["lpips"], - sam=cfg["model"]["config"]["loss_weights"]["sen2"]["sam"], - gradients=cfg["model"]["config"]["loss_weights"]["sen2"][ - "gradients" - ], - ), - s1_s2=TranslationLossWeights( - nll=cfg["model"]["config"]["loss_weights"]["s1_s2"]["nll"], - lpips=cfg["model"]["config"]["loss_weights"]["s1_s2"]["lpips"], - sam=cfg["model"]["config"]["loss_weights"]["s1_s2"]["sam"], - gradients=cfg["model"]["config"]["loss_weights"]["s1_s2"][ - "gradients" - ], - ), - s2_s1=TranslationLossWeights( - nll=cfg["model"]["config"]["loss_weights"]["s2_s1"]["nll"], - lpips=cfg["model"]["config"]["loss_weights"]["s2_s1"]["lpips"], - sam=cfg["model"]["config"]["loss_weights"]["s2_s1"]["sam"], - gradients=cfg["model"]["config"]["loss_weights"]["s2_s1"][ - "gradients" - ], - ), - forward=cfg["model"]["config"]["loss_weights"]["forward"], - cross=cfg["model"]["config"]["loss_weights"]["forward"], - latent=cfg["model"]["config"]["loss_weights"]["latent"], - ), + # read the config file + mmdc_full_config = get_mmdc_full_config( + mmdc_full_config_path, inference_tile=inference_tile ) - # Init the model + # Init the model passing the config mmdc_full_model = MMDCFullModule( mmdc_full_config, ) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - # print(f"device detected : {device}") - # print(f"nb_gpu's detected : {torch.cuda.device_count()}") + logger.debug(f"device detected : {device}") + logger.debug(f"nb_gpu's detected : {torch.cuda.device_count()}") # load state_dict lightning_checkpoint = torch.load(checkpoint, map_location=device) @@ -276,7 +109,7 @@ def predict_mmdc_model( # available_sensors = ["S2L2A", "S1FULL", "S1ASC", "S1DESC"] device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - # print(device) + logger.debug(device) prediction = model.predict( s2_ref.to(device), @@ -290,10 +123,10 @@ def predict_mmdc_model( srtm.to(device), ) - # print("prediction.shape :", prediction[0].latent.latent_s2.mean.shape) - # print("prediction.shape :", prediction[0].latent.latent_s2.logvar.shape) - # print("prediction.shape :", prediction[0].latent.latent_s1.mean.shape) - # print("prediction.shape :", prediction[0].latent.latent_s1.logvar.shape) + logger.debug("prediction.shape :", prediction[0].latent.latent_s2.mean.shape) + logger.debug("prediction.shape :", prediction[0].latent.latent_s2.logvar.shape) + logger.debug("prediction.shape :", prediction[0].latent.latent_s1.mean.shape) + logger.debug("prediction.shape :", prediction[0].latent.latent_s1.logvar.shape) # init the output latent spaces as # empty dataclass @@ -301,154 +134,184 @@ def predict_mmdc_model( # fullfit with a matchcase # match - match sensors: - # Cases - case ["S2L2A", "S1FULL"]: - # logger.info("S2 captured & S1 full captured") - - latent_space.latent_s2 = prediction[0].latent.latent_s2 - latent_space.latent_s1 = prediction[0].latent.latent_s1 - - latent_space_stack = torch.cat( - ( - latent_space.latent_s2.mean, - latent_space.latent_s2.logvar, - latent_space.latent_s1.mean, - latent_space.latent_s1.logvar, - ), - 1, - ) - - # print( - # "latent_space.latent_s2.mean :", - # latent_space.latent_s2.mean.shape, - # ) - # print( - # "latent_space.latent_s2.logvar :", - # latent_space.latent_s2.logvar.shape, - # ) - # print("latent_space.latent_s1.mean :", latent_space.latent_s1.mean.shape) - # print( - # "latent_space.latent_s1.logvar :", latent_space.latent_s1.logvar.shape - # ) - - case ["S2L2A", "S1ASC"]: - # logger.info("S2 captured & S1 asc captured") - - latent_space.latent_s2 = prediction[0].latent.latent_s2 - latent_space.latent_s1 = prediction[0].latent.latent_s1 - - latent_space_stack = torch.cat( - ( - latent_space.latent_s2.mean, - latent_space.latent_s2.logvar, - latent_space.latent_s1.mean, - latent_space.latent_s1.logvar, - ), - 1, - ) - # print( - # "latent_space.latent_s2.mean :", - # latent_space.latent_s2.mean.shape, - # ) - # print( - # "latent_space.latent_s2.logvar :", - # latent_space.latent_s2.logvar.shape, - # ) - # print("latent_space.latent_s1.mean :", latent_space.latent_s1.mean.shape) - # print( - # "latent_space.latent_s1.logvar :", latent_space.latent_s1.logvar.shape - # ) - - case ["S2L2A", "S1DESC"]: - # logger.info("S2 captured & S1 desc captured") - - latent_space.latent_s2 = prediction[0].latent.latent_s2 - latent_space.latent_s1 = prediction[0].latent.latent_s1 - - latent_space_stack = torch.cat( - ( - latent_space.latent_s2.mean, - latent_space.latent_s2.logvar, - latent_space.latent_s1.mean, - latent_space.latent_s1.logvar, - ), - 1, - ) - # print( - # "latent_space.latent_s2.mean :", - # latent_space.latent_s2.mean.shape, - # ) - # print( - # "latent_space.latent_s2.logvar :", - # latent_space.latent_s2.logvar.shape, - # ) - # print("latent_space.latent_s1.mean :", latent_space.latent_s1.mean.shape) - # print( - # "latent_space.latent_s1.logvar :", latent_space.latent_s1.logvar.shape - # ) - - case ["S2L2A"]: - # logger.info("Only S2 captured") - - latent_space.latent_s2 = prediction[0].latent.latent_s2 - - latent_space_stack = torch.cat( - ( - latent_space.latent_s2.mean, - latent_space.latent_s2.logvar, - ), - 1, - ) - # print( - # "latent_space.latent_s2.mean :", - # latent_space.latent_s2.mean.shape, - # ) - # print( - # "latent_space.latent_s2.logvar :", - # latent_space.latent_s2.logvar.shape, - # ) - - case ["S1FULL"]: - # logger.info("Only S1 full captured") - - latent_space.latent_s1 = prediction[0].latent.latent_s1 - - latent_space_stack = torch.cat( - (latent_space.latent_s1.mean, latent_space.latent_s1.logvar), - 1, - ) - # print("latent_space.latent_s1.mean :", latent_space.latent_s1.mean.shape) - # print( - # "latent_space.latent_s1.logvar :", latent_space.latent_s1.logvar.shape - # ) - - case ["S1ASC"]: - # logger.info("Only S1ASC captured") - - latent_space.latent_s1 = prediction[0].latent.latent_s1 - - latent_space_stack = torch.cat( - (latent_space.latent_s1.mean, latent_space.latent_s1.logvar), - 1, - ) - # print("latent_space.latent_s1.mean :", latent_space.latent_s1.mean.shape) - # print( - # "latent_space.latent_s1.logvar :", latent_space.latent_s1.logvar.shape - # ) - - case ["S1DESC"]: - # logger.info("Only S1DESC captured") - - latent_space.latent_s1 = prediction[0].latent.latent_s1 - - latent_space_stack = torch.cat( - (latent_space.latent_s1.mean, latent_space.latent_s1.logvar), - 1, - ) - # print("latent_space.latent_s1.mean :", latent_space.latent_s1.mean.shape) - # print( - # "latent_space.latent_s1.logvar :", latent_space.latent_s1.logvar.shape - # ) + # match sensors: + # Cases + if sensors == ["S2L2A", "S1FULL"]: + logger.debug("S2 captured & S1 full captured") + + latent_space.latent_s2 = prediction[0].latent.latent_s2 + latent_space.latent_s1 = prediction[0].latent.latent_s1 + + latent_space_stack = torch.cat( + ( + latent_space.latent_s2.mean, + latent_space.latent_s2.logvar, + latent_space.latent_s1.mean, + latent_space.latent_s1.logvar, + ), + 1, + ) + + logger.debug( + "latent_space.latent_s2.mean :", + latent_space.latent_s2.mean.shape, + ) + logger.debug( + "latent_space.latent_s2.logvar :", + latent_space.latent_s2.logvar.shape, + ) + logger.debug("latent_space.latent_s1.mean :", latent_space.latent_s1.mean.shape) + logger.debug( + "latent_space.latent_s1.logvar :", latent_space.latent_s1.logvar.shape + ) + + elif sensors == ["S2L2A", "S1ASC"]: + # logger.debug("S2 captured & S1 asc captured") + + latent_space.latent_s2 = prediction[0].latent.latent_s2 + latent_space.latent_s1 = prediction[0].latent.latent_s1 + + latent_space_stack = torch.cat( + ( + latent_space.latent_s2.mean, + latent_space.latent_s2.logvar, + latent_space.latent_s1.mean, + latent_space.latent_s1.logvar, + ), + 1, + ) + logger.debug( + "latent_space.latent_s2.mean :", + latent_space.latent_s2.mean.shape, + ) + logger.debug( + "latent_space.latent_s2.logvar :", + latent_space.latent_s2.logvar.shape, + ) + logger.debug("latent_space.latent_s1.mean :", latent_space.latent_s1.mean.shape) + logger.debug( + "latent_space.latent_s1.logvar :", latent_space.latent_s1.logvar.shape + ) + + elif sensors == ["S2L2A", "S1DESC"]: + logger.debug("S2 captured & S1 desc captured") + + latent_space.latent_s2 = prediction[0].latent.latent_s2 + latent_space.latent_s1 = prediction[0].latent.latent_s1 + + latent_space_stack = torch.cat( + ( + latent_space.latent_s2.mean, + latent_space.latent_s2.logvar, + latent_space.latent_s1.mean, + latent_space.latent_s1.logvar, + ), + 1, + ) + logger.debug( + "latent_space.latent_s2.mean :", + latent_space.latent_s2.mean.shape, + ) + logger.debug( + "latent_space.latent_s2.logvar :", + latent_space.latent_s2.logvar.shape, + ) + logger.debug("latent_space.latent_s1.mean :", latent_space.latent_s1.mean.shape) + logger.debug( + "latent_space.latent_s1.logvar :", latent_space.latent_s1.logvar.shape + ) + + elif sensors == ["S2L2A"]: + # logger.debug("Only S2 captured") + + latent_space.latent_s2 = prediction[0].latent.latent_s2 + + latent_space_stack = torch.cat( + ( + latent_space.latent_s2.mean, + latent_space.latent_s2.logvar, + ), + 1, + ) + logger.debug( + "latent_space.latent_s2.mean :", + latent_space.latent_s2.mean.shape, + ) + logger.debug( + "latent_space.latent_s2.logvar :", + latent_space.latent_s2.logvar.shape, + ) + + elif sensors == ["S1FULL"]: + # logger.debug("Only S1 full captured") + + latent_space.latent_s1 = prediction[0].latent.latent_s1 + + latent_space_stack = torch.cat( + (latent_space.latent_s1.mean, latent_space.latent_s1.logvar), + 1, + ) + logger.debug("latent_space.latent_s1.mean :", latent_space.latent_s1.mean.shape) + logger.debug( + "latent_space.latent_s1.logvar :", latent_space.latent_s1.logvar.shape + ) + + elif sensors == ["S1ASC"]: + # logger.debug("Only S1ASC captured") + + latent_space.latent_s1 = prediction[0].latent.latent_s1 + + latent_space_stack = torch.cat( + (latent_space.latent_s1.mean, latent_space.latent_s1.logvar), + 1, + ) + logger.debug("latent_space.latent_s1.mean :", latent_space.latent_s1.mean.shape) + logger.debug( + "latent_space.latent_s1.logvar :", latent_space.latent_s1.logvar.shape + ) + + elif sensors == ["S1DESC"]: + # logger.debug("Only S1DESC captured") + + latent_space.latent_s1 = prediction[0].latent.latent_s1 + + latent_space_stack = torch.cat( + (latent_space.latent_s1.mean, latent_space.latent_s1.logvar), + 1, + ) + logger.debug("latent_space.latent_s1.mean :", latent_space.latent_s1.mean.shape) + logger.debug( + "latent_space.latent_s1.logvar :", latent_space.latent_s1.logvar.shape + ) + else: + raise Exception(f"Not a valid sensors ({sensors}) are given") return latent_space_stack # latent_space + + +# Switch to match case in python 3.10 o superior +# fullfit with a matchcase +# for captor in sensors: +# # match +# match captor: +# # Cases +# case ["S2L2A", "S1FULL"]: +# print("S1 full captured & S2 captured") + +# case ["S2L2A", "S1ASC"]: +# print("S1 asc captured & S2 captured") + +# case ["S2L2A", "S1DESC"]: +# print("S1 desc captured & S2 captured") + +# case ["S2L2A"]: +# print("Only S2 captured") + +# case ["S1FULL"]: +# print("Only S1 full captured") + +# case ["S1ASC"]: +# print("Only S1ASC captured") + +# case ["S1DESC"]: +# print("Only S1DESC captured") -- GitLab From cd5117f4fbaf60fe879d83a946bdc6ed7ec5ec0e Mon Sep 17 00:00:00 2001 From: Juan Sebastian Vinasco-Salinas <js.vinasco.s@gmail.com> Date: Tue, 21 Mar 2023 16:33:48 +0000 Subject: [PATCH 66/81] put config logic in other function --- src/mmdc_singledate/inference/components/inference_components.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/mmdc_singledate/inference/components/inference_components.py b/src/mmdc_singledate/inference/components/inference_components.py index 30fb377e..c514a8fc 100644 --- a/src/mmdc_singledate/inference/components/inference_components.py +++ b/src/mmdc_singledate/inference/components/inference_components.py @@ -13,7 +13,6 @@ from typing import Literal import torch from torch import nn -# from mmdc_singledate.datamodules.types import MMDCDataChannels from mmdc_singledate.inference.utils import get_mmdc_full_config from mmdc_singledate.models.mmdc_full_module import MMDCFullModule from mmdc_singledate.models.types import S1S2VAELatentSpace -- GitLab From 854df0bac2dc3d0cc4d5fff7a4cba3c3a9fad3f2 Mon Sep 17 00:00:00 2001 From: Juan Sebastian Vinasco-Salinas <js.vinasco.s@gmail.com> Date: Tue, 21 Mar 2023 16:37:42 +0000 Subject: [PATCH 67/81] update config quit match case --- .../inference/mmdc_tile_inference.py | 164 +++++++++--------- 1 file changed, 86 insertions(+), 78 deletions(-) diff --git a/src/mmdc_singledate/inference/mmdc_tile_inference.py b/src/mmdc_singledate/inference/mmdc_tile_inference.py index ebc619f5..307e3007 100644 --- a/src/mmdc_singledate/inference/mmdc_tile_inference.py +++ b/src/mmdc_singledate/inference/mmdc_tile_inference.py @@ -120,70 +120,75 @@ def predict_single_date_tile( chunks = generate_chunks(meta["width"], meta["height"], process.nb_lines) # get the windows from the chunks rois = [rio.windows.Window(*chunk) for chunk in chunks] - # logger.info(f"chunk size : ({rois[0].width}, {rois[0].height}) ") + logger.debug(f"chunk size : ({rois[0].width}, {rois[0].height}) ") # init the dataset - # logger.info("Reading S2 data") + logger.debug("Reading S2 data") s2_data = read_img_tile( filename=input_data.s2_filename, rois=rois, availabitity=input_data.s2_availabitity, sensor_func=read_s2_img_tile, ) - # logger.info("Reading S1 ASC data") + logger.debug("Reading S1 ASC data") s1_asc_data = read_img_tile( filename=input_data.s1_asc_filename, rois=rois, availabitity=input_data.s1_asc_availability, sensor_func=read_s1_img_tile, ) - # logger.info("Reading S1 DESC data") + logger.debug("Reading S1 DESC data") s1_desc_data = read_img_tile( filename=input_data.s1_desc_filename, rois=rois, availabitity=input_data.s1_desc_availability, sensor_func=read_s1_img_tile, ) - # logger.info("Reading SRTM data") + logger.debug("Reading SRTM data") srtm_data = read_img_tile( filename=input_data.srtm_filename, rois=rois, availabitity=True, sensor_func=read_srtm_img_tile, ) - # logger.info("Reading WorldClim data") + logger.debug("Reading WorldClim data") worldclim_data = concat_worldclim_components( wc_filename=input_data.wc_filename, rois=rois, availabitity=True ) - # print("input_data.s2_availabitity=", input_data.s2_availabitity) - # print("s2_data=", s2_data) - # print("s1_asc_data=", s1_asc_data) - # print("s1_desc_data=", s1_desc_data) - # print("srtm_data=", srtm_data) - # print("worldclim_data=", worldclim_data) - # logger.info("Export Init") + logger.debug("input_data.s2_availabitity=", input_data.s2_availabitity) + logger.debug("s2_data=", s2_data) + logger.debug("s1_asc_data=", s1_asc_data) + logger.debug("s1_desc_data=", s1_desc_data) + logger.debug("srtm_data=", srtm_data) + logger.debug("worldclim_data=", worldclim_data) + logger.debug("Export Init") with rio.open(export_path, "w", **meta) as prediction: # iterate over the windows - for roi, s2, s1_asc, s1_desc, srtm, wc in zip( - rois, - s2_data, - s1_asc_data, - s1_desc_data, - srtm_data, - worldclim_data, + for roi, s2, s1_asc, s1_desc, srtm, wc in tqdm( + zip( + rois, + s2_data, + s1_asc_data, + s1_desc_data, + srtm_data, + worldclim_data, + ), + total=len(rois), + position=1, ): - # print(" original size : ", s2.s2_reflectances.shape) - # print(" original size : ", s2.s2_angles.shape) - # print(" original size : ", s2.s2_mask.shape) - # print(" original size : ", s1_asc.s1_backscatter.shape) - # print(" original size : ", s1_asc.s1_valmask.shape) - # print(" original size : ", s1_asc.s1_lia_angles.shape) - # print(" original size : ", s1_desc.s1_backscatter.shape) - # print(" original size : ", s1_desc.s1_valmask.shape) - # print(" original size : ", s1_desc.s1_lia_angles.shape) - # print(" original size : ", srtm.srtm.shape) - # print(" original size : ", wc.worldclim.shape) + # export one tile + logger.debug(" original size : ", s2.s2_reflectances.shape) + logger.debug(" original size : ", s2.s2_angles.shape) + logger.debug(" original size : ", s2.s2_mask.shape) + logger.debug(" original size : ", s1_asc.s1_backscatter.shape) + logger.debug(" original size : ", s1_asc.s1_valmask.shape) + logger.debug(" original size : ", s1_asc.s1_lia_angles.shape) + logger.debug(" original size : ", s1_desc.s1_backscatter.shape) + logger.debug(" original size : ", s1_desc.s1_valmask.shape) + logger.debug(" original size : ", s1_desc.s1_lia_angles.shape) + logger.debug(" original size : ", srtm.srtm.shape) + logger.debug(" original size : ", wc.worldclim.shape) # Concat S1 Data s1_backscatter = torch.cat( @@ -199,8 +204,8 @@ def predict_single_date_tile( ).shape s2_s2_mask_shape = s2.s2_mask.shape # [1, 1024, 10980] - # print("s2_mask_patch_size :", s2_mask_patch_size) - # print("s2_s2_mask_shape:", s2_s2_mask_shape) + logger.debug("s2_mask_patch_size :", s2_mask_patch_size) + logger.debug("s2_s2_mask_shape:", s2_s2_mask_shape) # reshape the data s2_refl_patch = patches.flatten2d( patches.patchify(s2.s2_reflectances, process.patch_size) @@ -231,25 +236,24 @@ def predict_single_date_tile( s1desc_lia_patch = s1_desc.s1_lia_angles.unsqueeze(0).repeat( s2_mask_patch_size[0] * s2_mask_patch_size[1], 1 ) - # torch.flatten(start_dim=0, end_dim=1) - # print("s2_patches", s2_refl_patch.shape) - # print("s2_angles_patches", s2_ang_patch.shape) - # print("s2_mask_patch", s2_mask_patch.shape) - # print("s1_patch", s1_patch.shape) - # print("s1_valmask_patch", s1_valmask_patch.shape) - # print( - # "s1_lia_asc patches", - # s1_asc.s1_lia_angles.shape, - # s1asc_lia_patch.shape, - # ) - # print( - # "s1_lia_desc patches", - # s1_desc.s1_lia_angles.shape, - # s1desc_lia_patch.shape, - # ) - # print("srtm patches", srtm_patch.shape) - # print("wc patches", wc_patch.shape) + logger.debug("s2_patches", s2_refl_patch.shape) + logger.debug("s2_angles_patches", s2_ang_patch.shape) + logger.debug("s2_mask_patch", s2_mask_patch.shape) + logger.debug("s1_patch", s1_patch.shape) + logger.debug("s1_valmask_patch", s1_valmask_patch.shape) + logger.debug( + "s1_lia_asc patches", + s1_asc.s1_lia_angles.shape, + s1asc_lia_patch.shape, + ) + logger.debug( + "s1_lia_desc patches", + s1_desc.s1_lia_angles.shape, + s1desc_lia_patch.shape, + ) + logger.debug("srtm patches", srtm_patch.shape) + logger.debug("wc patches", wc_patch.shape) # apply predict function # should return a s1s2vaelatentspace @@ -267,8 +271,8 @@ def predict_single_date_tile( wc_patch, srtm_patch, ) - # print(type(pred_vaelatentspace)) - # print("latent space sizes : ", pred_vaelatentspace.shape) + logger.debug(type(pred_vaelatentspace)) + logger.debug("latent space sizes : ", pred_vaelatentspace.shape) # unpatchify pred_vaelatentspace_unpatchify = patches.unpatchify( @@ -279,8 +283,8 @@ def predict_single_date_tile( ) )[:, : s2_s2_mask_shape[1], : s2_s2_mask_shape[2]] - # print("pred_tensor :", pred_vaelatentspace_unpatchify.shape) - # print("process.count : ", process.count) + logger.debug("pred_tensor :", pred_vaelatentspace_unpatchify.shape) + logger.debug("process.count : ", process.count) # check the pred and the ask are the same assert process.count == pred_vaelatentspace_unpatchify.shape[0] @@ -289,7 +293,7 @@ def predict_single_date_tile( window=roi, indexes=process.count, ) - # logger.info(("Export tile", f"filename :{export_path}")) + logger.debug(("Export tile", f"filename :{export_path}")) def estimate_nb_latent_spaces( @@ -326,22 +330,24 @@ def determinate_sensor_list( patchasc_s1_availability, patchdesc_s1_availability, ] - match sensors_availability: - case [True, True, True]: - sensors = ["S2L2A", "S1FULL"] - case [True, True, False]: - sensors = ["S2L2A", "S1ASC"] - case [True, False, True]: - sensors = ["S2L2A", "S1DESC"] - case [False, True, True]: - sensors = ["S1FULL"] - case [True, False, False]: - sensors = ["S2L2A"] - case [False, True, False]: - sensors = ["S1ASC"] - case [False, False, True]: - sensors = ["S1DESC"] - + if sensors_availability == [True, True, True]: + sensors = ["S2L2A", "S1FULL"] + elif sensors_availability == [True, True, False]: + sensors = ["S2L2A", "S1ASC"] + elif sensors_availability == [True, False, True]: + sensors = ["S2L2A", "S1DESC"] + elif sensors_availability == [False, True, True]: + sensors = ["S1FULL"] + elif sensors_availability == [True, False, False]: + sensors = ["S2L2A"] + elif sensors_availability == [False, True, False]: + sensors = ["S1ASC"] + elif sensors_availability == [False, False, True]: + sensors = ["S1DESC"] + else: + raise Exception( + (f"The sensor {sensors_availability}", " is not in the available list") + ) return sensors @@ -349,6 +355,7 @@ def mmdc_tile_inference( inference_dataframe: pd.DataFrame, export_path: Path, model_checkpoint_path: Path, + model_config_path: Path, patch_size: int = 256, nb_lines: int = 1024, ) -> None: @@ -372,11 +379,13 @@ def mmdc_tile_inference( # TODO Uncomment for test purpuse # inference_dataframe = inference_dataframe.head() - # print(inference_dataframe.shape) + # logger.debug(inference_dataframe.shape) # instance the model and get the pred func model = get_mmdc_full_model( checkpoint=model_checkpoint_path, + mmdc_full_config_path=model_config_path, + inference_tile=True, ) scales = get_scales() model.set_scales(scales) @@ -385,7 +394,7 @@ def mmdc_tile_inference( # iterate over the dates in the time serie for tuile, df_row in tqdm( - inference_dataframe.iterrows(), total=inference_dataframe.shape[0] + inference_dataframe.iterrows(), total=inference_dataframe.shape[0], position=0 ): # estimate nb latent spaces latent_space_size = estimate_nb_latent_spaces( @@ -401,9 +410,8 @@ def mmdc_tile_inference( # export path date_ = df_row["date"].strftime("%Y-%m-%d") export_file = export_path / f"latent_tile_infer_{date_}_{'_'.join(sensor)}.tif" - # - # print(tuile, df_row) - # print(latent_space_size) + logger.debug(tuile, df_row) + logger.debug(latent_space_size) # define process mmdc_process = MMDCProcess( count=latent_space_size, @@ -431,4 +439,4 @@ def mmdc_tile_inference( process=mmdc_process, ) - # logger.info("Export Finish !!!") + logger.debug("Export Finish !!!") -- GitLab From 49ad2680225b503e14c3fd2f03402295c1f46a87 Mon Sep 17 00:00:00 2001 From: Juan Sebastian Vinasco-Salinas <js.vinasco.s@gmail.com> Date: Wed, 22 Mar 2023 15:32:37 +0000 Subject: [PATCH 68/81] delete iota2 functions --- .../inference/components/inference_utils.py | 47 ------------------- 1 file changed, 47 deletions(-) diff --git a/src/mmdc_singledate/inference/components/inference_utils.py b/src/mmdc_singledate/inference/components/inference_utils.py index 1d79f1f0..ef85dbdb 100644 --- a/src/mmdc_singledate/inference/components/inference_utils.py +++ b/src/mmdc_singledate/inference/components/inference_utils.py @@ -243,27 +243,6 @@ def read_s2_img_tile( return S2Components(s2_reflectances=image_s2, s2_angles=angles_s2, s2_mask=mask) -# def read_s2_img_iota2( -# s2_tensor: torch.Tensor, -# *args: Any, -# **kwargs: Any, -# ) -> S2Components: # [torch.Tensor, torch.Tensor, torch.Tensor]: -# """ -# read a patch of sentinel 2 data -# contruct the masks and yield the patch -# of data -# """ -# # copy the masks -# # cloud_mask = -# # sat_mask = -# # edge_mask = -# mask = torch.concat((cloud_mask, sat_mask, edge_mask), axis=0) -# angles_s2 = join_even_odd_s2_angles(s2_tensor[14:, ...]) -# image_s2 = s2_tensor[:10, ...] - -# return S2Components(s2_reflectances=image_s2, s2_angles=angles_s2, s2_mask=mask) - - # TODO get out the s1_lia computation from this function def read_s1_img_tile( s1_tensor: torch.Tensor, @@ -302,32 +281,6 @@ def read_s1_img_tile( ) -# def read_s1_img_iota2( -# s1_tensor: torch.Tensor, -# *args: Any, -# **kwargs: Any, -# ) -> S1Components: # [torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: -# """ -# read a patch of s1 construct the datasets associated -# and yield the result -# """ -# # get the vv,vh, vv/vh bands -# img_s1 = torch.cat( -# (s1_tensor, (s1_tensor[1, ...] / s1_tensor[0, ...]).unsqueeze(0)) -# ) -# s1_lia = get_s1_acquisition_angles(s1_filename) -# # compute validity mask -# s1_valmask = torch.ones(img_s1.shape) -# # compute edge mask -# s1_backscatter = apply_log_to_s1(img_s1) - -# return S1Components( -# s1_backscatter=s1_backscatter, -# s1_valmask=s1_valmask, -# s1_lia_angles=s1_lia, -# ) - - def read_srtm_img_tile( srtm_tensor: torch.Tensor, *args: Any, -- GitLab From 0c6d60660a596d903a343cc692760546efb76373 Mon Sep 17 00:00:00 2001 From: Juan Sebastian Vinasco-Salinas <js.vinasco.s@gmail.com> Date: Wed, 22 Mar 2023 18:24:27 +0000 Subject: [PATCH 69/81] WIP parsing data and dates --- .../inference/mmdc_iota2_inference.py | 475 +++++++++++++++--- 1 file changed, 409 insertions(+), 66 deletions(-) diff --git a/src/mmdc_singledate/inference/mmdc_iota2_inference.py b/src/mmdc_singledate/inference/mmdc_iota2_inference.py index 4a40ff74..5bd385e4 100644 --- a/src/mmdc_singledate/inference/mmdc_iota2_inference.py +++ b/src/mmdc_singledate/inference/mmdc_iota2_inference.py @@ -6,29 +6,58 @@ Infereces API with Iota-2 Inspired by: https://src.koda.cnrs.fr/mmdc/mmdc-singledate/-/blob/01ff5139a9eb22785930964d181e2a0b7b7af0d1/iota2/external_iota2_code.py """ +# imports +import logging +import os +import numpy as np import torch +from dateutil import parser as date_parser +from iota2.configuration_files import read_config_file as rcf from torchutils import patches -# from mmdc_singledate.datamodules.components.datamodule_utils import ( -# apply_log_to_s1, -# join_even_odd_s2_angles, -# srtm_height_aspect, -# ) +# from mmdc_singledate.models.mmdc_full_module import MMDCFullModule +# from mmdc_singledate.models.types import S1S2VAELatentSpace +from mmdc_singledate.datamodules.components.datamodule_utils import ( + apply_log_to_s1, + srtm_height_aspect, +) +from .components.inference_components import get_mmdc_full_model, predict_mmdc_model +from .components.inference_utils import ( + GeoTiffDataset, + S1Components, + S2Components, + SRTMComponents, + WorldClimComponents, + get_mmdc_full_config, + read_s1_img_tile, + read_s2_img_tile, + read_srtm_img_tile, + read_worldclim_img_tile, +) +from .utils import get_scales -def apply_mmdc_full_mode( +# from tqdm import tqdm +# from pathlib import Path + + +# Configure the logger +NUMERIC_LEVEL = getattr(logging, "INFO", None) +logging.basicConfig( + level=NUMERIC_LEVEL, format="%(asctime)-15s %(levelname)s: %(message)s" +) +logger = logging.getLogger(__name__) + + +def read_s2_img_iota2( self, - checkpoint_path: str, - checkpoint_epoch: int = 100, - patch_size: int = 256, -): +) -> S2Components: # [torch.Tensor, torch.Tensor, torch.Tensor]: """ - Apply MMDC with Iota-2.py + read a patch of sentinel 2 data + contruct the masks and yield the patch + of data """ - # How manage the S1 acquisition dates for construct the mask validity - # in time - # DONE Get the data in the same order as # sentinel2.Sentinel2.GROUP_10M + # sensorsio.sentinel2.GROUP_20M @@ -46,9 +75,37 @@ def apply_mmdc_full_mode( self.get_interpolated_Sentinel2_B12(), ] - # TODO Masks contruction for S2 - list_s2_mask = self.get_Sentinel2_binary_masks() + # DONE Masks contruction for S2 + list_s2_mask = [ + self.get_Sentinel2_binary_masks(), + self.get_Sentinel2_binary_masks(), + self.get_Sentinel2_binary_masks(), + self.get_Sentinel2_binary_masks(), + ] + + # TODO Add S2 Angles + list_bands_s2_angles = [ + self.get_Sentinel2_angles(), + ] + + # extend the list + list_bands_s2.extend(list_s2_mask) + list_bands_s2.extend(list_s2_mask) + + bands_s2 = torch.Tensor( + list_bands_s2, dtype=torch.torch.float32, device=device + ).permute(-1, 0, 1, 2) + return read_s2_img_tile(bands_s2) + + +def read_s1_img_iota2( + self, +) -> S1Components: # [torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """ + read a patch of s1 construct the datasets associated + and yield the result + """ # TODO Manage S1 ASC and S1 DESC ? list_bands_s1 = [ self.get_interpolated_Sentinel1_ASC_vh(), @@ -61,58 +118,344 @@ def apply_mmdc_full_mode( / (self.get_interpolated_Sentinel1_DES_vv() + 1e-4), ] - # TODO Read SRTM data - # TODO Read Worldclim Data - - with torch.no_grad(): - # Permute dimensions to fit patchify - # Shape before permutation is C,H,W,D. D being the dates - # Shape after permutation is D,C,H,W - bands_s2 = torch.Tensor(list_bands_s2).permute(-1, 0, 1, 2) - bands_s2_mask = torch.Tensor(list_s2_mask).permute(-1, 0, 1, 2) - bands_s1 = torch.Tensor(list_bands_s1).permute(-1, 0, 1, 2) - bands_s1 = apply_log_to_s1(bands_s1).permute(-1, 0, 1, 2) - - # TODO Masks contruction for S1 - # build_s1_image_and_masks function for datamodules components datamodule components ? - - # Replace nan by 0 - bands_s1 = bands_s1.nan_to_num() - # These dimensions are useful for unpatchify - # Keep the dimensions of height and width found in chunk - band_h, band_w = bands_s2.shape[-2:] - # Keep the number of patches of patch_size - # in rows and cols found in chunk - h, w = patches.patchify( - bands_s2[0, ...], patch_size=patch_size, margin=patch_margin - ).shape[:2] - - # TODO Apply patchify - - # Get the model + bands_s1 = torch.Tensor(list_bands_s1).permute(-1, 0, 1, 2) + bands_s1 = apply_log_to_s1(bands_s1).permute(-1, 0, 1, 2) + + # get the vv,vh, vv/vh bands + # img_s1 = torch.cat( + # (s1_tensor, (s1_tensor[1, ...] / s1_tensor[0, ...]).unsqueeze(0)) + # ) + + list_bands_s1_lia = [ + self.get_interpolated_, + self.get_interpolated_, + self.get_interpolated_, + ] + + # s1_lia = get_s1_acquisition_angles(s1_filename) + # compute validity mask + s1_valmask = torch.ones(img_s1.shape) + # compute edge mask + s1_backscatter = apply_log_to_s1(img_s1) + + return read_s1_img_tile() + + +def read_srtm_img_iota2( + self, +) -> SRTMComponents: # [torch.Tensor]: + """ + read srtm patch + """ + # read the patch as tensor + + list_bands_srtm = [ + self.get_interpolated_userFeatures_aspect(), + self.get_interpolated_, + self.get_interpolated_, + ] + + srtm_tensor = torch.Tensor(list_bands_srtm).permute(-1, 0, 1, 2) + + return read_srtm_img_tile(srtm_height_aspect(srtm_tensor)) + + +def read_worldclim_img_iota2( + self, +) -> WorldClimComponents: # [torch.tensor]: + """ + read worldclim subset + """ + wc_tensor = torch.Tensor(self.get_interpolated_).permute(-1, 0, 1, 2) + return read_worldclim_img_tile(wc_tensor) + + +def predictc_iota2( + self, +): + """ + Apply MMDC with Iota-2.py + """ + + # retrieve the metadata + # meta = input_data.metadata.copy() + # get nb outputs + # meta.update({"count": process.count}) + # calculate rois + # chunks = generate_chunks(meta["width"], meta["height"], process.nb_lines) + # get the windows from the chunks + # rois = [rio.windows.Window(*chunk) for chunk in chunks] + # logger.debug(f"chunk size : ({rois[0].width}, {rois[0].height}) ") + + # init the dataset + logger.debug("Reading S2 data") + s2_data = read_s2_img_iota2() + logger.debug("Reading S1 ASC data") + s1_asc_data = read_s1_img_iota2() + logger.debug("Reading S1 DESC data") + s1_desc_data = read_s1_img_iota2() + logger.debug("Reading SRTM data") + srtm_data = read_srtm_img_iota2() + logger.debug("Reading WorldClim data") + worldclim_data = read_worldclim_img_iota2() + + # Concat S1 Data + s1_backscatter = torch.cat((s1_asc.s1_backscatter, s1_desc.s1_backscatter), 0) + # Multiply the masks + s1_valmask = torch.mul(s1_asc.s1_valmask, s1_desc.s1_valmask) + # keep the sizes for recover the original size + s2_mask_patch_size = patches.patchify(s2.s2_mask, process.patch_size).shape + s2_s2_mask_shape = s2.s2_mask.shape # [1, 1024, 10980] + + # reshape the data + s2_refl_patch = patches.flatten2d( + patches.patchify(s2_data.s2_reflectances, process.patch_size) + ) + s2_ang_patch = patches.flatten2d( + patches.patchify(s2_data.s2_angles, process.patch_size) + ) + s2_mask_patch = patches.flatten2d( + patches.patchify(s2_data.s2_mask, process.patch_size) + ) + s1_patch = patches.flatten2d(patches.patchify(s1_backscatter, process.patch_size)) + s1_valmask_patch = patches.flatten2d( + patches.patchify(s1_valmask, process.patch_size) + ) + srtm_patch = patches.flatten2d(patches.patchify(srtm_data.srtm, process.patch_size)) + wc_patch = patches.flatten2d( + patches.patchify(worldclim_data.worldclim, process.patch_size) + ) + # Expand the angles to fit the sizes + s1asc_lia_patch = s1_asc.s1_lia_angles.unsqueeze(0).repeat( + s2_mask_patch_size[0] * s2_mask_patch_size[1], 1 + ) + s1desc_lia_patch = s1_desc.s1_lia_angles.unsqueeze(0).repeat( + s2_mask_patch_size[0] * s2_mask_patch_size[1], 1 + ) + + logger.debug("s2_patches", s2_refl_patch.shape) + logger.debug("s2_angles_patches", s2_ang_patch.shape) + logger.debug("s2_mask_patch", s2_mask_patch.shape) + logger.debug("s1_patch", s1_patch.shape) + logger.debug("s1_valmask_patch", s1_valmask_patch.shape) + logger.debug( + "s1_lia_asc patches", + s1_asc.s1_lia_angles.shape, + s1asc_lia_patch.shape, + ) + logger.debug( + "s1_lia_desc patches", + s1_desc.s1_lia_angles.shape, + s1desc_lia_patch.shape, + ) + logger.debug("srtm patches", srtm_patch.shape) + logger.debug("wc patches", wc_patch.shape) + + # apply predict function + # should return a s1s2vaelatentspace + # object + pred_vaelatentspace = process.process( + process.model, + sensors, + s2_refl_patch, + s2_mask_patch, + s2_ang_patch, + s1_patch, + s1_valmask_patch, + s1asc_lia_patch, + s1desc_lia_patch, + wc_patch, + srtm_patch, + ) + logger.debug(type(pred_vaelatentspace)) + logger.debug("latent space sizes : ", pred_vaelatentspace.shape) + + # unpatchify + pred_vaelatentspace_unpatchify = patches.unpatchify( + patches.unflatten2d( + pred_vaelatentspace, + s2_mask_patch_size[0], + s2_mask_patch_size[1], + ) + )[:, : s2_s2_mask_shape[1], : s2_s2_mask_shape[2]] + + logger.debug("pred_tensor :", pred_vaelatentspace_unpatchify.shape) + logger.debug("process.count : ", process.count) + # check the pred and the ask are the same + assert process.count == pred_vaelatentspace_unpatchify.shape[0] + + predition_array = np.array(pred_vaelatentspace_unpatchify[0, ...]) + + logger.debug(("Export tile", f"filename :{export_path}")) + + return predition_array, labels + + +def acquisition_dataframe( + self, +) -> pd.DataFrame: + """ + Construct a DataFrame with the acquisition dates + for the predictions + """ + # use Iota2 API + # acquisition_dates = self.get_raw_dates() + + # Get the acquisition dates + acquisition_dates = { + "Sentinel1_DES_vv": ["20151231"], + "Sentinel1_DES_vh": ["20151231"], + "Sentinel1_ASC_vv": ["20170518", "20170519"], + "Sentinel1_ASC_vh": ["20170518", "20170519"], + "Sentinel2": ["20151130", "20151203"], + } + + # get the acquisition dates for each sensor-mode + sentinel2_acquisition_dates = pd.DataFrame( + {"Sentinel2": acquisition_dates["Sentinel2"]}, dtype=str + ) + sentinel1_asc_acquisition_dates = pd.DataFrame( + { + "Sentinel1_ASC_vv": acquisition_dates["Sentinel1_ASC_vv"], + "Sentinel1_ASC_vh": acquisition_dates["Sentinel1_ASC_vh"], + }, + dtype=str, + ) + sentinel1_des_acquisition_dates = pd.DataFrame( + { + "Sentinel1_DES_vv": acquisition_dates["Sentinel1_DES_vv"], + "Sentinel1_DES_vh": acquisition_dates["Sentinel1_DES_vh"], + }, + dtype=str, + ) + + # convert acquistion dates from str to datetime object + sentinel2_acquisition_dates = sentinel2_acquisition_dates["Sentinel2"].apply( + lambda x: date_parser.parse(x) + ) + sentinel1_asc_acquisition_dates = sentinel1_asc_acquisition_dates[ + "Sentinel1_ASC_vv" + ].apply(lambda x: date_parser.parse(x)) + sentinel1_asc_acquisition_dates = sentinel1_asc_acquisition_dates[ + "Sentinel1_ASC_vh" + ].apply(lambda x: date_parser.parse(x)) + sentinel1_desc_acquisition_dates = sentinel1_des_acquisition_dates[ + "Sentinel1_DES_vv" + ].apply(lambda x: date_parser.parse(x)) + sentinel1_desc_acquisition_dates = sentinel1_des_acquisition_dates[ + "Sentinel1_DES_vh" + ].apply(lambda x: date_parser.parse(x)) + + # merge the dataframes + + # concat_acquisitions = pd.concat() + + # return + + +def apply_mmdc_full_mode( + # checkpoint_path: str, # read from the cfg file + # checkpoint_epoch: int = 100, # read from the cfg file + # patch_size: int = 256, # read from the cfg file +): + """ + Entry Point + """ + # Set device + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + # read the configuration file + config = rcf("configs/iota2/externalfeatures_with_userfeatures.cfg") + # return a list with a tuple of functions and arguments + mmdc_parametrs = config.getParam("external_features", "functions") + padding_size_x = config.getParam("python_data_managing", "padding_size_x") + padding_size_y = config.getParam("python_data_managing", "padding_size_y") + + checkpoint_path = parametrs[0][1] # TODO fix this parameters + + # Get the acquisition dates + # dict_dates = { + # 'Sentinel1_DES_vv': ['20151231'], + # 'Sentinel1_DES_vh': ['20151231'], + # 'Sentinel1_ASC_vv': ['20170518', '20170519'], + # 'Sentinel1_ASC_vh': ['20170518', '20170519'], + # 'Sentinel2': ['20151130', '20151203']} + + # model mmdc_full_model = get_mmdc_full_model( - os.path.join(checkpoint_path, checkpoint_filename) - ) - - # apply model - # latent_variable = mmdc_full_model.predict( - # s2_x, - # s2_m, - # s2_angles_x, - # s1_x, - # s1_vm, - # s1_asc_angles_x, - # s1_desc_angles_x, - # worldclim_x, - # srtm_x, + checkpoint=model_checkpoint_path, + mmdc_full_config_path=model_config_path, + inference_tile=False, + ) + scales = get_scales() + mmdc_full_model.set_scales(scales) + mmdc_full_model.to(device) + + # process class + mmdc_process = MMDCProcess( + count=latent_space_size, + nb_lines=nb_lines, + patch_size=patch_size, + process=predict_mmdc_model, + model=mmdc_full_model, + ) + # create export filename + export_path = f"{export_path}/latent_singledate_{'_'.join(sensors)}.tif" + + # # How manage the S1 acquisition dates for construct the mask validity + # # in time + + # # TODO Read SRTM data + # # TODO Read Worldclim Data + + # # with torch.no_grad(): + # # Permute dimensions to fit patchify + # # Shape before permutation is C,H,W,D. D being the dates + # # Shape after permutation is D,C,H,W + # bands_s2_mask = torch.Tensor(list_s2_mask).permute(-1, 0, 1, 2) + # bands_s1 = torch.Tensor(list_bands_s1).permute(-1, 0, 1, 2) + # bands_s1 = apply_log_to_s1(bands_s1).permute(-1, 0, 1, 2) + + # # TODO Masks contruction for S1 + # # build_s1_image_and_masks function for datamodules components datamodule components ? + + # # Replace nan by 0 + # bands_s1 = bands_s1.nan_to_num() + # # These dimensions are useful for unpatchify + # # Keep the dimensions of height and width found in chunk + # band_h, band_w = bands_s2.shape[-2:] + # # Keep the number of patches of patch_size + # # in rows and cols found in chunk + # h, w = patches.patchify( + # bands_s2[0, ...], patch_size=patch_size, margin=patch_margin + # ).shape[:2] + + # # TODO Apply patchify + + # # Get the model + # mmdc_full_model = get_mmdc_full_model( + # os.path.join(checkpoint_path, checkpoint_filename) # ) - # TODO unpatchify + # # apply model + # # latent_variable = mmdc_full_model.predict( + # # s2_x, + # # s2_m, + # # s2_angles_x, + # # s1_x, + # # s1_vm, + # # s1_asc_angles_x, + # # s1_desc_angles_x, + # # worldclim_x, + # # srtm_x, + # # ) + + # # TODO unpatchify - # TODO crop padding + # # TODO crop padding - # TODO Depending of the configuration return a unique latent variable - # or a stack of laten variables - # coef = pass - # labels = pass - # return coef, labels + # # TODO Depending of the configuration return a unique latent variable + # # or a stack of laten variables + # # coef = pass + # # labels = pass + # # return coef, labels -- GitLab From 4ae4317bf4be86d5be3e74147fc8f707c8e35899 Mon Sep 17 00:00:00 2001 From: Juan Sebastian Vinasco-Salinas <js.vinasco.s@gmail.com> Date: Tue, 28 Mar 2023 15:20:38 +0000 Subject: [PATCH 70/81] add jobs --- jobs/iota2_clasif_mmdc_full.pbs | 120 +++++++++++++++++++++++++++++++ jobs/iota2_tile_aux_features.pbs | 18 +++++ 2 files changed, 138 insertions(+) create mode 100644 jobs/iota2_clasif_mmdc_full.pbs create mode 100644 jobs/iota2_tile_aux_features.pbs diff --git a/jobs/iota2_clasif_mmdc_full.pbs b/jobs/iota2_clasif_mmdc_full.pbs new file mode 100644 index 00000000..e3ec1e97 --- /dev/null +++ b/jobs/iota2_clasif_mmdc_full.pbs @@ -0,0 +1,120 @@ +#!/bin/bash +#PBS -N iota2-class +#PBS -l select=1:ncpus=8:mem=12G +#PBS -l walltime=10:00:00 + +# be sure no modules loaded +conda deactivate +module purge + + +ulimit -u 5000 + +# export ROOT_PATH="$(dirname $( cd "$(dirname "${BASH_SOURCE[0]}")"; pwd -P ))" +# export OTB_APPLICATION_PATH=$OTB_APPLICATION_PATH:/work/scratch/$USER/virtualenv/mmdc-iota2/lib/otb/applications// +# export PATH=$PATH:$ROOT_PATH/bin:$OTB_APPLICATION_PATH +# export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$ROOT_PATH/lib:$OTB_APPLICATION_PATH # +# +# load modules +module load conda/4.12.0 +conda activate /work/scratch/${USER}/virtualenv/mmdc-iota2 + +# Apply Iota2 +Iota2.py \ + -config ${HOME}/src/MMDC/mmdc-singledate/configs/iota2/iota2_mmdc_full.cfg \ + -scheduler_type PBS \ + -config_ressources /home/uz/vinascj/src/MMDC/mmdc-singledate/configs/iota2/config_ressources.cfg \ + # -scheduler_type debug \ + # -only_summary + # -nb_parallel_tasks 1 \ + + +# New +# Functions availables +#['__class__', +#'__delattr__', +#'__dict__', +#'__dir__', +#'__doc__', +#'__eq__', +#'__format__', +#'__ge__', +#'__getattribute__', +#'__gt__', +#'__hash__', +#'__init__', +#'__init_subclass__', +#'__le__', +#'__lt__', +#'__module__', +#'__ne__', +#'__new__', +#'__reduce__', +#'__reduce_ex__', +#'__repr__', +#'__setattr__', +#'__sizeof__', +#'__str__', +#'__subclasshook__', +#'__weakref__', +#'all_dates', +#'allow_nans', +#'binary_masks', +#'concat_mode', +#'dim_ts', +#'enabled_gap', +#'exogeneous_data_array', +#'exogeneous_data_name', +#'external_functions', +#'fill_missing_dates', +#'func_param', +#'get_Sentinel1_ASC_vh_binary_masks', +#'get_Sentinel1_ASC_vv_binary_masks', +#'get_Sentinel1_DES_vh_binary_masks', +#'get_Sentinel1_DES_vv_binary_masks', +#'get_Sentinel2_binary_masks', +#'get_filled_masks', +#'get_filled_stack', +#'get_interpolated_Sentinel1_ASC_vh', +#'get_interpolated_Sentinel1_ASC_vv', +#'get_interpolated_Sentinel1_DES_vh', +#'get_interpolated_Sentinel1_DES_vv', +#'get_interpolated_Sentinel2_B11', +#'get_interpolated_Sentinel2_B12', +#'get_interpolated_Sentinel2_B2', +#'get_interpolated_Sentinel2_B3', +#'get_interpolated_Sentinel2_B4', +#'get_interpolated_Sentinel2_B5', +#'get_interpolated_Sentinel2_B6', +#'get_interpolated_Sentinel2_B7', +#'get_interpolated_Sentinel2_B8', +#'get_interpolated_Sentinel2_B8A', +#'get_interpolated_Sentinel2_Brightness', +#'get_interpolated_Sentinel2_NDVI', +#'get_interpolated_Sentinel2_NDWI', +#'get_interpolated_dates', +#'get_interpolated_userFeatures_aspect', +#'get_raw_Sentinel1_ASC_vh', +#'get_raw_Sentinel1_ASC_vv', +#'get_raw_Sentinel1_DES_vh', 'get_raw_Sentinel1_DES_vv', +#'get_raw_Sentinel2_B11', +#'get_raw_Sentinel2_B12', +#'get_raw_Sentinel2_B2', +#'get_raw_Sentinel2_B3', +#'get_raw_Sentinel2_B4', +#'get_raw_Sentinel2_B5', +#'get_raw_Sentinel2_B6', +#'get_raw_Sentinel2_B7', +#'get_raw_Sentinel2_B8', +#'get_raw_Sentinel2_B8A', +#'get_raw_dates', +#'get_raw_userFeatures_aspect', +#'interpolated_data', +#'interpolated_dates', +#'missing_masks_values', +#'missing_refl_values', +#'out_data', +#'process', +#'raw_data', +#'raw_dates', +#'test_user_feature_with_fake_data'] diff --git a/jobs/iota2_tile_aux_features.pbs b/jobs/iota2_tile_aux_features.pbs new file mode 100644 index 00000000..2507bf36 --- /dev/null +++ b/jobs/iota2_tile_aux_features.pbs @@ -0,0 +1,18 @@ +#!/bin/bash +#PBS -N iota2-tiler +#PBS -l select=1:ncpus=8:mem=120G +#PBS -l walltime=4:00:00 + +# be sure no modules loaded +conda deactivate +module purge + +# load modules +module load conda +conda activate /work/scratch/${USER}/virtualenv/mmdc-iota2 + + +Iota2.py -scheduler_type debug \ + -config ${HOME}/src/MMDC/mmdc-singledate/configs/iota2/i2_tiler.cfg + -nb_parallel_tasks 1 \ + -only_summary -- GitLab From d9f063ccabfbc409d7e9f96c397754d09f45da42 Mon Sep 17 00:00:00 2001 From: Juan Sebastian Vinasco-Salinas <js.vinasco.s@gmail.com> Date: Tue, 28 Mar 2023 15:21:26 +0000 Subject: [PATCH 71/81] downgrade libsvm --- create-iota2-env.sh | 95 ++------------------------------------------- 1 file changed, 4 insertions(+), 91 deletions(-) diff --git a/create-iota2-env.sh b/create-iota2-env.sh index 06e9ba1f..5f3b99ee 100644 --- a/create-iota2-env.sh +++ b/create-iota2-env.sh @@ -45,6 +45,9 @@ conda install mamba which python python --version +# install libsvm +mamba install -c conda-forge libsvm=325 + # Install iota2 #mamba install iota2_develop=257da617 -c iota2 mamba install iota2 -c iota2 @@ -67,7 +70,7 @@ ln -s ~/src/MMDC/mmdc-singledate/iota2_thirdparties/iota2/iota2/ iota2 cd ~/src/MMDC/mmdc-singledate -conda install -c conda-forge pydantic +mamba install -c conda-forge pydantic # install missing dependancies pip install -r requirements-mmdc-iota2.txt @@ -87,93 +90,3 @@ pip install -e .[testing] # End conda deactivate - -# #!/usr/bin/env bash - -# export python_version="3.9" -# export name="mmdc-iota2" -# if ! [ -z "$1" ] -# then -# export name=$1 -# fi - -# source ~/set_proxy_iota2.sh -# if [ -z "$https_proxy" ] -# then -# echo "Please set https_proxy environment variable before running this script" -# exit 1 -# fi - -# export target=/work/scratch/$USER/virtualenv/$name - -# if ! [ -z "$2" ] -# then -# export target="$2/$name" -# fi - -# echo "Installing $name in $target ..." - -# if [ -d "$target" ]; then -# echo "Cleaning previous conda env" -# rm -rf $target -# fi - -# # Create blank virtualenv -# module purge -# module load conda -# module load gcc -# conda activate -# conda create --yes --prefix $target python==${python_version} pip - -# # Enter virtualenv -# conda deactivate -# conda activate $target - -# # install mamba -# conda install mamba - -# which python -# python --version - -# # Install iota2 -# #mamba install iota2_develop=257da617 -c iota2 -# mamba install iota2 -c iota2 - -# clone the lastest version of iota2 -# rm -rf iota2_thirdparties/iota2 -# git clone -b issue#600/tile_exogenous_features https://framagit.org/iota2-project/iota2.git iota2_thirdparties/iota2 -# cd iota2_thirdparties/iota2 -# git checkout issue#600/tile_exogenous_features -# git pull origin issue#600/tile_exogenous_features -# cd ../../ -# pwd - -# # create a backup -# cd /home/uz/$USER/scratch/virtualenv/mmdc-iota2/lib/python3.9/site-packages/iota2-0.0.0-py3.9.egg/ -# mv iota2 iota2_backup - -# # create a symbolic link -# ln -s ~/src/MMDC/mmdc-singledate/iota2_thirdparties/iota2/iota2/ iota2 - -# cd ~/src/MMDC/mmdc-singledate - -# conda install -c conda-forge pydantic - -# # install missing dependancies -# pip install -r requirements-mmdc-iota2.txt #--upgrade --no-cache-dir - -# # Install sensorsio -# rm -rf iota2_thirdparties/sensorsio -# git clone https://src.koda.cnrs.fr/mmdc/sensorsio.git iota2_thirdparties/sensorsio -# pip install iota2_thirdparties/sensorsio - -# # Install torchutils -# rm -rf iota2_thirdparties/torchutils -# git clone https://src.koda.cnrs.fr/mmdc/torchutils.git iota2_thirdparties/torchutils -# pip install iota2_thirdparties/torchutils - -# # Install the current project in edit mode -# pip install -e .[testing] - -# # End -# conda deactivate -- GitLab From 56f9992dcb6b1603cfcd89aacbe00c88cdb51245 Mon Sep 17 00:00:00 2001 From: Juan Sebastian Vinasco-Salinas <js.vinasco.s@gmail.com> Date: Tue, 28 Mar 2023 15:21:58 +0000 Subject: [PATCH 72/81] add time series availability --- .../inference/mmdc_iota2_inference.py | 143 ++++++++++++------ 1 file changed, 97 insertions(+), 46 deletions(-) diff --git a/src/mmdc_singledate/inference/mmdc_iota2_inference.py b/src/mmdc_singledate/inference/mmdc_iota2_inference.py index 5bd385e4..742f2892 100644 --- a/src/mmdc_singledate/inference/mmdc_iota2_inference.py +++ b/src/mmdc_singledate/inference/mmdc_iota2_inference.py @@ -18,25 +18,25 @@ from torchutils import patches # from mmdc_singledate.models.mmdc_full_module import MMDCFullModule # from mmdc_singledate.models.types import S1S2VAELatentSpace -from mmdc_singledate.datamodules.components.datamodule_utils import ( - apply_log_to_s1, - srtm_height_aspect, -) - -from .components.inference_components import get_mmdc_full_model, predict_mmdc_model -from .components.inference_utils import ( - GeoTiffDataset, - S1Components, - S2Components, - SRTMComponents, - WorldClimComponents, - get_mmdc_full_config, - read_s1_img_tile, - read_s2_img_tile, - read_srtm_img_tile, - read_worldclim_img_tile, -) -from .utils import get_scales +# from mmdc_singledate.datamodules.components.datamodule_utils import ( +# apply_log_to_s1, +# srtm_height_aspect, +# ) + +# from .components.inference_components import get_mmdc_full_model, predict_mmdc_model +# from .components.inference_utils import ( +# GeoTiffDataset, +# S1Components, +# S2Components, +# SRTMComponents, +# WorldClimComponents, +# get_mmdc_full_config, +# read_s1_img_tile, +# read_s2_img_tile, +# read_srtm_img_tile, +# read_worldclim_img_tile, +# ) +# from .utils import get_scales # from tqdm import tqdm # from pathlib import Path @@ -51,8 +51,9 @@ logger = logging.getLogger(__name__) def read_s2_img_iota2( - self, -) -> S2Components: # [torch.Tensor, torch.Tensor, torch.Tensor]: + self, +): +# ) -> S2Components: # [torch.Tensor, torch.Tensor, torch.Tensor]: """ read a patch of sentinel 2 data contruct the masks and yield the patch @@ -83,14 +84,19 @@ def read_s2_img_iota2( self.get_Sentinel2_binary_masks(), ] - # TODO Add S2 Angles + # DONE Add S2 Angles list_bands_s2_angles = [ - self.get_Sentinel2_angles(), + self.get_interpolated_userFeatures_ZENITH(), + self.get_interpolated_userFeatures_AZIMUTH(), + self.get_interpolated_userFeatures_EVEN_ZENITH(), + self.get_interpolated_userFeatures_ODD_ZENITH(), + self.get_interpolated_userFeatures_EVEN_AZIMUTH(), + self.get_interpolated_userFeatures_ODD_AZIMUTH(), ] # extend the list list_bands_s2.extend(list_s2_mask) - list_bands_s2.extend(list_s2_mask) + list_bands_s2.extend(list_bands_s2_angles) bands_s2 = torch.Tensor( list_bands_s2, dtype=torch.torch.float32, device=device @@ -101,7 +107,8 @@ def read_s2_img_iota2( def read_s1_img_iota2( self, -) -> S1Components: # [torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: +): +# ) -> S1Components: # [torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """ read a patch of s1 construct the datasets associated and yield the result @@ -143,30 +150,39 @@ def read_s1_img_iota2( def read_srtm_img_iota2( self, -) -> SRTMComponents: # [torch.Tensor]: +): +# ) -> SRTMComponents: """ - read srtm patch + read srtm patch from Iota2 API """ - # read the patch as tensor + # read the patch as numpy array list_bands_srtm = [ + self.get_interpolated_userFeatures_elevation(), + self.get_interpolated_userFeatures_slope(), self.get_interpolated_userFeatures_aspect(), - self.get_interpolated_, - self.get_interpolated_, ] - + # convert to tensor and permute the bands srtm_tensor = torch.Tensor(list_bands_srtm).permute(-1, 0, 1, 2) - return read_srtm_img_tile(srtm_height_aspect(srtm_tensor)) + return read_srtm_img_tile( + srtm_height_aspect( + srtm_tensor + ) + ) + def read_worldclim_img_iota2( self, -) -> WorldClimComponents: # [torch.tensor]: +): +# ) -> WorldClimComponents: # [torch.tensor]: """ - read worldclim subset + read worldclim patch from Iota2 API """ - wc_tensor = torch.Tensor(self.get_interpolated_).permute(-1, 0, 1, 2) + wc_tensor = torch.Tensor( + self.get_interpolated_userFeatures_worldclim() + ).permute(-1, 0, 1, 2) return read_worldclim_img_tile(wc_tensor) @@ -330,26 +346,34 @@ def acquisition_dataframe( ) # convert acquistion dates from str to datetime object - sentinel2_acquisition_dates = sentinel2_acquisition_dates["Sentinel2"].apply( + sentinel2_acquisition_dates["Sentinel2"] = sentinel2_acquisition_dates["Sentinel2"].apply( lambda x: date_parser.parse(x) ) - sentinel1_asc_acquisition_dates = sentinel1_asc_acquisition_dates[ + sentinel1_asc_acquisition_dates["Sentinel1_ASC_vv"] = sentinel1_asc_acquisition_dates[ "Sentinel1_ASC_vv" - ].apply(lambda x: date_parser.parse(x)) - sentinel1_asc_acquisition_dates = sentinel1_asc_acquisition_dates[ + ].apply(lambda x: date_parser.parse(x)) + sentinel1_asc_acquisition_dates["Sentinel1_ASC_vh"] = sentinel1_asc_acquisition_dates[ "Sentinel1_ASC_vh" - ].apply(lambda x: date_parser.parse(x)) - sentinel1_desc_acquisition_dates = sentinel1_des_acquisition_dates[ + ].apply(lambda x: date_parser.parse(x)) + sentinel1_des_acquisition_dates["Sentinel1_DES_vv"] = sentinel1_des_acquisition_dates[ "Sentinel1_DES_vv" - ].apply(lambda x: date_parser.parse(x)) - sentinel1_desc_acquisition_dates = sentinel1_des_acquisition_dates[ + ].apply(lambda x: date_parser.parse(x)) + sentinel1_des_acquisition_dates["Sentinel1_DES_vh"] = sentinel1_des_acquisition_dates[ "Sentinel1_DES_vh" - ].apply(lambda x: date_parser.parse(x)) + ].apply(lambda x: date_parser.parse(x)) # merge the dataframes + sensors_join = pd.concat([ + sentinel2_acquisition_dates, + sentinel1_asc_acquisition_dates, + sentinel1_des_acquisition_dates, + ] + ) - # concat_acquisitions = pd.concat() - + # get acquisition availability + sensors_join["s2_availability"] = pd.isna(sensors_join["Sentinel2"]) + sensors_join["s1_availability_asc"] = pd.isna(sensors_join["Sentinel1_ASC_vv"]) + sensors_join["s1_availability_desc"] = pd.isna(sensors_join["Sentinel1_DES_vv"]) # return @@ -459,3 +483,30 @@ def apply_mmdc_full_mode( # # coef = pass # # labels = pass # # return coef, labels + + +def test_mnt(self, arg1): + """ + compute the Soil Composition Index + """ + print(dir(self)) + print("raw dates :", self.get_raw_dates()) + print("interpolated dates :",self.get_interpolated_dates()) + print("interpolated dates :",self.interpolated_dates) + print("exogeneous_data_name:",self.exogeneous_data_name) + print("fill_missing_dates:",self.fill_missing_dates) + print(":",None) + + mnt = self.get_Sentinel2_binary_masks() #self.get_raw_userFeatures_mnt_band_0()#self.get_raw_userFeatures_mnt2() + # mnt = self.get_raw_userFeatures_mnt2_band_0() + # mnt = self.get_raw_userFeatures_slope() + print(f"type mnt : {type(mnt)}") + print(f"size mnt : {mnt.shape}") + labels = [f"mnt_{cpt}" for cpt in range(mnt.shape[-1])] + print(f"labels mnt : {labels}") + + # s2_b5 = self.get_interpolated_Sentinel2_B5() + # print(f"type s2_b5 : {type(s2_b5)}") # 4, 3, 2 # 26,50,2 + # print(f"size s2_b5 : {s2_b5.shape}") # 4, 3, 1 # 26,50,2 + + return mnt, labels -- GitLab From a160400bb6dc0618780b2024b22bfde7b1745294 Mon Sep 17 00:00:00 2001 From: Juan Sebastian Vinasco-Salinas <js.vinasco.s@gmail.com> Date: Wed, 29 Mar 2023 11:39:14 +0000 Subject: [PATCH 73/81] config ressources --- configs/iota2/config_ressources.cfg | 434 ++++++++++++++++++++++++++++ 1 file changed, 434 insertions(+) create mode 100644 configs/iota2/config_ressources.cfg diff --git a/configs/iota2/config_ressources.cfg b/configs/iota2/config_ressources.cfg new file mode 100644 index 00000000..9ef37b3e --- /dev/null +++ b/configs/iota2/config_ressources.cfg @@ -0,0 +1,434 @@ +################################################################################ +# Configuration file use to set HPC ressources request + +# All steps are mandatory +# +# TEMPLATE +# step_name:{ +# name : "IOTA2_dir" #no space characters +# nb_cpu : 1 +# ram : "5gb" +# walltime : "00:10:00"#HH:MM:SS +# process_min : 1 #not mandatory (default = 1) +# process_max : 9 #not mandatory (default = number of tasks) +# } +# +################################################################################ + +iota2_chain:{ + name : "IOTA2" + nb_cpu : 1 + ram : "5gb" + walltime : "10:00:00" + } + +iota2_dir:{ + name : "IOTA2_dir" + nb_cpu : 1 + ram : "5gb" + walltime : "00:10:00" + process_min : 1 + } + +preprocess_data : { + name:"preprocess_data" + nb_cpu:40 + ram:"180gb" + walltime:"10:00:00" + process_min : 1 + } + +get_common_mask : { + name:"CommonMasks" + nb_cpu:3 + ram:"15gb" + walltime:"00:10:00" + process_min : 1 + } + +coregistration : { + name:"Coregistration" + nb_cpu:4 + ram:"20gb" + walltime:"05:00:00" + process_min : 1 + } + +get_pixValidity : { + name:"NbView" + nb_cpu:3 + ram:"15gb" + walltime:"00:10:00" + process_min : 1 + } + +envelope : { + name:"Envelope" + nb_cpu:3 + ram:"15gb" + walltime:"00:10:00" + process_min : 1 + } + +regionShape : { + name:"regionShape" + nb_cpu:3 + ram:"15gb" + walltime:"00:10:00" + process_min : 1 + } + +splitRegions : { + name:"splitRegions" + nb_cpu:3 + ram:"15gb" + walltime:"00:10:00" + process_min : 1 + } + +extract_data_region_tiles : { + name:"extract_data_region_tiles" + nb_cpu:3 + ram:"15gb" + walltime:"00:10:00" + process_min : 1 + } + +split_learning_val : { + name:"split_learning_val" + nb_cpu:3 + ram:"15gb" + walltime:"00:10:00" + process_min : 1 + } + +split_learning_val_sub : { + name:"split_learning_val_sub" + nb_cpu:3 + ram:"15gb" + walltime:"00:10:00" + process_min : 1 + } +split_samples : { + name:"split_samples" + nb_cpu:3 + ram:"15gb" + walltime:"00:10:00" + process_min : 1 + } + +samplesFormatting : { + name:"samples_formatting" + nb_cpu:1 + ram:"5gb" + walltime:"00:10:00" + process_min : 1 + } + +samplesMerge: { + name:"samples_merges" + nb_cpu:2 + ram:"10gb" + walltime:"00:10:00" + process_min : 1 + } + +samplesStatistics : { + name:"samples_stats" + nb_cpu:3 + ram:"15gb" + walltime:"00:10:00" + process_min : 1 + } + +samplesSelection : { + name:"samples_selection" + nb_cpu:3 + ram:"15gb" + walltime:"00:10:00" + process_min : 1 + } +samplesselection_tiles : { + name:"samplesSelection_tiles" + nb_cpu:3 + ram:"15gb" + walltime:"00:10:00" + process_min : 1 + } +samplesManagement : { + name:"samples_management" + nb_cpu:3 + ram:"15gb" + walltime:"00:10:00" + process_min : 1 + } + +vectorSampler : { + name:"vectorSampler" + nb_cpu:40 + ram:"180gb" + walltime:"10:00:00" + process_min : 1 + } + +dimensionalityReduction : { + name:"dimensionalityReduction" + nb_cpu:3 + nb_MPI_process:2 + ram:"15gb" + nb_chunk:1 + walltime:"00:20:00"} + +samplesAugmentation : { + name:"samplesAugmentation" + nb_cpu:3 + ram:"15gb" + walltime:"00:10:00" + } + +mergeSample : { + name:"mergeSample" + nb_cpu:3 + ram:"15gb" + walltime:"00:10:00" + process_min : 1 + } + +stats_by_models : { + name:"stats_by_models" + nb_cpu:3 + ram:"15gb" + walltime:"00:10:00" + process_min : 1 + } + +training : { + name:"training" + nb_cpu:40 + ram:"180gb" + walltime:"10:00:00" + process_min : 1 + } + +#training regression +TrainRegressionPytorch : { + name:"training" + nb_cpu:40 + ram:"180gb" + walltime:"30:00:00" + } + +PredictRegressionScikit: { + name:"prediction" + nb_cpu:40 + ram:"180gb" + walltime:"30:00:00" + } + +#training regression +PredictRegressionPytorch : { + name:"choice" + nb_cpu:40 + ram:"180gb" + walltime:"30:00:00" + } + + +cmdClassifications : { + name:"cmdClassifications" + nb_cpu:10 + ram:"60gb" + walltime:"10:00:00" + process_min : 1 + } + +classifications : { + name:"classifications" + nb_cpu:10 + ram:"60gb" + walltime:"10:00:00" + process_min : 1 + } + +classifShaping : { + name:"classifShaping" + nb_cpu:10 + ram:"60gb" + walltime:"10:00:00" + process_min : 1 + } + +gen_confusionMatrix : { + name:"genCmdconfusionMatrix" + nb_cpu:3 + ram:"15gb" + walltime:"00:10:00" + process_min : 1 + } + +confusionMatrix : { + name:"confusionMatrix" + nb_cpu:3 + ram:"15gb" + walltime:"00:10:00" + process_min : 1 + } + +#regression +GenerateRegressionMetrics : { + name:"classification" + nb_cpu:10 + ram:"50gb" + walltime:"40:00:00" + } + + +fusion : { + name:"fusion" + nb_cpu:3 + ram:"15gb" + walltime:"00:10:00" + process_min : 1 + } + +noData : { + name:"noData" + nb_cpu:3 + ram:"15gb" + walltime:"00:10:00" + process_min : 1 + } + +statsReport : { + name:"statsReport" + nb_cpu:3 + ram:"15gb" + walltime:"00:10:00" + process_min : 1 + } + +confusionMatrixFusion : { + name:"confusionMatrixFusion" + nb_cpu:3 + ram:"15gb" + walltime:"00:10:00" + process_min : 1 + } + +merge_final_classifications : { + name:"merge_final_classifications" + nb_cpu:1 + ram:"5gb" + walltime:"00:10:00" + } + +reportGen : { + name:"reportGeneration" + nb_cpu:3 + ram:"15gb" + walltime:"00:10:00" + process_min : 1 + } + +mergeOutStats : { + name:"mergeOutStats" + nb_cpu:3 + ram:"15gb" + walltime:"00:10:00" + process_min : 1 + } + +SAROptConfusionMatrix : { + name:"SAROptConfusionMatrix" + nb_cpu:3 + ram:"15gb" + walltime:"00:10:00" + process_min : 1 + } + +SAROptConfusionMatrixFusion : { + name:"SAROptConfusionMatrixFusion" + nb_cpu:3 + ram:"15gb" + walltime:"00:10:00" + process_min : 1 + } + +SAROptFusion : { + name:"SAROptFusion" + nb_cpu:3 + ram:"15gb" + walltime:"00:10:00" + process_min : 1 + } +clump : { + name:"clump" + nb_cpu:1 + ram:"5gb" + walltime:"00:10:00" + process_min:1 + } + +grid : { + name:"grid" + nb_cpu:1 + ram:"5gb" + walltime:"00:10:00" + process_min:1 + } + +vectorisation : { + name:"vectorisation" + nb_cpu:1 + ram:"5gb" + walltime:"00:10:00" + process_min:1 + } + +crownsearch : { + name:"crownsearch" + nb_cpu:1 + ram:"5gb" + walltime:"00:10:00" + process_min:1 + } + +crownbuild : { + name:"crownbuild" + nb_cpu:1 + ram:"5gb" + walltime:"00:10:00" + process_min:1} + +statistics : { + name:"statistics" + nb_cpu:1 + ram:"5gb" + walltime:"00:10:00" + process_min:1 + } + +join : { + name:"join" + nb_cpu:1 + ram:"5gb" + walltime:"00:10:00" + process_min:1 + } + +tiler : { + name:"tiler" + nb_cpu:8 + ram:"120gb" + walltime:"01:00:00" + process_min:1 + +} +features_tiler : { + name:"features_tiler" + nb_cpu:8 + ram:"120gb" + walltime:"01:00:00" + process_min:1 + +} -- GitLab From 4f39a09b7c8e3f2f626959a3f9f0356032ff4fcc Mon Sep 17 00:00:00 2001 From: Juan Sebastian Vinasco-Salinas <js.vinasco.s@gmail.com> Date: Wed, 29 Mar 2023 11:40:41 +0000 Subject: [PATCH 74/81] add configs --- configs/iota2/config_sar.cfg | 19 ++++++++ configs/iota2/i2_tiler.cfg | 22 +++++++++ configs/iota2/iota2_mmdc_full.cfg | 76 +++++++++++++++++++++++++++++++ 3 files changed, 117 insertions(+) create mode 100644 configs/iota2/config_sar.cfg create mode 100644 configs/iota2/i2_tiler.cfg create mode 100644 configs/iota2/iota2_mmdc_full.cfg diff --git a/configs/iota2/config_sar.cfg b/configs/iota2/config_sar.cfg new file mode 100644 index 00000000..07512be0 --- /dev/null +++ b/configs/iota2/config_sar.cfg @@ -0,0 +1,19 @@ +[Paths] +output = /work/CESBIO/projects/MAESTRIA/iota2_mmdc/s1_preprocess +s1images = /work/CESBIO/projects/MAESTRIA/iota2_mmdc/s1_data_datalake/ +srtm = /datalake/static_aux/MNT/SRTM_30_hgt +geoidfile = /home/uz/vinascj/src/MMDC/mmdc-data/Geoid/egm96.grd + +[Processing] +tiles = 31TCJ +tilesshapefile = /work/CESBIO/projects/MAESTRIA/iota2_mmdc/vector_db/mgrs_tiles.shp +referencesfolder = /work/CESBIO/projects/MAESTRIA/iota2_mmdc/tiled_reference/ +srtmshapefile = /work/CESBIO/projects/MAESTRIA/iota2_mmdc/vector_db/srtm.shp +rasterpattern = T31TCJ.tif +gapFilling_interpolation = spline +temporalresolution = 10 +borderthreshold = 1e-3 +ramperprocess = 5000 + +[Filtering] +window_radius = 2 diff --git a/configs/iota2/i2_tiler.cfg b/configs/iota2/i2_tiler.cfg new file mode 100644 index 00000000..193d757d --- /dev/null +++ b/configs/iota2/i2_tiler.cfg @@ -0,0 +1,22 @@ +chain : +{ + output_path : '/work/CESBIO/projects/MAESTRIA/iota2_mmdc/tiled_auxilar_data' + spatial_resolution : 10 + first_step : 'tiler' + last_step : 'tiler' + + proj : 'EPSG:2154' + rasters_grid_path : '/work/CESBIO/projects/MAESTRIA/iota2_mmdc/tiled_reference/T31TCJ' + tile_field : 'Name' + list_tile : 'T31TCJ' + srtm_path:'/datalake/static_aux/MNT/SRTM_30_hgt' + worldclim_path:'/datalake/static_aux/worldclim-2.0' + s2_path : '/work/CESBIO/projects/MAESTRIA/iota2_mmdc/s2_data_datalake' + # add S1 data + s1_dir:"/work/CESBIO/projects/MAESTRIA/iota2_mmdc/s1_data_datalake" +} + +builders: +{ +builders_class_name : ["i2_features_to_grid"] +} diff --git a/configs/iota2/iota2_mmdc_full.cfg b/configs/iota2/iota2_mmdc_full.cfg new file mode 100644 index 00000000..fb651a84 --- /dev/null +++ b/configs/iota2/iota2_mmdc_full.cfg @@ -0,0 +1,76 @@ +# Entry Point Iota2 Config +chain : +{ + # Input and Output data + output_path : '/work/CESBIO/projects/MAESTRIA/iota2_mmdc/iota2_results' + remove_output_path : False + check_inputs : False + list_tile : 'T31TCJ' + data_field : 'code' + s2_path : '/work/scratch/vinascj/MMDC/iota2/i2_training_data/s2_data_full' + ground_truth : '/work/CESBIO/projects/MAESTRIA/iota2_mmdc/vector_db/reference_data.shp' + + # Add S1 data + s1_path : "/home/uz/vinascj/src/MMDC/mmdc-singledate/configs/iota2/config_sar.cfg" + + # Classification related parameters + spatial_resolution : 10 + color_table : '/work/CESBIO/projects/MAESTRIA/iota2_mmdc/colorFile.txt' + nomenclature_path : '/work/CESBIO/projects/MAESTRIA/iota2_mmdc/nomenclature.txt' + first_step : 'init' + last_step : 'validation' + proj : 'EPSG:2154' + + # Path to the User data + user_feat_path: '/work/CESBIO/projects/MAESTRIA/iota2_mmdc/tiled_auxilar_data' #/work/CESBIO/projects/MAESTRIA/iota2_mmdc/tiled_auxilar_data/T31TCJ +} + +# Parametrize the classifier +arg_train : +{ + random_seed : 0 # Set the random seed to split the ground truth + runs : 1 + classifier : 'sharkrf' + otb_classifier_options : {'classifier.sharkrf.nodesize': 25} + sample_selection : + { + sampler : 'random' + strategy : 'percent' + 'strategy.percent.p' : 0.1 + } +} + +# Use User Provided data +userFeat: +{ + arbo:"/*" + patterns:"elevation,aspect,slope,worldclim,EVEN_AZIMUTH,ODD_ZENITH,s1_incidence_ascending,worldclim,AZIMUTH,EVEN_ZENITH,s1_azimuth_ascending,s1_incidence_descending,ZENITH,elevation,ODD_AZIMUTH,s1_azimuth_descending," +} + +# Parameters for the user provided data +python_data_managing : +{ + padding_size_x : 1 + padding_size_y : 1 + chunk_size_mode:"split_number" + number_of_chunks:4 + data_mode_access: "both" +} + +# Process the User and Iota2 data +external_features : +{ + # module:"/home/uz/vinascj/scratch/MMDC/iota2/juan_i2_data/inputs/other/soi.py" + module:"/home/uz/vinascj/src/MMDC/mmdc-singledate/src/mmdc_singledate/inference/mmdc_iota2_inference.py" + functions : [['test_mnt', {"arg1":1}]] + concat_mode : True + external_features_flag:True +} + +# Parametrize nodes +task_retry_limits : +{ + allowed_retry : 0 + maximum_ram : 180.0 + maximum_cpu : 40 +} -- GitLab From 3d75244afe891beb613a30e3547bfd55547c9095 Mon Sep 17 00:00:00 2001 From: Juan Sebastian Vinasco-Salinas <js.vinasco.s@gmail.com> Date: Wed, 29 Mar 2023 12:32:36 +0000 Subject: [PATCH 75/81] fix import --- .../inference/components/inference_components.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/mmdc_singledate/inference/components/inference_components.py b/src/mmdc_singledate/inference/components/inference_components.py index c514a8fc..6331e26f 100644 --- a/src/mmdc_singledate/inference/components/inference_components.py +++ b/src/mmdc_singledate/inference/components/inference_components.py @@ -14,7 +14,7 @@ import torch from torch import nn from mmdc_singledate.inference.utils import get_mmdc_full_config -from mmdc_singledate.models.mmdc_full_module import MMDCFullModule +from mmdc_singledate.models.torch.full import MMDCFullModule from mmdc_singledate.models.types import S1S2VAELatentSpace # define sensors variable for typing -- GitLab From 8b7aa4065072748c94fc3b250372968b1b74e855 Mon Sep 17 00:00:00 2001 From: Juan Sebastian Vinasco-Salinas <js.vinasco.s@gmail.com> Date: Wed, 29 Mar 2023 12:47:36 +0000 Subject: [PATCH 76/81] update config file --- test/test_mmdc_inference.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_mmdc_inference.py b/test/test_mmdc_inference.py index fd15d093..0f5085a9 100644 --- a/test/test_mmdc_inference.py +++ b/test/test_mmdc_inference.py @@ -270,7 +270,7 @@ def test_predict_single_date_tile_no_data( mmdc_full_model = get_mmdc_full_model( os.path.join(checkpoint_path, checkpoint_filename), - mmdc_full_config_path="/home/uz/vinascj/src/MMDC/mmdc-singledate/configs/model/mmdc_full.json", + mmdc_full_config_path="/home/uz/vinascj/src/MMDC/mmdc-singledate/test/full_module_config.json", inference_tile=True, ) # move to device -- GitLab From 5671620dacf2a476efcf09ce13267591dc3cc0dc Mon Sep 17 00:00:00 2001 From: Juan Sebastian Vinasco-Salinas <js.vinasco.s@gmail.com> Date: Wed, 29 Mar 2023 12:48:43 +0000 Subject: [PATCH 77/81] update config file --- test/test_mmdc_inference.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/test/test_mmdc_inference.py b/test/test_mmdc_inference.py index 0f5085a9..6d8a5f4f 100644 --- a/test/test_mmdc_inference.py +++ b/test/test_mmdc_inference.py @@ -18,16 +18,17 @@ from mmdc_singledate.inference.components.inference_utils import ( GeoTiffDataset, generate_chunks, ) -from mmdc_singledate.inference.mmdc_tile_inference import ( +from mmdc_singledate.inference.mmdc_tile_inference import ( # inference_dataframe, MMDCProcess, - inference_dataframe, mmdc_tile_inference, predict_single_date_tile, ) -from mmdc_singledate.models.types import VAELatentSpace from .utils import get_scales, setup_data +# from mmdc_singledate.models.types import VAELatentSpace + + # usefull variables dataset_dir = "/work/CESBIO/projects/MAESTRIA/test_onetile/total/T31TCJ/" s2_filename_variable = "SENTINEL2A_20180302-105023-464_L2A_T31TCJ_C_V2-2_roi_0.tif" -- GitLab From 7ee2a504503af9e97cda31b89ae60da511819efc Mon Sep 17 00:00:00 2001 From: Juan Sebastian Vinasco-Salinas <js.vinasco.s@gmail.com> Date: Wed, 29 Mar 2023 13:42:51 +0000 Subject: [PATCH 78/81] WIP minor fix --- Makefile | 3 +- test/full_module_config.json | 197 +++++++++++++++++++++++++++++++++++ 2 files changed, 199 insertions(+), 1 deletion(-) create mode 100644 test/full_module_config.json diff --git a/Makefile b/Makefile index bc64a93b..d112a0ae 100644 --- a/Makefile +++ b/Makefile @@ -74,10 +74,11 @@ test_no_PIL: test_mask_loss: $(CONDA) && pytest -vv test/test_masked_losses.py -PYLINT_IGNORED = "pix2pix_module.py,pix2pix_networks.py,mmdc_residual_module.py" test_inference: $(CONDA) && pytest -vv test/test_mmdc_inference.py +PYLINT_IGNORED = "pix2pix_module.py,pix2pix_networks.py,mmdc_residual_module.py" + #.PHONY: pylint: $(CONDA) && pylint --ignore=$(PYLINT_IGNORED) src/ diff --git a/test/full_module_config.json b/test/full_module_config.json new file mode 100644 index 00000000..db4ed172 --- /dev/null +++ b/test/full_module_config.json @@ -0,0 +1,197 @@ +{ + "_target_": "mmdc_singledate.models.mmdc_full_module.MMDCFullLitModule", + "lr": 0.0001, + "model": { + "_target_": "mmdc_singledate.models.mmdc_full_module.MMDCFullModule", + "config": { + "_target_": "mmdc_singledate.models.types.MultiVAEConfig", + "data_sizes": { + "_target_": "mmdc_singledate.models.types.MMDCDataChannels", + "sen1": 6, + "s1_angles": 6, + "sen2": 10, + "s2_angles": 6, + "srtm": 4, + "w_c": 103 + }, + "embeddings": { + "_target_": "mmdc_singledate.models.types.ModularEmbeddingConfig", + "s1_angles": { + "_target_": "mmdc_singledate.models.types.MLPConfig", + "hidden_layers": [ + 9, + 12, + 9 + ], + "out_channels": 3 + }, + "s1_srtm": { + "_target_": "mmdc_singledate.models.types.UnetParams", + "out_channels": 4, + "encoder_sizes": [ + 32, + 64, + 128 + ], + "kernel_size": 3, + "tail_layers": 3 + }, + "s2_angles": { + "_target_": "mmdc_singledate.models.types.ConvnetParams", + "out_channels": 3, + "sizes": [ + 16, + 8 + ], + "kernel_sizes": [ + 1, + 1, + 1 + ] + }, + "s2_srtm": { + "_target_": "mmdc_singledate.models.types.UnetParams", + "out_channels": 4, + "encoder_sizes": [ + 32, + 64, + 128 + ], + "kernel_size": 3, + "tail_layers": 3 + }, + "w_c": { + "_target_": "mmdc_singledate.models.types.ConvnetParams", + "out_channels": 3, + "sizes": [ + 64, + 32, + 16 + ], + "kernel_sizes": [ + 1, + 1, + 1, + 1 + ] + } + }, + "s1_encoder": { + "_target_": "mmdc_singledate.models.types.UnetParams", + "out_channels": 4, + "encoder_sizes": [ + 64, + 128, + 256, + 512 + ], + "kernel_size": 3, + "tail_layers": 3 + }, + "s2_encoder": { + "_target_": "mmdc_singledate.models.types.UnetParams", + "out_channels": 4, + "encoder_sizes": [ + 64, + 128, + 256, + 512 + ], + "kernel_size": 3, + "tail_layers": 3 + }, + "s1_decoder": { + "_target_": "mmdc_singledate.models.types.ConvnetParams", + "out_channels": 0, + "sizes": [ + 64, + 32, + 16 + ], + "kernel_sizes": [ + 3, + 3, + 3, + 3 + ] + }, + "s2_decoder": { + "_target_": "mmdc_singledate.models.types.ConvnetParams", + "out_channels": 0, + "sizes": [ + 64, + 32, + 16 + ], + "kernel_sizes": [ + 3, + 3, + 3, + 3 + ] + }, + "s1_enc_use": { + "_target_": "mmdc_singledate.models.types.MMDCDataUse", + "s1_angles": true, + "s2_angles": true, + "srtm": true, + "w_c": true + }, + "s2_enc_use": { + "_target_": "mmdc_singledate.models.types.MMDCDataUse", + "s1_angles": true, + "s2_angles": true, + "srtm": true, + "w_c": true + }, + "s1_dec_use": { + "_target_": "mmdc_singledate.models.types.MMDCDataUse", + "s1_angles": true, + "s2_angles": true, + "srtm": true, + "w_c": true + }, + "s2_dec_use": { + "_target_": "mmdc_singledate.models.types.MMDCDataUse", + "s1_angles": true, + "s2_angles": true, + "srtm": true, + "w_c": true + }, + "loss_weights": { + "_target_": "mmdc_singledate.models.types.MultiVAELossWeights", + "sen1": { + "_target_": "mmdc_singledate.models.types.TranslationLossWeights", + "nll": 1, + "lpips": 1, + "sam": 1, + "gradients": 1 + }, + "sen2": { + "_target_": "mmdc_singledate.models.types.TranslationLossWeights", + "nll": 1, + "lpips": 1, + "sam": 1, + "gradients": 1 + }, + "s1_s2": { + "_target_": "mmdc_singledate.models.types.TranslationLossWeights", + "nll": 1, + "lpips": 1, + "sam": 1, + "gradients": 1 + }, + "s2_s1": { + "_target_": "mmdc_singledate.models.types.TranslationLossWeights", + "nll": 1, + "lpips": 1, + "sam": 1, + "gradients": 1 + }, + "forward": 1, + "cross": 1, + "latent": 1 + } + } + } +} -- GitLab From d412ff10112cce56c020f949e69639b166c3ded3 Mon Sep 17 00:00:00 2001 From: Juan Sebastian Vinasco-Salinas <js.vinasco.s@gmail.com> Date: Wed, 29 Mar 2023 13:47:42 +0000 Subject: [PATCH 79/81] WIP minor fix --- src/mmdc_singledate/inference/utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/mmdc_singledate/inference/utils.py b/src/mmdc_singledate/inference/utils.py index 6b6a3c45..4c24cb63 100644 --- a/src/mmdc_singledate/inference/utils.py +++ b/src/mmdc_singledate/inference/utils.py @@ -25,6 +25,7 @@ from mmdc_singledate.models.types import ( # S1S2VAELatentSpace, ) +# TODO Complexify the scales used for the inference def get_scales(): """ Read Scales for Inference @@ -108,7 +109,7 @@ def parse_from_cfg(path_to_cfg: str): return cfg_dict -def get_mmdc_full_config(path_to_cfg: str, inference_tile: bool): +def get_mmdc_full_config(path_to_cfg: str, inference_tile: bool) -> MultiVAEConfig: """ " Get the parameters for instantiate the mmdc_full module based on the target inference mode -- GitLab From 93b63ae2ba60d2c6f2abf7fb41420831ae297327 Mon Sep 17 00:00:00 2001 From: Juan Sebastian Vinasco-Salinas <js.vinasco.s@gmail.com> Date: Wed, 29 Mar 2023 13:50:48 +0000 Subject: [PATCH 80/81] WIP minor fix --- .../components/inference_components.py | 132 ++++++++---------- 1 file changed, 60 insertions(+), 72 deletions(-) diff --git a/src/mmdc_singledate/inference/components/inference_components.py b/src/mmdc_singledate/inference/components/inference_components.py index 6331e26f..42cdc604 100644 --- a/src/mmdc_singledate/inference/components/inference_components.py +++ b/src/mmdc_singledate/inference/components/inference_components.py @@ -122,14 +122,14 @@ def predict_mmdc_model( srtm.to(device), ) - logger.debug("prediction.shape :", prediction[0].latent.latent_s2.mean.shape) - logger.debug("prediction.shape :", prediction[0].latent.latent_s2.logvar.shape) - logger.debug("prediction.shape :", prediction[0].latent.latent_s1.mean.shape) - logger.debug("prediction.shape :", prediction[0].latent.latent_s1.logvar.shape) + logger.debug("prediction.shape :", prediction[0].latent.sen2.mean.shape) + logger.debug("prediction.shape :", prediction[0].latent.sen2.logvar.shape) + logger.debug("prediction.shape :", prediction[0].latent.sen1.mean.shape) + logger.debug("prediction.shape :", prediction[0].latent.sen1.logvar.shape) # init the output latent spaces as # empty dataclass - latent_space = S1S2VAELatentSpace(latent_s1=None, latent_s2=None) + latent_space = S1S2VAELatentSpace(sen1=None, sen2=None) # fullfit with a matchcase # match @@ -138,150 +138,138 @@ def predict_mmdc_model( if sensors == ["S2L2A", "S1FULL"]: logger.debug("S2 captured & S1 full captured") - latent_space.latent_s2 = prediction[0].latent.latent_s2 - latent_space.latent_s1 = prediction[0].latent.latent_s1 + latent_space.sen2 = prediction[0].latent.sen2 + latent_space.sen1 = prediction[0].latent.sen1 latent_space_stack = torch.cat( ( - latent_space.latent_s2.mean, - latent_space.latent_s2.logvar, - latent_space.latent_s1.mean, - latent_space.latent_s1.logvar, + latent_space.sen2.mean, + latent_space.sen2.logvar, + latent_space.sen1.mean, + latent_space.sen1.logvar, ), 1, ) logger.debug( - "latent_space.latent_s2.mean :", - latent_space.latent_s2.mean.shape, + "latent_space.sen2.mean :", + latent_space.sen2.mean.shape, ) logger.debug( - "latent_space.latent_s2.logvar :", - latent_space.latent_s2.logvar.shape, - ) - logger.debug("latent_space.latent_s1.mean :", latent_space.latent_s1.mean.shape) - logger.debug( - "latent_space.latent_s1.logvar :", latent_space.latent_s1.logvar.shape + "latent_space.sen2.logvar :", + latent_space.sen2.logvar.shape, ) + logger.debug("latent_space.sen1.mean :", latent_space.sen1.mean.shape) + logger.debug("latent_space.sen1.logvar :", latent_space.sen1.logvar.shape) elif sensors == ["S2L2A", "S1ASC"]: # logger.debug("S2 captured & S1 asc captured") - latent_space.latent_s2 = prediction[0].latent.latent_s2 - latent_space.latent_s1 = prediction[0].latent.latent_s1 + latent_space.sen2 = prediction[0].latent.sen2 + latent_space.sen1 = prediction[0].latent.sen1 latent_space_stack = torch.cat( ( - latent_space.latent_s2.mean, - latent_space.latent_s2.logvar, - latent_space.latent_s1.mean, - latent_space.latent_s1.logvar, + latent_space.sen2.mean, + latent_space.sen2.logvar, + latent_space.sen1.mean, + latent_space.sen1.logvar, ), 1, ) logger.debug( - "latent_space.latent_s2.mean :", - latent_space.latent_s2.mean.shape, - ) - logger.debug( - "latent_space.latent_s2.logvar :", - latent_space.latent_s2.logvar.shape, + "latent_space.sen2.mean :", + latent_space.sen2.mean.shape, ) - logger.debug("latent_space.latent_s1.mean :", latent_space.latent_s1.mean.shape) logger.debug( - "latent_space.latent_s1.logvar :", latent_space.latent_s1.logvar.shape + "latent_space.sen2.logvar :", + latent_space.sen2.logvar.shape, ) + logger.debug("latent_space.sen1.mean :", latent_space.sen1.mean.shape) + logger.debug("latent_space.sen1.logvar :", latent_space.sen1.logvar.shape) elif sensors == ["S2L2A", "S1DESC"]: logger.debug("S2 captured & S1 desc captured") - latent_space.latent_s2 = prediction[0].latent.latent_s2 - latent_space.latent_s1 = prediction[0].latent.latent_s1 + latent_space.sen2 = prediction[0].latent.sen2 + latent_space.sen1 = prediction[0].latent.sen1 latent_space_stack = torch.cat( ( - latent_space.latent_s2.mean, - latent_space.latent_s2.logvar, - latent_space.latent_s1.mean, - latent_space.latent_s1.logvar, + latent_space.sen2.mean, + latent_space.sen2.logvar, + latent_space.sen1.mean, + latent_space.sen1.logvar, ), 1, ) logger.debug( - "latent_space.latent_s2.mean :", - latent_space.latent_s2.mean.shape, - ) - logger.debug( - "latent_space.latent_s2.logvar :", - latent_space.latent_s2.logvar.shape, + "latent_space.sen2.mean :", + latent_space.sen2.mean.shape, ) - logger.debug("latent_space.latent_s1.mean :", latent_space.latent_s1.mean.shape) logger.debug( - "latent_space.latent_s1.logvar :", latent_space.latent_s1.logvar.shape + "latent_space.sen2.logvar :", + latent_space.sen2.logvar.shape, ) + logger.debug("latent_space.sen1.mean :", latent_space.sen1.mean.shape) + logger.debug("latent_space.sen1.logvar :", latent_space.sen1.logvar.shape) elif sensors == ["S2L2A"]: # logger.debug("Only S2 captured") - latent_space.latent_s2 = prediction[0].latent.latent_s2 + latent_space.sen2 = prediction[0].latent.sen2 latent_space_stack = torch.cat( ( - latent_space.latent_s2.mean, - latent_space.latent_s2.logvar, + latent_space.sen2.mean, + latent_space.sen2.logvar, ), 1, ) logger.debug( - "latent_space.latent_s2.mean :", - latent_space.latent_s2.mean.shape, + "latent_space.sen2.mean :", + latent_space.sen2.mean.shape, ) logger.debug( - "latent_space.latent_s2.logvar :", - latent_space.latent_s2.logvar.shape, + "latent_space.sen2.logvar :", + latent_space.sen2.logvar.shape, ) elif sensors == ["S1FULL"]: # logger.debug("Only S1 full captured") - latent_space.latent_s1 = prediction[0].latent.latent_s1 + latent_space.sen1 = prediction[0].latent.sen1 latent_space_stack = torch.cat( - (latent_space.latent_s1.mean, latent_space.latent_s1.logvar), + (latent_space.sen1.mean, latent_space.sen1.logvar), 1, ) - logger.debug("latent_space.latent_s1.mean :", latent_space.latent_s1.mean.shape) - logger.debug( - "latent_space.latent_s1.logvar :", latent_space.latent_s1.logvar.shape - ) + logger.debug("latent_space.sen1.mean :", latent_space.sen1.mean.shape) + logger.debug("latent_space.sen1.logvar :", latent_space.sen1.logvar.shape) elif sensors == ["S1ASC"]: # logger.debug("Only S1ASC captured") - latent_space.latent_s1 = prediction[0].latent.latent_s1 + latent_space.sen1 = prediction[0].latent.sen1 latent_space_stack = torch.cat( - (latent_space.latent_s1.mean, latent_space.latent_s1.logvar), + (latent_space.sen1.mean, latent_space.sen1.logvar), 1, ) - logger.debug("latent_space.latent_s1.mean :", latent_space.latent_s1.mean.shape) - logger.debug( - "latent_space.latent_s1.logvar :", latent_space.latent_s1.logvar.shape - ) + logger.debug("latent_space.sen1.mean :", latent_space.sen1.mean.shape) + logger.debug("latent_space.sen1.logvar :", latent_space.sen1.logvar.shape) elif sensors == ["S1DESC"]: # logger.debug("Only S1DESC captured") - latent_space.latent_s1 = prediction[0].latent.latent_s1 + latent_space.sen1 = prediction[0].latent.sen1 latent_space_stack = torch.cat( - (latent_space.latent_s1.mean, latent_space.latent_s1.logvar), + (latent_space.sen1.mean, latent_space.sen1.logvar), 1, ) - logger.debug("latent_space.latent_s1.mean :", latent_space.latent_s1.mean.shape) - logger.debug( - "latent_space.latent_s1.logvar :", latent_space.latent_s1.logvar.shape - ) + logger.debug("latent_space.sen1.mean :", latent_space.sen1.mean.shape) + logger.debug("latent_space.sen1.logvar :", latent_space.sen1.logvar.shape) else: raise Exception(f"Not a valid sensors ({sensors}) are given") -- GitLab From 7ab7d9e0dc900e2d3dea37014b6d3f08f9e3c58d Mon Sep 17 00:00:00 2001 From: Juan Sebastian Vinasco-Salinas <js.vinasco.s@gmail.com> Date: Wed, 29 Mar 2023 13:59:11 +0000 Subject: [PATCH 81/81] WIP adding entry points --- setup.cfg | 2 ++ 1 file changed, 2 insertions(+) diff --git a/setup.cfg b/setup.cfg index eb6eb8f6..3cce8fa6 100644 --- a/setup.cfg +++ b/setup.cfg @@ -72,6 +72,8 @@ console_scripts = mmdc_generate_dataset = bin.generate_mmdc_dataset:main mmdc_visualize = bin.visualize_mmdc_ds:main mmdc_split_dataset = bin.split_mmdc_dataset:main + mmdc_inference_singledate = bin.inference_mmdc_singledate:main + mmdc_inference_time_serie = bin.inference_mmdc_timeserie:main # mmdc_inference = bin.infer_mmdc_ds:main # Add here console scripts like: -- GitLab