[Mlir-commits] [mlir] [mlir][Bufferization] castOrReallocMemRefValue: Use BufferizationOptions (PR #89175)

Matthias Gehre llvmlistbot at llvm.org
Thu Apr 18 00:39:12 PDT 2024


https://github.com/mgehre-amd updated https://github.com/llvm/llvm-project/pull/89175

>From 90a464a7bdae6969758fcb96b45857d645bfa9f5 Mon Sep 17 00:00:00 2001
From: Matthias Gehre <matthias.gehre at amd.com>
Date: Sun, 7 Apr 2024 20:13:01 +0200
Subject: [PATCH] [mlir][Bufferization] castOrReallocMemRefValue: Use
 BufferizationOptions

---
 .../Dialect/Bufferization/IR/Bufferization.h  |  6 ++--
 .../Bufferization/IR/BufferizationOps.cpp     | 31 +++++++++++--------
 .../Bufferization/Transforms/Bufferize.cpp    |  6 ++--
 .../Transforms/finalizing-bufferize.mlir      |  6 ++--
 .../one-shot-module-bufferize-out-params.mlir |  2 +-
 .../Transforms/one-shot-module-bufferize.mlir |  2 +-
 .../Dialect/Bufferization/canonicalize.mlir   |  2 +-
 7 files changed, 31 insertions(+), 24 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h b/mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h
index e98b5728b38ef8..6f19dca2e82224 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h
@@ -53,12 +53,14 @@ void populateDynamicDimSizes(OpBuilder &b, Location loc, Value shapedValue,
 /// This function returns `failure()` in case of unsupported casts. E.g., casts
 /// with differing element types or memory spaces.
 FailureOr<Value> castOrReallocMemRefValue(OpBuilder &b, Value value,
-                                          MemRefType type);
+                                          MemRefType type,
+                                          const BufferizationOptions &options);
 
 /// Try to fold to_memref(to_tensor(x)). If x's type and the result type of the
 /// to_memref op are different, a memref.cast is needed.
 LogicalResult foldToMemrefToTensorPair(RewriterBase &rewriter,
-                                       ToMemrefOp toMemref);
+                                       ToMemrefOp toMemref,
+                                       const BufferizationOptions &options);
 
 /// Add the canonicalization patterns for bufferization.dealloc to the given
 /// pattern set to make them available to other passes (such as
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
index a656c812a59feb..5a8dd3bbca41ca 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
@@ -23,9 +23,9 @@ using namespace mlir::bufferization;
 // Helper functions
 //===----------------------------------------------------------------------===//
 
