[flang-commits] [flang] [flang] Conditional expressions lowering: use fir.if SSA results for trivial scalar types (PR #192338)

Caroline Newcombe via flang-commits flang-commits at lists.llvm.org
Wed Apr 15 13:55:29 PDT 2026


https://github.com/cenewcombe updated https://github.com/llvm/llvm-project/pull/192338

>From 88bb0fe2be7b46ba9c53697d89c30ff463563222 Mon Sep 17 00:00:00 2001
From: Caroline Newcombe <caroline.newcombe at hpe.com>
Date: Wed, 15 Apr 2026 15:17:38 -0500
Subject: [PATCH] [flang] Conditional expressions lowering: use fir.if SSA
 results for trivial scalar types

---
 flang/lib/Lower/ConvertExprToHLFIR.cpp      |  50 ++++++++-
 flang/test/Lower/HLFIR/conditional-expr.f90 | 110 ++++++++++----------
 2 files changed, 104 insertions(+), 56 deletions(-)

diff --git a/flang/lib/Lower/ConvertExprToHLFIR.cpp b/flang/lib/Lower/ConvertExprToHLFIR.cpp
index cad54fc441d88..e3ffd9f7194c2 100644
--- a/flang/lib/Lower/ConvertExprToHLFIR.cpp
+++ b/flang/lib/Lower/ConvertExprToHLFIR.cpp
@@ -1894,12 +1894,52 @@ class HlfirBuilder {
     buildConditionalIfChain(
         condExpr, [&](const Fortran::evaluate::Expr<T> &expr) {
           hlfir::Entity entity{gen(expr)};
-          entity = hlfir::loadTrivialScalar(loc, builder, entity);
           hlfir::AssignOp::create(builder, loc, entity, temp);
         });
     return temp;
   }
 
+  /// Generate scalar conditional for trivial scalar types using fir.if SSA
+  /// results; avoids temporary and assignment.
+  template <typename T>
+  hlfir::Entity genTrivialScalarConditional(
+      const Fortran::evaluate::ConditionalExpr<T> &condExpr,
+      mlir::Type elementType) {
+    assert(fir::isa_trivial(elementType) &&
+           "genTrivialScalarConditional only handles trivial scalar types");
+    const mlir::Location loc{getLoc()};
+    fir::FirOpBuilder &builder{getBuilder()};
+    getStmtCtx().pushScope();
+    const hlfir::EntityWithAttributes condEntity{gen(condExpr.condition())};
+    mlir::Value condition{hlfir::loadTrivialScalar(loc, builder, condEntity)};
+    condition = builder.createConvert(loc, builder.getI1Type(), condition);
+    auto results =
+        builder
+            .genIfOp(loc, {elementType}, condition,
+                     /*withElseRegion=*/true)
+            .genThen([&]() {
+              getStmtCtx().pushScope();
+              hlfir::Entity entity{gen(condExpr.thenValue())};
+              entity = hlfir::loadTrivialScalar(loc, builder, entity);
+              getStmtCtx().finalizeAndPop();
+              mlir::Value result =
+                  builder.createConvert(loc, elementType, entity);
+              fir::ResultOp::create(builder, loc, result);
+            })
+            .genElse([&]() {
+              getStmtCtx().pushScope();
+              hlfir::Entity entity{gen(condExpr.elseValue())};
+              entity = hlfir::loadTrivialScalar(loc, builder, entity);
+              getStmtCtx().finalizeAndPop();
+              mlir::Value result =
+                  builder.createConvert(loc, elementType, entity);
+              fir::ResultOp::create(builder, loc, result);
+            })
+            .getResults();
+    getStmtCtx().finalizeAndPop();
+    return hlfir::Entity{results[0]};
+  }
+
   /// Generate conditional expression using an allocatable temporary with lazy
   /// evaluation. Creates an unallocated allocatable, then uses assignment to
   /// set the value from the chosen branch (allocation/reallocation handled by
@@ -1992,8 +2032,12 @@ class HlfirBuilder {
         return *result;
     }
     // Scalar types (INTEGER, REAL, COMPLEX, LOGICAL, UNSIGNED, Derived).
-    return hlfir::EntityWithAttributes{genScalarConditional(
-        condExpr, hlfir::getFortranElementType(resultType), {})};
+    const mlir::Type elementType{hlfir::getFortranElementType(resultType)};
+    if (fir::isa_trivial(elementType))
+      return hlfir::EntityWithAttributes{
+          genTrivialScalarConditional(condExpr, elementType)};
+    return hlfir::EntityWithAttributes{
+        genScalarConditional(condExpr, elementType, {})};
   }
 
   hlfir::EntityWithAttributes
