[Mlir-commits] [mlir] b153c05 - [mlir][scf] Uplift `scf.while` to `scf.for` (#76108)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Apr 15 12:17:03 PDT 2024


Author: Ivan Butygin
Date: 2024-04-15T22:16:59+03:00
New Revision: b153c05cba9be7f009b8ad8413c5840baf7d278c

URL: https://github.com/llvm/llvm-project/commit/b153c05cba9be7f009b8ad8413c5840baf7d278c
DIFF: https://github.com/llvm/llvm-project/commit/b153c05cba9be7f009b8ad8413c5840baf7d278c.diff

LOG: [mlir][scf] Uplift `scf.while` to `scf.for` (#76108)

Add uplifting from `scf.while` to `scf.for`.

This uplifting expects a very specific ops pattern:
* `before` block consisting of single `arith.cmp` op
* `after` block containing `arith.addi`

We also have a set of patterns to cleanup `scf.while` loops to get them
close to the desired form, they will be added in separate PRs.

This is part of upstreaming `numba-mlir` scf uplifting pipeline: `cf ->
scf.while -> scf.for -> scf.parallel`

Original code:
https://github.com/numba/numba-mlir/blob/main/mlir/lib/Transforms/PromoteToParallel.cpp

Added: 
    mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp
    mlir/test/Dialect/SCF/uplift-while.mlir
    mlir/test/lib/Dialect/SCF/TestUpliftWhileToFor.cpp

Modified: 
    mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h
    mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h
    mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
    mlir/test/lib/Dialect/SCF/CMakeLists.txt
    mlir/tools/mlir-opt/mlir-opt.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h b/mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h
index 5c0d5643c01986..fdf25706269803 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h
@@ -79,6 +79,12 @@ void populateSCFLoopPipeliningPatterns(RewritePatternSet &patterns,
 /// loop bounds and loop steps are canonicalized.
 void populateSCFForLoopCanonicalizationPatterns(RewritePatternSet &patterns);
 
+/// Populate patterns to uplift `scf.while` ops to `scf.for`.
+/// Uplifitng expects a specific ops pattern:
+///  * `before` block consisting of single arith.cmp op
+///  * `after` block containing arith.addi
+void populateUpliftWhileToForPatterns(RewritePatternSet &patterns);
+
 } // namespace scf
 } // namespace mlir
 

diff  --git a/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h b/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h
index 690cd146c606e9..220dcb35571d27 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h
@@ -222,6 +222,12 @@ FailureOr<WhileOp> wrapWhileLoopInZeroTripCheck(WhileOp whileOp,
                                                 RewriterBase &rewriter,
                                                 bool forceCreateCheck = false);
 
+/// Try to uplift `scf.while` op to `scf.for`.
+/// Uplifitng expects a specific ops pattern:
+///  * `before` block consisting of single arith.cmp op
+///  * `after` block containing arith.addi
+FailureOr<ForOp> upliftWhileToForLoop(RewriterBase &rewriter, WhileOp loop);
+
 } // namespace scf
 } // namespace mlir
 

