[flang-commits] [flang] b7e915c - [flang] Conditional expressions lowering: use fir.if SSA results for trivial scalar types (#192338)
via flang-commits
flang-commits at lists.llvm.org
Fri Apr 17 08:20:47 PDT 2026
Author: Caroline Newcombe
Date: 2026-04-17T10:20:42-05:00
New Revision: b7e915c59354cf4e78a37363fab514ad21ebb29c
URL: https://github.com/llvm/llvm-project/commit/b7e915c59354cf4e78a37363fab514ad21ebb29c
DIFF: https://github.com/llvm/llvm-project/commit/b7e915c59354cf4e78a37363fab514ad21ebb29c.diff
LOG: [flang] Conditional expressions lowering: use fir.if SSA results for trivial scalar types (#192338)
For trivial scalar types (INTEGER, REAL, COMPLEX, LOGICAL, UNSIGNED),
generate `fir.if` with SSA results instead of allocating a temporary and
using `hlfir.assign`. This avoids the alloca/declare/assign/load pattern
for types that can be passed directly as SSA values.
Non-trivial scalar types (derived types, characters) continue to use the
existing temporary-based paths.
The LIT test expectations have been updated accordingly, and a test case
was added.
Added:
Modified:
flang/lib/Lower/ConvertExprToHLFIR.cpp
flang/test/Lower/HLFIR/conditional-expr.f90
Removed:
################################################################################
diff --git a/flang/lib/Lower/ConvertExprToHLFIR.cpp b/flang/lib/Lower/ConvertExprToHLFIR.cpp
index cad54fc441d88..e3ffd9f7194c2 100644
--- a/flang/lib/Lower/ConvertExprToHLFIR.cpp
+++ b/flang/lib/Lower/ConvertExprToHLFIR.cpp
@@ -1894,12 +1894,52 @@ class HlfirBuilder {
buildConditionalIfChain(
condExpr, [&](const Fortran::evaluate::Expr<T> &expr) {
hlfir::Entity entity{gen(expr)};
- entity = hlfir::loadTrivialScalar(loc, builder, entity);
hlfir::AssignOp::create(builder, loc, entity, temp);
});
return temp;
}
+ /// Generate scalar conditional for trivial scalar types using fir.if SSA
+ /// results; avoids temporary and assignment.
+ template <typename T>
+ hlfir::Entity genTrivialScalarConditional(
+ const Fortran::evaluate::ConditionalExpr<T> &condExpr,
+ mlir::Type elementType) {
+ assert(fir::isa_trivial(elementType) &&
+ "genTrivialScalarConditional only handles trivial scalar types");
+ const mlir::Location loc{getLoc()};
+ fir::FirOpBuilder &builder{getBuilder()};
+ getStmtCtx().pushScope();
+ const hlfir::EntityWithAttributes condEntity{gen(condExpr.condition())};
+ mlir::Value condition{hlfir::loadTrivialScalar(loc, builder, condEntity)};
+ condition = builder.createConvert(loc, builder.getI1Type(), condition);
+ auto results =
+ builder
+ .genIfOp(loc, {elementType}, condition,
+ /*withElseRegion=*/true)
+ .genThen([&]() {
+ getStmtCtx().pushScope();
+ hlfir::Entity entity{gen(condExpr.thenValue())};
+ entity = hlfir::loadTrivialScalar(loc, builder, entity);
+ getStmtCtx().finalizeAndPop();
+ mlir::Value result =
+ builder.createConvert(loc, elementType, entity);
+ fir::ResultOp::create(builder, loc, result);
+ })
+ .genElse([&]() {
+ getStmtCtx().pushScope();
+ hlfir::Entity entity{gen(condExpr.elseValue())};
+ entity = hlfir::loadTrivialScalar(loc, builder, entity);
+ getStmtCtx().finalizeAndPop();
+ mlir::Value result =
+ builder.createConvert(loc, elementType, entity);
+ fir::ResultOp::create(builder, loc, result);
+ })
+ .getResults();
+ getStmtCtx().finalizeAndPop();
+ return hlfir::Entity{results[0]};
+ }
+
/// Generate conditional expression using an allocatable temporary with lazy
/// evaluation. Creates an unallocated allocatable, then uses assignment to
/// set the value from the chosen branch (allocation/reallocation handled by
@@ -1992,8 +2032,12 @@ class HlfirBuilder {
return *result;
}
// Scalar types (INTEGER, REAL, COMPLEX, LOGICAL, UNSIGNED, Derived).
- return hlfir::EntityWithAttributes{genScalarConditional(
- condExpr, hlfir::getFortranElementType(resultType), {})};
+ const mlir::Type elementType{hlfir::getFortranElementType(resultType)};
+ if (fir::isa_trivial(elementType))
+ return hlfir::EntityWithAttributes{
+ genTrivialScalarConditional(condExpr, elementType)};
+ return hlfir::EntityWithAttributes{
+ genScalarConditional(condExpr, elementType, {})};
}
hlfir::EntityWithAttributes
diff --git a/flang/test/Lower/HLFIR/conditional-expr.f90 b/flang/test/Lower/HLFIR/conditional-expr.f90
index d0d7f41a92124..5ec5db26fa970 100644
--- a/flang/test/Lower/HLFIR/conditional-expr.f90
+++ b/flang/test/Lower/HLFIR/conditional-expr.f90
@@ -1,5 +1,5 @@
! Test lowering of conditional expressions (Fortran 2023)
-! RUN: %flang_fc1 -emit-hlfir -o - %s 2>&1 | FileCheck %s
+! RUN: %flang_fc1 -emit-hlfir -funsigned -o - %s 2>&1 | FileCheck %s
! CHECK-LABEL: func.func @_QPtest_scalar_integer(
! CHECK-SAME: %[[FLAG:.*]]: !fir.ref<!fir.logical<4>> {fir.bindc_name = "flag"},
@@ -8,24 +8,21 @@
subroutine test_scalar_integer(flag, x, y)
logical :: flag
integer :: x, y, result
- ! CHECK: %[[TEMP:.*]] = fir.alloca i32 {bindc_name = ".cond.scalar"
! CHECK-DAG: %[[FLAG_DECL:.*]]:2 = hlfir.declare %[[FLAG]]
! CHECK-DAG: %[[X_DECL:.*]]:2 = hlfir.declare %[[X]]
! CHECK-DAG: %[[Y_DECL:.*]]:2 = hlfir.declare %[[Y]]
- ! CHECK: %[[TEMP_DECL:.*]]:2 = hlfir.declare %[[TEMP]] {uniq_name = ".cond.result"}
result = (flag ? x : y)
! CHECK: %[[FLAG_LOAD:.*]] = fir.load %[[FLAG_DECL]]#0
! CHECK: %[[FLAG_CONV:.*]] = fir.convert %[[FLAG_LOAD]] : (!fir.logical<4>) -> i1
- ! CHECK: fir.if %[[FLAG_CONV]] {
+ ! CHECK: %[[RESULT:.*]] = fir.if %[[FLAG_CONV]] -> (i32) {
! CHECK: %[[X_LOAD:.*]] = fir.load %[[X_DECL]]#0 : !fir.ref<i32>
- ! CHECK: hlfir.assign %[[X_LOAD]] to %[[TEMP_DECL]]#0 : i32, !fir.ref<i32>
+ ! CHECK: fir.result %[[X_LOAD]] : i32
! CHECK: } else {
! CHECK: %[[Y_LOAD:.*]] = fir.load %[[Y_DECL]]#0 : !fir.ref<i32>
- ! CHECK: hlfir.assign %[[Y_LOAD]] to %[[TEMP_DECL]]#0 : i32, !fir.ref<i32>
+ ! CHECK: fir.result %[[Y_LOAD]] : i32
! CHECK: }
- ! CHECK: %[[LOAD:.*]] = fir.load %[[TEMP_DECL]]#0
- ! CHECK: hlfir.assign %[[LOAD]] to %{{.*}} : i32, !fir.ref<i32>
+ ! CHECK: hlfir.assign %[[RESULT]] to %{{.*}} : i32, !fir.ref<i32>
end subroutine
! CHECK-LABEL: func.func @_QPtest_scalar_real(
@@ -33,12 +30,10 @@ subroutine test_scalar_real(flag, x, y)
logical :: flag
real :: x, y, result
result = (flag ? x : y)
- ! CHECK: %[[TEMP:.*]] = fir.alloca f32 {bindc_name = ".cond.scalar"
- ! CHECK: %[[TEMP_DECL:.*]]:2 = hlfir.declare %[[TEMP]] {uniq_name = ".cond.result"}
- ! CHECK: fir.if
- ! CHECK: hlfir.assign {{.*}} to %[[TEMP_DECL]]#0 : f32, !fir.ref<f32>
+ ! CHECK: %[[RESULT:.*]] = fir.if {{.*}} -> (f32) {
+ ! CHECK: fir.result {{.*}} : f32
! CHECK: } else {
- ! CHECK: hlfir.assign {{.*}} to %[[TEMP_DECL]]#0 : f32, !fir.ref<f32>
+ ! CHECK: fir.result {{.*}} : f32
! CHECK: }
end subroutine
@@ -47,12 +42,10 @@ subroutine test_scalar_complex(flag, x, y)
logical :: flag
complex :: x, y, result
result = (flag ? x : y)
- ! CHECK: %[[TEMP:.*]] = fir.alloca complex<f32> {bindc_name = ".cond.scalar"
- ! CHECK: %[[TEMP_DECL:.*]]:2 = hlfir.declare %[[TEMP]] {uniq_name = ".cond.result"}
- ! CHECK: fir.if
- ! CHECK: hlfir.assign {{.*}} to %[[TEMP_DECL]]#0 : complex<f32>, !fir.ref<complex<f32>>
+ ! CHECK: %[[RESULT:.*]] = fir.if {{.*}} -> (complex<f32>) {
+ ! CHECK: fir.result {{.*}} : complex<f32>
! CHECK: } else {
- ! CHECK: hlfir.assign {{.*}} to %[[TEMP_DECL]]#0 : complex<f32>, !fir.ref<complex<f32>>
+ ! CHECK: fir.result {{.*}} : complex<f32>
! CHECK: }
end subroutine
@@ -60,9 +53,37 @@ subroutine test_scalar_complex(flag, x, y)
subroutine test_scalar_logical(flag, x, y)
logical :: flag, x, y, result
result = (flag ? x : y)
- ! CHECK: %[[TEMP:.*]] = fir.alloca !fir.logical<4> {bindc_name = ".cond.scalar"
- ! CHECK: fir.if
+ ! CHECK: %[[RESULT:.*]] = fir.if {{.*}} -> (!fir.logical<4>) {
+ ! CHECK: fir.result {{.*}} : !fir.logical<4>
+ ! CHECK: } else {
+ ! CHECK: fir.result {{.*}} : !fir.logical<4>
+ ! CHECK: }
+end subroutine
+
+! CHECK-LABEL: func.func @_QPtest_scalar_unsigned(
+subroutine test_scalar_unsigned(flag, x, y)
+ logical :: flag
+ unsigned :: x, y, result
+ result = (flag ? x : y)
+ ! CHECK: %[[RESULT:.*]] = fir.if {{.*}} -> (ui32) {
+ ! CHECK: fir.result {{.*}} : ui32
+ ! CHECK: } else {
+ ! CHECK: fir.result {{.*}} : ui32
+ ! CHECK: }
+end subroutine
+
+! CHECK-LABEL: func.func @_QPtest_logical_literal(
+subroutine test_logical_literal(flag)
+ logical :: flag, result
+ result = (flag ? .true. : .false.)
+ ! CHECK: %[[RESULT:.*]] = fir.if {{.*}} -> (!fir.logical<4>) {
+ ! CHECK: %[[TRUE:.*]] = arith.constant true
+ ! CHECK: %[[CONV:.*]] = fir.convert %[[TRUE]] : (i1) -> !fir.logical<4>
+ ! CHECK: fir.result %[[CONV]] : !fir.logical<4>
! CHECK: } else {
+ ! CHECK: %[[FALSE:.*]] = arith.constant false
+ ! CHECK: %[[CONV:.*]] = fir.convert %[[FALSE]] : (i1) -> !fir.logical<4>
+ ! CHECK: fir.result %[[CONV]] : !fir.logical<4>
! CHECK: }
end subroutine
@@ -71,25 +92,19 @@ subroutine test_multi_branch(x)
integer :: x, result
! Multi-branch: x > 10 ? 100 : x > 5 ? 50 : 0
result = (x > 10 ? 100 : x > 5 ? 50 : 0)
- ! Both outer and inner temps are hoisted to function entry.
- ! CHECK-DAG: fir.alloca i32 {bindc_name = ".cond.scalar"
- ! CHECK-DAG: fir.alloca i32 {bindc_name = ".cond.scalar"
- ! Outer temp declaration and first condition: x > 10
- ! CHECK: hlfir.declare {{.*}} {uniq_name = ".cond.result"}
+ ! Outer condition: x > 10
! CHECK: arith.cmpi sgt
- ! CHECK: fir.if {{.*}} {
- ! CHECK: hlfir.assign {{.*}}
+ ! CHECK: %[[OUTER:.*]] = fir.if {{.*}} -> (i32) {
+ ! CHECK: fir.result {{.*}} : i32
! CHECK: } else {
- ! Inner temp for the nested conditional: x > 5 ? 50 : 0
- ! CHECK: hlfir.declare {{.*}} {uniq_name = ".cond.result"}
+ ! Inner conditional: x > 5 ? 50 : 0
! CHECK: arith.cmpi sgt
- ! CHECK: fir.if {{.*}} {
- ! CHECK: hlfir.assign {{.*}}
+ ! CHECK: %[[INNER:.*]] = fir.if {{.*}} -> (i32) {
+ ! CHECK: fir.result {{.*}} : i32
! CHECK: } else {
- ! CHECK: hlfir.assign {{.*}}
+ ! CHECK: fir.result {{.*}} : i32
! CHECK: }
- ! CHECK: fir.load
- ! CHECK: hlfir.assign {{.*}}
+ ! CHECK: fir.result %[[INNER]] : i32
! CHECK: }
end subroutine
@@ -176,22 +191,17 @@ subroutine test_nested_conditionals(flag1, flag2, x, y, z)
integer :: x, y, z, result
! Nested: flag1 ? (flag2 ? x : y) : z
result = (flag1 ? (flag2 ? x : y) : z)
- ! Both outer and inner temps are hoisted to function entry.
- ! CHECK-DAG: fir.alloca i32 {bindc_name = ".cond.scalar"
- ! CHECK-DAG: fir.alloca i32 {bindc_name = ".cond.scalar"
- ! Outer temp declaration and conditional
- ! CHECK: hlfir.declare {{.*}} {uniq_name = ".cond.result"}
- ! CHECK: fir.if {{%.*}} {
- ! Inner temp declaration and conditional
- ! CHECK: hlfir.declare {{.*}} {uniq_name = ".cond.result"}
- ! CHECK: fir.if {{%.*}} {
- ! CHECK: hlfir.assign {{.*}}
+ ! Outer conditional
+ ! CHECK: %[[OUTER:.*]] = fir.if {{%.*}} -> (i32) {
+ ! Inner conditional
+ ! CHECK: %[[INNER:.*]] = fir.if {{%.*}} -> (i32) {
+ ! CHECK: fir.result {{.*}} : i32
! CHECK: } else {
- ! CHECK: hlfir.assign {{.*}}
+ ! CHECK: fir.result {{.*}} : i32
! CHECK: }
- ! CHECK: hlfir.assign {{.*}}
+ ! CHECK: fir.result %[[INNER]] : i32
! CHECK: } else {
- ! CHECK: hlfir.assign {{.*}}
+ ! CHECK: fir.result {{.*}} : i32
! CHECK: }
end subroutine
@@ -201,11 +211,11 @@ subroutine test_in_expression(flag, x, y)
integer :: x, y, z
! Conditional in larger expression: (flag ? x : y) + 10
z = (flag ? x : y) + 10
- ! CHECK: %[[TEMP:.*]] = fir.alloca i32 {bindc_name = ".cond.scalar"
- ! CHECK: fir.if
+ ! CHECK: %[[COND_RESULT:.*]] = fir.if {{.*}} -> (i32) {
+ ! CHECK: fir.result {{.*}} : i32
! CHECK: } else {
+ ! CHECK: fir.result {{.*}} : i32
! CHECK: }
- ! CHECK: %[[COND_RESULT:.*]] = fir.load
! CHECK: %[[C10:.*]] = arith.constant 10
! CHECK: %[[SUM:.*]] = arith.addi %[[COND_RESULT]], %[[C10]]
! CHECK: hlfir.assign %[[SUM]]
@@ -232,17 +242,23 @@ subroutine test_
diff erent_kinds(flag)
integer(kind=4) :: i4_1, i4_2, i4_result
integer(kind=8) :: i8_1, i8_2, i8_result
- ! Both temps allocated at function start
- ! CHECK-DAG: %{{.*}} = fir.alloca i64 {bindc_name = ".cond.scalar"}
- ! CHECK-DAG: %{{.*}} = fir.alloca i32 {bindc_name = ".cond.scalar"}
-
i4_1 = 1
i4_2 = 2
i4_result = (flag ? i4_1 : i4_2)
+ ! CHECK: %{{.*}} = fir.if {{.*}} -> (i32) {
+ ! CHECK: fir.result {{.*}} : i32
+ ! CHECK: } else {
+ ! CHECK: fir.result {{.*}} : i32
+ ! CHECK: }
i8_1 = 3
i8_2 = 4
i8_result = (flag ? i8_1 : i8_2)
+ ! CHECK: %{{.*}} = fir.if {{.*}} -> (i64) {
+ ! CHECK: fir.result {{.*}} : i64
+ ! CHECK: } else {
+ ! CHECK: fir.result {{.*}} : i64
+ ! CHECK: }
end subroutine
! CHECK-LABEL: func.func @_QPtest_array_section(
More information about the flang-commits
mailing list