[flang-commits] [flang] [flang] Region-based HLFIR operation for conditional expressions lowering (PR #194411)

Caroline Newcombe via flang-commits flang-commits at lists.llvm.org
Mon Apr 27 09:51:20 PDT 2026


https://github.com/cenewcombe created https://github.com/llvm/llvm-project/pull/194411

Introduces `hlfir.conditional`, a region-based HLFIR operation that represents Fortran 2023 conditional expressions (10.1.2.3) with lazy branch evaluation.

## Motivation

The previous lowering emitted `fir.if` directly in ConvertExprToHLFIR with per-category strategies (scalar temp + assign, allocatable temp + realloc, CHARACTER-specific handling). This leaked bufferization concerns into lowering and required four separate code paths.

A naive approach of lowering directly to `fir.if` with the branch values as results doesn't work, because `fir.if` requires both branches to yield results of identical MLIR types. The two branches can produce different runtime representations — for example, a a constant-length CHARACTER variable (`fir.ref<fir.char<1,10>>`) vs. a dynamic-length CHARACTER expression (`fir.ref<fir.char<1,?>>`), or an array with a different descriptor shape. These type mismatches are only resolvable after we know the canonical type for the temporary that will hold the result.

## Approach

`hlfir.conditional` captures each branch in its own region (terminated by `hlfir.yield`), deferring materialization to the bufferization pass. During bufferization, `ConditionalOpConversion`:

- Computes a single canonical temp type from the hlfir.expr result type (computeTempBaseType), ensuring both branches produce an identical MLIR type for the `fir.if` results.
- Clones each region into a `fir.if` branch, creating a temp of that canonical type and assigning the yielded value into it.
- Defers `hlfir.destroy` cloning until after the assign to prevent use-after-free.

Trivial scalar types (INTEGER, REAL, COMPLEX, LOGICAL, UNSIGNED) bypass hlfir.conditional entirely and use fir.if SSA results directly since no temporary is needed.

**AI Usage Disclosure**: AI tools (Claude Sonnet 4.5) were used to assist with implementation of this feature and test code generation. I have reviewed, modified, and tested all AI-generated code.

>From ba53c17154af2ebeab2bd870e5bd062b1e6faf8b Mon Sep 17 00:00:00 2001
From: Caroline Newcombe <caroline.newcombe at hpe.com>
Date: Wed, 22 Apr 2026 09:27:20 -0500
Subject: [PATCH] [flang] Region-based HLFIR operation for conditional
 expressions lowering

---
 .../include/flang/Optimizer/HLFIR/HLFIROps.td |  57 +++++-
 flang/lib/Lower/ConvertExprToHLFIR.cpp        | 177 +++++-------------
 flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp     |  35 ++++
 .../HLFIR/Transforms/BufferizeHLFIR.cpp       | 148 ++++++++++++++-
 flang/test/Lower/HLFIR/conditional-expr.f90   | 115 ++++--------
 5 files changed, 310 insertions(+), 222 deletions(-)

diff --git a/flang/include/flang/Optimizer/HLFIR/HLFIROps.td b/flang/include/flang/Optimizer/HLFIR/HLFIROps.td
index 69647fe7a2c6d..f05554079816c 100644
--- a/flang/include/flang/Optimizer/HLFIR/HLFIROps.td
+++ b/flang/include/flang/Optimizer/HLFIR/HLFIROps.td
@@ -1510,11 +1510,14 @@ def hlfir_RegionAssignOp : hlfir_Op<"region_assign", [hlfir_OrderedAssignmentTre
   let hasVerifier = 1;
 }
 
-def hlfir_YieldOp : hlfir_Op<"yield", [Terminator, ParentOneOf<["RegionAssignOp",
-    "ElementalAddrOp", "ForallOp", "ForallMaskOp", "WhereOp", "ElseWhereOp",
-    "ExactlyOnceOp"]>,
-    SingleBlockImplicitTerminator<"fir::FirEndOp">, RecursivelySpeculatable,
-        RecursiveMemoryEffects]> {
+def hlfir_YieldOp
+    : hlfir_Op<"yield", [Terminator,
+                         ParentOneOf<["RegionAssignOp", "ElementalAddrOp",
+                                      "ForallOp", "ForallMaskOp", "WhereOp",
+                                      "ElseWhereOp", "ExactlyOnceOp",
+                                      "ConditionalOp"]>,
+                         SingleBlockImplicitTerminator<"fir::FirEndOp">,
+                         RecursivelySpeculatable, RecursiveMemoryEffects]> {
 
   let summary = "Yield a value or variable inside a forall, where or region assignment";
 
@@ -1988,5 +1991,49 @@ def hlfir_EvaluateInMemoryOp : hlfir_Op<"eval_in_mem", [AttrSizedOperandSegments
   let hasVerifier = 1;
 }
 
+def hlfir_ConditionalOp : hlfir_Op<"conditional", [RecursiveMemoryEffects,
+                                                   MemoryEffects<[MemAlloc]>]> {
+  let summary = "Fortran conditional expression";
+  let description = [{
+    Represent a Fortran conditional expression (F2023 10.1.2.3) that produces a
+    value by lazily evaluating one of two branches based on a condition.
+
+    Only one of the two regions is evaluated at runtime. Each region must
+    be terminated by an hlfir.yield that yields the branch's result value.
+
+    The result is an hlfir.expr. The dynamic type, length type parameters, and
+    shape of the result are determined by the selected branch at runtime (F2023
+    10.1.4.(7)), so they are not operands of this operation.
+
+    Chained conditional expressions are represented by nesting an
+    hlfir.conditional inside the else region.
+
+    Example: ( X ? Y : Z ) where X is logical and Y, Z are real scalars
+    ```
+      %0 = hlfir.conditional %cond : (i1) -> !hlfir.expr<f32> {
+        hlfir.yield %y : f32
+      } else {
+        hlfir.yield %z : f32
+      }
+    ```
+  }];
+
+  let arguments = (ins I1:$condition);
+
+  let results = (outs hlfir_ExprType);
+  let regions = (region SizedRegion<1>:$then_region,
+      SizedRegion<1>:$else_region);
+
+  let assemblyFormat = [{
+    $condition attr-dict `:` functional-type(operands, results)
+    $then_region `else` $else_region
+  }];
+
+  let skipDefaultBuilders = 1;
+  let builders = [OpBuilder<(ins "mlir::Type":$result_type,
+      "mlir::Value":$condition)>];
+
+  let hasVerifier = 1;
+}
 
 #endif // FORTRAN_DIALECT_HLFIR_OPS
diff --git a/flang/lib/Lower/ConvertExprToHLFIR.cpp b/flang/lib/Lower/ConvertExprToHLFIR.cpp
index a57fce53c0ca5..d9b450d65878f 100644
--- a/flang/lib/Lower/ConvertExprToHLFIR.cpp
+++ b/flang/lib/Lower/ConvertExprToHLFIR.cpp
@@ -1847,56 +1847,48 @@ class HlfirBuilder {
     llvm_unreachable("unknown descriptor inquiry");
   }
 
-  /// Build nested if-then-else chain by walking the right-skewed
-  /// ConditionalExpr tree. The assignValue callback generates and assigns
-  /// each value to avoid evaluating non-taken branches.
-  template <typename T, typename Callback>
-  void
-  buildConditionalIfChain(const Fortran::evaluate::ConditionalExpr<T> &condExpr,
-                          const Callback &assignValue) {
+  /// Generate a conditional expression as an hlfir.conditional op whose
+  /// regions yield the then/else values. Materialization into memory is
+  /// deferred to the bufferization pass.
+  template <typename T>
+  hlfir::Entity
+  genConditionalOp(const Fortran::evaluate::ConditionalExpr<T> &condExpr,
+                   mlir::Type elementType, bool isPolymorphic) {
     const mlir::Location loc{getLoc()};
     fir::FirOpBuilder &builder{getBuilder()};
+    // Lower the condition to i1.
     getStmtCtx().pushScope();
     const hlfir::EntityWithAttributes condEntity{gen(condExpr.condition())};
     mlir::Value condition{hlfir::loadTrivialScalar(loc, builder, condEntity)};
     condition = builder.createConvert(loc, builder.getI1Type(), condition);
-    builder.genIfOp(loc, {}, condition, /*withElseRegion=*/true)
-        .genThen([&]() {
-          getStmtCtx().pushScope();
-          assignValue(condExpr.thenValue());
-          getStmtCtx().finalizeAndPop();
-        })
-        .genElse([&]() {
-          getStmtCtx().pushScope();
-          assignValue(condExpr.elseValue());
-          getStmtCtx().finalizeAndPop();
-        })
-        .end();
     getStmtCtx().finalizeAndPop();
-  }
-
-  /// Generate scalar conditional with lazy evaluation using assignment.
-  /// Creates a temporary and assigns the selected branch value to it.
-  template <typename T>
-  hlfir::Entity
-  genScalarConditional(const Fortran::evaluate::ConditionalExpr<T> &condExpr,
-                       mlir::Type elementType,
-                       const llvm::SmallVector<mlir::Value, 1> &typeParams) {
-    const mlir::Location loc{getLoc()};
-    fir::FirOpBuilder &builder{getBuilder()};
-    const mlir::Value tempStorage{builder.createTemporary(
-        loc, elementType, ".cond.scalar",
-        /*shape=*/mlir::ValueRange{}, /*typeParams=*/typeParams)};
-    const hlfir::DeclareOp tempDecl{hlfir::DeclareOp::create(
-        builder, loc, tempStorage, ".cond.result",
-        /*shape=*/mlir::Value{}, /*typeParams=*/typeParams)};
-    const hlfir::Entity temp{tempDecl};
-    buildConditionalIfChain(
-        condExpr, [&](const Fortran::evaluate::Expr<T> &expr) {
-          hlfir::Entity entity{gen(expr)};
-          hlfir::AssignOp::create(builder, loc, entity, temp);
-        });
-    return temp;
+    // Build the hlfir.expr result type.
+    const hlfir::ExprType::Shape shape(condExpr.Rank(),
+                                       hlfir::ExprType::getUnknownExtent());
+    const mlir::Type exprType{hlfir::ExprType::get(builder.getContext(), shape,
+                                                   elementType, isPolymorphic)};
+    auto condOp =
+        hlfir::ConditionalOp::create(builder, loc, exprType, condition);
+    // Populate each region inside a scope so that any cleanups
+    // (hlfir.destroy) emitted by gen() stay inside the region, avoiding
+    // dominance violations. The ConditionalOpConversion in bufferization
+    // defers these destroy ops until after the assign into the temp.
+    builder.setInsertionPointToStart(&condOp.getThenRegion().front());
+    getStmtCtx().pushScope();
+    const hlfir::Entity thenEntity{gen(condExpr.thenValue())};
+    getStmtCtx().finalizeAndPop();
+    hlfir::YieldOp::create(builder, loc, thenEntity);
+    builder.setInsertionPointToStart(&condOp.getElseRegion().front());
+    getStmtCtx().pushScope();
+    const hlfir::Entity elseEntity{gen(condExpr.elseValue())};
+    getStmtCtx().finalizeAndPop();
+    hlfir::YieldOp::create(builder, loc, elseEntity);
+    builder.setInsertionPointAfter(condOp);
+    fir::FirOpBuilder *const bldr{&builder};
+    mlir::Value result{condOp.getResult()};
+    getStmtCtx().attachCleanup(
+        [=]() { hlfir::DestroyOp::create(*bldr, loc, result); });
+    return hlfir::Entity{result};
   }
 
   /// Generate scalar conditional for trivial scalar types using fir.if SSA
@@ -1940,110 +1932,27 @@ class HlfirBuilder {
     return hlfir::Entity{results[0]};
   }
 
-  /// Generate conditional expression using an allocatable temporary with lazy
-  /// evaluation. Creates an unallocated allocatable, then uses assignment to
-  /// set the value from the chosen branch (allocation/reallocation handled by
-  /// runtime).
-  template <typename T>
-  hlfir::Entity genAllocatableConditional(
-      const Fortran::evaluate::ConditionalExpr<T> &condExpr,
-      mlir::Type resultType, llvm::StringRef debugName) {
-    const mlir::Location loc{getLoc()};
-    fir::FirOpBuilder &builder{getBuilder()};
-    // Polymorphic types need fir.class (not fir.box) to carry dynamic type
-    // info. Both scalar and array polymorphic types reach here.
-    const bool isPolymorphic{fir::isPolymorphicType(resultType)};
-    const mlir::Type allocType{
-        hlfir::getFortranElementOrSequenceType(resultType)};
-    const mlir::Type heapType{fir::HeapType::get(allocType)};
-    const mlir::Type boxHeapType{isPolymorphic
-                                     ? mlir::Type{fir::ClassType::get(heapType)}
-                                     : mlir::Type{fir::BoxType::get(heapType)}};
-    const mlir::Value tempStorage{
-        builder.createTemporary(loc, boxHeapType, debugName)};
-    const mlir::Value unallocBox{fir::factory::createUnallocatedBox(
-        builder, loc, boxHeapType, /*nonDeferredParams=*/{})};
-    builder.createStoreWithConvert(loc, unallocBox, tempStorage);
-    const hlfir::DeclareOp tempDecl{
-        hlfir::DeclareOp::create(builder, loc, tempStorage, ".cond.result")};
-    const hlfir::Entity temp{tempDecl};
-    // Lazy evaluation: only the selected branch is evaluated and assigned.
-    buildConditionalIfChain(
-        condExpr, [&](const Fortran::evaluate::Expr<T> &expr) {
-          const hlfir::Entity entity{gen(expr)};
-          hlfir::AssignOp::create(builder, loc, entity, temp,
-                                  /*isWholeAllocatableAssignment=*/true,
-                                  /*keepLhsLengthIfRealloc=*/false,
-                                  /*temporary_lhs=*/true);
-        });
-    fir::FirOpBuilder *const bldr{&builder};
-    getStmtCtx().attachCleanup([=]() {
-      fir::factory::genFreememIfAllocated(
-          *bldr, loc,
-          fir::MutableBoxValue{tempStorage, /*lenParams=*/{},
-                               fir::MutableProperties{}});
-    });
-    return temp;
-  }
-
-  /// Generate scalar CHARACTER conditional with proper length handling.
-  template <typename T>
-  std::optional<hlfir::EntityWithAttributes> genCharacterConditional(
-      const Fortran::evaluate::ConditionalExpr<T> &condExpr) {
-    const mlir::Location loc{getLoc()};
-    fir::FirOpBuilder &builder{getBuilder()};
-    const mlir::Type resultType{Fortran::lower::translateSomeExprToFIRType(
-        converter, toEvExpr(condExpr))};
-    const mlir::Type elementType{hlfir::getFortranElementType(resultType)};
-    if (auto charType = mlir::dyn_cast<fir::CharacterType>(elementType)) {
-      if (charType.hasConstantLen()) {
-        llvm::SmallVector<mlir::Value, 1> typeParams;
-        const mlir::Value len{builder.createIntegerConstant(
-            loc, builder.getCharacterLengthType(), charType.getLen())};
-        typeParams.push_back(len);
-        return hlfir::EntityWithAttributes{
-            genScalarConditional(condExpr, elementType, typeParams)};
-      }
-      // Non-constant/varying length: use allocatable conditional to get length
-      // from selected branch.
-      return hlfir::EntityWithAttributes{
-          genAllocatableConditional(condExpr, elementType, ".cond.char")};
-    }
-    return std::nullopt;
-  }
-
   /// Conditional expression (Fortran 2023)
   template <typename T>
   hlfir::EntityWithAttributes
   gen(const Fortran::evaluate::ConditionalExpr<T> &condExpr) {
-    const int rank{condExpr.Rank()};
     mlir::Type resultType{Fortran::lower::translateSomeExprToFIRType(
         converter, toEvExpr(condExpr))};
     if (fir::isRecordWithTypeParameters(
             hlfir::getFortranElementType(resultType)))
       TODO(getLoc(), "conditional expression with length-parameterized "
                      "derived type");
-    // Arrays: handle early to avoid unnecessary type checks.
-    // Per F2023 10.1.4(7), the shape is determined by the chosen branch.
-    if (rank != 0) {
-      return hlfir::EntityWithAttributes{
-          genAllocatableConditional(condExpr, resultType, ".cond.array")};
-    }
-    // CHARACTER scalars require special handling for type parameters.
-    if constexpr (T::category == Fortran::common::TypeCategory::Character) {
-      if (auto result = genCharacterConditional(condExpr))
-        return *result;
-    }
-    // Scalar types (INTEGER, REAL, COMPLEX, LOGICAL, UNSIGNED, Derived).
+    // Trivial scalar types (INTEGER, REAL, COMPLEX, LOGICAL, UNSIGNED)
+    // use fir.if SSA results directly — no temporary needed.
     const mlir::Type elementType{hlfir::getFortranElementType(resultType)};
-    if (fir::isPolymorphicType(resultType))
-      return hlfir::EntityWithAttributes{
-          genAllocatableConditional(condExpr, resultType, ".cond.polymorphic")};
-    if (fir::isa_trivial(elementType))
+    if (condExpr.Rank() == 0 && !fir::isPolymorphicType(resultType) &&
+        fir::isa_trivial(elementType))
       return hlfir::EntityWithAttributes{
           genTrivialScalarConditional(condExpr, elementType)};
-    return hlfir::EntityWithAttributes{
-        genScalarConditional(condExpr, elementType, {})};
+    // All other cases: arrays, CHARACTER, polymorphic, non-trivial derived.
+    // Emit hlfir.conditional to delay materialization to bufferization.
+    return hlfir::EntityWithAttributes{genConditionalOp(
+        condExpr, elementType, fir::isPolymorphicType(resultType))};
   }
 
   hlfir::EntityWithAttributes
diff --git a/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp b/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp
index e42c064794176..3230066f2305f 100644
--- a/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp
+++ b/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp
@@ -2461,6 +2461,41 @@ llvm::LogicalResult hlfir::EvaluateInMemoryOp::verify() {
   return mlir::success();
 }
 
+//===----------------------------------------------------------------------===//
+// ConditionalOp
+//===----------------------------------------------------------------------===//
+
+void hlfir::ConditionalOp::build(mlir::OpBuilder &builder,
+                                 mlir::OperationState &odsState,
+                                 mlir::Type resultType, mlir::Value condition) {
+  odsState.addTypes(resultType);
+  odsState.addOperands(condition);
+  // Create the then and else regions, each with one empty block.
+  odsState.addRegion()->push_back(new mlir::Block{});
+  odsState.addRegion()->push_back(new mlir::Block{});
+}
+
+llvm::LogicalResult hlfir::ConditionalOp::verify() {
+  if (!mlir::isa<hlfir::ExprType>(getResult().getType()))
+    return emitOpError("result must be an hlfir.expr type");
+  const auto checkRegion = [&](mlir::Region &region,
+                               llvm::StringRef name) -> llvm::LogicalResult {
+    if (region.empty())
+      return emitOpError(name) << " region must not be empty";
+    if (!region.hasOneBlock())
+      return emitOpError(name) << " region must have exactly one block";
+    if (!mlir::isa_and_nonnull<hlfir::YieldOp>(getTerminator(region)))
+      return emitOpError(name)
+             << " region must be terminated by an hlfir.yield";
+    return mlir::success();
+  };
+  if (const auto res = checkRegion(getThenRegion(), "then"); failed(res))
+    return res;
+  if (const auto res = checkRegion(getElseRegion(), "else"); failed(res))
+    return res;
+  return mlir::success();
+}
+
 #include "flang/Optimizer/HLFIR/HLFIROpInterfaces.cpp.inc"
 #define GET_OP_CLASSES
 #include "flang/Optimizer/HLFIR/HLFIREnums.cpp.inc"
diff --git a/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp b/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp
index 35cbdd59cf5d8..24e4099ec98b7 100644
--- a/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp
+++ b/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp
@@ -29,6 +29,7 @@
 #include "flang/Optimizer/OpenMP/Passes.h"
 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
 #include "mlir/IR/Dominance.h"
+#include "mlir/IR/IRMapping.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Pass/PassManager.h"
@@ -857,6 +858,137 @@ struct CharExtremumOpConversion
   }
 };
 
+struct ConditionalOpConversion
+    : public mlir::OpConversionPattern<hlfir::ConditionalOp> {
+  using mlir::OpConversionPattern<hlfir::ConditionalOp>::OpConversionPattern;
+  explicit ConditionalOpConversion(mlir::MLIRContext *ctx)
+      : mlir::OpConversionPattern<hlfir::ConditionalOp>{ctx} {
+    // This pattern recursively converts nested ConditionalOp's
+    // by cloning and then converting them, so we have to allow
+    // for recursive pattern application. The recursion is bounded
+    // by the nesting level of ConditionalOp's.
+    setHasBoundedRewriteRecursion();
+  }
+  /// Compute the MLIR type of the temp's DeclareOp base result,
+  /// given the hlfir.expr type of the conditional. This must match
+  /// what createAndDeclareTemp + hlfir::DeclareOp would produce.
+  static mlir::Type computeTempBaseType(fir::FirOpBuilder &builder,
+                                        hlfir::ExprType exprType) {
+    const mlir::Type eleTy{exprType.getEleTy()};
+    const bool isPolymorphic{exprType.isPolymorphic()};
+    const bool isArray{exprType.isArray()};
+    mlir::Type elemOrSeqType{eleTy};
+    if (isArray)
+      elemOrSeqType = fir::SequenceType::get(exprType.getShape(), eleTy);
+    // Polymorphic: runtime allocate produces fir.class<fir.heap<T>>,
+    // DeclareOp strips the heap attribute → fir.class<T>.
+    if (isPolymorphic)
+      return fir::ClassType::get(elemOrSeqType);
+    // Arrays (non-polymorphic): heap alloc → fir.heap<seqTy>,
+    // DeclareOp wraps in box → fir.box<seqTy>.
+    if (isArray)
+      return fir::BoxType::get(elemOrSeqType);
+    // Scalar, non-polymorphic.
+    if (auto charTy{mlir::dyn_cast<fir::CharacterType>(eleTy)})
+      if (charTy.hasDynamicLen())
+        return fir::BoxCharType::get(builder.getContext(), charTy.getFKind());
+    if (fir::isRecordWithTypeParameters(eleTy))
+      return fir::BoxType::get(eleTy);
+    return fir::ReferenceType::get(eleTy);
+  }
+
+  llvm::LogicalResult
+  matchAndRewrite(hlfir::ConditionalOp condOp, OpAdaptor adaptor,
+                  mlir::ConversionPatternRewriter &rewriter) const override {
+    const mlir::Location loc{condOp->getLoc()};
+    fir::FirOpBuilder builder(rewriter, condOp.getOperation());
+    HLFIRListener listener{builder, rewriter};
+    builder.setListener(&listener);
+    // Use ExprType to ensure both branches produce identical MLIR temp types.
+    const auto exprType{
+        mlir::cast<hlfir::ExprType>(condOp.getResult().getType())};
+    const bool isPolymorphic{exprType.isPolymorphic()};
+    const bool isArray{exprType.isArray()};
+    mlir::Type elemOrSeqType{exprType.getEleTy()};
+    if (isArray)
+      elemOrSeqType =
+          fir::SequenceType::get(exprType.getShape(), elemOrSeqType);
+    const bool useStack{!isArray && !isPolymorphic};
+    const mlir::Type tempBaseType{computeTempBaseType(builder, exprType)};
+    // Callback for hlfir::DeclareOp.
+    auto genTempDeclareOp =
+        [](fir::FirOpBuilder &bldr, mlir::Location l, mlir::Value memref,
+           llvm::StringRef name, mlir::Value shape,
+           llvm::ArrayRef<mlir::Value> typeParams,
+           fir::FortranVariableFlagsAttr attrs) -> mlir::Value {
+      auto declareOp =
+          hlfir::DeclareOp::create(bldr, l, memref, name, shape, typeParams,
+                                   /*dummy_scope=*/nullptr, /*storage=*/nullptr,
+                                   /*storage_offset=*/0, attrs);
+      return declareOp.getBase();
+    };
+
+    // Emit one branch: clone ops, create temp, assign, run deferred
+    // destroys, yield (temp, mustFree).
+    auto emitBranch = [&](mlir::Region &region) {
+      mlir::IRMapping mapper;
+      // Clone all ops except hlfir.destroy and the terminator.
+      for (auto &op : region.front().without_terminator())
+        if (!mlir::isa<hlfir::DestroyOp>(op))
+          builder.clone(op, mapper);
+      auto yield{mlir::cast<hlfir::YieldOp>(region.front().getTerminator())};
+      // Dereference allocatable/pointer values.
+      const hlfir::Entity val{hlfir::derefPointersAndAllocatables(
+          loc, builder,
+          hlfir::Entity{mapper.lookupOrDefault(yield.getEntity())})};
+      // Obtain runtime length/shape from the actual yielded value.
+      llvm::SmallVector<mlir::Value> lenParams;
+      hlfir::genLengthParameters(loc, builder, val, lenParams);
+      mlir::Value shape{};
+      llvm::SmallVector<mlir::Value> extents;
+      if (isArray) {
+        shape = hlfir::genShape(loc, builder, val);
+        extents = hlfir::getExplicitExtentsFromShape(shape, builder);
+      }
+      // Create temp with common MLIR type but runtime params from the yielded
+      // value.
+      const auto [base, isHeapAlloc]{builder.createAndDeclareTemp(
+          loc, elemOrSeqType, shape, extents, lenParams, genTempDeclareOp,
+          isPolymorphic ? val.getBase() : nullptr, useStack, ".tmp.cond")};
+      const hlfir::Entity temp{base};
+      assert(temp.getType() == tempBaseType &&
+             "temp type mismatch with fir.if result type");
+      hlfir::AssignOp::create(builder, loc, val, temp,
+                              /*realloc=*/false,
+                              /*keep_lhs_length_if_realloc=*/false,
+                              /*temporary_lhs=*/true);
+      // Clone hlfir.destroy ops after the assign to avoid
+      // use-after-free of the source operand.
+      for (auto &op : region.front().without_terminator())
+        if (mlir::isa<hlfir::DestroyOp>(op))
+          builder.clone(op, mapper);
+      // Return temp and mustFree as separate fir.result values.
+      mlir::Value mustFreeVal{builder.createBool(loc, isHeapAlloc)};
+      fir::ResultOp::create(builder, loc, mlir::ValueRange{temp, mustFreeVal});
+    };
+
+    // Generate fir.if returning (temp, mustFree) as two results.
+    auto ifOp{builder.genIfOp(
+        loc, /*resultTypes=*/{tempBaseType, builder.getI1Type()},
+        adaptor.getCondition(),
+        /*withElseRegion=*/true)};
+    ifOp.genThen([&]() { emitBranch(condOp.getThenRegion()); })
+        .genElse([&]() { emitBranch(condOp.getElseRegion()); })
+        .end();
+    // Package fir.if results into the bufferized expr tuple.
+    const mlir::Value bufferizedExpr{
+        packageBufferizedExpr(loc, builder, hlfir::Entity{ifOp.getResults()[0]},
+                              ifOp.getResults()[1])};
+    rewriter.replaceOp(condOp, bufferizedExpr);
+    return mlir::success();
+  }
+};
+
 struct EvaluateInMemoryOpConversion
     : public mlir::OpConversionPattern<hlfir::EvaluateInMemoryOp> {
   using mlir::OpConversionPattern<
@@ -892,12 +1024,13 @@ class BufferizeHLFIR : public hlfir::impl::BufferizeHLFIRBase<BufferizeHLFIR> {
     auto module = this->getOperation();
     auto *context = &getContext();
     mlir::RewritePatternSet patterns(context);
-    patterns.insert<ApplyOpConversion, AsExprOpConversion, AssignOpConversion,
-                    AssociateOpConversion, CharExtremumOpConversion,
-                    ConcatOpConversion, DestroyOpConversion,
-                    EndAssociateOpConversion, EvaluateInMemoryOpConversion,
-                    NoReassocOpConversion, SetLengthOpConversion,
-                    ShapeOfOpConversion, GetLengthOpConversion>(context);
+    patterns.insert<
+        ApplyOpConversion, AsExprOpConversion, AssignOpConversion,
+        AssociateOpConversion, CharExtremumOpConversion, ConcatOpConversion,
+        ConditionalOpConversion, DestroyOpConversion, EndAssociateOpConversion,
+        EvaluateInMemoryOpConversion, NoReassocOpConversion,
+        SetLengthOpConversion, ShapeOfOpConversion, GetLengthOpConversion>(
+        context);
     patterns.insert<ElementalOpConversion>(context, optimizeEmptyElementals);
     mlir::ConversionTarget target(*context);
     // Note that YieldElementOp is not marked as an illegal operation.
@@ -905,7 +1038,8 @@ class BufferizeHLFIR : public hlfir::impl::BufferizeHLFIRBase<BufferizeHLFIR> {
     // conversion pattern to YieldElementOp itself. If any YieldElementOp
     // survives this pass, the verifier will detect it because it has to be
     // a child of ElementalOp and ElementalOp's are explicitly illegal.
-    target.addIllegalOp<hlfir::ApplyOp, hlfir::AssociateOp, hlfir::ElementalOp,
+    target.addIllegalOp<hlfir::ApplyOp, hlfir::AssociateOp,
+                        hlfir::ConditionalOp, hlfir::ElementalOp,
                         hlfir::EndAssociateOp, hlfir::SetLengthOp>();
 
     target.markUnknownOpDynamicallyLegal([](mlir::Operation *op) {
diff --git a/flang/test/Lower/HLFIR/conditional-expr.f90 b/flang/test/Lower/HLFIR/conditional-expr.f90
index 56d9bc4c45369..83b8cb8531304 100644
--- a/flang/test/Lower/HLFIR/conditional-expr.f90
+++ b/flang/test/Lower/HLFIR/conditional-expr.f90
@@ -115,13 +115,10 @@ subroutine test_char_constant_len(flag)
   str1 = "HELLO"
   str2 = "WORLD"
   result = (flag ? str1 : str2)
-  ! Constant length: use scalar temp path.
-  ! CHECK: %[[TEMP:.*]] = fir.alloca !fir.char<1,5> {bindc_name = ".cond.scalar"
-  ! CHECK: %[[TEMP_DECL:.*]]:2 = hlfir.declare %[[TEMP]] typeparams {{.*}} {uniq_name = ".cond.result"}
-  ! CHECK: fir.if
-  ! CHECK:   hlfir.assign {{.*}} to %[[TEMP_DECL]]#0
+  ! CHECK: %[[RESULT:.*]] = hlfir.conditional %{{.*}} : (i1) -> !hlfir.expr<!fir.char<1,5>> {
+  ! CHECK:   hlfir.yield %{{.*}} : !fir.ref<!fir.char<1,5>>
   ! CHECK: } else {
-  ! CHECK:   hlfir.assign {{.*}} to %[[TEMP_DECL]]#0
+  ! CHECK:   hlfir.yield %{{.*}} : !fir.ref<!fir.char<1,5>>
   ! CHECK: }
 end subroutine
 
@@ -133,16 +130,10 @@ subroutine test_char_deferred_len(flag)
   str2 = "A MUCH LONGER STRING"
   ! Result length comes from selected branch
   result = (flag ? str1 : str2)
-  ! CHECK-DAG: %[[BOX_ALLOC:.*]] = fir.alloca !fir.box<!fir.heap<!fir.char<1,?>>> {bindc_name = ".cond.char"
-  ! CHECK-DAG: %[[UNALLOC:.*]] = fir.zero_bits !fir.heap<!fir.char<1,?>>
-  ! CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
-  ! CHECK: %[[BOX:.*]] = fir.embox %[[UNALLOC]] typeparams %[[C0]]
-  ! CHECK: fir.store %[[BOX]] to %{{.*}} : !fir.ref<!fir.box<!fir.heap<!fir.char<1,?>>>>
-  ! CHECK: %[[BOX_DECL:.*]]:2 = hlfir.declare %[[BOX_ALLOC]] {uniq_name = ".cond.result"}
-  ! CHECK: fir.if
-  ! CHECK:   hlfir.assign {{.*}} to %[[BOX_DECL]]#0 realloc temporary_lhs
+  ! CHECK: %[[RESULT:.*]] = hlfir.conditional %{{.*}} : (i1) -> !hlfir.expr<!fir.char<1,?>> {
+  ! CHECK:   hlfir.yield %{{.*}}
   ! CHECK: } else {
-  ! CHECK:   hlfir.assign {{.*}} to %[[BOX_DECL]]#0 realloc temporary_lhs
+  ! CHECK:   hlfir.yield %{{.*}}
   ! CHECK: }
 end subroutine
 
@@ -153,16 +144,10 @@ subroutine test_array(flag)
   arr1 = 1
   arr2 = 2
   result = (flag ? arr1 : arr2)
-  ! CHECK: %[[BOX_ALLOC:.*]] = fir.alloca !fir.box<!fir.heap<!fir.array<{{.*}}xi32>>> {bindc_name = ".cond.array"
-  ! CHECK: %[[UNALLOC:.*]] = fir.zero_bits !fir.heap<!fir.array<{{.*}}xi32>>
-  ! CHECK: %[[SHAPE:.*]] = fir.shape
-  ! CHECK: %[[BOX:.*]] = fir.embox %[[UNALLOC]](%[[SHAPE]])
-  ! CHECK: fir.store %[[BOX]] to %[[BOX_ALLOC]]
-  ! CHECK: %[[BOX_DECL:.*]]:2 = hlfir.declare %[[BOX_ALLOC]] {uniq_name = ".cond.result"}
-  ! CHECK: fir.if
-  ! CHECK:   hlfir.assign {{.*}} to %[[BOX_DECL]]#0 realloc temporary_lhs
+  ! CHECK: %[[RESULT:.*]] = hlfir.conditional %{{.*}} : (i1) -> !hlfir.expr<?xi32> {
+  ! CHECK:   hlfir.yield %{{.*}} : !fir.ref<!fir.array<10xi32>>
   ! CHECK: } else {
-  ! CHECK:   hlfir.assign {{.*}} to %[[BOX_DECL]]#0 realloc temporary_lhs
+  ! CHECK:   hlfir.yield %{{.*}} : !fir.ref<!fir.array<10xi32>>
   ! CHECK: }
 end subroutine
 
@@ -176,12 +161,10 @@ subroutine test_derived_type(flag)
   p1 = point(1.0, 2.0)
   p2 = point(3.0, 4.0)
   result = (flag ? p1 : p2)
-  ! CHECK: %[[TEMP:.*]] = fir.alloca !fir.type<_QFtest_derived_typeTpoint{x:f32,y:f32}> {bindc_name = ".cond.scalar"
-  ! CHECK: %[[TEMP_DECL:.*]]:2 = hlfir.declare %[[TEMP]] {uniq_name = ".cond.result"}
-  ! CHECK: fir.if
-  ! CHECK:   hlfir.assign {{.*}} to %[[TEMP_DECL]]#0 : !fir.ref<!fir.type<_QFtest_derived_typeTpoint{x:f32,y:f32}>>, !fir.ref<!fir.type<_QFtest_derived_typeTpoint{x:f32,y:f32}>>
+  ! CHECK: %[[RESULT:.*]] = hlfir.conditional %{{.*}} : (i1) -> !hlfir.expr<!fir.type<_QFtest_derived_typeTpoint{x:f32,y:f32}>> {
+  ! CHECK:   hlfir.yield %{{.*}} : !fir.ref<!fir.type<_QFtest_derived_typeTpoint{x:f32,y:f32}>>
   ! CHECK: } else {
-  ! CHECK:   hlfir.assign {{.*}} to %[[TEMP_DECL]]#0 : !fir.ref<!fir.type<_QFtest_derived_typeTpoint{x:f32,y:f32}>>, !fir.ref<!fir.type<_QFtest_derived_typeTpoint{x:f32,y:f32}>>
+  ! CHECK:   hlfir.yield %{{.*}} : !fir.ref<!fir.type<_QFtest_derived_typeTpoint{x:f32,y:f32}>>
   ! CHECK: }
 end subroutine
 
@@ -227,12 +210,10 @@ subroutine test_assumed_length_char(flag, str1, str2)
   character(len=*) :: str1, str2
   character(len=100) :: result
   result = (flag ? str1 : str2)
-  ! Deferred length path since len=* is not constant
-  ! CHECK: %[[BOX_ALLOC:.*]] = fir.alloca !fir.box<!fir.heap<!fir.char<1,?>>> {bindc_name = ".cond.char"
-  ! CHECK: fir.if
-  ! CHECK:   hlfir.assign {{.*}} to {{.*}} realloc temporary_lhs
+  ! CHECK: %[[RESULT:.*]] = hlfir.conditional %{{.*}} : (i1) -> !hlfir.expr<!fir.char<1,?>> {
+  ! CHECK:   hlfir.yield %{{.*}}
   ! CHECK: } else {
-  ! CHECK:   hlfir.assign {{.*}} to {{.*}} realloc temporary_lhs
+  ! CHECK:   hlfir.yield %{{.*}}
   ! CHECK: }
 end subroutine
 
@@ -266,11 +247,10 @@ subroutine test_array_section(flag)
   logical :: flag
   integer :: arr1(20), arr2(20), result(10)
   result = (flag ? arr1(1:10) : arr2(11:20))
-  ! CHECK: %[[BOX_ALLOC:.*]] = fir.alloca !fir.box<!fir.heap<!fir.array<{{.*}}xi32>>> {bindc_name = ".cond.array"
-  ! CHECK: fir.if
-  ! CHECK:   hlfir.assign {{.*}} to {{.*}} realloc temporary_lhs
+  ! CHECK: %[[RESULT:.*]] = hlfir.conditional %{{.*}} : (i1) -> !hlfir.expr<?xi32> {
+  ! CHECK:   hlfir.yield %{{.*}}
   ! CHECK: } else {
-  ! CHECK:   hlfir.assign {{.*}} to {{.*}} realloc temporary_lhs
+  ! CHECK:   hlfir.yield %{{.*}}
   ! CHECK: }
 end subroutine
 
@@ -280,11 +260,10 @@ subroutine test_noncontiguous_section(flag)
   integer :: arr1(20), arr2(20), result(5)
   ! Non-contiguous stride-2 sections: result must be contiguous.
   result = (flag ? arr1(1:10:2) : arr2(2:10:2))
-  ! CHECK: %[[BOX_ALLOC:.*]] = fir.alloca !fir.box<!fir.heap<!fir.array<{{.*}}xi32>>> {bindc_name = ".cond.array"
-  ! CHECK: fir.if
-  ! CHECK:   hlfir.assign {{.*}} to {{.*}} realloc temporary_lhs
+  ! CHECK: %[[RESULT:.*]] = hlfir.conditional %{{.*}} : (i1) -> !hlfir.expr<?xi32> {
+  ! CHECK:   hlfir.yield %{{.*}}
   ! CHECK: } else {
-  ! CHECK:   hlfir.assign {{.*}} to {{.*}} realloc temporary_lhs
+  ! CHECK:   hlfir.yield %{{.*}}
   ! CHECK: }
 end subroutine
 
@@ -296,17 +275,12 @@ subroutine test_polymorphic(flag, x, y)
   logical :: flag
   class(base_type), intent(in) :: x, y
   type(base_type) :: result
-  ! Polymorphic conditional: uses fir.class (not fir.box) to carry dynamic type.
+  ! Polymorphic conditional: uses hlfir.conditional with polymorphic expr type.
   result = (flag ? x : y)
-  ! CHECK: %[[BOX_ALLOC:.*]] = fir.alloca !fir.class<!fir.heap<!fir.type<_QFtest_polymorphicTbase_type{val:i32}>>> {bindc_name = ".cond.polymorphic"
-  ! CHECK: %[[UNALLOC:.*]] = fir.zero_bits !fir.heap<!fir.type<_QFtest_polymorphicTbase_type{val:i32}>>
-  ! CHECK: %[[BOX:.*]] = fir.embox %[[UNALLOC]] : (!fir.heap<!fir.type<_QFtest_polymorphicTbase_type{val:i32}>>) -> !fir.class<!fir.heap<!fir.type<_QFtest_polymorphicTbase_type{val:i32}>>>
-  ! CHECK: fir.store %[[BOX]] to %[[BOX_ALLOC]]
-  ! CHECK: %[[BOX_DECL:.*]]:2 = hlfir.declare %[[BOX_ALLOC]] {uniq_name = ".cond.result"}
-  ! CHECK: fir.if
-  ! CHECK:   hlfir.assign {{.*}} to %[[BOX_DECL]]#0 realloc temporary_lhs
+  ! CHECK: %[[RESULT:.*]] = hlfir.conditional %{{.*}} {{.*}} {
+  ! CHECK:   hlfir.yield %{{.*}}
   ! CHECK: } else {
-  ! CHECK:   hlfir.assign {{.*}} to %[[BOX_DECL]]#0 realloc temporary_lhs
+  ! CHECK:   hlfir.yield %{{.*}}
   ! CHECK: }
 end subroutine
 
@@ -318,14 +292,12 @@ subroutine test_polymorphic_array(flag, x, y)
   logical :: flag
   class(base_type), intent(in) :: x(:), y(:)
   type(base_type), allocatable :: result(:)
-  ! Polymorphic array: uses fir.class (not fir.box) for the allocatable temp.
+  ! Polymorphic array: hlfir.conditional with polymorphic array expr type.
   result = (flag ? x : y)
-  ! CHECK: fir.alloca !fir.class<!fir.heap<!fir.array<?x!fir.type<_QFtest_polymorphic_arrayTbase_type{val:i32}>>>> {bindc_name = ".cond.array"
-  ! CHECK: %[[PA_DECL:.*]]:2 = hlfir.declare {{.*}} {uniq_name = ".cond.result"}
-  ! CHECK: fir.if
-  ! CHECK:   hlfir.assign {{.*}} to %[[PA_DECL]]#0 realloc temporary_lhs
+  ! CHECK: %[[RESULT:.*]] = hlfir.conditional %{{.*}} {{.*}} {
+  ! CHECK:   hlfir.yield %{{.*}}
   ! CHECK: } else {
-  ! CHECK:   hlfir.assign {{.*}} to %[[PA_DECL]]#0 realloc temporary_lhs
+  ! CHECK:   hlfir.yield %{{.*}}
   ! CHECK: }
 end subroutine
 
@@ -339,13 +311,10 @@ subroutine test_polymorphic_char_component(flag, x, y)
   class(named_type), intent(in) :: x, y
   type(named_type), allocatable :: result
   result = (flag ? x : y)
-  ! The alloca type proves fir.class is used for the polymorphic temp.
-  ! CHECK: fir.alloca !fir.class<!fir.heap<!fir.type<_QFtest_polymorphic_char_componentTnamed_type{name:!fir.char<1,20>,id:i32}>>> {bindc_name = ".cond.polymorphic"
-  ! CHECK: %[[PC_DECL:.*]]:2 = hlfir.declare {{.*}} {uniq_name = ".cond.result"}
-  ! CHECK: fir.if
-  ! CHECK:   hlfir.assign {{.*}} to %[[PC_DECL]]#0 realloc temporary_lhs
+  ! CHECK: %[[RESULT:.*]] = hlfir.conditional %{{.*}} {{.*}} {
+  ! CHECK:   hlfir.yield %{{.*}}
   ! CHECK: } else {
-  ! CHECK:   hlfir.assign {{.*}} to %[[PC_DECL]]#0 realloc temporary_lhs
+  ! CHECK:   hlfir.yield %{{.*}}
   ! CHECK: }
 end subroutine
 
@@ -358,15 +327,12 @@ subroutine test_mixed_type_class(flag, x, y)
   type(base_type), intent(in) :: x
   class(base_type), intent(in) :: y
   type(base_type) :: result
-  ! Mixed TYPE(t)/CLASS(t): GetType() returns polymorphic when either branch
-  ! is polymorphic, so the result must use fir.class (not fir.box).
+  ! Mixed TYPE(t)/CLASS(t): result is polymorphic when either branch is.
   result = (flag ? x : y)
-  ! CHECK: fir.alloca !fir.class<!fir.heap<!fir.type<_QFtest_mixed_type_classTbase_type{val:i32}>>> {bindc_name = ".cond.polymorphic"
-  ! CHECK: %[[MX_DECL:.*]]:2 = hlfir.declare {{.*}} {uniq_name = ".cond.result"}
-  ! CHECK: fir.if
-  ! CHECK:   hlfir.assign {{.*}} to %[[MX_DECL]]#0 realloc temporary_lhs
+  ! CHECK: %[[RESULT:.*]] = hlfir.conditional %{{.*}} {{.*}} {
+  ! CHECK:   hlfir.yield %{{.*}}
   ! CHECK: } else {
-  ! CHECK:   hlfir.assign {{.*}} to %[[MX_DECL]]#0 realloc temporary_lhs
+  ! CHECK:   hlfir.yield %{{.*}}
   ! CHECK: }
 end subroutine
 
@@ -382,13 +348,10 @@ subroutine test_polymorphic_extends(flag, x, y)
   class(base_type), intent(in) :: x, y
   type(base_type) :: result
   ! Polymorphic with type hierarchy: x or y may hold extended_type at runtime.
-  ! The lowering must use fir.class to preserve dynamic type info.
   result = (flag ? x : y)
-  ! CHECK: fir.alloca !fir.class<!fir.heap<!fir.type<_QFtest_polymorphic_extendsTbase_type{val:i32}>>> {bindc_name = ".cond.polymorphic"
-  ! CHECK: %[[PE_DECL:.*]]:2 = hlfir.declare {{.*}} {uniq_name = ".cond.result"}
-  ! CHECK: fir.if
-  ! CHECK:   hlfir.assign {{.*}} to %[[PE_DECL]]#0 realloc temporary_lhs
+  ! CHECK: %[[RESULT:.*]] = hlfir.conditional %{{.*}} {{.*}} {
+  ! CHECK:   hlfir.yield %{{.*}}
   ! CHECK: } else {
-  ! CHECK:   hlfir.assign {{.*}} to %[[PE_DECL]]#0 realloc temporary_lhs
+  ! CHECK:   hlfir.yield %{{.*}}
   ! CHECK: }
 end subroutine



More information about the flang-commits mailing list