[flang-commits] [flang] [flang][fir]: New optimizer transform of `fir.if` to `scf.if`. (PR #149959)
Terapines MLIR via flang-commits
flang-commits at lists.llvm.org
Tue Jul 22 01:47:34 PDT 2025
https://github.com/terapines-osc-mlir updated https://github.com/llvm/llvm-project/pull/149959
>From de4de9c815273820376be8df2880a5a0d1bbd841 Mon Sep 17 00:00:00 2001
From: Terapines MLIR <osc-mlir at terapines.com>
Date: Tue, 22 Jul 2025 11:09:20 +0800
Subject: [PATCH] [flang][fir]: New optimizer transform of `fir.if` to
`scf.if`.
---
flang/lib/Optimizer/Transforms/FIRToSCF.cpp | 43 +++++++++++++++-
flang/test/Fir/FirToSCF/if.fir | 57 +++++++++++++++++++++
2 files changed, 98 insertions(+), 2 deletions(-)
create mode 100644 flang/test/Fir/FirToSCF/if.fir
diff --git a/flang/lib/Optimizer/Transforms/FIRToSCF.cpp b/flang/lib/Optimizer/Transforms/FIRToSCF.cpp
index d7d1865bc56ba..f3198329028e9 100644
--- a/flang/lib/Optimizer/Transforms/FIRToSCF.cpp
+++ b/flang/lib/Optimizer/Transforms/FIRToSCF.cpp
@@ -87,13 +87,52 @@ struct DoLoopConversion : public OpRewritePattern<fir::DoLoopOp> {
return success();
}
};
+
+void copyBlocksAndTransformResult(PatternRewriter &rewriter, Block &srcBlock,
+ Block &dstBlock) {
+ Operation *srcTerminator = srcBlock.getTerminator();
+ auto resultOp = cast<fir::ResultOp>(srcTerminator);
+
+ dstBlock.getOperations().splice(dstBlock.begin(), srcBlock.getOperations(),
+ srcBlock.begin(), std::prev(srcBlock.end()));
+
+ if (!resultOp->getOperands().empty()) {
+ rewriter.setInsertionPointToEnd(&dstBlock);
+ rewriter.create<scf::YieldOp>(resultOp.getLoc(), resultOp.getOperands());
+ }
+
+ rewriter.eraseOp(srcTerminator);
+}
+
+struct IfConversion : public OpRewritePattern<fir::IfOp> {
+ using OpRewritePattern<fir::IfOp>::OpRewritePattern;
+ LogicalResult matchAndRewrite(fir::IfOp ifOp,
+ PatternRewriter &rewriter) const override {
+ mlir::Location loc = ifOp.getLoc();
+ bool hasElse = !ifOp.getElseRegion().empty();
+ auto scfIfOp = rewriter.create<scf::IfOp>(loc, ifOp->getResultTypes(),
+ ifOp.getCondition(), hasElse);
+
+ copyBlocksAndTransformResult(rewriter, ifOp.getThenRegion().front(),
+ scfIfOp.getThenRegion().front());
+
+ if (hasElse) {
+ copyBlocksAndTransformResult(rewriter, ifOp.getElseRegion().front(),
+ scfIfOp.getElseRegion().front());
+ }
+
+ scfIfOp->setAttrs(ifOp->getAttrs());
+ rewriter.replaceOp(ifOp, scfIfOp);
+ return success();
+ }
+};
} // namespace
void FIRToSCFPass::runOnOperation() {
RewritePatternSet patterns(&getContext());
- patterns.add<DoLoopConversion>(patterns.getContext());
+ patterns.add<DoLoopConversion, IfConversion>(patterns.getContext());
ConversionTarget target(getContext());
- target.addIllegalOp<fir::DoLoopOp>();
+ target.addIllegalOp<fir::DoLoopOp, fir::IfOp>();
target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
if (failed(
applyPartialConversion(getOperation(), target, std::move(patterns))))
diff --git a/flang/test/Fir/FirToSCF/if.fir b/flang/test/Fir/FirToSCF/if.fir
new file mode 100644
index 0000000000000..03be264c4cdf5
--- /dev/null
+++ b/flang/test/Fir/FirToSCF/if.fir
@@ -0,0 +1,57 @@
+// RUN: fir-opt %s --fir-to-scf | FileCheck %s
+
+// CHECK-LABEL: func.func @test_only(
+// CHECK-SAME: %[[ARG0:.*]]: i1, %[[ARG1:.*]]: i32) {
+// CHECK: scf.if %[[ARG0]] {
+// CHECK: %[[VAL_1:.*]] = arith.addi %[[ARG1]], %[[ARG1]] : i32
+// CHECK: }
+// CHECK: return
+// CHECK: }
+func.func @test_only(%arg0 : i1, %arg1 : i32) {
+ fir.if %arg0 {
+ %0 = arith.addi %arg1, %arg1 : i32
+ }
+ return
+}
+
+// CHECK-LABEL: func.func @test_else() {
+// CHECK: %[[VAL_1:.*]] = arith.constant false
+// CHECK: %[[VAL_2:.*]] = arith.constant 2 : i32
+// CHECK: scf.if %[[VAL_1]] {
+// CHECK: %[[VAL_3:.*]] = arith.constant 3 : i32
+// CHECK: } else {
+// CHECK: %[[VAL_3:.*]] = arith.constant 3 : i32
+// CHECK: }
+// CHECK: return
+// CHECK: }
+func.func @test_else() {
+ %false = arith.constant false
+ %1 = arith.constant 2 : i32
+ fir.if %false {
+ %2 = arith.constant 3 : i32
+ } else {
+ %3 = arith.constant 3 : i32
+ }
+ return
+}
+
+// CHECK-LABEL: func.func @test_two_result() {
+// CHECK: %[[VAL_1:.*]] = arith.constant 2.000000e+00 : f32
+// CHECK: %[[VAL_2:.*]] = arith.constant false
+// CHECK: %[[RES:[0-9]+]]:2 = scf.if %[[VAL_2]] -> (f32, f32) {
+// CHECK: scf.yield %[[VAL_1]], %[[VAL_1]] : f32, f32
+// CHECK: } else {
+// CHECK: scf.yield %[[VAL_1]], %[[VAL_1]] : f32, f32
+// CHECK: }
+// CHECK: return
+// CHECK: }
+func.func @test_two_result() {
+ %1 = arith.constant 2.0 : f32
+ %cmp = arith.constant false
+ %x, %y = fir.if %cmp -> (f32, f32) {
+ fir.result %1, %1 : f32, f32
+ } else {
+ fir.result %1, %1 : f32, f32
+ }
+ return
+}
More information about the flang-commits
mailing list