[Mlir-commits] [mlir] 3f050f6 - [mlir][transform] Add multi-buffering to the transform dialect

Stella Stamenova llvmlistbot at llvm.org
Wed Sep 28 14:30:43 PDT 2022


Author: Kirsten Lee
Date: 2022-09-28T14:30:02-07:00
New Revision: 3f050f6ac4626c9bf99bfdf5e4d7162ba7ad5cdc

URL: https://github.com/llvm/llvm-project/commit/3f050f6ac4626c9bf99bfdf5e4d7162ba7ad5cdc
DIFF: https://github.com/llvm/llvm-project/commit/3f050f6ac4626c9bf99bfdf5e4d7162ba7ad5cdc.diff

LOG: [mlir][transform] Add multi-buffering to the transform dialect

Add the plumbing necessary to call the memref dialect's multiBuffer
function. This will allow separation between choosing which buffers
to multi-buffer and the actual transform.

Alter the multibuffer function to return the newly created
allocation if multi-buffering succeeds. This is necessary to
communicate with the transform dialect hooks what allocation
multi-buffering created.

Reviewed By: ftynse, nicolasvasilache

Differential Revision: https://reviews.llvm.org/D133985

Added: 
    mlir/include/mlir/Dialect/MemRef/TransformOps/CMakeLists.txt
    mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.h
    mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td
    mlir/lib/Dialect/MemRef/TransformOps/CMakeLists.txt
    mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp
    mlir/test/Dialect/MemRef/transform-ops.mlir

Modified: 
    mlir/include/mlir/Dialect/MemRef/CMakeLists.txt
    mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h
    mlir/include/mlir/InitAllDialects.h
    mlir/lib/Dialect/MemRef/CMakeLists.txt
    mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/MemRef/CMakeLists.txt b/mlir/include/mlir/Dialect/MemRef/CMakeLists.txt
index 9f57627c321fb..660deb21479d2 100644
--- a/mlir/include/mlir/Dialect/MemRef/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/MemRef/CMakeLists.txt
@@ -1,2 +1,3 @@
 add_subdirectory(IR)
+add_subdirectory(TransformOps)
 add_subdirectory(Transforms)

diff  --git a/mlir/include/mlir/Dialect/MemRef/TransformOps/CMakeLists.txt b/mlir/include/mlir/Dialect/MemRef/TransformOps/CMakeLists.txt
new file mode 100644
index 0000000000000..8dbe988023594
--- /dev/null
+++ b/mlir/include/mlir/Dialect/MemRef/TransformOps/CMakeLists.txt
@@ -0,0 +1,6 @@
+set(LLVM_TARGET_DEFINITIONS MemRefTransformOps.td)
+mlir_tablegen(MemRefTransformOps.h.inc -gen-op-decls)
+mlir_tablegen(MemRefTransformOps.cpp.inc -gen-op-defs)
+add_public_tablegen_target(MLIRMemRefTransformOpsIncGen)
+
+add_mlir_doc(MemRefTransformOps MemRefTransformOps Dialects/ -gen-op-doc)

diff  --git a/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.h b/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.h
new file mode 100644
index 0000000000000..dd33df99861c7
--- /dev/null
+++ b/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.h
@@ -0,0 +1,33 @@
+//===- MemRefTransformOps.h - MemRef transformation ops ---------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_MEMREF_TRANSFORMOPS_MEMREFTRANSFORMOPS_H
+#define MLIR_DIALECT_MEMREF_TRANSFORMOPS_MEMREFTRANSFORMOPS_H
+
+#include "mlir/Dialect/PDL/IR/PDLTypes.h"
+#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
+#include "mlir/IR/OpImplementation.h"
+
+namespace mlir {
+namespace memref {
+class AllocOp;
+} // namespace memref
+} // namespace mlir
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.h.inc"
+
+namespace mlir {
+class DialectRegistry;
+
+namespace memref {
+void registerTransformDialectExtension(DialectRegistry &registry);
+} // namespace memref
+} // namespace mlir
+
+#endif // MLIR_DIALECT_MEMREF_TRANSFORMOPS_MEMREFTRANSFORMOPS_H

