[Mlir-commits] [mlir] [MLIR][SCF] Fold dim ops of iter_args to respective init_args (PR #109973)

Prashant Kumar llvmlistbot at llvm.org
Thu Sep 26 02:54:40 PDT 2024


https://github.com/pashu123 updated https://github.com/llvm/llvm-project/pull/109973

>From f73b26ad8687287722e687fd884acf912f3a3f16 Mon Sep 17 00:00:00 2001
From: Prashant Kumar <pk5561 at gmail.com>
Date: Wed, 25 Sep 2024 18:05:31 +0530
Subject: [PATCH 1/3] [MLIR][SCF] Remove whitespaces.

---
 mlir/test/Dialect/SCF/canonicalize.mlir | 10 ++++++----
 1 file changed, 6 insertions(+), 4 deletions(-)

diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index c68369a8e4fce7..2eab04363d38f3 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -1787,7 +1787,7 @@ module {
 }
 // CHECK-LABEL: @fold_iter_args_not_being_modified_within_scfforall
 //  CHECK-SAME:   (%{{.*}}: index, %[[ARG1:.*]]: tensor<?xf32>, %[[ARG2:.*]]: tensor<?xf32>) -> (tensor<?xf32>, tensor<?xf32>) {
-//       CHECK:    %[[RESULT:.*]] = scf.forall 
+//       CHECK:    %[[RESULT:.*]] = scf.forall
 //  CHECK-SAME:                       shared_outs(%[[ITER_ARG_5:.*]] = %[[ARG2]]) -> (tensor<?xf32>) {
 //       CHECK:      %[[OPERAND0:.*]] = tensor.extract_slice %[[ARG1]]
 //       CHECK:      %[[OPERAND1:.*]] = tensor.extract_slice %[[ITER_ARG_5]]
@@ -1830,7 +1830,7 @@ module {
 }
 // CHECK-LABEL: @fold_iter_args_with_no_use_of_result_scfforall
 //  CHECK-SAME:   (%{{.*}}: index, %[[ARG1:.*]]: tensor<?xf32>, %[[ARG2:.*]]: tensor<?xf32>, %[[ARG3:.*]]: tensor<?xf32>) -> tensor<?xf32> {
-//       CHECK:    %[[RESULT:.*]] = scf.forall 
+//       CHECK:    %[[RESULT:.*]] = scf.forall
 //  CHECK-SAME:                       shared_outs(%[[ITER_ARG_6:.*]] = %[[ARG2]]) -> (tensor<?xf32>) {
 //       CHECK:      %[[OPERAND0:.*]] = tensor.extract_slice %[[ARG1]]
 //       CHECK:      %[[OPERAND1:.*]] = tensor.extract_slice %[[ARG3]]
@@ -1854,7 +1854,7 @@ func.func @index_switch_fold() -> (f32, f32) {
     %y = arith.constant 42.0 : f32
     scf.yield %y : f32
   }
-  
+
   %switch_cst_2 = arith.constant 2: index
   %1 = scf.index_switch %switch_cst_2 -> f32
   case 0 {
@@ -1865,7 +1865,7 @@ func.func @index_switch_fold() -> (f32, f32) {
     %y = arith.constant 42.0 : f32
     scf.yield %y : f32
   }
-  
+
   return %0, %1 : f32, f32
 }
 
@@ -1891,3 +1891,5 @@ func.func @index_switch_fold_no_res() {
 
 // CHECK-LABEL: func.func @index_switch_fold_no_res()
 //  CHECK-NEXT: "test.op"() : () -> ()
+
+// -----

>From 25480550ebad8b29e2f38637ce22a7e2c4ddb297 Mon Sep 17 00:00:00 2001
From: Prashant Kumar <pk5561 at gmail.com>
Date: Wed, 25 Sep 2024 18:07:10 +0530
Subject: [PATCH 2/3] [MLIR][SCF] Fold dim ops of iter_args to respective
 init_args

---
 mlir/lib/Dialect/SCF/IR/SCF.cpp         | 46 +++++++++++++++++++++++--
 mlir/test/Dialect/SCF/canonicalize.mlir | 27 +++++++++++++++
 2 files changed, 70 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 6d47ff3890977a..9705a9cd1741b9 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -1478,6 +1478,45 @@ struct DimOfForallOp : public OpRewritePattern<tensor::DimOp> {
   }
 };
 
+/// Fold dim ops of iter_args to dim ops of their respective init args. E.g.:
+///
+/// ```
+/// %0 = ... : tensor<?x?xf32>
+/// scf.forall ... shared_outs(%arg0 = %0) -> (tensor<?x?xf32>) {
+///   %1 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
+///   ...
+/// }
+/// ```
+///
+/// is folded to:
+///
+/// ```
+/// %0 = ... : tensor<?x?xf32>
+/// scf.forall ... shared_outs(%arg0 = %0) -> (tensor<?x?xf32>) {
+///   %1 = tensor.dim %0, %c0 : tensor<?x?xf32>
+///   ...
+/// }
+/// ```
+struct DimOfForallIterArg : public OpRewritePattern<tensor::DimOp> {
+  using OpRewritePattern<tensor::DimOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(tensor::DimOp dimOp,
+                                PatternRewriter &rewriter) const final {
+    auto blockArg = dyn_cast<BlockArgument>(dimOp.getSource());
+    if (!blockArg)
+      return failure();
+    auto forallOp =
+        dyn_cast<ForallOp>(blockArg.getParentBlock()->getParentOp());
+    if (!forallOp)
+      return failure();
+    Value initArg = forallOp.getTiedLoopInit(blockArg)->get();
+    rewriter.modifyOpInPlace(
+        dimOp, [&]() { dimOp.getSourceMutable().assign(initArg); });
+
+    return success();
+  }
+};
+
 class ForallOpControlOperandsFolder : public OpRewritePattern<ForallOp> {
 public:
   using OpRewritePattern<ForallOp>::OpRewritePattern;
@@ -1851,9 +1890,10 @@ struct FoldTensorCastOfOutputIntoForallOp
 
 void ForallOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                            MLIRContext *context) {
-  results.add<DimOfForallOp, FoldTensorCastOfOutputIntoForallOp,
-              ForallOpControlOperandsFolder, ForallOpIterArgsFolder,
-              ForallOpSingleOrZeroIterationDimsFolder>(context);
+  results.add<DimOfForallOp, DimOfForallIterArg,
+              FoldTensorCastOfOutputIntoForallOp, ForallOpControlOperandsFolder,
+              ForallOpIterArgsFolder, ForallOpSingleOrZeroIterationDimsFolder>(
+      context);
 }
 
 /// Given the region at `index`, or the parent operation if `index` is None,
diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index 2eab04363d38f3..e203a517932237 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -1893,3 +1893,30 @@ func.func @index_switch_fold_no_res() {
 //  CHECK-NEXT: "test.op"() : () -> ()
 
 // -----
+
+func.func @forall_iter_to_init_arg(
+  %arg0 : tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %dim0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
+
+  %result = scf.forall (%i) = (%c0) to (%dim0)
+      step (%c1) shared_outs(%o = %arg1) -> (tensor<?x?xf32>) {
+
+    %dim1 = tensor.dim %o, %c1 : tensor<?x?xf32>
+    %slice = tensor.extract_slice %arg1[%i, 0] [1, %dim1] [1, 1]
+      : tensor<?x?xf32> to tensor<1x?xf32>
+
+    scf.forall.in_parallel {
+      tensor.parallel_insert_slice %slice into %o[%i, 0] [1, %dim1] [1, 1]
+        : tensor<1x?xf32> into tensor<?x?xf32>
+    }
+  }
+
+  return %result : tensor<?x?xf32>
+}
+// CHECK-LABEL: @forall_iter_to_init_arg
+//  CHECK-SAME:   (%[[ARG0:.*]]: tensor<?x?xf32>, %[[ARG1:.*]]: tensor<?x?xf32>) -> tensor<?x?xf32> {
+//       CHECK:    %[[RESULT:.*]] = scf.forall
+//  CHECK-SAME:                       shared_outs(%[[OUTS:.*]] = %[[ARG1]]) -> (tensor<?x?xf32>) {
+//  CHECK-NEXT:       %{{.*}} = tensor.dim %[[ARG1]], %{{.*}} : tensor<?x?xf32>

>From a6ca06fc6aec6ea8df32e3b5d247cfb1740e1968 Mon Sep 17 00:00:00 2001
From: Prashant Kumar <pk5561 at gmail.com>
Date: Thu, 26 Sep 2024 15:24:16 +0530
Subject: [PATCH 3/3] Move pattern to resolveShapedTypeResultDims

---
 .../ResolveShapedTypeResultDims.cpp           | 44 ++++++++++++++++++-
 mlir/lib/Dialect/SCF/IR/SCF.cpp               | 39 ----------------
 mlir/test/Dialect/MemRef/resolve-dim-ops.mlir | 28 ++++++++++++
 mlir/test/Dialect/SCF/canonicalize.mlir       | 29 ------------
 4 files changed, 70 insertions(+), 70 deletions(-)

diff --git a/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp b/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
index 0cb5931ce6bf9b..79466c26b5cab6 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
@@ -103,6 +103,45 @@ struct DimOfReifyRankedShapedTypeOpInterface : public OpRewritePattern<OpTy> {
     return success();
   }
 };
+
+/// Fold dim ops of iter_args to dim ops of their respective init args. E.g.:
+///
+/// ```
+/// %0 = ... : tensor<?x?xf32>
+/// scf.forall ... shared_outs(%arg0 = %0) -> (tensor<?x?xf32>) {
+///   %1 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
+///   ...
+/// }
+/// ```
+///
+/// is folded to:
+///
+/// ```
+/// %0 = ... : tensor<?x?xf32>
+/// scf.forall ... shared_outs(%arg0 = %0) -> (tensor<?x?xf32>) {
+///   %1 = tensor.dim %0, %c0 : tensor<?x?xf32>
+///   ...
+/// }
+/// ```
+template <typename OpTy>
+struct IterArgsToInitArgs : public OpRewritePattern<OpTy> {
+  using OpRewritePattern<OpTy>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(OpTy dimOp,
+                                PatternRewriter &rewriter) const final {
+    auto blockArg = dyn_cast<BlockArgument>(dimOp.getSource());
+    if (!blockArg)
+      return failure();
+    auto loopLikeOp =
+        dyn_cast<LoopLikeOpInterface>(blockArg.getParentBlock()->getParentOp());
+    if (!loopLikeOp)
+      return failure();
+    Value initArg = loopLikeOp.getTiedLoopInit(blockArg)->get();
+    rewriter.modifyOpInPlace(
+        dimOp, [&]() { dimOp.getSourceMutable().assign(initArg); });
+    return success();
+  }
+};
 } // namespace
 
 //===----------------------------------------------------------------------===//
