Highlighting for visibility:
The custom checkpoint helper in this repo re-runs the forward pass during backprop without restoring the RNG state. Every stochastic layer inside the block, like dropout, sees a different random mask on the backward pass, so the gradients don't match the loss. So non-zero dropout with gradient checkpoint enabled causes loss to diverge.
Code link: nn.py#L124
This Colab notebook isolates the issue with code from this repo.
Colab notebook
I wrote more details here after using it for a large model training
https://almutwakel.com/blog/divergence