diff  --git a/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td b/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td
new file mode 100644
index 0000000000000..0a66f82f47e08
--- /dev/null
+++ b/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td
@@ -0,0 +1,52 @@
+//===- MemRefTransformOps.td - MemRef transformation ops --*- tablegen -*--===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MEMREF_TRANSFORM_OPS
+#define MEMREF_TRANSFORM_OPS
+
+include "mlir/Dialect/Transform/IR/TransformDialect.td"
+include "mlir/Dialect/Transform/IR/TransformEffects.td"
+include "mlir/Dialect/Transform/IR/TransformInterfaces.td"
+include "mlir/Dialect/PDL/IR/PDLTypes.td"
+include "mlir/Interfaces/SideEffectInterfaces.td"
+include "mlir/IR/OpBase.td"
+
+def MemRefMultiBufferOp : Op<Transform_Dialect, "memref.multibuffer",
+    [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
+     TransformOpInterface, TransformEachOpTrait]> {
+  let summary = "Multibuffers an allocation";
+  let description = [{
+     Transformation to do multi-buffering/array expansion to remove
+     dependencies on the temporary allocation between consecutive loop
+     iterations. This transform expands the size of an allocation by
+     a given multiplicative factor and fixes up any users of the
+     multibuffered allocation.
+
+     #### Return modes
+
+     This operation returns the new allocation if multi-buffering
+     succeeds, and failure otherwise.
+  }];
+
+  let arguments =
+      (ins PDL_Operation:$target,
+           ConfinedAttr<I64Attr, [IntPositive]>:$factor);
+
+  let results = (outs PDL_Operation:$transformed);
+
+  let assemblyFormat = "$target attr-dict";
+
+  let extraClassDeclaration = [{
+    ::mlir::DiagnosedSilenceableFailure applyToOne(
+        memref::AllocOp target,
+        ::llvm::SmallVector<::mlir::Operation *> &results,
+        ::mlir::transform::TransformState &state);
+  }];
+}
+
+#endif // MEMREF_TRANSFORM_OPS

