[Mlir-commits] [mlir] ae1ea0b - [mlir] Decompose Bufferization Clone operation into Memref Alloc and Copy.

Julian Gross llvmlistbot at llvm.org
Tue Nov 30 01:17:35 PST 2021


Author: Julian Gross
Date: 2021-11-30T10:15:56+01:00
New Revision: ae1ea0bead75f4c7a4c965dfa40b5f3b78b60364

URL: https://github.com/llvm/llvm-project/commit/ae1ea0bead75f4c7a4c965dfa40b5f3b78b60364
DIFF: https://github.com/llvm/llvm-project/commit/ae1ea0bead75f4c7a4c965dfa40b5f3b78b60364.diff

LOG: [mlir] Decompose Bufferization Clone operation into Memref Alloc and Copy.

This patch introduces a new conversion to convert bufferization.clone operations
into a memref.alloc and a memref.copy operation. This transformation is needed to
transform all remaining clones which "survive" all previous transformations, before
a given program is lowered further (to LLVM e.g.). Otherwise, these operations
cannot be handled anymore and lead to compile errors.
See: https://llvm.discourse.group/t/bufferization-error-related-to-memref-clone/4665

Differential Revision: https://reviews.llvm.org/D114233

Added: 
    mlir/include/mlir/Conversion/BufferizationToMemRef/BufferizationToMemRef.h
    mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp
    mlir/lib/Conversion/BufferizationToMemRef/CMakeLists.txt
    mlir/test/Conversion/BufferizationToMemRef/bufferization-to-memref.mlir

Modified: 
    mlir/include/mlir/Conversion/Passes.h
    mlir/include/mlir/Conversion/Passes.td
    mlir/lib/Conversion/CMakeLists.txt

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Conversion/BufferizationToMemRef/BufferizationToMemRef.h b/mlir/include/mlir/Conversion/BufferizationToMemRef/BufferizationToMemRef.h
new file mode 100644
index 0000000000000..d8cb1524cf222
--- /dev/null
+++ b/mlir/include/mlir/Conversion/BufferizationToMemRef/BufferizationToMemRef.h
@@ -0,0 +1,26 @@
+//===- BufferizationToMemRef.h - Bufferization to MemRef conversion -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CONVERSION_BUFFERIZATIONTOMEMREF_BUFFERIZATIONTOMEMREF_H
+#define MLIR_CONVERSION_BUFFERIZATIONTOMEMREF_BUFFERIZATIONTOMEMREF_H
+
+#include <memory>
+
+namespace mlir {
+class Pass;
+class RewritePatternSet;
+
+/// Collect a set of patterns to convert memory-related operations from the
+/// Bufferization dialect to the MemRef dialect.
+void populateBufferizationToMemRefConversionPatterns(
+    RewritePatternSet &patterns);
+
+std::unique_ptr<Pass> createBufferizationToMemRefPass();
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_BUFFERIZATIONTOMEMREF_BUFFERIZATIONTOMEMREF_H

diff  --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h
index 015023e083ec9..b39105aad3892 100644
--- a/mlir/include/mlir/Conversion/Passes.h
+++ b/mlir/include/mlir/Conversion/Passes.h
@@ -14,6 +14,7 @@
 #include "mlir/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.h"
 #include "mlir/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.h"
 #include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h"
+#include "mlir/Conversion/BufferizationToMemRef/BufferizationToMemRef.h"
 #include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h"
 #include "mlir/Conversion/ComplexToStandard/ComplexToStandard.h"
 #include "mlir/Conversion/GPUCommon/GPUCommonPass.h"

diff  --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 8d806649d901a..31ef83579b2f9 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -126,6 +126,17 @@ def ConvertAsyncToLLVM : Pass<"convert-async-to-llvm", "ModuleOp"> {
   ];
 }
 
+//===----------------------------------------------------------------------===//
+// BufferizationToMemRef
+//===----------------------------------------------------------------------===//
+
+def ConvertBufferizationToMemRef : Pass<"convert-bufferization-to-memref"> {
+  let summary = "Convert operations from the Bufferization dialect to the "
+                "MemRef dialect";
+  let constructor = "mlir::createBufferizationToMemRefPass()";
+  let dependentDialects = ["arith::ArithmeticDialect", "memref::MemRefDialect"];
+}
+
 //===----------------------------------------------------------------------===//
 // ComplexToLLVM
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp b/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp
new file mode 100644
index 0000000000000..66ed3474ff359
--- /dev/null
+++ b/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp
@@ -0,0 +1,91 @@
+//===- BufferizationToMemRef.cpp - Bufferization to MemRef conversion -----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements patterns to convert Bufferization dialect to MemRef
+// dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#include "../PassDetail.h"
+#include "mlir/Conversion/BufferizationToMemRef/BufferizationToMemRef.h"
+#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
+#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+using namespace mlir;
+
+namespace {
+/// The CloneOpConversion transforms all bufferization clone operations into
+/// memref alloc and memref copy operations. In the dynamic-shape case, it also
+/// emits additional dim and constant operations to determine the shape. This
+/// conversion does not resolve memory leaks if it is used alone.
+struct CloneOpConversion : public OpConversionPattern<bufferization::CloneOp> {
+  using OpConversionPattern<bufferization::CloneOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(bufferization::CloneOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    // Check for unranked memref types which are currently not supported.
+    Type type = op.getType();
+    if (type.isa<UnrankedMemRefType>()) {
+      return rewriter.notifyMatchFailure(
+          op, "UnrankedMemRefType is not supported.");
+    }
+
+    // Transform a clone operation into alloc + copy operation and pay
+    // attention to the shape dimensions.
+    MemRefType memrefType = type.cast<MemRefType>();
+    Location loc = op->getLoc();
+    SmallVector<Value, 4> dynamicOperands;
+    for (int i = 0; i < memrefType.getRank(); ++i) {
+      if (!memrefType.isDynamicDim(i))
+        continue;
+      Value size = rewriter.createOrFold<arith::ConstantIndexOp>(loc, i);
+      Value dim = rewriter.createOrFold<memref::DimOp>(loc, op.input(), size);
+      dynamicOperands.push_back(dim);
+    }
+    Value alloc = rewriter.replaceOpWithNewOp<memref::AllocOp>(op, memrefType,
+                                                               dynamicOperands);
+    rewriter.create<memref::CopyOp>(loc, op.input(), alloc);
+    return success();
+  }
+};
+} // namespace
+
+void mlir::populateBufferizationToMemRefConversionPatterns(
+    RewritePatternSet &patterns) {
+  patterns.add<CloneOpConversion>(patterns.getContext());
+}
+
+namespace {
+struct BufferizationToMemRefPass
+    : public ConvertBufferizationToMemRefBase<BufferizationToMemRefPass> {
+  BufferizationToMemRefPass() = default;
+
+  void runOnOperation() override {
+    RewritePatternSet patterns(&getContext());
+    populateBufferizationToMemRefConversionPatterns(patterns);
+
+    ConversionTarget target(getContext());
+    target.addLegalDialect<memref::MemRefDialect>();
+    target.addLegalOp<arith::ConstantOp>();
+    target.addIllegalDialect<bufferization::BufferizationDialect>();
+
+    if (failed(applyPartialConversion(getOperation(), target,
+                                      std::move(patterns))))
+      signalPassFailure();
+  }
+};
+} // namespace
+
+std::unique_ptr<Pass> mlir::createBufferizationToMemRefPass() {
+  return std::make_unique<BufferizationToMemRefPass>();
+}