diff  --git a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
index e5494205e086ac..a2925aef17ca78 100644
--- a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
@@ -14,6 +14,7 @@ add_mlir_dialect_library(MLIRSCFTransforms
   StructuralTypeConversions.cpp
   TileUsingInterface.cpp
   WrapInZeroTripCheck.cpp
+  UpliftWhileToFor.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/SCF

diff  --git a/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp b/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp
new file mode 100644
index 00000000000000..fea2f659535bb4
--- /dev/null
+++ b/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp
@@ -0,0 +1,222 @@
+//===- UpliftWhileToFor.cpp - scf.while to scf.for loop uplifting ---------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Transforms SCF.WhileOp's into SCF.ForOp's.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/SCF/Transforms/Passes.h"
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/SCF/Transforms/Patterns.h"
+#include "mlir/IR/Dominance.h"
+#include "mlir/IR/PatternMatch.h"
+
+using namespace mlir;
+
+namespace {
+struct UpliftWhileOp : public OpRewritePattern<scf::WhileOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(scf::WhileOp loop,
+                                PatternRewriter &rewriter) const override {
+    return upliftWhileToForLoop(rewriter, loop);
+  }
+};
+} // namespace
+
+FailureOr<scf::ForOp> mlir::scf::upliftWhileToForLoop(RewriterBase &rewriter,
+                                                      scf::WhileOp loop) {
+  Block *beforeBody = loop.getBeforeBody();
+  if (!llvm::hasSingleElement(beforeBody->without_terminator()))
+    return rewriter.notifyMatchFailure(loop, "Loop body must have single op");
+
+  auto cmp = dyn_cast<arith::CmpIOp>(beforeBody->front());
+  if (!cmp)
+    return rewriter.notifyMatchFailure(loop,
+                                       "Loop body must have single cmp op");
+
+  scf::ConditionOp beforeTerm = loop.getConditionOp();
+  if (!cmp->hasOneUse() || beforeTerm.getCondition() != cmp.getResult())
+    return rewriter.notifyMatchFailure(loop, [&](Diagnostic &diag) {
+      diag << "Expected single condition use: " << *cmp;
+    });
+
+  // All `before` block args must be directly forwarded to ConditionOp.
+  // They will be converted to `scf.for` `iter_vars` except induction var.
+  if (ValueRange(beforeBody->getArguments()) != beforeTerm.getArgs())
+    return rewriter.notifyMatchFailure(loop, "Invalid args order");
+
+  using Pred = arith::CmpIPredicate;
+  Pred predicate = cmp.getPredicate();
+  if (predicate != Pred::slt && predicate != Pred::sgt)
+    return rewriter.notifyMatchFailure(loop, [&](Diagnostic &diag) {
+      diag << "Expected 'slt' or 'sgt' predicate: " << *cmp;
+    });
+
+  BlockArgument inductionVar;
+  Value ub;
+  DominanceInfo dom;
+
+  // Check if cmp has a suitable form. One of the arguments must be a `before`
+  // block arg, other must be defined outside `scf.while` and will be treated
+  // as upper bound.
+  for (bool reverse : {false, true}) {
+    auto expectedPred = reverse ? Pred::sgt : Pred::slt;
+    if (cmp.getPredicate() != expectedPred)
+      continue;
+
+    auto arg1 = reverse ? cmp.getRhs() : cmp.getLhs();
+    auto arg2 = reverse ? cmp.getLhs() : cmp.getRhs();
+
+    auto blockArg = dyn_cast<BlockArgument>(arg1);
+    if (!blockArg || blockArg.getOwner() != beforeBody)
+      continue;
+
+    if (!dom.properlyDominates(arg2, loop))
+      continue;
+
+    inductionVar = blockArg;
+    ub = arg2;
+    break;
+  }
+
+  if (!inductionVar)
+    return rewriter.notifyMatchFailure(loop, [&](Diagnostic &diag) {
+      diag << "Unrecognized cmp form: " << *cmp;
+    });
+
+  // inductionVar must have 2 uses: one is in `cmp` and other is `condition`
+  // arg.
+  if (!llvm::hasNItems(inductionVar.getUses(), 2))
+    return rewriter.notifyMatchFailure(loop, [&](Diagnostic &diag) {
+      diag << "Unrecognized induction var: " << inductionVar;
+    });
+
+  Block *afterBody = loop.getAfterBody();
+  scf::YieldOp afterTerm = loop.getYieldOp();
+  auto argNumber = inductionVar.getArgNumber();
+  auto afterTermIndArg = afterTerm.getResults()[argNumber];
+
+  auto inductionVarAfter = afterBody->getArgument(argNumber);
+
+  Value step;
+
+  // Find suitable `addi` op inside `after` block, one of the args must be an
+  // Induction var passed from `before` block and second arg must be defined
+  // outside of the loop and will be considered step value.
+  // TODO: Add `subi` support?
+  for (auto &use : inductionVarAfter.getUses()) {
+    auto owner = dyn_cast<arith::AddIOp>(use.getOwner());
+    if (!owner)
+      continue;
+
+    auto other =
+        (inductionVarAfter == owner.getLhs() ? owner.getRhs() : owner.getLhs());
+    if (!dom.properlyDominates(other, loop))
+      continue;
+
+    if (afterTermIndArg != owner.getResult())
+      continue;
+
+    step = other;
+    break;
+  }
+
+  if (!step)
+    return rewriter.notifyMatchFailure(loop, "Didn't found suitable 'addi' op");
+
+  auto lb = loop.getInits()[argNumber];
+
+  assert(lb.getType().isIntOrIndex());
+  assert(lb.getType() == ub.getType());
+  assert(lb.getType() == step.getType());
+
+  llvm::SmallVector<Value> newArgs;
+
+  // Populate inits for new `scf.for`, skip induction var.
+  newArgs.reserve(loop.getInits().size());
+  for (auto &&[i, init] : llvm::enumerate(loop.getInits())) {
+    if (i == argNumber)
+      continue;
+
+    newArgs.emplace_back(init);
+  }
+
+  Location loc = loop.getLoc();
+
+  // With `builder == nullptr`, ForOp::build will try to insert terminator at
+  // the end of newly created block and we don't want it. Provide empty
+  // dummy builder instead.
+  auto emptyBuilder = [](OpBuilder &, Location, Value, ValueRange) {};
+  auto newLoop =
+      rewriter.create<scf::ForOp>(loc, lb, ub, step, newArgs, emptyBuilder);
+
+  Block *newBody = newLoop.getBody();
+
+  // Populate block args for `scf.for` body, move induction var to the front.
+  newArgs.clear();
+  ValueRange newBodyArgs = newBody->getArguments();
+  for (auto i : llvm::seq<size_t>(0, newBodyArgs.size())) {
+    if (i < argNumber) {
+      newArgs.emplace_back(newBodyArgs[i + 1]);
+    } else if (i == argNumber) {
+      newArgs.emplace_back(newBodyArgs.front());
+    } else {
+      newArgs.emplace_back(newBodyArgs[i]);
+    }
+  }
+
+  rewriter.inlineBlockBefore(loop.getAfterBody(), newBody, newBody->end(),
+                             newArgs);
+
+  auto term = cast<scf::YieldOp>(newBody->getTerminator());
+
+  // Populate new yield args, skipping the induction var.
+  newArgs.clear();
+  for (auto &&[i, arg] : llvm::enumerate(term.getResults())) {
+    if (i == argNumber)
+      continue;
+
+    newArgs.emplace_back(arg);
+  }
+
+  OpBuilder::InsertionGuard g(rewriter);
+  rewriter.setInsertionPoint(term);
+  rewriter.replaceOpWithNewOp<scf::YieldOp>(term, newArgs);
+
+  // Compute induction var value after loop execution.
+  rewriter.setInsertionPointAfter(newLoop);
+  Value one;
+  if (isa<IndexType>(step.getType())) {
+    one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+  } else {
+    one = rewriter.create<arith::ConstantIntOp>(loc, 1, step.getType());
+  }
+
+  Value stepDec = rewriter.create<arith::SubIOp>(loc, step, one);
+  Value len = rewriter.create<arith::SubIOp>(loc, ub, lb);
+  len = rewriter.create<arith::AddIOp>(loc, len, stepDec);
+  len = rewriter.create<arith::DivSIOp>(loc, len, step);
+  len = rewriter.create<arith::SubIOp>(loc, len, one);
+  Value res = rewriter.create<arith::MulIOp>(loc, len, step);
+  res = rewriter.create<arith::AddIOp>(loc, lb, res);
+
+  // Reconstruct `scf.while` results, inserting final induction var value
+  // into proper place.
+  newArgs.clear();
+  llvm::append_range(newArgs, newLoop.getResults());
+  newArgs.insert(newArgs.begin() + argNumber, res);
+  rewriter.replaceOp(loop, newArgs);
+  return newLoop;
+}
+
+void mlir::scf::populateUpliftWhileToForPatterns(RewritePatternSet &patterns) {
+  patterns.add<UpliftWhileOp>(patterns.getContext());
+}

