[Mlir-commits] [mlir] [mlir][bufferization] Unranked memref support for clone (PR #94757)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Jun 10 05:44:59 PDT 2024


https://github.com/ryankima updated https://github.com/llvm/llvm-project/pull/94757

>From e0b88941207408fcc544b5ab2b389828f3ddf135 Mon Sep 17 00:00:00 2001
From: ryankim <ryankim at mathworks.com>
Date: Fri, 7 Jun 2024 09:42:41 -0400
Subject: [PATCH 1/2] [mlir][bufferization] Unranked memref support for clone

---
 .../BufferizationToMemRef.cpp                 | 97 +++++++++++++------
 .../bufferization-to-memref.mlir              | 17 +++-
 2 files changed, 83 insertions(+), 31 deletions(-)

diff --git a/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp b/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp
index 3069f6e073240..636a2b3d7a81f 100644
--- a/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp
+++ b/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp
@@ -42,39 +42,78 @@ struct CloneOpConversion : public OpConversionPattern<bufferization::CloneOp> {
   LogicalResult
   matchAndRewrite(bufferization::CloneOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    // Check for unranked memref types which are currently not supported.
+    Location loc = op->getLoc();
+
+    Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+    Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+
     Type type = op.getType();
+    Value alloc;
+
     if (isa<UnrankedMemRefType>(type)) {
-      return rewriter.notifyMatchFailure(
-          op, "UnrankedMemRefType is not supported.");
-    }
-    MemRefType memrefType = cast<MemRefType>(type);
-    MemRefLayoutAttrInterface layout;
-    auto allocType =
-        MemRefType::get(memrefType.getShape(), memrefType.getElementType(),
-                        layout, memrefType.getMemorySpace());
-    // Since this implementation always allocates, certain result types of the
-    // clone op cannot be lowered.
-    if (!memref::CastOp::areCastCompatible({allocType}, {memrefType}))
-      return failure();
-
-    // Transform a clone operation into alloc + copy operation and pay
-    // attention to the shape dimensions.
-    Location loc = op->getLoc();
-    SmallVector<Value, 4> dynamicOperands;
-    for (int i = 0; i < memrefType.getRank(); ++i) {
-      if (!memrefType.isDynamicDim(i))
-        continue;
-      Value dim = rewriter.createOrFold<memref::DimOp>(loc, op.getInput(), i);
-      dynamicOperands.push_back(dim);
+      // Dynamically evaluate the size and shape of the unranked memref
+      Value rank = rewriter.create<memref::RankOp>(loc, op.getInput());
+      MemRefType allocType =
+          MemRefType::get({ShapedType::kDynamic}, rewriter.getIndexType());
+      Value shape = rewriter.create<memref::AllocaOp>(loc, allocType, rank);
+
+      // Create a loop to query dimension sizes, store them as a shape, and
+      // compute the total size of the memref
+      auto size =
+          rewriter
+              .create<scf::ForOp>(
+                  loc, zero, rank, one, ValueRange(one),
+                  [&](OpBuilder &builder, Location loc, Value i,
+                      ValueRange args) {
+                    auto acc = args.front();
+
+                    auto dim =
+                        rewriter.create<memref::DimOp>(loc, op.getInput(), i);
+
+                    rewriter.create<memref::StoreOp>(loc, dim, shape, i);
+                    acc = rewriter.create<arith::MulIOp>(loc, acc, dim);
+
+                    rewriter.create<scf::YieldOp>(loc, acc);
+                  })
+              .getResult(0);
+
+      UnrankedMemRefType unranked = cast<UnrankedMemRefType>(type);
+      MemRefType memrefType =
+          MemRefType::get({ShapedType::kDynamic}, unranked.getElementType());
+
+      // Allocate new memref with 1D dynamic shape, then reshape into the
+      // shape of the original unranked memref
+      alloc = rewriter.create<memref::AllocOp>(loc, memrefType, size);
+      alloc = rewriter.create<memref::ReshapeOp>(loc, unranked, alloc, shape);
+    } else {
+      MemRefType memrefType = cast<MemRefType>(type);
+      MemRefLayoutAttrInterface layout;
+      auto allocType =
+          MemRefType::get(memrefType.getShape(), memrefType.getElementType(),
+                          layout, memrefType.getMemorySpace());
+      // Since this implementation always allocates, certain result types of
+      // the clone op cannot be lowered.
+      if (!memref::CastOp::areCastCompatible({allocType}, {memrefType}))
+        return failure();
+
+      // Transform a clone operation into alloc + copy operation and pay
+      // attention to the shape dimensions.
+      SmallVector<Value, 4> dynamicOperands;
+      for (int i = 0; i < memrefType.getRank(); ++i) {
+        if (!memrefType.isDynamicDim(i))
+          continue;
+        Value dim = rewriter.createOrFold<memref::DimOp>(loc, op.getInput(), i);
+        dynamicOperands.push_back(dim);
+      }
+
+      // Allocate a memref with identity layout.
+      alloc = rewriter.create<memref::AllocOp>(loc, allocType, dynamicOperands);
+      // Cast the allocation to the specified type if needed.
+      if (memrefType != allocType)
+        alloc =
+            rewriter.create<memref::CastOp>(op->getLoc(), memrefType, alloc);
     }
 
-    // Allocate a memref with identity layout.
-    Value alloc = rewriter.create<memref::AllocOp>(op->getLoc(), allocType,
-                                                   dynamicOperands);
-    // Cast the allocation to the specified type if needed.
-    if (memrefType != allocType)
-      alloc = rewriter.create<memref::CastOp>(op->getLoc(), memrefType, alloc);
     rewriter.replaceOp(op, alloc);
     rewriter.create<memref::CopyOp>(loc, op.getInput(), alloc);
     return success();
diff --git a/mlir/test/Conversion/BufferizationToMemRef/bufferization-to-memref.mlir b/mlir/test/Conversion/BufferizationToMemRef/bufferization-to-memref.mlir
index f58a2afa1a896..21d5f42158d09 100644
--- a/mlir/test/Conversion/BufferizationToMemRef/bufferization-to-memref.mlir
+++ b/mlir/test/Conversion/BufferizationToMemRef/bufferization-to-memref.mlir
@@ -22,7 +22,7 @@ func.func @conversion_dynamic(%arg0 : memref<?xf32>) -> memref<?xf32> {
 }
 
 // CHECK:      %[[CONST:.*]] = arith.constant
-// CHECK-NEXT: %[[DIM:.*]] = memref.dim %[[ARG:.*]], %[[CONST]]
+// CHECK:      %[[DIM:.*]] = memref.dim %[[ARG:.*]], %[[CONST]]
 // CHECK-NEXT: %[[ALLOC:.*]] = memref.alloc(%[[DIM]])
 // CHECK-NEXT: memref.copy %[[ARG]], %[[ALLOC]]
 // CHECK-NEXT: memref.dealloc %[[ARG]]
@@ -30,13 +30,26 @@ func.func @conversion_dynamic(%arg0 : memref<?xf32>) -> memref<?xf32> {
 
 // -----
 
+// CHECK-LABEL: @conversion_unknown
 func.func @conversion_unknown(%arg0 : memref<*xf32>) -> memref<*xf32> {
-// expected-error at +1 {{failed to legalize operation 'bufferization.clone' that was explicitly marked illegal}}
   %1 = bufferization.clone %arg0 : memref<*xf32> to memref<*xf32>
   memref.dealloc %arg0 : memref<*xf32>
   return %1 : memref<*xf32>
 }
 
+// CHECK:      %[[RANK:.*]] = memref.rank %[[ARG:.*]]
+// CHECK-NEXT: %[[ALLOCA:.*]] = memref.alloca(%[[RANK]])
+// CHECK-NEXT: %[[FOR:.*]] = scf.for
+// CHECK-NEXT: %[[DIM:.*]] = memref.dim %[[ARG:.*]] %[[ARG:.*]]
+// CHECK-NEXT: memref.store %[[DIM:.*]], %[[ALLOCA:.*]][%[[ARG:.*]]]
+// CHECK-NEXT: %[[MUL:.*]] = arith.muli %[[ARG:.*]], %[[DIM:.*]]
+// CHECK-NEXT: scf.yield %[[MUL:.*]]
+// CHECK:      %[[ALLOC:.*]] = memref.alloc(%[[FOR:.*]])
+// CHECK-NEXT: %[[RESHAPE:.*]] = memref.reshape %[[ALLOC:.*]]
+// CHECK-NEXT: memref.copy
+// CHECK-NEXT: memref.dealloc
+// CHECK-NEXT: return %[[RESHAPE:.*]]
+
 // -----
 
 // CHECK-LABEL: func @conversion_with_layout_map(

>From 92c6dc409d57ac8d17e3ff909459784563d22982 Mon Sep 17 00:00:00 2001
From: ryankim <ryankim at mathworks.com>
Date: Mon, 10 Jun 2024 08:42:06 -0400
Subject: [PATCH 2/2] Move constant declarations, reformat loop body, and
 declare unrankedType in if statement

---
 .../BufferizationToMemRef.cpp                 | 44 +++++++++----------
 1 file changed, 21 insertions(+), 23 deletions(-)

diff --git a/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp b/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp
index 636a2b3d7a81f..810f82f6442ea 100644
--- a/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp
+++ b/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp
@@ -44,13 +44,14 @@ struct CloneOpConversion : public OpConversionPattern<bufferization::CloneOp> {
                   ConversionPatternRewriter &rewriter) const override {
     Location loc = op->getLoc();
 
-    Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
-    Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
-
     Type type = op.getType();
     Value alloc;
 
-    if (isa<UnrankedMemRefType>(type)) {
+    if (auto unrankedType = dyn_cast<UnrankedMemRefType>(type)) {
+      // Constants
+      Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+      Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+
       // Dynamically evaluate the size and shape of the unranked memref
       Value rank = rewriter.create<memref::RankOp>(loc, op.getInput());
       MemRefType allocType =
@@ -59,32 +60,29 @@ struct CloneOpConversion : public OpConversionPattern<bufferization::CloneOp> {
 
       // Create a loop to query dimension sizes, store them as a shape, and
       // compute the total size of the memref
-      auto size =
-          rewriter
-              .create<scf::ForOp>(
-                  loc, zero, rank, one, ValueRange(one),
-                  [&](OpBuilder &builder, Location loc, Value i,
-                      ValueRange args) {
-                    auto acc = args.front();
-
-                    auto dim =
-                        rewriter.create<memref::DimOp>(loc, op.getInput(), i);
+      auto loopBody = [&](OpBuilder &builder, Location loc, Value i,
+                          ValueRange args) {
+        auto acc = args.front();
+        auto dim = rewriter.create<memref::DimOp>(loc, op.getInput(), i);
 
-                    rewriter.create<memref::StoreOp>(loc, dim, shape, i);
-                    acc = rewriter.create<arith::MulIOp>(loc, acc, dim);
+        rewriter.create<memref::StoreOp>(loc, dim, shape, i);
+        acc = rewriter.create<arith::MulIOp>(loc, acc, dim);
 
-                    rewriter.create<scf::YieldOp>(loc, acc);
-                  })
-              .getResult(0);
+        rewriter.create<scf::YieldOp>(loc, acc);
+      };
+      auto size = rewriter
+                      .create<scf::ForOp>(loc, zero, rank, one, ValueRange(one),
+                                          loopBody)
+                      .getResult(0);
 
-      UnrankedMemRefType unranked = cast<UnrankedMemRefType>(type);
-      MemRefType memrefType =
-          MemRefType::get({ShapedType::kDynamic}, unranked.getElementType());
+      MemRefType memrefType = MemRefType::get({ShapedType::kDynamic},
+                                              unrankedType.getElementType());
 
       // Allocate new memref with 1D dynamic shape, then reshape into the
       // shape of the original unranked memref
       alloc = rewriter.create<memref::AllocOp>(loc, memrefType, size);
-      alloc = rewriter.create<memref::ReshapeOp>(loc, unranked, alloc, shape);
+      alloc =
+          rewriter.create<memref::ReshapeOp>(loc, unrankedType, alloc, shape);
     } else {
       MemRefType memrefType = cast<MemRefType>(type);
       MemRefLayoutAttrInterface layout;



More information about the Mlir-commits mailing list