[Mlir-commits] [mlir] [MLIR] Enable pattern only for scf.forall op (PR #110230)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Sep 27 02:46:32 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Prashant Kumar (pashu123)

<details>
<summary>Changes</summary>

The init args shape might change in the loop body and hence the pattern doesn't hold true.

---
Full diff: https://github.com/llvm/llvm-project/pull/110230.diff


1 Files Affected:

- (modified) mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp (+18-4) 


``````````diff
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp b/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
index fb2921fec9f79d..aea26602dfb7a4 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
@@ -15,6 +15,7 @@
 
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/Arith/Utils/Utils.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/MemRef/Transforms/Transforms.h"
@@ -131,11 +132,24 @@ struct IterArgsToInitArgs : public OpRewritePattern<tensor::DimOp> {
     auto blockArg = dyn_cast<BlockArgument>(dimOp.getSource());
     if (!blockArg)
       return failure();
-    auto loopLikeOp =
-        dyn_cast<LoopLikeOpInterface>(blockArg.getParentBlock()->getParentOp());
-    if (!loopLikeOp)
+    // TODO: Enable this for loopLikeInterface. Restricting for scf.for
+    // because the init args shape might change in the loop body.
+    // For e.g.:
+    // ```
+    //  %0 = tensor.empty(%c1) : tensor<?xf32>
+    //  %r = scf.for %iv = %c0 to %c10 step %c1 iter_args(%arg0 = %0) -> tensor<?xf32> {
+    //    %1 = tensor.dim %arg0, %c0 : tensor<?xf32>
+    //    %2 = arith.addi %c1, %1 : index
+    //    %3 = tensor.empty(%2) : tensor<?xf32>
+    //    scf.yield %3 : tensor<?xf32>
+    //  }
+    //
+    // ```
+    auto forAllOp =
+        dyn_cast<scf::ForallOp>(blockArg.getParentBlock()->getParentOp());
+    if (!forAllOp)
       return failure();
-    Value initArg = loopLikeOp.getTiedLoopInit(blockArg)->get();
+    Value initArg = forAllOp.getTiedLoopInit(blockArg)->get();
     rewriter.modifyOpInPlace(
         dimOp, [&]() { dimOp.getSourceMutable().assign(initArg); });
     return success();

``````````

</details>


https://github.com/llvm/llvm-project/pull/110230


More information about the Mlir-commits mailing list