[Mlir-commits] [mlir] [mlir][scf] Uplift `scf.while` to `scf.for` (PR #76108)

Ivan Butygin llvmlistbot at llvm.org
Sat Mar 30 06:51:24 PDT 2024


https://github.com/Hardcode84 updated https://github.com/llvm/llvm-project/pull/76108

>From fa0320efd0c48f26fd467fecd53c7fe27f1b1e83 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Wed, 20 Dec 2023 22:18:29 +0100
Subject: [PATCH 1/6] [mlir][scf] Uplift `scf.while` to `scf.for`

Add uplifting from `scf.while` to `scf.for`.

This uplifting expects a very specifi ops pattern:
* `before` body consisting of single `arith.cmp` op
* `after` body containing `arith.addi`
* Iter var must be of type `index` or integer of specified width

We also have a set of patterns to clenaup `scf.while` loops to get them close to the desired form, they will be added in separate PRs.

This is part of upstreaming `numba-mlir` scf uplifting pipeline:
`cf -> scf.while -> scf.for -> scf.parallel`

Original code: https://github.com/numba/numba-mlir/blob/main/mlir/lib/Transforms/PromoteToParallel.cpp
---
 .../mlir/Dialect/SCF/Transforms/Passes.td     |  16 ++
 .../mlir/Dialect/SCF/Transforms/Patterns.h    |   4 +
 .../lib/Dialect/SCF/Transforms/CMakeLists.txt |   1 +
 .../SCF/Transforms/UpliftWhileToFor.cpp       | 250 ++++++++++++++++++
 mlir/test/Dialect/SCF/uplift-while.mlir       | 162 ++++++++++++
 5 files changed, 433 insertions(+)
 create mode 100644 mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp
 create mode 100644 mlir/test/Dialect/SCF/uplift-while.mlir

diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
index 350611ad86873d..ec28bb0b8b8aa8 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
@@ -154,4 +154,20 @@ def SCFForToWhileLoop : Pass<"scf-for-to-while"> {
   }];
 }
 
+def SCFUpliftWhileToFor : Pass<"scf-uplift-while-to-for"> {
+  let summary = "Uplift scf.while ops to scf.for";
+  let description = [{
+    This pass tries to uplift `scf.while` ops to `scf.for` if they have a
+    compatible form. `scf.while` are left unchanged if uplifting is not
+    possible.
+  }];
+
+  let options = [
+    Option<"indexBitWidth", "index-bitwidth", "unsigned",
+           /*default=*/"64",
+           "Bitwidth of index type.">,
+  ];
+ }
+
+
 #endif // MLIR_DIALECT_SCF_PASSES
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h b/mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h
index 5c0d5643c01986..9f3cdd93071ea9 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h
@@ -79,6 +79,10 @@ void populateSCFLoopPipeliningPatterns(RewritePatternSet &patterns,
 /// loop bounds and loop steps are canonicalized.
 void populateSCFForLoopCanonicalizationPatterns(RewritePatternSet &patterns);
 
+/// Populate patterns to uplift `scf.while` ops to `scf.for`.
+void populateUpliftWhileToForPatterns(RewritePatternSet &patterns,
+                                      unsigned indexBitwidth);
+
 } // namespace scf
 } // namespace mlir
 
