[Mlir-commits] [mlir] [MLIR] Enable pattern only for scf.forall op (PR #110230)

Prashant Kumar llvmlistbot at llvm.org
Fri Sep 27 02:45:59 PDT 2024


https://github.com/pashu123 created https://github.com/llvm/llvm-project/pull/110230

The init args shape might change in the loop body and hence the pattern doesn't hold true.

>From 7931acb52c7230b5ea56195bde4f3ea3285a228e Mon Sep 17 00:00:00 2001
From: Prashant Kumar <pk5561 at gmail.com>
Date: Fri, 27 Sep 2024 15:14:01 +0530
Subject: [PATCH] [MLIR] Enable pattern only for scf.forall op

The init args shape might change in the loop body and hence the pattern
doesn't hold true.
---
 .../ResolveShapedTypeResultDims.cpp           | 22 +++++++++++++++----
 1 file changed, 18 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp b/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
index fb2921fec9f79d..aea26602dfb7a4 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
@@ -15,6 +15,7 @@
 
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/Arith/Utils/Utils.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/MemRef/Transforms/Transforms.h"
@@ -131,11 +132,24 @@ struct IterArgsToInitArgs : public OpRewritePattern<tensor::DimOp> {
     auto blockArg = dyn_cast<BlockArgument>(dimOp.getSource());
     if (!blockArg)
       return failure();
-    auto loopLikeOp =
-        dyn_cast<LoopLikeOpInterface>(blockArg.getParentBlock()->getParentOp());
-    if (!loopLikeOp)
+    // TODO: Enable this for loopLikeInterface. Restricting for scf.for
+    // because the init args shape might change in the loop body.
+    // For e.g.:
+    // ```
+    //  %0 = tensor.empty(%c1) : tensor<?xf32>
+    //  %r = scf.for %iv = %c0 to %c10 step %c1 iter_args(%arg0 = %0) -> tensor<?xf32> {
+    //    %1 = tensor.dim %arg0, %c0 : tensor<?xf32>
+    //    %2 = arith.addi %c1, %1 : index
+    //    %3 = tensor.empty(%2) : tensor<?xf32>
+    //    scf.yield %3 : tensor<?xf32>
+    //  }
+    //
+    // ```
+    auto forAllOp =
+        dyn_cast<scf::ForallOp>(blockArg.getParentBlock()->getParentOp());
+    if (!forAllOp)
       return failure();
-    Value initArg = loopLikeOp.getTiedLoopInit(blockArg)->get();
+    Value initArg = forAllOp.getTiedLoopInit(blockArg)->get();
     rewriter.modifyOpInPlace(
         dimOp, [&]() { dimOp.getSourceMutable().assign(initArg); });
     return success();



More information about the Mlir-commits mailing list