[Mlir-commits] [mlir] [mlir][bufferization][scf] Implement BufferDeallocationOpInterface for scf.reduce.return (PR #66886)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Sep 20 03:38:54 PDT 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

<details>
<summary>Changes</summary>

This is necessary to run the new buffer deallocation pipeline as part of the sparse compiler pipeline.

---
Full diff: https://github.com/llvm/llvm-project/pull/66886.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/SCF/Transforms/BufferDeallocationOpInterfaceImpl.cpp (+16) 
- (modified) mlir/test/Dialect/SCF/buffer-deallocation.mlir (+28) 


``````````diff
diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferDeallocationOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferDeallocationOpInterfaceImpl.cpp
index 4ded8ba55013dc6..24fbc1dca83618c 100644
--- a/mlir/lib/Dialect/SCF/Transforms/BufferDeallocationOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/BufferDeallocationOpInterfaceImpl.cpp
@@ -56,11 +56,27 @@ struct InParallelOpInterface
   }
 };
 
+struct ReduceReturnOpInterface
+    : public BufferDeallocationOpInterface::ExternalModel<
+          ReduceReturnOpInterface, scf::ReduceReturnOp> {
+  FailureOr<Operation *> process(Operation *op, DeallocationState &state,
+                                 const DeallocationOptions &options) const {
+    auto reduceReturnOp = cast<scf::ReduceReturnOp>(op);
+    if (isa<BaseMemRefType>(reduceReturnOp.getOperand().getType()))
+      return op->emitError("only supported when operand is not a MemRef");
+
+    SmallVector<Value> updatedOperandOwnership;
+    return deallocation_impl::insertDeallocOpForReturnLike(
+        state, op, {}, updatedOperandOwnership);
+  }
+};
+
 } // namespace
 
 void mlir::scf::registerBufferDeallocationOpInterfaceExternalModels(
     DialectRegistry &registry) {
   registry.addExtension(+[](MLIRContext *ctx, SCFDialect *dialect) {
     InParallelOp::attachInterface<InParallelOpInterface>(*ctx);
+    ReduceReturnOp::attachInterface<ReduceReturnOpInterface>(*ctx);
   });
 }
diff --git a/mlir/test/Dialect/SCF/buffer-deallocation.mlir b/mlir/test/Dialect/SCF/buffer-deallocation.mlir
index 0847b1f1183f9f8..99cfed99c02d1a9 100644
--- a/mlir/test/Dialect/SCF/buffer-deallocation.mlir
+++ b/mlir/test/Dialect/SCF/buffer-deallocation.mlir
@@ -22,3 +22,31 @@ func.func @parallel_insert_slice(%arg0: index) {
 //       CHECK: }
 //       CHECK: bufferization.dealloc ([[ALLOC0]] : memref<2xf32>) if (%true
 //   CHECK-NOT: retain
+
+// -----
+
+func.func @reduce(%buffer: memref<100xf32>) {
+  %init = arith.constant 0.0 : f32
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  scf.parallel (%iv) = (%c0) to (%c1) step (%c1) init (%init) -> f32 {
+    %elem_to_reduce = memref.load %buffer[%iv] : memref<100xf32>
+    scf.reduce(%elem_to_reduce) : f32 {
+      ^bb0(%lhs : f32, %rhs: f32):
+        %alloc = memref.alloc() : memref<2xf32>
+        memref.store %lhs, %alloc [%c0] : memref<2xf32>
+        memref.store %rhs, %alloc [%c1] : memref<2xf32>
+        %0 = memref.load %alloc[%c0] : memref<2xf32>
+        %1 = memref.load %alloc[%c1] : memref<2xf32>
+        %res = arith.addf %0, %1 : f32
+        scf.reduce.return %res : f32
+    }
+  }
+  func.return
+}
+
+// CHECK-LABEL: func @reduce
+//       CHECK: scf.reduce
+//       CHECK:   [[ALLOC:%.+]] = memref.alloc(
+//       CHECK:   bufferization.dealloc ([[ALLOC]] :{{.*}}) if (%true
+//       CHECK:   scf.reduce.return

``````````

</details>


https://github.com/llvm/llvm-project/pull/66886


More information about the Mlir-commits mailing list