[Mlir-commits] [mlir] 8dca38d - [mlir][bufferize] Support layout maps in bufferization.clone lowering
Matthias Springer
llvmlistbot at llvm.org
Wed Mar 16 06:34:10 PDT 2022
Author: Matthias Springer
Date: 2022-03-16T22:29:22+09:00
New Revision: 8dca38d53480a6f1382bf0035f16a432dc7cc999
URL: https://github.com/llvm/llvm-project/commit/8dca38d53480a6f1382bf0035f16a432dc7cc999
DIFF: https://github.com/llvm/llvm-project/commit/8dca38d53480a6f1382bf0035f16a432dc7cc999.diff
LOG: [mlir][bufferize] Support layout maps in bufferization.clone lowering
Differential Revision: https://reviews.llvm.org/D121278
Added:
Modified:
mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp
mlir/test/Conversion/BufferizationToMemRef/bufferization-to-memref.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp b/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp
index 66ed3474ff359..b28ca796f0397 100644
--- a/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp
+++ b/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp
@@ -39,10 +39,18 @@ struct CloneOpConversion : public OpConversionPattern<bufferization::CloneOp> {
return rewriter.notifyMatchFailure(
op, "UnrankedMemRefType is not supported.");
}
+ MemRefType memrefType = type.cast<MemRefType>();
+ MemRefLayoutAttrInterface layout;
+ auto allocType =
+ MemRefType::get(memrefType.getShape(), memrefType.getElementType(),
+ layout, memrefType.getMemorySpace());
+ // Since this implementation always allocates, certain result types of the
+ // clone op cannot be lowered.
+ if (!memref::CastOp::areCastCompatible({allocType}, {memrefType}))
+ return failure();
// Transform a clone operation into alloc + copy operation and pay
// attention to the shape dimensions.
- MemRefType memrefType = type.cast<MemRefType>();
Location loc = op->getLoc();
SmallVector<Value, 4> dynamicOperands;
for (int i = 0; i < memrefType.getRank(); ++i) {
@@ -52,8 +60,14 @@ struct CloneOpConversion : public OpConversionPattern<bufferization::CloneOp> {
Value dim = rewriter.createOrFold<memref::DimOp>(loc, op.input(), size);
dynamicOperands.push_back(dim);
}
- Value alloc = rewriter.replaceOpWithNewOp<memref::AllocOp>(op, memrefType,
- dynamicOperands);
+
+ // Allocate a memref with identity layout.
+ Value alloc = rewriter.create<memref::AllocOp>(op->getLoc(), allocType,
+ dynamicOperands);
+ // Cast the allocation to the specified type if needed.
+ if (memrefType != allocType)
+ alloc = rewriter.create<memref::CastOp>(op->getLoc(), memrefType, alloc);
+ rewriter.replaceOp(op, alloc);
rewriter.create<memref::CopyOp>(loc, op.input(), alloc);
return success();
}
diff --git a/mlir/test/Conversion/BufferizationToMemRef/bufferization-to-memref.mlir b/mlir/test/Conversion/BufferizationToMemRef/bufferization-to-memref.mlir
index 4b1d17742e919..8dbed159a2dab 100644
--- a/mlir/test/Conversion/BufferizationToMemRef/bufferization-to-memref.mlir
+++ b/mlir/test/Conversion/BufferizationToMemRef/bufferization-to-memref.mlir
@@ -2,9 +2,9 @@
// CHECK-LABEL: @conversion_static
func @conversion_static(%arg0 : memref<2xf32>) -> memref<2xf32> {
- %0 = bufferization.clone %arg0 : memref<2xf32> to memref<2xf32>
- memref.dealloc %arg0 : memref<2xf32>
- return %0 : memref<2xf32>
+ %0 = bufferization.clone %arg0 : memref<2xf32> to memref<2xf32>
+ memref.dealloc %arg0 : memref<2xf32>
+ return %0 : memref<2xf32>
}
// CHECK: %[[ALLOC:.*]] = memref.alloc
@@ -16,9 +16,9 @@ func @conversion_static(%arg0 : memref<2xf32>) -> memref<2xf32> {
// CHECK-LABEL: @conversion_dynamic
func @conversion_dynamic(%arg0 : memref<?xf32>) -> memref<?xf32> {
- %1 = bufferization.clone %arg0 : memref<?xf32> to memref<?xf32>
- memref.dealloc %arg0 : memref<?xf32>
- return %1 : memref<?xf32>
+ %1 = bufferization.clone %arg0 : memref<?xf32> to memref<?xf32>
+ memref.dealloc %arg0 : memref<?xf32>
+ return %1 : memref<?xf32>
}
// CHECK: %[[CONST:.*]] = arith.constant
@@ -32,7 +32,40 @@ func @conversion_dynamic(%arg0 : memref<?xf32>) -> memref<?xf32> {
func @conversion_unknown(%arg0 : memref<*xf32>) -> memref<*xf32> {
// expected-error at +1 {{failed to legalize operation 'bufferization.clone' that was explicitly marked illegal}}
- %1 = bufferization.clone %arg0 : memref<*xf32> to memref<*xf32>
- memref.dealloc %arg0 : memref<*xf32>
- return %1 : memref<*xf32>
+ %1 = bufferization.clone %arg0 : memref<*xf32> to memref<*xf32>
+ memref.dealloc %arg0 : memref<*xf32>
+ return %1 : memref<*xf32>
+}
+
+// -----
+
+// CHECK: #[[$MAP:.*]] = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)>
+#map = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)>
+// CHECK-LABEL: func @conversion_with_layout_map(
+// CHECK-SAME: %[[ARG:.*]]: memref<?xf32, #[[$MAP]]>
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[DIM:.*]] = memref.dim %[[ARG]], %[[C0]]
+// CHECK: %[[ALLOC:.*]] = memref.alloc(%[[DIM]]) : memref<?xf32>
+// CHECK: %[[CASTED:.*]] = memref.cast %[[ALLOC]] : memref<?xf32> to memref<?xf32, #[[$MAP]]>
+// CHECK: memref.copy
+// CHECK: memref.dealloc
+// CHECK: return %[[CASTED]]
+func @conversion_with_layout_map(%arg0 : memref<?xf32, #map>) -> memref<?xf32, #map> {
+ %1 = bufferization.clone %arg0 : memref<?xf32, #map> to memref<?xf32, #map>
+ memref.dealloc %arg0 : memref<?xf32, #map>
+ return %1 : memref<?xf32, #map>
+}
+
+// -----
+
+// This bufferization.clone cannot be lowered because a buffer with this layout
+// map cannot be allocated (or casted to).
+
+#map2 = affine_map<(d0)[s0] -> (d0 * 10 + s0)>
+func @conversion_with_invalid_layout_map(%arg0 : memref<?xf32, #map2>)
+ -> memref<?xf32, #map2> {
+// expected-error at +1 {{failed to legalize operation 'bufferization.clone' that was explicitly marked illegal}}
+ %1 = bufferization.clone %arg0 : memref<?xf32, #map2> to memref<?xf32, #map2>
+ memref.dealloc %arg0 : memref<?xf32, #map2>
+ return %1 : memref<?xf32, #map2>
}
More information about the Mlir-commits
mailing list