[Mlir-commits] [mlir] [mlir][bufferization] Unranked memref support for clone (PR #94757)

Matthias Springer llvmlistbot at llvm.org
Sun Jun 9 01:38:45 PDT 2024


================
@@ -42,39 +42,78 @@ struct CloneOpConversion : public OpConversionPattern<bufferization::CloneOp> {
   LogicalResult
   matchAndRewrite(bufferization::CloneOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    // Check for unranked memref types which are currently not supported.
+    Location loc = op->getLoc();
+
+    Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+    Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+
     Type type = op.getType();
+    Value alloc;
+
     if (isa<UnrankedMemRefType>(type)) {
-      return rewriter.notifyMatchFailure(
-          op, "UnrankedMemRefType is not supported.");
-    }
-    MemRefType memrefType = cast<MemRefType>(type);
-    MemRefLayoutAttrInterface layout;
-    auto allocType =
-        MemRefType::get(memrefType.getShape(), memrefType.getElementType(),
-                        layout, memrefType.getMemorySpace());
-    // Since this implementation always allocates, certain result types of the
-    // clone op cannot be lowered.
-    if (!memref::CastOp::areCastCompatible({allocType}, {memrefType}))
-      return failure();
-
-    // Transform a clone operation into alloc + copy operation and pay
-    // attention to the shape dimensions.
-    Location loc = op->getLoc();
-    SmallVector<Value, 4> dynamicOperands;
-    for (int i = 0; i < memrefType.getRank(); ++i) {
-      if (!memrefType.isDynamicDim(i))
-        continue;
-      Value dim = rewriter.createOrFold<memref::DimOp>(loc, op.getInput(), i);
-      dynamicOperands.push_back(dim);
+      // Dynamically evaluate the size and shape of the unranked memref
+      Value rank = rewriter.create<memref::RankOp>(loc, op.getInput());
+      MemRefType allocType =
+          MemRefType::get({ShapedType::kDynamic}, rewriter.getIndexType());
+      Value shape = rewriter.create<memref::AllocaOp>(loc, allocType, rank);
+
+      // Create a loop to query dimension sizes, store them as a shape, and
+      // compute the total size of the memref
+      auto size =
+          rewriter
+              .create<scf::ForOp>(
+                  loc, zero, rank, one, ValueRange(one),
+                  [&](OpBuilder &builder, Location loc, Value i,
----------------
matthias-springer wrote:

nit: can you move this lambda into a variable to avoid the excessive indentation

https://github.com/llvm/llvm-project/pull/94757


More information about the Mlir-commits mailing list