[Mlir-commits] [mlir] [mlir][scf] Uplift `scf.while` to `scf.for` (PR #76108)
Ivan Butygin
llvmlistbot at llvm.org
Thu Jan 4 12:01:08 PST 2024
https://github.com/Hardcode84 updated https://github.com/llvm/llvm-project/pull/76108
>From 4841b833c3b903d9bbc0a44eb01561df8cf0cbd0 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Wed, 20 Dec 2023 22:18:29 +0100
Subject: [PATCH 1/6] [mlir][scf] Uplift `scf.while` to `scf.for`
Add uplifting from `scf.while` to `scf.for`.
This uplifting expects a very specifi ops pattern:
* `before` body consisting of single `arith.cmp` op
* `after` body containing `arith.addi`
* Iter var must be of type `index` or integer of specified width
We also have a set of patterns to clenaup `scf.while` loops to get them close to the desired form, they will be added in separate PRs.
This is part of upstreaming `numba-mlir` scf uplifting pipeline:
`cf -> scf.while -> scf.for -> scf.parallel`
Original code: https://github.com/numba/numba-mlir/blob/main/mlir/lib/Transforms/PromoteToParallel.cpp
---
.../mlir/Dialect/SCF/Transforms/Passes.td | 16 ++
.../mlir/Dialect/SCF/Transforms/Patterns.h | 4 +
.../lib/Dialect/SCF/Transforms/CMakeLists.txt | 1 +
.../SCF/Transforms/UpliftWhileToFor.cpp | 250 ++++++++++++++++++
mlir/test/Dialect/SCF/uplift-while.mlir | 162 ++++++++++++
5 files changed, 433 insertions(+)
create mode 100644 mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp
create mode 100644 mlir/test/Dialect/SCF/uplift-while.mlir
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
index 350611ad86873d..ec28bb0b8b8aa8 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
@@ -154,4 +154,20 @@ def SCFForToWhileLoop : Pass<"scf-for-to-while"> {
}];
}
+def SCFUpliftWhileToFor : Pass<"scf-uplift-while-to-for"> {
+ let summary = "Uplift scf.while ops to scf.for";
+ let description = [{
+ This pass tries to uplift `scf.while` ops to `scf.for` if they have a
+ compatible form. `scf.while` are left unchanged if uplifting is not
+ possible.
+ }];
+
+ let options = [
+ Option<"indexBitWidth", "index-bitwidth", "unsigned",
+ /*default=*/"64",
+ "Bitwidth of index type.">,
+ ];
+ }
+
+
#endif // MLIR_DIALECT_SCF_PASSES
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h b/mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h
index 5c0d5643c01986..9f3cdd93071ea9 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h
@@ -79,6 +79,10 @@ void populateSCFLoopPipeliningPatterns(RewritePatternSet &patterns,
/// loop bounds and loop steps are canonicalized.
void populateSCFForLoopCanonicalizationPatterns(RewritePatternSet &patterns);
+/// Populate patterns to uplift `scf.while` ops to `scf.for`.
+void populateUpliftWhileToForPatterns(RewritePatternSet &patterns,
+ unsigned indexBitwidth);
+
} // namespace scf
} // namespace mlir
diff --git a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
index fdaeb2fad9afa4..7643bab80a1308 100644
--- a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
@@ -13,6 +13,7 @@ add_mlir_dialect_library(MLIRSCFTransforms
ParallelLoopTiling.cpp
StructuralTypeConversions.cpp
TileUsingInterface.cpp
+ UpliftWhileToFor.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/SCF
diff --git a/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp b/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp
new file mode 100644
index 00000000000000..cd16b622504953
--- /dev/null
+++ b/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp
@@ -0,0 +1,250 @@
+//===- UpliftWhileToFor.cpp - scf.while to scf.for loop uplifting ---------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Transforms SCF.WhileOp's into SCF.ForOp's.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/SCF/Transforms/Passes.h"
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/SCF/Transforms/Patterns.h"
+#include "mlir/IR/Dominance.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_SCFUPLIFTWHILETOFOR
+#include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+
+static bool checkIndexType(arith::CmpIOp op, unsigned indexBitWidth) {
+ auto type = op.getLhs().getType();
+ if (isa<mlir::IndexType>(type))
+ return true;
+
+ if (type.isSignlessInteger(indexBitWidth))
+ return true;
+
+ return false;
+}
+
+namespace {
+struct UpliftWhileOp : public OpRewritePattern<scf::WhileOp> {
+ UpliftWhileOp(MLIRContext *context, unsigned indexBitWidth_)
+ : OpRewritePattern<scf::WhileOp>(context), indexBitWidth(indexBitWidth_) {
+ }
+
+ LogicalResult matchAndRewrite(scf::WhileOp loop,
+ PatternRewriter &rewriter) const override {
+ Block *beforeBody = loop.getBeforeBody();
+ if (!llvm::hasSingleElement(beforeBody->without_terminator()))
+ return rewriter.notifyMatchFailure(loop, "Loop body must have single op");
+
+ auto cmp = dyn_cast<arith::CmpIOp>(beforeBody->front());
+ if (!cmp)
+ return rewriter.notifyMatchFailure(loop,
+ "Loop body must have single cmp op");
+
+ auto beforeTerm = cast<scf::ConditionOp>(beforeBody->getTerminator());
+ if (!llvm::hasSingleElement(cmp->getUses()) &&
+ beforeTerm.getCondition() == cmp.getResult())
+ return rewriter.notifyMatchFailure(loop, [&](Diagnostic &diag) {
+ diag << "Expected single condiditon use: " << *cmp;
+ });
+
+ if (ValueRange(beforeBody->getArguments()) != beforeTerm.getArgs())
+ return rewriter.notifyMatchFailure(loop, "Invalid args order");
+
+ using Pred = arith::CmpIPredicate;
+ auto predicate = cmp.getPredicate();
+ if (predicate != Pred::slt && predicate != Pred::sgt)
+ return rewriter.notifyMatchFailure(loop, [&](Diagnostic &diag) {
+ diag << "Expected 'slt' or 'sgt' predicate: " << *cmp;
+ });
+
+ if (!checkIndexType(cmp, indexBitWidth))
+ return rewriter.notifyMatchFailure(loop, [&](Diagnostic &diag) {
+ diag << "Expected index-like type: " << *cmp;
+ });
+
+ BlockArgument iterVar;
+ Value end;
+ DominanceInfo dom;
+ for (bool reverse : {false, true}) {
+ auto expectedPred = reverse ? Pred::sgt : Pred::slt;
+ if (cmp.getPredicate() != expectedPred)
+ continue;
+
+ auto arg1 = reverse ? cmp.getRhs() : cmp.getLhs();
+ auto arg2 = reverse ? cmp.getLhs() : cmp.getRhs();
+
+ auto blockArg = dyn_cast<BlockArgument>(arg1);
+ if (!blockArg || blockArg.getOwner() != beforeBody)
+ continue;
+
+ if (!dom.properlyDominates(arg2, loop))
+ continue;
+
+ iterVar = blockArg;
+ end = arg2;
+ break;
+ }
+
+ if (!iterVar)
+ return rewriter.notifyMatchFailure(loop, [&](Diagnostic &diag) {
+ diag << "Unrecognized cmp form: " << *cmp;
+ });
+
+ if (!llvm::hasNItems(iterVar.getUses(), 2))
+ return rewriter.notifyMatchFailure(loop, [&](Diagnostic &diag) {
+ diag << "Unrecognized iter var: " << iterVar;
+ });
+
+ Block *afterBody = loop.getAfterBody();
+ auto afterTerm = cast<scf::YieldOp>(afterBody->getTerminator());
+ auto argNumber = iterVar.getArgNumber();
+ auto afterTermIterArg = afterTerm.getResults()[argNumber];
+
+ auto iterVarAfter = afterBody->getArgument(argNumber);
+
+ Value step;
+ for (auto &use : iterVarAfter.getUses()) {
+ auto owner = dyn_cast<arith::AddIOp>(use.getOwner());
+ if (!owner)
+ continue;
+
+ auto other =
+ (iterVarAfter == owner.getLhs() ? owner.getRhs() : owner.getLhs());
+ if (!dom.properlyDominates(other, loop))
+ continue;
+
+ if (afterTermIterArg != owner.getResult())
+ continue;
+
+ step = other;
+ break;
+ }
+
+ if (!step)
+ return rewriter.notifyMatchFailure(loop,
+ "Didn't found suitable 'add' op");
+
+ auto begin = loop.getInits()[argNumber];
+
+ auto loc = loop.getLoc();
+ auto indexType = rewriter.getIndexType();
+ auto toIndex = [&](Value val) -> Value {
+ if (val.getType() != indexType)
+ return rewriter.create<arith::IndexCastOp>(loc, indexType, val);
+
+ return val;
+ };
+ begin = toIndex(begin);
+ end = toIndex(end);
+ step = toIndex(step);
+
+ llvm::SmallVector<Value> mapping;
+ mapping.reserve(loop.getInits().size());
+ for (auto &&[i, init] : llvm::enumerate(loop.getInits())) {
+ if (i == argNumber)
+ continue;
+
+ mapping.emplace_back(init);
+ }
+
+ auto emptyBuidler = [](OpBuilder &, Location, Value, ValueRange) {};
+ auto newLoop = rewriter.create<scf::ForOp>(loc, begin, end, step, mapping,
+ emptyBuidler);
+
+ Block *newBody = newLoop.getBody();
+
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPointToStart(newBody);
+ Value newIterVar = newBody->getArgument(0);
+ if (newIterVar.getType() != iterVar.getType())
+ newIterVar = rewriter.create<arith::IndexCastOp>(loc, iterVar.getType(),
+ newIterVar);
+
+ mapping.clear();
+ auto newArgs = newBody->getArguments();
+ for (auto i : llvm::seq<size_t>(0, newArgs.size())) {
+ if (i < argNumber) {
+ mapping.emplace_back(newArgs[i + 1]);
+ } else if (i == argNumber) {
+ Value arg = newArgs.front();
+ if (arg.getType() != iterVar.getType())
+ arg =
+ rewriter.create<arith::IndexCastOp>(loc, iterVar.getType(), arg);
+ mapping.emplace_back(arg);
+ } else {
+ mapping.emplace_back(newArgs[i]);
+ }
+ }
+
+ rewriter.inlineBlockBefore(loop.getAfterBody(), newBody, newBody->end(),
+ mapping);
+
+ auto term = cast<scf::YieldOp>(newBody->getTerminator());
+
+ mapping.clear();
+ for (auto &&[i, arg] : llvm::enumerate(term.getResults())) {
+ if (i == argNumber)
+ continue;
+
+ mapping.emplace_back(arg);
+ }
+
+ rewriter.setInsertionPoint(term);
+ rewriter.replaceOpWithNewOp<scf::YieldOp>(term, mapping);
+
+ rewriter.setInsertionPointAfter(newLoop);
+ Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+ Value stepDec = rewriter.create<arith::SubIOp>(loc, step, one);
+ Value len = rewriter.create<arith::SubIOp>(loc, end, begin);
+ len = rewriter.create<arith::AddIOp>(loc, len, stepDec);
+ len = rewriter.create<arith::DivSIOp>(loc, len, step);
+ len = rewriter.create<arith::SubIOp>(loc, len, one);
+ Value res = rewriter.create<arith::MulIOp>(loc, len, step);
+ res = rewriter.create<arith::AddIOp>(loc, begin, res);
+ if (res.getType() != iterVar.getType())
+ res = rewriter.create<arith::IndexCastOp>(loc, iterVar.getType(), res);
+
+ mapping.clear();
+ llvm::append_range(mapping, newLoop.getResults());
+ mapping.insert(mapping.begin() + argNumber, res);
+ rewriter.replaceOp(loop, mapping);
+ return success();
+ }
+
+private:
+ unsigned indexBitWidth = 0;
+};
+
+struct SCFUpliftWhileToFor final
+ : impl::SCFUpliftWhileToForBase<SCFUpliftWhileToFor> {
+ using SCFUpliftWhileToForBase::SCFUpliftWhileToForBase;
+
+ void runOnOperation() override {
+ Operation *op = getOperation();
+ MLIRContext *ctx = op->getContext();
+ RewritePatternSet patterns(ctx);
+ mlir::scf::populateUpliftWhileToForPatterns(patterns, this->indexBitWidth);
+ if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
+ signalPassFailure();
+ }
+};
+} // namespace
+
+void mlir::scf::populateUpliftWhileToForPatterns(RewritePatternSet &patterns,
+ unsigned indexBitwidth) {
+ patterns.add<UpliftWhileOp>(patterns.getContext(), indexBitwidth);
+}
diff --git a/mlir/test/Dialect/SCF/uplift-while.mlir b/mlir/test/Dialect/SCF/uplift-while.mlir
new file mode 100644
index 00000000000000..52a5c0f3cd6347
--- /dev/null
+++ b/mlir/test/Dialect/SCF/uplift-while.mlir
@@ -0,0 +1,162 @@
+// RUN: mlir-opt %s -pass-pipeline='builtin.module(func.func(scf-uplift-while-to-for{index-bitwidth=64}))' -split-input-file -allow-unregistered-dialect | FileCheck %s
+
+func.func @uplift_while(%arg0: index, %arg1: index, %arg2: index) -> index {
+ %0 = scf.while (%arg3 = %arg0) : (index) -> (index) {
+ %1 = arith.cmpi slt, %arg3, %arg1 : index
+ scf.condition(%1) %arg3 : index
+ } do {
+ ^bb0(%arg3: index):
+ "test.test1"(%arg3) : (index) -> ()
+ %added = arith.addi %arg3, %arg2 : index
+ "test.test2"(%added) : (index) -> ()
+ scf.yield %added : index
+ }
+ return %0 : index
+}
+
+// CHECK-LABEL: func @uplift_while
+// CHECK-SAME: (%[[BEGIN:.*]]: index, %[[END:.*]]: index, %[[STEP:.*]]: index) -> index
+// CHECK: %[[C1:.*]] = arith.constant 1 : index
+// CHECK: scf.for %[[I:.*]] = %[[BEGIN]] to %[[END]] step %[[STEP]] {
+// CHECK: "test.test1"(%[[I]]) : (index) -> ()
+// CHECK: %[[INC:.*]] = arith.addi %[[I]], %[[STEP]] : index
+// CHECK: "test.test2"(%[[INC]]) : (index) -> ()
+// CHECK: %[[R1:.*]] = arith.subi %[[STEP]], %[[C1]] : index
+// CHECK: %[[R2:.*]] = arith.subi %[[END]], %[[BEGIN]] : index
+// CHECK: %[[R3:.*]] = arith.addi %[[R2]], %[[R1]] : index
+// CHECK: %[[R4:.*]] = arith.divsi %[[R3]], %[[STEP]] : index
+// CHECK: %[[R5:.*]] = arith.subi %[[R4]], %[[C1]] : index
+// CHECK: %[[R6:.*]] = arith.muli %[[R5]], %[[STEP]] : index
+// CHECK: %[[R7:.*]] = arith.addi %[[BEGIN]], %[[R6]] : index
+// CHECK: return %[[R7]] : index
+
+// -----
+
+func.func @uplift_while(%arg0: index, %arg1: index, %arg2: index) -> index {
+ %0 = scf.while (%arg3 = %arg0) : (index) -> (index) {
+ %1 = arith.cmpi sgt, %arg1, %arg3 : index
+ scf.condition(%1) %arg3 : index
+ } do {
+ ^bb0(%arg3: index):
+ "test.test1"(%arg3) : (index) -> ()
+ %added = arith.addi %arg3, %arg2 : index
+ "test.test2"(%added) : (index) -> ()
+ scf.yield %added : index
+ }
+ return %0 : index
+}
+
+// CHECK-LABEL: func @uplift_while
+// CHECK-SAME: (%[[BEGIN:.*]]: index, %[[END:.*]]: index, %[[STEP:.*]]: index) -> index
+// CHECK: %[[C1:.*]] = arith.constant 1 : index
+// CHECK: scf.for %[[I:.*]] = %[[BEGIN]] to %[[END]] step %[[STEP]] {
+// CHECK: "test.test1"(%[[I]]) : (index) -> ()
+// CHECK: %[[INC:.*]] = arith.addi %[[I]], %[[STEP]] : index
+// CHECK: "test.test2"(%[[INC]]) : (index) -> ()
+// CHECK: %[[R1:.*]] = arith.subi %[[STEP]], %[[C1]] : index
+// CHECK: %[[R2:.*]] = arith.subi %[[END]], %[[BEGIN]] : index
+// CHECK: %[[R3:.*]] = arith.addi %[[R2]], %[[R1]] : index
+// CHECK: %[[R4:.*]] = arith.divsi %[[R3]], %[[STEP]] : index
+// CHECK: %[[R5:.*]] = arith.subi %[[R4]], %[[C1]] : index
+// CHECK: %[[R6:.*]] = arith.muli %[[R5]], %[[STEP]] : index
+// CHECK: %[[R7:.*]] = arith.addi %[[BEGIN]], %[[R6]] : index
+// CHECK: return %[[R7]] : index
+
+// -----
+
+func.func @uplift_while(%arg0: index, %arg1: index, %arg2: index) -> index {
+ %0 = scf.while (%arg3 = %arg0) : (index) -> (index) {
+ %1 = arith.cmpi slt, %arg3, %arg1 : index
+ scf.condition(%1) %arg3 : index
+ } do {
+ ^bb0(%arg3: index):
+ "test.test1"(%arg3) : (index) -> ()
+ %added = arith.addi %arg2, %arg3 : index
+ "test.test2"(%added) : (index) -> ()
+ scf.yield %added : index
+ }
+ return %0 : index
+}
+
+// CHECK-LABEL: func @uplift_while
+// CHECK-SAME: (%[[BEGIN:.*]]: index, %[[END:.*]]: index, %[[STEP:.*]]: index) -> index
+// CHECK: %[[C1:.*]] = arith.constant 1 : index
+// CHECK: scf.for %[[I:.*]] = %[[BEGIN]] to %[[END]] step %[[STEP]] {
+// CHECK: "test.test1"(%[[I]]) : (index) -> ()
+// CHECK: %[[INC:.*]] = arith.addi %[[STEP]], %[[I]] : index
+// CHECK: "test.test2"(%[[INC]]) : (index) -> ()
+// CHECK: %[[R1:.*]] = arith.subi %[[STEP]], %[[C1]] : index
+// CHECK: %[[R2:.*]] = arith.subi %[[END]], %[[BEGIN]] : index
+// CHECK: %[[R3:.*]] = arith.addi %[[R2]], %[[R1]] : index
+// CHECK: %[[R4:.*]] = arith.divsi %[[R3]], %[[STEP]] : index
+// CHECK: %[[R5:.*]] = arith.subi %[[R4]], %[[C1]] : index
+// CHECK: %[[R6:.*]] = arith.muli %[[R5]], %[[STEP]] : index
+// CHECK: %[[R7:.*]] = arith.addi %[[BEGIN]], %[[R6]] : index
+// CHECK: return %[[R7]] : index
+
+
+// -----
+
+func.func @uplift_while(%arg0: index, %arg1: index, %arg2: index) -> (i32, f32) {
+ %c1 = arith.constant 1 : i32
+ %c2 = arith.constant 2.0 : f32
+ %0:3 = scf.while (%arg4 = %c1, %arg3 = %arg0, %arg5 = %c2) : (i32, index, f32) -> (i32, index, f32) {
+ %1 = arith.cmpi slt, %arg3, %arg1 : index
+ scf.condition(%1) %arg4, %arg3, %arg5 : i32, index, f32
+ } do {
+ ^bb0(%arg4: i32, %arg3: index, %arg5: f32):
+ %1 = "test.test1"(%arg4) : (i32) -> i32
+ %added = arith.addi %arg3, %arg2 : index
+ %2 = "test.test2"(%arg5) : (f32) -> f32
+ scf.yield %1, %added, %2 : i32, index, f32
+ }
+ return %0#0, %0#2 : i32, f32
+}
+
+// CHECK-LABEL: func @uplift_while
+// CHECK-SAME: (%[[BEGIN:.*]]: index, %[[END:.*]]: index, %[[STEP:.*]]: index) -> (i32, f32)
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : i32
+// CHECK-DAG: %[[C2:.*]] = arith.constant 2.000000e+00 : f32
+// CHECK: %[[RES:.*]]:2 = scf.for %[[I:.*]] = %[[BEGIN]] to %[[END]] step %[[STEP]]
+// CHECK-SAME: iter_args(%[[ARG1:.*]] = %[[C1]], %[[ARG2:.*]] = %[[C2]]) -> (i32, f32) {
+// CHECK: %[[T1:.*]] = "test.test1"(%[[ARG1]]) : (i32) -> i32
+// CHECK: %[[T2:.*]] = "test.test2"(%[[ARG2]]) : (f32) -> f32
+// CHECK: scf.yield %[[T1]], %[[T2]] : i32, f32
+// CHECK: return %[[RES]]#0, %[[RES]]#1 : i32, f32
+
+// -----
+
+func.func @uplift_while(%arg0: i64, %arg1: i64, %arg2: i64) -> i64 {
+ %0 = scf.while (%arg3 = %arg0) : (i64) -> (i64) {
+ %1 = arith.cmpi slt, %arg3, %arg1 : i64
+ scf.condition(%1) %arg3 : i64
+ } do {
+ ^bb0(%arg3: i64):
+ "test.test1"(%arg3) : (i64) -> ()
+ %added = arith.addi %arg3, %arg2 : i64
+ "test.test2"(%added) : (i64) -> ()
+ scf.yield %added : i64
+ }
+ return %0 : i64
+}
+
+// CHECK-LABEL: func @uplift_while
+// CHECK-SAME: (%[[BEGINI:.*]]: i64, %[[ENDI:.*]]: i64, %[[STEPI:.*]]: i64) -> i64
+// CHECK: %[[C1:.*]] = arith.constant 1 : index
+// CHECK: %[[BEGIN:.*]] = arith.index_cast %[[BEGINI]] : i64 to index
+// CHECK: %[[END:.*]] = arith.index_cast %[[ENDI]] : i64 to index
+// CHECK: %[[STEP:.*]] = arith.index_cast %[[STEPI]] : i64 to index
+// CHECK: scf.for %[[I:.*]] = %[[BEGIN]] to %[[END]] step %[[STEP]] {
+// CHECK: %[[II:.*]] = arith.index_cast %[[I]] : index to i64
+// CHECK: "test.test1"(%[[II]]) : (i64) -> ()
+// CHECK: %[[INC:.*]] = arith.addi %[[II]], %[[STEPI]] : i64
+// CHECK: "test.test2"(%[[INC]]) : (i64) -> ()
+// CHECK: %[[R1:.*]] = arith.subi %[[STEP]], %[[C1]] : index
+// CHECK: %[[R2:.*]] = arith.subi %[[END]], %[[BEGIN]] : index
+// CHECK: %[[R3:.*]] = arith.addi %[[R2]], %[[R1]] : index
+// CHECK: %[[R4:.*]] = arith.divsi %[[R3]], %[[STEP]] : index
+// CHECK: %[[R5:.*]] = arith.subi %[[R4]], %[[C1]] : index
+// CHECK: %[[R6:.*]] = arith.muli %[[R5]], %[[STEP]] : index
+// CHECK: %[[R7:.*]] = arith.addi %[[BEGIN]], %[[R6]] : index
+// CHECK: %[[RES:.*]] = arith.index_cast %[[R7]] : index to i64
+// CHECK: return %[[RES]] : i64
>From f3b000ae9a90d373233bc4db1c4e95aba01a72e0 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Thu, 21 Dec 2023 15:54:39 +0100
Subject: [PATCH 2/6] Support non-index types
---
.../mlir/Dialect/SCF/Transforms/Passes.td | 6 --
.../mlir/Dialect/SCF/Transforms/Patterns.h | 3 +-
.../SCF/Transforms/UpliftWhileToFor.cpp | 64 +++++--------------
mlir/test/Dialect/SCF/uplift-while.mlir | 33 ++++------
4 files changed, 31 insertions(+), 75 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
index ec28bb0b8b8aa8..907cc77554a8e0 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
@@ -161,12 +161,6 @@ def SCFUpliftWhileToFor : Pass<"scf-uplift-while-to-for"> {
compatible form. `scf.while` are left unchanged if uplifting is not
possible.
}];
-
- let options = [
- Option<"indexBitWidth", "index-bitwidth", "unsigned",
- /*default=*/"64",
- "Bitwidth of index type.">,
- ];
}
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h b/mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h
index 9f3cdd93071ea9..3320019cd5a9d4 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h
@@ -80,8 +80,7 @@ void populateSCFLoopPipeliningPatterns(RewritePatternSet &patterns,
void populateSCFForLoopCanonicalizationPatterns(RewritePatternSet &patterns);
/// Populate patterns to uplift `scf.while` ops to `scf.for`.
-void populateUpliftWhileToForPatterns(RewritePatternSet &patterns,
- unsigned indexBitwidth);
+void populateUpliftWhileToForPatterns(RewritePatternSet &patterns);
} // namespace scf
} // namespace mlir
diff --git a/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp b/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp
index cd16b622504953..ccb4535f3c8d8d 100644
--- a/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp
@@ -26,22 +26,9 @@ namespace mlir {
using namespace mlir;
-static bool checkIndexType(arith::CmpIOp op, unsigned indexBitWidth) {
- auto type = op.getLhs().getType();
- if (isa<mlir::IndexType>(type))
- return true;
-
- if (type.isSignlessInteger(indexBitWidth))
- return true;
-
- return false;
-}
-
namespace {
struct UpliftWhileOp : public OpRewritePattern<scf::WhileOp> {
- UpliftWhileOp(MLIRContext *context, unsigned indexBitWidth_)
- : OpRewritePattern<scf::WhileOp>(context), indexBitWidth(indexBitWidth_) {
- }
+ using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(scf::WhileOp loop,
PatternRewriter &rewriter) const override {
@@ -71,11 +58,6 @@ struct UpliftWhileOp : public OpRewritePattern<scf::WhileOp> {
diag << "Expected 'slt' or 'sgt' predicate: " << *cmp;
});
- if (!checkIndexType(cmp, indexBitWidth))
- return rewriter.notifyMatchFailure(loop, [&](Diagnostic &diag) {
- diag << "Expected index-like type: " << *cmp;
- });
-
BlockArgument iterVar;
Value end;
DominanceInfo dom;
@@ -140,17 +122,9 @@ struct UpliftWhileOp : public OpRewritePattern<scf::WhileOp> {
auto begin = loop.getInits()[argNumber];
- auto loc = loop.getLoc();
- auto indexType = rewriter.getIndexType();
- auto toIndex = [&](Value val) -> Value {
- if (val.getType() != indexType)
- return rewriter.create<arith::IndexCastOp>(loc, indexType, val);
-
- return val;
- };
- begin = toIndex(begin);
- end = toIndex(end);
- step = toIndex(step);
+ assert(begin.getType().isIntOrIndex());
+ assert(begin.getType() == end.getType());
+ assert(begin.getType() == step.getType());
llvm::SmallVector<Value> mapping;
mapping.reserve(loop.getInits().size());
@@ -161,6 +135,7 @@ struct UpliftWhileOp : public OpRewritePattern<scf::WhileOp> {
mapping.emplace_back(init);
}
+ auto loc = loop.getLoc();
auto emptyBuidler = [](OpBuilder &, Location, Value, ValueRange) {};
auto newLoop = rewriter.create<scf::ForOp>(loc, begin, end, step, mapping,
emptyBuidler);
@@ -170,9 +145,6 @@ struct UpliftWhileOp : public OpRewritePattern<scf::WhileOp> {
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPointToStart(newBody);
Value newIterVar = newBody->getArgument(0);
- if (newIterVar.getType() != iterVar.getType())
- newIterVar = rewriter.create<arith::IndexCastOp>(loc, iterVar.getType(),
- newIterVar);
mapping.clear();
auto newArgs = newBody->getArguments();
@@ -180,11 +152,7 @@ struct UpliftWhileOp : public OpRewritePattern<scf::WhileOp> {
if (i < argNumber) {
mapping.emplace_back(newArgs[i + 1]);
} else if (i == argNumber) {
- Value arg = newArgs.front();
- if (arg.getType() != iterVar.getType())
- arg =
- rewriter.create<arith::IndexCastOp>(loc, iterVar.getType(), arg);
- mapping.emplace_back(arg);
+ mapping.emplace_back(newArgs.front());
} else {
mapping.emplace_back(newArgs[i]);
}
@@ -207,7 +175,13 @@ struct UpliftWhileOp : public OpRewritePattern<scf::WhileOp> {
rewriter.replaceOpWithNewOp<scf::YieldOp>(term, mapping);
rewriter.setInsertionPointAfter(newLoop);
- Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+ Value one;
+ if (isa<IndexType>(step.getType())) {
+ one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+ } else {
+ one = rewriter.create<arith::ConstantIntOp>(loc, 1, step.getType());
+ }
+
Value stepDec = rewriter.create<arith::SubIOp>(loc, step, one);
Value len = rewriter.create<arith::SubIOp>(loc, end, begin);
len = rewriter.create<arith::AddIOp>(loc, len, stepDec);
@@ -215,8 +189,6 @@ struct UpliftWhileOp : public OpRewritePattern<scf::WhileOp> {
len = rewriter.create<arith::SubIOp>(loc, len, one);
Value res = rewriter.create<arith::MulIOp>(loc, len, step);
res = rewriter.create<arith::AddIOp>(loc, begin, res);
- if (res.getType() != iterVar.getType())
- res = rewriter.create<arith::IndexCastOp>(loc, iterVar.getType(), res);
mapping.clear();
llvm::append_range(mapping, newLoop.getResults());
@@ -224,9 +196,6 @@ struct UpliftWhileOp : public OpRewritePattern<scf::WhileOp> {
rewriter.replaceOp(loop, mapping);
return success();
}
-
-private:
- unsigned indexBitWidth = 0;
};
struct SCFUpliftWhileToFor final
@@ -237,14 +206,13 @@ struct SCFUpliftWhileToFor final
Operation *op = getOperation();
MLIRContext *ctx = op->getContext();
RewritePatternSet patterns(ctx);
- mlir::scf::populateUpliftWhileToForPatterns(patterns, this->indexBitWidth);
+ mlir::scf::populateUpliftWhileToForPatterns(patterns);
if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
signalPassFailure();
}
};
} // namespace
-void mlir::scf::populateUpliftWhileToForPatterns(RewritePatternSet &patterns,
- unsigned indexBitwidth) {
- patterns.add<UpliftWhileOp>(patterns.getContext(), indexBitwidth);
+void mlir::scf::populateUpliftWhileToForPatterns(RewritePatternSet &patterns) {
+ patterns.add<UpliftWhileOp>(patterns.getContext());
}
diff --git a/mlir/test/Dialect/SCF/uplift-while.mlir b/mlir/test/Dialect/SCF/uplift-while.mlir
index 52a5c0f3cd6347..57394fd3528a62 100644
--- a/mlir/test/Dialect/SCF/uplift-while.mlir
+++ b/mlir/test/Dialect/SCF/uplift-while.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -pass-pipeline='builtin.module(func.func(scf-uplift-while-to-for{index-bitwidth=64}))' -split-input-file -allow-unregistered-dialect | FileCheck %s
+// RUN: mlir-opt %s -pass-pipeline='builtin.module(func.func(scf-uplift-while-to-for))' -split-input-file -allow-unregistered-dialect | FileCheck %s
func.func @uplift_while(%arg0: index, %arg1: index, %arg2: index) -> index {
%0 = scf.while (%arg3 = %arg0) : (index) -> (index) {
@@ -141,22 +141,17 @@ func.func @uplift_while(%arg0: i64, %arg1: i64, %arg2: i64) -> i64 {
}
// CHECK-LABEL: func @uplift_while
-// CHECK-SAME: (%[[BEGINI:.*]]: i64, %[[ENDI:.*]]: i64, %[[STEPI:.*]]: i64) -> i64
-// CHECK: %[[C1:.*]] = arith.constant 1 : index
-// CHECK: %[[BEGIN:.*]] = arith.index_cast %[[BEGINI]] : i64 to index
-// CHECK: %[[END:.*]] = arith.index_cast %[[ENDI]] : i64 to index
-// CHECK: %[[STEP:.*]] = arith.index_cast %[[STEPI]] : i64 to index
-// CHECK: scf.for %[[I:.*]] = %[[BEGIN]] to %[[END]] step %[[STEP]] {
-// CHECK: %[[II:.*]] = arith.index_cast %[[I]] : index to i64
-// CHECK: "test.test1"(%[[II]]) : (i64) -> ()
-// CHECK: %[[INC:.*]] = arith.addi %[[II]], %[[STEPI]] : i64
+// CHECK-SAME: (%[[BEGIN:.*]]: i64, %[[END:.*]]: i64, %[[STEP:.*]]: i64) -> i64
+// CHECK: %[[C1:.*]] = arith.constant 1 : i64
+// CHECK: scf.for %[[I:.*]] = %[[BEGIN]] to %[[END]] step %[[STEP]] : i64 {
+// CHECK: "test.test1"(%[[I]]) : (i64) -> ()
+// CHECK: %[[INC:.*]] = arith.addi %[[I]], %[[STEP]] : i64
// CHECK: "test.test2"(%[[INC]]) : (i64) -> ()
-// CHECK: %[[R1:.*]] = arith.subi %[[STEP]], %[[C1]] : index
-// CHECK: %[[R2:.*]] = arith.subi %[[END]], %[[BEGIN]] : index
-// CHECK: %[[R3:.*]] = arith.addi %[[R2]], %[[R1]] : index
-// CHECK: %[[R4:.*]] = arith.divsi %[[R3]], %[[STEP]] : index
-// CHECK: %[[R5:.*]] = arith.subi %[[R4]], %[[C1]] : index
-// CHECK: %[[R6:.*]] = arith.muli %[[R5]], %[[STEP]] : index
-// CHECK: %[[R7:.*]] = arith.addi %[[BEGIN]], %[[R6]] : index
-// CHECK: %[[RES:.*]] = arith.index_cast %[[R7]] : index to i64
-// CHECK: return %[[RES]] : i64
+// CHECK: %[[R1:.*]] = arith.subi %[[STEP]], %[[C1]] : i64
+// CHECK: %[[R2:.*]] = arith.subi %[[END]], %[[BEGIN]] : i64
+// CHECK: %[[R3:.*]] = arith.addi %[[R2]], %[[R1]] : i64
+// CHECK: %[[R4:.*]] = arith.divsi %[[R3]], %[[STEP]] : i64
+// CHECK: %[[R5:.*]] = arith.subi %[[R4]], %[[C1]] : i64
+// CHECK: %[[R6:.*]] = arith.muli %[[R5]], %[[STEP]] : i64
+// CHECK: %[[R7:.*]] = arith.addi %[[BEGIN]], %[[R6]] : i64
+// CHECK: return %[[R7]] : i64
>From 0e7bb3a241bc491e8b8f3438497d98b0c7695eba Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Thu, 21 Dec 2023 16:18:47 +0100
Subject: [PATCH 3/6] cleanup
---
.../mlir/Dialect/SCF/Transforms/Passes.td | 4 ++
.../SCF/Transforms/UpliftWhileToFor.cpp | 50 +++++++++++--------
2 files changed, 33 insertions(+), 21 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
index 907cc77554a8e0..42b02bb7ac6c4f 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
@@ -160,6 +160,10 @@ def SCFUpliftWhileToFor : Pass<"scf-uplift-while-to-for"> {
This pass tries to uplift `scf.while` ops to `scf.for` if they have a
compatible form. `scf.while` are left unchanged if uplifting is not
possible.
+
+ This pass expects a specific ops pattern:
+ * `before` block consisting of single arith.cmp op
+ * `after` block containing arith.addi
}];
}
diff --git a/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp b/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp
index ccb4535f3c8d8d..e1f9d25dee4d0e 100644
--- a/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp
@@ -41,26 +41,31 @@ struct UpliftWhileOp : public OpRewritePattern<scf::WhileOp> {
return rewriter.notifyMatchFailure(loop,
"Loop body must have single cmp op");
- auto beforeTerm = cast<scf::ConditionOp>(beforeBody->getTerminator());
- if (!llvm::hasSingleElement(cmp->getUses()) &&
- beforeTerm.getCondition() == cmp.getResult())
+ scf::ConditionOp beforeTerm = loop.getConditionOp();
+ if (!cmp->hasOneUse() && beforeTerm.getCondition() == cmp.getResult())
return rewriter.notifyMatchFailure(loop, [&](Diagnostic &diag) {
diag << "Expected single condiditon use: " << *cmp;
});
+ // All `before` block args must be directly forwarded to ConditionOp.
+ // They will be converted to `scf.for` `iter_vars` except induction var.
if (ValueRange(beforeBody->getArguments()) != beforeTerm.getArgs())
return rewriter.notifyMatchFailure(loop, "Invalid args order");
using Pred = arith::CmpIPredicate;
- auto predicate = cmp.getPredicate();
+ Pred predicate = cmp.getPredicate();
if (predicate != Pred::slt && predicate != Pred::sgt)
return rewriter.notifyMatchFailure(loop, [&](Diagnostic &diag) {
diag << "Expected 'slt' or 'sgt' predicate: " << *cmp;
});
- BlockArgument iterVar;
+ BlockArgument indVar;
Value end;
DominanceInfo dom;
+
+ // Check if cmp has a suitable form. One of the arguments must be a `before`
+ // block arg, other must be defined outside `scf.while` and will be treated
+ // as upper bound.
for (bool reverse : {false, true}) {
auto expectedPred = reverse ? Pred::sgt : Pred::slt;
if (cmp.getPredicate() != expectedPred)
@@ -76,36 +81,42 @@ struct UpliftWhileOp : public OpRewritePattern<scf::WhileOp> {
if (!dom.properlyDominates(arg2, loop))
continue;
- iterVar = blockArg;
+ indVar = blockArg;
end = arg2;
break;
}
- if (!iterVar)
+ if (!indVar)
return rewriter.notifyMatchFailure(loop, [&](Diagnostic &diag) {
diag << "Unrecognized cmp form: " << *cmp;
});
- if (!llvm::hasNItems(iterVar.getUses(), 2))
+ // indVar must have 2 uses: one is in `cmp` and other is `condition` arg.
+ if (!llvm::hasNItems(indVar.getUses(), 2))
return rewriter.notifyMatchFailure(loop, [&](Diagnostic &diag) {
- diag << "Unrecognized iter var: " << iterVar;
+ diag << "Unrecognized induction var: " << indVar;
});
Block *afterBody = loop.getAfterBody();
- auto afterTerm = cast<scf::YieldOp>(afterBody->getTerminator());
- auto argNumber = iterVar.getArgNumber();
+ scf::YieldOp afterTerm = loop.getYieldOp();
+ auto argNumber = indVar.getArgNumber();
auto afterTermIterArg = afterTerm.getResults()[argNumber];
- auto iterVarAfter = afterBody->getArgument(argNumber);
+ auto indVarAfter = afterBody->getArgument(argNumber);
Value step;
- for (auto &use : iterVarAfter.getUses()) {
+
+ // Find suitable `addi` op inside `after` block, one of the args must be an
+ // Induction var passed from `before` block and second arg must be defined
+ // outside of the loop and will be considered step value.
+ // TODO: Add `subi` support?
+ for (auto &use : indVarAfter.getUses()) {
auto owner = dyn_cast<arith::AddIOp>(use.getOwner());
if (!owner)
continue;
auto other =
- (iterVarAfter == owner.getLhs() ? owner.getRhs() : owner.getLhs());
+ (indVarAfter == owner.getLhs() ? owner.getRhs() : owner.getLhs());
if (!dom.properlyDominates(other, loop))
continue;
@@ -118,7 +129,7 @@ struct UpliftWhileOp : public OpRewritePattern<scf::WhileOp> {
if (!step)
return rewriter.notifyMatchFailure(loop,
- "Didn't found suitable 'add' op");
+ "Didn't found suitable 'addi' op");
auto begin = loop.getInits()[argNumber];
@@ -136,16 +147,12 @@ struct UpliftWhileOp : public OpRewritePattern<scf::WhileOp> {
}
auto loc = loop.getLoc();
- auto emptyBuidler = [](OpBuilder &, Location, Value, ValueRange) {};
+ auto emptyBuilder = [](OpBuilder &, Location, Value, ValueRange) {};
auto newLoop = rewriter.create<scf::ForOp>(loc, begin, end, step, mapping,
- emptyBuidler);
+ emptyBuilder);
Block *newBody = newLoop.getBody();
- OpBuilder::InsertionGuard g(rewriter);
- rewriter.setInsertionPointToStart(newBody);
- Value newIterVar = newBody->getArgument(0);
-
mapping.clear();
auto newArgs = newBody->getArguments();
for (auto i : llvm::seq<size_t>(0, newArgs.size())) {
@@ -171,6 +178,7 @@ struct UpliftWhileOp : public OpRewritePattern<scf::WhileOp> {
mapping.emplace_back(arg);
}
+ OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(term);
rewriter.replaceOpWithNewOp<scf::YieldOp>(term, mapping);
>From 68b52cda39cf81a9839632cd8bd983f771c31fd1 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Thu, 21 Dec 2023 16:37:14 +0100
Subject: [PATCH 4/6] Renamings and comments
---
.../SCF/Transforms/UpliftWhileToFor.cpp | 67 +++++++++++--------
1 file changed, 39 insertions(+), 28 deletions(-)
diff --git a/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp b/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp
index e1f9d25dee4d0e..0e4dd60217a72f 100644
--- a/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp
@@ -60,7 +60,7 @@ struct UpliftWhileOp : public OpRewritePattern<scf::WhileOp> {
});
BlockArgument indVar;
- Value end;
+ Value ub;
DominanceInfo dom;
// Check if cmp has a suitable form. One of the arguments must be a `before`
@@ -82,7 +82,7 @@ struct UpliftWhileOp : public OpRewritePattern<scf::WhileOp> {
continue;
indVar = blockArg;
- end = arg2;
+ ub = arg2;
break;
}
@@ -131,57 +131,66 @@ struct UpliftWhileOp : public OpRewritePattern<scf::WhileOp> {
return rewriter.notifyMatchFailure(loop,
"Didn't found suitable 'addi' op");
- auto begin = loop.getInits()[argNumber];
+ auto lb = loop.getInits()[argNumber];
- assert(begin.getType().isIntOrIndex());
- assert(begin.getType() == end.getType());
- assert(begin.getType() == step.getType());
+ assert(lb.getType().isIntOrIndex());
+ assert(lb.getType() == ub.getType());
+ assert(lb.getType() == step.getType());
- llvm::SmallVector<Value> mapping;
- mapping.reserve(loop.getInits().size());
+ llvm::SmallVector<Value> newArgs;
+
+ // Populate inits for new `scf.for`
+ newArgs.reserve(loop.getInits().size());
for (auto &&[i, init] : llvm::enumerate(loop.getInits())) {
if (i == argNumber)
continue;
- mapping.emplace_back(init);
+ newArgs.emplace_back(init);
}
auto loc = loop.getLoc();
+
+ // With `builder == nullptr`, ForOp::build will try to insert terminator at
+ // the end of newly created block and we don't want it. Provide empty
+ // dummy builder instead.
auto emptyBuilder = [](OpBuilder &, Location, Value, ValueRange) {};
- auto newLoop = rewriter.create<scf::ForOp>(loc, begin, end, step, mapping,
- emptyBuilder);
+ auto newLoop =
+ rewriter.create<scf::ForOp>(loc, lb, ub, step, newArgs, emptyBuilder);
Block *newBody = newLoop.getBody();
- mapping.clear();
- auto newArgs = newBody->getArguments();
- for (auto i : llvm::seq<size_t>(0, newArgs.size())) {
+ // Populate block args for `scf.for` body, move induction var to the front.
+ newArgs.clear();
+ ValueRange newBodyArgs = newBody->getArguments();
+ for (auto i : llvm::seq<size_t>(0, newBodyArgs.size())) {
if (i < argNumber) {
- mapping.emplace_back(newArgs[i + 1]);
+ newArgs.emplace_back(newBodyArgs[i + 1]);
} else if (i == argNumber) {
- mapping.emplace_back(newArgs.front());
+ newArgs.emplace_back(newBodyArgs.front());
} else {
- mapping.emplace_back(newArgs[i]);
+ newArgs.emplace_back(newBodyArgs[i]);
}
}
rewriter.inlineBlockBefore(loop.getAfterBody(), newBody, newBody->end(),
- mapping);
+ newArgs);
auto term = cast<scf::YieldOp>(newBody->getTerminator());
- mapping.clear();
+ // Populate new yield args, skipping the induction var.
+ newArgs.clear();
for (auto &&[i, arg] : llvm::enumerate(term.getResults())) {
if (i == argNumber)
continue;
- mapping.emplace_back(arg);
+ newArgs.emplace_back(arg);
}
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(term);
- rewriter.replaceOpWithNewOp<scf::YieldOp>(term, mapping);
+ rewriter.replaceOpWithNewOp<scf::YieldOp>(term, newArgs);
+ // Compute induction var value after loop execution.
rewriter.setInsertionPointAfter(newLoop);
Value one;
if (isa<IndexType>(step.getType())) {
@@ -191,17 +200,19 @@ struct UpliftWhileOp : public OpRewritePattern<scf::WhileOp> {
}
Value stepDec = rewriter.create<arith::SubIOp>(loc, step, one);
- Value len = rewriter.create<arith::SubIOp>(loc, end, begin);
+ Value len = rewriter.create<arith::SubIOp>(loc, ub, lb);
len = rewriter.create<arith::AddIOp>(loc, len, stepDec);
len = rewriter.create<arith::DivSIOp>(loc, len, step);
len = rewriter.create<arith::SubIOp>(loc, len, one);
Value res = rewriter.create<arith::MulIOp>(loc, len, step);
- res = rewriter.create<arith::AddIOp>(loc, begin, res);
-
- mapping.clear();
- llvm::append_range(mapping, newLoop.getResults());
- mapping.insert(mapping.begin() + argNumber, res);
- rewriter.replaceOp(loop, mapping);
+ res = rewriter.create<arith::AddIOp>(loc, lb, res);
+
+ // Reconstruct `scf.while` results, inserting final induction var value
+ // into proper place.
+ newArgs.clear();
+ llvm::append_range(newArgs, newLoop.getResults());
+ newArgs.insert(newArgs.begin() + argNumber, res);
+ rewriter.replaceOp(loop, newArgs);
return success();
}
};
>From 30901ee1912d27073c0bffb25c1375baa8c42604 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Thu, 21 Dec 2023 16:47:07 +0100
Subject: [PATCH 5/6] renaming
---
.../SCF/Transforms/UpliftWhileToFor.cpp | 29 ++++++++++---------
1 file changed, 15 insertions(+), 14 deletions(-)
diff --git a/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp b/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp
index 0e4dd60217a72f..a14c726fcd0537 100644
--- a/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp
@@ -59,7 +59,7 @@ struct UpliftWhileOp : public OpRewritePattern<scf::WhileOp> {
diag << "Expected 'slt' or 'sgt' predicate: " << *cmp;
});
- BlockArgument indVar;
+ BlockArgument inductionVar;
Value ub;
DominanceInfo dom;
@@ -81,28 +81,29 @@ struct UpliftWhileOp : public OpRewritePattern<scf::WhileOp> {
if (!dom.properlyDominates(arg2, loop))
continue;
- indVar = blockArg;
+ inductionVar = blockArg;
ub = arg2;
break;
}
- if (!indVar)
+ if (!inductionVar)
return rewriter.notifyMatchFailure(loop, [&](Diagnostic &diag) {
diag << "Unrecognized cmp form: " << *cmp;
});
- // indVar must have 2 uses: one is in `cmp` and other is `condition` arg.
- if (!llvm::hasNItems(indVar.getUses(), 2))
+ // inductionVar must have 2 uses: one is in `cmp` and other is `condition`
+ // arg.
+ if (!llvm::hasNItems(inductionVar.getUses(), 2))
return rewriter.notifyMatchFailure(loop, [&](Diagnostic &diag) {
- diag << "Unrecognized induction var: " << indVar;
+ diag << "Unrecognized induction var: " << inductionVar;
});
Block *afterBody = loop.getAfterBody();
scf::YieldOp afterTerm = loop.getYieldOp();
- auto argNumber = indVar.getArgNumber();
- auto afterTermIterArg = afterTerm.getResults()[argNumber];
+ auto argNumber = inductionVar.getArgNumber();
+ auto afterTermIndArg = afterTerm.getResults()[argNumber];
- auto indVarAfter = afterBody->getArgument(argNumber);
+ auto inductionVarAfter = afterBody->getArgument(argNumber);
Value step;
@@ -110,17 +111,17 @@ struct UpliftWhileOp : public OpRewritePattern<scf::WhileOp> {
// Induction var passed from `before` block and second arg must be defined
// outside of the loop and will be considered step value.
// TODO: Add `subi` support?
- for (auto &use : indVarAfter.getUses()) {
+ for (auto &use : inductionVarAfter.getUses()) {
auto owner = dyn_cast<arith::AddIOp>(use.getOwner());
if (!owner)
continue;
- auto other =
- (indVarAfter == owner.getLhs() ? owner.getRhs() : owner.getLhs());
+ auto other = (inductionVarAfter == owner.getLhs() ? owner.getRhs()
+ : owner.getLhs());
if (!dom.properlyDominates(other, loop))
continue;
- if (afterTermIterArg != owner.getResult())
+ if (afterTermIndArg != owner.getResult())
continue;
step = other;
@@ -139,7 +140,7 @@ struct UpliftWhileOp : public OpRewritePattern<scf::WhileOp> {
llvm::SmallVector<Value> newArgs;
- // Populate inits for new `scf.for`
+ // Populate inits for new `scf.for`, skip induction var.
newArgs.reserve(loop.getInits().size());
for (auto &&[i, init] : llvm::enumerate(loop.getInits())) {
if (i == argNumber)
>From f16cc6163be13c9dcc389045022fb029f301465f Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Thu, 4 Jan 2024 21:00:36 +0100
Subject: [PATCH 6/6] Renamed to test pass
---
mlir/include/mlir/Dialect/SCF/Transforms/Passes.td | 2 +-
mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h | 3 +++
mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp | 8 ++++----
mlir/test/Dialect/SCF/uplift-while.mlir | 2 +-
4 files changed, 9 insertions(+), 6 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
index 42b02bb7ac6c4f..f9fc4a0358c9ec 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
@@ -154,7 +154,7 @@ def SCFForToWhileLoop : Pass<"scf-for-to-while"> {
}];
}
-def SCFUpliftWhileToFor : Pass<"scf-uplift-while-to-for"> {
+def TestSCFUpliftWhileToFor : Pass<"test-scf-uplift-while-to-for"> {
let summary = "Uplift scf.while ops to scf.for";
let description = [{
This pass tries to uplift `scf.while` ops to `scf.for` if they have a
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h b/mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h
index 3320019cd5a9d4..fdf25706269803 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h
@@ -80,6 +80,9 @@ void populateSCFLoopPipeliningPatterns(RewritePatternSet &patterns,
void populateSCFForLoopCanonicalizationPatterns(RewritePatternSet &patterns);
/// Populate patterns to uplift `scf.while` ops to `scf.for`.
+/// Uplifitng expects a specific ops pattern:
+/// * `before` block consisting of single arith.cmp op
+/// * `after` block containing arith.addi
void populateUpliftWhileToForPatterns(RewritePatternSet &patterns);
} // namespace scf
diff --git a/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp b/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp
index a14c726fcd0537..74152f7a5304ce 100644
--- a/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp
@@ -20,7 +20,7 @@
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
namespace mlir {
-#define GEN_PASS_DEF_SCFUPLIFTWHILETOFOR
+#define GEN_PASS_DEF_TESTSCFUPLIFTWHILETOFOR
#include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
} // namespace mlir
@@ -218,9 +218,9 @@ struct UpliftWhileOp : public OpRewritePattern<scf::WhileOp> {
}
};
-struct SCFUpliftWhileToFor final
- : impl::SCFUpliftWhileToForBase<SCFUpliftWhileToFor> {
- using SCFUpliftWhileToForBase::SCFUpliftWhileToForBase;
+struct TestSCFUpliftWhileToFor final
+ : impl::TestSCFUpliftWhileToForBase<TestSCFUpliftWhileToFor> {
+ using TestSCFUpliftWhileToForBase::TestSCFUpliftWhileToForBase;
void runOnOperation() override {
Operation *op = getOperation();
diff --git a/mlir/test/Dialect/SCF/uplift-while.mlir b/mlir/test/Dialect/SCF/uplift-while.mlir
index 57394fd3528a62..25ea6142a332dc 100644
--- a/mlir/test/Dialect/SCF/uplift-while.mlir
+++ b/mlir/test/Dialect/SCF/uplift-while.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -pass-pipeline='builtin.module(func.func(scf-uplift-while-to-for))' -split-input-file -allow-unregistered-dialect | FileCheck %s
+// RUN: mlir-opt %s -pass-pipeline='builtin.module(func.func(test-scf-uplift-while-to-for))' -split-input-file -allow-unregistered-dialect | FileCheck %s
func.func @uplift_while(%arg0: index, %arg1: index, %arg2: index) -> index {
%0 = scf.while (%arg3 = %arg0) : (index) -> (index) {
More information about the Mlir-commits
mailing list