[Mlir-commits] [mlir] [mlir] Add forall canonicalization to replace constant induction vars (PR #112764)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Oct 17 13:44:09 PDT 2024


https://github.com/Max191 updated https://github.com/llvm/llvm-project/pull/112764

>From ab5c0725206d9bc44b753f91be69d15329735985 Mon Sep 17 00:00:00 2001
From: Max Dawkins <max.dawkins at gmail.com>
Date: Wed, 16 Oct 2024 13:29:10 -0500
Subject: [PATCH 1/3] [mlir] Add forall canonicalization to replace constant
 induction vars

Signed-off-by: Max Dawkins <max.dawkins at gmail.com>
---
 mlir/lib/Dialect/SCF/IR/SCF.cpp         | 29 ++++++++++++++++++++++++-
 mlir/test/Dialect/SCF/canonicalize.mlir |  2 ++
 2 files changed, 30 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 2582d4e0df1920..7789f21af00780 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -1767,6 +1767,32 @@ struct ForallOpSingleOrZeroIterationDimsFolder
   }
 };
 
+struct ForallOpReplaceConstantInductionVar : public OpRewritePattern<ForallOp> {
+  using OpRewritePattern<ForallOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(ForallOp op,
+                                PatternRewriter &rewriter) const override {
+    // Replace all induction vars with a single trip count with their lower
+    // bound.
+    Location loc = op.getLoc();
+    bool replacedIv = false;
+    for (auto [lb, ub, step, iv] :
+         llvm::zip(op.getMixedLowerBound(), op.getMixedUpperBound(),
+                   op.getMixedStep(), op.getInductionVars())) {
+      if (iv.getUses().begin() == iv.getUses().end())
+        continue;
+      auto numIterations = constantTripCount(lb, ub, step);
+      if (!numIterations.has_value() || numIterations.value() != 1) {
+        continue;
+      }
+      rewriter.replaceAllUsesWith(
+          iv, getValueOrCreateConstantIndexOp(rewriter, loc, lb));
+      return success();
+    }
+    return failure();
+  }
+};
+
 struct FoldTensorCastOfOutputIntoForallOp
     : public OpRewritePattern<scf::ForallOp> {
   using OpRewritePattern<scf::ForallOp>::OpRewritePattern;
@@ -1851,7 +1877,8 @@ void ForallOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                            MLIRContext *context) {
   results.add<DimOfForallOp, FoldTensorCastOfOutputIntoForallOp,
               ForallOpControlOperandsFolder, ForallOpIterArgsFolder,
-              ForallOpSingleOrZeroIterationDimsFolder>(context);
+              ForallOpSingleOrZeroIterationDimsFolder,
+              ForallOpReplaceConstantInductionVar>(context);
 }
 
 /// Given the region at `index`, or the parent operation if `index` is None,
diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index c68369a8e4fce7..6f4703c04dc768 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -1632,6 +1632,8 @@ func.func @do_not_inline_distributed_forall_loop(
 }
 // CHECK-LABEL: @do_not_inline_distributed_forall_loop
 // CHECK: scf.forall
+// CHECK:   tensor.extract_slice %{{.*}}[0, 0] [2, 3] [1, 1]
+// CHECK:   tensor.parallel_insert_slice %{{.*}}[0, 0] [2, 3] [1, 1]
 
 // -----
 

>From 75f86d5c4c9f3596d8c30c152a2eafb498c01f80 Mon Sep 17 00:00:00 2001
From: Max Dawkins <max.dawkins at gmail.com>
Date: Thu, 17 Oct 2024 15:32:56 -0400
Subject: [PATCH 2/3] add non zero test

Signed-off-by: Max Dawkins <max.dawkins at gmail.com>
---
 mlir/test/Dialect/SCF/canonicalize.mlir | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index 6f4703c04dc768..8c4e7a41ee6bc4 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -1617,7 +1617,7 @@ func.func @do_not_inline_distributed_forall_loop(
     %in: tensor<8x8xf32>) -> tensor<8x8xf32> {
   %cst = arith.constant 0.000000e+00 : f32
   %0 = tensor.empty() : tensor<8x8xf32>
-  %1 = scf.forall (%i, %j) = (0, 0) to (1, 1) step (8, 8)
+  %1 = scf.forall (%i, %j) = (0, 4) to (1, 5) step (8, 8)
       shared_outs (%out_ = %0) -> (tensor<8x8xf32>) {
     %slice = tensor.extract_slice %out_[%i, %j] [2, 3] [1, 1]
       : tensor<8x8xf32> to tensor<2x3xf32>
@@ -1632,8 +1632,8 @@ func.func @do_not_inline_distributed_forall_loop(
 }
 // CHECK-LABEL: @do_not_inline_distributed_forall_loop
 // CHECK: scf.forall
-// CHECK:   tensor.extract_slice %{{.*}}[0, 0] [2, 3] [1, 1]
-// CHECK:   tensor.parallel_insert_slice %{{.*}}[0, 0] [2, 3] [1, 1]
+// CHECK:   tensor.extract_slice %{{.*}}[0, 4] [2, 3] [1, 1]
+// CHECK:   tensor.parallel_insert_slice %{{.*}}[0, 4] [2, 3] [1, 1]
 
 // -----
 

>From ce389815649887a9807cd2632704c02de284d271 Mon Sep 17 00:00:00 2001
From: Max Dawkins <max.dawkins at gmail.com>
Date: Thu, 17 Oct 2024 16:43:57 -0400
Subject: [PATCH 3/3] address comment

Signed-off-by: Max Dawkins <max.dawkins at gmail.com>
---
 mlir/lib/Dialect/SCF/IR/SCF.cpp | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 7789f21af00780..84fc27e49d2af9 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -1775,7 +1775,7 @@ struct ForallOpReplaceConstantInductionVar : public OpRewritePattern<ForallOp> {
     // Replace all induction vars with a single trip count with their lower
     // bound.
     Location loc = op.getLoc();
-    bool replacedIv = false;
+    bool changed = false;
     for (auto [lb, ub, step, iv] :
          llvm::zip(op.getMixedLowerBound(), op.getMixedUpperBound(),
                    op.getMixedStep(), op.getInductionVars())) {
@@ -1787,9 +1787,9 @@ struct ForallOpReplaceConstantInductionVar : public OpRewritePattern<ForallOp> {
       }
       rewriter.replaceAllUsesWith(
           iv, getValueOrCreateConstantIndexOp(rewriter, loc, lb));
-      return success();
+      changed = true;
     }
-    return failure();
+    return success(changed);
   }
 };
 



More information about the Mlir-commits mailing list