@@ -30,11 +30,22 @@ def compute_gramian(t: Tensor, contracted_dims: int = -1) -> PSDTensor:
3030 first dimension).
3131 """
3232
33- contracted_dims = contracted_dims if 0 <= contracted_dims else contracted_dims + t .ndim
34- indices_source = list (range (t .ndim - contracted_dims ))
35- indices_dest = list (range (t .ndim - 1 , contracted_dims - 1 , - 1 ))
36- transposed = t .movedim (indices_source , indices_dest )
37- gramian = torch .tensordot (t , transposed , dims = contracted_dims )
33+ # Optimization: it's faster to do that than moving dims and using tensordot, and this case
34+ # happens very often, sometimes hundreds of times for a single jac_to_grad.
35+ if contracted_dims == - 1 :
36+ if t .ndim == 1 :
37+ matrix = t .unsqueeze (1 )
38+ else :
39+ matrix = t .flatten (start_dim = 1 )
40+
41+ gramian = matrix @ matrix .T
42+
43+ else :
44+ contracted_dims = contracted_dims if 0 <= contracted_dims else contracted_dims + t .ndim
45+ indices_source = list (range (t .ndim - contracted_dims ))
46+ indices_dest = list (range (t .ndim - 1 , contracted_dims - 1 , - 1 ))
47+ transposed = t .movedim (indices_source , indices_dest )
48+ gramian = torch .tensordot (t , transposed , dims = contracted_dims )
3849 return cast (PSDTensor , gramian )
3950
4051
0 commit comments