[Mlir-commits] [mlir] 3f050f6 - [mlir][transform] Add multi-buffering to the transform dialect
Stella Stamenova
llvmlistbot at llvm.org
Wed Sep 28 14:30:43 PDT 2022
Author: Kirsten Lee
Date: 2022-09-28T14:30:02-07:00
New Revision: 3f050f6ac4626c9bf99bfdf5e4d7162ba7ad5cdc
URL: https://github.com/llvm/llvm-project/commit/3f050f6ac4626c9bf99bfdf5e4d7162ba7ad5cdc
DIFF: https://github.com/llvm/llvm-project/commit/3f050f6ac4626c9bf99bfdf5e4d7162ba7ad5cdc.diff
LOG: [mlir][transform] Add multi-buffering to the transform dialect
Add the plumbing necessary to call the memref dialect's multiBuffer
function. This will allow separation between choosing which buffers
to multi-buffer and the actual transform.
Alter the multibuffer function to return the newly created
allocation if multi-buffering succeeds. This is necessary to
communicate with the transform dialect hooks what allocation
multi-buffering created.
Reviewed By: ftynse, nicolasvasilache
Differential Revision: https://reviews.llvm.org/D133985
Added:
mlir/include/mlir/Dialect/MemRef/TransformOps/CMakeLists.txt
mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.h
mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td
mlir/lib/Dialect/MemRef/TransformOps/CMakeLists.txt
mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp
mlir/test/Dialect/MemRef/transform-ops.mlir
Modified:
mlir/include/mlir/Dialect/MemRef/CMakeLists.txt
mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h
mlir/include/mlir/InitAllDialects.h
mlir/lib/Dialect/MemRef/CMakeLists.txt
mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/MemRef/CMakeLists.txt b/mlir/include/mlir/Dialect/MemRef/CMakeLists.txt
index 9f57627c321fb..660deb21479d2 100644
--- a/mlir/include/mlir/Dialect/MemRef/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/MemRef/CMakeLists.txt
@@ -1,2 +1,3 @@
add_subdirectory(IR)
+add_subdirectory(TransformOps)
add_subdirectory(Transforms)
diff --git a/mlir/include/mlir/Dialect/MemRef/TransformOps/CMakeLists.txt b/mlir/include/mlir/Dialect/MemRef/TransformOps/CMakeLists.txt
new file mode 100644
index 0000000000000..8dbe988023594
--- /dev/null
+++ b/mlir/include/mlir/Dialect/MemRef/TransformOps/CMakeLists.txt
@@ -0,0 +1,6 @@
+set(LLVM_TARGET_DEFINITIONS MemRefTransformOps.td)
+mlir_tablegen(MemRefTransformOps.h.inc -gen-op-decls)
+mlir_tablegen(MemRefTransformOps.cpp.inc -gen-op-defs)
+add_public_tablegen_target(MLIRMemRefTransformOpsIncGen)
+
+add_mlir_doc(MemRefTransformOps MemRefTransformOps Dialects/ -gen-op-doc)
diff --git a/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.h b/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.h
new file mode 100644
index 0000000000000..dd33df99861c7
--- /dev/null
+++ b/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.h
@@ -0,0 +1,33 @@
+//===- MemRefTransformOps.h - MemRef transformation ops ---------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_MEMREF_TRANSFORMOPS_MEMREFTRANSFORMOPS_H
+#define MLIR_DIALECT_MEMREF_TRANSFORMOPS_MEMREFTRANSFORMOPS_H
+
+#include "mlir/Dialect/PDL/IR/PDLTypes.h"
+#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
+#include "mlir/IR/OpImplementation.h"
+
+namespace mlir {
+namespace memref {
+class AllocOp;
+} // namespace memref
+} // namespace mlir
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.h.inc"
+
+namespace mlir {
+class DialectRegistry;
+
+namespace memref {
+void registerTransformDialectExtension(DialectRegistry ®istry);
+} // namespace memref
+} // namespace mlir
+
+#endif // MLIR_DIALECT_MEMREF_TRANSFORMOPS_MEMREFTRANSFORMOPS_H
diff --git a/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td b/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td
new file mode 100644
index 0000000000000..0a66f82f47e08
--- /dev/null
+++ b/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td
@@ -0,0 +1,52 @@
+//===- MemRefTransformOps.td - MemRef transformation ops --*- tablegen -*--===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MEMREF_TRANSFORM_OPS
+#define MEMREF_TRANSFORM_OPS
+
+include "mlir/Dialect/Transform/IR/TransformDialect.td"
+include "mlir/Dialect/Transform/IR/TransformEffects.td"
+include "mlir/Dialect/Transform/IR/TransformInterfaces.td"
+include "mlir/Dialect/PDL/IR/PDLTypes.td"
+include "mlir/Interfaces/SideEffectInterfaces.td"
+include "mlir/IR/OpBase.td"
+
+def MemRefMultiBufferOp : Op<Transform_Dialect, "memref.multibuffer",
+ [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
+ TransformOpInterface, TransformEachOpTrait]> {
+ let summary = "Multibuffers an allocation";
+ let description = [{
+ Transformation to do multi-buffering/array expansion to remove
+ dependencies on the temporary allocation between consecutive loop
+ iterations. This transform expands the size of an allocation by
+ a given multiplicative factor and fixes up any users of the
+ multibuffered allocation.
+
+ #### Return modes
+
+ This operation returns the new allocation if multi-buffering
+ succeeds, and failure otherwise.
+ }];
+
+ let arguments =
+ (ins PDL_Operation:$target,
+ ConfinedAttr<I64Attr, [IntPositive]>:$factor);
+
+ let results = (outs PDL_Operation:$transformed);
+
+ let assemblyFormat = "$target attr-dict";
+
+ let extraClassDeclaration = [{
+ ::mlir::DiagnosedSilenceableFailure applyToOne(
+ memref::AllocOp target,
+ ::llvm::SmallVector<::mlir::Operation *> &results,
+ ::mlir::transform::TransformState &state);
+ }];
+}
+
+#endif // MEMREF_TRANSFORM_OPS
diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h
index 3d518169ae231..2a7b5d82a4cdb 100644
--- a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h
@@ -62,8 +62,8 @@ void populateSimplifyExtractStridedMetadataOpPatterns(
/// Transformation to do multi-buffering/array expansion to remove dependencies
/// on the temporary allocation between consecutive loop iterations.
-/// It return success if the allocation was multi-buffered and returns failure()
-/// otherwise.
+/// It returns the new allocation if the original allocation was multi-buffered
+/// and returns failure() otherwise.
/// Example:
/// ```
/// %0 = memref.alloc() : memref<4x128xf32>
@@ -85,7 +85,8 @@ void populateSimplifyExtractStridedMetadataOpPatterns(
/// "some_use"(%sv) : (memref<4x128xf32, strided<...>) -> ()
/// }
/// ```
-LogicalResult multiBuffer(memref::AllocOp allocOp, unsigned multiplier);
+FailureOr<memref::AllocOp> multiBuffer(memref::AllocOp allocOp,
+ unsigned multiplier);
//===----------------------------------------------------------------------===//
// Passes
diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h
index c91eb56f44bf1..2cd15a655000e 100644
--- a/mlir/include/mlir/InitAllDialects.h
+++ b/mlir/include/mlir/InitAllDialects.h
@@ -41,6 +41,7 @@
#include "mlir/Dialect/MLProgram/IR/MLProgram.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.h"
#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
#include "mlir/Dialect/OpenACC/OpenACC.h"
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
@@ -112,6 +113,7 @@ inline void registerAllDialects(DialectRegistry ®istry) {
// Register all dialect extensions.
bufferization::registerTransformDialectExtension(registry);
linalg::registerTransformDialectExtension(registry);
+ memref::registerTransformDialectExtension(registry);
scf::registerTransformDialectExtension(registry);
// Register all external models.
diff --git a/mlir/lib/Dialect/MemRef/CMakeLists.txt b/mlir/lib/Dialect/MemRef/CMakeLists.txt
index 31167e6af908b..c47e4c5495c17 100644
--- a/mlir/lib/Dialect/MemRef/CMakeLists.txt
+++ b/mlir/lib/Dialect/MemRef/CMakeLists.txt
@@ -1,3 +1,4 @@
add_subdirectory(IR)
+add_subdirectory(TransformOps)
add_subdirectory(Transforms)
add_subdirectory(Utils)
diff --git a/mlir/lib/Dialect/MemRef/TransformOps/CMakeLists.txt b/mlir/lib/Dialect/MemRef/TransformOps/CMakeLists.txt
new file mode 100644
index 0000000000000..03de3afe39ca5
--- /dev/null
+++ b/mlir/lib/Dialect/MemRef/TransformOps/CMakeLists.txt
@@ -0,0 +1,18 @@
+add_mlir_dialect_library(MLIRMemRefTransformOps
+ MemRefTransformOps.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/MemRef/TransformOps
+
+ DEPENDS
+ MLIRMemRefTransformOpsIncGen
+
+ LINK_LIBS PUBLIC
+ MLIRAffineDialect
+ MLIRArithmeticDialect
+ MLIRIR
+ MLIRPDLDialect
+ MLIRMemRefDialect
+ MLIRMemRefTransforms
+ MLIRTransformDialect
+)
diff --git a/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp b/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp
new file mode 100644
index 0000000000000..c91a3b9511a9f
--- /dev/null
+++ b/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp
@@ -0,0 +1,69 @@
+//===- MemRefTransformOps.cpp - Implementation of Memref transform ops ----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/MemRef/Transforms/Passes.h"
+#include "mlir/Dialect/PDL/IR/PDL.h"
+#include "mlir/Dialect/Transform/IR/TransformDialect.h"
+#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
+
+using namespace mlir;
+
+//===----------------------------------------------------------------------===//
+// MemRefMultiBufferOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure
+transform::MemRefMultiBufferOp::applyToOne(memref::AllocOp target,
+ SmallVector<Operation *> &results,
+ transform::TransformState &state) {
+ auto newBuffer = memref::multiBuffer(target, getFactor());
+ if (failed(newBuffer)) {
+ Diagnostic diag(target->getLoc(), DiagnosticSeverity::Note);
+ diag << "op failed to multibuffer";
+ return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag));
+ }
+
+ results.push_back(newBuffer.value());
+ return DiagnosedSilenceableFailure(success());
+}
+
+//===----------------------------------------------------------------------===//
+// Transform op registration
+//===----------------------------------------------------------------------===//
+
+namespace {
+class MemRefTransformDialectExtension
+ : public transform::TransformDialectExtension<
+ MemRefTransformDialectExtension> {
+public:
+ using Base::Base;
+
+ void init() {
+ declareDependentDialect<pdl::PDLDialect>();
+ declareGeneratedDialect<AffineDialect>();
+ declareGeneratedDialect<arith::ArithmeticDialect>();
+
+ registerTransformOps<
+#define GET_OP_LIST
+#include "mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp.inc"
+ >();
+ }
+};
+} // namespace
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp.inc"
+
+void mlir::memref::registerTransformDialectExtension(
+ DialectRegistry ®istry) {
+ registry.addExtensions<MemRefTransformDialectExtension>();
+}
diff --git a/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp b/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp
index 354bc11a00c7b..75e3746dac8b7 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp
@@ -78,8 +78,8 @@ static Value getOrCreateValue(OpFoldResult res, OpBuilder &builder,
// Returns success if the transformation happened and failure otherwise.
// This is not a pattern as it requires propagating the new memref type to its
// uses and requires updating subview ops.
-LogicalResult mlir::memref::multiBuffer(memref::AllocOp allocOp,
- unsigned multiplier) {
+FailureOr<memref::AllocOp> mlir::memref::multiBuffer(memref::AllocOp allocOp,
+ unsigned multiplier) {
DominanceInfo dom(allocOp->getParentOp());
LoopLikeOpInterface candidateLoop;
for (Operation *user : allocOp->getUsers()) {
@@ -142,5 +142,5 @@ LogicalResult mlir::memref::multiBuffer(memref::AllocOp allocOp,
offsets, sizes, strides);
replaceUsesAndPropagateType(allocOp, subview, builder);
allocOp.erase();
- return success();
+ return newAlloc;
}
diff --git a/mlir/test/Dialect/MemRef/transform-ops.mlir b/mlir/test/Dialect/MemRef/transform-ops.mlir
new file mode 100644
index 0000000000000..5b6f70c7be8ec
--- /dev/null
+++ b/mlir/test/Dialect/MemRef/transform-ops.mlir
@@ -0,0 +1,37 @@
+// RUN: mlir-opt %s -test-transform-dialect-interpreter -verify-diagnostics -allow-unregistered-dialect | FileCheck %s
+
+// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2) -> (((d0 - d1) floordiv d2) mod 2)>
+// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0)[s0] -> (d0 + s0)>
+
+// CHECK-LABEL: func @multi_buffer
+func.func @multi_buffer(%in: memref<16xf32>) {
+ // CHECK: %[[A:.*]] = memref.alloc() : memref<2x4xf32>
+ // expected-remark @below {{transformed}}
+ %tmp = memref.alloc() : memref<4xf32>
+
+ // CHECK: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK: %[[C4:.*]] = arith.constant 4 : index
+ %c0 = arith.constant 0 : index
+ %c4 = arith.constant 4 : index
+ %c16 = arith.constant 16 : index
+
+ // CHECK: scf.for %[[IV:.*]] = %[[C0]]
+ scf.for %i0 = %c0 to %c16 step %c4 {
+ // CHECK: %[[I:.*]] = affine.apply #[[$MAP0]](%[[IV]], %[[C0]], %[[C4]])
+ // CHECK: %[[SV:.*]] = memref.subview %[[A]][%[[I]], 0] [1, 4] [1, 1] : memref<2x4xf32> to memref<4xf32, strided<[1], offset: ?>>
+ %1 = memref.subview %in[%i0] [4] [1] : memref<16xf32> to memref<4xf32, affine_map<(d0)[s0] -> (d0 + s0)>>
+ // CHECK: memref.copy %{{.*}}, %[[SV]] : memref<4xf32, #[[$MAP1]]> to memref<4xf32, strided<[1], offset: ?>>
+ memref.copy %1, %tmp : memref<4xf32, affine_map<(d0)[s0] -> (d0 + s0)>> to memref<4xf32>
+
+ "some_use"(%tmp) : (memref<4xf32>) ->()
+ }
+ return
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+ %0 = transform.structured.match ops{["memref.alloc"]} in %arg1
+ %1 = transform.memref.multibuffer %0 {factor = 2 : i64}
+ // Verify that the returned handle is usable.
+ transform.test_print_remark_at_operand %1, "transformed"
+}
More information about the Mlir-commits
mailing list