@@ -127,8 +166,9 @@ struct ResolveShapedTypeResultDimsPass final
 void memref::populateResolveRankedShapedTypeResultDimsPatterns(
     RewritePatternSet &patterns) {
   patterns.add<DimOfReifyRankedShapedTypeOpInterface<memref::DimOp>,
-               DimOfReifyRankedShapedTypeOpInterface<tensor::DimOp>>(
-      patterns.getContext());
+               DimOfReifyRankedShapedTypeOpInterface<tensor::DimOp>,
+               IterArgsToInitArgs<memref::DimOp>,
+               IterArgsToInitArgs<tensor::DimOp>>(patterns.getContext());
 }
 
 void memref::populateResolveShapedTypeResultDimsPatterns(
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 9705a9cd1741b9..2041ce88152f33 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -1478,45 +1478,6 @@ struct DimOfForallOp : public OpRewritePattern<tensor::DimOp> {
   }
 };
 
-/// Fold dim ops of iter_args to dim ops of their respective init args. E.g.:
-///
-/// ```
-/// %0 = ... : tensor<?x?xf32>
-/// scf.forall ... shared_outs(%arg0 = %0) -> (tensor<?x?xf32>) {
-///   %1 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
-///   ...
-/// }
-/// ```
-///
-/// is folded to:
-///
-/// ```
-/// %0 = ... : tensor<?x?xf32>
-/// scf.forall ... shared_outs(%arg0 = %0) -> (tensor<?x?xf32>) {
-///   %1 = tensor.dim %0, %c0 : tensor<?x?xf32>
-///   ...
-/// }
-/// ```
-struct DimOfForallIterArg : public OpRewritePattern<tensor::DimOp> {
-  using OpRewritePattern<tensor::DimOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(tensor::DimOp dimOp,
-                                PatternRewriter &rewriter) const final {
-    auto blockArg = dyn_cast<BlockArgument>(dimOp.getSource());
-    if (!blockArg)
-      return failure();
-    auto forallOp =
-        dyn_cast<ForallOp>(blockArg.getParentBlock()->getParentOp());
-    if (!forallOp)
-      return failure();
-    Value initArg = forallOp.getTiedLoopInit(blockArg)->get();
-    rewriter.modifyOpInPlace(
-        dimOp, [&]() { dimOp.getSourceMutable().assign(initArg); });
-
-    return success();
-  }
-};
-
 class ForallOpControlOperandsFolder : public OpRewritePattern<ForallOp> {
 public:
   using OpRewritePattern<ForallOp>::OpRewritePattern;
diff --git a/mlir/test/Dialect/MemRef/resolve-dim-ops.mlir b/mlir/test/Dialect/MemRef/resolve-dim-ops.mlir
index 85a4853972457c..ef8b80f6b5c22a 100644
--- a/mlir/test/Dialect/MemRef/resolve-dim-ops.mlir
+++ b/mlir/test/Dialect/MemRef/resolve-dim-ops.mlir
@@ -71,3 +71,31 @@ func.func @dim_of_memref_expand_shape(%arg0: memref<?x8xi32>)
   %1 = memref.dim %0, %c1 : memref<1x?x2x4xi32>
   return %1 : index
 }
