diff --git a/.gitignore b/.gitignore index 2015cd5d54387d056b86f31ffd8e464d784072d4..26aaaf09a17d237bb39b785b8eb06847dd7bf7b7 100644 --- a/.gitignore +++ b/.gitignore @@ -9,7 +9,9 @@ _lightning_logs/* src/models/*.ipynb *.swp *~ -thirdparties +thirdparties/* +iota2_thirdparties .coverage src/MMDC_SingleDate.egg-info/ .projectile +iota2_thirdparties/* diff --git a/Makefile b/Makefile index 5692cfc81f56d500719a680d88cdb810aa0a04de..d112a0ae67321f6a6ab3887a00a2155e5d993644 100644 --- a/Makefile +++ b/Makefile @@ -74,8 +74,11 @@ test_no_PIL: test_mask_loss: $(CONDA) && pytest -vv test/test_masked_losses.py +test_inference: + $(CONDA) && pytest -vv test/test_mmdc_inference.py PYLINT_IGNORED = "pix2pix_module.py,pix2pix_networks.py,mmdc_residual_module.py" + #.PHONY: pylint: $(CONDA) && pylint --ignore=$(PYLINT_IGNORED) src/ diff --git a/README.rst b/README.rst new file mode 100644 index 0000000000000000000000000000000000000000..5bc068184ef56f4b256ff867dfc8fb66e4a0230c --- /dev/null +++ b/README.rst @@ -0,0 +1 @@ +# MMDC-SingleDate diff --git a/configs/iota2/config_ressources.cfg b/configs/iota2/config_ressources.cfg new file mode 100644 index 0000000000000000000000000000000000000000..9ef37b3e8271362bc58573623eeee90ac5b68494 --- /dev/null +++ b/configs/iota2/config_ressources.cfg @@ -0,0 +1,434 @@ +################################################################################ +# Configuration file use to set HPC ressources request + +# All steps are mandatory +# +# TEMPLATE +# step_name:{ +# name : "IOTA2_dir" #no space characters +# nb_cpu : 1 +# ram : "5gb" +# walltime : "00:10:00"#HH:MM:SS +# process_min : 1 #not mandatory (default = 1) +# process_max : 9 #not mandatory (default = number of tasks) +# } +# +################################################################################ + +iota2_chain:{ + name : "IOTA2" + nb_cpu : 1 + ram : "5gb" + walltime : "10:00:00" + } + +iota2_dir:{ + name : "IOTA2_dir" + nb_cpu : 1 + ram : "5gb" + walltime : "00:10:00" + process_min : 1 + } + +preprocess_data : { + name:"preprocess_data" + nb_cpu:40 + ram:"180gb" + walltime:"10:00:00" + process_min : 1 + } + +get_common_mask : { + name:"CommonMasks" + nb_cpu:3 + ram:"15gb" + walltime:"00:10:00" + process_min : 1 + } + +coregistration : { + name:"Coregistration" + nb_cpu:4 + ram:"20gb" + walltime:"05:00:00" + process_min : 1 + } + +get_pixValidity : { + name:"NbView" + nb_cpu:3 + ram:"15gb" + walltime:"00:10:00" + process_min : 1 + } + +envelope : { + name:"Envelope" + nb_cpu:3 + ram:"15gb" + walltime:"00:10:00" + process_min : 1 + } + +regionShape : { + name:"regionShape" + nb_cpu:3 + ram:"15gb" + walltime:"00:10:00" + process_min : 1 + } + +splitRegions : { + name:"splitRegions" + nb_cpu:3 + ram:"15gb" + walltime:"00:10:00" + process_min : 1 + } + +extract_data_region_tiles : { + name:"extract_data_region_tiles" + nb_cpu:3 + ram:"15gb" + walltime:"00:10:00" + process_min : 1 + } + +split_learning_val : { + name:"split_learning_val" + nb_cpu:3 + ram:"15gb" + walltime:"00:10:00" + process_min : 1 + } + +split_learning_val_sub : { + name:"split_learning_val_sub" + nb_cpu:3 + ram:"15gb" + walltime:"00:10:00" + process_min : 1 + } +split_samples : { + name:"split_samples" + nb_cpu:3 + ram:"15gb" + walltime:"00:10:00" + process_min : 1 + } + +samplesFormatting : { + name:"samples_formatting" + nb_cpu:1 + ram:"5gb" + walltime:"00:10:00" + process_min : 1 + } + +samplesMerge: { + name:"samples_merges" + nb_cpu:2 + ram:"10gb" + walltime:"00:10:00" + process_min : 1 + } + +samplesStatistics : { + name:"samples_stats" + nb_cpu:3 + ram:"15gb" + walltime:"00:10:00" + process_min : 1 + } + +samplesSelection : { + name:"samples_selection" + nb_cpu:3 + ram:"15gb" + walltime:"00:10:00" + process_min : 1 + } +samplesselection_tiles : { + name:"samplesSelection_tiles" + nb_cpu:3 + ram:"15gb" + walltime:"00:10:00" + process_min : 1 + } +samplesManagement : { + name:"samples_management" + nb_cpu:3 + ram:"15gb" + walltime:"00:10:00" + process_min : 1 + } + +vectorSampler : { + name:"vectorSampler" + nb_cpu:40 + ram:"180gb" + walltime:"10:00:00" + process_min : 1 + } + +dimensionalityReduction : { + name:"dimensionalityReduction" + nb_cpu:3 + nb_MPI_process:2 + ram:"15gb" + nb_chunk:1 + walltime:"00:20:00"} + +samplesAugmentation : { + name:"samplesAugmentation" + nb_cpu:3 + ram:"15gb" + walltime:"00:10:00" + } + +mergeSample : { + name:"mergeSample" + nb_cpu:3 + ram:"15gb" + walltime:"00:10:00" + process_min : 1 + } + +stats_by_models : { + name:"stats_by_models" + nb_cpu:3 + ram:"15gb" + walltime:"00:10:00" + process_min : 1 + } + +training : { + name:"training" + nb_cpu:40 + ram:"180gb" + walltime:"10:00:00" + process_min : 1 + } + +#training regression +TrainRegressionPytorch : { + name:"training" + nb_cpu:40 + ram:"180gb" + walltime:"30:00:00" + } + +PredictRegressionScikit: { + name:"prediction" + nb_cpu:40 + ram:"180gb" + walltime:"30:00:00" + } + +#training regression +PredictRegressionPytorch : { + name:"choice" + nb_cpu:40 + ram:"180gb" + walltime:"30:00:00" + } + + +cmdClassifications : { + name:"cmdClassifications" + nb_cpu:10 + ram:"60gb" + walltime:"10:00:00" + process_min : 1 + } + +classifications : { + name:"classifications" + nb_cpu:10 + ram:"60gb" + walltime:"10:00:00" + process_min : 1 + } + +classifShaping : { + name:"classifShaping" + nb_cpu:10 + ram:"60gb" + walltime:"10:00:00" + process_min : 1 + } + +gen_confusionMatrix : { + name:"genCmdconfusionMatrix" + nb_cpu:3 + ram:"15gb" + walltime:"00:10:00" + process_min : 1 + } + +confusionMatrix : { + name:"confusionMatrix" + nb_cpu:3 + ram:"15gb" + walltime:"00:10:00" + process_min : 1 + } + +#regression +GenerateRegressionMetrics : { + name:"classification" + nb_cpu:10 + ram:"50gb" + walltime:"40:00:00" + } + + +fusion : { + name:"fusion" + nb_cpu:3 + ram:"15gb" + walltime:"00:10:00" + process_min : 1 + } + +noData : { + name:"noData" + nb_cpu:3 + ram:"15gb" + walltime:"00:10:00" + process_min : 1 + } + +statsReport : { + name:"statsReport" + nb_cpu:3 + ram:"15gb" + walltime:"00:10:00" + process_min : 1 + } + +confusionMatrixFusion : { + name:"confusionMatrixFusion" + nb_cpu:3 + ram:"15gb" + walltime:"00:10:00" + process_min : 1 + } + +merge_final_classifications : { + name:"merge_final_classifications" + nb_cpu:1 + ram:"5gb" + walltime:"00:10:00" + } + +reportGen : { + name:"reportGeneration" + nb_cpu:3 + ram:"15gb" + walltime:"00:10:00" + process_min : 1 + } + +mergeOutStats : { + name:"mergeOutStats" + nb_cpu:3 + ram:"15gb" + walltime:"00:10:00" + process_min : 1 + } + +SAROptConfusionMatrix : { + name:"SAROptConfusionMatrix" + nb_cpu:3 + ram:"15gb" + walltime:"00:10:00" + process_min : 1 + } + +SAROptConfusionMatrixFusion : { + name:"SAROptConfusionMatrixFusion" + nb_cpu:3 + ram:"15gb" + walltime:"00:10:00" + process_min : 1 + } + +SAROptFusion : { + name:"SAROptFusion" + nb_cpu:3 + ram:"15gb" + walltime:"00:10:00" + process_min : 1 + } +clump : { + name:"clump" + nb_cpu:1 + ram:"5gb" + walltime:"00:10:00" + process_min:1 + } + +grid : { + name:"grid" + nb_cpu:1 + ram:"5gb" + walltime:"00:10:00" + process_min:1 + } + +vectorisation : { + name:"vectorisation" + nb_cpu:1 + ram:"5gb" + walltime:"00:10:00" + process_min:1 + } + +crownsearch : { + name:"crownsearch" + nb_cpu:1 + ram:"5gb" + walltime:"00:10:00" + process_min:1 + } + +crownbuild : { + name:"crownbuild" + nb_cpu:1 + ram:"5gb" + walltime:"00:10:00" + process_min:1} + +statistics : { + name:"statistics" + nb_cpu:1 + ram:"5gb" + walltime:"00:10:00" + process_min:1 + } + +join : { + name:"join" + nb_cpu:1 + ram:"5gb" + walltime:"00:10:00" + process_min:1 + } + +tiler : { + name:"tiler" + nb_cpu:8 + ram:"120gb" + walltime:"01:00:00" + process_min:1 + +} +features_tiler : { + name:"features_tiler" + nb_cpu:8 + ram:"120gb" + walltime:"01:00:00" + process_min:1 + +} diff --git a/configs/iota2/config_sar.cfg b/configs/iota2/config_sar.cfg new file mode 100644 index 0000000000000000000000000000000000000000..07512be098cb3c53f0bef5d62cabb56d8691d435 --- /dev/null +++ b/configs/iota2/config_sar.cfg @@ -0,0 +1,19 @@ +[Paths] +output = /work/CESBIO/projects/MAESTRIA/iota2_mmdc/s1_preprocess +s1images = /work/CESBIO/projects/MAESTRIA/iota2_mmdc/s1_data_datalake/ +srtm = /datalake/static_aux/MNT/SRTM_30_hgt +geoidfile = /home/uz/vinascj/src/MMDC/mmdc-data/Geoid/egm96.grd + +[Processing] +tiles = 31TCJ +tilesshapefile = /work/CESBIO/projects/MAESTRIA/iota2_mmdc/vector_db/mgrs_tiles.shp +referencesfolder = /work/CESBIO/projects/MAESTRIA/iota2_mmdc/tiled_reference/ +srtmshapefile = /work/CESBIO/projects/MAESTRIA/iota2_mmdc/vector_db/srtm.shp +rasterpattern = T31TCJ.tif +gapFilling_interpolation = spline +temporalresolution = 10 +borderthreshold = 1e-3 +ramperprocess = 5000 + +[Filtering] +window_radius = 2 diff --git a/configs/iota2/externalfeatures_with_userfeatures.cfg b/configs/iota2/externalfeatures_with_userfeatures.cfg new file mode 100644 index 0000000000000000000000000000000000000000..b6eb0faba181ae73ec023d94019fcc3a0063c74d --- /dev/null +++ b/configs/iota2/externalfeatures_with_userfeatures.cfg @@ -0,0 +1,24 @@ +# iota2 configuration file to launch iota2 as a tiler of features. +# paths with 'XXX' must be replaced by local user paths. + + +chain : +{ + output_path : '/home/uz/vinascj/scratch/MMDC/iota2/grid_test_full' + + first_step : 'tiler' + last_step : 'tiler' + + proj : 'EPSG:2154' + # rasters_grid_path : '/XXX/grid_raster' + grid : '/home/uz/vinascj/src/MMDC/mmdc-singledate/thirdparties/sensorsio/src/sensorsio/data/sentinel2/mgrs_tiles.shp' #'/XXXX/Features.shp' # MGRS file providing tiles grid + list_tile : 'T31TCJ' # T31TCK T31TDJ' + features_path : '/XXX/non_tiled_features' + + spatial_resolution : 10 # mandatory but not used in this case. The output spatial resolution will be the one found in 'rasters_grid_path' directory. +} + +builders: +{ +builders_class_name : ["i2_features_to_grid"] +} diff --git a/configs/iota2/externalfeatures_with_userfeatures_50x50.cfg b/configs/iota2/externalfeatures_with_userfeatures_50x50.cfg new file mode 100644 index 0000000000000000000000000000000000000000..71bc61d1d25887b551723e21b110a3afe6bb05db --- /dev/null +++ b/configs/iota2/externalfeatures_with_userfeatures_50x50.cfg @@ -0,0 +1,77 @@ +chain : +{ + output_path : '/home/uz/vinascj/scratch/MMDC/iota2/juan_i2_data/results/external_user_features' + remove_output_path : False + check_inputs : False + + nomenclature_path : '/home/uz/vinascj/scratch/MMDC/iota2/juan_i2_data/inputs/other/nomenclature_grosse.txt' + color_table : '/home/uz/vinascj/scratch/MMDC/iota2/juan_i2_data/inputs/other/colorFile_grosse.txt' + + list_tile : 'T31TCJ' + ground_truth : '/home/uz/vinascj/scratch/MMDC/iota2/juan_i2_data/inputs/vector_data/S2_50x50.shp' + data_field : 'groscode' + s2_path : '/home/uz/vinascj/scratch/MMDC/iota2/juan_i2_data/inputs/raster_data/S2_2dates_50x50_symlink' + # user_feat_path: '/home/uz/vinascj/scratch/MMDC/iota2/juan_i2_data/inputs/raster_data/mnt_50x50' + user_feat_path: '/home/uz/vinascj/scratch/MMDC/iota2/grid_test_full' + # s2_output_path : '/home/uz/vinascj/scratch/MMDC/iota2/juan_i2_data/inputs/raster_data/S2_2dates_50x50_output' + s1_path : '/home/uz/vinascj/scratch/MMDC/iota2/juan_i2_data/inputs/config/SAR_test_T31TCJ.cfg' + + srtm_path:'/datalake/static_aux/MNT/SRTM_30_hgt' + worldclim_path:'/datalake/static_aux/worldclim-2.0' + grid : '/home/uz/vinascj/src/MMDC/mmdc-singledate/thirdparties/sensorsio/src/sensorsio/data/sentinel2/mgrs_tiles.shp' #'/XXXX/Features.shp' # MGRS file providing tiles grid + tile_field : 'Name' + + spatial_resolution : 10 + first_step : 'init' + last_step : 'validation' + proj : 'EPSG:2154' +} +userFeat: +{ + arbo:"/*" + patterns:"aspect,EVEN_AZIMUTH,ODD_ZENITH,s1_incidence_ascending,worldclim,AZIMUTH,EVEN_ZENITH,s1_azimuth_ascending,s1_incidence_descending,ZENITH,elevation,ODD_AZIMUTH,s1_azimuth_descending,slope" +} +python_data_managing : +{ + padding_size_x : 1 + padding_size_y : 1 + chunk_size_mode:"split_number" + number_of_chunks:2 + data_mode_access: "both" +} + +external_features : +{ + module:"/home/uz/vinascj/scratch/MMDC/iota2/juan_i2_data/inputs/other/soi.py" + functions : [['test_mnt', {"arg1":1}]] + # functions : [['test_s2_concat', {"arg1":1}]] + concat_mode : True + external_features_flag:True +} +arg_classification: +{ + enable_probability_map:True + generate_final_probability_map:True +} +arg_train : +{ + + runs : 1 + random_seed : 1 + classifier:"sharkrf" + + sample_selection : + { + sampler : 'random' + strategy : 'percent' + 'strategy.percent.p' : 1.0 + ram : 4000 + } +} + +task_retry_limits : +{ + allowed_retry : 0 + maximum_ram : 180.0 + maximum_cpu : 40 +} diff --git a/configs/iota2/i2_classification.cfg b/configs/iota2/i2_classification.cfg new file mode 100644 index 0000000000000000000000000000000000000000..71c1b7284e9759e95b520311ec2e0c3beb72c2d9 --- /dev/null +++ b/configs/iota2/i2_classification.cfg @@ -0,0 +1,61 @@ +chain : +{ + output_path : '/work/scratch/sowm/IOTA2_TEST_S2/IOTA2_Outputs/Results_classif' + remove_output_path : True + nomenclature_path : '/work/scratch/sowm/IOTA2_TEST_S2/nomenclature23.txt' + list_tile : 'T31TCJ T31TDJ' + s2_path : '/work/scratch/sowm/IOTA2_TEST_S2/sensor_data' + s1_path : '/work/scratch/sowm/test-data-cube/iota2/s1_full.cfg' + ground_truth : '/work/scratch/sowm/IOTA2_TEST_S2/vector_data/reference_data.shp' + data_field : 'code' + spatial_resolution : 10 + color_table : '/work/scratch/sowm/IOTA2_TEST_S2/colorFile.txt' + proj : 'EPSG:2154' + first_step: "init" + last_step: "validation" +} +sensors_data_interpolation: +{ + write_outputs: False +} + +external_features: +{ + + module: "/home/qt/sowm/scratch/test-data-cube/iota2/external_iota2_code.py" + functions: [["apply_convae", + { "checkpoint_path": "/home/qt/sowm/scratch/results/Convae_split_roi_psz_128_ep_150_bs_16_lr_1e-05_s1_[64, 128, 256, 512, 1024]_[8]_s2_[64, 128, 256, 512, 1024]_[8]_21916652.admin01", + "checkpoint_epoch": 100, + "patch_size": 74 + }]] + concat_mode: True +} + +python_data_managing: +{ + chunk_size_mode: "split_number" + number_of_chunks: 20 + padding_size_y: 27 + +} + +arg_train : +{ + classifier : 'sharkrf' + otb_classifier_options : {"classifier.sharkrf.ntrees" : 100 } + sample_selection: {"sampler": "random", + "strategy": "percent", + "strategy.percent.p": 0.1} + random_seed: 42 + +} +arg_classification : +{ + classif_mode : 'separate' +} +task_retry_limits : +{ + allowed_retry : 0 + maximum_ram : 180.0 + maximum_cpu : 40 +} diff --git a/configs/iota2/i2_grid.cfg b/configs/iota2/i2_grid.cfg new file mode 100644 index 0000000000000000000000000000000000000000..e13f87883d2ddfe06eb41e448d2a53c83b2911d9 --- /dev/null +++ b/configs/iota2/i2_grid.cfg @@ -0,0 +1,19 @@ +chain : +{ + output_path : '/work/scratch/vinascj/MMDC/iota2/grid_test' + spatial_resolution : 100 # output target resolution + first_step : 'tiler' # do not change + last_step : 'tiler' # do not change + + proj : 'EPSG:2154' # output target crs + grid : '/home/uz/vinascj/src/MMDC/mmdc-singledate/thirdparties/sensorsio/src/sensorsio/data/sentinel2/mgrs_tiles.shp' #'/XXXX/Features.shp' # MGRS file providing tiles grid + tile_field : 'Name' + list_tile : '31TCK 31TCJ' # list of tiles in grid to build + srtm_path:'/datalake/static_aux/MNT/SRTM_30_hgt' + worldclim_path:'/datalake/static_aux/worldclim-2.0' +} + +builders: +{ +builders_class_name : ["i2_features_to_grid"] +} diff --git a/configs/iota2/i2_tiler.cfg b/configs/iota2/i2_tiler.cfg new file mode 100644 index 0000000000000000000000000000000000000000..193d757d7bf28c475f36cf673d4cb32c8ff0c2a1 --- /dev/null +++ b/configs/iota2/i2_tiler.cfg @@ -0,0 +1,22 @@ +chain : +{ + output_path : '/work/CESBIO/projects/MAESTRIA/iota2_mmdc/tiled_auxilar_data' + spatial_resolution : 10 + first_step : 'tiler' + last_step : 'tiler' + + proj : 'EPSG:2154' + rasters_grid_path : '/work/CESBIO/projects/MAESTRIA/iota2_mmdc/tiled_reference/T31TCJ' + tile_field : 'Name' + list_tile : 'T31TCJ' + srtm_path:'/datalake/static_aux/MNT/SRTM_30_hgt' + worldclim_path:'/datalake/static_aux/worldclim-2.0' + s2_path : '/work/CESBIO/projects/MAESTRIA/iota2_mmdc/s2_data_datalake' + # add S1 data + s1_dir:"/work/CESBIO/projects/MAESTRIA/iota2_mmdc/s1_data_datalake" +} + +builders: +{ +builders_class_name : ["i2_features_to_grid"] +} diff --git a/configs/iota2/iota2_grid_full.cfg b/configs/iota2/iota2_grid_full.cfg new file mode 100644 index 0000000000000000000000000000000000000000..d5379d4997ca4edf95e8ea310d76361999f93483 --- /dev/null +++ b/configs/iota2/iota2_grid_full.cfg @@ -0,0 +1,21 @@ +chain : +{ + output_path : '/home/uz/vinascj/scratch/MMDC/iota2/grid_test_full' + spatial_resolution : 10 + first_step : 'tiler' + last_step : 'tiler' + + proj : 'EPSG:2154' + rasters_grid_path : '/work/OT/theia/oso/arthur/TMP/raster_grid' + tile_field : 'Name' + list_tile : 'T31TCJ' + srtm_path:'/datalake/static_aux/MNT/SRTM_30_hgt' + worldclim_path:'/datalake/static_aux/worldclim-2.0' + s2_path:'/work/OT/theia/oso/arthur/TMP/test_s2_angles' + s1_dir:'/work/OT/theia/oso/arthur/s1_emma/all_tiles' +} + +builders: +{ +builders_class_name : ["i2_features_to_grid"] +} diff --git a/configs/iota2/iota2_mmdc_full.cfg b/configs/iota2/iota2_mmdc_full.cfg new file mode 100644 index 0000000000000000000000000000000000000000..fb651a84712caf32e05b65e6a5b196aca2e2357a --- /dev/null +++ b/configs/iota2/iota2_mmdc_full.cfg @@ -0,0 +1,76 @@ +# Entry Point Iota2 Config +chain : +{ + # Input and Output data + output_path : '/work/CESBIO/projects/MAESTRIA/iota2_mmdc/iota2_results' + remove_output_path : False + check_inputs : False + list_tile : 'T31TCJ' + data_field : 'code' + s2_path : '/work/scratch/vinascj/MMDC/iota2/i2_training_data/s2_data_full' + ground_truth : '/work/CESBIO/projects/MAESTRIA/iota2_mmdc/vector_db/reference_data.shp' + + # Add S1 data + s1_path : "/home/uz/vinascj/src/MMDC/mmdc-singledate/configs/iota2/config_sar.cfg" + + # Classification related parameters + spatial_resolution : 10 + color_table : '/work/CESBIO/projects/MAESTRIA/iota2_mmdc/colorFile.txt' + nomenclature_path : '/work/CESBIO/projects/MAESTRIA/iota2_mmdc/nomenclature.txt' + first_step : 'init' + last_step : 'validation' + proj : 'EPSG:2154' + + # Path to the User data + user_feat_path: '/work/CESBIO/projects/MAESTRIA/iota2_mmdc/tiled_auxilar_data' #/work/CESBIO/projects/MAESTRIA/iota2_mmdc/tiled_auxilar_data/T31TCJ +} + +# Parametrize the classifier +arg_train : +{ + random_seed : 0 # Set the random seed to split the ground truth + runs : 1 + classifier : 'sharkrf' + otb_classifier_options : {'classifier.sharkrf.nodesize': 25} + sample_selection : + { + sampler : 'random' + strategy : 'percent' + 'strategy.percent.p' : 0.1 + } +} + +# Use User Provided data +userFeat: +{ + arbo:"/*" + patterns:"elevation,aspect,slope,worldclim,EVEN_AZIMUTH,ODD_ZENITH,s1_incidence_ascending,worldclim,AZIMUTH,EVEN_ZENITH,s1_azimuth_ascending,s1_incidence_descending,ZENITH,elevation,ODD_AZIMUTH,s1_azimuth_descending," +} + +# Parameters for the user provided data +python_data_managing : +{ + padding_size_x : 1 + padding_size_y : 1 + chunk_size_mode:"split_number" + number_of_chunks:4 + data_mode_access: "both" +} + +# Process the User and Iota2 data +external_features : +{ + # module:"/home/uz/vinascj/scratch/MMDC/iota2/juan_i2_data/inputs/other/soi.py" + module:"/home/uz/vinascj/src/MMDC/mmdc-singledate/src/mmdc_singledate/inference/mmdc_iota2_inference.py" + functions : [['test_mnt', {"arg1":1}]] + concat_mode : True + external_features_flag:True +} + +# Parametrize nodes +task_retry_limits : +{ + allowed_retry : 0 + maximum_ram : 180.0 + maximum_cpu : 40 +} diff --git a/create-conda-env.sh b/create-conda-env.sh index 36d9e095ac741a8f583be21825b83183102951ce..61f3fc75559987423d0242d841d339532bea7f91 100755 --- a/create-conda-env.sh +++ b/create-conda-env.sh @@ -51,6 +51,8 @@ mamba install --yes "pytorch>=2.0.0.dev202302=py3.10_cuda11.7_cudnn8.5.0_0" "tor conda deactivate conda activate $target +# proxy +source ~/set_proxy_iota2.sh # Requirements pip install -r requirements-mmdc-sgld-test.txt diff --git a/create-iota2-env.sh b/create-iota2-env.sh new file mode 100644 index 0000000000000000000000000000000000000000..5f3b99ee1a40f098a0648e34b6892558a7dbde8c --- /dev/null +++ b/create-iota2-env.sh @@ -0,0 +1,92 @@ +#!/usr/bin/env bash + +export python_version="3.9" +export name="mmdc-iota2" +if ! [ -z "$1" ] +then + export name=$1 +fi + +source ~/set_proxy_iota2.sh +if [ -z "$https_proxy" ] +then + echo "Please set https_proxy environment variable before running this script" + exit 1 +fi + +export target=/work/scratch/$USER/virtualenv/$name + +if ! [ -z "$2" ] +then + export target="$2/$name" +fi + +echo "Installing $name in $target ..." + +if [ -d "$target" ]; then + echo "Cleaning previous conda env" + rm -rf $target +fi + +# Create blank virtualenv +module purge +module load conda +module load gcc +conda activate +conda create --yes --prefix $target python==${python_version} pip + +# Enter virtualenv +conda deactivate +conda activate $target + +# install mamba +conda install mamba + +which python +python --version + +# install libsvm +mamba install -c conda-forge libsvm=325 + +# Install iota2 +#mamba install iota2_develop=257da617 -c iota2 +mamba install iota2 -c iota2 + +# clone the lastest version of iota2 +rm -rf iota2_thirdparties/iota2 +git clone -b issue#600/tile_exogenous_features https://framagit.org/iota2-project/iota2.git iota2_thirdparties/iota2 +cd iota2_thirdparties/iota2 +git checkout issue#600/tile_exogenous_features +git pull origin issue#600/tile_exogenous_features +cd ../../ +pwd + +# create a backup +cd /home/uz/$USER/scratch/virtualenv/mmdc-iota2/lib/python3.9/site-packages/iota2-0.0.0-py3.9.egg/ +mv iota2 iota2_backup + +# create a symbolic link +ln -s ~/src/MMDC/mmdc-singledate/iota2_thirdparties/iota2/iota2/ iota2 + +cd ~/src/MMDC/mmdc-singledate + +mamba install -c conda-forge pydantic + +# install missing dependancies +pip install -r requirements-mmdc-iota2.txt + +# Install sensorsio +rm -rf iota2_thirdparties/sensorsio +git clone -b handle_local_srtm https://framagit.org/iota2-project/sensorsio.git iota2_thirdparties/sensorsio +pip install iota2_thirdparties/sensorsio + +# Install torchutils +rm -rf iota2_thirdparties/torchutils +git clone https://src.koda.cnrs.fr/mmdc/torchutils.git iota2_thirdparties/torchutils +pip install iota2_thirdparties/torchutils + +# Install the current project in edit mode +pip install -e .[testing] + +# End +conda deactivate diff --git a/iota2_init.sh b/iota2_init.sh new file mode 100644 index 0000000000000000000000000000000000000000..4d1f386fd738ade3e88092fc7c9009e21d9af48a --- /dev/null +++ b/iota2_init.sh @@ -0,0 +1,3 @@ +module purge +module load conda +conda activate /work/scratch/${USER}/virtualenv/mmdc-iota2 diff --git a/jobs/inference_single_date.pbs b/jobs/inference_single_date.pbs new file mode 100644 index 0000000000000000000000000000000000000000..1374be032667395ef6155e5acae186735a34b252 --- /dev/null +++ b/jobs/inference_single_date.pbs @@ -0,0 +1,35 @@ +#!/bin/bash +#PBS -N infer_sgdt +#PBS -q qgpgpu +#PBS -l select=1:ncpus=8:mem=92G:ngpus=1 +#PBS -l walltime=1:00:00 + +# be sure no modules loaded +module purge + +export SRCDIR=${HOME}/src/MMDC/mmdc-singledate +export MMDC_INIT=${SRCDIR}/mmdc_init.sh +export WORKING_DIR=/work/scratch/${USER}/MMDC/jobs +export DATA_DIR=/work/CESBIO/projects/MAESTRIA/test_onetile/total/ +export SCRATCH_DIR=/work/scratch/${USER}/ + +cd ${SRCDIR} + +mkdir ${SCRATCH_DIR}/MMDC/inference/ +mkdir ${SCRATCH_DIR}/MMDC/inference/singledate + +source ${MMDC_INIT} + +python ${SRCDIR}/src/bin/inference_mmdc_singledate.py \ + --s2_filename ${DATA_DIR}/T31TCJ/SENTINEL2A_20180302-105023-464_L2A_T31TCJ_C_V2-2_roi_0.tif \ + --s1_asc_filename ${DATA_DIR}/T31TCJ/S1A_IW_GRDH_1SDV_20180303T174722_20180303T174747_020854_023C45_4525_36.09887223808782_15.429479982324139_roi_0.tif \ + --s1_desc_filename ${DATA_DIR}/T31TCJ/S1A_IW_GRDH_1SDV_20180302T060027_20180302T060052_020832_023B8F_C485_40.83835645911643_165.05888005216622_roi_0.tif \ + --srtm_filename ${DATA_DIR}/T31TCJ/srtm_T31TCJ_roi_0.tif \ + --worldclim_filename ${DATA_DIR}/T31TCJ/wc_clim_1_T31TCJ_roi_0.tif \ + --export_path ${SCRATCH_DIR}/MMDC/inference/singledate \ + --model_checkpoint_path /home/uz/vinascj/scratch/MMDC/results/latent/logs/experiments/runs/mmdc_full/2023-03-16_14-00-10/checkpoints/last.ckpt \ + --model_config_path /home/uz/vinascj/src/MMDC/mmdc-singledate/configs/model/mmdc_full.yaml \ + --latent_space_size 16 \ + --patch_size 128 \ + --nb_lines 256 \ + --sensors S2L2A S1FULL \ diff --git a/jobs/inference_time_serie.pbs b/jobs/inference_time_serie.pbs new file mode 100644 index 0000000000000000000000000000000000000000..0aaf57e27a722be7a9463c1067a8425e7c5f53ee --- /dev/null +++ b/jobs/inference_time_serie.pbs @@ -0,0 +1,32 @@ +#!/bin/bash +#PBS -N infer_sgdt +#PBS -q qgpgpu +#PBS -l select=1:ncpus=8:mem=92G:ngpus=1 +#PBS -l walltime=2:00:00 + +# be sure no modules loaded +module purge + +export SRCDIR=${HOME}/src/MMDC/mmdc-singledate +export MMDC_INIT=${SRCDIR}/mmdc_init.sh +export WORKING_DIR=/work/scratch/${USER}/MMDC/jobs +export DATA_DIR=/work/CESBIO/projects/MAESTRIA/training_dataset2/ +export SCRATCH_DIR=/work/scratch/${USER}/ + + +mkdir ${SCRATCH_DIR}/MMDC/inference/ +mkdir ${SCRATCH_DIR}/MMDC/inference/time_serie + + +cd ${SRCDIR} +source ${MMDC_INIT} + +python ${SRCDIR}/src/bin/inference_mmdc_timeserie.py \ + --input_path ${DATA_DIR} \ + --export_path ${SCRATCH_DIR}/MMDC/inference/time_serie \ + --tile_list T33TUL \ + --model_checkpoint_path /home/uz/vinascj/scratch/MMDC/results/latent/logs/experiments/runs/mmdc_full/2023-03-16_14-00-10/checkpoints/last.ckpt \ + --model_config_path /home/uz/vinascj/src/MMDC/mmdc-singledate/configs/model/mmdc_full.yaml \ + --days_gap 0 \ + --patch_size 128 \ + --nb_lines 256 \ diff --git a/jobs/iota2_aux_angles_test.pbs b/jobs/iota2_aux_angles_test.pbs new file mode 100644 index 0000000000000000000000000000000000000000..6eaacfe2b80aa45250bccc366d768aae11968d32 --- /dev/null +++ b/jobs/iota2_aux_angles_test.pbs @@ -0,0 +1,20 @@ +#!/bin/bash +#PBS -N iota2-test +#PBS -l select=1:ncpus=8:mem=12G +#PBS -l walltime=1:00:00 + +# be sure no modules loaded +conda deactivate +module purge + +# load modules +module load conda +conda activate /work/scratch/${USER}/virtualenv/mmdc-iota2 + + +Iota2.py -scheduler_type debug \ + -config ${HOME}/src/MMDC/mmdc-singledate/configs/iota2/iota2_grid_angles.cfg + #-nb_parallel_tasks 1 \ + #-only_summary + +# /home/uz/vinascj/src/MMDC/mmdc-singledate/configs/iota2/externalfeatures_with_userfeatures_50x50.cfg \ diff --git a/jobs/iota2_aux_full.pbs b/jobs/iota2_aux_full.pbs new file mode 100644 index 0000000000000000000000000000000000000000..2e438ed230ab02a57c2423491c4f0b270b4215b7 --- /dev/null +++ b/jobs/iota2_aux_full.pbs @@ -0,0 +1,20 @@ +#!/bin/bash +#PBS -N iota2-test +#PBS -l select=1:ncpus=8:mem=12G +#PBS -l walltime=1:00:00 + +# be sure no modules loaded +conda deactivate +module purge + +# load modules +module load conda +conda activate /work/scratch/${USER}/virtualenv/mmdc-iota2 + + +Iota2.py -scheduler_type debug \ + -config ${HOME}/src/MMDC/mmdc-singledate/configs/iota2/iota2_grid_full.cfg + #-nb_parallel_tasks 1 \ + #-only_summary + +# /home/uz/vinascj/src/MMDC/mmdc-singledate/configs/iota2/externalfeatures_with_userfeatures_50x50.cfg \ diff --git a/jobs/iota2_aux_test.pbs b/jobs/iota2_aux_test.pbs new file mode 100644 index 0000000000000000000000000000000000000000..3b3eed36098a0dadbac5933a6100d5321cdd6b7c --- /dev/null +++ b/jobs/iota2_aux_test.pbs @@ -0,0 +1,20 @@ +#!/bin/bash +#PBS -N iota2-test +#PBS -l select=1:ncpus=8:mem=12G +#PBS -l walltime=1:00:00 + +# be sure no modules loaded +conda deactivate +module purge + +# load modules +module load conda +conda activate /work/scratch/${USER}/virtualenv/mmdc-iota2 + + +Iota2.py -scheduler_type debug \ + -config ${HOME}/src/MMDC/mmdc-singledate/configs/iota2/i2_grid.cfg + #-nb_parallel_tasks 1 \ + #-only_summary + +# /home/uz/vinascj/src/MMDC/mmdc-singledate/configs/iota2/externalfeatures_with_userfeatures_50x50.cfg \ diff --git a/jobs/iota2_clasif_mmdc_full.pbs b/jobs/iota2_clasif_mmdc_full.pbs new file mode 100644 index 0000000000000000000000000000000000000000..e3ec1e9731521f71acb8957c61d5e4f6c3840d5b --- /dev/null +++ b/jobs/iota2_clasif_mmdc_full.pbs @@ -0,0 +1,120 @@ +#!/bin/bash +#PBS -N iota2-class +#PBS -l select=1:ncpus=8:mem=12G +#PBS -l walltime=10:00:00 + +# be sure no modules loaded +conda deactivate +module purge + + +ulimit -u 5000 + +# export ROOT_PATH="$(dirname $( cd "$(dirname "${BASH_SOURCE[0]}")"; pwd -P ))" +# export OTB_APPLICATION_PATH=$OTB_APPLICATION_PATH:/work/scratch/$USER/virtualenv/mmdc-iota2/lib/otb/applications// +# export PATH=$PATH:$ROOT_PATH/bin:$OTB_APPLICATION_PATH +# export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$ROOT_PATH/lib:$OTB_APPLICATION_PATH # +# +# load modules +module load conda/4.12.0 +conda activate /work/scratch/${USER}/virtualenv/mmdc-iota2 + +# Apply Iota2 +Iota2.py \ + -config ${HOME}/src/MMDC/mmdc-singledate/configs/iota2/iota2_mmdc_full.cfg \ + -scheduler_type PBS \ + -config_ressources /home/uz/vinascj/src/MMDC/mmdc-singledate/configs/iota2/config_ressources.cfg \ + # -scheduler_type debug \ + # -only_summary + # -nb_parallel_tasks 1 \ + + +# New +# Functions availables +#['__class__', +#'__delattr__', +#'__dict__', +#'__dir__', +#'__doc__', +#'__eq__', +#'__format__', +#'__ge__', +#'__getattribute__', +#'__gt__', +#'__hash__', +#'__init__', +#'__init_subclass__', +#'__le__', +#'__lt__', +#'__module__', +#'__ne__', +#'__new__', +#'__reduce__', +#'__reduce_ex__', +#'__repr__', +#'__setattr__', +#'__sizeof__', +#'__str__', +#'__subclasshook__', +#'__weakref__', +#'all_dates', +#'allow_nans', +#'binary_masks', +#'concat_mode', +#'dim_ts', +#'enabled_gap', +#'exogeneous_data_array', +#'exogeneous_data_name', +#'external_functions', +#'fill_missing_dates', +#'func_param', +#'get_Sentinel1_ASC_vh_binary_masks', +#'get_Sentinel1_ASC_vv_binary_masks', +#'get_Sentinel1_DES_vh_binary_masks', +#'get_Sentinel1_DES_vv_binary_masks', +#'get_Sentinel2_binary_masks', +#'get_filled_masks', +#'get_filled_stack', +#'get_interpolated_Sentinel1_ASC_vh', +#'get_interpolated_Sentinel1_ASC_vv', +#'get_interpolated_Sentinel1_DES_vh', +#'get_interpolated_Sentinel1_DES_vv', +#'get_interpolated_Sentinel2_B11', +#'get_interpolated_Sentinel2_B12', +#'get_interpolated_Sentinel2_B2', +#'get_interpolated_Sentinel2_B3', +#'get_interpolated_Sentinel2_B4', +#'get_interpolated_Sentinel2_B5', +#'get_interpolated_Sentinel2_B6', +#'get_interpolated_Sentinel2_B7', +#'get_interpolated_Sentinel2_B8', +#'get_interpolated_Sentinel2_B8A', +#'get_interpolated_Sentinel2_Brightness', +#'get_interpolated_Sentinel2_NDVI', +#'get_interpolated_Sentinel2_NDWI', +#'get_interpolated_dates', +#'get_interpolated_userFeatures_aspect', +#'get_raw_Sentinel1_ASC_vh', +#'get_raw_Sentinel1_ASC_vv', +#'get_raw_Sentinel1_DES_vh', 'get_raw_Sentinel1_DES_vv', +#'get_raw_Sentinel2_B11', +#'get_raw_Sentinel2_B12', +#'get_raw_Sentinel2_B2', +#'get_raw_Sentinel2_B3', +#'get_raw_Sentinel2_B4', +#'get_raw_Sentinel2_B5', +#'get_raw_Sentinel2_B6', +#'get_raw_Sentinel2_B7', +#'get_raw_Sentinel2_B8', +#'get_raw_Sentinel2_B8A', +#'get_raw_dates', +#'get_raw_userFeatures_aspect', +#'interpolated_data', +#'interpolated_dates', +#'missing_masks_values', +#'missing_refl_values', +#'out_data', +#'process', +#'raw_data', +#'raw_dates', +#'test_user_feature_with_fake_data'] diff --git a/jobs/iota2_external_feature_test.pbs b/jobs/iota2_external_feature_test.pbs new file mode 100644 index 0000000000000000000000000000000000000000..f8b6953c9eccb6981125153d07ae459d9d51a2e0 --- /dev/null +++ b/jobs/iota2_external_feature_test.pbs @@ -0,0 +1,45 @@ +#!/bin/bash +#PBS -N iota2-test +#PBS -l select=1:ncpus=8:mem=12G +#PBS -l walltime=1:00:00 + +# be sure no modules loaded +conda deactivate +module purge + +# load modules +module load conda +conda activate /work/scratch/${USER}/virtualenv/mmdc-iota2 +# conda activate /work/scratch/${USER}/virtualenv/mmdc-iota2-develop + + +Iota2.py -scheduler_type debug \ + -config ${HOME}/src/MMDC/mmdc-singledate/configs/iota2/externalfeatures_with_userfeatures_50x50.cfg \ + -nb_parallel_tasks 1 \ + # -only_summary + + + +# Functions availables +# ['__class__', '__delattr__', '__dict__', '__dir__', '__doc__', +# '__eq__', '__format__', '__ge__', '__getattribute__', '__gt__', +# '__hash__', '__init__', '__init_subclass__', '__le__', '__lt__', +# '__module__', '__ne__', '__new__', '__reduce__', '__reduce_ex__', +# '__repr__', '__setattr__', '__sizeof__', '__str__', '__subclasshook__', +# '__weakref__', 'all_dates', 'allow_nans', 'binary_masks', 'concat_mode', +# 'dim_ts', 'enabled_gap', 'exogeneous_data_array', 'exogeneous_data_name', +# 'external_functions', 'fill_missing_dates', 'func_param', +# 'get_Sentinel2_binary_masks', 'get_filled_masks', 'get_filled_stack', +# 'get_interpolated_Sentinel2_B11', 'get_interpolated_Sentinel2_B12', +# 'get_interpolated_Sentinel2_B2', 'get_interpolated_Sentinel2_B3', +# 'get_interpolated_Sentinel2_B4', 'get_interpolated_Sentinel2_B5', +# 'get_interpolated_Sentinel2_B6', 'get_interpolated_Sentinel2_B7', +# 'get_interpolated_Sentinel2_B8', 'get_interpolated_Sentinel2_B8A', +# 'get_interpolated_Sentinel2_Brightness', 'get_interpolated_Sentinel2_NDVI', +# 'get_interpolated_Sentinel2_NDWI', 'get_interpolated_dates', 'get_raw_Sentinel2_B11', +# 'get_raw_Sentinel2_B12', 'get_raw_Sentinel2_B2', 'get_raw_Sentinel2_B3', +# 'get_raw_Sentinel2_B4', 'get_raw_Sentinel2_B5', 'get_raw_Sentinel2_B6', +# 'get_raw_Sentinel2_B7', 'get_raw_Sentinel2_B8', 'get_raw_Sentinel2_B8A', +# 'get_raw_dates', 'interpolated_data', 'interpolated_dates', 'missing_masks_values', +# 'missing_refl_values', 'out_data', 'process', 'raw_data', 'raw_dates', +# 'test_user_feature_with_fake_data'] diff --git a/jobs/iota2_tile_aux_features.pbs b/jobs/iota2_tile_aux_features.pbs new file mode 100644 index 0000000000000000000000000000000000000000..2507bf3622735f451fd4f5ddca3bcf957378f1b6 --- /dev/null +++ b/jobs/iota2_tile_aux_features.pbs @@ -0,0 +1,18 @@ +#!/bin/bash +#PBS -N iota2-tiler +#PBS -l select=1:ncpus=8:mem=120G +#PBS -l walltime=4:00:00 + +# be sure no modules loaded +conda deactivate +module purge + +# load modules +module load conda +conda activate /work/scratch/${USER}/virtualenv/mmdc-iota2 + + +Iota2.py -scheduler_type debug \ + -config ${HOME}/src/MMDC/mmdc-singledate/configs/iota2/i2_tiler.cfg + -nb_parallel_tasks 1 \ + -only_summary diff --git a/requirements-mmdc-iota2.txt b/requirements-mmdc-iota2.txt new file mode 100644 index 0000000000000000000000000000000000000000..619064374381c3f90897b1041ece6b7f67ac77eb --- /dev/null +++ b/requirements-mmdc-iota2.txt @@ -0,0 +1,3 @@ +config==0.5.1 +itk +pydantic diff --git a/requirements-mmdc-sgld.txt b/requirements-mmdc-sgld.txt index 02defbac5b66055577c13f6a2bdbb69a8d055c06..c4b218d12589530ee4bc2a8bb5323b799f95346e 100644 --- a/requirements-mmdc-sgld.txt +++ b/requirements-mmdc-sgld.txt @@ -1,6 +1,7 @@ affine auto-walrus black # code formatting +config findpeaks flake8 # code analysis geopandas @@ -18,6 +19,8 @@ numpy pandas pre-commit # hooks for applying linters on commit pudb # debugger +pydantic # validate configurations in iota2 +pydantic-yaml # read yml configs for construct the networks pytest # tests python-dotenv # loading env variables from .env file python-lsp-server[all] diff --git a/setup.cfg b/setup.cfg index 9c5469ae7b2d4c444e78501d06a55285fa81dc5e..3cce8fa609ddec17b800acd1c9e5742fad738506 100644 --- a/setup.cfg +++ b/setup.cfg @@ -68,6 +68,14 @@ testing = pytest-cov [options.entry_points] +console_scripts = + mmdc_generate_dataset = bin.generate_mmdc_dataset:main + mmdc_visualize = bin.visualize_mmdc_ds:main + mmdc_split_dataset = bin.split_mmdc_dataset:main + mmdc_inference_singledate = bin.inference_mmdc_singledate:main + mmdc_inference_time_serie = bin.inference_mmdc_timeserie:main + +# mmdc_inference = bin.infer_mmdc_ds:main # Add here console scripts like: # console_scripts = # script_name = mmdc-singledate.module:function diff --git a/src/bin/generate_mmdc_dataset.py b/src/bin/generate_mmdc_dataset.py index d801c56694537e2f86e830bfb7a10ba024b524c8..ada7c11b64e258f1b8f3a9cacd0c0b683b0eb1d6 100644 --- a/src/bin/generate_mmdc_dataset.py +++ b/src/bin/generate_mmdc_dataset.py @@ -81,7 +81,7 @@ def get_parser() -> argparse.ArgumentParser: return arg_parser -if __name__ == "__main__": +def main(): # Parser arguments parser = get_parser() args = parser.parse_args() @@ -124,3 +124,7 @@ if __name__ == "__main__": args.threshold, ), ) + + +if __name__ == "__main__": + main() diff --git a/src/bin/inference_mmdc_singledate.py b/src/bin/inference_mmdc_singledate.py new file mode 100644 index 0000000000000000000000000000000000000000..5df740b3fb860dcfbd91be4e7d79581306f888ff --- /dev/null +++ b/src/bin/inference_mmdc_singledate.py @@ -0,0 +1,180 @@ +#!/usr/bin/env python3 +# copyright: (c) 2023 cesbio / centre national d'Etudes Spatiales + +""" +Inference code for MMDC +""" + + +# imports +import argparse +import logging +import os + +import torch + +from mmdc_singledate.inference.components.inference_components import ( + get_mmdc_full_model, + predict_mmdc_model, +) +from mmdc_singledate.inference.components.inference_utils import GeoTiffDataset +from mmdc_singledate.inference.mmdc_tile_inference import ( + MMDCProcess, + predict_single_date_tile, +) +from mmdc_singledate.inference.utils import get_scales + + +def get_parser() -> argparse.ArgumentParser: + """ + Generate argument parser for CLI + """ + arg_parser = argparse.ArgumentParser( + os.path.basename(__file__), + description="Compute the inference a selected number of latent spaces " + " and exported as GTiff for a given tile", + ) + + arg_parser.add_argument( + "--loglevel", + default="INFO", + choices=("DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"), + help="Logger level (default: INFO. Should be one of " + "(DEBUG, INFO, WARNING, ERROR, CRITICAL)", + ) + + arg_parser.add_argument( + "--s2_filename", type=str, help="s2 input file", required=False + ) + + arg_parser.add_argument( + "--s1_asc_filename", type=str, help="s1 asc input file", required=False + ) + + arg_parser.add_argument( + "--s1_desc_filename", type=str, help="s1 desc input file", required=False + ) + + arg_parser.add_argument( + "--srtm_filename", type=str, help="srtm input file", required=True + ) + + arg_parser.add_argument( + "--worldclim_filename", type=str, help="worldclim input file", required=True + ) + + arg_parser.add_argument( + "--export_path", type=str, help="export path ", required=True + ) + + arg_parser.add_argument( + "--model_checkpoint_path", + type=str, + help="model checkpoint path ", + required=True, + ) + + arg_parser.add_argument( + "--model_config_path", + type=str, + help="model configuration path ", + required=True, + ) + + arg_parser.add_argument( + "--latent_space_size", + type=int, + help="Number of channels in the output file", + required=False, + ) + + arg_parser.add_argument( + "--nb_lines", + type=int, + help="Number of Line to read in each window reading", + required=False, + ) + + arg_parser.add_argument( + "--patch_size", + type=int, + help="Number of Line to read in each window reading", + required=True, + ) + + arg_parser.add_argument( + "--sensors", + dest="sensors", + nargs="+", + help=( + "List of sensors S2L2A | S1FULL | S1ASC | S1DESC", + " (Is sensible to the order S2L2A always firts)", + ), + required=True, + ) + + return arg_parser + + +def main(): + """ + Entry Point + """ + # Parser arguments + parser = get_parser() + args = parser.parse_args() + + # Configure logging + numeric_level = getattr(logging, args.loglevel.upper(), None) + if not isinstance(numeric_level, int): + raise ValueError("Invalid log level: %s" % args.loglevel) + + logging.basicConfig( + level=numeric_level, + datefmt="%y-%m-%d %H:%M:%S", + format="%(asctime)s :: %(levelname)s :: %(message)s", + ) + + # device + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + # model + mmdc_full_model = get_mmdc_full_model( + checkpoint=args.model_checkpoint_path, + mmdc_full_config_path=args.model_config_path, + inference_tile=True, + ) + scales = get_scales() + mmdc_full_model.set_scales(scales) + mmdc_full_model.to(device) + # process class + mmdc_process = MMDCProcess( + count=args.latent_space_size, + nb_lines=args.nb_lines, + patch_size=args.patch_size, + process=predict_mmdc_model, + model=mmdc_full_model, + ) + + # create export filename + export_path = f"{args.export_path}/latent_singledate_{'_'.join(args.sensors)}.tif" + + # GeoTiff + geotiff_dataclass = GeoTiffDataset( + s2_filename=args.s2_filename, + s1_asc_filename=args.s1_asc_filename, + s1_desc_filename=args.s1_desc_filename, + srtm_filename=args.srtm_filename, + wc_filename=args.worldclim_filename, + ) + + predict_single_date_tile( + input_data=geotiff_dataclass, + export_path=export_path, + sensors=args.sensors, + process=mmdc_process, + ) + + +if __name__ == "__main__": + main() diff --git a/src/bin/inference_mmdc_timeserie.py b/src/bin/inference_mmdc_timeserie.py new file mode 100644 index 0000000000000000000000000000000000000000..4db518f2b3d1d7a4702e7522469f591160e3cb06 --- /dev/null +++ b/src/bin/inference_mmdc_timeserie.py @@ -0,0 +1,139 @@ +#!/usr/bin/env python3 +# copyright: (c) 2023 cesbio / centre national d'Etudes Spatiales + +""" +Inference code for MMDC +""" + + +# imports +import argparse +import logging +import os +from pathlib import Path + +from mmdc_singledate.inference.mmdc_tile_inference import ( + inference_dataframe, + mmdc_tile_inference, +) + + +def get_parser() -> argparse.ArgumentParser: + """ + Generate argument parser for CLI + """ + arg_parser = argparse.ArgumentParser( + os.path.basename(__file__), + description="Compute the inference a selected number of latent spaces " + " and exported as GTiff for a given tile", + ) + + arg_parser.add_argument( + "--loglevel", + default="INFO", + choices=("DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"), + help="Logger level (default: INFO. Should be one of " + "(DEBUG, INFO, WARNING, ERROR, CRITICAL)", + ) + + arg_parser.add_argument( + "--input_path", type=str, help="data input folder", required=True + ) + + arg_parser.add_argument( + "--export_path", type=str, help="export path ", required=True + ) + + arg_parser.add_argument( + "--days_gap", + type=int, + help="difference between S1 and S2 acquisitions in days", + required=True, + ) + + arg_parser.add_argument( + "--model_checkpoint_path", + type=str, + help="model checkpoint path ", + required=True, + ) + + arg_parser.add_argument( + "--model_config_path", + type=str, + help="model configuration path ", + required=True, + ) + + arg_parser.add_argument( + "--tile_list", + nargs="+", + type=str, + help="List of tiles to produce", + required=False, + ) + + arg_parser.add_argument( + "--patch_size", + type=int, + help="Sizes of the patches to make the inference", + required=False, + ) + + arg_parser.add_argument( + "--nb_lines", + type=int, + help="Number of Line to read in each window reading", + required=False, + ) + + return arg_parser + + +def main(): + """ + Entry Point + """ + # Parser arguments + parser = get_parser() + args = parser.parse_args() + + # Configure logging + numeric_level = getattr(logging, args.loglevel.upper(), None) + if not isinstance(numeric_level, int): + raise ValueError("Invalid log level: %s" % args.loglevel) + + logging.basicConfig( + level=numeric_level, + datefmt="%y-%m-%d %H:%M:%S", + format="%(asctime)s :: %(levelname)s :: %(message)s", + ) + + if not os.path.isdir(args.export_path): + logging.info( + f"The directory is not present. Creating a new one.. " f"{args.export_path}" + ) + Path(args.export_path).mkdir() + + # calculate time serie + tile_df = inference_dataframe( + samples_dir=args.input_path, + input_tile_list=args.tile_list, + days_gap=args.days_gap, + ) + logging.info(f"nb of dates to infer : {tile_df.shape[0]}") + # entry point + mmdc_tile_inference( + inference_dataframe=tile_df, + export_path=Path(args.export_path), + model_checkpoint_path=Path(args.model_checkpoint_path), + model_config_path=Path(args.model_config_path), + patch_size=args.patch_size, + nb_lines=args.nb_lines, + ) + + logging.info("All the data has exported succesfully") + + +if __name__ == "__main__": + main() diff --git a/src/bin/split_mmdc_dataset.py b/src/bin/split_mmdc_dataset.py index de429bf5c5a539a2c76e8d0af1bc1a42401b7c10..b205a24971293c051f942134bfa3c3ede81d3f4c 100644 --- a/src/bin/split_mmdc_dataset.py +++ b/src/bin/split_mmdc_dataset.py @@ -141,16 +141,10 @@ def split_tile_rois( roi_intersect.writelines(sorted([f"{roi}\n" for roi in roi_list])) -def main( - input_dir: Path, input_file: Path, test_percentage: int, random_state: int -) -> None: +def main(): """ Split the tiles/rois between train/val/test """ - split_tile_rois(input_dir, input_file, test_percentage, random_state) - - -if __name__ == "__main__": # Parser arguments parser = get_parser() args = parser.parse_args() @@ -171,10 +165,14 @@ if __name__ == "__main__": logging.info("test percentage selected : %s", args.test_percentage) logging.info("random state value : %s", args.random_state) - # Go to main - main( + # Go to entry point + split_tile_rois( Path(args.tensors_dir), Path(args.roi_intersections_file), args.test_percentage, args.random_state, ) + + +if __name__ == "__main__": + main() diff --git a/src/bin/visualize_mmdc_ds.py b/src/bin/visualize_mmdc_ds.py index 01da86c662bac0e91ad3cf8c654e0ec3d39d4e38..a69bc924abb3550c0e6e5e263a5b14e5004f72fa 100644 --- a/src/bin/visualize_mmdc_ds.py +++ b/src/bin/visualize_mmdc_ds.py @@ -71,34 +71,10 @@ def get_parser() -> argparse.ArgumentParser: return arg_parser -def main( - input_directory: Path, - input_tiles: Path, - patch_index: list[int] | None, - nb_patches: int | None, - export_path: Path, -) -> None: +def main(): """ Entry point for the visualization """ - try: - export_visualization_graphs( - input_directory=input_directory, - input_tiles=input_tiles, - patch_index=patch_index, - nb_patches=nb_patches, - export_path=export_path, - ) - - except FileNotFoundError: - if not input_directory.exists(): - print(f"Folder {input_directory} does not exist!!") - - if not export_path.exists(): - print(f"Folder {export_path} does not exist!!") - - -if __name__ == "__main__": # Parser arguments parser = get_parser() args = parser.parse_args() @@ -126,13 +102,26 @@ if __name__ == "__main__": logging.info(" index patches : %s", args.patch_index) logging.info(" output directory : %s", args.export_path) - # Go to main - main( - input_directory=Path(args.input_path), - input_tiles=Path(args.input_tiles), - patch_index=args.patch_index, - nb_patches=args.nb_patches, - export_path=Path(args.export_path), - ) + # Go to entry point + + try: + export_visualization_graphs( + input_directory=Path(args.input_path), + input_tiles=Path(args.input_tiles), + patch_index=args.patch_index, + nb_patches=args.nb_patches, + export_path=Path(args.export_path), + ) + + except FileNotFoundError: + if not args.input_directory.exists(): + print(f"Folder {args.input_directory} does not exist!!") + + if not args.export_path.exists(): + print(f"Folder {args.export_path} does not exist!!") logging.info("Visualization export finished") + + +if __name__ == "__main__": + main() diff --git a/src/mmdc_singledate/inference/__init__.py b/src/mmdc_singledate/inference/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e5a0d9b4834ec8f46d6e0d1256c6dcaad2e460fe --- /dev/null +++ b/src/mmdc_singledate/inference/__init__.py @@ -0,0 +1 @@ +#!/usr/bin/env python3 diff --git a/src/mmdc_singledate/inference/components/__init__.py b/src/mmdc_singledate/inference/components/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e5a0d9b4834ec8f46d6e0d1256c6dcaad2e460fe --- /dev/null +++ b/src/mmdc_singledate/inference/components/__init__.py @@ -0,0 +1 @@ +#!/usr/bin/env python3 diff --git a/src/mmdc_singledate/inference/components/inference_components.py b/src/mmdc_singledate/inference/components/inference_components.py new file mode 100644 index 0000000000000000000000000000000000000000..42cdc604a0cbd1fc7d9bd47b8325bb49ef290189 --- /dev/null +++ b/src/mmdc_singledate/inference/components/inference_components.py @@ -0,0 +1,304 @@ +#!/usr/bin/env python3 +# copyright: (c) 2023 cesbio / centre national d'Etudes Spatiales + +""" +Functions components for get the +latent spaces for a given tile +""" +# imports +import logging +from pathlib import Path +from typing import Literal + +import torch +from torch import nn + +from mmdc_singledate.inference.utils import get_mmdc_full_config +from mmdc_singledate.models.torch.full import MMDCFullModule +from mmdc_singledate.models.types import S1S2VAELatentSpace + +# define sensors variable for typing +SENSORS = Literal["S2L2A", "S1FULL", "S1ASC", "S1DESC"] + +# Configure the logger +NUMERIC_LEVEL = getattr(logging, "INFO", None) +logging.basicConfig( + level=NUMERIC_LEVEL, format="%(asctime)-15s %(levelname)s: %(message)s" +) +logger = logging.getLogger(__name__) + + +@torch.no_grad() +def get_mmdc_full_model( + checkpoint: str, + mmdc_full_config_path: str, + inference_tile: bool, +) -> nn.Module: + """ + Instantiate the network + + """ + if not Path(checkpoint).exists: + logger.info("No checkpoint path was give. ") + return 1 + else: + # read the config file + mmdc_full_config = get_mmdc_full_config( + mmdc_full_config_path, inference_tile=inference_tile + ) + # Init the model passing the config + mmdc_full_model = MMDCFullModule( + mmdc_full_config, + ) + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + logger.debug(f"device detected : {device}") + logger.debug(f"nb_gpu's detected : {torch.cuda.device_count()}") + # load state_dict + lightning_checkpoint = torch.load(checkpoint, map_location=device) + + # delete "model" from the loaded checkpoint + lightning_checkpoint_cleaning = { + key.split("model.")[-1]: item + for key, item in lightning_checkpoint["state_dict"].items() + if key.startswith("model.") + } + + # load the state dict + mmdc_full_model.load_state_dict(lightning_checkpoint_cleaning) + + # disble randomness, dropout, etc... + mmdc_full_model.eval() + + return mmdc_full_model + + +@torch.no_grad() +def predict_mmdc_model( + model: nn.Module, + sensors: SENSORS, + s2_ref, + s2_mask, + s2_angles, + s1_back, + s1_vm, + s1_asc_angles, + s1_desc_angles, + worldclim, + srtm, +) -> torch.Tensor: + """ + This function apply the predict method to a + batch of mmdc data and get the latent space + by sensor + + :args: model : mmdc_full model intance + :args: sensors : Sensor acquisition + :args: s2_ref : + :args: s2_angles : + :args: s1_back : + :args: s1_asc_angles : + :args: s1_desc_angles : + :args: worldclim : + :args: srtm : + + :return: VAELatentSpace + """ + # TODO check if the sensors are in the current supported list + # available_sensors = ["S2L2A", "S1FULL", "S1ASC", "S1DESC"] + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + logger.debug(device) + + prediction = model.predict( + s2_ref.to(device), + s2_mask.to(device), + s2_angles.to(device), + s1_back.to(device), + s1_vm.to(device), + s1_asc_angles.to(device), + s1_desc_angles.to(device), + worldclim.to(device), + srtm.to(device), + ) + + logger.debug("prediction.shape :", prediction[0].latent.sen2.mean.shape) + logger.debug("prediction.shape :", prediction[0].latent.sen2.logvar.shape) + logger.debug("prediction.shape :", prediction[0].latent.sen1.mean.shape) + logger.debug("prediction.shape :", prediction[0].latent.sen1.logvar.shape) + + # init the output latent spaces as + # empty dataclass + latent_space = S1S2VAELatentSpace(sen1=None, sen2=None) + + # fullfit with a matchcase + # match + # match sensors: + # Cases + if sensors == ["S2L2A", "S1FULL"]: + logger.debug("S2 captured & S1 full captured") + + latent_space.sen2 = prediction[0].latent.sen2 + latent_space.sen1 = prediction[0].latent.sen1 + + latent_space_stack = torch.cat( + ( + latent_space.sen2.mean, + latent_space.sen2.logvar, + latent_space.sen1.mean, + latent_space.sen1.logvar, + ), + 1, + ) + + logger.debug( + "latent_space.sen2.mean :", + latent_space.sen2.mean.shape, + ) + logger.debug( + "latent_space.sen2.logvar :", + latent_space.sen2.logvar.shape, + ) + logger.debug("latent_space.sen1.mean :", latent_space.sen1.mean.shape) + logger.debug("latent_space.sen1.logvar :", latent_space.sen1.logvar.shape) + + elif sensors == ["S2L2A", "S1ASC"]: + # logger.debug("S2 captured & S1 asc captured") + + latent_space.sen2 = prediction[0].latent.sen2 + latent_space.sen1 = prediction[0].latent.sen1 + + latent_space_stack = torch.cat( + ( + latent_space.sen2.mean, + latent_space.sen2.logvar, + latent_space.sen1.mean, + latent_space.sen1.logvar, + ), + 1, + ) + logger.debug( + "latent_space.sen2.mean :", + latent_space.sen2.mean.shape, + ) + logger.debug( + "latent_space.sen2.logvar :", + latent_space.sen2.logvar.shape, + ) + logger.debug("latent_space.sen1.mean :", latent_space.sen1.mean.shape) + logger.debug("latent_space.sen1.logvar :", latent_space.sen1.logvar.shape) + + elif sensors == ["S2L2A", "S1DESC"]: + logger.debug("S2 captured & S1 desc captured") + + latent_space.sen2 = prediction[0].latent.sen2 + latent_space.sen1 = prediction[0].latent.sen1 + + latent_space_stack = torch.cat( + ( + latent_space.sen2.mean, + latent_space.sen2.logvar, + latent_space.sen1.mean, + latent_space.sen1.logvar, + ), + 1, + ) + logger.debug( + "latent_space.sen2.mean :", + latent_space.sen2.mean.shape, + ) + logger.debug( + "latent_space.sen2.logvar :", + latent_space.sen2.logvar.shape, + ) + logger.debug("latent_space.sen1.mean :", latent_space.sen1.mean.shape) + logger.debug("latent_space.sen1.logvar :", latent_space.sen1.logvar.shape) + + elif sensors == ["S2L2A"]: + # logger.debug("Only S2 captured") + + latent_space.sen2 = prediction[0].latent.sen2 + + latent_space_stack = torch.cat( + ( + latent_space.sen2.mean, + latent_space.sen2.logvar, + ), + 1, + ) + logger.debug( + "latent_space.sen2.mean :", + latent_space.sen2.mean.shape, + ) + logger.debug( + "latent_space.sen2.logvar :", + latent_space.sen2.logvar.shape, + ) + + elif sensors == ["S1FULL"]: + # logger.debug("Only S1 full captured") + + latent_space.sen1 = prediction[0].latent.sen1 + + latent_space_stack = torch.cat( + (latent_space.sen1.mean, latent_space.sen1.logvar), + 1, + ) + logger.debug("latent_space.sen1.mean :", latent_space.sen1.mean.shape) + logger.debug("latent_space.sen1.logvar :", latent_space.sen1.logvar.shape) + + elif sensors == ["S1ASC"]: + # logger.debug("Only S1ASC captured") + + latent_space.sen1 = prediction[0].latent.sen1 + + latent_space_stack = torch.cat( + (latent_space.sen1.mean, latent_space.sen1.logvar), + 1, + ) + logger.debug("latent_space.sen1.mean :", latent_space.sen1.mean.shape) + logger.debug("latent_space.sen1.logvar :", latent_space.sen1.logvar.shape) + + elif sensors == ["S1DESC"]: + # logger.debug("Only S1DESC captured") + + latent_space.sen1 = prediction[0].latent.sen1 + + latent_space_stack = torch.cat( + (latent_space.sen1.mean, latent_space.sen1.logvar), + 1, + ) + logger.debug("latent_space.sen1.mean :", latent_space.sen1.mean.shape) + logger.debug("latent_space.sen1.logvar :", latent_space.sen1.logvar.shape) + else: + raise Exception(f"Not a valid sensors ({sensors}) are given") + + return latent_space_stack # latent_space + + +# Switch to match case in python 3.10 o superior +# fullfit with a matchcase +# for captor in sensors: +# # match +# match captor: +# # Cases +# case ["S2L2A", "S1FULL"]: +# print("S1 full captured & S2 captured") + +# case ["S2L2A", "S1ASC"]: +# print("S1 asc captured & S2 captured") + +# case ["S2L2A", "S1DESC"]: +# print("S1 desc captured & S2 captured") + +# case ["S2L2A"]: +# print("Only S2 captured") + +# case ["S1FULL"]: +# print("Only S1 full captured") + +# case ["S1ASC"]: +# print("Only S1ASC captured") + +# case ["S1DESC"]: +# print("Only S1DESC captured") diff --git a/src/mmdc_singledate/inference/components/inference_utils.py b/src/mmdc_singledate/inference/components/inference_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..ef85dbdb889e4515bbd634e0f855ce91f5e22121 --- /dev/null +++ b/src/mmdc_singledate/inference/components/inference_utils.py @@ -0,0 +1,345 @@ +#!/usr/bin/env python3 +# copyright: (c) 2023 cesbio / centre national d'Etudes Spatiales + +""" +Infereces utils functions + +""" +from collections.abc import Callable, Generator +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any + +import rasterio as rio +import torch +from rasterio.windows import Window as rio_window + +from mmdc_singledate.datamodules.components.datamodule_components import ( + get_worldclim_filenames, +) +from mmdc_singledate.datamodules.components.datamodule_utils import ( + apply_log_to_s1, + get_s1_acquisition_angles, + join_even_odd_s2_angles, + srtm_height_aspect, +) + +# dataclasses +Chunk = tuple[int, int, int, int] + +# hold the data + + +@dataclass +class GeoTiffDataset: + + """ + hold the tif's filenames + stored in disk and his metadata + """ + + s2_filename: str + s1_asc_filename: str + s1_desc_filename: str + srtm_filename: str + wc_filename: str + metadata: dict = field(init=False) + s2_availabitity: bool | None = field(default=None, init=True) + s1_asc_availability: bool | None = field(default=None, init=True) + s1_desc_availability: bool | None = field(default=None, init=True) + + def __post_init__(self) -> None: + # check if the data exists + if not Path(self.srtm_filename).is_file(): + raise Exception(f"{self.srtm_filename} do not exists!") + if not Path(self.wc_filename).is_file(): + raise Exception(f"{self.wc_filename} do not exists!") + if Path(self.s1_asc_filename).is_file(): + self.s1_asc_availability = True + else: + self.s1_asc_availability = False + if Path(self.s1_desc_filename).is_file(): + self.s1_desc_availability = True + else: + self.s1_desc_availability = False + if Path(self.s2_filename).is_file(): + self.s2_availabitity = True + else: + self.s2_availabitity = False + # check if the filenames + # provide are in the same + # footprint and spatial resolution + # rio.open( + # self.s1_asc_filename + # ) as s1_asc, rio.open(self.s1_desc_filename) as s1_desc, + # rio.open(self.s2_filename) as s2, + with rio.open(self.srtm_filename) as srtm, rio.open(self.wc_filename) as wc: + # recover the metadata + self.metadata = srtm.meta + + # create a python set to check + crss = { + srtm.meta["crs"], + wc.meta["crs"], + # s2.meta["crs"], + # s1_asc.meta["crs"], + # s1_desc.meta["crs"], + } + + heights = { + srtm.meta["height"], + wc.meta["height"], + # s2.meta["height"], + # s1_asc.meta["height"], + # s1_desc.meta["height"], + } + + widths = { + srtm.meta["width"], + wc.meta["width"], + # s2.meta["width"], + # s1_asc.meta["width"], + # s1_desc.meta["width"], + } + + # check crs + if len(crss) > 1: + raise Exception("Data should be in the same CRS") + + # check height + if len(heights) > 1: + raise Exception("Data should be have the same height") + + # check width + if len(widths) > 1: + raise Exception("Data should be have the same width") + + +@dataclass +class S2Components: + """ + dataclass for hold the s2 related data + """ + + s2_reflectances: torch.Tensor + s2_angles: torch.Tensor + s2_mask: torch.Tensor + + +@dataclass +class S1Components: + """ + dataclass for hold the s1 related data + """ + + s1_backscatter: torch.Tensor + s1_valmask: torch.Tensor + s1_lia_angles: torch.Tensor + + +@dataclass +class SRTMComponents: + """ + dataclass for hold srtm related data + """ + + srtm: torch.Tensor + + +@dataclass +class WorldClimComponents: + """ + dataclass for hold worldclim related data + """ + + worldclim: torch.Tensor + + +# chunks logic +def generate_chunks( + width: int, height: int, nlines: int, test_area: tuple[int, int] | None = None +) -> list[Chunk]: + """Given the width and height of an image and a number of lines per chunk, + generate the list of chunks for the image. The chunks span all the width. A + chunk is encoded as 4 values: (x0, y0, width, height) + :param width: image width + :param height: image height + :param nlines: number of lines per chunk + :param test_area: generate only the chunks between (ya, yb) + :returns: a list of chunks + """ + chunks = [] + for iter_height in range(height // nlines - 1): + y0 = iter_height * nlines + y1 = y0 + nlines + if y1 > height: + y1 = height + append_chunk = (test_area is None) or (y0 > test_area[0] and y1 < test_area[1]) + if append_chunk: + chunks.append((0, y0, width, y1 - y0)) + return chunks + + +# read data +def read_img_tile( + filename: str, + rois: list[rio_window], + availabitity: bool, + sensor_func: Callable[[torch.Tensor, Any], Any], +) -> Generator[Any, None, None]: + """ + read a patch of a abstract sensor data + contruct the auxiliary data and yield the patch + of data + """ + # check if the filename exists + # if exist proceed to yield the data + if availabitity: + with rio.open(filename) as raster: + for roi in rois: + # read the patch as tensor + tensor = torch.tensor(raster.read(window=roi), requires_grad=False) + # compute some transformation to the tensor and return dataclass + sensor_data = sensor_func(tensor, availabitity, filename) + yield sensor_data + # if not exists create a zeros tensor and yield + else: + # print("Creating fake data") + for roi in rois: + null_data = torch.ones(roi.width, roi.height) + sensor_data = sensor_func(null_data, availabitity, filename) + yield sensor_data + + +def read_s2_img_tile( + s2_tensor: torch.Tensor, + availabitity: bool, + *args: Any, + **kwargs: Any, +) -> S2Components: # [torch.Tensor, torch.Tensor, torch.Tensor]: + """ + read a patch of sentinel 2 data + contruct the masks and yield the patch + of data + """ + if availabitity: + # extract masks + cloud_mask = s2_tensor[11, ...].to(torch.uint8) + cloud_mask[cloud_mask > 0] = 1 + sat_mask = s2_tensor[10, ...].to(torch.uint8) + edge_mask = s2_tensor[12, ...].to(torch.uint8) + # create the validity mask + mask = torch.logical_or( + cloud_mask, torch.logical_or(edge_mask, sat_mask) + ).unsqueeze(0) + angles_s2 = join_even_odd_s2_angles(s2_tensor[14:, ...]) + image_s2 = s2_tensor[:10, ...] + else: + image_s2 = torch.ones(10, s2_tensor.shape[0], s2_tensor.shape[1]) + angles_s2 = torch.ones(6, s2_tensor.shape[0], s2_tensor.shape[1]) + mask = torch.zeros(1, s2_tensor.shape[0], s2_tensor.shape[1]) + # print("passing zero data") + + return S2Components(s2_reflectances=image_s2, s2_angles=angles_s2, s2_mask=mask) + + +# TODO get out the s1_lia computation from this function +def read_s1_img_tile( + s1_tensor: torch.Tensor, + availability: bool, + s1_filename: str, + *args: Any, + **kwargs: Any, +) -> S1Components: # [torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """ + read a patch of s1 construct the datasets associated + and yield the result + + inputs : s1_tensor read from read_img_tile + """ + # if s1_tensor.shape[0] == 2: + if availability: + # get the vv,vh, vv/vh bands + img_s1 = torch.cat( + (s1_tensor, (s1_tensor[1, ...] / s1_tensor[0, ...]).unsqueeze(0)) + ) + # get the local incidencce angle (lia) + s1_lia = get_s1_acquisition_angles(s1_filename) + # compute validity mask + s1_valmask = torch.ones(1, s1_tensor.shape[1], s1_tensor.shape[2]) + # compute edge mask + s1_backscatter = apply_log_to_s1(img_s1) + else: + s1_backscatter = torch.ones(3, s1_tensor.shape[1], s1_tensor.shape[0]) + s1_valmask = torch.zeros(1, s1_tensor.shape[1], s1_tensor.shape[0]) + s1_lia = torch.ones(3) + + return S1Components( + s1_backscatter=s1_backscatter, + s1_valmask=s1_valmask, + s1_lia_angles=s1_lia, + ) + + +def read_srtm_img_tile( + srtm_tensor: torch.Tensor, + *args: Any, + **kwargs: Any, +) -> SRTMComponents: # [torch.Tensor]: + """ + read srtm patch + """ + # read the patch as tensor + img_srtm = srtm_height_aspect(srtm_tensor) + return SRTMComponents(img_srtm) + + +def read_worldclim_img_tile( + wc_tensor: torch.Tensor, + *args: Any, + **kwargs: Any, +) -> WorldClimComponents: # [torch.tensor]: + """ + read worldclim subset + """ + return WorldClimComponents(wc_tensor) + + +def expand_worldclim_filenames( + wc_filename: str, +) -> list[str]: + """ + given the firts worldclim filename expand the filenames + to the others files + """ + name_wc = [ + get_worldclim_filenames(wc_filename, str(idx), True) for idx in range(1, 13, 1) + ] + + name_wc_bio = get_worldclim_filenames(wc_filename, None, False) + name_wc.append(name_wc_bio) + + return name_wc + + +def concat_worldclim_components( + wc_filename: str, + rois: list[rio_window], + availabitity, + *args: Any, + **kwargs: Any, +) -> Generator[Any, None, None]: + """ + Compose function for apply the read_img_tile general function + to the different worldclim files + """ + # get all filenames + wc_filenames = expand_worldclim_filenames(wc_filename) + # read the tensors + wc_tensor = [ + next( + read_img_tile(wc_filename, rois, availabitity, read_worldclim_img_tile) + ).worldclim + for idx, wc_filename in enumerate(wc_filenames) + ] + yield WorldClimComponents(torch.cat(wc_tensor)) diff --git a/src/mmdc_singledate/inference/mmdc_iota2_inference.py b/src/mmdc_singledate/inference/mmdc_iota2_inference.py new file mode 100644 index 0000000000000000000000000000000000000000..742f2892e94d794519dfb309ff1c20f6357d8db7 --- /dev/null +++ b/src/mmdc_singledate/inference/mmdc_iota2_inference.py @@ -0,0 +1,512 @@ +#!/usr/bin/env python3 +# copyright: (c) 2023 cesbio / centre national d'Etudes Spatiales + +""" +Infereces API with Iota-2 +Inspired by: +https://src.koda.cnrs.fr/mmdc/mmdc-singledate/-/blob/01ff5139a9eb22785930964d181e2a0b7b7af0d1/iota2/external_iota2_code.py +""" +# imports +import logging +import os + +import numpy as np +import torch +from dateutil import parser as date_parser +from iota2.configuration_files import read_config_file as rcf +from torchutils import patches + +# from mmdc_singledate.models.mmdc_full_module import MMDCFullModule +# from mmdc_singledate.models.types import S1S2VAELatentSpace +# from mmdc_singledate.datamodules.components.datamodule_utils import ( +# apply_log_to_s1, +# srtm_height_aspect, +# ) + +# from .components.inference_components import get_mmdc_full_model, predict_mmdc_model +# from .components.inference_utils import ( +# GeoTiffDataset, +# S1Components, +# S2Components, +# SRTMComponents, +# WorldClimComponents, +# get_mmdc_full_config, +# read_s1_img_tile, +# read_s2_img_tile, +# read_srtm_img_tile, +# read_worldclim_img_tile, +# ) +# from .utils import get_scales + +# from tqdm import tqdm +# from pathlib import Path + + +# Configure the logger +NUMERIC_LEVEL = getattr(logging, "INFO", None) +logging.basicConfig( + level=NUMERIC_LEVEL, format="%(asctime)-15s %(levelname)s: %(message)s" +) +logger = logging.getLogger(__name__) + + +def read_s2_img_iota2( + self, +): +# ) -> S2Components: # [torch.Tensor, torch.Tensor, torch.Tensor]: + """ + read a patch of sentinel 2 data + contruct the masks and yield the patch + of data + """ + # DONE Get the data in the same order as + # sentinel2.Sentinel2.GROUP_10M + + # sensorsio.sentinel2.GROUP_20M + # This data correspond to (Chunk_Size, Img_Size, nb_dates) + list_bands_s2 = [ + self.get_interpolated_Sentinel2_B2(), + self.get_interpolated_Sentinel2_B3(), + self.get_interpolated_Sentinel2_B4(), + self.get_interpolated_Sentinel2_B8(), + self.get_interpolated_Sentinel2_B5(), + self.get_interpolated_Sentinel2_B6(), + self.get_interpolated_Sentinel2_B7(), + self.get_interpolated_Sentinel2_B8A(), + self.get_interpolated_Sentinel2_B11(), + self.get_interpolated_Sentinel2_B12(), + ] + + # DONE Masks contruction for S2 + list_s2_mask = [ + self.get_Sentinel2_binary_masks(), + self.get_Sentinel2_binary_masks(), + self.get_Sentinel2_binary_masks(), + self.get_Sentinel2_binary_masks(), + ] + + # DONE Add S2 Angles + list_bands_s2_angles = [ + self.get_interpolated_userFeatures_ZENITH(), + self.get_interpolated_userFeatures_AZIMUTH(), + self.get_interpolated_userFeatures_EVEN_ZENITH(), + self.get_interpolated_userFeatures_ODD_ZENITH(), + self.get_interpolated_userFeatures_EVEN_AZIMUTH(), + self.get_interpolated_userFeatures_ODD_AZIMUTH(), + ] + + # extend the list + list_bands_s2.extend(list_s2_mask) + list_bands_s2.extend(list_bands_s2_angles) + + bands_s2 = torch.Tensor( + list_bands_s2, dtype=torch.torch.float32, device=device + ).permute(-1, 0, 1, 2) + + return read_s2_img_tile(bands_s2) + + +def read_s1_img_iota2( + self, +): +# ) -> S1Components: # [torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """ + read a patch of s1 construct the datasets associated + and yield the result + """ + # TODO Manage S1 ASC and S1 DESC ? + list_bands_s1 = [ + self.get_interpolated_Sentinel1_ASC_vh(), + self.get_interpolated_Sentinel1_ASC_vv(), + self.get_interpolated_Sentinel1_ASC_vh() + / (self.get_interpolated_Sentinel1_ASC_vv() + 1e-4), + self.get_interpolated_Sentinel1_DES_vh(), + self.get_interpolated_Sentinel1_DES_vv(), + self.get_interpolated_Sentinel1_DES_vh() + / (self.get_interpolated_Sentinel1_DES_vv() + 1e-4), + ] + + bands_s1 = torch.Tensor(list_bands_s1).permute(-1, 0, 1, 2) + bands_s1 = apply_log_to_s1(bands_s1).permute(-1, 0, 1, 2) + + # get the vv,vh, vv/vh bands + # img_s1 = torch.cat( + # (s1_tensor, (s1_tensor[1, ...] / s1_tensor[0, ...]).unsqueeze(0)) + # ) + + list_bands_s1_lia = [ + self.get_interpolated_, + self.get_interpolated_, + self.get_interpolated_, + ] + + # s1_lia = get_s1_acquisition_angles(s1_filename) + # compute validity mask + s1_valmask = torch.ones(img_s1.shape) + # compute edge mask + s1_backscatter = apply_log_to_s1(img_s1) + + return read_s1_img_tile() + + +def read_srtm_img_iota2( + self, +): +# ) -> SRTMComponents: + """ + read srtm patch from Iota2 API + """ + # read the patch as numpy array + + list_bands_srtm = [ + self.get_interpolated_userFeatures_elevation(), + self.get_interpolated_userFeatures_slope(), + self.get_interpolated_userFeatures_aspect(), + ] + # convert to tensor and permute the bands + srtm_tensor = torch.Tensor(list_bands_srtm).permute(-1, 0, 1, 2) + + return read_srtm_img_tile( + srtm_height_aspect( + srtm_tensor + ) + ) + + + +def read_worldclim_img_iota2( + self, +): +# ) -> WorldClimComponents: # [torch.tensor]: + """ + read worldclim patch from Iota2 API + """ + wc_tensor = torch.Tensor( + self.get_interpolated_userFeatures_worldclim() + ).permute(-1, 0, 1, 2) + return read_worldclim_img_tile(wc_tensor) + + +def predictc_iota2( + self, +): + """ + Apply MMDC with Iota-2.py + """ + + # retrieve the metadata + # meta = input_data.metadata.copy() + # get nb outputs + # meta.update({"count": process.count}) + # calculate rois + # chunks = generate_chunks(meta["width"], meta["height"], process.nb_lines) + # get the windows from the chunks + # rois = [rio.windows.Window(*chunk) for chunk in chunks] + # logger.debug(f"chunk size : ({rois[0].width}, {rois[0].height}) ") + + # init the dataset + logger.debug("Reading S2 data") + s2_data = read_s2_img_iota2() + logger.debug("Reading S1 ASC data") + s1_asc_data = read_s1_img_iota2() + logger.debug("Reading S1 DESC data") + s1_desc_data = read_s1_img_iota2() + logger.debug("Reading SRTM data") + srtm_data = read_srtm_img_iota2() + logger.debug("Reading WorldClim data") + worldclim_data = read_worldclim_img_iota2() + + # Concat S1 Data + s1_backscatter = torch.cat((s1_asc.s1_backscatter, s1_desc.s1_backscatter), 0) + # Multiply the masks + s1_valmask = torch.mul(s1_asc.s1_valmask, s1_desc.s1_valmask) + # keep the sizes for recover the original size + s2_mask_patch_size = patches.patchify(s2.s2_mask, process.patch_size).shape + s2_s2_mask_shape = s2.s2_mask.shape # [1, 1024, 10980] + + # reshape the data + s2_refl_patch = patches.flatten2d( + patches.patchify(s2_data.s2_reflectances, process.patch_size) + ) + s2_ang_patch = patches.flatten2d( + patches.patchify(s2_data.s2_angles, process.patch_size) + ) + s2_mask_patch = patches.flatten2d( + patches.patchify(s2_data.s2_mask, process.patch_size) + ) + s1_patch = patches.flatten2d(patches.patchify(s1_backscatter, process.patch_size)) + s1_valmask_patch = patches.flatten2d( + patches.patchify(s1_valmask, process.patch_size) + ) + srtm_patch = patches.flatten2d(patches.patchify(srtm_data.srtm, process.patch_size)) + wc_patch = patches.flatten2d( + patches.patchify(worldclim_data.worldclim, process.patch_size) + ) + # Expand the angles to fit the sizes + s1asc_lia_patch = s1_asc.s1_lia_angles.unsqueeze(0).repeat( + s2_mask_patch_size[0] * s2_mask_patch_size[1], 1 + ) + s1desc_lia_patch = s1_desc.s1_lia_angles.unsqueeze(0).repeat( + s2_mask_patch_size[0] * s2_mask_patch_size[1], 1 + ) + + logger.debug("s2_patches", s2_refl_patch.shape) + logger.debug("s2_angles_patches", s2_ang_patch.shape) + logger.debug("s2_mask_patch", s2_mask_patch.shape) + logger.debug("s1_patch", s1_patch.shape) + logger.debug("s1_valmask_patch", s1_valmask_patch.shape) + logger.debug( + "s1_lia_asc patches", + s1_asc.s1_lia_angles.shape, + s1asc_lia_patch.shape, + ) + logger.debug( + "s1_lia_desc patches", + s1_desc.s1_lia_angles.shape, + s1desc_lia_patch.shape, + ) + logger.debug("srtm patches", srtm_patch.shape) + logger.debug("wc patches", wc_patch.shape) + + # apply predict function + # should return a s1s2vaelatentspace + # object + pred_vaelatentspace = process.process( + process.model, + sensors, + s2_refl_patch, + s2_mask_patch, + s2_ang_patch, + s1_patch, + s1_valmask_patch, + s1asc_lia_patch, + s1desc_lia_patch, + wc_patch, + srtm_patch, + ) + logger.debug(type(pred_vaelatentspace)) + logger.debug("latent space sizes : ", pred_vaelatentspace.shape) + + # unpatchify + pred_vaelatentspace_unpatchify = patches.unpatchify( + patches.unflatten2d( + pred_vaelatentspace, + s2_mask_patch_size[0], + s2_mask_patch_size[1], + ) + )[:, : s2_s2_mask_shape[1], : s2_s2_mask_shape[2]] + + logger.debug("pred_tensor :", pred_vaelatentspace_unpatchify.shape) + logger.debug("process.count : ", process.count) + # check the pred and the ask are the same + assert process.count == pred_vaelatentspace_unpatchify.shape[0] + + predition_array = np.array(pred_vaelatentspace_unpatchify[0, ...]) + + logger.debug(("Export tile", f"filename :{export_path}")) + + return predition_array, labels + + +def acquisition_dataframe( + self, +) -> pd.DataFrame: + """ + Construct a DataFrame with the acquisition dates + for the predictions + """ + # use Iota2 API + # acquisition_dates = self.get_raw_dates() + + # Get the acquisition dates + acquisition_dates = { + "Sentinel1_DES_vv": ["20151231"], + "Sentinel1_DES_vh": ["20151231"], + "Sentinel1_ASC_vv": ["20170518", "20170519"], + "Sentinel1_ASC_vh": ["20170518", "20170519"], + "Sentinel2": ["20151130", "20151203"], + } + + # get the acquisition dates for each sensor-mode + sentinel2_acquisition_dates = pd.DataFrame( + {"Sentinel2": acquisition_dates["Sentinel2"]}, dtype=str + ) + sentinel1_asc_acquisition_dates = pd.DataFrame( + { + "Sentinel1_ASC_vv": acquisition_dates["Sentinel1_ASC_vv"], + "Sentinel1_ASC_vh": acquisition_dates["Sentinel1_ASC_vh"], + }, + dtype=str, + ) + sentinel1_des_acquisition_dates = pd.DataFrame( + { + "Sentinel1_DES_vv": acquisition_dates["Sentinel1_DES_vv"], + "Sentinel1_DES_vh": acquisition_dates["Sentinel1_DES_vh"], + }, + dtype=str, + ) + + # convert acquistion dates from str to datetime object + sentinel2_acquisition_dates["Sentinel2"] = sentinel2_acquisition_dates["Sentinel2"].apply( + lambda x: date_parser.parse(x) + ) + sentinel1_asc_acquisition_dates["Sentinel1_ASC_vv"] = sentinel1_asc_acquisition_dates[ + "Sentinel1_ASC_vv" + ].apply(lambda x: date_parser.parse(x)) + sentinel1_asc_acquisition_dates["Sentinel1_ASC_vh"] = sentinel1_asc_acquisition_dates[ + "Sentinel1_ASC_vh" + ].apply(lambda x: date_parser.parse(x)) + sentinel1_des_acquisition_dates["Sentinel1_DES_vv"] = sentinel1_des_acquisition_dates[ + "Sentinel1_DES_vv" + ].apply(lambda x: date_parser.parse(x)) + sentinel1_des_acquisition_dates["Sentinel1_DES_vh"] = sentinel1_des_acquisition_dates[ + "Sentinel1_DES_vh" + ].apply(lambda x: date_parser.parse(x)) + + # merge the dataframes + sensors_join = pd.concat([ + sentinel2_acquisition_dates, + sentinel1_asc_acquisition_dates, + sentinel1_des_acquisition_dates, + ] + ) + + # get acquisition availability + sensors_join["s2_availability"] = pd.isna(sensors_join["Sentinel2"]) + sensors_join["s1_availability_asc"] = pd.isna(sensors_join["Sentinel1_ASC_vv"]) + sensors_join["s1_availability_desc"] = pd.isna(sensors_join["Sentinel1_DES_vv"]) + # return + + +def apply_mmdc_full_mode( + # checkpoint_path: str, # read from the cfg file + # checkpoint_epoch: int = 100, # read from the cfg file + # patch_size: int = 256, # read from the cfg file +): + """ + Entry Point + """ + # Set device + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + # read the configuration file + config = rcf("configs/iota2/externalfeatures_with_userfeatures.cfg") + # return a list with a tuple of functions and arguments + mmdc_parametrs = config.getParam("external_features", "functions") + padding_size_x = config.getParam("python_data_managing", "padding_size_x") + padding_size_y = config.getParam("python_data_managing", "padding_size_y") + + checkpoint_path = parametrs[0][1] # TODO fix this parameters + + # Get the acquisition dates + # dict_dates = { + # 'Sentinel1_DES_vv': ['20151231'], + # 'Sentinel1_DES_vh': ['20151231'], + # 'Sentinel1_ASC_vv': ['20170518', '20170519'], + # 'Sentinel1_ASC_vh': ['20170518', '20170519'], + # 'Sentinel2': ['20151130', '20151203']} + + # model + mmdc_full_model = get_mmdc_full_model( + checkpoint=model_checkpoint_path, + mmdc_full_config_path=model_config_path, + inference_tile=False, + ) + scales = get_scales() + mmdc_full_model.set_scales(scales) + mmdc_full_model.to(device) + + # process class + mmdc_process = MMDCProcess( + count=latent_space_size, + nb_lines=nb_lines, + patch_size=patch_size, + process=predict_mmdc_model, + model=mmdc_full_model, + ) + # create export filename + export_path = f"{export_path}/latent_singledate_{'_'.join(sensors)}.tif" + + # # How manage the S1 acquisition dates for construct the mask validity + # # in time + + # # TODO Read SRTM data + # # TODO Read Worldclim Data + + # # with torch.no_grad(): + # # Permute dimensions to fit patchify + # # Shape before permutation is C,H,W,D. D being the dates + # # Shape after permutation is D,C,H,W + # bands_s2_mask = torch.Tensor(list_s2_mask).permute(-1, 0, 1, 2) + # bands_s1 = torch.Tensor(list_bands_s1).permute(-1, 0, 1, 2) + # bands_s1 = apply_log_to_s1(bands_s1).permute(-1, 0, 1, 2) + + # # TODO Masks contruction for S1 + # # build_s1_image_and_masks function for datamodules components datamodule components ? + + # # Replace nan by 0 + # bands_s1 = bands_s1.nan_to_num() + # # These dimensions are useful for unpatchify + # # Keep the dimensions of height and width found in chunk + # band_h, band_w = bands_s2.shape[-2:] + # # Keep the number of patches of patch_size + # # in rows and cols found in chunk + # h, w = patches.patchify( + # bands_s2[0, ...], patch_size=patch_size, margin=patch_margin + # ).shape[:2] + + # # TODO Apply patchify + + # # Get the model + # mmdc_full_model = get_mmdc_full_model( + # os.path.join(checkpoint_path, checkpoint_filename) + # ) + + # # apply model + # # latent_variable = mmdc_full_model.predict( + # # s2_x, + # # s2_m, + # # s2_angles_x, + # # s1_x, + # # s1_vm, + # # s1_asc_angles_x, + # # s1_desc_angles_x, + # # worldclim_x, + # # srtm_x, + # # ) + + # # TODO unpatchify + + # # TODO crop padding + + # # TODO Depending of the configuration return a unique latent variable + # # or a stack of laten variables + # # coef = pass + # # labels = pass + # # return coef, labels + + +def test_mnt(self, arg1): + """ + compute the Soil Composition Index + """ + print(dir(self)) + print("raw dates :", self.get_raw_dates()) + print("interpolated dates :",self.get_interpolated_dates()) + print("interpolated dates :",self.interpolated_dates) + print("exogeneous_data_name:",self.exogeneous_data_name) + print("fill_missing_dates:",self.fill_missing_dates) + print(":",None) + + mnt = self.get_Sentinel2_binary_masks() #self.get_raw_userFeatures_mnt_band_0()#self.get_raw_userFeatures_mnt2() + # mnt = self.get_raw_userFeatures_mnt2_band_0() + # mnt = self.get_raw_userFeatures_slope() + print(f"type mnt : {type(mnt)}") + print(f"size mnt : {mnt.shape}") + labels = [f"mnt_{cpt}" for cpt in range(mnt.shape[-1])] + print(f"labels mnt : {labels}") + + # s2_b5 = self.get_interpolated_Sentinel2_B5() + # print(f"type s2_b5 : {type(s2_b5)}") # 4, 3, 2 # 26,50,2 + # print(f"size s2_b5 : {s2_b5.shape}") # 4, 3, 1 # 26,50,2 + + return mnt, labels diff --git a/src/mmdc_singledate/inference/mmdc_tile_inference.py b/src/mmdc_singledate/inference/mmdc_tile_inference.py new file mode 100644 index 0000000000000000000000000000000000000000..307e3007405badd2b7fddfc05a5e3fbbb159e300 --- /dev/null +++ b/src/mmdc_singledate/inference/mmdc_tile_inference.py @@ -0,0 +1,442 @@ +#!/usr/bin/env python3 +# copyright: (c) 2023 cesbio / centre national d'Etudes Spatiales + +""" +Infereces API with Rasterio +""" + +# imports +import logging +from collections.abc import Callable +from dataclasses import dataclass +from pathlib import Path + +import numpy as np +import pandas as pd +import rasterio as rio +import torch +from torchutils import patches +from tqdm import tqdm + +from mmdc_singledate.datamodules.components.datamodule_components import prepare_data_df + +from .components.inference_components import get_mmdc_full_model, predict_mmdc_model +from .components.inference_utils import ( + GeoTiffDataset, + concat_worldclim_components, + generate_chunks, + read_img_tile, + read_s1_img_tile, + read_s2_img_tile, + read_srtm_img_tile, +) +from .utils import get_scales + +# Configure the logger +NUMERIC_LEVEL = getattr(logging, "INFO", None) +logging.basicConfig( + level=NUMERIC_LEVEL, format="%(asctime)-15s %(levelname)s: %(message)s" +) +logger = logging.getLogger(__name__) + + +def inference_dataframe( + samples_dir: str, + input_tile_list: list[str] | None, + days_gap: int, +): + """ + Read the input directory and create + a dataframe with the occurrences + for be infered + """ + + # Create the Tile time serie + + tile_time_serie_df = prepare_data_df( + samples_dir=samples_dir, input_tile_list=input_tile_list, days_gap=days_gap + ) + + tile_time_serie_df["patch_s2_availability"] = tile_time_serie_df["patch_s2"].apply( + lambda x: Path(x).exists() + ) + + tile_time_serie_df["patchasc_s1_availability"] = tile_time_serie_df[ + "patchasc_s1" + ].apply(lambda x: Path(x).exists()) + + tile_time_serie_df["patchdesc_s1_availability"] = tile_time_serie_df[ + "patchasc_s1" + ].apply(lambda x: Path(x).exists()) + + return tile_time_serie_df + + +# functions and classes +@dataclass +class MMDCProcess: + """ + Class to hold + """ + + count: int + nb_lines: int + patch_size: int + model: torch.nn + process: Callable[ + [ + torch.tensor, + torch.tensor, + torch.tensor, + torch.tensor, + torch.tensor, + torch.tensor, + torch.tensor, + ], + torch.tensor, + ] + + +def predict_single_date_tile( + input_data: GeoTiffDataset, + export_path: Path, + sensors: list[str], + process: MMDCProcess, +): + """ + Predict a tile of data + """ + # check that at least one dataset exists + if ( + Path(input_data.s2_filename).exists + or Path(input_data.s1_asc_filename) + or Path(input_data.s1_desc_filename) + ): + # retrieve the metadata + meta = input_data.metadata.copy() + # get nb outputs + meta.update({"count": process.count}) + # calculate rois + chunks = generate_chunks(meta["width"], meta["height"], process.nb_lines) + # get the windows from the chunks + rois = [rio.windows.Window(*chunk) for chunk in chunks] + logger.debug(f"chunk size : ({rois[0].width}, {rois[0].height}) ") + + # init the dataset + logger.debug("Reading S2 data") + s2_data = read_img_tile( + filename=input_data.s2_filename, + rois=rois, + availabitity=input_data.s2_availabitity, + sensor_func=read_s2_img_tile, + ) + logger.debug("Reading S1 ASC data") + s1_asc_data = read_img_tile( + filename=input_data.s1_asc_filename, + rois=rois, + availabitity=input_data.s1_asc_availability, + sensor_func=read_s1_img_tile, + ) + logger.debug("Reading S1 DESC data") + s1_desc_data = read_img_tile( + filename=input_data.s1_desc_filename, + rois=rois, + availabitity=input_data.s1_desc_availability, + sensor_func=read_s1_img_tile, + ) + logger.debug("Reading SRTM data") + srtm_data = read_img_tile( + filename=input_data.srtm_filename, + rois=rois, + availabitity=True, + sensor_func=read_srtm_img_tile, + ) + logger.debug("Reading WorldClim data") + worldclim_data = concat_worldclim_components( + wc_filename=input_data.wc_filename, rois=rois, availabitity=True + ) + logger.debug("input_data.s2_availabitity=", input_data.s2_availabitity) + logger.debug("s2_data=", s2_data) + logger.debug("s1_asc_data=", s1_asc_data) + logger.debug("s1_desc_data=", s1_desc_data) + logger.debug("srtm_data=", srtm_data) + logger.debug("worldclim_data=", worldclim_data) + logger.debug("Export Init") + + with rio.open(export_path, "w", **meta) as prediction: + # iterate over the windows + for roi, s2, s1_asc, s1_desc, srtm, wc in tqdm( + zip( + rois, + s2_data, + s1_asc_data, + s1_desc_data, + srtm_data, + worldclim_data, + ), + total=len(rois), + position=1, + ): + # export one tile + logger.debug(" original size : ", s2.s2_reflectances.shape) + logger.debug(" original size : ", s2.s2_angles.shape) + logger.debug(" original size : ", s2.s2_mask.shape) + logger.debug(" original size : ", s1_asc.s1_backscatter.shape) + logger.debug(" original size : ", s1_asc.s1_valmask.shape) + logger.debug(" original size : ", s1_asc.s1_lia_angles.shape) + logger.debug(" original size : ", s1_desc.s1_backscatter.shape) + logger.debug(" original size : ", s1_desc.s1_valmask.shape) + logger.debug(" original size : ", s1_desc.s1_lia_angles.shape) + logger.debug(" original size : ", srtm.srtm.shape) + logger.debug(" original size : ", wc.worldclim.shape) + + # Concat S1 Data + s1_backscatter = torch.cat( + (s1_asc.s1_backscatter, s1_desc.s1_backscatter), 0 + ) + # The validity mask of S1 is unique + # Multiply the masks + s1_valmask = torch.mul(s1_asc.s1_valmask, s1_desc.s1_valmask) + + # keep the sizes for recover the original size + s2_mask_patch_size = patches.patchify( + s2.s2_mask, process.patch_size + ).shape + s2_s2_mask_shape = s2.s2_mask.shape # [1, 1024, 10980] + + logger.debug("s2_mask_patch_size :", s2_mask_patch_size) + logger.debug("s2_s2_mask_shape:", s2_s2_mask_shape) + # reshape the data + s2_refl_patch = patches.flatten2d( + patches.patchify(s2.s2_reflectances, process.patch_size) + ) + s2_ang_patch = patches.flatten2d( + patches.patchify(s2.s2_angles, process.patch_size) + ) + s2_mask_patch = patches.flatten2d( + patches.patchify(s2.s2_mask, process.patch_size) + ) + s1_patch = patches.flatten2d( + patches.patchify(s1_backscatter, process.patch_size) + ) + s1_valmask_patch = patches.flatten2d( + patches.patchify(s1_valmask, process.patch_size) + ) + srtm_patch = patches.flatten2d( + patches.patchify(srtm.srtm, process.patch_size) + ) + wc_patch = patches.flatten2d( + patches.patchify(wc.worldclim, process.patch_size) + ) + # Expand the angles to fit the sizes + s1asc_lia_patch = s1_asc.s1_lia_angles.unsqueeze(0).repeat( + s2_mask_patch_size[0] * s2_mask_patch_size[1], 1 + ) + # torch.flatten(start_dim=0, end_dim=1) + s1desc_lia_patch = s1_desc.s1_lia_angles.unsqueeze(0).repeat( + s2_mask_patch_size[0] * s2_mask_patch_size[1], 1 + ) + + logger.debug("s2_patches", s2_refl_patch.shape) + logger.debug("s2_angles_patches", s2_ang_patch.shape) + logger.debug("s2_mask_patch", s2_mask_patch.shape) + logger.debug("s1_patch", s1_patch.shape) + logger.debug("s1_valmask_patch", s1_valmask_patch.shape) + logger.debug( + "s1_lia_asc patches", + s1_asc.s1_lia_angles.shape, + s1asc_lia_patch.shape, + ) + logger.debug( + "s1_lia_desc patches", + s1_desc.s1_lia_angles.shape, + s1desc_lia_patch.shape, + ) + logger.debug("srtm patches", srtm_patch.shape) + logger.debug("wc patches", wc_patch.shape) + + # apply predict function + # should return a s1s2vaelatentspace + # object + pred_vaelatentspace = process.process( + process.model, + sensors, + s2_refl_patch, + s2_mask_patch, + s2_ang_patch, + s1_patch, + s1_valmask_patch, + s1asc_lia_patch, + s1desc_lia_patch, + wc_patch, + srtm_patch, + ) + logger.debug(type(pred_vaelatentspace)) + logger.debug("latent space sizes : ", pred_vaelatentspace.shape) + + # unpatchify + pred_vaelatentspace_unpatchify = patches.unpatchify( + patches.unflatten2d( + pred_vaelatentspace, + s2_mask_patch_size[0], + s2_mask_patch_size[1], + ) + )[:, : s2_s2_mask_shape[1], : s2_s2_mask_shape[2]] + + logger.debug("pred_tensor :", pred_vaelatentspace_unpatchify.shape) + logger.debug("process.count : ", process.count) + # check the pred and the ask are the same + assert process.count == pred_vaelatentspace_unpatchify.shape[0] + + prediction.write( + np.array(pred_vaelatentspace_unpatchify[0, ...]), + window=roi, + indexes=process.count, + ) + logger.debug(("Export tile", f"filename :{export_path}")) + + +def estimate_nb_latent_spaces( + patch_s2_availability: bool, + patchasc_s1_availability: bool, + patchdesc_s1_availability: bool, +) -> int: + """ + Given the availability of the input data + estimate the number of latent variables + """ + + if (patch_s2_availability and patchasc_s1_availability) or ( + patch_s2_availability and patchdesc_s1_availability + ): + nb_latent_spaces = 8 * 2 + + else: + nb_latent_spaces = 4 * 2 + + return nb_latent_spaces + + +def determinate_sensor_list( + patch_s2_availability: bool, + patchasc_s1_availability: bool, + patchdesc_s1_availability: bool, +) -> list: + """ + Convert availabitiy dataframe to list of sensors + """ + sensors_availability = [ + patch_s2_availability, + patchasc_s1_availability, + patchdesc_s1_availability, + ] + if sensors_availability == [True, True, True]: + sensors = ["S2L2A", "S1FULL"] + elif sensors_availability == [True, True, False]: + sensors = ["S2L2A", "S1ASC"] + elif sensors_availability == [True, False, True]: + sensors = ["S2L2A", "S1DESC"] + elif sensors_availability == [False, True, True]: + sensors = ["S1FULL"] + elif sensors_availability == [True, False, False]: + sensors = ["S2L2A"] + elif sensors_availability == [False, True, False]: + sensors = ["S1ASC"] + elif sensors_availability == [False, False, True]: + sensors = ["S1DESC"] + else: + raise Exception( + (f"The sensor {sensors_availability}", " is not in the available list") + ) + return sensors + + +def mmdc_tile_inference( + inference_dataframe: pd.DataFrame, + export_path: Path, + model_checkpoint_path: Path, + model_config_path: Path, + patch_size: int = 256, + nb_lines: int = 1024, +) -> None: + """ + Entry point + + :args: input_path : full path to the export + :args: export_path : full path to the export + :args: sensor: list of sensors data input + :args: tile_list: list of tiles, + :args: days_gap : days between s1 and s2 acquisitions, + :args: latent_space_size : latent space output par sensor, + :args: nb_lines : number of lines to read every at time 1024, + + :return : None + + """ + # device + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + # TODO Uncomment for test purpuse + # inference_dataframe = inference_dataframe.head() + # logger.debug(inference_dataframe.shape) + + # instance the model and get the pred func + model = get_mmdc_full_model( + checkpoint=model_checkpoint_path, + mmdc_full_config_path=model_config_path, + inference_tile=True, + ) + scales = get_scales() + model.set_scales(scales) + # move to device + model.to(device) + + # iterate over the dates in the time serie + for tuile, df_row in tqdm( + inference_dataframe.iterrows(), total=inference_dataframe.shape[0], position=0 + ): + # estimate nb latent spaces + latent_space_size = estimate_nb_latent_spaces( + df_row["patch_s2_availability"], + df_row["patchasc_s1_availability"], + df_row["patchdesc_s1_availability"], + ) + sensor = determinate_sensor_list( + df_row["patch_s2_availability"], + df_row["patchasc_s1_availability"], + df_row["patchdesc_s1_availability"], + ) + # export path + date_ = df_row["date"].strftime("%Y-%m-%d") + export_file = export_path / f"latent_tile_infer_{date_}_{'_'.join(sensor)}.tif" + logger.debug(tuile, df_row) + logger.debug(latent_space_size) + # define process + mmdc_process = MMDCProcess( + count=latent_space_size, + nb_lines=nb_lines, + patch_size=patch_size, + model=model, + process=predict_mmdc_model, + ) + # get the input data + mmdc_input_data = GeoTiffDataset( + s2_filename=df_row["patch_s2"], + s1_asc_filename=df_row["patchasc_s1"], + s1_desc_filename=df_row["patchdesc_s1"], + srtm_filename=df_row["srtm_filename"], + wc_filename=df_row["worldclim_filename"], + s2_availabitity=df_row["patch_s2_availability"], + s1_asc_availability=df_row["patchasc_s1_availability"], + s1_desc_availability=df_row["patchdesc_s1_availability"], + ) + # predict tile + predict_single_date_tile( + input_data=mmdc_input_data, + export_path=export_file, + sensors=sensor, + process=mmdc_process, + ) + + logger.debug("Export Finish !!!") diff --git a/src/mmdc_singledate/inference/mmdc_tile_inference_iota2.py b/src/mmdc_singledate/inference/mmdc_tile_inference_iota2.py new file mode 100644 index 0000000000000000000000000000000000000000..1110eccfc79e703032a87b3c673e992252594453 --- /dev/null +++ b/src/mmdc_singledate/inference/mmdc_tile_inference_iota2.py @@ -0,0 +1,120 @@ +#!/usr/bin/env python3 +# copyright: (c) 2023 cesbio / centre national d'Etudes Spatiales + +""" +Infereces API with Iota-2 +Inspired by: +https://src.koda.cnrs.fr/mmdc/mmdc-singledate/-/blob/01ff5139a9eb22785930964d181e2a0b7b7af0d1/iota2/external_iota2_code.py +""" + +from typing import Any + +import torch +from torchutils import patches + +from mmdc_singledate.datamodules.components.datamodule_utils import ( + apply_log_to_s1, + srtm_height_aspect, + join_even_odd_s2_angles +) + + +def apply_mmdc_full_mode( + self, + checkpoint_path: str, + checkpoint_epoch: int = 100, + patch_size: int = 256, +): + """ + Apply MMDC with Iota-2.py + """ + # How manage the S1 acquisition dates for construct the mask validity + # in time + + # DONE Get the data in the same order as + # sentinel2.Sentinel2.GROUP_10M + + # sensorsio.sentinel2.GROUP_20M + # This data correspond to (Chunk_Size, Img_Size, nb_dates) + list_bands_s2 = [ + self.get_interpolated_Sentinel2_B2(), + self.get_interpolated_Sentinel2_B3(), + self.get_interpolated_Sentinel2_B4(), + self.get_interpolated_Sentinel2_B8(), + self.get_interpolated_Sentinel2_B5(), + self.get_interpolated_Sentinel2_B6(), + self.get_interpolated_Sentinel2_B7(), + self.get_interpolated_Sentinel2_B8A(), + self.get_interpolated_Sentinel2_B11(), + self.get_interpolated_Sentinel2_B12(), + ] + + # TODO Masks contruction for S2 + list_s2_mask = self.get_Sentinel2_binary_masks() + + # TODO Manage S1 ASC and S1 DESC ? + list_bands_s1 = [ + self.get_interpolated_Sentinel1_ASC_vh(), + self.get_interpolated_Sentinel1_ASC_vv(), + self.get_interpolated_Sentinel1_ASC_vh() + / (self.get_interpolated_Sentinel1_ASC_vv() + 1e-4), + self.get_interpolated_Sentinel1_DES_vh(), + self.get_interpolated_Sentinel1_DES_vv(), + self.get_interpolated_Sentinel1_DES_vh() + / (self.get_interpolated_Sentinel1_DES_vv() + 1e-4), + ] + + with torch.no_grad(): + # Permute dimensions to fit patchify + # Shape before permutation is C,H,W,D. D being the dates + # Shape after permutation is D,C,H,W + bands_s2 = torch.Tensor(list_bands_s2).permute(-1, 0, 1, 2) + bands_s2_mask = torch.Tensor(list_s2_mask).permute(-1, 0, 1, 2) + bands_s1 = torch.Tensor(list_bands_s1).permute(-1, 0, 1, 2) + bands_s1 = apply_log_to_s1(bands_s1).permute(-1, 0, 1, 2) + + + # TODO Masks contruction for S1 + # build_s1_image_and_masks function for datamodules components datamodule components ? + + # Replace nan by 0 + bands_s1 = bands_s1.nan_to_num() + # These dimensions are useful for unpatchify + # Keep the dimensions of height and width found in chunk + band_h, band_w = bands_s2.shape[-2:] + # Keep the number of patches of patch_size + # in rows and cols found in chunk + h, w = patches.patchify(bands_s2[0, ...], + patch_size=patch_size, + margin=patch_margin).shape[:2] + + + # TODO Apply patchify + + + # Get the model + mmdc_full_model = get_mmdc_full_model( + os.path.join(checkpoint_path, checkpoint_filename) + ) + + # apply model + latent_variable = mmdc_full_model.predict( + s2_x, + s2_m, + s2_angles_x, + s1_x, + s1_vm, + s1_asc_angles_x, + s1_desc_angles_x, + worldclim_x, + srtm_x, + ) + + # TODO unpatchify + + # TODO crop padding + + # TODO Depending of the configuration return a unique latent variable + # or a stack of laten variables + coef = pass + labels = pass + return coef, labels diff --git a/src/mmdc_singledate/inference/utils.py b/src/mmdc_singledate/inference/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..4c24cb634838ce955c8dec096e4557113993e4bb --- /dev/null +++ b/src/mmdc_singledate/inference/utils.py @@ -0,0 +1,240 @@ +#!/usr/bin/env python3 + +from dataclasses import fields + +import pydantic +import torch +import yaml +from config import Config + +from mmdc_singledate.datamodules.types import ( + MMDCDataChannels, + MMDCDataStats, + MMDCShiftScales, + ShiftScale, +) +from mmdc_singledate.models.types import ( # S1S2VAELatentSpace, + ConvnetParams, + MLPConfig, + MMDCDataUse, + ModularEmbeddingConfig, + MultiVAEConfig, + MultiVAELossWeights, + TranslationLossWeights, + UnetParams, +) + + +# TODO Complexify the scales used for the inference +def get_scales(): + """ + Read Scales for Inference + """ + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + dataset_dir = "/work/CESBIO/projects/MAESTRIA/training_dataset2/tftwmwc3" + tile = "T31TEM" + patch_size = "256x256" + roi = 1 + stats = MMDCDataStats( + torch.load(f"{dataset_dir}/{tile}/{tile}_stats_s2_{patch_size}_{roi}.pth"), + torch.load(f"{dataset_dir}/{tile}/{tile}_stats_s1_{patch_size}_{roi}.pth"), + torch.load( + f"{dataset_dir}/{tile}/{tile}_stats_worldclim_{patch_size}_{roi}.pth" + ), + torch.load(f"{dataset_dir}/{tile}/{tile}_stats_srtm_{patch_size}_{roi}.pth"), + ) + scale_regul = torch.nn.Threshold(1e-10, 1.0) + shift_scale_s2 = ShiftScale( + stats.sen2.median.to(device), + scale_regul((stats.sen2.qmax - stats.sen2.qmin) / 2.0).to(device), + ) + shift_scale_s1 = ShiftScale( + stats.sen1.median.to(device), + scale_regul((stats.sen1.qmax - stats.sen1.qmin) / 2.0).to(device), + ) + shift_scale_wc = ShiftScale( + stats.worldclim.median.to(device), + scale_regul((stats.worldclim.qmax - stats.worldclim.qmin) / 2.0).to(device), + ) + shift_scale_srtm = ShiftScale( + stats.srtm.median.to(device), + scale_regul((stats.srtm.qmax - stats.srtm.qmin) / 2.0).to(device), + ) + + return MMDCShiftScales( + shift_scale_s2, shift_scale_s1, shift_scale_wc, shift_scale_srtm + ) + + +# DONE add pydantic validation +def parse_from_yaml(path_to_yaml: str): + """ + Read a yaml file and return a dict + https://betterprogramming.pub/validating-yaml-configs-made-easy-with-pydantic-594522612db5 + """ + with open(path_to_yaml) as f: + config = yaml.safe_load(f) + + return config + + +def parse_multivaeconfig_to_pydantic(cfg: dict): + """ + Convert a MultiVAEConfig to a pydantic version + for validate the values + :args: cfg : configuration as dict + """ + # parse the field of the class to a dict + multivae_fields = { + _field.name: (_field.type, ...) for _field in fields(MultiVAEConfig) + } + # create pydantic model from a dataclass + pydantic_multivaefields = pydantic.create_model("MultiVAEConfig", **multivae_fields) + # apply the convertion + multivaeconfig_pydantic = pydantic_multivaefields(**cfg["model"]["config"]) + + return multivaeconfig_pydantic + + +# DONE add pydantic validation +def parse_from_cfg(path_to_cfg: str): + """ + Read a cfg as config file validate the values + and return a dict + """ + with open(path_to_cfg, encoding="UTF-8") as i2_config: + cfg = Config(i2_config) + cfg_dict = cfg.as_dict() + + return cfg_dict + + +def get_mmdc_full_config(path_to_cfg: str, inference_tile: bool) -> MultiVAEConfig: + """ " + Get the parameters for instantiate the mmdc_full module + based on the target inference mode + """ + + # read the yaml hydra config + if inference_tile: + cfg = parse_from_yaml(path_to_cfg) # "configs/model/mmdc_full.yaml") + cfg = parse_multivaeconfig_to_pydantic(cfg) + + else: + cfg = parse_from_cfg(path_to_cfg) + cfg = parse_multivaeconfig_to_pydantic(cfg) + + mmdc_full_config = MultiVAEConfig( + data_sizes=MMDCDataChannels( + sen1=cfg.data_sizes.sen1, + s1_angles=cfg.data_sizes.s1_angles, + sen2=cfg.data_sizes.sen2, + s2_angles=cfg.data_sizes.s2_angles, + srtm=cfg.data_sizes.srtm, + w_c=cfg.data_sizes.w_c, + ), + embeddings=ModularEmbeddingConfig( + s1_angles=MLPConfig( + hidden_layers=cfg.embeddings.s1_angles.hidden_layers, + out_channels=cfg.embeddings.s1_angles.out_channels, + ), + s1_srtm=UnetParams( + out_channels=cfg.embeddings.s1_srtm.out_channels, + encoder_sizes=cfg.embeddings.s1_srtm.encoder_sizes, + kernel_size=cfg.embeddings.s1_srtm.kernel_size, + tail_layers=cfg.embeddings.s1_srtm.tail_layers, + ), + s2_angles=ConvnetParams( + out_channels=cfg.embeddings.s2_angles.out_channels, + sizes=cfg.embeddings.s2_angles.sizes, + kernel_sizes=cfg.embeddings.s2_angles.kernel_sizes, + ), + s2_srtm=UnetParams( + out_channels=cfg.embeddings.s2_srtm.out_channels, + encoder_sizes=cfg.embeddings.s2_srtm.encoder_sizes, + kernel_size=cfg.embeddings.s2_srtm.kernel_size, + tail_layers=cfg.embeddings.s2_srtm.tail_layers, + ), + w_c=ConvnetParams( + out_channels=cfg.embeddings.w_c.out_channels, + sizes=cfg.embeddings.w_c.sizes, + kernel_sizes=cfg.embeddings.w_c.kernel_sizes, + ), + ), + s1_encoder=UnetParams( + out_channels=cfg.s1_encoder.out_channels, + encoder_sizes=cfg.s1_encoder.encoder_sizes, + kernel_size=cfg.s1_encoder.kernel_size, + tail_layers=cfg.s1_encoder.tail_layers, + ), + s2_encoder=UnetParams( + out_channels=cfg.s2_encoder.out_channels, + encoder_sizes=cfg.s2_encoder.encoder_sizes, + kernel_size=cfg.s2_encoder.kernel_size, + tail_layers=cfg.s2_encoder.tail_layers, + ), + s1_decoder=ConvnetParams( + out_channels=cfg.s1_decoder.out_channels, + sizes=cfg.s1_decoder.sizes, + kernel_sizes=cfg.s1_decoder.kernel_sizes, + ), + s2_decoder=ConvnetParams( + out_channels=cfg.s2_decoder.out_channels, + sizes=cfg.s2_decoder.sizes, + kernel_sizes=cfg.s2_decoder.kernel_sizes, + ), + s1_enc_use=MMDCDataUse( + s1_angles=cfg.s1_enc_use.s1_angles, + s2_angles=cfg.s1_enc_use.s2_angles, + srtm=cfg.s1_enc_use.srtm, + w_c=cfg.s1_enc_use.w_c, + ), + s2_enc_use=MMDCDataUse( + s1_angles=cfg.s2_enc_use.s1_angles, + s2_angles=cfg.s2_enc_use.s2_angles, + srtm=cfg.s2_enc_use.srtm, + w_c=cfg.s2_enc_use.w_c, + ), + s1_dec_use=MMDCDataUse( + s1_angles=cfg.s1_dec_use.s1_angles, + s2_angles=cfg.s1_dec_use.s2_angles, + srtm=cfg.s1_dec_use.srtm, + w_c=cfg.s1_dec_use.w_c, + ), + s2_dec_use=MMDCDataUse( + s1_angles=cfg.s2_dec_use.s1_angles, + s2_angles=cfg.s2_dec_use.s2_angles, + srtm=cfg.s2_dec_use.srtm, + w_c=cfg.s2_dec_use.w_c, + ), + loss_weights=MultiVAELossWeights( + sen1=TranslationLossWeights( + nll=cfg.loss_weights.sen1.nll, + lpips=cfg.loss_weights.sen1.lpips, + sam=cfg.loss_weights.sen1.sam, + gradients=cfg.loss_weights.sen1.gradients, + ), + sen2=TranslationLossWeights( + nll=cfg.loss_weights.sen2.nll, + lpips=cfg.loss_weights.sen2.lpips, + sam=cfg.loss_weights.sen2.sam, + gradients=cfg.loss_weights.sen2.gradients, + ), + s1_s2=TranslationLossWeights( + nll=cfg.loss_weights.s1_s2.nll, + lpips=cfg.loss_weights.s1_s2.lpips, + sam=cfg.loss_weights.s1_s2.sam, + gradients=cfg.loss_weights.s1_s2.gradients, + ), + s2_s1=TranslationLossWeights( + nll=cfg.loss_weights.s2_s1.nll, + lpips=cfg.loss_weights.s2_s1.lpips, + sam=cfg.loss_weights.s2_s1.sam, + gradients=cfg.loss_weights.s2_s1.gradients, + ), + forward=cfg.loss_weights.forward, + cross=cfg.loss_weights.forward, + latent=cfg.loss_weights.latent, + ), + ) + return mmdc_full_config diff --git a/test/full_module_config.json b/test/full_module_config.json new file mode 100644 index 0000000000000000000000000000000000000000..db4ed17229fcb4b9e5c9a88aa23e791825ea80dd --- /dev/null +++ b/test/full_module_config.json @@ -0,0 +1,197 @@ +{ + "_target_": "mmdc_singledate.models.mmdc_full_module.MMDCFullLitModule", + "lr": 0.0001, + "model": { + "_target_": "mmdc_singledate.models.mmdc_full_module.MMDCFullModule", + "config": { + "_target_": "mmdc_singledate.models.types.MultiVAEConfig", + "data_sizes": { + "_target_": "mmdc_singledate.models.types.MMDCDataChannels", + "sen1": 6, + "s1_angles": 6, + "sen2": 10, + "s2_angles": 6, + "srtm": 4, + "w_c": 103 + }, + "embeddings": { + "_target_": "mmdc_singledate.models.types.ModularEmbeddingConfig", + "s1_angles": { + "_target_": "mmdc_singledate.models.types.MLPConfig", + "hidden_layers": [ + 9, + 12, + 9 + ], + "out_channels": 3 + }, + "s1_srtm": { + "_target_": "mmdc_singledate.models.types.UnetParams", + "out_channels": 4, + "encoder_sizes": [ + 32, + 64, + 128 + ], + "kernel_size": 3, + "tail_layers": 3 + }, + "s2_angles": { + "_target_": "mmdc_singledate.models.types.ConvnetParams", + "out_channels": 3, + "sizes": [ + 16, + 8 + ], + "kernel_sizes": [ + 1, + 1, + 1 + ] + }, + "s2_srtm": { + "_target_": "mmdc_singledate.models.types.UnetParams", + "out_channels": 4, + "encoder_sizes": [ + 32, + 64, + 128 + ], + "kernel_size": 3, + "tail_layers": 3 + }, + "w_c": { + "_target_": "mmdc_singledate.models.types.ConvnetParams", + "out_channels": 3, + "sizes": [ + 64, + 32, + 16 + ], + "kernel_sizes": [ + 1, + 1, + 1, + 1 + ] + } + }, + "s1_encoder": { + "_target_": "mmdc_singledate.models.types.UnetParams", + "out_channels": 4, + "encoder_sizes": [ + 64, + 128, + 256, + 512 + ], + "kernel_size": 3, + "tail_layers": 3 + }, + "s2_encoder": { + "_target_": "mmdc_singledate.models.types.UnetParams", + "out_channels": 4, + "encoder_sizes": [ + 64, + 128, + 256, + 512 + ], + "kernel_size": 3, + "tail_layers": 3 + }, + "s1_decoder": { + "_target_": "mmdc_singledate.models.types.ConvnetParams", + "out_channels": 0, + "sizes": [ + 64, + 32, + 16 + ], + "kernel_sizes": [ + 3, + 3, + 3, + 3 + ] + }, + "s2_decoder": { + "_target_": "mmdc_singledate.models.types.ConvnetParams", + "out_channels": 0, + "sizes": [ + 64, + 32, + 16 + ], + "kernel_sizes": [ + 3, + 3, + 3, + 3 + ] + }, + "s1_enc_use": { + "_target_": "mmdc_singledate.models.types.MMDCDataUse", + "s1_angles": true, + "s2_angles": true, + "srtm": true, + "w_c": true + }, + "s2_enc_use": { + "_target_": "mmdc_singledate.models.types.MMDCDataUse", + "s1_angles": true, + "s2_angles": true, + "srtm": true, + "w_c": true + }, + "s1_dec_use": { + "_target_": "mmdc_singledate.models.types.MMDCDataUse", + "s1_angles": true, + "s2_angles": true, + "srtm": true, + "w_c": true + }, + "s2_dec_use": { + "_target_": "mmdc_singledate.models.types.MMDCDataUse", + "s1_angles": true, + "s2_angles": true, + "srtm": true, + "w_c": true + }, + "loss_weights": { + "_target_": "mmdc_singledate.models.types.MultiVAELossWeights", + "sen1": { + "_target_": "mmdc_singledate.models.types.TranslationLossWeights", + "nll": 1, + "lpips": 1, + "sam": 1, + "gradients": 1 + }, + "sen2": { + "_target_": "mmdc_singledate.models.types.TranslationLossWeights", + "nll": 1, + "lpips": 1, + "sam": 1, + "gradients": 1 + }, + "s1_s2": { + "_target_": "mmdc_singledate.models.types.TranslationLossWeights", + "nll": 1, + "lpips": 1, + "sam": 1, + "gradients": 1 + }, + "s2_s1": { + "_target_": "mmdc_singledate.models.types.TranslationLossWeights", + "nll": 1, + "lpips": 1, + "sam": 1, + "gradients": 1 + }, + "forward": 1, + "cross": 1, + "latent": 1 + } + } + } +} diff --git a/test/test_mmdc_inference.py b/test/test_mmdc_inference.py new file mode 100644 index 0000000000000000000000000000000000000000..6d8a5f4fcb26dd62acc4162b8e3a9f0dc045c845 --- /dev/null +++ b/test/test_mmdc_inference.py @@ -0,0 +1,453 @@ +#!/usr/bin/env python3 +# Copyright: (c) 2023 CESBIO / Centre National d'Etudes Spatiales + +import os +from datetime import datetime +from pathlib import Path + +import pandas as pd +import pytest +import rasterio as rio +import torch + +from mmdc_singledate.inference.components.inference_components import ( + get_mmdc_full_model, + predict_mmdc_model, +) +from mmdc_singledate.inference.components.inference_utils import ( + GeoTiffDataset, + generate_chunks, +) +from mmdc_singledate.inference.mmdc_tile_inference import ( # inference_dataframe, + MMDCProcess, + mmdc_tile_inference, + predict_single_date_tile, +) + +from .utils import get_scales, setup_data + +# from mmdc_singledate.models.types import VAELatentSpace + + +# usefull variables +dataset_dir = "/work/CESBIO/projects/MAESTRIA/test_onetile/total/T31TCJ/" +s2_filename_variable = "SENTINEL2A_20180302-105023-464_L2A_T31TCJ_C_V2-2_roi_0.tif" +s1_asc_filename_variable = "S1A_IW_GRDH_1SDV_20180303T174722_20180303T174747_020854_023C45_4525_36.09887223808782_15.429479982324139_roi_0.tif" +s1_desc_filename_variable = "S1A_IW_GRDH_1SDV_20180302T060027_20180302T060052_020832_023B8F_C485_40.83835645911643_165.05888005216622_roi_0.tif" +srtm_filename_variable = "srtm_T31TCJ_roi_0.tif" +wc_filename_variable = "wc_clim_1_T31TCJ_roi_0.tif" + + +sensors_test = [ + (["S2L2A", "S1FULL"]), + (["S2L2A", "S1ASC"]), + (["S2L2A", "S1DESC"]), + (["S2L2A"]), + (["S1FULL"]), + (["S1ASC"]), + (["S1DESC"]), +] + + +datasets = [ + ( + s2_filename_variable, + s1_asc_filename_variable, + s1_desc_filename_variable, + srtm_filename_variable, + wc_filename_variable, + ), + ( + "", + s1_asc_filename_variable, + s1_desc_filename_variable, + srtm_filename_variable, + wc_filename_variable, + ), + ( + s2_filename_variable, + "", + s1_desc_filename_variable, + srtm_filename_variable, + wc_filename_variable, + ), + ( + s2_filename_variable, + s1_asc_filename_variable, + "", + srtm_filename_variable, + wc_filename_variable, + ), +] + + +# @pytest.mark.skip(reason="Check Time Serie generation") +def test_GeoTiffDataset(): + """ """ + input_data = GeoTiffDataset( + s2_filename=os.path.join(dataset_dir, s2_filename_variable), + s1_asc_filename=os.path.join( + dataset_dir, + s1_asc_filename_variable, + ), + s1_desc_filename=os.path.join( + dataset_dir, + s1_desc_filename_variable, + ), + srtm_filename=os.path.join(dataset_dir, srtm_filename_variable), + wc_filename=os.path.join(dataset_dir, wc_filename_variable), + ) + # print GetTiffDataset + print(input_data) + # Check Data Existence + assert input_data.s1_asc_availability == True + assert input_data.s1_desc_availability == True + assert input_data.s2_availabitity == True + + +# @pytest.mark.skip(reason="Check Time Serie generation") +def test_generate_chunks(): + """ + test chunk functionality + """ + chunks = generate_chunks(width=10980, height=10980, nlines=1024) + + assert type(chunks) == list + assert chunks[0] == (0, 0, 10980, 1024) + + +@pytest.mark.skip(reason="Check Time Serie generation") +@pytest.mark.parametrize("sensors", sensors_test) +def test_predict_mmdc_model(sensors): + """ """ + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + # checkpoints + checkpoint_path = "/home/uz/vinascj/scratch/MMDC/results/latent/logs/experiments/runs/mmdc_full/2023-03-16_14-00-10/checkpoints" + checkpoint_filename = "last.ckpt" + + mmdc_full_model = get_mmdc_full_model( + os.path.join(checkpoint_path, checkpoint_filename) + ) + ( + s2_x, + s2_m, + s2_a, + s1_x, + s1_vm, + s1_a_asc, + s1_a_desc, + srtm_x, + wc_x, + device, + ) = setup_data() + scales = get_scales() + mmdc_full_model.set_scales(scales) + # move to device + mmdc_full_model.to(device) + + pred = predict_mmdc_model( + mmdc_full_model, + sensors, + s2_x, + s2_m, + s2_a, + s1_x, + s1_vm, + s1_a_asc, + s1_a_desc, + wc_x, + srtm_x, + ) + print(pred.shape) + + assert type(pred) == torch.Tensor + + +# TODO add cases with no data +# @pytest.mark.skip(reason="Dissable for faster testing") +@pytest.mark.parametrize("sensors", sensors_test) +def test_predict_single_date_tile(sensors): + """ """ + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + export_path = f"/work/CESBIO/projects/MAESTRIA/test_onetile/total/export/test_latent_singledate_with_model_{'_'.join(sensors)}.tif" + input_data = GeoTiffDataset( + s2_filename=os.path.join(dataset_dir, s2_filename_variable), + s1_asc_filename=os.path.join( + dataset_dir, + s1_asc_filename_variable, + ), + s1_desc_filename=os.path.join( + dataset_dir, + s1_desc_filename_variable, + ), + srtm_filename=os.path.join(dataset_dir, srtm_filename_variable), + wc_filename=os.path.join(dataset_dir, wc_filename_variable), + ) + + nb_bands = len(sensors) * 8 + print("nb_bands :", nb_bands) + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + # checkpoints + checkpoint_path = "/home/uz/vinascj/scratch/MMDC/results/latent/logs/experiments/runs/mmdc_full/2023-03-16_14-00-10/checkpoints" + checkpoint_filename = "last.ckpt" # "last.ckpt" + + mmdc_full_model = get_mmdc_full_model( + os.path.join(checkpoint_path, checkpoint_filename), + mmdc_full_config_path="/home/uz/vinascj/src/MMDC/mmdc-singledate/configs/model/mmdc_full.json", + inference_tile=True, + ) + scales = get_scales() + mmdc_full_model.set_scales(scales) + # move to device + mmdc_full_model.to(device) + assert mmdc_full_model.training == False + + # init the process + process = MMDCProcess( + count=nb_bands, # 4, + nb_lines=256, + patch_size=128, + model=mmdc_full_model, + process=predict_mmdc_model, + ) + + # entry point pred function + predict_single_date_tile( + input_data=input_data, + export_path=export_path, + sensors=sensors, + process=process, + ) + + assert Path(export_path).exists() == True + + with rio.open(export_path) as predited_raster: + predicted_metadata = predited_raster.meta.copy() + assert predicted_metadata["count"] == nb_bands + + +@pytest.mark.parametrize( + "s2_filename_variable, s1_asc_filename_variable, s1_desc_filename_variable, srtm_filename_variable, wc_filename_variable", + datasets, +) +def test_predict_single_date_tile_no_data( + s2_filename_variable, + s1_asc_filename_variable, + s1_desc_filename_variable, + srtm_filename_variable, + wc_filename_variable, +): + """ """ + sensors = ["S2L2A", "S1ASC"] + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + export_path = f"/work/CESBIO/projects/MAESTRIA/test_onetile/total/export/test_latent_singledate_with_model_{'_'.join(sensors)}.tif" + dataset_dir = "/work/CESBIO/projects/MAESTRIA/test_onetile/total/T31TCJ/" + input_data = GeoTiffDataset( + s2_filename=os.path.join(dataset_dir, s2_filename_variable), + s1_asc_filename=os.path.join( + dataset_dir, + s1_asc_filename_variable, + ), + s1_desc_filename=os.path.join( + dataset_dir, + s1_desc_filename_variable, + ), + srtm_filename=os.path.join(dataset_dir, srtm_filename_variable), + wc_filename=os.path.join(dataset_dir, wc_filename_variable), + ) + + nb_bands = len(sensors) * 8 + print("nb_bands :", nb_bands) + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + # checkpoints + checkpoint_path = "/home/uz/vinascj/scratch/MMDC/results/latent/logs/experiments/runs/mmdc_full/2023-03-16_14-00-10/checkpoints" + checkpoint_filename = "last.ckpt" # "last.ckpt" + + mmdc_full_model = get_mmdc_full_model( + os.path.join(checkpoint_path, checkpoint_filename), + mmdc_full_config_path="/home/uz/vinascj/src/MMDC/mmdc-singledate/test/full_module_config.json", + inference_tile=True, + ) + # move to device + scales = get_scales() + mmdc_full_model.set_scales(scales) + mmdc_full_model.to(device) + assert mmdc_full_model.training == False + + # init the process + process = MMDCProcess( + count=nb_bands, # 4, + nb_lines=256, + patch_size=128, + model=mmdc_full_model, + process=predict_mmdc_model, + ) + + # entry point pred function + predict_single_date_tile( + input_data=input_data, + export_path=export_path, + sensors=sensors, + process=process, + ) + + assert Path(export_path).exists() == True + + with rio.open(export_path) as predited_raster: + predicted_metadata = predited_raster.meta.copy() + assert predicted_metadata["count"] == nb_bands + + +# @pytest.mark.skip(reason="Check Time Serie generation") +def test_mmdc_tile_inference(): + """ + Test the inference code in a tile + """ + # feed parameters + input_path = Path("/work/CESBIO/projects/MAESTRIA/training_dataset2/") + export_path = Path("/work/CESBIO/projects/MAESTRIA/test_onetile/total/export/") + model_checkpoint_path = "/home/uz/vinascj/scratch/MMDC/results/latent/logs/experiments/runs/mmdc_full/2023-03-16_14-00-10/checkpoints/last.ckpt" + model_config_path = ( + "/home/uz/vinascj/src/MMDC/mmdc-singledate/configs/model/mmdc_full.yaml" + ) + input_tile_list = ["/work/CESBIO/projects/MAESTRIA/training_dataset2/T33TUL"] + + # dataframe with input data + tile_df = pd.DataFrame( + { + "date": [ + datetime.strptime("2018-01-11", "%Y-%m-%d"), + datetime.strptime("2018-01-14", "%Y-%m-%d"), + datetime.strptime("2018-01-19", "%Y-%m-%d"), + datetime.strptime("2018-01-21", "%Y-%m-%d"), + datetime.strptime("2018-01-24", "%Y-%m-%d"), + ], + "patch_s2": [ + "/work/CESBIO/projects/MAESTRIA/training_dataset2/T33TUL/SENTINEL2B_20180111-100351-463_L2A_T33TUL_D_V1-4_roi_0.tif", + "/work/CESBIO/projects/MAESTRIA/training_dataset2/T33TUL/SENTINEL2B_20180114-101347-463_L2A_T33TUL_D_V1-4_roi_0.tif", + "/work/CESBIO/projects/MAESTRIA/training_dataset2/T33TUL/SENTINEL2A_20180119-101331-457_L2A_T33TUL_D_V1-4_roi_0.tif", + "/work/CESBIO/projects/MAESTRIA/training_dataset2/T33TUL/SENTINEL2B_20180121-100542-770_L2A_T33TUL_D_V1-4_roi_0.tif", + "/work/CESBIO/projects/MAESTRIA/training_dataset2/T33TUL/SENTINEL2B_20180124-101352-753_L2A_T33TUL_D_V1-4_roi_0.tif", + ], + "patchasc_s1": [ + "/work/CESBIO/projects/MAESTRIA/training_dataset2/T33TUL/S1A_IW_GRDH_1SDV_20180108T165831_20180108T165856_020066_022322_3008_36.29396013100586_14.593286735562288_roi_0.tif", + "/work/CESBIO/projects/MAESTRIA/training_dataset2/T33TUL/S1B_IW_GRDH_1SDV_20180114T165747_20180114T165812_009170_0106A2_AFD8_36.27898252884488_14.533603849909781_roi_0.tif", + "/work/CESBIO/projects/MAESTRIA/training_dataset2/T33TUL/nan", + "/work/CESBIO/projects/MAESTRIA/training_dataset2/T33TUL/S1A_IW_GRDH_1SDV_20180120T165830_20180120T165855_020241_0228B0_95C7_36.29306272148375_14.591766979064303_roi_0.tif", + "/work/CESBIO/projects/MAESTRIA/training_dataset2/T33TUL/nan", + ], + "patchdesc_s1": [ + "/work/CESBIO/projects/MAESTRIA/training_dataset2/T33TUL/nan", + "/work/CESBIO/projects/MAESTRIA/training_dataset2/T33TUL/S1B_IW_GRDH_1SDV_20180113T051001_20180113T051026_009148_0105F4_218B_45.70533932879138_164.47774901324672_roi_0.tif", + "/work/CESBIO/projects/MAESTRIA/training_dataset2/T33TUL/S1A_IW_GRDH_1SDV_20180119T051046_20180119T051111_020219_0227FF_F7E9_45.71131905091333_164.664087191203_roi_0.tif", + "/work/CESBIO/projects/MAESTRIA/training_dataset2/T33TUL/S1A_IW_GRDH_1SDV_20180119T051046_20180119T051111_020219_0227FF_F7E9_45.71131905091333_164.664087191203_roi_0.tif", + "/work/CESBIO/projects/MAESTRIA/training_dataset2/T33TUL/S1A_IW_GRDH_1SDV_20180124T051848_20180124T051913_020292_022A63_177C_35.81756575888757_165.03806323648553_roi_0.tif", + ], + "Tuiles": [ + "T33TUL", + "T33TUL", + "T33TUL", + "T33TUL", + "T33TUL", + ], + "patch": [ + 0, + 0, + 0, + 0, + 0, + ], + "roi": [ + 0, + 0, + 0, + 0, + 0, + ], + "worldclim_filename": [ + "/work/CESBIO/projects/MAESTRIA/training_dataset2/T33TUL/wc_clim_1_T33TUL_roi_0.tif", + "/work/CESBIO/projects/MAESTRIA/training_dataset2/T33TUL/wc_clim_1_T33TUL_roi_0.tif", + "/work/CESBIO/projects/MAESTRIA/training_dataset2/T33TUL/wc_clim_1_T33TUL_roi_0.tif", + "/work/CESBIO/projects/MAESTRIA/training_dataset2/T33TUL/wc_clim_1_T33TUL_roi_0.tif", + "/work/CESBIO/projects/MAESTRIA/training_dataset2/T33TUL/wc_clim_1_T33TUL_roi_0.tif", + ], + "srtm_filename": [ + "/work/CESBIO/projects/MAESTRIA/training_dataset2/T33TUL/srtm_T33TUL_roi_0.tif", + "/work/CESBIO/projects/MAESTRIA/training_dataset2/T33TUL/srtm_T33TUL_roi_0.tif", + "/work/CESBIO/projects/MAESTRIA/training_dataset2/T33TUL/srtm_T33TUL_roi_0.tif", + "/work/CESBIO/projects/MAESTRIA/training_dataset2/T33TUL/srtm_T33TUL_roi_0.tif", + "/work/CESBIO/projects/MAESTRIA/training_dataset2/T33TUL/srtm_T33TUL_roi_0.tif", + ], + "patch_s2_availability": [ + True, + True, + True, + True, + True, + ], + "patchasc_s1_availability": [ + True, + True, + False, + True, + False, + ], + "patchdesc_s1_availability": [ + True, + True, + False, + True, + False, + ], + } + ) + # inference_dataframe( + # samples_dir=input_path, + # input_tile_list=input_tile_list, + # days_gap=5, + # ) + # # just test the df + print( + tile_df.loc[ + :, + [ + "patch_s2_availability", + "patchasc_s1_availability", + "patchdesc_s1_availability", + ], + ] + ) + + # apply prediction + mmdc_tile_inference( + inference_dataframe=tile_df, + export_path=export_path, + model_checkpoint_path=model_checkpoint_path, + model_config_path=model_config_path, + patch_size=128, + nb_lines=256, + ) + + assert Path(export_path / "latent_tile_infer_2018-01-11_S2L2A_S1FULL.tif").exists() + assert Path(export_path / "latent_tile_infer_2018-01-14_S2L2A_S1FULL.tif").exists() + assert Path(export_path / "latent_tile_infer_2018-01-19_S2L2A.tif").exists() + assert Path(export_path / "latent_tile_infer_2018-01-21_S2L2A_S1FULL.tif").exists() + assert Path(export_path / "latent_tile_infer_2018-01-24_S2L2A.tif").exists() + + with rio.open( + os.path.join(export_path, "latent_tile_infer_2018-01-19_S2L2A.tif") + ) as raster: + metadata = raster.meta.copy() + assert metadata["count"] == 8 + + with rio.open( + os.path.join(export_path, "latent_tile_infer_2018-01-11_S2L2A_S1FULL.tif") + ) as raster: + metadata = raster.meta.copy() + assert metadata["count"] == 16