[Mlir-commits] [mlir] [MLIR] Enable pattern only for scf.forall op (PR #110230)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Sep 27 02:46:32 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Prashant Kumar (pashu123)
<details>
<summary>Changes</summary>
The init args shape might change in the loop body and hence the pattern doesn't hold true.
---
Full diff: https://github.com/llvm/llvm-project/pull/110230.diff
1 Files Affected:
- (modified) mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp (+18-4)
``````````diff
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp b/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
index fb2921fec9f79d..aea26602dfb7a4 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
@@ -15,6 +15,7 @@
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
@@ -131,11 +132,24 @@ struct IterArgsToInitArgs : public OpRewritePattern<tensor::DimOp> {
auto blockArg = dyn_cast<BlockArgument>(dimOp.getSource());
if (!blockArg)
return failure();
- auto loopLikeOp =
- dyn_cast<LoopLikeOpInterface>(blockArg.getParentBlock()->getParentOp());
- if (!loopLikeOp)
+ // TODO: Enable this for loopLikeInterface. Restricting for scf.for
+ // because the init args shape might change in the loop body.
+ // For e.g.:
+ // ```
+ // %0 = tensor.empty(%c1) : tensor<?xf32>
+ // %r = scf.for %iv = %c0 to %c10 step %c1 iter_args(%arg0 = %0) -> tensor<?xf32> {
+ // %1 = tensor.dim %arg0, %c0 : tensor<?xf32>
+ // %2 = arith.addi %c1, %1 : index
+ // %3 = tensor.empty(%2) : tensor<?xf32>
+ // scf.yield %3 : tensor<?xf32>
+ // }
+ //
+ // ```
+ auto forAllOp =
+ dyn_cast<scf::ForallOp>(blockArg.getParentBlock()->getParentOp());
+ if (!forAllOp)
return failure();
- Value initArg = loopLikeOp.getTiedLoopInit(blockArg)->get();
+ Value initArg = forAllOp.getTiedLoopInit(blockArg)->get();
rewriter.modifyOpInPlace(
dimOp, [&]() { dimOp.getSourceMutable().assign(initArg); });
return success();
``````````
</details>
https://github.com/llvm/llvm-project/pull/110230
More information about the Mlir-commits
mailing list