In convert_seq_to_patch_view, the return value of select_segments is unused, suggesting the scoring module might not be effective.
|
@staticmethod |
|
def convert_seq_to_patch_view( |
|
mask: torch.Tensor, |
|
scores: torch.Tensor, |
|
patch_len: int = 8, |
|
stride: Optional[int] = None, |
|
): |
|
""" |
|
Input: |
|
mask : torch.Tensor of shape [batch_size x seq_len] |
|
Output |
|
mask : torch.Tensor of shape [batch_size x n_patches] |
|
""" |
|
stride = patch_len if stride is None else stride |
|
# sm.forward(mask) |
|
if hasattr(scores, "shape"): |
|
select_segments(scores, patch_len, mask=mask) |
|
mask = mask.unfold(dimension=-1, size=patch_len, step=stride) |
|
# mask : [batch_size x n_patches x patch_len] |
|
return (mask.sum(dim=-1) == patch_len).long() |
In
convert_seq_to_patch_view, the return value ofselect_segmentsis unused, suggesting the scoring module might not be effective.Samay/src/samay/models/lptm/model/masktrain.py
Lines 19 to 38 in 6a549ae