-FailureOr<Value>
-mlir::bufferization::castOrReallocMemRefValue(OpBuilder &b, Value value,
-                                              MemRefType destType) {
+FailureOr<Value> mlir::bufferization::castOrReallocMemRefValue(
+    OpBuilder &b, Value value, MemRefType destType,
+    const BufferizationOptions &options) {
   auto srcType = llvm::cast<MemRefType>(value.getType());
 
   // Element type, rank and memory space must match.
@@ -73,18 +73,23 @@ mlir::bufferization::castOrReallocMemRefValue(OpBuilder &b, Value value,
     Value size = b.create<memref::DimOp>(loc, value, i);
     dynamicOperands.push_back(size);
   }
-  // TODO: Use alloc/memcpy callback from BufferizationOptions if called via
-  // BufferizableOpInterface impl of ToMemrefOp.
-  Value copy = b.create<memref::AllocOp>(loc, destType, dynamicOperands);
-  b.create<memref::CopyOp>(loc, value, copy);
+
+  FailureOr<Value> copy =
+      options.createAlloc(b, loc, destType, dynamicOperands);
+  if (failed(copy)) {
+    return failure();
+  }
+  if (failed(options.createMemCpy(b, loc, value, *copy))) {
+    return failure();
+  }
   return copy;
 }
 
 /// Try to fold to_memref(to_tensor(x)). If x's type and the result type of the
 /// to_memref op are different, a memref.cast is needed.
-LogicalResult
-mlir::bufferization::foldToMemrefToTensorPair(RewriterBase &rewriter,
-                                              ToMemrefOp toMemref) {
+LogicalResult mlir::bufferization::foldToMemrefToTensorPair(
+    RewriterBase &rewriter, ToMemrefOp toMemref,
+    const BufferizationOptions &options) {
   auto memrefToTensor = toMemref.getTensor().getDefiningOp<ToTensorOp>();
   if (!memrefToTensor)
     return failure();
@@ -105,7 +110,7 @@ mlir::bufferization::foldToMemrefToTensorPair(RewriterBase &rewriter,
   // Ranked memref -> Ranked memref cast.
   if (rankedSrcType && rankedDestType) {
     FailureOr<Value> replacement = castOrReallocMemRefValue(
-        rewriter, memrefToTensor.getMemref(), rankedDestType);
+        rewriter, memrefToTensor.getMemref(), rankedDestType, options);
     if (failed(replacement))
       return failure();
 
@@ -795,7 +800,7 @@ struct ToMemrefToTensorFolding : public OpRewritePattern<ToMemrefOp> {
 
   LogicalResult matchAndRewrite(ToMemrefOp toMemref,
                                 PatternRewriter &rewriter) const final {
-    return foldToMemrefToTensorPair(rewriter, toMemref);
+    return foldToMemrefToTensorPair(rewriter, toMemref, {});
   }
 };
 
@@ -843,7 +848,7 @@ void ToMemrefOp::getCanonicalizationPatterns(RewritePatternSet &results,
 LogicalResult ToMemrefOp::bufferize(RewriterBase &rewriter,
                                     const BufferizationOptions &options) {
   // Fold to_memref(to_tensor(x)) to x. Insert a cast if necessary.
-  (void)foldToMemrefToTensorPair(rewriter, *this);
+  (void)foldToMemrefToTensorPair(rewriter, *this, options);
   // Note: The return value of `bufferize` indicates whether there was an error
   // or not. (And not whether the pattern matched or not.)
   return success();
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
index 32f4e6a0fe8901..786a071dccfe67 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
@@ -75,7 +75,7 @@ BufferizeTypeConverter::BufferizeTypeConverter() {
       if (!rankedDestType)
         return nullptr;
       FailureOr<Value> replacement =
-          castOrReallocMemRefValue(builder, inputs[0], rankedDestType);
+          castOrReallocMemRefValue(builder, inputs[0], rankedDestType, {});
       if (failed(replacement))
         return nullptr;
       return *replacement;
@@ -512,8 +512,8 @@ LogicalResult bufferization::bufferizeOp(Operation *op,
   // Fold all to_memref(to_tensor(x)) pairs.
   for (Operation *op : toMemrefOps) {
     rewriter.setInsertionPoint(op);
-    (void)bufferization::foldToMemrefToTensorPair(rewriter,
-                                                  cast<ToMemrefOp>(op));
+    (void)bufferization::foldToMemrefToTensorPair(
+        rewriter, cast<ToMemrefOp>(op), options);
   }
 
   // Remove all dead to_tensor ops.
diff --git a/mlir/test/Dialect/Bufferization/Transforms/finalizing-bufferize.mlir b/mlir/test/Dialect/Bufferization/Transforms/finalizing-bufferize.mlir
index ff94c1b331d928..7f1e009c303a68 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/finalizing-bufferize.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/finalizing-bufferize.mlir
@@ -33,7 +33,7 @@ func.func @unable_to_convert_lone_tensor_load(%arg0: memref<f32>) {
 //  CHECK-SAME:     %[[arg:.*]]: memref<?xf32, strided<[1], offset: ?>>)
 //       CHECK:   %[[c0:.*]] = arith.constant 0 : index
 //       CHECK:   %[[dim:.*]] = memref.dim %[[arg]], %[[c0]]
-//       CHECK:   %[[alloc:.*]] = memref.alloc(%[[dim]]) : memref<?xf32>
+//       CHECK:   %[[alloc:.*]] = memref.alloc(%[[dim]]) {{.*}} : memref<?xf32>
 //       CHECK:   memref.copy %[[arg]], %[[alloc]]
 //       CHECK:   return %[[alloc]]
 func.func @dyn_layout_to_no_layout_cast(%m: memref<?xf32, strided<[1], offset: ?>>) -> memref<?xf32> {
@@ -48,7 +48,7 @@ func.func @dyn_layout_to_no_layout_cast(%m: memref<?xf32, strided<[1], offset: ?
 //  CHECK-SAME:     %[[arg:.*]]: memref<?xf32, strided<[100], offset: ?>>)
 //       CHECK:   %[[c0:.*]] = arith.constant 0 : index
 //       CHECK:   %[[dim:.*]] = memref.dim %[[arg]], %[[c0]]
-//       CHECK:   %[[alloc:.*]] = memref.alloc(%[[dim]]) : memref<?xf32>
+//       CHECK:   %[[alloc:.*]] = memref.alloc(%[[dim]]) {{.*}} : memref<?xf32>
 //       CHECK:   memref.copy %[[arg]], %[[alloc]]
 //       CHECK:   return %[[alloc]]
 func.func @fancy_layout_to_no_layout_cast(%m: memref<?xf32, strided<[100], offset: ?>>) -> memref<?xf32> {
@@ -63,7 +63,7 @@ func.func @fancy_layout_to_no_layout_cast(%m: memref<?xf32, strided<[100], offse
 //  CHECK-SAME:     %[[arg:.*]]: memref<?xf32, strided<[1], offset: 25>>)
 //       CHECK:   %[[c0:.*]] = arith.constant 0 : index
 //       CHECK:   %[[dim:.*]] = memref.dim %[[arg]], %[[c0]]
-//       CHECK:   %[[alloc:.*]] = memref.alloc(%[[dim]]) : memref<?xf32>
+//       CHECK:   %[[alloc:.*]] = memref.alloc(%[[dim]]) {{.*}} : memref<?xf32>
 //       CHECK:   memref.copy %[[arg]], %[[alloc]]
 //       CHECK:   return %[[alloc]]
 func.func @static_layout_to_no_layout_cast(%m: memref<?xf32, strided<[1], offset: 25>>) -> memref<?xf32> {
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-out-params.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-out-params.mlir
index de75b288855f94..9cf44c335d551e 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-out-params.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-out-params.mlir
@@ -84,7 +84,7 @@ func.func @main(%t: tensor<5xf32>) -> (f32, f32) {
 // Note: This alloc is not needed, but it is inserted before the returned buffer
 // is promoted to an out param to reconcile mismatching layout maps on return
 // value and function signature.
-//       CHECK-NO-LAYOUT:   %[[alloc2:.*]] = memref.alloc() : memref<2x5xf32>
+//       CHECK-NO-LAYOUT:   %[[alloc2:.*]] = memref.alloc() {{.*}} : memref<2x5xf32>
 //       CHECK-NO-LAYOUT:   memref.copy %[[subview]], %[[alloc2]]
 //       CHECK-NO-LAYOUT:   memref.copy %[[alloc2]], %[[r]]
 
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir
index 429c9e4dea9e93..0248afb11f1672 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir
@@ -52,7 +52,7 @@ func.func private @external_func_with_return_val(tensor<4xi32>) -> f32
 // CHECK-NO-LAYOUT-MAP-LABEL: func @return_extract_slice(%{{.*}}) -> memref<2x?xf32>
 //       CHECK-NO-LAYOUT-MAP:   %[[alloc:.*]] = memref.alloc() {{.*}} : memref<20x10xf32>
 //       CHECK-NO-LAYOUT-MAP:   %[[subview:.*]] = memref.subview {{.*}} : memref<20x10xf32> to memref<2x?xf32, strided<[10, 1], offset: ?>>
-//       CHECK-NO-LAYOUT-MAP:   %[[alloc_no_layout:.*]] = memref.alloc(%{{.*}}) : memref<2x?xf32>
+//       CHECK-NO-LAYOUT-MAP:   %[[alloc_no_layout:.*]] = memref.alloc(%{{.*}}) {{.*}} : memref<2x?xf32>
 //       CHECK-NO-LAYOUT-MAP:   memref.copy %[[subview]], %[[alloc_no_layout]]
 // TODO: %alloc should be deallocated here, but we currently do not dealloc
 // buffers that are inserted due to to_tensor/to_memref canonicalization (when
diff --git a/mlir/test/Dialect/Bufferization/canonicalize.mlir b/mlir/test/Dialect/Bufferization/canonicalize.mlir
index b6c0a0e25efe0e..113aad67985d70 100644
--- a/mlir/test/Dialect/Bufferization/canonicalize.mlir
+++ b/mlir/test/Dialect/Bufferization/canonicalize.mlir
@@ -84,7 +84,7 @@ func.func @canonicalize_buffer_cast_of_tensor_load_to_copy(
 //  CHECK-NOT: bufferization.to_memref
 //      CHECK: %[[C0:.*]] = arith.constant 0 : index
 //      CHECK: %[[DIM:.*]] = memref.dim %[[M]], %[[C0]] : memref<?xf32, strided<[1], offset: ?>>
-//      CHECK: %[[ALLOC:.*]] = memref.alloc(%[[DIM]]) : memref<?xf32, strided<[1], offset: 3>>
+//      CHECK: %[[ALLOC:.*]] = memref.alloc(%[[DIM]]) {{.*}} : memref<?xf32, strided<[1], offset: 3>>
 //      CHECK: memref.copy %[[M]], %[[ALLOC]]
 // CHECK-SAME:   memref<?xf32, strided<[1], offset: ?>> to memref<?xf32, strided<[1], offset: 3>>
 //      CHECK: return %[[ALLOC]]



More information about the Mlir-commits mailing list