[Mlir-commits] [mlir] [mlir][bufferization][scf] Implement BufferDeallocationOpInterface for scf.reduce.return (PR #66886)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Sep 20 03:38:54 PDT 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
<details>
<summary>Changes</summary>
This is necessary to run the new buffer deallocation pipeline as part of the sparse compiler pipeline.
---
Full diff: https://github.com/llvm/llvm-project/pull/66886.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/SCF/Transforms/BufferDeallocationOpInterfaceImpl.cpp (+16)
- (modified) mlir/test/Dialect/SCF/buffer-deallocation.mlir (+28)
``````````diff
diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferDeallocationOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferDeallocationOpInterfaceImpl.cpp
index 4ded8ba55013dc6..24fbc1dca83618c 100644
--- a/mlir/lib/Dialect/SCF/Transforms/BufferDeallocationOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/BufferDeallocationOpInterfaceImpl.cpp
@@ -56,11 +56,27 @@ struct InParallelOpInterface
}
};
+struct ReduceReturnOpInterface
+ : public BufferDeallocationOpInterface::ExternalModel<
+ ReduceReturnOpInterface, scf::ReduceReturnOp> {
+ FailureOr<Operation *> process(Operation *op, DeallocationState &state,
+ const DeallocationOptions &options) const {
+ auto reduceReturnOp = cast<scf::ReduceReturnOp>(op);
+ if (isa<BaseMemRefType>(reduceReturnOp.getOperand().getType()))
+ return op->emitError("only supported when operand is not a MemRef");
+
+ SmallVector<Value> updatedOperandOwnership;
+ return deallocation_impl::insertDeallocOpForReturnLike(
+ state, op, {}, updatedOperandOwnership);
+ }
+};
+
} // namespace
void mlir::scf::registerBufferDeallocationOpInterfaceExternalModels(
DialectRegistry ®istry) {
registry.addExtension(+[](MLIRContext *ctx, SCFDialect *dialect) {
InParallelOp::attachInterface<InParallelOpInterface>(*ctx);
+ ReduceReturnOp::attachInterface<ReduceReturnOpInterface>(*ctx);
});
}
diff --git a/mlir/test/Dialect/SCF/buffer-deallocation.mlir b/mlir/test/Dialect/SCF/buffer-deallocation.mlir
index 0847b1f1183f9f8..99cfed99c02d1a9 100644
--- a/mlir/test/Dialect/SCF/buffer-deallocation.mlir
+++ b/mlir/test/Dialect/SCF/buffer-deallocation.mlir
@@ -22,3 +22,31 @@ func.func @parallel_insert_slice(%arg0: index) {
// CHECK: }
// CHECK: bufferization.dealloc ([[ALLOC0]] : memref<2xf32>) if (%true
// CHECK-NOT: retain
+
+// -----
+
+func.func @reduce(%buffer: memref<100xf32>) {
+ %init = arith.constant 0.0 : f32
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ scf.parallel (%iv) = (%c0) to (%c1) step (%c1) init (%init) -> f32 {
+ %elem_to_reduce = memref.load %buffer[%iv] : memref<100xf32>
+ scf.reduce(%elem_to_reduce) : f32 {
+ ^bb0(%lhs : f32, %rhs: f32):
+ %alloc = memref.alloc() : memref<2xf32>
+ memref.store %lhs, %alloc [%c0] : memref<2xf32>
+ memref.store %rhs, %alloc [%c1] : memref<2xf32>
+ %0 = memref.load %alloc[%c0] : memref<2xf32>
+ %1 = memref.load %alloc[%c1] : memref<2xf32>
+ %res = arith.addf %0, %1 : f32
+ scf.reduce.return %res : f32
+ }
+ }
+ func.return
+}
+
+// CHECK-LABEL: func @reduce
+// CHECK: scf.reduce
+// CHECK: [[ALLOC:%.+]] = memref.alloc(
+// CHECK: bufferization.dealloc ([[ALLOC]] :{{.*}}) if (%true
+// CHECK: scf.reduce.return
``````````
</details>
https://github.com/llvm/llvm-project/pull/66886
More information about the Mlir-commits
mailing list