[Mlir-commits] [mlir] ae1ea0b - [mlir] Decompose Bufferization Clone operation into Memref Alloc and Copy.
Julian Gross
llvmlistbot at llvm.org
Tue Nov 30 01:17:35 PST 2021
Author: Julian Gross
Date: 2021-11-30T10:15:56+01:00
New Revision: ae1ea0bead75f4c7a4c965dfa40b5f3b78b60364
URL: https://github.com/llvm/llvm-project/commit/ae1ea0bead75f4c7a4c965dfa40b5f3b78b60364
DIFF: https://github.com/llvm/llvm-project/commit/ae1ea0bead75f4c7a4c965dfa40b5f3b78b60364.diff
LOG: [mlir] Decompose Bufferization Clone operation into Memref Alloc and Copy.
This patch introduces a new conversion to convert bufferization.clone operations
into a memref.alloc and a memref.copy operation. This transformation is needed to
transform all remaining clones which "survive" all previous transformations, before
a given program is lowered further (to LLVM e.g.). Otherwise, these operations
cannot be handled anymore and lead to compile errors.
See: https://llvm.discourse.group/t/bufferization-error-related-to-memref-clone/4665
Differential Revision: https://reviews.llvm.org/D114233
Added:
mlir/include/mlir/Conversion/BufferizationToMemRef/BufferizationToMemRef.h
mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp
mlir/lib/Conversion/BufferizationToMemRef/CMakeLists.txt
mlir/test/Conversion/BufferizationToMemRef/bufferization-to-memref.mlir
Modified:
mlir/include/mlir/Conversion/Passes.h
mlir/include/mlir/Conversion/Passes.td
mlir/lib/Conversion/CMakeLists.txt
Removed:
################################################################################
diff --git a/mlir/include/mlir/Conversion/BufferizationToMemRef/BufferizationToMemRef.h b/mlir/include/mlir/Conversion/BufferizationToMemRef/BufferizationToMemRef.h
new file mode 100644
index 0000000000000..d8cb1524cf222
--- /dev/null
+++ b/mlir/include/mlir/Conversion/BufferizationToMemRef/BufferizationToMemRef.h
@@ -0,0 +1,26 @@
+//===- BufferizationToMemRef.h - Bufferization to MemRef conversion -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CONVERSION_BUFFERIZATIONTOMEMREF_BUFFERIZATIONTOMEMREF_H
+#define MLIR_CONVERSION_BUFFERIZATIONTOMEMREF_BUFFERIZATIONTOMEMREF_H
+
+#include <memory>
+
+namespace mlir {
+class Pass;
+class RewritePatternSet;
+
+/// Collect a set of patterns to convert memory-related operations from the
+/// Bufferization dialect to the MemRef dialect.
+void populateBufferizationToMemRefConversionPatterns(
+ RewritePatternSet &patterns);
+
+std::unique_ptr<Pass> createBufferizationToMemRefPass();
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_BUFFERIZATIONTOMEMREF_BUFFERIZATIONTOMEMREF_H
diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h
index 015023e083ec9..b39105aad3892 100644
--- a/mlir/include/mlir/Conversion/Passes.h
+++ b/mlir/include/mlir/Conversion/Passes.h
@@ -14,6 +14,7 @@
#include "mlir/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.h"
#include "mlir/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.h"
#include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h"
+#include "mlir/Conversion/BufferizationToMemRef/BufferizationToMemRef.h"
#include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h"
#include "mlir/Conversion/ComplexToStandard/ComplexToStandard.h"
#include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 8d806649d901a..31ef83579b2f9 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -126,6 +126,17 @@ def ConvertAsyncToLLVM : Pass<"convert-async-to-llvm", "ModuleOp"> {
];
}
+//===----------------------------------------------------------------------===//
+// BufferizationToMemRef
+//===----------------------------------------------------------------------===//
+
+def ConvertBufferizationToMemRef : Pass<"convert-bufferization-to-memref"> {
+ let summary = "Convert operations from the Bufferization dialect to the "
+ "MemRef dialect";
+ let constructor = "mlir::createBufferizationToMemRefPass()";
+ let dependentDialects = ["arith::ArithmeticDialect", "memref::MemRefDialect"];
+}
+
//===----------------------------------------------------------------------===//
// ComplexToLLVM
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp b/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp
new file mode 100644
index 0000000000000..66ed3474ff359
--- /dev/null
+++ b/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp
@@ -0,0 +1,91 @@
+//===- BufferizationToMemRef.cpp - Bufferization to MemRef conversion -----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements patterns to convert Bufferization dialect to MemRef
+// dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#include "../PassDetail.h"
+#include "mlir/Conversion/BufferizationToMemRef/BufferizationToMemRef.h"
+#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
+#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+using namespace mlir;
+
+namespace {
+/// The CloneOpConversion transforms all bufferization clone operations into
+/// memref alloc and memref copy operations. In the dynamic-shape case, it also
+/// emits additional dim and constant operations to determine the shape. This
+/// conversion does not resolve memory leaks if it is used alone.
+struct CloneOpConversion : public OpConversionPattern<bufferization::CloneOp> {
+ using OpConversionPattern<bufferization::CloneOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(bufferization::CloneOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ // Check for unranked memref types which are currently not supported.
+ Type type = op.getType();
+ if (type.isa<UnrankedMemRefType>()) {
+ return rewriter.notifyMatchFailure(
+ op, "UnrankedMemRefType is not supported.");
+ }
+
+ // Transform a clone operation into alloc + copy operation and pay
+ // attention to the shape dimensions.
+ MemRefType memrefType = type.cast<MemRefType>();
+ Location loc = op->getLoc();
+ SmallVector<Value, 4> dynamicOperands;
+ for (int i = 0; i < memrefType.getRank(); ++i) {
+ if (!memrefType.isDynamicDim(i))
+ continue;
+ Value size = rewriter.createOrFold<arith::ConstantIndexOp>(loc, i);
+ Value dim = rewriter.createOrFold<memref::DimOp>(loc, op.input(), size);
+ dynamicOperands.push_back(dim);
+ }
+ Value alloc = rewriter.replaceOpWithNewOp<memref::AllocOp>(op, memrefType,
+ dynamicOperands);
+ rewriter.create<memref::CopyOp>(loc, op.input(), alloc);
+ return success();
+ }
+};
+} // namespace
+
+void mlir::populateBufferizationToMemRefConversionPatterns(
+ RewritePatternSet &patterns) {
+ patterns.add<CloneOpConversion>(patterns.getContext());
+}
+
+namespace {
+struct BufferizationToMemRefPass
+ : public ConvertBufferizationToMemRefBase<BufferizationToMemRefPass> {
+ BufferizationToMemRefPass() = default;
+
+ void runOnOperation() override {
+ RewritePatternSet patterns(&getContext());
+ populateBufferizationToMemRefConversionPatterns(patterns);
+
+ ConversionTarget target(getContext());
+ target.addLegalDialect<memref::MemRefDialect>();
+ target.addLegalOp<arith::ConstantOp>();
+ target.addIllegalDialect<bufferization::BufferizationDialect>();
+
+ if (failed(applyPartialConversion(getOperation(), target,
+ std::move(patterns))))
+ signalPassFailure();
+ }
+};
+} // namespace
+
+std::unique_ptr<Pass> mlir::createBufferizationToMemRefPass() {
+ return std::make_unique<BufferizationToMemRefPass>();
+}
diff --git a/mlir/lib/Conversion/BufferizationToMemRef/CMakeLists.txt b/mlir/lib/Conversion/BufferizationToMemRef/CMakeLists.txt
new file mode 100644
index 0000000000000..9f1a12b149e6a
--- /dev/null
+++ b/mlir/lib/Conversion/BufferizationToMemRef/CMakeLists.txt
@@ -0,0 +1,12 @@
+add_mlir_conversion_library(MLIRBufferizationToMemRef
+ BufferizationToMemRef.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/BufferizationToMemRef
+
+ DEPENDS
+ MLIRConversionPassIncGen
+
+ LINK_LIBS PUBLIC
+ MLIRBufferization
+ )
diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index 9f5244b762a55..602b9e72b330a 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -3,6 +3,7 @@ add_subdirectory(ArithmeticToLLVM)
add_subdirectory(ArithmeticToSPIRV)
add_subdirectory(ArmNeon2dToIntr)
add_subdirectory(AsyncToLLVM)
+add_subdirectory(BufferizationToMemRef)
add_subdirectory(ComplexToLLVM)
add_subdirectory(ComplexToStandard)
add_subdirectory(GPUCommon)
diff --git a/mlir/test/Conversion/BufferizationToMemRef/bufferization-to-memref.mlir b/mlir/test/Conversion/BufferizationToMemRef/bufferization-to-memref.mlir
new file mode 100644
index 0000000000000..4b1d17742e919
--- /dev/null
+++ b/mlir/test/Conversion/BufferizationToMemRef/bufferization-to-memref.mlir
@@ -0,0 +1,38 @@
+// RUN: mlir-opt -verify-diagnostics -convert-bufferization-to-memref -split-input-file %s | FileCheck %s
+
+// CHECK-LABEL: @conversion_static
+func @conversion_static(%arg0 : memref<2xf32>) -> memref<2xf32> {
+ %0 = bufferization.clone %arg0 : memref<2xf32> to memref<2xf32>
+ memref.dealloc %arg0 : memref<2xf32>
+ return %0 : memref<2xf32>
+}
+
+// CHECK: %[[ALLOC:.*]] = memref.alloc
+// CHECK-NEXT: memref.copy %[[ARG:.*]], %[[ALLOC]]
+// CHECK-NEXT: memref.dealloc %[[ARG]]
+// CHECK-NEXT: return %[[ALLOC]]
+
+// -----
+
+// CHECK-LABEL: @conversion_dynamic
+func @conversion_dynamic(%arg0 : memref<?xf32>) -> memref<?xf32> {
+ %1 = bufferization.clone %arg0 : memref<?xf32> to memref<?xf32>
+ memref.dealloc %arg0 : memref<?xf32>
+ return %1 : memref<?xf32>
+}
+
+// CHECK: %[[CONST:.*]] = arith.constant
+// CHECK-NEXT: %[[DIM:.*]] = memref.dim %[[ARG:.*]], %[[CONST]]
+// CHECK-NEXT: %[[ALLOC:.*]] = memref.alloc(%[[DIM]])
+// CHECK-NEXT: memref.copy %[[ARG]], %[[ALLOC]]
+// CHECK-NEXT: memref.dealloc %[[ARG]]
+// CHECK-NEXT: return %[[ALLOC]]
+
+// -----
+
+func @conversion_unknown(%arg0 : memref<*xf32>) -> memref<*xf32> {
+// expected-error at +1 {{failed to legalize operation 'bufferization.clone' that was explicitly marked illegal}}
+ %1 = bufferization.clone %arg0 : memref<*xf32> to memref<*xf32>
+ memref.dealloc %arg0 : memref<*xf32>
+ return %1 : memref<*xf32>
+}
More information about the Mlir-commits
mailing list