+
+// -----
+
+// CHECK-LABEL: @iter_to_init_arg_loop_like
+//  CHECK-SAME:   (%[[ARG0:.*]]: tensor<?x?xf32>, %[[ARG1:.*]]: tensor<?x?xf32>) -> tensor<?x?xf32> {
+//       CHECK:    %[[RESULT:.*]] = scf.forall
+//  CHECK-SAME:                       shared_outs(%[[OUTS:.*]] = %[[ARG1]]) -> (tensor<?x?xf32>) {
+//  CHECK-NEXT:       %{{.*}} = tensor.dim %[[ARG1]], %{{.*}} : tensor<?x?xf32>
+func.func @iter_to_init_arg_loop_like(
+  %arg0 : tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %dim0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
+
+  %result = scf.forall (%i) = (%c0) to (%dim0)
+      step (%c1) shared_outs(%o = %arg1) -> (tensor<?x?xf32>) {
+
+    %dim1 = tensor.dim %o, %c1 : tensor<?x?xf32>
+    %slice = tensor.extract_slice %arg1[%i, 0] [1, %dim1] [1, 1]
+      : tensor<?x?xf32> to tensor<1x?xf32>
+
+    scf.forall.in_parallel {
+      tensor.parallel_insert_slice %slice into %o[%i, 0] [1, %dim1] [1, 1]
+        : tensor<1x?xf32> into tensor<?x?xf32>
+    }
+  }
+  return %result : tensor<?x?xf32>
+}
diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index e203a517932237..284e141d87bcc3 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -1891,32 +1891,3 @@ func.func @index_switch_fold_no_res() {
 
 // CHECK-LABEL: func.func @index_switch_fold_no_res()
 //  CHECK-NEXT: "test.op"() : () -> ()
-
-// -----
-
-func.func @forall_iter_to_init_arg(
-  %arg0 : tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
-  %c0 = arith.constant 0 : index
-  %c1 = arith.constant 1 : index
-  %dim0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
-
-  %result = scf.forall (%i) = (%c0) to (%dim0)
-      step (%c1) shared_outs(%o = %arg1) -> (tensor<?x?xf32>) {
-
-    %dim1 = tensor.dim %o, %c1 : tensor<?x?xf32>
-    %slice = tensor.extract_slice %arg1[%i, 0] [1, %dim1] [1, 1]
-      : tensor<?x?xf32> to tensor<1x?xf32>
-
-    scf.forall.in_parallel {
-      tensor.parallel_insert_slice %slice into %o[%i, 0] [1, %dim1] [1, 1]
-        : tensor<1x?xf32> into tensor<?x?xf32>
-    }
-  }
-
-  return %result : tensor<?x?xf32>
-}
-// CHECK-LABEL: @forall_iter_to_init_arg
-//  CHECK-SAME:   (%[[ARG0:.*]]: tensor<?x?xf32>, %[[ARG1:.*]]: tensor<?x?xf32>) -> tensor<?x?xf32> {
-//       CHECK:    %[[RESULT:.*]] = scf.forall
-//  CHECK-SAME:                       shared_outs(%[[OUTS:.*]] = %[[ARG1]]) -> (tensor<?x?xf32>) {
-//  CHECK-NEXT:       %{{.*}} = tensor.dim %[[ARG1]], %{{.*}} : tensor<?x?xf32>



More information about the Mlir-commits mailing list