diff  --git a/mlir/lib/Conversion/BufferizationToMemRef/CMakeLists.txt b/mlir/lib/Conversion/BufferizationToMemRef/CMakeLists.txt
new file mode 100644
index 0000000000000..9f1a12b149e6a
--- /dev/null
+++ b/mlir/lib/Conversion/BufferizationToMemRef/CMakeLists.txt
@@ -0,0 +1,12 @@
+add_mlir_conversion_library(MLIRBufferizationToMemRef
+  BufferizationToMemRef.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/BufferizationToMemRef
+
+  DEPENDS
+  MLIRConversionPassIncGen
+
+  LINK_LIBS PUBLIC
+  MLIRBufferization
+  )

diff  --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index 9f5244b762a55..602b9e72b330a 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -3,6 +3,7 @@ add_subdirectory(ArithmeticToLLVM)
 add_subdirectory(ArithmeticToSPIRV)
 add_subdirectory(ArmNeon2dToIntr)
 add_subdirectory(AsyncToLLVM)
+add_subdirectory(BufferizationToMemRef)
 add_subdirectory(ComplexToLLVM)
 add_subdirectory(ComplexToStandard)
 add_subdirectory(GPUCommon)

diff  --git a/mlir/test/Conversion/BufferizationToMemRef/bufferization-to-memref.mlir b/mlir/test/Conversion/BufferizationToMemRef/bufferization-to-memref.mlir
new file mode 100644
index 0000000000000..4b1d17742e919
--- /dev/null
+++ b/mlir/test/Conversion/BufferizationToMemRef/bufferization-to-memref.mlir
@@ -0,0 +1,38 @@
+// RUN: mlir-opt -verify-diagnostics -convert-bufferization-to-memref -split-input-file %s | FileCheck %s
+
+// CHECK-LABEL: @conversion_static
+func @conversion_static(%arg0 : memref<2xf32>) -> memref<2xf32> {
+    %0 = bufferization.clone %arg0 : memref<2xf32> to memref<2xf32>
+    memref.dealloc %arg0 : memref<2xf32>
+    return %0 : memref<2xf32>
+}
+
+// CHECK:      %[[ALLOC:.*]] = memref.alloc
+// CHECK-NEXT: memref.copy %[[ARG:.*]], %[[ALLOC]]
+// CHECK-NEXT: memref.dealloc %[[ARG]]
+// CHECK-NEXT: return %[[ALLOC]]
+
+// -----
+
+// CHECK-LABEL: @conversion_dynamic
+func @conversion_dynamic(%arg0 : memref<?xf32>) -> memref<?xf32> {
+    %1 = bufferization.clone %arg0 : memref<?xf32> to memref<?xf32>
+    memref.dealloc %arg0 : memref<?xf32>
+    return %1 : memref<?xf32>
+}
+
+// CHECK:      %[[CONST:.*]] = arith.constant
+// CHECK-NEXT: %[[DIM:.*]] = memref.dim %[[ARG:.*]], %[[CONST]]
+// CHECK-NEXT: %[[ALLOC:.*]] = memref.alloc(%[[DIM]])
+// CHECK-NEXT: memref.copy %[[ARG]], %[[ALLOC]]
+// CHECK-NEXT: memref.dealloc %[[ARG]]
+// CHECK-NEXT: return %[[ALLOC]]
+
+// -----
+
+func @conversion_unknown(%arg0 : memref<*xf32>) -> memref<*xf32> {
+// expected-error at +1 {{failed to legalize operation 'bufferization.clone' that was explicitly marked illegal}}
+    %1 = bufferization.clone %arg0 : memref<*xf32> to memref<*xf32>
+    memref.dealloc %arg0 : memref<*xf32>
+    return %1 : memref<*xf32>
+}


        


More information about the Mlir-commits mailing list