[Mlir-commits] [mlir] [MLIR][SCF] Fold dim ops of iter_args to respective init_args (PR #109973)

Matthias Springer llvmlistbot at llvm.org
Thu Sep 26 12:15:39 PDT 2024


================
@@ -103,6 +103,44 @@ struct DimOfReifyRankedShapedTypeOpInterface : public OpRewritePattern<OpTy> {
     return success();
   }
 };
+
+/// Fold dim ops of iter_args to dim ops of their respective init args. E.g.:
+///
+/// ```
+/// %0 = ... : tensor<?x?xf32>
+/// scf.forall ... shared_outs(%arg0 = %0) -> (tensor<?x?xf32>) {
+///   %1 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
+///   ...
+/// }
+/// ```
+///
+/// is folded to:
+///
+/// ```
+/// %0 = ... : tensor<?x?xf32>
+/// scf.forall ... shared_outs(%arg0 = %0) -> (tensor<?x?xf32>) {
+///   %1 = tensor.dim %0, %c0 : tensor<?x?xf32>
+///   ...
+/// }
+/// ```
+struct IterArgsToInitArgs : public OpRewritePattern<tensor::DimOp> {
+  using OpRewritePattern<tensor::DimOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(tensor::DimOp dimOp,
+                                PatternRewriter &rewriter) const final {
+    auto blockArg = dyn_cast<BlockArgument>(dimOp.getSource());
+    if (!blockArg)
+      return failure();
+    auto loopLikeOp =
+        dyn_cast<LoopLikeOpInterface>(blockArg.getParentBlock()->getParentOp());
----------------
matthias-springer wrote:

I think this canonicalization is incorrect in case of `scf.for`. Example:
```
%0 = tensor.empty(%c1) : tensor<?xf32>
%r = scf.for %iv = %c0 to %c10 step %c1 iter_args(%arg0 = %0) -> tensor<?xf32> {
  %1 = tensor.dim %arg0, %c0 : tensor<?xf32>
  %2 = arith.addi %c1, %1 : index
  %3 = tensor.empty(%2) : tensor<?xf32>
  scf.yield %3 : tensor<?xf32>
}
```

With this new folding, `%r` is always a tensor with dynamic size 1.

https://github.com/llvm/llvm-project/pull/109973


More information about the Mlir-commits mailing list