Issue un trim function for MonoModalSITS
The current trim
function does not perform correctly, as far as I understand its behavior. Currently, We have this behavior:
import torch
from einops import parse_shape, rearrange
from mtan_s1s2_classif.models.torch.time_series import MonoModalSITS
# Generate dummy data
torch.manual_seed(0)
nb_doy = 10
batch_size = 4
batch_size_half = batch_size // 2
w = h = 2
nb_channels = 2
doy = torch.rand((batch_size, nb_doy))
# doy, _ = torch.sort(doy, stable=True)
data = torch.rand((batch_size, nb_doy, nb_channels, w, h))
mask = torch.full((batch_size, nb_doy, w, h), False)
# mask = torch.cat([torch.rand(batch_size, nb_doy-3, w, h) > 0.5, torch.full((batch_size, 3, w, h), False)], dim=1)
# Fully masked parts
mask[:batch_size_half, 0::2, ...] = True
mask[batch_size_half:, 1::2, ...] = True
mono_sits = MonoModalSITS(data, doy, mask)
If we print the first sample to follow hit along the different transformation
print(mono_sits.doy[0, :])
print(mono_sits.mask[0, :, 0, 0])
: tensor([0.4963, 0.7682, 0.0885, 0.1320, 0.3074, 0.6341, 0.4901, 0.8964, 0.4556,
: 0.6323])
: tensor([ True, False, True, False, True, False, True, False, True, False])
Now we do a line-by-line trim (see function).
First issue here: if we do not add stable=True
in torch.sort
function, the order might be not preserved, see here
data_shape = parse_shape(mono_sits.data, "b t c h w")
doy_masked = torch.all(rearrange(mono_sits.mask, "b t h w -> b t (h w)"), dim=-1)
_, ind = torch.sort(doy_masked, dim=-1, descending=False, stable=True)
print(mono_sits.doy[0, ind[0, :]])
: tensor([0.7682, 0.1320, 0.6341, 0.8964, 0.6323, 0.4963, 0.0885, 0.3074, 0.4901,
: 0.4556])
The last line above should be the output of the following, if I have correctly understood what is the objective trim
:
trim_doy = torch.full_like(mono_sits.doy, 0)
trim_doy.scatter_(1, ind, mono_sits.doy)
print(trim_doy[0, :])
: tensor([0.6341, 0.4963, 0.4901, 0.7682, 0.8964, 0.0885, 0.4556, 0.1320, 0.6323,
: 0.3074])
To recover the same value, we should use gather
rather than scatter
:
trim_doy = torch.gather(mono_sits.doy, 1, ind)
print(trim_doy[0, :])
: tensor([0.7682, 0.1320, 0.6341, 0.8964, 0.6323, 0.4963, 0.0885, 0.3074, 0.4901,
: 0.4556])
I have pushed a here a possible solution (wip): https://src.koda.cnrs.fr/mmdc/mtan_s1s2_classif/-/tree/test_trim?ref_type=heads
Also, I suggest two possible modifications:
- rename the function
trim_
to match pytorch convention for inplace modification function (but that is not a truly inplace function) - Add option
sorted=False|True
to preserve order w.r.t. doys. Currently the trimed SITS has all doys with mask=False at the beginning of the array, and the order of the data is changed w.r.t. the initial data.
The latter should be implemented after the trim operation. I will do it accordingly.