diff  --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h
index 3d518169ae231..2a7b5d82a4cdb 100644
--- a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h
@@ -62,8 +62,8 @@ void populateSimplifyExtractStridedMetadataOpPatterns(
 
 /// Transformation to do multi-buffering/array expansion to remove dependencies
 /// on the temporary allocation between consecutive loop iterations.
-/// It return success if the allocation was multi-buffered and returns failure()
-/// otherwise.
+/// It returns the new allocation if the original allocation was multi-buffered
+/// and returns failure() otherwise.
 /// Example:
 /// ```
 /// %0 = memref.alloc() : memref<4x128xf32>
@@ -85,7 +85,8 @@ void populateSimplifyExtractStridedMetadataOpPatterns(
 ///   "some_use"(%sv) : (memref<4x128xf32, strided<...>) -> ()
 /// }
 /// ```
-LogicalResult multiBuffer(memref::AllocOp allocOp, unsigned multiplier);
+FailureOr<memref::AllocOp> multiBuffer(memref::AllocOp allocOp,
+                                       unsigned multiplier);
 
 //===----------------------------------------------------------------------===//
 // Passes

diff  --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h
index c91eb56f44bf1..2cd15a655000e 100644
--- a/mlir/include/mlir/InitAllDialects.h
+++ b/mlir/include/mlir/InitAllDialects.h
@@ -41,6 +41,7 @@
 #include "mlir/Dialect/MLProgram/IR/MLProgram.h"
 #include "mlir/Dialect/Math/IR/Math.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.h"
 #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
 #include "mlir/Dialect/OpenACC/OpenACC.h"
 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
@@ -112,6 +113,7 @@ inline void registerAllDialects(DialectRegistry &registry) {
   // Register all dialect extensions.
   bufferization::registerTransformDialectExtension(registry);
   linalg::registerTransformDialectExtension(registry);
+  memref::registerTransformDialectExtension(registry);
   scf::registerTransformDialectExtension(registry);
 
   // Register all external models.

diff  --git a/mlir/lib/Dialect/MemRef/CMakeLists.txt b/mlir/lib/Dialect/MemRef/CMakeLists.txt
index 31167e6af908b..c47e4c5495c17 100644
--- a/mlir/lib/Dialect/MemRef/CMakeLists.txt
+++ b/mlir/lib/Dialect/MemRef/CMakeLists.txt
@@ -1,3 +1,4 @@
 add_subdirectory(IR)
+add_subdirectory(TransformOps)
 add_subdirectory(Transforms)
 add_subdirectory(Utils)

diff  --git a/mlir/lib/Dialect/MemRef/TransformOps/CMakeLists.txt b/mlir/lib/Dialect/MemRef/TransformOps/CMakeLists.txt
new file mode 100644
index 0000000000000..03de3afe39ca5
--- /dev/null
+++ b/mlir/lib/Dialect/MemRef/TransformOps/CMakeLists.txt
@@ -0,0 +1,18 @@
+add_mlir_dialect_library(MLIRMemRefTransformOps
+  MemRefTransformOps.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/MemRef/TransformOps
+
+  DEPENDS
+  MLIRMemRefTransformOpsIncGen
+
+  LINK_LIBS PUBLIC
+  MLIRAffineDialect
+  MLIRArithmeticDialect
+  MLIRIR
+  MLIRPDLDialect
+  MLIRMemRefDialect
+  MLIRMemRefTransforms
+  MLIRTransformDialect
+)

diff  --git a/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp b/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp
new file mode 100644
index 0000000000000..c91a3b9511a9f
--- /dev/null
+++ b/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp
@@ -0,0 +1,69 @@
+//===- MemRefTransformOps.cpp - Implementation of Memref transform ops ----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/MemRef/Transforms/Passes.h"
+#include "mlir/Dialect/PDL/IR/PDL.h"
+#include "mlir/Dialect/Transform/IR/TransformDialect.h"
+#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
+
+using namespace mlir;
+
+//===----------------------------------------------------------------------===//
+// MemRefMultiBufferOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure
+transform::MemRefMultiBufferOp::applyToOne(memref::AllocOp target,
+                                           SmallVector<Operation *> &results,
+                                           transform::TransformState &state) {
+  auto newBuffer = memref::multiBuffer(target, getFactor());
+  if (failed(newBuffer)) {
+    Diagnostic diag(target->getLoc(), DiagnosticSeverity::Note);
+    diag << "op failed to multibuffer";
+    return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag));
+  }
+
+  results.push_back(newBuffer.value());
+  return DiagnosedSilenceableFailure(success());
+}
+
+//===----------------------------------------------------------------------===//
+// Transform op registration
+//===----------------------------------------------------------------------===//
+
+namespace {
+class MemRefTransformDialectExtension
+    : public transform::TransformDialectExtension<
+          MemRefTransformDialectExtension> {
+public:
+  using Base::Base;
+
+  void init() {
+    declareDependentDialect<pdl::PDLDialect>();
+    declareGeneratedDialect<AffineDialect>();
+    declareGeneratedDialect<arith::ArithmeticDialect>();
+
+    registerTransformOps<
+#define GET_OP_LIST
+#include "mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp.inc"
+        >();
+  }
+};
+} // namespace
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp.inc"
+
+void mlir::memref::registerTransformDialectExtension(
+    DialectRegistry &registry) {
+  registry.addExtensions<MemRefTransformDialectExtension>();
+}

diff  --git a/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp b/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp
index 354bc11a00c7b..75e3746dac8b7 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp
@@ -78,8 +78,8 @@ static Value getOrCreateValue(OpFoldResult res, OpBuilder &builder,
 // Returns success if the transformation happened and failure otherwise.
 // This is not a pattern as it requires propagating the new memref type to its
 // uses and requires updating subview ops.
-LogicalResult mlir::memref::multiBuffer(memref::AllocOp allocOp,
-                                        unsigned multiplier) {
+FailureOr<memref::AllocOp> mlir::memref::multiBuffer(memref::AllocOp allocOp,
+                                                     unsigned multiplier) {
   DominanceInfo dom(allocOp->getParentOp());
   LoopLikeOpInterface candidateLoop;
   for (Operation *user : allocOp->getUsers()) {
@@ -142,5 +142,5 @@ LogicalResult mlir::memref::multiBuffer(memref::AllocOp allocOp,
                                                     offsets, sizes, strides);
   replaceUsesAndPropagateType(allocOp, subview, builder);
   allocOp.erase();
-  return success();
+  return newAlloc;
 }

diff  --git a/mlir/test/Dialect/MemRef/transform-ops.mlir b/mlir/test/Dialect/MemRef/transform-ops.mlir
new file mode 100644
index 0000000000000..5b6f70c7be8ec
--- /dev/null
+++ b/mlir/test/Dialect/MemRef/transform-ops.mlir
@@ -0,0 +1,37 @@
+// RUN: mlir-opt %s -test-transform-dialect-interpreter -verify-diagnostics -allow-unregistered-dialect | FileCheck %s
+
+// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2) -> (((d0 - d1) floordiv d2) mod 2)>
+// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0)[s0] -> (d0 + s0)>
+
+// CHECK-LABEL: func @multi_buffer
+func.func @multi_buffer(%in: memref<16xf32>) {
+  // CHECK: %[[A:.*]] = memref.alloc() : memref<2x4xf32>
+  // expected-remark @below {{transformed}}
+  %tmp = memref.alloc() : memref<4xf32>
+
+  // CHECK: %[[C0:.*]] = arith.constant 0 : index
+  // CHECK: %[[C4:.*]] = arith.constant 4 : index
+  %c0 = arith.constant 0 : index
+  %c4 = arith.constant 4 : index
+  %c16 = arith.constant 16 : index
+
+  // CHECK: scf.for %[[IV:.*]] = %[[C0]]
+  scf.for %i0 = %c0 to %c16 step %c4 {
+    // CHECK: %[[I:.*]] = affine.apply #[[$MAP0]](%[[IV]], %[[C0]], %[[C4]])
+    // CHECK: %[[SV:.*]] = memref.subview %[[A]][%[[I]], 0] [1, 4] [1, 1] : memref<2x4xf32> to memref<4xf32, strided<[1], offset: ?>>
+    %1 = memref.subview %in[%i0] [4] [1] : memref<16xf32> to memref<4xf32, affine_map<(d0)[s0] -> (d0 + s0)>>
+    // CHECK: memref.copy %{{.*}}, %[[SV]] : memref<4xf32, #[[$MAP1]]> to memref<4xf32, strided<[1], offset: ?>>
+    memref.copy %1, %tmp :  memref<4xf32, affine_map<(d0)[s0] -> (d0 + s0)>> to memref<4xf32>
+
+    "some_use"(%tmp) : (memref<4xf32>) ->()
+  }
+  return
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+  %0 = transform.structured.match ops{["memref.alloc"]} in %arg1
+  %1 = transform.memref.multibuffer %0 {factor = 2 : i64}
+  // Verify that the returned handle is usable.
+  transform.test_print_remark_at_operand %1, "transformed"
+}


        


More information about the Mlir-commits mailing list