diff  --git a/mlir/test/Dialect/SCF/uplift-while.mlir b/mlir/test/Dialect/SCF/uplift-while.mlir
new file mode 100644
index 00000000000000..25ea6142a332dc
--- /dev/null
+++ b/mlir/test/Dialect/SCF/uplift-while.mlir
@@ -0,0 +1,157 @@
+// RUN: mlir-opt %s -pass-pipeline='builtin.module(func.func(test-scf-uplift-while-to-for))' -split-input-file -allow-unregistered-dialect | FileCheck %s
+
+func.func @uplift_while(%arg0: index, %arg1: index, %arg2: index) -> index {
+  %0 = scf.while (%arg3 = %arg0) : (index) -> (index) {
+    %1 = arith.cmpi slt, %arg3, %arg1 : index
+    scf.condition(%1) %arg3 : index
+  } do {
+  ^bb0(%arg3: index):
+    "test.test1"(%arg3) : (index) -> ()
+    %added = arith.addi %arg3, %arg2 : index
+    "test.test2"(%added) : (index) -> ()
+    scf.yield %added : index
+  }
+  return %0 : index
+}
+
+// CHECK-LABEL: func @uplift_while
+//  CHECK-SAME:     (%[[BEGIN:.*]]: index, %[[END:.*]]: index, %[[STEP:.*]]: index) -> index
+//       CHECK:     %[[C1:.*]] = arith.constant 1 : index
+//       CHECK:     scf.for %[[I:.*]] = %[[BEGIN]] to %[[END]] step %[[STEP]] {
+//       CHECK:     "test.test1"(%[[I]]) : (index) -> ()
+//       CHECK:     %[[INC:.*]] = arith.addi %[[I]], %[[STEP]] : index
+//       CHECK:     "test.test2"(%[[INC]]) : (index) -> ()
+//       CHECK:     %[[R1:.*]] = arith.subi %[[STEP]], %[[C1]] : index
+//       CHECK:     %[[R2:.*]] = arith.subi %[[END]], %[[BEGIN]] : index
+//       CHECK:     %[[R3:.*]] = arith.addi %[[R2]], %[[R1]] : index
+//       CHECK:     %[[R4:.*]] = arith.divsi %[[R3]], %[[STEP]] : index
+//       CHECK:     %[[R5:.*]] = arith.subi %[[R4]], %[[C1]] : index
+//       CHECK:     %[[R6:.*]] = arith.muli %[[R5]], %[[STEP]] : index
+//       CHECK:     %[[R7:.*]] = arith.addi %[[BEGIN]], %[[R6]] : index
+//       CHECK:     return %[[R7]] : index
+
+// -----
+
+func.func @uplift_while(%arg0: index, %arg1: index, %arg2: index) -> index {
+  %0 = scf.while (%arg3 = %arg0) : (index) -> (index) {
+    %1 = arith.cmpi sgt, %arg1, %arg3 : index
+    scf.condition(%1) %arg3 : index
+  } do {
+  ^bb0(%arg3: index):
+    "test.test1"(%arg3) : (index) -> ()
+    %added = arith.addi %arg3, %arg2 : index
+    "test.test2"(%added) : (index) -> ()
+    scf.yield %added : index
+  }
+  return %0 : index
+}
+
+// CHECK-LABEL: func @uplift_while
+//  CHECK-SAME:     (%[[BEGIN:.*]]: index, %[[END:.*]]: index, %[[STEP:.*]]: index) -> index
+//       CHECK:     %[[C1:.*]] = arith.constant 1 : index
+//       CHECK:     scf.for %[[I:.*]] = %[[BEGIN]] to %[[END]] step %[[STEP]] {
+//       CHECK:     "test.test1"(%[[I]]) : (index) -> ()
+//       CHECK:     %[[INC:.*]] = arith.addi %[[I]], %[[STEP]] : index
+//       CHECK:     "test.test2"(%[[INC]]) : (index) -> ()
+//       CHECK:     %[[R1:.*]] = arith.subi %[[STEP]], %[[C1]] : index
+//       CHECK:     %[[R2:.*]] = arith.subi %[[END]], %[[BEGIN]] : index
+//       CHECK:     %[[R3:.*]] = arith.addi %[[R2]], %[[R1]] : index
+//       CHECK:     %[[R4:.*]] = arith.divsi %[[R3]], %[[STEP]] : index
+//       CHECK:     %[[R5:.*]] = arith.subi %[[R4]], %[[C1]] : index
+//       CHECK:     %[[R6:.*]] = arith.muli %[[R5]], %[[STEP]] : index
+//       CHECK:     %[[R7:.*]] = arith.addi %[[BEGIN]], %[[R6]] : index
+//       CHECK:     return %[[R7]] : index
+
+// -----
+
+func.func @uplift_while(%arg0: index, %arg1: index, %arg2: index) -> index {
+  %0 = scf.while (%arg3 = %arg0) : (index) -> (index) {
+    %1 = arith.cmpi slt, %arg3, %arg1 : index
+    scf.condition(%1) %arg3 : index
+  } do {
+  ^bb0(%arg3: index):
+    "test.test1"(%arg3) : (index) -> ()
+    %added = arith.addi %arg2, %arg3 : index
+    "test.test2"(%added) : (index) -> ()
+    scf.yield %added : index
+  }
+  return %0 : index
+}
+
+// CHECK-LABEL: func @uplift_while
+//  CHECK-SAME:     (%[[BEGIN:.*]]: index, %[[END:.*]]: index, %[[STEP:.*]]: index) -> index
+//       CHECK:     %[[C1:.*]] = arith.constant 1 : index
+//       CHECK:     scf.for %[[I:.*]] = %[[BEGIN]] to %[[END]] step %[[STEP]] {
+//       CHECK:     "test.test1"(%[[I]]) : (index) -> ()
+//       CHECK:     %[[INC:.*]] = arith.addi %[[STEP]], %[[I]] : index
+//       CHECK:     "test.test2"(%[[INC]]) : (index) -> ()
+//       CHECK:     %[[R1:.*]] = arith.subi %[[STEP]], %[[C1]] : index
+//       CHECK:     %[[R2:.*]] = arith.subi %[[END]], %[[BEGIN]] : index
+//       CHECK:     %[[R3:.*]] = arith.addi %[[R2]], %[[R1]] : index
+//       CHECK:     %[[R4:.*]] = arith.divsi %[[R3]], %[[STEP]] : index
+//       CHECK:     %[[R5:.*]] = arith.subi %[[R4]], %[[C1]] : index
+//       CHECK:     %[[R6:.*]] = arith.muli %[[R5]], %[[STEP]] : index
+//       CHECK:     %[[R7:.*]] = arith.addi %[[BEGIN]], %[[R6]] : index
+//       CHECK:     return %[[R7]] : index
+
+
+// -----
+
+func.func @uplift_while(%arg0: index, %arg1: index, %arg2: index) -> (i32, f32) {
+  %c1 = arith.constant 1 : i32
+  %c2 = arith.constant 2.0 : f32
+  %0:3 = scf.while (%arg4 = %c1, %arg3 = %arg0, %arg5 = %c2) : (i32, index, f32) -> (i32, index, f32) {
+    %1 = arith.cmpi slt, %arg3, %arg1 : index
+    scf.condition(%1) %arg4, %arg3, %arg5 : i32, index, f32
+  } do {
+  ^bb0(%arg4: i32, %arg3: index, %arg5: f32):
+    %1 = "test.test1"(%arg4) : (i32) -> i32
+    %added = arith.addi %arg3, %arg2 : index
+    %2 = "test.test2"(%arg5) : (f32) -> f32
+    scf.yield %1, %added, %2 : i32, index, f32
+  }
+  return %0#0, %0#2 : i32, f32
+}
+
+// CHECK-LABEL: func @uplift_while
+//  CHECK-SAME:     (%[[BEGIN:.*]]: index, %[[END:.*]]: index, %[[STEP:.*]]: index) -> (i32, f32)
+//   CHECK-DAG:     %[[C1:.*]] = arith.constant 1 : i32
+//   CHECK-DAG:     %[[C2:.*]] = arith.constant 2.000000e+00 : f32
+//       CHECK:     %[[RES:.*]]:2 = scf.for %[[I:.*]] = %[[BEGIN]] to %[[END]] step %[[STEP]]
+//  CHECK-SAME:     iter_args(%[[ARG1:.*]] = %[[C1]], %[[ARG2:.*]] = %[[C2]]) -> (i32, f32) {
+//       CHECK:     %[[T1:.*]] = "test.test1"(%[[ARG1]]) : (i32) -> i32
+//       CHECK:     %[[T2:.*]] = "test.test2"(%[[ARG2]]) : (f32) -> f32
+//       CHECK:     scf.yield %[[T1]], %[[T2]] : i32, f32
+//       CHECK:     return %[[RES]]#0, %[[RES]]#1 : i32, f32
+
+// -----
+
+func.func @uplift_while(%arg0: i64, %arg1: i64, %arg2: i64) -> i64 {
+  %0 = scf.while (%arg3 = %arg0) : (i64) -> (i64) {
+    %1 = arith.cmpi slt, %arg3, %arg1 : i64
+    scf.condition(%1) %arg3 : i64
+  } do {
+  ^bb0(%arg3: i64):
+    "test.test1"(%arg3) : (i64) -> ()
+    %added = arith.addi %arg3, %arg2 : i64
+    "test.test2"(%added) : (i64) -> ()
+    scf.yield %added : i64
+  }
+  return %0 : i64
+}
+
+// CHECK-LABEL: func @uplift_while
+//  CHECK-SAME:     (%[[BEGIN:.*]]: i64, %[[END:.*]]: i64, %[[STEP:.*]]: i64) -> i64
+//       CHECK:     %[[C1:.*]] = arith.constant 1 : i64
+//       CHECK:     scf.for %[[I:.*]] = %[[BEGIN]] to %[[END]] step %[[STEP]] : i64 {
+//       CHECK:     "test.test1"(%[[I]]) : (i64) -> ()
+//       CHECK:     %[[INC:.*]] = arith.addi %[[I]], %[[STEP]] : i64
+//       CHECK:     "test.test2"(%[[INC]]) : (i64) -> ()
+//       CHECK:     %[[R1:.*]] = arith.subi %[[STEP]], %[[C1]] : i64
+//       CHECK:     %[[R2:.*]] = arith.subi %[[END]], %[[BEGIN]] : i64
+//       CHECK:     %[[R3:.*]] = arith.addi %[[R2]], %[[R1]] : i64
+//       CHECK:     %[[R4:.*]] = arith.divsi %[[R3]], %[[STEP]] : i64
+//       CHECK:     %[[R5:.*]] = arith.subi %[[R4]], %[[C1]] : i64
+//       CHECK:     %[[R6:.*]] = arith.muli %[[R5]], %[[STEP]] : i64
+//       CHECK:     %[[R7:.*]] = arith.addi %[[BEGIN]], %[[R6]] : i64
+//       CHECK:     return %[[R7]] : i64

