[Mlir-commits] [mlir] [mlir][bufferization] Implement BufferDeallocationopInterface for scf.forall.in_parallel (PR #66351)

Martin Erhart llvmlistbot at llvm.org
Thu Sep 14 07:19:06 PDT 2023


https://github.com/maerhart updated https://github.com/llvm/llvm-project/pull/66351:

>From 3db42e80236638d434a9c8f6391d49179b0c478f Mon Sep 17 00:00:00 2001
From: Martin Erhart <merhart at google.com>
Date: Tue, 12 Sep 2023 15:21:53 +0000
Subject: [PATCH] [mlir][bufferization] Implement BufferDeallocationopInterface
 for scf.forall.in_parallel

The scf.forall.in_parallel terminator operation has a nested graph region with
the NoTerminator trait. Such regions are not supported by the default
implementations. Therefore, this commit adds a specialized implementation for
this operation which only covers the case where the nested region is empty.
This is because after bufferization, ops like tensor.parallel_insert_slice were
already converted to memref operations residing int the scf.forall only and the
nested region of scf.forall.in_parallel ends up empty.
---
 .../BufferDeallocationOpInterfaceImpl.h       | 22 +++++
 mlir/include/mlir/InitAllDialects.h           |  3 +
 .../BufferDeallocationOpInterfaceImpl.cpp     | 87 +++++++++++++++++++
 .../lib/Dialect/SCF/Transforms/CMakeLists.txt |  1 +
 .../test/Dialect/SCF/buffer-deallocation.mlir | 24 +++++
 5 files changed, 137 insertions(+)
 create mode 100644 mlir/include/mlir/Dialect/SCF/Transforms/BufferDeallocationOpInterfaceImpl.h
 create mode 100644 mlir/lib/Dialect/SCF/Transforms/BufferDeallocationOpInterfaceImpl.cpp
 create mode 100644 mlir/test/Dialect/SCF/buffer-deallocation.mlir

diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/BufferDeallocationOpInterfaceImpl.h b/mlir/include/mlir/Dialect/SCF/Transforms/BufferDeallocationOpInterfaceImpl.h
new file mode 100644
index 000000000000000..cbfb490c7ab098c
--- /dev/null
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/BufferDeallocationOpInterfaceImpl.h
@@ -0,0 +1,22 @@
+//===- BufferDeallocationOpInterfaceImpl.h ----------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_SCF_TRANSFORMS_BUFFERDEALLOCATIONOPINTERFACEIMPL_H
+#define MLIR_DIALECT_SCF_TRANSFORMS_BUFFERDEALLOCATIONOPINTERFACEIMPL_H
+
+namespace mlir {
+
+class DialectRegistry;
+
+namespace scf {
+void registerBufferDeallocationOpInterfaceExternalModels(
+    DialectRegistry &registry);
+} // namespace scf
+} // namespace mlir
+
+#endif // MLIR_DIALECT_SCF_TRANSFORMS_BUFFERDEALLOCATIONOPINTERFACEIMPL_H
diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h
index 0182ab93929cb8c..5b2b1ed24d5173d 100644
--- a/mlir/include/mlir/InitAllDialects.h
+++ b/mlir/include/mlir/InitAllDialects.h
@@ -60,6 +60,8 @@
 #include "mlir/Dialect/Quant/QuantOps.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.h"
+#include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.h"
+#include "mlir/Dialect/SCF/Transforms/BufferDeallocationOpInterfaceImpl.h"
 #include "mlir/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
 #include "mlir/Dialect/Shape/IR/Shape.h"
@@ -149,6 +151,7 @@ inline void registerAllDialects(DialectRegistry &registry) {
   memref::registerRuntimeVerifiableOpInterfaceExternalModels(registry);
   memref::registerValueBoundsOpInterfaceExternalModels(registry);
   memref::registerMemorySlotExternalModels(registry);
+  scf::registerBufferDeallocationOpInterfaceExternalModels(registry);
   scf::registerBufferizableOpInterfaceExternalModels(registry);
   scf::registerValueBoundsOpInterfaceExternalModels(registry);
   shape::registerBufferizableOpInterfaceExternalModels(registry);
diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferDeallocationOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferDeallocationOpInterfaceImpl.cpp
new file mode 100644
index 000000000000000..88cb3e9b097147f
--- /dev/null
+++ b/mlir/lib/Dialect/SCF/Transforms/BufferDeallocationOpInterfaceImpl.cpp
@@ -0,0 +1,87 @@
+//===- BufferDeallocationOpInterfaceImpl.cpp ------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/SCF/Transforms/BufferDeallocationOpInterfaceImpl.h"
+#include "mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.h"
+#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+
+using namespace mlir;
+using namespace mlir::bufferization;
+
+namespace {
+/// The `scf.forall.in_parallel` terminator is special in a few ways:
+/// * It does not implement the BranchOpInterface or
+///   RegionBranchTerminatorOpInterface, but the ParallelCombiningOpInterface
+///   which is not supported by BufferDeallocation.
+/// * It has a graph-like region which only allows one specific tensor op
+/// * After bufferization the nested region is always empty
+/// For these reasons we provide custom deallocation logic via this external
+/// model.
+///
+/// Example:
+/// ```mlir
+/// scf.forall (%arg1) in (%arg0) {
+///   %alloc = memref.alloc() : memref<2xf32>
+///   ...
+///   <implicit in_parallel terminator here>
+/// }
+/// ```
+/// gets transformed to
+/// ```mlir
+/// scf.forall (%arg1) in (%arg0) {
+///   %alloc = memref.alloc() : memref<2xf32>
+///   ...
+///   bufferization.dealloc (%alloc : memref<2xf32>) if (%true)
+///   <implicit in_parallel terminator here>
+/// }
+/// ```
+struct InParallelOpInterface
+    : public BufferDeallocationOpInterface::ExternalModel<InParallelOpInterface,
+                                                          scf::InParallelOp> {
+  FailureOr<Operation *> process(Operation *op, DeallocationState &state,
+                                 const DeallocationOptions &options) const {
+    auto inParallelOp = cast<scf::InParallelOp>(op);
+    OpBuilder builder(op);
+    if (!inParallelOp.getBody()->empty())
+      return op->emitError("only supported when nested region is empty");
+
+    // Collect the values to deallocate and retain and use them to create the
+    // dealloc operation.
+    Block *block = op->getBlock();
+    SmallVector<Value> memrefs, conditions, toRetain;
+    if (failed(state.getMemrefsAndConditionsToDeallocate(
+            builder, op->getLoc(), block, memrefs, conditions)))
+      return failure();
+
+    state.getMemrefsToRetain(block, /*toBlock=*/nullptr, {}, toRetain);
+    if (memrefs.empty() && toRetain.empty())
+      return op;
+
+    auto deallocOp = builder.create<bufferization::DeallocOp>(
+        op->getLoc(), memrefs, conditions, toRetain);
+
+    // We want to replace the current ownership of the retained values with the
+    // result values of the dealloc operation as they are always unique.
+    state.resetOwnerships(deallocOp.getRetained(), block);
+    for (auto [retained, ownership] :
+         llvm::zip(deallocOp.getRetained(), deallocOp.getUpdatedConditions()))
+      state.updateOwnership(retained, ownership, block);
+
+    return op;
+  }
+};
+
+} // namespace
+
+void mlir::scf::registerBufferDeallocationOpInterfaceExternalModels(
+    DialectRegistry &registry) {
+  registry.addExtension(+[](MLIRContext *ctx, SCFDialect *dialect) {
+    InParallelOp::attachInterface<InParallelOpInterface>(*ctx);
+  });
+}
diff --git a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
index 20abf2b583bbf70..fdaeb2fad9afa4f 100644
--- a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
@@ -1,4 +1,5 @@
 add_mlir_dialect_library(MLIRSCFTransforms
+  BufferDeallocationOpInterfaceImpl.cpp
   BufferizableOpInterfaceImpl.cpp
   Bufferize.cpp
   ForToWhile.cpp
diff --git a/mlir/test/Dialect/SCF/buffer-deallocation.mlir b/mlir/test/Dialect/SCF/buffer-deallocation.mlir
new file mode 100644
index 000000000000000..0847b1f1183f9f8
--- /dev/null
+++ b/mlir/test/Dialect/SCF/buffer-deallocation.mlir
@@ -0,0 +1,24 @@
+// RUN: mlir-opt -verify-diagnostics -ownership-based-buffer-deallocation \
+// RUN:   -buffer-deallocation-simplification -split-input-file %s | FileCheck %s
+
+func.func @parallel_insert_slice(%arg0: index) {
+  %c0 = arith.constant 0 : index
+  %alloc = memref.alloc() : memref<2xf32>
+  scf.forall (%arg1) in (%arg0) {
+    %alloc0 = memref.alloc() : memref<2xf32>
+    %0 = memref.load %alloc[%c0] : memref<2xf32>
+    linalg.fill ins(%0 : f32) outs(%alloc0 : memref<2xf32>)
+  }
+  return
+}
+
+// CHECK-LABEL: func @parallel_insert_slice
+//  CHECK-SAME: (%arg0: index)
+//       CHECK: [[ALLOC0:%.+]] = memref.alloc(
+//       CHECK: scf.forall
+//       CHECK:   [[ALLOC1:%.+]] = memref.alloc(
+//       CHECK:   bufferization.dealloc ([[ALLOC1]] : memref<2xf32>) if (%true
+//   CHECK-NOT: retain
+//       CHECK: }
+//       CHECK: bufferization.dealloc ([[ALLOC0]] : memref<2xf32>) if (%true
+//   CHECK-NOT: retain



More information about the Mlir-commits mailing list