[Mlir-commits] [mlir] [mlir] Add forall canonicalization to replace constant induction vars (PR #112764)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Oct 17 12:06:44 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: None (Max191)
<details>
<summary>Changes</summary>
Adds a canonicalization pattern for scf.forall that replaces constant induction variables with a constant index. There is a similar canonicalization that completely removes constant induction variables from the loop, but that pattern does not apply on foralls with mappings, so this one is necessary for those cases.
---
Full diff: https://github.com/llvm/llvm-project/pull/112764.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/SCF/IR/SCF.cpp (+28-1)
- (modified) mlir/test/Dialect/SCF/canonicalize.mlir (+2)
``````````diff
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 2582d4e0df1920..7789f21af00780 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -1767,6 +1767,32 @@ struct ForallOpSingleOrZeroIterationDimsFolder
}
};
+struct ForallOpReplaceConstantInductionVar : public OpRewritePattern<ForallOp> {
+ using OpRewritePattern<ForallOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(ForallOp op,
+ PatternRewriter &rewriter) const override {
+ // Replace all induction vars with a single trip count with their lower
+ // bound.
+ Location loc = op.getLoc();
+ bool replacedIv = false;
+ for (auto [lb, ub, step, iv] :
+ llvm::zip(op.getMixedLowerBound(), op.getMixedUpperBound(),
+ op.getMixedStep(), op.getInductionVars())) {
+ if (iv.getUses().begin() == iv.getUses().end())
+ continue;
+ auto numIterations = constantTripCount(lb, ub, step);
+ if (!numIterations.has_value() || numIterations.value() != 1) {
+ continue;
+ }
+ rewriter.replaceAllUsesWith(
+ iv, getValueOrCreateConstantIndexOp(rewriter, loc, lb));
+ return success();
+ }
+ return failure();
+ }
+};
+
struct FoldTensorCastOfOutputIntoForallOp
: public OpRewritePattern<scf::ForallOp> {
using OpRewritePattern<scf::ForallOp>::OpRewritePattern;
@@ -1851,7 +1877,8 @@ void ForallOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<DimOfForallOp, FoldTensorCastOfOutputIntoForallOp,
ForallOpControlOperandsFolder, ForallOpIterArgsFolder,
- ForallOpSingleOrZeroIterationDimsFolder>(context);
+ ForallOpSingleOrZeroIterationDimsFolder,
+ ForallOpReplaceConstantInductionVar>(context);
}
/// Given the region at `index`, or the parent operation if `index` is None,
diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index c68369a8e4fce7..6f4703c04dc768 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -1632,6 +1632,8 @@ func.func @do_not_inline_distributed_forall_loop(
}
// CHECK-LABEL: @do_not_inline_distributed_forall_loop
// CHECK: scf.forall
+// CHECK: tensor.extract_slice %{{.*}}[0, 0] [2, 3] [1, 1]
+// CHECK: tensor.parallel_insert_slice %{{.*}}[0, 0] [2, 3] [1, 1]
// -----
``````````
</details>
https://github.com/llvm/llvm-project/pull/112764
More information about the Mlir-commits
mailing list