[Mlir-commits] [mlir] [MLIR][SCF] Fold dim ops of iter_args to respective init_args (PR #109973)

Prashant Kumar llvmlistbot at llvm.org
Wed Sep 25 06:15:57 PDT 2024


https://github.com/pashu123 updated https://github.com/llvm/llvm-project/pull/109973

>From f73b26ad8687287722e687fd884acf912f3a3f16 Mon Sep 17 00:00:00 2001
From: Prashant Kumar <pk5561 at gmail.com>
Date: Wed, 25 Sep 2024 18:05:31 +0530
Subject: [PATCH 1/2] [MLIR][SCF] Remove whitespaces.

---
 mlir/test/Dialect/SCF/canonicalize.mlir | 10 ++++++----
 1 file changed, 6 insertions(+), 4 deletions(-)

diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index c68369a8e4fce7..2eab04363d38f3 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -1787,7 +1787,7 @@ module {
 }
 // CHECK-LABEL: @fold_iter_args_not_being_modified_within_scfforall
 //  CHECK-SAME:   (%{{.*}}: index, %[[ARG1:.*]]: tensor<?xf32>, %[[ARG2:.*]]: tensor<?xf32>) -> (tensor<?xf32>, tensor<?xf32>) {
-//       CHECK:    %[[RESULT:.*]] = scf.forall 
+//       CHECK:    %[[RESULT:.*]] = scf.forall
 //  CHECK-SAME:                       shared_outs(%[[ITER_ARG_5:.*]] = %[[ARG2]]) -> (tensor<?xf32>) {
 //       CHECK:      %[[OPERAND0:.*]] = tensor.extract_slice %[[ARG1]]
 //       CHECK:      %[[OPERAND1:.*]] = tensor.extract_slice %[[ITER_ARG_5]]
@@ -1830,7 +1830,7 @@ module {
 }
 // CHECK-LABEL: @fold_iter_args_with_no_use_of_result_scfforall
 //  CHECK-SAME:   (%{{.*}}: index, %[[ARG1:.*]]: tensor<?xf32>, %[[ARG2:.*]]: tensor<?xf32>, %[[ARG3:.*]]: tensor<?xf32>) -> tensor<?xf32> {
-//       CHECK:    %[[RESULT:.*]] = scf.forall 
+//       CHECK:    %[[RESULT:.*]] = scf.forall
 //  CHECK-SAME:                       shared_outs(%[[ITER_ARG_6:.*]] = %[[ARG2]]) -> (tensor<?xf32>) {
 //       CHECK:      %[[OPERAND0:.*]] = tensor.extract_slice %[[ARG1]]
 //       CHECK:      %[[OPERAND1:.*]] = tensor.extract_slice %[[ARG3]]
@@ -1854,7 +1854,7 @@ func.func @index_switch_fold() -> (f32, f32) {
     %y = arith.constant 42.0 : f32
     scf.yield %y : f32
   }
-  
+
   %switch_cst_2 = arith.constant 2: index
   %1 = scf.index_switch %switch_cst_2 -> f32
   case 0 {
@@ -1865,7 +1865,7 @@ func.func @index_switch_fold() -> (f32, f32) {
     %y = arith.constant 42.0 : f32
     scf.yield %y : f32
   }
-  
+
   return %0, %1 : f32, f32
 }
 
@@ -1891,3 +1891,5 @@ func.func @index_switch_fold_no_res() {
 
 // CHECK-LABEL: func.func @index_switch_fold_no_res()
 //  CHECK-NEXT: "test.op"() : () -> ()
+
+// -----

>From 25480550ebad8b29e2f38637ce22a7e2c4ddb297 Mon Sep 17 00:00:00 2001
From: Prashant Kumar <pk5561 at gmail.com>
Date: Wed, 25 Sep 2024 18:07:10 +0530
Subject: [PATCH 2/2] [MLIR][SCF] Fold dim ops of iter_args to respective
 init_args

---
 mlir/lib/Dialect/SCF/IR/SCF.cpp         | 46 +++++++++++++++++++++++--
 mlir/test/Dialect/SCF/canonicalize.mlir | 27 +++++++++++++++
 2 files changed, 70 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 6d47ff3890977a..9705a9cd1741b9 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -1478,6 +1478,45 @@ struct DimOfForallOp : public OpRewritePattern<tensor::DimOp> {
   }
 };
 
+/// Fold dim ops of iter_args to dim ops of their respective init args. E.g.:
+///
+/// ```
+/// %0 = ... : tensor<?x?xf32>
+/// scf.forall ... shared_outs(%arg0 = %0) -> (tensor<?x?xf32>) {
+///   %1 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
+///   ...
+/// }
+/// ```
+///
+/// is folded to:
+///
+/// ```
+/// %0 = ... : tensor<?x?xf32>
+/// scf.forall ... shared_outs(%arg0 = %0) -> (tensor<?x?xf32>) {
+///   %1 = tensor.dim %0, %c0 : tensor<?x?xf32>
+///   ...
+/// }
+/// ```
+struct DimOfForallIterArg : public OpRewritePattern<tensor::DimOp> {
+  using OpRewritePattern<tensor::DimOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(tensor::DimOp dimOp,
+                                PatternRewriter &rewriter) const final {
+    auto blockArg = dyn_cast<BlockArgument>(dimOp.getSource());
+    if (!blockArg)
+      return failure();
+    auto forallOp =
+        dyn_cast<ForallOp>(blockArg.getParentBlock()->getParentOp());
+    if (!forallOp)
+      return failure();
+    Value initArg = forallOp.getTiedLoopInit(blockArg)->get();
+    rewriter.modifyOpInPlace(
+        dimOp, [&]() { dimOp.getSourceMutable().assign(initArg); });
+
+    return success();
+  }
+};
+
 class ForallOpControlOperandsFolder : public OpRewritePattern<ForallOp> {
 public:
   using OpRewritePattern<ForallOp>::OpRewritePattern;
@@ -1851,9 +1890,10 @@ struct FoldTensorCastOfOutputIntoForallOp
 
 void ForallOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                            MLIRContext *context) {
-  results.add<DimOfForallOp, FoldTensorCastOfOutputIntoForallOp,
-              ForallOpControlOperandsFolder, ForallOpIterArgsFolder,
-              ForallOpSingleOrZeroIterationDimsFolder>(context);
+  results.add<DimOfForallOp, DimOfForallIterArg,
+              FoldTensorCastOfOutputIntoForallOp, ForallOpControlOperandsFolder,
+              ForallOpIterArgsFolder, ForallOpSingleOrZeroIterationDimsFolder>(
+      context);
 }
 
 /// Given the region at `index`, or the parent operation if `index` is None,
diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index 2eab04363d38f3..e203a517932237 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -1893,3 +1893,30 @@ func.func @index_switch_fold_no_res() {
 //  CHECK-NEXT: "test.op"() : () -> ()
 
 // -----
+
+func.func @forall_iter_to_init_arg(
+  %arg0 : tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %dim0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
+
+  %result = scf.forall (%i) = (%c0) to (%dim0)
+      step (%c1) shared_outs(%o = %arg1) -> (tensor<?x?xf32>) {
+
+    %dim1 = tensor.dim %o, %c1 : tensor<?x?xf32>
+    %slice = tensor.extract_slice %arg1[%i, 0] [1, %dim1] [1, 1]
+      : tensor<?x?xf32> to tensor<1x?xf32>
+
+    scf.forall.in_parallel {
+      tensor.parallel_insert_slice %slice into %o[%i, 0] [1, %dim1] [1, 1]
+        : tensor<1x?xf32> into tensor<?x?xf32>
+    }
+  }
+
+  return %result : tensor<?x?xf32>
+}
+// CHECK-LABEL: @forall_iter_to_init_arg
+//  CHECK-SAME:   (%[[ARG0:.*]]: tensor<?x?xf32>, %[[ARG1:.*]]: tensor<?x?xf32>) -> tensor<?x?xf32> {
+//       CHECK:    %[[RESULT:.*]] = scf.forall
+//  CHECK-SAME:                       shared_outs(%[[OUTS:.*]] = %[[ARG1]]) -> (tensor<?x?xf32>) {
+//  CHECK-NEXT:       %{{.*}} = tensor.dim %[[ARG1]], %{{.*}} : tensor<?x?xf32>



More information about the Mlir-commits mailing list