We should reimplement our jagged tensor reductions using scatter_reduce_ rather than the current implementation which uses naive reduction with atomics.
I'm approving the PR to get rid of the dependency, but as a follow up, could you also change the JaggedReduce implementation to use scatter_reduce_?
Originally posted by @fwilliams in #571 (review)