[Mlir-commits] [mlir] [mlir] Add forall canonicalization to replace constant induction vars (PR #112764)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Oct 18 06:39:30 PDT 2024
https://github.com/Max191 updated https://github.com/llvm/llvm-project/pull/112764
>From a00ac08a65b80f3c094d55ccf9cb90dc8d28dcfa Mon Sep 17 00:00:00 2001
From: Max Dawkins <max.dawkins at gmail.com>
Date: Wed, 16 Oct 2024 13:29:10 -0500
Subject: [PATCH 1/4] [mlir] Add forall canonicalization to replace constant
induction vars
Signed-off-by: Max Dawkins <max.dawkins at gmail.com>
---
mlir/lib/Dialect/SCF/IR/SCF.cpp | 29 ++++++++++++++++++++++++-
mlir/test/Dialect/SCF/canonicalize.mlir | 2 ++
2 files changed, 30 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 2582d4e0df1920..7789f21af00780 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -1767,6 +1767,32 @@ struct ForallOpSingleOrZeroIterationDimsFolder
}
};
+struct ForallOpReplaceConstantInductionVar : public OpRewritePattern<ForallOp> {
+ using OpRewritePattern<ForallOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(ForallOp op,
+ PatternRewriter &rewriter) const override {
+ // Replace all induction vars with a single trip count with their lower
+ // bound.
+ Location loc = op.getLoc();
+ bool replacedIv = false;
+ for (auto [lb, ub, step, iv] :
+ llvm::zip(op.getMixedLowerBound(), op.getMixedUpperBound(),
+ op.getMixedStep(), op.getInductionVars())) {
+ if (iv.getUses().begin() == iv.getUses().end())
+ continue;
+ auto numIterations = constantTripCount(lb, ub, step);
+ if (!numIterations.has_value() || numIterations.value() != 1) {
+ continue;
+ }
+ rewriter.replaceAllUsesWith(
+ iv, getValueOrCreateConstantIndexOp(rewriter, loc, lb));
+ return success();
+ }
+ return failure();
+ }
+};
+
struct FoldTensorCastOfOutputIntoForallOp
: public OpRewritePattern<scf::ForallOp> {
using OpRewritePattern<scf::ForallOp>::OpRewritePattern;
@@ -1851,7 +1877,8 @@ void ForallOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<DimOfForallOp, FoldTensorCastOfOutputIntoForallOp,
ForallOpControlOperandsFolder, ForallOpIterArgsFolder,
- ForallOpSingleOrZeroIterationDimsFolder>(context);
+ ForallOpSingleOrZeroIterationDimsFolder,
+ ForallOpReplaceConstantInductionVar>(context);
}
/// Given the region at `index`, or the parent operation if `index` is None,
diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index c68369a8e4fce7..6f4703c04dc768 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -1632,6 +1632,8 @@ func.func @do_not_inline_distributed_forall_loop(
}
// CHECK-LABEL: @do_not_inline_distributed_forall_loop
// CHECK: scf.forall
+// CHECK: tensor.extract_slice %{{.*}}[0, 0] [2, 3] [1, 1]
+// CHECK: tensor.parallel_insert_slice %{{.*}}[0, 0] [2, 3] [1, 1]
// -----
>From 43ece6b2deea0c1c94abf4ec6b57f6a816262409 Mon Sep 17 00:00:00 2001
From: Max Dawkins <max.dawkins at gmail.com>
Date: Thu, 17 Oct 2024 15:32:56 -0400
Subject: [PATCH 2/4] add non zero test
Signed-off-by: Max Dawkins <max.dawkins at gmail.com>
---
mlir/test/Dialect/SCF/canonicalize.mlir | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index 6f4703c04dc768..8c4e7a41ee6bc4 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -1617,7 +1617,7 @@ func.func @do_not_inline_distributed_forall_loop(
%in: tensor<8x8xf32>) -> tensor<8x8xf32> {
%cst = arith.constant 0.000000e+00 : f32
%0 = tensor.empty() : tensor<8x8xf32>
- %1 = scf.forall (%i, %j) = (0, 0) to (1, 1) step (8, 8)
+ %1 = scf.forall (%i, %j) = (0, 4) to (1, 5) step (8, 8)
shared_outs (%out_ = %0) -> (tensor<8x8xf32>) {
%slice = tensor.extract_slice %out_[%i, %j] [2, 3] [1, 1]
: tensor<8x8xf32> to tensor<2x3xf32>
@@ -1632,8 +1632,8 @@ func.func @do_not_inline_distributed_forall_loop(
}
// CHECK-LABEL: @do_not_inline_distributed_forall_loop
// CHECK: scf.forall
-// CHECK: tensor.extract_slice %{{.*}}[0, 0] [2, 3] [1, 1]
-// CHECK: tensor.parallel_insert_slice %{{.*}}[0, 0] [2, 3] [1, 1]
+// CHECK: tensor.extract_slice %{{.*}}[0, 4] [2, 3] [1, 1]
+// CHECK: tensor.parallel_insert_slice %{{.*}}[0, 4] [2, 3] [1, 1]
// -----
>From 8350c931689bf2d281aec7fb38cc4692bbef02db Mon Sep 17 00:00:00 2001
From: Max Dawkins <max.dawkins at gmail.com>
Date: Thu, 17 Oct 2024 16:43:57 -0400
Subject: [PATCH 3/4] address comment
Signed-off-by: Max Dawkins <max.dawkins at gmail.com>
---
mlir/lib/Dialect/SCF/IR/SCF.cpp | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 7789f21af00780..84fc27e49d2af9 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -1775,7 +1775,7 @@ struct ForallOpReplaceConstantInductionVar : public OpRewritePattern<ForallOp> {
// Replace all induction vars with a single trip count with their lower
// bound.
Location loc = op.getLoc();
- bool replacedIv = false;
+ bool changed = false;
for (auto [lb, ub, step, iv] :
llvm::zip(op.getMixedLowerBound(), op.getMixedUpperBound(),
op.getMixedStep(), op.getInductionVars())) {
@@ -1787,9 +1787,9 @@ struct ForallOpReplaceConstantInductionVar : public OpRewritePattern<ForallOp> {
}
rewriter.replaceAllUsesWith(
iv, getValueOrCreateConstantIndexOp(rewriter, loc, lb));
- return success();
+ changed = true;
}
- return failure();
+ return success(changed);
}
};
>From 731a481e90fbb43429b8c79a64a54b043e7d3b54 Mon Sep 17 00:00:00 2001
From: Max Dawkins <max.dawkins at gmail.com>
Date: Fri, 18 Oct 2024 09:38:43 -0400
Subject: [PATCH 4/4] move comment
Signed-off-by: Max Dawkins <max.dawkins at gmail.com>
---
mlir/lib/Dialect/SCF/IR/SCF.cpp | 3 +--
1 file changed, 1 insertion(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 84fc27e49d2af9..6678878215c11f 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -1767,13 +1767,12 @@ struct ForallOpSingleOrZeroIterationDimsFolder
}
};
+/// Replace all induction vars with a single trip count with their lower bound.
struct ForallOpReplaceConstantInductionVar : public OpRewritePattern<ForallOp> {
using OpRewritePattern<ForallOp>::OpRewritePattern;
LogicalResult matchAndRewrite(ForallOp op,
PatternRewriter &rewriter) const override {
- // Replace all induction vars with a single trip count with their lower
- // bound.
Location loc = op.getLoc();
bool changed = false;
for (auto [lb, ub, step, iv] :
More information about the Mlir-commits
mailing list