[Mlir-commits] [mlir] [mlir][bufferization][scf] Implement BufferDeallocationOpInterface for scf.reduce.return (PR #66886)

Martin Erhart llvmlistbot at llvm.org
Wed Sep 20 03:37:50 PDT 2023


https://github.com/maerhart created https://github.com/llvm/llvm-project/pull/66886

This is necessary to run the new buffer deallocation pipeline as part of the sparse compiler pipeline.

>From af08ee5e287da1810250da5e8980aee85e14ed8b Mon Sep 17 00:00:00 2001
From: Martin Erhart <merhart at google.com>
Date: Wed, 20 Sep 2023 10:33:55 +0000
Subject: [PATCH] [mlir][bufferization][scf] Implement
 BufferDeallocationOpInterface for scf.reduce.return

---
 .../BufferDeallocationOpInterfaceImpl.cpp     | 16 +++++++++++
 .../test/Dialect/SCF/buffer-deallocation.mlir | 28 +++++++++++++++++++
 2 files changed, 44 insertions(+)

diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferDeallocationOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferDeallocationOpInterfaceImpl.cpp
index 4ded8ba55013dc6..24fbc1dca83618c 100644
--- a/mlir/lib/Dialect/SCF/Transforms/BufferDeallocationOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/BufferDeallocationOpInterfaceImpl.cpp
@@ -56,11 +56,27 @@ struct InParallelOpInterface
   }
 };
 
+struct ReduceReturnOpInterface
+    : public BufferDeallocationOpInterface::ExternalModel<
+          ReduceReturnOpInterface, scf::ReduceReturnOp> {
+  FailureOr<Operation *> process(Operation *op, DeallocationState &state,
+                                 const DeallocationOptions &options) const {
+    auto reduceReturnOp = cast<scf::ReduceReturnOp>(op);
+    if (isa<BaseMemRefType>(reduceReturnOp.getOperand().getType()))
+      return op->emitError("only supported when operand is not a MemRef");
+
+    SmallVector<Value> updatedOperandOwnership;
+    return deallocation_impl::insertDeallocOpForReturnLike(
+        state, op, {}, updatedOperandOwnership);
+  }
+};
+
 } // namespace
 
 void mlir::scf::registerBufferDeallocationOpInterfaceExternalModels(
     DialectRegistry &registry) {
   registry.addExtension(+[](MLIRContext *ctx, SCFDialect *dialect) {
     InParallelOp::attachInterface<InParallelOpInterface>(*ctx);
+    ReduceReturnOp::attachInterface<ReduceReturnOpInterface>(*ctx);
   });
 }
diff --git a/mlir/test/Dialect/SCF/buffer-deallocation.mlir b/mlir/test/Dialect/SCF/buffer-deallocation.mlir
index 0847b1f1183f9f8..99cfed99c02d1a9 100644
--- a/mlir/test/Dialect/SCF/buffer-deallocation.mlir
+++ b/mlir/test/Dialect/SCF/buffer-deallocation.mlir
@@ -22,3 +22,31 @@ func.func @parallel_insert_slice(%arg0: index) {
 //       CHECK: }
 //       CHECK: bufferization.dealloc ([[ALLOC0]] : memref<2xf32>) if (%true
 //   CHECK-NOT: retain
+
+// -----
+
+func.func @reduce(%buffer: memref<100xf32>) {
+  %init = arith.constant 0.0 : f32
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  scf.parallel (%iv) = (%c0) to (%c1) step (%c1) init (%init) -> f32 {
+    %elem_to_reduce = memref.load %buffer[%iv] : memref<100xf32>
+    scf.reduce(%elem_to_reduce) : f32 {
+      ^bb0(%lhs : f32, %rhs: f32):
+        %alloc = memref.alloc() : memref<2xf32>
+        memref.store %lhs, %alloc [%c0] : memref<2xf32>
+        memref.store %rhs, %alloc [%c1] : memref<2xf32>
+        %0 = memref.load %alloc[%c0] : memref<2xf32>
+        %1 = memref.load %alloc[%c1] : memref<2xf32>
+        %res = arith.addf %0, %1 : f32
+        scf.reduce.return %res : f32
+    }
+  }
+  func.return
+}
+
+// CHECK-LABEL: func @reduce
+//       CHECK: scf.reduce
+//       CHECK:   [[ALLOC:%.+]] = memref.alloc(
+//       CHECK:   bufferization.dealloc ([[ALLOC]] :{{.*}}) if (%true
+//       CHECK:   scf.reduce.return



More information about the Mlir-commits mailing list