[Mlir-commits] [mlir] c1047ba - [MLIR] Enable pattern only for scf.forall op (#110230)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Oct 17 06:02:08 PDT 2024


Author: Prashant Kumar
Date: 2024-10-17T18:32:03+05:30
New Revision: c1047ba8366a447b61f845048a5f287dae24d9d0

URL: https://github.com/llvm/llvm-project/commit/c1047ba8366a447b61f845048a5f287dae24d9d0
DIFF: https://github.com/llvm/llvm-project/commit/c1047ba8366a447b61f845048a5f287dae24d9d0.diff

LOG: [MLIR] Enable pattern only for scf.forall op (#110230)

The init args shape might change in the loop body and hence the pattern
doesn't hold true.

Added: 
    

Modified: 
    mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp b/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
index fb2921fec9f79d..792e7229183064 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
@@ -18,6 +18,7 @@
 #include "mlir/Dialect/Arith/Utils/Utils.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/MemRef/Transforms/Transforms.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Interfaces/InferTypeOpInterface.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -131,11 +132,25 @@ struct IterArgsToInitArgs : public OpRewritePattern<tensor::DimOp> {
     auto blockArg = dyn_cast<BlockArgument>(dimOp.getSource());
     if (!blockArg)
       return failure();
-    auto loopLikeOp =
-        dyn_cast<LoopLikeOpInterface>(blockArg.getParentBlock()->getParentOp());
-    if (!loopLikeOp)
+    // TODO: Enable this for loopLikeInterface. Restricting for scf.for
+    // because the init args shape might change in the loop body.
+    // For e.g.:
+    // ```
+    //  %0 = tensor.empty(%c1) : tensor<?xf32>
+    //  %r = scf.for %iv = %c0 to %c10 step %c1 iter_args(%arg0 = %0) ->
+    //  tensor<?xf32> {
+    //    %1 = tensor.dim %arg0, %c0 : tensor<?xf32>
+    //    %2 = arith.addi %c1, %1 : index
+    //    %3 = tensor.empty(%2) : tensor<?xf32>
+    //    scf.yield %3 : tensor<?xf32>
+    //  }
+    //
+    // ```
+    auto forAllOp =
+        dyn_cast<scf::ForallOp>(blockArg.getParentBlock()->getParentOp());
+    if (!forAllOp)
       return failure();
-    Value initArg = loopLikeOp.getTiedLoopInit(blockArg)->get();
+    Value initArg = forAllOp.getTiedLoopInit(blockArg)->get();
     rewriter.modifyOpInPlace(
         dimOp, [&]() { dimOp.getSourceMutable().assign(initArg); });
     return success();


        


More information about the Mlir-commits mailing list