diff --git a/flang/test/Lower/HLFIR/conditional-expr.f90 b/flang/test/Lower/HLFIR/conditional-expr.f90
index d0d7f41a92124..5122d67c08e92 100644
--- a/flang/test/Lower/HLFIR/conditional-expr.f90
+++ b/flang/test/Lower/HLFIR/conditional-expr.f90
@@ -8,24 +8,21 @@
 subroutine test_scalar_integer(flag, x, y)
   logical :: flag
   integer :: x, y, result
-  ! CHECK: %[[TEMP:.*]] = fir.alloca i32 {bindc_name = ".cond.scalar"
   ! CHECK-DAG: %[[FLAG_DECL:.*]]:2 = hlfir.declare %[[FLAG]]
   ! CHECK-DAG: %[[X_DECL:.*]]:2 = hlfir.declare %[[X]]
   ! CHECK-DAG: %[[Y_DECL:.*]]:2 = hlfir.declare %[[Y]]
-  ! CHECK: %[[TEMP_DECL:.*]]:2 = hlfir.declare %[[TEMP]] {uniq_name = ".cond.result"}
 
   result = (flag ? x : y)
   ! CHECK: %[[FLAG_LOAD:.*]] = fir.load %[[FLAG_DECL]]#0
   ! CHECK: %[[FLAG_CONV:.*]] = fir.convert %[[FLAG_LOAD]] : (!fir.logical<4>) -> i1
-  ! CHECK: fir.if %[[FLAG_CONV]] {
+  ! CHECK: %[[RESULT:.*]] = fir.if %[[FLAG_CONV]] -> (i32) {
   ! CHECK:   %[[X_LOAD:.*]] = fir.load %[[X_DECL]]#0 : !fir.ref<i32>
-  ! CHECK:   hlfir.assign %[[X_LOAD]] to %[[TEMP_DECL]]#0 : i32, !fir.ref<i32>
+  ! CHECK:   fir.result %[[X_LOAD]] : i32
   ! CHECK: } else {
   ! CHECK:   %[[Y_LOAD:.*]] = fir.load %[[Y_DECL]]#0 : !fir.ref<i32>
-  ! CHECK:   hlfir.assign %[[Y_LOAD]] to %[[TEMP_DECL]]#0 : i32, !fir.ref<i32>
+  ! CHECK:   fir.result %[[Y_LOAD]] : i32
   ! CHECK: }
-  ! CHECK: %[[LOAD:.*]] = fir.load %[[TEMP_DECL]]#0
-  ! CHECK: hlfir.assign %[[LOAD]] to %{{.*}} : i32, !fir.ref<i32>
+  ! CHECK: hlfir.assign %[[RESULT]] to %{{.*}} : i32, !fir.ref<i32>
 end subroutine
 
 ! CHECK-LABEL: func.func @_QPtest_scalar_real(
@@ -33,12 +30,10 @@ subroutine test_scalar_real(flag, x, y)
   logical :: flag
   real :: x, y, result
   result = (flag ? x : y)
-  ! CHECK: %[[TEMP:.*]] = fir.alloca f32 {bindc_name = ".cond.scalar"
-  ! CHECK: %[[TEMP_DECL:.*]]:2 = hlfir.declare %[[TEMP]] {uniq_name = ".cond.result"}
-  ! CHECK: fir.if
-  ! CHECK:   hlfir.assign {{.*}} to %[[TEMP_DECL]]#0 : f32, !fir.ref<f32>
+  ! CHECK: %[[RESULT:.*]] = fir.if {{.*}} -> (f32) {
+  ! CHECK:   fir.result {{.*}} : f32
   ! CHECK: } else {
-  ! CHECK:   hlfir.assign {{.*}} to %[[TEMP_DECL]]#0 : f32, !fir.ref<f32>
+  ! CHECK:   fir.result {{.*}} : f32
   ! CHECK: }
 end subroutine
 
@@ -47,12 +42,10 @@ subroutine test_scalar_complex(flag, x, y)
   logical :: flag
   complex :: x, y, result
   result = (flag ? x : y)
-  ! CHECK: %[[TEMP:.*]] = fir.alloca complex<f32> {bindc_name = ".cond.scalar"
-  ! CHECK: %[[TEMP_DECL:.*]]:2 = hlfir.declare %[[TEMP]] {uniq_name = ".cond.result"}
-  ! CHECK: fir.if
-  ! CHECK:   hlfir.assign {{.*}} to %[[TEMP_DECL]]#0 : complex<f32>, !fir.ref<complex<f32>>
+  ! CHECK: %[[RESULT:.*]] = fir.if {{.*}} -> (complex<f32>) {
+  ! CHECK:   fir.result {{.*}} : complex<f32>
   ! CHECK: } else {
-  ! CHECK:   hlfir.assign {{.*}} to %[[TEMP_DECL]]#0 : complex<f32>, !fir.ref<complex<f32>>
+  ! CHECK:   fir.result {{.*}} : complex<f32>
   ! CHECK: }
 end subroutine
 
@@ -60,9 +53,25 @@ subroutine test_scalar_complex(flag, x, y)
 subroutine test_scalar_logical(flag, x, y)
   logical :: flag, x, y, result
   result = (flag ? x : y)
-  ! CHECK: %[[TEMP:.*]] = fir.alloca !fir.logical<4> {bindc_name = ".cond.scalar"
-  ! CHECK: fir.if
+  ! CHECK: %[[RESULT:.*]] = fir.if {{.*}} -> (!fir.logical<4>) {
+  ! CHECK:   fir.result {{.*}} : !fir.logical<4>
+  ! CHECK: } else {
+  ! CHECK:   fir.result {{.*}} : !fir.logical<4>
+  ! CHECK: }
+end subroutine
+
+! CHECK-LABEL: func.func @_QPtest_logical_literal(
+subroutine test_logical_literal(flag)
+  logical :: flag, result
+  result = (flag ? .true. : .false.)
+  ! CHECK: %[[RESULT:.*]] = fir.if {{.*}} -> (!fir.logical<4>) {
+  ! CHECK:   %[[TRUE:.*]] = arith.constant true
+  ! CHECK:   %[[CONV:.*]] = fir.convert %[[TRUE]] : (i1) -> !fir.logical<4>
+  ! CHECK:   fir.result %[[CONV]] : !fir.logical<4>
   ! CHECK: } else {
+  ! CHECK:   %[[FALSE:.*]] = arith.constant false
+  ! CHECK:   %[[CONV:.*]] = fir.convert %[[FALSE]] : (i1) -> !fir.logical<4>
+  ! CHECK:   fir.result %[[CONV]] : !fir.logical<4>
   ! CHECK: }
 end subroutine
 
@@ -71,25 +80,19 @@ subroutine test_multi_branch(x)
   integer :: x, result
   ! Multi-branch: x > 10 ? 100 : x > 5 ? 50 : 0
   result = (x > 10 ? 100 : x > 5 ? 50 : 0)
-  ! Both outer and inner temps are hoisted to function entry.
-  ! CHECK-DAG: fir.alloca i32 {bindc_name = ".cond.scalar"
-  ! CHECK-DAG: fir.alloca i32 {bindc_name = ".cond.scalar"
-  ! Outer temp declaration and first condition: x > 10
-  ! CHECK: hlfir.declare {{.*}} {uniq_name = ".cond.result"}
+  ! Outer condition: x > 10
   ! CHECK: arith.cmpi sgt
-  ! CHECK: fir.if {{.*}} {
-  ! CHECK:   hlfir.assign {{.*}}
+  ! CHECK: %[[OUTER:.*]] = fir.if {{.*}} -> (i32) {
+  ! CHECK:   fir.result {{.*}} : i32
   ! CHECK: } else {
-  ! Inner temp for the nested conditional: x > 5 ? 50 : 0
-  ! CHECK:   hlfir.declare {{.*}} {uniq_name = ".cond.result"}
+  ! Inner conditional: x > 5 ? 50 : 0
   ! CHECK:   arith.cmpi sgt
-  ! CHECK:   fir.if {{.*}} {
-  ! CHECK:     hlfir.assign {{.*}}
+  ! CHECK:   %[[INNER:.*]] = fir.if {{.*}} -> (i32) {
+  ! CHECK:     fir.result {{.*}} : i32
   ! CHECK:   } else {
-  ! CHECK:     hlfir.assign {{.*}}
+  ! CHECK:     fir.result {{.*}} : i32
   ! CHECK:   }
-  ! CHECK:   fir.load
-  ! CHECK:   hlfir.assign {{.*}}
+  ! CHECK:   fir.result %[[INNER]] : i32
   ! CHECK: }
 end subroutine
 
@@ -176,22 +179,17 @@ subroutine test_nested_conditionals(flag1, flag2, x, y, z)
   integer :: x, y, z, result
   ! Nested: flag1 ? (flag2 ? x : y) : z
   result = (flag1 ? (flag2 ? x : y) : z)
-  ! Both outer and inner temps are hoisted to function entry.
-  ! CHECK-DAG: fir.alloca i32 {bindc_name = ".cond.scalar"
-  ! CHECK-DAG: fir.alloca i32 {bindc_name = ".cond.scalar"
-  ! Outer temp declaration and conditional
-  ! CHECK: hlfir.declare {{.*}} {uniq_name = ".cond.result"}
-  ! CHECK: fir.if {{%.*}} {
-  ! Inner temp declaration and conditional
-  ! CHECK:   hlfir.declare {{.*}} {uniq_name = ".cond.result"}
-  ! CHECK:   fir.if {{%.*}} {
-  ! CHECK:     hlfir.assign {{.*}}
+  ! Outer conditional
+  ! CHECK: %[[OUTER:.*]] = fir.if {{%.*}} -> (i32) {
+  ! Inner conditional
+  ! CHECK:   %[[INNER:.*]] = fir.if {{%.*}} -> (i32) {
+  ! CHECK:     fir.result {{.*}} : i32
   ! CHECK:   } else {
-  ! CHECK:     hlfir.assign {{.*}}
+  ! CHECK:     fir.result {{.*}} : i32
   ! CHECK:   }
-  ! CHECK:   hlfir.assign {{.*}}
+  ! CHECK:   fir.result %[[INNER]] : i32
   ! CHECK: } else {
-  ! CHECK:   hlfir.assign {{.*}}
+  ! CHECK:   fir.result {{.*}} : i32
   ! CHECK: }
 end subroutine
 
@@ -201,11 +199,11 @@ subroutine test_in_expression(flag, x, y)
   integer :: x, y, z
   ! Conditional in larger expression: (flag ? x : y) + 10
   z = (flag ? x : y) + 10
-  ! CHECK: %[[TEMP:.*]] = fir.alloca i32 {bindc_name = ".cond.scalar"
-  ! CHECK: fir.if
+  ! CHECK: %[[COND_RESULT:.*]] = fir.if {{.*}} -> (i32) {
+  ! CHECK:   fir.result {{.*}} : i32
   ! CHECK: } else {
+  ! CHECK:   fir.result {{.*}} : i32
   ! CHECK: }
-  ! CHECK: %[[COND_RESULT:.*]] = fir.load
   ! CHECK: %[[C10:.*]] = arith.constant 10
   ! CHECK: %[[SUM:.*]] = arith.addi %[[COND_RESULT]], %[[C10]]
   ! CHECK: hlfir.assign %[[SUM]]
@@ -232,17 +230,23 @@ subroutine test_different_kinds(flag)
   integer(kind=4) :: i4_1, i4_2, i4_result
   integer(kind=8) :: i8_1, i8_2, i8_result
 
-  ! Both temps allocated at function start
-  ! CHECK-DAG: %{{.*}} = fir.alloca i64 {bindc_name = ".cond.scalar"}
-  ! CHECK-DAG: %{{.*}} = fir.alloca i32 {bindc_name = ".cond.scalar"}
-
   i4_1 = 1
   i4_2 = 2
   i4_result = (flag ? i4_1 : i4_2)
+  ! CHECK: %{{.*}} = fir.if {{.*}} -> (i32) {
+  ! CHECK:   fir.result {{.*}} : i32
+  ! CHECK: } else {
+  ! CHECK:   fir.result {{.*}} : i32
+  ! CHECK: }
 
   i8_1 = 3
   i8_2 = 4
   i8_result = (flag ? i8_1 : i8_2)
+  ! CHECK: %{{.*}} = fir.if {{.*}} -> (i64) {
+  ! CHECK:   fir.result {{.*}} : i64
+  ! CHECK: } else {
+  ! CHECK:   fir.result {{.*}} : i64
+  ! CHECK: }
 end subroutine
 
 ! CHECK-LABEL: func.func @_QPtest_array_section(



More information about the flang-commits mailing list