[Mlir-commits] [mlir] 8dca38d - [mlir][bufferize] Support layout maps in bufferization.clone lowering

Matthias Springer llvmlistbot at llvm.org
Wed Mar 16 06:34:10 PDT 2022


Author: Matthias Springer
Date: 2022-03-16T22:29:22+09:00
New Revision: 8dca38d53480a6f1382bf0035f16a432dc7cc999

URL: https://github.com/llvm/llvm-project/commit/8dca38d53480a6f1382bf0035f16a432dc7cc999
DIFF: https://github.com/llvm/llvm-project/commit/8dca38d53480a6f1382bf0035f16a432dc7cc999.diff

LOG: [mlir][bufferize] Support layout maps in bufferization.clone lowering

Differential Revision: https://reviews.llvm.org/D121278

Added: 
    

Modified: 
    mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp
    mlir/test/Conversion/BufferizationToMemRef/bufferization-to-memref.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp b/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp
index 66ed3474ff359..b28ca796f0397 100644
--- a/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp
+++ b/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp
@@ -39,10 +39,18 @@ struct CloneOpConversion : public OpConversionPattern<bufferization::CloneOp> {
       return rewriter.notifyMatchFailure(
           op, "UnrankedMemRefType is not supported.");
     }
+    MemRefType memrefType = type.cast<MemRefType>();
+    MemRefLayoutAttrInterface layout;
+    auto allocType =
+        MemRefType::get(memrefType.getShape(), memrefType.getElementType(),
+                        layout, memrefType.getMemorySpace());
+    // Since this implementation always allocates, certain result types of the
+    // clone op cannot be lowered.
+    if (!memref::CastOp::areCastCompatible({allocType}, {memrefType}))
+      return failure();
 
     // Transform a clone operation into alloc + copy operation and pay
     // attention to the shape dimensions.
-    MemRefType memrefType = type.cast<MemRefType>();
     Location loc = op->getLoc();
     SmallVector<Value, 4> dynamicOperands;
     for (int i = 0; i < memrefType.getRank(); ++i) {
@@ -52,8 +60,14 @@ struct CloneOpConversion : public OpConversionPattern<bufferization::CloneOp> {
       Value dim = rewriter.createOrFold<memref::DimOp>(loc, op.input(), size);
       dynamicOperands.push_back(dim);
     }
-    Value alloc = rewriter.replaceOpWithNewOp<memref::AllocOp>(op, memrefType,
-                                                               dynamicOperands);
+
+    // Allocate a memref with identity layout.
+    Value alloc = rewriter.create<memref::AllocOp>(op->getLoc(), allocType,
+                                                   dynamicOperands);
+    // Cast the allocation to the specified type if needed.
+    if (memrefType != allocType)
+      alloc = rewriter.create<memref::CastOp>(op->getLoc(), memrefType, alloc);
+    rewriter.replaceOp(op, alloc);
     rewriter.create<memref::CopyOp>(loc, op.input(), alloc);
     return success();
   }

diff  --git a/mlir/test/Conversion/BufferizationToMemRef/bufferization-to-memref.mlir b/mlir/test/Conversion/BufferizationToMemRef/bufferization-to-memref.mlir
index 4b1d17742e919..8dbed159a2dab 100644
--- a/mlir/test/Conversion/BufferizationToMemRef/bufferization-to-memref.mlir
+++ b/mlir/test/Conversion/BufferizationToMemRef/bufferization-to-memref.mlir
@@ -2,9 +2,9 @@
 
 // CHECK-LABEL: @conversion_static
 func @conversion_static(%arg0 : memref<2xf32>) -> memref<2xf32> {
-    %0 = bufferization.clone %arg0 : memref<2xf32> to memref<2xf32>
-    memref.dealloc %arg0 : memref<2xf32>
-    return %0 : memref<2xf32>
+  %0 = bufferization.clone %arg0 : memref<2xf32> to memref<2xf32>
+  memref.dealloc %arg0 : memref<2xf32>
+  return %0 : memref<2xf32>
 }
 
 // CHECK:      %[[ALLOC:.*]] = memref.alloc
@@ -16,9 +16,9 @@ func @conversion_static(%arg0 : memref<2xf32>) -> memref<2xf32> {
 
 // CHECK-LABEL: @conversion_dynamic
 func @conversion_dynamic(%arg0 : memref<?xf32>) -> memref<?xf32> {
-    %1 = bufferization.clone %arg0 : memref<?xf32> to memref<?xf32>
-    memref.dealloc %arg0 : memref<?xf32>
-    return %1 : memref<?xf32>
+  %1 = bufferization.clone %arg0 : memref<?xf32> to memref<?xf32>
+  memref.dealloc %arg0 : memref<?xf32>
+  return %1 : memref<?xf32>
 }
 
 // CHECK:      %[[CONST:.*]] = arith.constant
@@ -32,7 +32,40 @@ func @conversion_dynamic(%arg0 : memref<?xf32>) -> memref<?xf32> {
 
 func @conversion_unknown(%arg0 : memref<*xf32>) -> memref<*xf32> {
 // expected-error at +1 {{failed to legalize operation 'bufferization.clone' that was explicitly marked illegal}}
-    %1 = bufferization.clone %arg0 : memref<*xf32> to memref<*xf32>
-    memref.dealloc %arg0 : memref<*xf32>
-    return %1 : memref<*xf32>
+  %1 = bufferization.clone %arg0 : memref<*xf32> to memref<*xf32>
+  memref.dealloc %arg0 : memref<*xf32>
+  return %1 : memref<*xf32>
+}
+
+// -----
+
+// CHECK: #[[$MAP:.*]] = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)>
+#map = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)>
+// CHECK-LABEL: func @conversion_with_layout_map(
+//  CHECK-SAME:     %[[ARG:.*]]: memref<?xf32, #[[$MAP]]>
+//       CHECK:   %[[C0:.*]] = arith.constant 0 : index
+//       CHECK:   %[[DIM:.*]] = memref.dim %[[ARG]], %[[C0]]
+//       CHECK:   %[[ALLOC:.*]] = memref.alloc(%[[DIM]]) : memref<?xf32>
+//       CHECK:   %[[CASTED:.*]] = memref.cast %[[ALLOC]] : memref<?xf32> to memref<?xf32, #[[$MAP]]>
+//       CHECK:   memref.copy
+//       CHECK:   memref.dealloc
+//       CHECK:   return %[[CASTED]]
+func @conversion_with_layout_map(%arg0 : memref<?xf32, #map>) -> memref<?xf32, #map> {
+  %1 = bufferization.clone %arg0 : memref<?xf32, #map> to memref<?xf32, #map>
+  memref.dealloc %arg0 : memref<?xf32, #map>
+  return %1 : memref<?xf32, #map>
+}
+
+// -----
+
+// This bufferization.clone cannot be lowered because a buffer with this layout
+// map cannot be allocated (or casted to).
+
+#map2 = affine_map<(d0)[s0] -> (d0 * 10 + s0)>
+func @conversion_with_invalid_layout_map(%arg0 : memref<?xf32, #map2>)
+    -> memref<?xf32, #map2> {
+// expected-error at +1 {{failed to legalize operation 'bufferization.clone' that was explicitly marked illegal}}
+  %1 = bufferization.clone %arg0 : memref<?xf32, #map2> to memref<?xf32, #map2>
+  memref.dealloc %arg0 : memref<?xf32, #map2>
+  return %1 : memref<?xf32, #map2>
 }


        


More information about the Mlir-commits mailing list