[Mlir-commits] [mlir] c1047ba - [MLIR] Enable pattern only for scf.forall op (#110230)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Oct 17 06:02:08 PDT 2024
Author: Prashant Kumar
Date: 2024-10-17T18:32:03+05:30
New Revision: c1047ba8366a447b61f845048a5f287dae24d9d0
URL: https://github.com/llvm/llvm-project/commit/c1047ba8366a447b61f845048a5f287dae24d9d0
DIFF: https://github.com/llvm/llvm-project/commit/c1047ba8366a447b61f845048a5f287dae24d9d0.diff
LOG: [MLIR] Enable pattern only for scf.forall op (#110230)
The init args shape might change in the loop body and hence the pattern
doesn't hold true.
Added:
Modified:
mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp b/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
index fb2921fec9f79d..792e7229183064 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
@@ -18,6 +18,7 @@
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -131,11 +132,25 @@ struct IterArgsToInitArgs : public OpRewritePattern<tensor::DimOp> {
auto blockArg = dyn_cast<BlockArgument>(dimOp.getSource());
if (!blockArg)
return failure();
- auto loopLikeOp =
- dyn_cast<LoopLikeOpInterface>(blockArg.getParentBlock()->getParentOp());
- if (!loopLikeOp)
+ // TODO: Enable this for loopLikeInterface. Restricting for scf.for
+ // because the init args shape might change in the loop body.
+ // For e.g.:
+ // ```
+ // %0 = tensor.empty(%c1) : tensor<?xf32>
+ // %r = scf.for %iv = %c0 to %c10 step %c1 iter_args(%arg0 = %0) ->
+ // tensor<?xf32> {
+ // %1 = tensor.dim %arg0, %c0 : tensor<?xf32>
+ // %2 = arith.addi %c1, %1 : index
+ // %3 = tensor.empty(%2) : tensor<?xf32>
+ // scf.yield %3 : tensor<?xf32>
+ // }
+ //
+ // ```
+ auto forAllOp =
+ dyn_cast<scf::ForallOp>(blockArg.getParentBlock()->getParentOp());
+ if (!forAllOp)
return failure();
- Value initArg = loopLikeOp.getTiedLoopInit(blockArg)->get();
+ Value initArg = forAllOp.getTiedLoopInit(blockArg)->get();
rewriter.modifyOpInPlace(
dimOp, [&]() { dimOp.getSourceMutable().assign(initArg); });
return success();
More information about the Mlir-commits
mailing list