diff  --git a/mlir/test/lib/Dialect/SCF/CMakeLists.txt b/mlir/test/lib/Dialect/SCF/CMakeLists.txt
index d93bd559151829..792430cc84b650 100644
--- a/mlir/test/lib/Dialect/SCF/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/SCF/CMakeLists.txt
@@ -4,6 +4,7 @@ add_mlir_library(MLIRSCFTestPasses
   TestLoopUnrolling.cpp
   TestSCFUtils.cpp
   TestSCFWrapInZeroTripCheck.cpp
+  TestUpliftWhileToFor.cpp
   TestWhileOpBuilder.cpp
 
   EXCLUDE_FROM_LIBMLIR

diff  --git a/mlir/test/lib/Dialect/SCF/TestUpliftWhileToFor.cpp b/mlir/test/lib/Dialect/SCF/TestUpliftWhileToFor.cpp
new file mode 100644
index 00000000000000..468bc0ca78489f
--- /dev/null
+++ b/mlir/test/lib/Dialect/SCF/TestUpliftWhileToFor.cpp
@@ -0,0 +1,50 @@
+//===- TestUpliftWhileToFor.cpp - while to for loop uplifting test pass ---===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Pass to test transforms SCF.WhileOp's into SCF.ForOp's.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/SCF/Transforms/Patterns.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+
+namespace {
+
+struct TestSCFUpliftWhileToFor
+    : public PassWrapper<TestSCFUpliftWhileToFor, OperationPass<void>> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestSCFUpliftWhileToFor)
+
+  StringRef getArgument() const final { return "test-scf-uplift-while-to-for"; }
+
+  StringRef getDescription() const final {
+    return "test scf while to for uplifting";
+  }
+
+  void runOnOperation() override {
+    Operation *op = getOperation();
+    MLIRContext *ctx = op->getContext();
+    RewritePatternSet patterns(ctx);
+    scf::populateUpliftWhileToForPatterns(patterns);
+    if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
+      signalPassFailure();
+  }
+};
+
+} // namespace
+
+namespace mlir {
+namespace test {
+void registerTestSCFUpliftWhileToFor() {
+  PassRegistration<TestSCFUpliftWhileToFor>();
+}
+} // namespace test
+} // namespace mlir

diff  --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 6ce9f3041d6f48..237ebeb166dc99 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -130,6 +130,7 @@ void registerTestOneToNTypeConversionPass();
 void registerTestOpaqueLoc();
 void registerTestPadFusion();
 void registerTestRecursiveTypesPass();
+void registerTestSCFUpliftWhileToFor();
 void registerTestSCFUtilsPass();
 void registerTestSCFWhileOpBuilderPass();
 void registerTestSCFWrapInZeroTripCheckPasses();
@@ -258,6 +259,7 @@ void registerTestPasses() {
   mlir::test::registerTestOpaqueLoc();
   mlir::test::registerTestPadFusion();
   mlir::test::registerTestRecursiveTypesPass();
+  mlir::test::registerTestSCFUpliftWhileToFor();
   mlir::test::registerTestSCFUtilsPass();
   mlir::test::registerTestSCFWhileOpBuilderPass();
   mlir::test::registerTestSCFWrapInZeroTripCheckPasses();


        


More information about the Mlir-commits mailing list