For Triangle Multiplication (
|
def _to_jax(self) -> jax.lax.Precision: |
), the JAX precision mapping is incorrect:
Precision.TF32x3 is mapped to
jax.lax.Precision.HIGHEST rather than
jax.lax.DotAlgorithmPreset.TF32_TF32_F32_X3. I would recommend using these precise precisions instead.