diff --git a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
index e5494205e086ac..a2925aef17ca78 100644
--- a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
@@ -14,6 +14,7 @@ add_mlir_dialect_library(MLIRSCFTransforms
   StructuralTypeConversions.cpp
   TileUsingInterface.cpp
   WrapInZeroTripCheck.cpp
+  UpliftWhileToFor.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/SCF
diff --git a/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp b/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp
new file mode 100644
index 00000000000000..cd16b622504953
--- /dev/null
+++ b/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp
@@ -0,0 +1,250 @@
+//===- UpliftWhileToFor.cpp - scf.while to scf.for loop uplifting ---------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Transforms SCF.WhileOp's into SCF.ForOp's.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/SCF/Transforms/Passes.h"
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/SCF/Transforms/Patterns.h"
+#include "mlir/IR/Dominance.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_SCFUPLIFTWHILETOFOR
+#include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+
+static bool checkIndexType(arith::CmpIOp op, unsigned indexBitWidth) {
+  auto type = op.getLhs().getType();
+  if (isa<mlir::IndexType>(type))
+    return true;
+
+  if (type.isSignlessInteger(indexBitWidth))
+    return true;
+
+  return false;
+}
+
+namespace {
+struct UpliftWhileOp : public OpRewritePattern<scf::WhileOp> {
+  UpliftWhileOp(MLIRContext *context, unsigned indexBitWidth_)
+      : OpRewritePattern<scf::WhileOp>(context), indexBitWidth(indexBitWidth_) {
+  }
+
+  LogicalResult matchAndRewrite(scf::WhileOp loop,
+                                PatternRewriter &rewriter) const override {
+    Block *beforeBody = loop.getBeforeBody();
+    if (!llvm::hasSingleElement(beforeBody->without_terminator()))
+      return rewriter.notifyMatchFailure(loop, "Loop body must have single op");
+
+    auto cmp = dyn_cast<arith::CmpIOp>(beforeBody->front());
+    if (!cmp)
+      return rewriter.notifyMatchFailure(loop,
+                                         "Loop body must have single cmp op");
+
+    auto beforeTerm = cast<scf::ConditionOp>(beforeBody->getTerminator());
+    if (!llvm::hasSingleElement(cmp->getUses()) &&
+        beforeTerm.getCondition() == cmp.getResult())
+      return rewriter.notifyMatchFailure(loop, [&](Diagnostic &diag) {
+        diag << "Expected single condiditon use: " << *cmp;
+      });
+
+    if (ValueRange(beforeBody->getArguments()) != beforeTerm.getArgs())
+      return rewriter.notifyMatchFailure(loop, "Invalid args order");
+
+    using Pred = arith::CmpIPredicate;
+    auto predicate = cmp.getPredicate();
+    if (predicate != Pred::slt && predicate != Pred::sgt)
+      return rewriter.notifyMatchFailure(loop, [&](Diagnostic &diag) {
+        diag << "Expected 'slt' or 'sgt' predicate: " << *cmp;
+      });
+
+    if (!checkIndexType(cmp, indexBitWidth))
+      return rewriter.notifyMatchFailure(loop, [&](Diagnostic &diag) {
+        diag << "Expected index-like type: " << *cmp;
+      });
+
+    BlockArgument iterVar;
+    Value end;
+    DominanceInfo dom;
+    for (bool reverse : {false, true}) {
+      auto expectedPred = reverse ? Pred::sgt : Pred::slt;
+      if (cmp.getPredicate() != expectedPred)
+        continue;
+
+      auto arg1 = reverse ? cmp.getRhs() : cmp.getLhs();
+      auto arg2 = reverse ? cmp.getLhs() : cmp.getRhs();
+
+      auto blockArg = dyn_cast<BlockArgument>(arg1);
+      if (!blockArg || blockArg.getOwner() != beforeBody)
+        continue;
+
+      if (!dom.properlyDominates(arg2, loop))
+        continue;
+
+      iterVar = blockArg;
+      end = arg2;
+      break;
+    }
+
+    if (!iterVar)
+      return rewriter.notifyMatchFailure(loop, [&](Diagnostic &diag) {
+        diag << "Unrecognized cmp form: " << *cmp;
+      });
+
+    if (!llvm::hasNItems(iterVar.getUses(), 2))
+      return rewriter.notifyMatchFailure(loop, [&](Diagnostic &diag) {
+        diag << "Unrecognized iter var: " << iterVar;
+      });
+
+    Block *afterBody = loop.getAfterBody();
+    auto afterTerm = cast<scf::YieldOp>(afterBody->getTerminator());
+    auto argNumber = iterVar.getArgNumber();
+    auto afterTermIterArg = afterTerm.getResults()[argNumber];
+
+    auto iterVarAfter = afterBody->getArgument(argNumber);
+
+    Value step;
+    for (auto &use : iterVarAfter.getUses()) {
+      auto owner = dyn_cast<arith::AddIOp>(use.getOwner());
+      if (!owner)
+        continue;
+
+      auto other =
+          (iterVarAfter == owner.getLhs() ? owner.getRhs() : owner.getLhs());
+      if (!dom.properlyDominates(other, loop))
+        continue;
+
+      if (afterTermIterArg != owner.getResult())
+        continue;
+
+      step = other;
+      break;
+    }
+
+    if (!step)
+      return rewriter.notifyMatchFailure(loop,
+                                         "Didn't found suitable 'add' op");
+
+    auto begin = loop.getInits()[argNumber];
+
+    auto loc = loop.getLoc();
+    auto indexType = rewriter.getIndexType();
+    auto toIndex = [&](Value val) -> Value {
+      if (val.getType() != indexType)
+        return rewriter.create<arith::IndexCastOp>(loc, indexType, val);
+
+      return val;
+    };
+    begin = toIndex(begin);
+    end = toIndex(end);
+    step = toIndex(step);
+
+    llvm::SmallVector<Value> mapping;
+    mapping.reserve(loop.getInits().size());
+    for (auto &&[i, init] : llvm::enumerate(loop.getInits())) {
+      if (i == argNumber)
+        continue;
+
+      mapping.emplace_back(init);
+    }
+
+    auto emptyBuidler = [](OpBuilder &, Location, Value, ValueRange) {};
+    auto newLoop = rewriter.create<scf::ForOp>(loc, begin, end, step, mapping,
+                                               emptyBuidler);
+
+    Block *newBody = newLoop.getBody();
+
+    OpBuilder::InsertionGuard g(rewriter);
+    rewriter.setInsertionPointToStart(newBody);
+    Value newIterVar = newBody->getArgument(0);
+    if (newIterVar.getType() != iterVar.getType())
+      newIterVar = rewriter.create<arith::IndexCastOp>(loc, iterVar.getType(),
+                                                       newIterVar);
+
+    mapping.clear();
+    auto newArgs = newBody->getArguments();
+    for (auto i : llvm::seq<size_t>(0, newArgs.size())) {
+      if (i < argNumber) {
+        mapping.emplace_back(newArgs[i + 1]);
+      } else if (i == argNumber) {
+        Value arg = newArgs.front();
+        if (arg.getType() != iterVar.getType())
+          arg =
+              rewriter.create<arith::IndexCastOp>(loc, iterVar.getType(), arg);
+        mapping.emplace_back(arg);
+      } else {
+        mapping.emplace_back(newArgs[i]);
+      }
+    }
+
+    rewriter.inlineBlockBefore(loop.getAfterBody(), newBody, newBody->end(),
+                               mapping);
+
+    auto term = cast<scf::YieldOp>(newBody->getTerminator());
+
+    mapping.clear();
+    for (auto &&[i, arg] : llvm::enumerate(term.getResults())) {
+      if (i == argNumber)
+        continue;
+
+      mapping.emplace_back(arg);
+    }
+
+    rewriter.setInsertionPoint(term);
+    rewriter.replaceOpWithNewOp<scf::YieldOp>(term, mapping);
+
+    rewriter.setInsertionPointAfter(newLoop);
+    Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+    Value stepDec = rewriter.create<arith::SubIOp>(loc, step, one);
+    Value len = rewriter.create<arith::SubIOp>(loc, end, begin);
+    len = rewriter.create<arith::AddIOp>(loc, len, stepDec);
+    len = rewriter.create<arith::DivSIOp>(loc, len, step);
+    len = rewriter.create<arith::SubIOp>(loc, len, one);
+    Value res = rewriter.create<arith::MulIOp>(loc, len, step);
+    res = rewriter.create<arith::AddIOp>(loc, begin, res);
+    if (res.getType() != iterVar.getType())
+      res = rewriter.create<arith::IndexCastOp>(loc, iterVar.getType(), res);
+
+    mapping.clear();
+    llvm::append_range(mapping, newLoop.getResults());
+    mapping.insert(mapping.begin() + argNumber, res);
+    rewriter.replaceOp(loop, mapping);
+    return success();
+  }
+
+private:
+  unsigned indexBitWidth = 0;
+};
+
+struct SCFUpliftWhileToFor final
+    : impl::SCFUpliftWhileToForBase<SCFUpliftWhileToFor> {
+  using SCFUpliftWhileToForBase::SCFUpliftWhileToForBase;
+
+  void runOnOperation() override {
+    Operation *op = getOperation();
+    MLIRContext *ctx = op->getContext();
+    RewritePatternSet patterns(ctx);
+    mlir::scf::populateUpliftWhileToForPatterns(patterns, this->indexBitWidth);
+    if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
+      signalPassFailure();
+  }
+};
+} // namespace
+
+void mlir::scf::populateUpliftWhileToForPatterns(RewritePatternSet &patterns,
+                                                 unsigned indexBitwidth) {
+  patterns.add<UpliftWhileOp>(patterns.getContext(), indexBitwidth);
+}
diff --git a/mlir/test/Dialect/SCF/uplift-while.mlir b/mlir/test/Dialect/SCF/uplift-while.mlir
new file mode 100644
index 00000000000000..52a5c0f3cd6347
--- /dev/null
+++ b/mlir/test/Dialect/SCF/uplift-while.mlir
@@ -0,0 +1,162 @@
+// RUN: mlir-opt %s -pass-pipeline='builtin.module(func.func(scf-uplift-while-to-for{index-bitwidth=64}))' -split-input-file -allow-unregistered-dialect | FileCheck %s
+
+func.func @uplift_while(%arg0: index, %arg1: index, %arg2: index) -> index {
+  %0 = scf.while (%arg3 = %arg0) : (index) -> (index) {
+    %1 = arith.cmpi slt, %arg3, %arg1 : index
+    scf.condition(%1) %arg3 : index
+  } do {
+  ^bb0(%arg3: index):
+    "test.test1"(%arg3) : (index) -> ()
+    %added = arith.addi %arg3, %arg2 : index
+    "test.test2"(%added) : (index) -> ()
+    scf.yield %added : index
+  }
+  return %0 : index
+}
+
+// CHECK-LABEL: func @uplift_while
+//  CHECK-SAME:     (%[[BEGIN:.*]]: index, %[[END:.*]]: index, %[[STEP:.*]]: index) -> index
+//       CHECK:     %[[C1:.*]] = arith.constant 1 : index
+//       CHECK:     scf.for %[[I:.*]] = %[[BEGIN]] to %[[END]] step %[[STEP]] {
+//       CHECK:     "test.test1"(%[[I]]) : (index) -> ()
+//       CHECK:     %[[INC:.*]] = arith.addi %[[I]], %[[STEP]] : index
+//       CHECK:     "test.test2"(%[[INC]]) : (index) -> ()
+//       CHECK:     %[[R1:.*]] = arith.subi %[[STEP]], %[[C1]] : index
+//       CHECK:     %[[R2:.*]] = arith.subi %[[END]], %[[BEGIN]] : index
+//       CHECK:     %[[R3:.*]] = arith.addi %[[R2]], %[[R1]] : index
+//       CHECK:     %[[R4:.*]] = arith.divsi %[[R3]], %[[STEP]] : index
+//       CHECK:     %[[R5:.*]] = arith.subi %[[R4]], %[[C1]] : index
+//       CHECK:     %[[R6:.*]] = arith.muli %[[R5]], %[[STEP]] : index
+//       CHECK:     %[[R7:.*]] = arith.addi %[[BEGIN]], %[[R6]] : index
+//       CHECK:     return %[[R7]] : index
+
+// -----
+
+func.func @uplift_while(%arg0: index, %arg1: index, %arg2: index) -> index {
+  %0 = scf.while (%arg3 = %arg0) : (index) -> (index) {
+    %1 = arith.cmpi sgt, %arg1, %arg3 : index
+    scf.condition(%1) %arg3 : index
+  } do {
+  ^bb0(%arg3: index):
+    "test.test1"(%arg3) : (index) -> ()
+    %added = arith.addi %arg3, %arg2 : index
+    "test.test2"(%added) : (index) -> ()
+    scf.yield %added : index
+  }
+  return %0 : index
+}
+
+// CHECK-LABEL: func @uplift_while
+//  CHECK-SAME:     (%[[BEGIN:.*]]: index, %[[END:.*]]: index, %[[STEP:.*]]: index) -> index
+//       CHECK:     %[[C1:.*]] = arith.constant 1 : index
+//       CHECK:     scf.for %[[I:.*]] = %[[BEGIN]] to %[[END]] step %[[STEP]] {
+//       CHECK:     "test.test1"(%[[I]]) : (index) -> ()
+//       CHECK:     %[[INC:.*]] = arith.addi %[[I]], %[[STEP]] : index
+//       CHECK:     "test.test2"(%[[INC]]) : (index) -> ()
+//       CHECK:     %[[R1:.*]] = arith.subi %[[STEP]], %[[C1]] : index
+//       CHECK:     %[[R2:.*]] = arith.subi %[[END]], %[[BEGIN]] : index
+//       CHECK:     %[[R3:.*]] = arith.addi %[[R2]], %[[R1]] : index
+//       CHECK:     %[[R4:.*]] = arith.divsi %[[R3]], %[[STEP]] : index
+//       CHECK:     %[[R5:.*]] = arith.subi %[[R4]], %[[C1]] : index
+//       CHECK:     %[[R6:.*]] = arith.muli %[[R5]], %[[STEP]] : index
+//       CHECK:     %[[R7:.*]] = arith.addi %[[BEGIN]], %[[R6]] : index
+//       CHECK:     return %[[R7]] : index
+
+// -----
+
+func.func @uplift_while(%arg0: index, %arg1: index, %arg2: index) -> index {
+  %0 = scf.while (%arg3 = %arg0) : (index) -> (index) {
+    %1 = arith.cmpi slt, %arg3, %arg1 : index
+    scf.condition(%1) %arg3 : index
+  } do {
+  ^bb0(%arg3: index):
+    "test.test1"(%arg3) : (index) -> ()
+    %added = arith.addi %arg2, %arg3 : index
+    "test.test2"(%added) : (index) -> ()
+    scf.yield %added : index
+  }
+  return %0 : index
+}
+
+// CHECK-LABEL: func @uplift_while
+//  CHECK-SAME:     (%[[BEGIN:.*]]: index, %[[END:.*]]: index, %[[STEP:.*]]: index) -> index
+//       CHECK:     %[[C1:.*]] = arith.constant 1 : index
+//       CHECK:     scf.for %[[I:.*]] = %[[BEGIN]] to %[[END]] step %[[STEP]] {
+//       CHECK:     "test.test1"(%[[I]]) : (index) -> ()
+//       CHECK:     %[[INC:.*]] = arith.addi %[[STEP]], %[[I]] : index
+//       CHECK:     "test.test2"(%[[INC]]) : (index) -> ()
+//       CHECK:     %[[R1:.*]] = arith.subi %[[STEP]], %[[C1]] : index
+//       CHECK:     %[[R2:.*]] = arith.subi %[[END]], %[[BEGIN]] : index
+//       CHECK:     %[[R3:.*]] = arith.addi %[[R2]], %[[R1]] : index
+//       CHECK:     %[[R4:.*]] = arith.divsi %[[R3]], %[[STEP]] : index
+//       CHECK:     %[[R5:.*]] = arith.subi %[[R4]], %[[C1]] : index
+//       CHECK:     %[[R6:.*]] = arith.muli %[[R5]], %[[STEP]] : index
+//       CHECK:     %[[R7:.*]] = arith.addi %[[BEGIN]], %[[R6]] : index
+//       CHECK:     return %[[R7]] : index
+
+
+// -----
+
+func.func @uplift_while(%arg0: index, %arg1: index, %arg2: index) -> (i32, f32) {
+  %c1 = arith.constant 1 : i32
+  %c2 = arith.constant 2.0 : f32
+  %0:3 = scf.while (%arg4 = %c1, %arg3 = %arg0, %arg5 = %c2) : (i32, index, f32) -> (i32, index, f32) {
+    %1 = arith.cmpi slt, %arg3, %arg1 : index
+    scf.condition(%1) %arg4, %arg3, %arg5 : i32, index, f32
+  } do {
+  ^bb0(%arg4: i32, %arg3: index, %arg5: f32):
+    %1 = "test.test1"(%arg4) : (i32) -> i32
+    %added = arith.addi %arg3, %arg2 : index
+    %2 = "test.test2"(%arg5) : (f32) -> f32
+    scf.yield %1, %added, %2 : i32, index, f32
+  }
+  return %0#0, %0#2 : i32, f32
+}
+
+// CHECK-LABEL: func @uplift_while
+//  CHECK-SAME:     (%[[BEGIN:.*]]: index, %[[END:.*]]: index, %[[STEP:.*]]: index) -> (i32, f32)
+//   CHECK-DAG:     %[[C1:.*]] = arith.constant 1 : i32
+//   CHECK-DAG:     %[[C2:.*]] = arith.constant 2.000000e+00 : f32
+//       CHECK:     %[[RES:.*]]:2 = scf.for %[[I:.*]] = %[[BEGIN]] to %[[END]] step %[[STEP]]
+//  CHECK-SAME:     iter_args(%[[ARG1:.*]] = %[[C1]], %[[ARG2:.*]] = %[[C2]]) -> (i32, f32) {
+//       CHECK:     %[[T1:.*]] = "test.test1"(%[[ARG1]]) : (i32) -> i32
+//       CHECK:     %[[T2:.*]] = "test.test2"(%[[ARG2]]) : (f32) -> f32
+//       CHECK:     scf.yield %[[T1]], %[[T2]] : i32, f32
+//       CHECK:     return %[[RES]]#0, %[[RES]]#1 : i32, f32
+
+// -----
+
+func.func @uplift_while(%arg0: i64, %arg1: i64, %arg2: i64) -> i64 {
+  %0 = scf.while (%arg3 = %arg0) : (i64) -> (i64) {
+    %1 = arith.cmpi slt, %arg3, %arg1 : i64
+    scf.condition(%1) %arg3 : i64
+  } do {
+  ^bb0(%arg3: i64):
+    "test.test1"(%arg3) : (i64) -> ()
+    %added = arith.addi %arg3, %arg2 : i64
+    "test.test2"(%added) : (i64) -> ()
+    scf.yield %added : i64
+  }
+  return %0 : i64
+}
+
+// CHECK-LABEL: func @uplift_while
+//  CHECK-SAME:     (%[[BEGINI:.*]]: i64, %[[ENDI:.*]]: i64, %[[STEPI:.*]]: i64) -> i64
+//       CHECK:     %[[C1:.*]] = arith.constant 1 : index
+//       CHECK:     %[[BEGIN:.*]] = arith.index_cast %[[BEGINI]] : i64 to index
+//       CHECK:     %[[END:.*]] = arith.index_cast %[[ENDI]] : i64 to index
+//       CHECK:     %[[STEP:.*]] = arith.index_cast %[[STEPI]] : i64 to index
+//       CHECK:     scf.for %[[I:.*]] = %[[BEGIN]] to %[[END]] step %[[STEP]] {
+//       CHECK:     %[[II:.*]] = arith.index_cast %[[I]] : index to i64
+//       CHECK:     "test.test1"(%[[II]]) : (i64) -> ()
+//       CHECK:     %[[INC:.*]] = arith.addi %[[II]], %[[STEPI]] : i64
+//       CHECK:     "test.test2"(%[[INC]]) : (i64) -> ()
+//       CHECK:     %[[R1:.*]] = arith.subi %[[STEP]], %[[C1]] : index
+//       CHECK:     %[[R2:.*]] = arith.subi %[[END]], %[[BEGIN]] : index
+//       CHECK:     %[[R3:.*]] = arith.addi %[[R2]], %[[R1]] : index
+//       CHECK:     %[[R4:.*]] = arith.divsi %[[R3]], %[[STEP]] : index
+//       CHECK:     %[[R5:.*]] = arith.subi %[[R4]], %[[C1]] : index
+//       CHECK:     %[[R6:.*]] = arith.muli %[[R5]], %[[STEP]] : index
+//       CHECK:     %[[R7:.*]] = arith.addi %[[BEGIN]], %[[R6]] : index
+//       CHECK:     %[[RES:.*]] = arith.index_cast %[[R7]] : index to i64
+//       CHECK:     return %[[RES]] : i64

>From a4011b9918bd080a369fa58005fdea80aea1bfb9 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Thu, 21 Dec 2023 15:54:39 +0100
Subject: [PATCH 2/6] Support non-index types

---
 .../mlir/Dialect/SCF/Transforms/Passes.td     |  6 --
 .../mlir/Dialect/SCF/Transforms/Patterns.h    |  3 +-
 .../SCF/Transforms/UpliftWhileToFor.cpp       | 64 +++++--------------
 mlir/test/Dialect/SCF/uplift-while.mlir       | 33 ++++------
 4 files changed, 31 insertions(+), 75 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
index ec28bb0b8b8aa8..907cc77554a8e0 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
@@ -161,12 +161,6 @@ def SCFUpliftWhileToFor : Pass<"scf-uplift-while-to-for"> {
     compatible form. `scf.while` are left unchanged if uplifting is not
     possible.
   }];
