diff --git a/enzyme/Enzyme/MLIR/Implementations/MemRefAutoDiffOpInterfaceImpl.cpp b/enzyme/Enzyme/MLIR/Implementations/MemRefAutoDiffOpInterfaceImpl.cpp index 0a60414fd0d4..7f4564879a12 100644 --- a/enzyme/Enzyme/MLIR/Implementations/MemRefAutoDiffOpInterfaceImpl.cpp +++ b/enzyme/Enzyme/MLIR/Implementations/MemRefAutoDiffOpInterfaceImpl.cpp @@ -220,6 +220,33 @@ struct AllocOpInterfaceReverse } }; +struct SubViewOpInterfaceReverse + : public ReverseAutoDiffOpInterface::ExternalModel< + SubViewOpInterfaceReverse, memref::SubViewOp> { + void createReverseModeAdjoint(Operation *op, OpBuilder &builder, + MGradientUtilsReverse *gutils, + SmallVector caches) const {} + + SmallVector cacheValues(Operation *op, + MGradientUtilsReverse *gutils) const { + return SmallVector(); + } + + void createShadowValues(Operation *op, OpBuilder &builder, + MGradientUtilsReverse *gutils) const { + auto subviewOp = cast(op); + auto newSubviewOp = cast(gutils->getNewFromOriginal(op)); + if (gutils->hasInvertPointer(subviewOp.getSource())) { + Value shadow = builder.create( + op->getLoc(), newSubviewOp.getType(), + gutils->invertPointerM(subviewOp.getSource(), builder), + newSubviewOp.getMixedOffsets(), newSubviewOp.getMixedSizes(), + newSubviewOp.getMixedStrides()); + gutils->mapShadowValue(subviewOp, shadow, builder); + } + } +}; + struct AllocOpInterface : public AutoDiffOpInterface::ExternalModel { @@ -268,5 +295,6 @@ void mlir::enzyme::registerMemRefDialectAutoDiffInterface( memref::LoadOp::attachInterface(*context); memref::StoreOp::attachInterface(*context); memref::AllocOp::attachInterface(*context); + memref::SubViewOp::attachInterface(*context); }); } diff --git a/enzyme/Enzyme/MLIR/enzymemlir-opt.cpp b/enzyme/Enzyme/MLIR/enzymemlir-opt.cpp index 2c0cabe3f035..100186241153 100644 --- a/enzyme/Enzyme/MLIR/enzymemlir-opt.cpp +++ b/enzyme/Enzyme/MLIR/enzymemlir-opt.cpp @@ -58,6 +58,7 @@ int main(int argc, char **argv) { registry.insert(); registry.insert(); registry.insert(); + registry.insert(); registry.insert(); registry.insert();