[Mlir-commits] [mlir] [mlir][bufferization] Unranked memref support for clone (PR #94757)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Jun 10 05:44:59 PDT 2024
https://github.com/ryankima updated https://github.com/llvm/llvm-project/pull/94757
>From e0b88941207408fcc544b5ab2b389828f3ddf135 Mon Sep 17 00:00:00 2001
From: ryankim <ryankim at mathworks.com>
Date: Fri, 7 Jun 2024 09:42:41 -0400
Subject: [PATCH 1/2] [mlir][bufferization] Unranked memref support for clone
---
.../BufferizationToMemRef.cpp | 97 +++++++++++++------
.../bufferization-to-memref.mlir | 17 +++-
2 files changed, 83 insertions(+), 31 deletions(-)
diff --git a/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp b/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp
index 3069f6e073240..636a2b3d7a81f 100644
--- a/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp
+++ b/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp
@@ -42,39 +42,78 @@ struct CloneOpConversion : public OpConversionPattern<bufferization::CloneOp> {
LogicalResult
matchAndRewrite(bufferization::CloneOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- // Check for unranked memref types which are currently not supported.
+ Location loc = op->getLoc();
+
+ Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+ Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+
Type type = op.getType();
+ Value alloc;
+
if (isa<UnrankedMemRefType>(type)) {
- return rewriter.notifyMatchFailure(
- op, "UnrankedMemRefType is not supported.");
- }
- MemRefType memrefType = cast<MemRefType>(type);
- MemRefLayoutAttrInterface layout;
- auto allocType =
- MemRefType::get(memrefType.getShape(), memrefType.getElementType(),
- layout, memrefType.getMemorySpace());
- // Since this implementation always allocates, certain result types of the
- // clone op cannot be lowered.
- if (!memref::CastOp::areCastCompatible({allocType}, {memrefType}))
- return failure();
-
- // Transform a clone operation into alloc + copy operation and pay
- // attention to the shape dimensions.
- Location loc = op->getLoc();
- SmallVector<Value, 4> dynamicOperands;
- for (int i = 0; i < memrefType.getRank(); ++i) {
- if (!memrefType.isDynamicDim(i))
- continue;
- Value dim = rewriter.createOrFold<memref::DimOp>(loc, op.getInput(), i);
- dynamicOperands.push_back(dim);
+ // Dynamically evaluate the size and shape of the unranked memref
+ Value rank = rewriter.create<memref::RankOp>(loc, op.getInput());
+ MemRefType allocType =
+ MemRefType::get({ShapedType::kDynamic}, rewriter.getIndexType());
+ Value shape = rewriter.create<memref::AllocaOp>(loc, allocType, rank);
+
+ // Create a loop to query dimension sizes, store them as a shape, and
+ // compute the total size of the memref
+ auto size =
+ rewriter
+ .create<scf::ForOp>(
+ loc, zero, rank, one, ValueRange(one),
+ [&](OpBuilder &builder, Location loc, Value i,
+ ValueRange args) {
+ auto acc = args.front();
+
+ auto dim =
+ rewriter.create<memref::DimOp>(loc, op.getInput(), i);
+
+ rewriter.create<memref::StoreOp>(loc, dim, shape, i);
+ acc = rewriter.create<arith::MulIOp>(loc, acc, dim);
+
+ rewriter.create<scf::YieldOp>(loc, acc);
+ })
+ .getResult(0);
+
+ UnrankedMemRefType unranked = cast<UnrankedMemRefType>(type);
+ MemRefType memrefType =
+ MemRefType::get({ShapedType::kDynamic}, unranked.getElementType());
+
+ // Allocate new memref with 1D dynamic shape, then reshape into the
+ // shape of the original unranked memref
+ alloc = rewriter.create<memref::AllocOp>(loc, memrefType, size);
+ alloc = rewriter.create<memref::ReshapeOp>(loc, unranked, alloc, shape);
+ } else {
+ MemRefType memrefType = cast<MemRefType>(type);
+ MemRefLayoutAttrInterface layout;
+ auto allocType =
+ MemRefType::get(memrefType.getShape(), memrefType.getElementType(),
+ layout, memrefType.getMemorySpace());
+ // Since this implementation always allocates, certain result types of
+ // the clone op cannot be lowered.
+ if (!memref::CastOp::areCastCompatible({allocType}, {memrefType}))
+ return failure();
+
+ // Transform a clone operation into alloc + copy operation and pay
+ // attention to the shape dimensions.
+ SmallVector<Value, 4> dynamicOperands;
+ for (int i = 0; i < memrefType.getRank(); ++i) {
+ if (!memrefType.isDynamicDim(i))
+ continue;
+ Value dim = rewriter.createOrFold<memref::DimOp>(loc, op.getInput(), i);
+ dynamicOperands.push_back(dim);
+ }
+
+ // Allocate a memref with identity layout.
+ alloc = rewriter.create<memref::AllocOp>(loc, allocType, dynamicOperands);
+ // Cast the allocation to the specified type if needed.
+ if (memrefType != allocType)
+ alloc =
+ rewriter.create<memref::CastOp>(op->getLoc(), memrefType, alloc);
}
- // Allocate a memref with identity layout.
- Value alloc = rewriter.create<memref::AllocOp>(op->getLoc(), allocType,
- dynamicOperands);
- // Cast the allocation to the specified type if needed.
- if (memrefType != allocType)
- alloc = rewriter.create<memref::CastOp>(op->getLoc(), memrefType, alloc);
rewriter.replaceOp(op, alloc);
rewriter.create<memref::CopyOp>(loc, op.getInput(), alloc);
return success();
diff --git a/mlir/test/Conversion/BufferizationToMemRef/bufferization-to-memref.mlir b/mlir/test/Conversion/BufferizationToMemRef/bufferization-to-memref.mlir
index f58a2afa1a896..21d5f42158d09 100644
--- a/mlir/test/Conversion/BufferizationToMemRef/bufferization-to-memref.mlir
+++ b/mlir/test/Conversion/BufferizationToMemRef/bufferization-to-memref.mlir
@@ -22,7 +22,7 @@ func.func @conversion_dynamic(%arg0 : memref<?xf32>) -> memref<?xf32> {
}
// CHECK: %[[CONST:.*]] = arith.constant
-// CHECK-NEXT: %[[DIM:.*]] = memref.dim %[[ARG:.*]], %[[CONST]]
+// CHECK: %[[DIM:.*]] = memref.dim %[[ARG:.*]], %[[CONST]]
// CHECK-NEXT: %[[ALLOC:.*]] = memref.alloc(%[[DIM]])
// CHECK-NEXT: memref.copy %[[ARG]], %[[ALLOC]]
// CHECK-NEXT: memref.dealloc %[[ARG]]
@@ -30,13 +30,26 @@ func.func @conversion_dynamic(%arg0 : memref<?xf32>) -> memref<?xf32> {
// -----
+// CHECK-LABEL: @conversion_unknown
func.func @conversion_unknown(%arg0 : memref<*xf32>) -> memref<*xf32> {
-// expected-error at +1 {{failed to legalize operation 'bufferization.clone' that was explicitly marked illegal}}
%1 = bufferization.clone %arg0 : memref<*xf32> to memref<*xf32>
memref.dealloc %arg0 : memref<*xf32>
return %1 : memref<*xf32>
}
+// CHECK: %[[RANK:.*]] = memref.rank %[[ARG:.*]]
+// CHECK-NEXT: %[[ALLOCA:.*]] = memref.alloca(%[[RANK]])
+// CHECK-NEXT: %[[FOR:.*]] = scf.for
+// CHECK-NEXT: %[[DIM:.*]] = memref.dim %[[ARG:.*]] %[[ARG:.*]]
+// CHECK-NEXT: memref.store %[[DIM:.*]], %[[ALLOCA:.*]][%[[ARG:.*]]]
+// CHECK-NEXT: %[[MUL:.*]] = arith.muli %[[ARG:.*]], %[[DIM:.*]]
+// CHECK-NEXT: scf.yield %[[MUL:.*]]
+// CHECK: %[[ALLOC:.*]] = memref.alloc(%[[FOR:.*]])
+// CHECK-NEXT: %[[RESHAPE:.*]] = memref.reshape %[[ALLOC:.*]]
+// CHECK-NEXT: memref.copy
+// CHECK-NEXT: memref.dealloc
+// CHECK-NEXT: return %[[RESHAPE:.*]]
+
// -----
// CHECK-LABEL: func @conversion_with_layout_map(
>From 92c6dc409d57ac8d17e3ff909459784563d22982 Mon Sep 17 00:00:00 2001
From: ryankim <ryankim at mathworks.com>
Date: Mon, 10 Jun 2024 08:42:06 -0400
Subject: [PATCH 2/2] Move constant declarations, reformat loop body, and
declare unrankedType in if statement
---
.../BufferizationToMemRef.cpp | 44 +++++++++----------
1 file changed, 21 insertions(+), 23 deletions(-)
diff --git a/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp b/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp
index 636a2b3d7a81f..810f82f6442ea 100644
--- a/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp
+++ b/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp
@@ -44,13 +44,14 @@ struct CloneOpConversion : public OpConversionPattern<bufferization::CloneOp> {
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
- Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
- Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
-
Type type = op.getType();
Value alloc;
- if (isa<UnrankedMemRefType>(type)) {
+ if (auto unrankedType = dyn_cast<UnrankedMemRefType>(type)) {
+ // Constants
+ Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+ Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+
// Dynamically evaluate the size and shape of the unranked memref
Value rank = rewriter.create<memref::RankOp>(loc, op.getInput());
MemRefType allocType =
@@ -59,32 +60,29 @@ struct CloneOpConversion : public OpConversionPattern<bufferization::CloneOp> {
// Create a loop to query dimension sizes, store them as a shape, and
// compute the total size of the memref
- auto size =
- rewriter
- .create<scf::ForOp>(
- loc, zero, rank, one, ValueRange(one),
- [&](OpBuilder &builder, Location loc, Value i,
- ValueRange args) {
- auto acc = args.front();
-
- auto dim =
- rewriter.create<memref::DimOp>(loc, op.getInput(), i);
+ auto loopBody = [&](OpBuilder &builder, Location loc, Value i,
+ ValueRange args) {
+ auto acc = args.front();
+ auto dim = rewriter.create<memref::DimOp>(loc, op.getInput(), i);
- rewriter.create<memref::StoreOp>(loc, dim, shape, i);
- acc = rewriter.create<arith::MulIOp>(loc, acc, dim);
+ rewriter.create<memref::StoreOp>(loc, dim, shape, i);
+ acc = rewriter.create<arith::MulIOp>(loc, acc, dim);
- rewriter.create<scf::YieldOp>(loc, acc);
- })
- .getResult(0);
+ rewriter.create<scf::YieldOp>(loc, acc);
+ };
+ auto size = rewriter
+ .create<scf::ForOp>(loc, zero, rank, one, ValueRange(one),
+ loopBody)
+ .getResult(0);
- UnrankedMemRefType unranked = cast<UnrankedMemRefType>(type);
- MemRefType memrefType =
- MemRefType::get({ShapedType::kDynamic}, unranked.getElementType());
+ MemRefType memrefType = MemRefType::get({ShapedType::kDynamic},
+ unrankedType.getElementType());
// Allocate new memref with 1D dynamic shape, then reshape into the
// shape of the original unranked memref
alloc = rewriter.create<memref::AllocOp>(loc, memrefType, size);
- alloc = rewriter.create<memref::ReshapeOp>(loc, unranked, alloc, shape);
+ alloc =
+ rewriter.create<memref::ReshapeOp>(loc, unrankedType, alloc, shape);
} else {
MemRefType memrefType = cast<MemRefType>(type);
MemRefLayoutAttrInterface layout;
More information about the Mlir-commits
mailing list