[flang-commits] [flang] [flang][fir]: Add conversion of `fir.iterate_while` to `scf.while`. (PR #152439)
Terapines MLIR via flang-commits
flang-commits at lists.llvm.org
Fri Aug 8 02:58:48 PDT 2025
https://github.com/terapines-osc-mlir updated https://github.com/llvm/llvm-project/pull/152439
>From 8a8a319b6a0ba2eaf79a89060ce90d5ad88b87ac Mon Sep 17 00:00:00 2001
From: Terapines MLIR <osc-mlir at terapines.com>
Date: Thu, 7 Aug 2025 09:27:37 +0800
Subject: [PATCH 1/2] [flang][fir]: Add conversion of `fir.iterate_while` to
`scf.while`.
---
flang/lib/Optimizer/Transforms/FIRToSCF.cpp | 90 ++++++++++++++++++-
flang/test/Fir/FirToSCF/iter-while.fir | 99 +++++++++++++++++++++
2 files changed, 187 insertions(+), 2 deletions(-)
create mode 100644 flang/test/Fir/FirToSCF/iter-while.fir
diff --git a/flang/lib/Optimizer/Transforms/FIRToSCF.cpp b/flang/lib/Optimizer/Transforms/FIRToSCF.cpp
index 1902757e83bf3..b779a21089549 100644
--- a/flang/lib/Optimizer/Transforms/FIRToSCF.cpp
+++ b/flang/lib/Optimizer/Transforms/FIRToSCF.cpp
@@ -88,6 +88,91 @@ struct DoLoopConversion : public OpRewritePattern<fir::DoLoopOp> {
}
};
+struct IterWhileConversion : public OpRewritePattern<fir::IterWhileOp> {
+ using OpRewritePattern<fir::IterWhileOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(fir::IterWhileOp iterWhileOp,
+ PatternRewriter &rewriter) const override {
+
+ Location loc = iterWhileOp.getLoc();
+ Value lowerBound = iterWhileOp.getLowerBound();
+ Value upperBound = iterWhileOp.getUpperBound();
+ Value step = iterWhileOp.getStep();
+
+ Value okInit = iterWhileOp.getIterateIn();
+ ValueRange iterArgs = iterWhileOp.getInitArgs();
+
+ SmallVector<Value> initVals;
+ initVals.push_back(lowerBound);
+ initVals.push_back(okInit);
+ initVals.append(iterArgs.begin(), iterArgs.end());
+
+ SmallVector<Type> loopTypes;
+ loopTypes.push_back(lowerBound.getType());
+ loopTypes.push_back(okInit.getType());
+ for (auto val : iterArgs)
+ loopTypes.push_back(val.getType());
+
+ auto scfWhileOp = scf::WhileOp::create(rewriter, loc, loopTypes, initVals);
+ rewriter.createBlock(&scfWhileOp.getBefore(), scfWhileOp.getBefore().end(),
+ loopTypes,
+ SmallVector<Location>(loopTypes.size(), loc));
+
+ rewriter.createBlock(&scfWhileOp.getAfter(), scfWhileOp.getAfter().end(),
+ loopTypes,
+ SmallVector<Location>(loopTypes.size(), loc));
+
+ {
+ rewriter.setInsertionPointToStart(&scfWhileOp.getBefore().front());
+ auto args = scfWhileOp.getBefore().getArguments();
+ auto iv = args[0];
+ auto ok = args[1];
+
+ Value inductionCmp = mlir::arith::CmpIOp::create(
+ rewriter, loc, mlir::arith::CmpIPredicate::sle, iv, upperBound);
+ Value cmp = mlir::arith::AndIOp::create(rewriter, loc, inductionCmp, ok);
+
+ mlir::scf::ConditionOp::create(rewriter, loc, cmp, args);
+ }
+
+ {
+ rewriter.setInsertionPointToStart(&scfWhileOp.getAfter().front());
+ auto args = scfWhileOp.getAfter().getArguments();
+ auto iv = args[0];
+
+ mlir::IRMapping mapping;
+ for (auto [oldArg, newVal] :
+ llvm::zip(iterWhileOp.getBody()->getArguments(), args))
+ mapping.map(oldArg, newVal);
+
+ for (auto &op : iterWhileOp.getBody()->without_terminator())
+ rewriter.clone(op, mapping);
+
+ auto resultOp =
+ cast<fir::ResultOp>(iterWhileOp.getBody()->getTerminator());
+ auto results = resultOp.getResults();
+
+ SmallVector<Value> yieldedVals;
+
+ Value nextIv = mlir::arith::AddIOp::create(rewriter, loc, iv, step);
+ yieldedVals.push_back(nextIv);
+
+ for (auto val : results.drop_front()) {
+ if (mapping.contains(val)) {
+ yieldedVals.push_back(mapping.lookup(val));
+ } else {
+ yieldedVals.push_back(val);
+ }
+ }
+
+ mlir::scf::YieldOp::create(rewriter, loc, yieldedVals);
+ }
+
+ rewriter.replaceOp(iterWhileOp, scfWhileOp);
+ return success();
+ }
+};
+
void copyBlockAndTransformResult(PatternRewriter &rewriter, Block &srcBlock,
Block &dstBlock) {
Operation *srcTerminator = srcBlock.getTerminator();
@@ -130,9 +215,10 @@ struct IfConversion : public OpRewritePattern<fir::IfOp> {
void FIRToSCFPass::runOnOperation() {
RewritePatternSet patterns(&getContext());
- patterns.add<DoLoopConversion, IfConversion>(patterns.getContext());
+ patterns.add<DoLoopConversion, IterWhileConversion, IfConversion>(
+ patterns.getContext());
ConversionTarget target(getContext());
- target.addIllegalOp<fir::DoLoopOp, fir::IfOp>();
+ target.addIllegalOp<fir::DoLoopOp, fir::IterWhileOp, fir::IfOp>();
target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
if (failed(
applyPartialConversion(getOperation(), target, std::move(patterns))))
diff --git a/flang/test/Fir/FirToSCF/iter-while.fir b/flang/test/Fir/FirToSCF/iter-while.fir
new file mode 100644
index 0000000000000..a5de48f2ba848
--- /dev/null
+++ b/flang/test/Fir/FirToSCF/iter-while.fir
@@ -0,0 +1,99 @@
+// RUN: fir-opt %s --fir-to-scf | FileCheck %s
+
+// CHECK-LABEL: func.func @test_simple_iterate_while_1() -> (index, i1, i16, i32) {
+// CHECK: %[[VAL_0:.*]] = arith.constant 11 : index
+// CHECK: %[[VAL_1:.*]] = arith.constant 22 : index
+// CHECK: %[[VAL_2:.*]] = arith.constant 2 : index
+// CHECK: %[[VAL_3:.*]] = arith.constant true
+// CHECK: %[[VAL_4:.*]] = arith.constant 123 : i16
+// CHECK: %[[VAL_5:.*]] = arith.constant 456 : i32
+// CHECK: %[[VAL_6:.*]]:4 = scf.while (%[[VAL_7:.*]] = %[[VAL_0]], %[[VAL_8:.*]] = %[[VAL_3]], %[[VAL_9:.*]] = %[[VAL_4]], %[[VAL_10:.*]] = %[[VAL_5]]) : (index, i1, i16, i32) -> (index, i1, i16, i32) {
+// CHECK: %[[VAL_11:.*]] = arith.cmpi sle, %[[VAL_7]], %[[VAL_1]] : index
+// CHECK: %[[VAL_12:.*]] = arith.andi %[[VAL_11]], %[[VAL_8]] : i1
+// CHECK: scf.condition(%[[VAL_12]]) %[[VAL_7]], %[[VAL_8]], %[[VAL_9]], %[[VAL_10]] : index, i1, i16, i32
+// CHECK: } do {
+// CHECK: ^bb0(%[[VAL_13:.*]]: index, %[[VAL_14:.*]]: i1, %[[VAL_15:.*]]: i16, %[[VAL_16:.*]]: i32):
+// CHECK: %[[VAL_17:.*]] = arith.constant true
+// CHECK: %[[VAL_18:.*]] = arith.constant 22 : i16
+// CHECK: %[[VAL_19:.*]] = arith.constant 33 : i32
+// CHECK: %[[VAL_20:.*]] = arith.addi %[[VAL_13]], %[[VAL_2]] : index
+// CHECK: scf.yield %[[VAL_20]], %[[VAL_17]], %[[VAL_18]], %[[VAL_19]] : index, i1, i16, i32
+// CHECK: }
+// CHECK: return %[[VAL_21:.*]]#0, %[[VAL_21]]#1, %[[VAL_21]]#2, %[[VAL_21]]#3 : index, i1, i16, i32
+// CHECK: }
+func.func @test_simple_iterate_while_1() -> (index, i1, i16, i32) {
+ %lo = arith.constant 11 : index
+ %up = arith.constant 22 : index
+ %step = arith.constant 2 : index
+ %ok = arith.constant 1 : i1
+ %val1 = arith.constant 123 : i16
+ %val2 = arith.constant 456 : i32
+
+ %res:4 = fir.iterate_while (%i = %lo to %up step %step) and (%c = %ok) iter_args(%v1 = %val1, %v2 = %val2) -> (index, i1, i16, i32) {
+ %new_c = arith.constant 1 : i1
+ %new_v1 = arith.constant 22 : i16
+ %new_v2 = arith.constant 33 : i32
+ fir.result %i, %new_c, %new_v1, %new_v2 : index, i1, i16, i32
+ }
+
+ return %res#0, %res#1, %res#2, %res#3 : index, i1, i16, i32
+}
+
+// CHECK-LABEL: func.func @test_simple_iterate_while_2(
+// CHECK-SAME: %[[ARG0:.*]]: index, %[[ARG1:.*]]: index, %[[ARG2:.*]]: i1, %[[ARG3:.*]]: i32) -> (index, i1, i32) {
+// CHECK: %[[VAL_0:.*]] = arith.constant 1 : index
+// CHECK: %[[VAL_1:.*]]:3 = scf.while (%[[VAL_2:.*]] = %[[ARG0]], %[[VAL_3:.*]] = %[[ARG2]], %[[VAL_4:.*]] = %[[ARG3]]) : (index, i1, i32) -> (index, i1, i32) {
+// CHECK: %[[VAL_5:.*]] = arith.cmpi sle, %[[VAL_2]], %[[ARG1]] : index
+// CHECK: %[[VAL_6:.*]] = arith.andi %[[VAL_5]], %[[VAL_3]] : i1
+// CHECK: scf.condition(%[[VAL_6]]) %[[VAL_2]], %[[VAL_3]], %[[VAL_4]] : index, i1, i32
+// CHECK: } do {
+// CHECK: ^bb0(%[[VAL_7:.*]]: index, %[[VAL_8:.*]]: i1, %[[VAL_9:.*]]: i32):
+// CHECK: %[[VAL_10:.*]] = arith.constant 123 : i32
+// CHECK: %[[VAL_11:.*]] = arith.constant true
+// CHECK: %[[VAL_12:.*]] = arith.addi %[[VAL_7]], %[[VAL_0]] : index
+// CHECK: scf.yield %[[VAL_12]], %[[VAL_11]], %[[VAL_10]] : index, i1, i32
+// CHECK: }
+// CHECK: return %[[VAL_13:.*]]#0, %[[VAL_13]]#1, %[[VAL_13]]#2 : index, i1, i32
+// CHECK: }
+func.func @test_simple_iterate_while_2(%start: index, %stop: index, %cond: i1, %val: i32) -> (index, i1, i32) {
+ %step = arith.constant 1 : index
+
+ %res:3 = fir.iterate_while (%i = %start to %stop step %step) and (%ok = %cond) iter_args(%x = %val) -> (index, i1, i32) {
+ %new_x = arith.constant 123 : i32
+ %new_ok = arith.constant 1 : i1
+ fir.result %i, %new_ok, %new_x : index, i1, i32
+ }
+
+ return %res#0, %res#1, %res#2 : index, i1, i32
+}
+
+// CHECK-LABEL: func.func @test_zero_iterations() -> (index, i1, i8) {
+// CHECK: %[[VAL_0:.*]] = arith.constant 10 : index
+// CHECK: %[[VAL_1:.*]] = arith.constant 5 : index
+// CHECK: %[[VAL_2:.*]] = arith.constant 1 : index
+// CHECK: %[[VAL_3:.*]] = arith.constant true
+// CHECK: %[[VAL_4:.*]] = arith.constant 42 : i8
+// CHECK: %[[VAL_5:.*]]:3 = scf.while (%[[VAL_6:.*]] = %[[VAL_0]], %[[VAL_7:.*]] = %[[VAL_3]], %[[VAL_8:.*]] = %[[VAL_4]]) : (index, i1, i8) -> (index, i1, i8) {
+// CHECK: %[[VAL_9:.*]] = arith.cmpi sle, %[[VAL_6]], %[[VAL_1]] : index
+// CHECK: %[[VAL_10:.*]] = arith.andi %[[VAL_9]], %[[VAL_7]] : i1
+// CHECK: scf.condition(%[[VAL_10]]) %[[VAL_6]], %[[VAL_7]], %[[VAL_8]] : index, i1, i8
+// CHECK: } do {
+// CHECK: ^bb0(%[[VAL_11:.*]]: index, %[[VAL_12:.*]]: i1, %[[VAL_13:.*]]: i8):
+// CHECK: %[[VAL_14:.*]] = arith.addi %[[VAL_11]], %[[VAL_2]] : index
+// CHECK: scf.yield %[[VAL_14]], %[[VAL_12]], %[[VAL_13]] : index, i1, i8
+// CHECK: }
+// CHECK: return %[[VAL_15:.*]]#0, %[[VAL_15]]#1, %[[VAL_15]]#2 : index, i1, i8
+// CHECK: }
+func.func @test_zero_iterations() -> (index, i1, i8) {
+ %lo = arith.constant 10 : index
+ %up = arith.constant 5 : index
+ %step = arith.constant 1 : index
+ %ok = arith.constant 1 : i1
+ %x = arith.constant 42 : i8
+
+ %res:3 = fir.iterate_while (%i = %lo to %up step %step) and (%c = %ok) iter_args(%xv = %x) -> (index, i1, i8) {
+ fir.result %i, %c, %xv : index, i1, i8
+ }
+
+ return %res#0, %res#1, %res#2 : index, i1, i8
+}
>From 0e1f67ba92d447ff8995f0396b0191d1b3aa61ac Mon Sep 17 00:00:00 2001
From: Terapines MLIR <osc-mlir at terapines.com>
Date: Fri, 8 Aug 2025 17:53:20 +0800
Subject: [PATCH 2/2] Refactor: Using `rewriter.inlineBlockBefore` instead of
`mlir::IRMapping`.
---
flang/lib/Optimizer/Transforms/FIRToSCF.cpp | 77 ++++++++++-----------
flang/test/Fir/FirToSCF/iter-while.fir | 18 ++---
2 files changed, 45 insertions(+), 50 deletions(-)
diff --git a/flang/lib/Optimizer/Transforms/FIRToSCF.cpp b/flang/lib/Optimizer/Transforms/FIRToSCF.cpp
index b779a21089549..60a794dfb8e8e 100644
--- a/flang/lib/Optimizer/Transforms/FIRToSCF.cpp
+++ b/flang/lib/Optimizer/Transforms/FIRToSCF.cpp
@@ -10,6 +10,7 @@
#include "flang/Optimizer/Transforms/Passes.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Transforms/DialectConversion.h"
+#include <mlir/Support/LLVM.h>
namespace fir {
#define GEN_PASS_DEF_FIRTOSCFPASS
@@ -114,59 +115,53 @@ struct IterWhileConversion : public OpRewritePattern<fir::IterWhileOp> {
loopTypes.push_back(val.getType());
auto scfWhileOp = scf::WhileOp::create(rewriter, loc, loopTypes, initVals);
- rewriter.createBlock(&scfWhileOp.getBefore(), scfWhileOp.getBefore().end(),
- loopTypes,
- SmallVector<Location>(loopTypes.size(), loc));
- rewriter.createBlock(&scfWhileOp.getAfter(), scfWhileOp.getAfter().end(),
- loopTypes,
- SmallVector<Location>(loopTypes.size(), loc));
+ auto &beforeBlock = *rewriter.createBlock(
+ &scfWhileOp.getBefore(), scfWhileOp.getBefore().end(), loopTypes,
+ SmallVector<Location>(loopTypes.size(), loc));
- {
- rewriter.setInsertionPointToStart(&scfWhileOp.getBefore().front());
- auto args = scfWhileOp.getBefore().getArguments();
- auto iv = args[0];
- auto ok = args[1];
+ auto &afterBlock = *rewriter.createBlock(
+ &scfWhileOp.getAfter(), scfWhileOp.getAfter().end(), loopTypes,
+ SmallVector<Location>(loopTypes.size(), loc));
- Value inductionCmp = mlir::arith::CmpIOp::create(
- rewriter, loc, mlir::arith::CmpIPredicate::sle, iv, upperBound);
- Value cmp = mlir::arith::AndIOp::create(rewriter, loc, inductionCmp, ok);
+ auto beforeArgs = scfWhileOp.getBefore().getArguments();
+ auto beforeIv = beforeArgs[0];
+ auto beforeOk = beforeArgs[1];
- mlir::scf::ConditionOp::create(rewriter, loc, cmp, args);
- }
+ rewriter.setInsertionPointToStart(&beforeBlock);
- {
- rewriter.setInsertionPointToStart(&scfWhileOp.getAfter().front());
- auto args = scfWhileOp.getAfter().getArguments();
- auto iv = args[0];
+ Value inductionCmp = mlir::arith::CmpIOp::create(
+ rewriter, loc, mlir::arith::CmpIPredicate::sle, beforeIv, upperBound);
+ Value cond =
+ mlir::arith::AndIOp::create(rewriter, loc, inductionCmp, beforeOk);
- mlir::IRMapping mapping;
- for (auto [oldArg, newVal] :
- llvm::zip(iterWhileOp.getBody()->getArguments(), args))
- mapping.map(oldArg, newVal);
+ mlir::scf::ConditionOp::create(rewriter, loc, cond, beforeArgs);
- for (auto &op : iterWhileOp.getBody()->without_terminator())
- rewriter.clone(op, mapping);
+ auto afterArgs = scfWhileOp.getAfter().getArguments();
- auto resultOp =
- cast<fir::ResultOp>(iterWhileOp.getBody()->getTerminator());
- auto results = resultOp.getResults();
+ SmallVector<Value> argReplacements;
+ for (auto [oldArg, newVal] :
+ llvm::zip(iterWhileOp.getBody()->getArguments(), afterArgs))
+ argReplacements.push_back(newVal);
- SmallVector<Value> yieldedVals;
+ auto resultOp = cast<fir::ResultOp>(iterWhileOp.getBody()->getTerminator());
+ SmallVector<Value> results(resultOp->getOperands().begin(),
+ resultOp->getOperands().end());
- Value nextIv = mlir::arith::AddIOp::create(rewriter, loc, iv, step);
- yieldedVals.push_back(nextIv);
+ rewriter.inlineBlockBefore(iterWhileOp.getBody(), &afterBlock,
+ afterBlock.begin(), argReplacements);
- for (auto val : results.drop_front()) {
- if (mapping.contains(val)) {
- yieldedVals.push_back(mapping.lookup(val));
- } else {
- yieldedVals.push_back(val);
- }
- }
+ Value afterIv = afterArgs[0];
- mlir::scf::YieldOp::create(rewriter, loc, yieldedVals);
- }
+ rewriter.setInsertionPointToStart(&afterBlock);
+
+ results[0] = mlir::arith::AddIOp::create(rewriter, loc, afterIv, step);
+
+ Operation *movedTerminator = afterBlock.getTerminator();
+ rewriter.setInsertionPoint(movedTerminator);
+
+ mlir::scf::YieldOp::create(rewriter, loc, results);
+ rewriter.eraseOp(movedTerminator);
rewriter.replaceOp(iterWhileOp, scfWhileOp);
return success();
diff --git a/flang/test/Fir/FirToSCF/iter-while.fir b/flang/test/Fir/FirToSCF/iter-while.fir
index a5de48f2ba848..0de7aabed120e 100644
--- a/flang/test/Fir/FirToSCF/iter-while.fir
+++ b/flang/test/Fir/FirToSCF/iter-while.fir
@@ -13,11 +13,11 @@
// CHECK: scf.condition(%[[VAL_12]]) %[[VAL_7]], %[[VAL_8]], %[[VAL_9]], %[[VAL_10]] : index, i1, i16, i32
// CHECK: } do {
// CHECK: ^bb0(%[[VAL_13:.*]]: index, %[[VAL_14:.*]]: i1, %[[VAL_15:.*]]: i16, %[[VAL_16:.*]]: i32):
-// CHECK: %[[VAL_17:.*]] = arith.constant true
-// CHECK: %[[VAL_18:.*]] = arith.constant 22 : i16
-// CHECK: %[[VAL_19:.*]] = arith.constant 33 : i32
-// CHECK: %[[VAL_20:.*]] = arith.addi %[[VAL_13]], %[[VAL_2]] : index
-// CHECK: scf.yield %[[VAL_20]], %[[VAL_17]], %[[VAL_18]], %[[VAL_19]] : index, i1, i16, i32
+// CHECK: %[[VAL_17:.*]] = arith.addi %[[VAL_13]], %[[VAL_2]] : index
+// CHECK: %[[VAL_18:.*]] = arith.constant true
+// CHECK: %[[VAL_19:.*]] = arith.constant 22 : i16
+// CHECK: %[[VAL_20:.*]] = arith.constant 33 : i32
+// CHECK: scf.yield %[[VAL_17]], %[[VAL_18]], %[[VAL_19]], %[[VAL_20]] : index, i1, i16, i32
// CHECK: }
// CHECK: return %[[VAL_21:.*]]#0, %[[VAL_21]]#1, %[[VAL_21]]#2, %[[VAL_21]]#3 : index, i1, i16, i32
// CHECK: }
@@ -48,10 +48,10 @@ func.func @test_simple_iterate_while_1() -> (index, i1, i16, i32) {
// CHECK: scf.condition(%[[VAL_6]]) %[[VAL_2]], %[[VAL_3]], %[[VAL_4]] : index, i1, i32
// CHECK: } do {
// CHECK: ^bb0(%[[VAL_7:.*]]: index, %[[VAL_8:.*]]: i1, %[[VAL_9:.*]]: i32):
-// CHECK: %[[VAL_10:.*]] = arith.constant 123 : i32
-// CHECK: %[[VAL_11:.*]] = arith.constant true
-// CHECK: %[[VAL_12:.*]] = arith.addi %[[VAL_7]], %[[VAL_0]] : index
-// CHECK: scf.yield %[[VAL_12]], %[[VAL_11]], %[[VAL_10]] : index, i1, i32
+// CHECK: %[[VAL_10:.*]] = arith.addi %[[VAL_7]], %[[VAL_0]] : index
+// CHECK: %[[VAL_11:.*]] = arith.constant 123 : i32
+// CHECK: %[[VAL_12:.*]] = arith.constant true
+// CHECK: scf.yield %[[VAL_10]], %[[VAL_12]], %[[VAL_11]] : index, i1, i32
// CHECK: }
// CHECK: return %[[VAL_13:.*]]#0, %[[VAL_13]]#1, %[[VAL_13]]#2 : index, i1, i32
// CHECK: }
More information about the flang-commits
mailing list