-
-  let options = [
-    Option<"indexBitWidth", "index-bitwidth", "unsigned",
-           /*default=*/"64",
-           "Bitwidth of index type.">,
-  ];
  }
 
 
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h b/mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h
index 9f3cdd93071ea9..3320019cd5a9d4 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h
@@ -80,8 +80,7 @@ void populateSCFLoopPipeliningPatterns(RewritePatternSet &patterns,
 void populateSCFForLoopCanonicalizationPatterns(RewritePatternSet &patterns);
 
 /// Populate patterns to uplift `scf.while` ops to `scf.for`.
-void populateUpliftWhileToForPatterns(RewritePatternSet &patterns,
-                                      unsigned indexBitwidth);
+void populateUpliftWhileToForPatterns(RewritePatternSet &patterns);
 
 } // namespace scf
 } // namespace mlir
diff --git a/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp b/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp
index cd16b622504953..ccb4535f3c8d8d 100644
--- a/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp
@@ -26,22 +26,9 @@ namespace mlir {
 
 using namespace mlir;
 
-static bool checkIndexType(arith::CmpIOp op, unsigned indexBitWidth) {
-  auto type = op.getLhs().getType();
-  if (isa<mlir::IndexType>(type))
-    return true;
-
-  if (type.isSignlessInteger(indexBitWidth))
-    return true;
-
-  return false;
-}
-
 namespace {
 struct UpliftWhileOp : public OpRewritePattern<scf::WhileOp> {
-  UpliftWhileOp(MLIRContext *context, unsigned indexBitWidth_)
-      : OpRewritePattern<scf::WhileOp>(context), indexBitWidth(indexBitWidth_) {
-  }
+  using OpRewritePattern::OpRewritePattern;
 
   LogicalResult matchAndRewrite(scf::WhileOp loop,
                                 PatternRewriter &rewriter) const override {
@@ -71,11 +58,6 @@ struct UpliftWhileOp : public OpRewritePattern<scf::WhileOp> {
         diag << "Expected 'slt' or 'sgt' predicate: " << *cmp;
       });
 
-    if (!checkIndexType(cmp, indexBitWidth))
-      return rewriter.notifyMatchFailure(loop, [&](Diagnostic &diag) {
-        diag << "Expected index-like type: " << *cmp;
-      });
-
     BlockArgument iterVar;
     Value end;
     DominanceInfo dom;
@@ -140,17 +122,9 @@ struct UpliftWhileOp : public OpRewritePattern<scf::WhileOp> {
 
     auto begin = loop.getInits()[argNumber];
 
-    auto loc = loop.getLoc();
-    auto indexType = rewriter.getIndexType();
-    auto toIndex = [&](Value val) -> Value {
-      if (val.getType() != indexType)
-        return rewriter.create<arith::IndexCastOp>(loc, indexType, val);
-
-      return val;
-    };
-    begin = toIndex(begin);
-    end = toIndex(end);
-    step = toIndex(step);
+    assert(begin.getType().isIntOrIndex());
+    assert(begin.getType() == end.getType());
+    assert(begin.getType() == step.getType());
 
     llvm::SmallVector<Value> mapping;
     mapping.reserve(loop.getInits().size());
@@ -161,6 +135,7 @@ struct UpliftWhileOp : public OpRewritePattern<scf::WhileOp> {
       mapping.emplace_back(init);
     }
 
+    auto loc = loop.getLoc();
     auto emptyBuidler = [](OpBuilder &, Location, Value, ValueRange) {};
     auto newLoop = rewriter.create<scf::ForOp>(loc, begin, end, step, mapping,
                                                emptyBuidler);
@@ -170,9 +145,6 @@ struct UpliftWhileOp : public OpRewritePattern<scf::WhileOp> {
     OpBuilder::InsertionGuard g(rewriter);
     rewriter.setInsertionPointToStart(newBody);
     Value newIterVar = newBody->getArgument(0);
-    if (newIterVar.getType() != iterVar.getType())
-      newIterVar = rewriter.create<arith::IndexCastOp>(loc, iterVar.getType(),
-                                                       newIterVar);
 
     mapping.clear();
     auto newArgs = newBody->getArguments();
@@ -180,11 +152,7 @@ struct UpliftWhileOp : public OpRewritePattern<scf::WhileOp> {
       if (i < argNumber) {
         mapping.emplace_back(newArgs[i + 1]);
       } else if (i == argNumber) {
-        Value arg = newArgs.front();
-        if (arg.getType() != iterVar.getType())
-          arg =
-              rewriter.create<arith::IndexCastOp>(loc, iterVar.getType(), arg);
-        mapping.emplace_back(arg);
+        mapping.emplace_back(newArgs.front());
       } else {
         mapping.emplace_back(newArgs[i]);
       }
@@ -207,7 +175,13 @@ struct UpliftWhileOp : public OpRewritePattern<scf::WhileOp> {
     rewriter.replaceOpWithNewOp<scf::YieldOp>(term, mapping);
 
     rewriter.setInsertionPointAfter(newLoop);
-    Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+    Value one;
+    if (isa<IndexType>(step.getType())) {
+      one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+    } else {
+      one = rewriter.create<arith::ConstantIntOp>(loc, 1, step.getType());
+    }
+
     Value stepDec = rewriter.create<arith::SubIOp>(loc, step, one);
     Value len = rewriter.create<arith::SubIOp>(loc, end, begin);
     len = rewriter.create<arith::AddIOp>(loc, len, stepDec);
@@ -215,8 +189,6 @@ struct UpliftWhileOp : public OpRewritePattern<scf::WhileOp> {
     len = rewriter.create<arith::SubIOp>(loc, len, one);
     Value res = rewriter.create<arith::MulIOp>(loc, len, step);
     res = rewriter.create<arith::AddIOp>(loc, begin, res);
-    if (res.getType() != iterVar.getType())
-      res = rewriter.create<arith::IndexCastOp>(loc, iterVar.getType(), res);
 
     mapping.clear();
     llvm::append_range(mapping, newLoop.getResults());
@@ -224,9 +196,6 @@ struct UpliftWhileOp : public OpRewritePattern<scf::WhileOp> {
     rewriter.replaceOp(loop, mapping);
     return success();
   }
-
-private:
-  unsigned indexBitWidth = 0;
 };
 
 struct SCFUpliftWhileToFor final
@@ -237,14 +206,13 @@ struct SCFUpliftWhileToFor final
     Operation *op = getOperation();
     MLIRContext *ctx = op->getContext();
     RewritePatternSet patterns(ctx);
-    mlir::scf::populateUpliftWhileToForPatterns(patterns, this->indexBitWidth);
+    mlir::scf::populateUpliftWhileToForPatterns(patterns);
     if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
       signalPassFailure();
   }
 };
 } // namespace
 
-void mlir::scf::populateUpliftWhileToForPatterns(RewritePatternSet &patterns,
-                                                 unsigned indexBitwidth) {
-  patterns.add<UpliftWhileOp>(patterns.getContext(), indexBitwidth);
+void mlir::scf::populateUpliftWhileToForPatterns(RewritePatternSet &patterns) {
+  patterns.add<UpliftWhileOp>(patterns.getContext());
 }
diff --git a/mlir/test/Dialect/SCF/uplift-while.mlir b/mlir/test/Dialect/SCF/uplift-while.mlir
index 52a5c0f3cd6347..57394fd3528a62 100644
--- a/mlir/test/Dialect/SCF/uplift-while.mlir
+++ b/mlir/test/Dialect/SCF/uplift-while.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -pass-pipeline='builtin.module(func.func(scf-uplift-while-to-for{index-bitwidth=64}))' -split-input-file -allow-unregistered-dialect | FileCheck %s
+// RUN: mlir-opt %s -pass-pipeline='builtin.module(func.func(scf-uplift-while-to-for))' -split-input-file -allow-unregistered-dialect | FileCheck %s
 
 func.func @uplift_while(%arg0: index, %arg1: index, %arg2: index) -> index {
   %0 = scf.while (%arg3 = %arg0) : (index) -> (index) {
@@ -141,22 +141,17 @@ func.func @uplift_while(%arg0: i64, %arg1: i64, %arg2: i64) -> i64 {
 }
 
 // CHECK-LABEL: func @uplift_while
-//  CHECK-SAME:     (%[[BEGINI:.*]]: i64, %[[ENDI:.*]]: i64, %[[STEPI:.*]]: i64) -> i64
-//       CHECK:     %[[C1:.*]] = arith.constant 1 : index
-//       CHECK:     %[[BEGIN:.*]] = arith.index_cast %[[BEGINI]] : i64 to index
-//       CHECK:     %[[END:.*]] = arith.index_cast %[[ENDI]] : i64 to index
-//       CHECK:     %[[STEP:.*]] = arith.index_cast %[[STEPI]] : i64 to index
-//       CHECK:     scf.for %[[I:.*]] = %[[BEGIN]] to %[[END]] step %[[STEP]] {
-//       CHECK:     %[[II:.*]] = arith.index_cast %[[I]] : index to i64
-//       CHECK:     "test.test1"(%[[II]]) : (i64) -> ()
-//       CHECK:     %[[INC:.*]] = arith.addi %[[II]], %[[STEPI]] : i64
+//  CHECK-SAME:     (%[[BEGIN:.*]]: i64, %[[END:.*]]: i64, %[[STEP:.*]]: i64) -> i64
+//       CHECK:     %[[C1:.*]] = arith.constant 1 : i64
+//       CHECK:     scf.for %[[I:.*]] = %[[BEGIN]] to %[[END]] step %[[STEP]] : i64 {
+//       CHECK:     "test.test1"(%[[I]]) : (i64) -> ()
+//       CHECK:     %[[INC:.*]] = arith.addi %[[I]], %[[STEP]] : i64
 //       CHECK:     "test.test2"(%[[INC]]) : (i64) -> ()
-//       CHECK:     %[[R1:.*]] = arith.subi %[[STEP]], %[[C1]] : index
-//       CHECK:     %[[R2:.*]] = arith.subi %[[END]], %[[BEGIN]] : index
-//       CHECK:     %[[R3:.*]] = arith.addi %[[R2]], %[[R1]] : index
-//       CHECK:     %[[R4:.*]] = arith.divsi %[[R3]], %[[STEP]] : index
-//       CHECK:     %[[R5:.*]] = arith.subi %[[R4]], %[[C1]] : index
-//       CHECK:     %[[R6:.*]] = arith.muli %[[R5]], %[[STEP]] : index
-//       CHECK:     %[[R7:.*]] = arith.addi %[[BEGIN]], %[[R6]] : index
-//       CHECK:     %[[RES:.*]] = arith.index_cast %[[R7]] : index to i64
-//       CHECK:     return %[[RES]] : i64
+//       CHECK:     %[[R1:.*]] = arith.subi %[[STEP]], %[[C1]] : i64
+//       CHECK:     %[[R2:.*]] = arith.subi %[[END]], %[[BEGIN]] : i64
+//       CHECK:     %[[R3:.*]] = arith.addi %[[R2]], %[[R1]] : i64
+//       CHECK:     %[[R4:.*]] = arith.divsi %[[R3]], %[[STEP]] : i64
+//       CHECK:     %[[R5:.*]] = arith.subi %[[R4]], %[[C1]] : i64
+//       CHECK:     %[[R6:.*]] = arith.muli %[[R5]], %[[STEP]] : i64
+//       CHECK:     %[[R7:.*]] = arith.addi %[[BEGIN]], %[[R6]] : i64
+//       CHECK:     return %[[R7]] : i64

>From dc973c4c3c53c6e7c6a2d33e92a472ff1277de54 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Thu, 21 Dec 2023 16:18:47 +0100
Subject: [PATCH 3/6] cleanup

---
 .../mlir/Dialect/SCF/Transforms/Passes.td     |  4 ++
 .../SCF/Transforms/UpliftWhileToFor.cpp       | 50 +++++++++++--------
 2 files changed, 33 insertions(+), 21 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
index 907cc77554a8e0..42b02bb7ac6c4f 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
@@ -160,6 +160,10 @@ def SCFUpliftWhileToFor : Pass<"scf-uplift-while-to-for"> {
     This pass tries to uplift `scf.while` ops to `scf.for` if they have a
     compatible form. `scf.while` are left unchanged if uplifting is not
     possible.
+
+    This pass expects a specific ops pattern:
+    * `before` block consisting of single arith.cmp op
+    * `after` block containing arith.addi
   }];
  }
 
diff --git a/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp b/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp
index ccb4535f3c8d8d..e1f9d25dee4d0e 100644
--- a/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp
@@ -41,26 +41,31 @@ struct UpliftWhileOp : public OpRewritePattern<scf::WhileOp> {
       return rewriter.notifyMatchFailure(loop,
                                          "Loop body must have single cmp op");
 
-    auto beforeTerm = cast<scf::ConditionOp>(beforeBody->getTerminator());
-    if (!llvm::hasSingleElement(cmp->getUses()) &&
-        beforeTerm.getCondition() == cmp.getResult())
+    scf::ConditionOp beforeTerm = loop.getConditionOp();
+    if (!cmp->hasOneUse() && beforeTerm.getCondition() == cmp.getResult())
       return rewriter.notifyMatchFailure(loop, [&](Diagnostic &diag) {
         diag << "Expected single condiditon use: " << *cmp;
       });
 
+    // All `before` block args must be directly forwarded to ConditionOp.
+    // They will be converted to `scf.for` `iter_vars` except induction var.
     if (ValueRange(beforeBody->getArguments()) != beforeTerm.getArgs())
       return rewriter.notifyMatchFailure(loop, "Invalid args order");
 
     using Pred = arith::CmpIPredicate;
-    auto predicate = cmp.getPredicate();
+    Pred predicate = cmp.getPredicate();
     if (predicate != Pred::slt && predicate != Pred::sgt)
       return rewriter.notifyMatchFailure(loop, [&](Diagnostic &diag) {
         diag << "Expected 'slt' or 'sgt' predicate: " << *cmp;
       });
 
-    BlockArgument iterVar;
+    BlockArgument indVar;
     Value end;
     DominanceInfo dom;
+
+    // Check if cmp has a suitable form. One of the arguments must be a `before`
+    // block arg, other must be defined outside `scf.while` and will be treated
+    // as upper bound.
     for (bool reverse : {false, true}) {
       auto expectedPred = reverse ? Pred::sgt : Pred::slt;
       if (cmp.getPredicate() != expectedPred)
@@ -76,36 +81,42 @@ struct UpliftWhileOp : public OpRewritePattern<scf::WhileOp> {
       if (!dom.properlyDominates(arg2, loop))
         continue;
 
-      iterVar = blockArg;
+      indVar = blockArg;
       end = arg2;
       break;
     }
 
-    if (!iterVar)
+    if (!indVar)
       return rewriter.notifyMatchFailure(loop, [&](Diagnostic &diag) {
         diag << "Unrecognized cmp form: " << *cmp;
       });
 
-    if (!llvm::hasNItems(iterVar.getUses(), 2))
+    // indVar must have 2 uses: one is in `cmp` and other is `condition` arg.
+    if (!llvm::hasNItems(indVar.getUses(), 2))
       return rewriter.notifyMatchFailure(loop, [&](Diagnostic &diag) {
-        diag << "Unrecognized iter var: " << iterVar;
+        diag << "Unrecognized induction var: " << indVar;
       });
 
     Block *afterBody = loop.getAfterBody();
-    auto afterTerm = cast<scf::YieldOp>(afterBody->getTerminator());
-    auto argNumber = iterVar.getArgNumber();
+    scf::YieldOp afterTerm = loop.getYieldOp();
+    auto argNumber = indVar.getArgNumber();
     auto afterTermIterArg = afterTerm.getResults()[argNumber];
 
-    auto iterVarAfter = afterBody->getArgument(argNumber);
+    auto indVarAfter = afterBody->getArgument(argNumber);
 
     Value step;
-    for (auto &use : iterVarAfter.getUses()) {
+
+    // Find suitable `addi` op inside `after` block, one of the args must be an
+    // Induction var passed from `before` block and second arg must be defined
+    // outside of the loop and will be considered step value.
+    // TODO: Add `subi` support?
+    for (auto &use : indVarAfter.getUses()) {
       auto owner = dyn_cast<arith::AddIOp>(use.getOwner());
       if (!owner)
         continue;
 
       auto other =
-          (iterVarAfter == owner.getLhs() ? owner.getRhs() : owner.getLhs());
+          (indVarAfter == owner.getLhs() ? owner.getRhs() : owner.getLhs());
       if (!dom.properlyDominates(other, loop))
         continue;
 
@@ -118,7 +129,7 @@ struct UpliftWhileOp : public OpRewritePattern<scf::WhileOp> {
 
     if (!step)
       return rewriter.notifyMatchFailure(loop,
-                                         "Didn't found suitable 'add' op");
+                                         "Didn't found suitable 'addi' op");
 
     auto begin = loop.getInits()[argNumber];
 
@@ -136,16 +147,12 @@ struct UpliftWhileOp : public OpRewritePattern<scf::WhileOp> {
     }
 
     auto loc = loop.getLoc();
-    auto emptyBuidler = [](OpBuilder &, Location, Value, ValueRange) {};
+    auto emptyBuilder = [](OpBuilder &, Location, Value, ValueRange) {};
     auto newLoop = rewriter.create<scf::ForOp>(loc, begin, end, step, mapping,
-                                               emptyBuidler);
+                                               emptyBuilder);
 
     Block *newBody = newLoop.getBody();
 
-    OpBuilder::InsertionGuard g(rewriter);
-    rewriter.setInsertionPointToStart(newBody);
-    Value newIterVar = newBody->getArgument(0);
-
     mapping.clear();
     auto newArgs = newBody->getArguments();
     for (auto i : llvm::seq<size_t>(0, newArgs.size())) {
@@ -171,6 +178,7 @@ struct UpliftWhileOp : public OpRewritePattern<scf::WhileOp> {
       mapping.emplace_back(arg);
     }
 
+    OpBuilder::InsertionGuard g(rewriter);
     rewriter.setInsertionPoint(term);
     rewriter.replaceOpWithNewOp<scf::YieldOp>(term, mapping);
 

>From 6bef63da04a74d4112051804927490086af604da Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Thu, 21 Dec 2023 16:37:14 +0100
Subject: [PATCH 4/6] Renamings and comments

---
 .../SCF/Transforms/UpliftWhileToFor.cpp       | 67 +++++++++++--------
 1 file changed, 39 insertions(+), 28 deletions(-)

diff --git a/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp b/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp
index e1f9d25dee4d0e..0e4dd60217a72f 100644
--- a/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp
@@ -60,7 +60,7 @@ struct UpliftWhileOp : public OpRewritePattern<scf::WhileOp> {
       });
 
     BlockArgument indVar;
-    Value end;
+    Value ub;
     DominanceInfo dom;
 
     // Check if cmp has a suitable form. One of the arguments must be a `before`
@@ -82,7 +82,7 @@ struct UpliftWhileOp : public OpRewritePattern<scf::WhileOp> {
         continue;
 
       indVar = blockArg;
-      end = arg2;
+      ub = arg2;
       break;
     }
 
@@ -131,57 +131,66 @@ struct UpliftWhileOp : public OpRewritePattern<scf::WhileOp> {
       return rewriter.notifyMatchFailure(loop,
                                          "Didn't found suitable 'addi' op");
 
-    auto begin = loop.getInits()[argNumber];
+    auto lb = loop.getInits()[argNumber];
 
-    assert(begin.getType().isIntOrIndex());
-    assert(begin.getType() == end.getType());
-    assert(begin.getType() == step.getType());
+    assert(lb.getType().isIntOrIndex());
+    assert(lb.getType() == ub.getType());
+    assert(lb.getType() == step.getType());
 
-    llvm::SmallVector<Value> mapping;
-    mapping.reserve(loop.getInits().size());
+    llvm::SmallVector<Value> newArgs;
+
+    // Populate inits for new `scf.for`
+    newArgs.reserve(loop.getInits().size());
     for (auto &&[i, init] : llvm::enumerate(loop.getInits())) {
       if (i == argNumber)
         continue;
 
-      mapping.emplace_back(init);
+      newArgs.emplace_back(init);
     }
 
     auto loc = loop.getLoc();
+
+    // With `builder == nullptr`, ForOp::build will try to insert terminator at
+    // the end of newly created block and we don't want it. Provide empty
+    // dummy builder instead.
     auto emptyBuilder = [](OpBuilder &, Location, Value, ValueRange) {};
-    auto newLoop = rewriter.create<scf::ForOp>(loc, begin, end, step, mapping,
-                                               emptyBuilder);
+    auto newLoop =
+        rewriter.create<scf::ForOp>(loc, lb, ub, step, newArgs, emptyBuilder);
 
     Block *newBody = newLoop.getBody();
 
-    mapping.clear();
-    auto newArgs = newBody->getArguments();
-    for (auto i : llvm::seq<size_t>(0, newArgs.size())) {
+    // Populate block args for `scf.for` body, move induction var to the front.
+    newArgs.clear();
+    ValueRange newBodyArgs = newBody->getArguments();
+    for (auto i : llvm::seq<size_t>(0, newBodyArgs.size())) {
       if (i < argNumber) {
-        mapping.emplace_back(newArgs[i + 1]);
+        newArgs.emplace_back(newBodyArgs[i + 1]);
       } else if (i == argNumber) {
-        mapping.emplace_back(newArgs.front());
+        newArgs.emplace_back(newBodyArgs.front());
       } else {
-        mapping.emplace_back(newArgs[i]);
+        newArgs.emplace_back(newBodyArgs[i]);
       }
     }
 
     rewriter.inlineBlockBefore(loop.getAfterBody(), newBody, newBody->end(),
-                               mapping);
+                               newArgs);
 
     auto term = cast<scf::YieldOp>(newBody->getTerminator());
 
-    mapping.clear();
+    // Populate new yield args, skipping the induction var.
+    newArgs.clear();
     for (auto &&[i, arg] : llvm::enumerate(term.getResults())) {
       if (i == argNumber)
         continue;
 
-      mapping.emplace_back(arg);
+      newArgs.emplace_back(arg);
     }
 
     OpBuilder::InsertionGuard g(rewriter);
     rewriter.setInsertionPoint(term);
-    rewriter.replaceOpWithNewOp<scf::YieldOp>(term, mapping);
+    rewriter.replaceOpWithNewOp<scf::YieldOp>(term, newArgs);
 
+    // Compute induction var value after loop execution.
     rewriter.setInsertionPointAfter(newLoop);
     Value one;
     if (isa<IndexType>(step.getType())) {
@@ -191,17 +200,19 @@ struct UpliftWhileOp : public OpRewritePattern<scf::WhileOp> {
     }
 
     Value stepDec = rewriter.create<arith::SubIOp>(loc, step, one);
-    Value len = rewriter.create<arith::SubIOp>(loc, end, begin);
+    Value len = rewriter.create<arith::SubIOp>(loc, ub, lb);
     len = rewriter.create<arith::AddIOp>(loc, len, stepDec);
     len = rewriter.create<arith::DivSIOp>(loc, len, step);
     len = rewriter.create<arith::SubIOp>(loc, len, one);
     Value res = rewriter.create<arith::MulIOp>(loc, len, step);
-    res = rewriter.create<arith::AddIOp>(loc, begin, res);
-
-    mapping.clear();
-    llvm::append_range(mapping, newLoop.getResults());
-    mapping.insert(mapping.begin() + argNumber, res);
-    rewriter.replaceOp(loop, mapping);
+    res = rewriter.create<arith::AddIOp>(loc, lb, res);
+
+    // Reconstruct `scf.while` results, inserting final induction var value
+    // into proper place.
+    newArgs.clear();
+    llvm::append_range(newArgs, newLoop.getResults());
+    newArgs.insert(newArgs.begin() + argNumber, res);
+    rewriter.replaceOp(loop, newArgs);
     return success();
   }
 };

>From 26da9305f30ca67803e6def14d2ca5c29b47e6ae Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Thu, 21 Dec 2023 16:47:07 +0100
Subject: [PATCH 5/6] renaming

---
 .../SCF/Transforms/UpliftWhileToFor.cpp       | 29 ++++++++++---------
 1 file changed, 15 insertions(+), 14 deletions(-)

diff --git a/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp b/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp
index 0e4dd60217a72f..a14c726fcd0537 100644
--- a/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp
@@ -59,7 +59,7 @@ struct UpliftWhileOp : public OpRewritePattern<scf::WhileOp> {
         diag << "Expected 'slt' or 'sgt' predicate: " << *cmp;
       });
 
-    BlockArgument indVar;
+    BlockArgument inductionVar;
     Value ub;
     DominanceInfo dom;
 
@@ -81,28 +81,29 @@ struct UpliftWhileOp : public OpRewritePattern<scf::WhileOp> {
       if (!dom.properlyDominates(arg2, loop))
         continue;
 
-      indVar = blockArg;
+      inductionVar = blockArg;
       ub = arg2;
       break;
     }
 
-    if (!indVar)
+    if (!inductionVar)
       return rewriter.notifyMatchFailure(loop, [&](Diagnostic &diag) {
         diag << "Unrecognized cmp form: " << *cmp;
       });
 
-    // indVar must have 2 uses: one is in `cmp` and other is `condition` arg.
-    if (!llvm::hasNItems(indVar.getUses(), 2))
+    // inductionVar must have 2 uses: one is in `cmp` and other is `condition`
+    // arg.
+    if (!llvm::hasNItems(inductionVar.getUses(), 2))
       return rewriter.notifyMatchFailure(loop, [&](Diagnostic &diag) {
-        diag << "Unrecognized induction var: " << indVar;
+        diag << "Unrecognized induction var: " << inductionVar;
       });
 
     Block *afterBody = loop.getAfterBody();
     scf::YieldOp afterTerm = loop.getYieldOp();
-    auto argNumber = indVar.getArgNumber();
-    auto afterTermIterArg = afterTerm.getResults()[argNumber];
+    auto argNumber = inductionVar.getArgNumber();
+    auto afterTermIndArg = afterTerm.getResults()[argNumber];
 
-    auto indVarAfter = afterBody->getArgument(argNumber);
+    auto inductionVarAfter = afterBody->getArgument(argNumber);
 
     Value step;
 
@@ -110,17 +111,17 @@ struct UpliftWhileOp : public OpRewritePattern<scf::WhileOp> {
     // Induction var passed from `before` block and second arg must be defined
     // outside of the loop and will be considered step value.
     // TODO: Add `subi` support?
-    for (auto &use : indVarAfter.getUses()) {
+    for (auto &use : inductionVarAfter.getUses()) {
       auto owner = dyn_cast<arith::AddIOp>(use.getOwner());
       if (!owner)
         continue;
 
-      auto other =
-          (indVarAfter == owner.getLhs() ? owner.getRhs() : owner.getLhs());
+      auto other = (inductionVarAfter == owner.getLhs() ? owner.getRhs()
+                                                        : owner.getLhs());
       if (!dom.properlyDominates(other, loop))
         continue;
 
-      if (afterTermIterArg != owner.getResult())
+      if (afterTermIndArg != owner.getResult())
         continue;
 
       step = other;
@@ -139,7 +140,7 @@ struct UpliftWhileOp : public OpRewritePattern<scf::WhileOp> {
 
     llvm::SmallVector<Value> newArgs;
 
-    // Populate inits for new `scf.for`
+    // Populate inits for new `scf.for`, skip induction var.
     newArgs.reserve(loop.getInits().size());
     for (auto &&[i, init] : llvm::enumerate(loop.getInits())) {
       if (i == argNumber)

>From b4ca7e0e22db9100da74aaf0c29f03d0b00e54a1 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Thu, 4 Jan 2024 21:00:36 +0100
Subject: [PATCH 6/6] Renamed to test pass

---
 mlir/include/mlir/Dialect/SCF/Transforms/Passes.td   | 2 +-
 mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h  | 3 +++
 mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp | 8 ++++----
 mlir/test/Dialect/SCF/uplift-while.mlir              | 2 +-
 4 files changed, 9 insertions(+), 6 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
index 42b02bb7ac6c4f..f9fc4a0358c9ec 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
@@ -154,7 +154,7 @@ def SCFForToWhileLoop : Pass<"scf-for-to-while"> {
   }];
 }
 
-def SCFUpliftWhileToFor : Pass<"scf-uplift-while-to-for"> {
+def TestSCFUpliftWhileToFor : Pass<"test-scf-uplift-while-to-for"> {
   let summary = "Uplift scf.while ops to scf.for";
   let description = [{
     This pass tries to uplift `scf.while` ops to `scf.for` if they have a
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h b/mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h
index 3320019cd5a9d4..fdf25706269803 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h
@@ -80,6 +80,9 @@ void populateSCFLoopPipeliningPatterns(RewritePatternSet &patterns,
 void populateSCFForLoopCanonicalizationPatterns(RewritePatternSet &patterns);
 
 /// Populate patterns to uplift `scf.while` ops to `scf.for`.
+/// Uplifitng expects a specific ops pattern:
+///  * `before` block consisting of single arith.cmp op
+///  * `after` block containing arith.addi
 void populateUpliftWhileToForPatterns(RewritePatternSet &patterns);
 
 } // namespace scf
diff --git a/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp b/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp
index a14c726fcd0537..74152f7a5304ce 100644
--- a/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp
@@ -20,7 +20,7 @@
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 
 namespace mlir {
-#define GEN_PASS_DEF_SCFUPLIFTWHILETOFOR
+#define GEN_PASS_DEF_TESTSCFUPLIFTWHILETOFOR
 #include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
 } // namespace mlir
 
@@ -218,9 +218,9 @@ struct UpliftWhileOp : public OpRewritePattern<scf::WhileOp> {
   }
 };
 
-struct SCFUpliftWhileToFor final
-    : impl::SCFUpliftWhileToForBase<SCFUpliftWhileToFor> {
-  using SCFUpliftWhileToForBase::SCFUpliftWhileToForBase;
+struct TestSCFUpliftWhileToFor final
+    : impl::TestSCFUpliftWhileToForBase<TestSCFUpliftWhileToFor> {
+  using TestSCFUpliftWhileToForBase::TestSCFUpliftWhileToForBase;
 
   void runOnOperation() override {
     Operation *op = getOperation();
diff --git a/mlir/test/Dialect/SCF/uplift-while.mlir b/mlir/test/Dialect/SCF/uplift-while.mlir
index 57394fd3528a62..25ea6142a332dc 100644
--- a/mlir/test/Dialect/SCF/uplift-while.mlir
+++ b/mlir/test/Dialect/SCF/uplift-while.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -pass-pipeline='builtin.module(func.func(scf-uplift-while-to-for))' -split-input-file -allow-unregistered-dialect | FileCheck %s
+// RUN: mlir-opt %s -pass-pipeline='builtin.module(func.func(test-scf-uplift-while-to-for))' -split-input-file -allow-unregistered-dialect | FileCheck %s
 
 func.func @uplift_while(%arg0: index, %arg1: index, %arg2: index) -> index {
   %0 = scf.while (%arg3 = %arg0) : (index) -> (index) {



More information about the Mlir-commits mailing list