[flang-commits] [flang] b8e4232 - [flang] Cast fir.select[_rank] selector to i64. (#153239)

via flang-commits flang-commits at lists.llvm.org
Tue Aug 12 16:43:48 PDT 2025


Author: Slava Zakharin
Date: 2025-08-12T16:43:44-07:00
New Revision: b8e4232bd2fb326cca994dd88cfb249266d6c53e

URL: https://github.com/llvm/llvm-project/commit/b8e4232bd2fb326cca994dd88cfb249266d6c53e
DIFF: https://github.com/llvm/llvm-project/commit/b8e4232bd2fb326cca994dd88cfb249266d6c53e.diff

LOG: [flang] Cast fir.select[_rank] selector to i64. (#153239)

Properly cast the selector to `i64` regardless of its integer type.
We used to generate llvm.trunc always.

We have to use `i64` as long as the case values may exceed INT_MAX.

Fixes #153050.

Added: 
    

Modified: 
    flang/lib/Optimizer/CodeGen/CodeGen.cpp
    flang/test/Fir/convert-to-llvm.fir
    flang/test/Fir/select.fir

Removed: 
    


################################################################################
diff  --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
index 1b289ae690cbe..ba5fef97c83ed 100644
--- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp
+++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
@@ -3525,114 +3525,123 @@ struct SelectCaseOpConversion : public fir::FIROpConversion<fir::SelectCaseOp> {
   }
 };
 
-/// Helper function for converting select ops. This function converts the
-/// signature of the given block. If the new block signature is 
diff erent from
-/// `expectedTypes`, returns "failure".
-static llvm::FailureOr<mlir::Block *>
-getConvertedBlock(mlir::ConversionPatternRewriter &rewriter,
-                  const mlir::TypeConverter *converter,
-                  mlir::Operation *branchOp, mlir::Block *block,
-                  mlir::TypeRange expectedTypes) {
-  assert(converter && "expected non-null type converter");
-  assert(!block->isEntryBlock() && "entry blocks have no predecessors");
-
-  // There is nothing to do if the types already match.
-  if (block->getArgumentTypes() == expectedTypes)
-    return block;
-
-  // Compute the new block argument types and convert the block.
-  std::optional<mlir::TypeConverter::SignatureConversion> conversion =
-      converter->convertBlockSignature(block);
-  if (!conversion)
-    return rewriter.notifyMatchFailure(branchOp,
-                                       "could not compute block signature");
-  if (expectedTypes != conversion->getConvertedTypes())
-    return rewriter.notifyMatchFailure(
-        branchOp,
-        "mismatch between adaptor operand types and computed block signature");
-  return rewriter.applySignatureConversion(block, *conversion, converter);
-}
-
+/// Base class for SelectOpConversion and SelectRankOpConversion.
 template <typename OP>
-static llvm::LogicalResult
-selectMatchAndRewrite(const fir::LLVMTypeConverter &lowering, OP select,
-                      typename OP::Adaptor adaptor,
-                      mlir::ConversionPatternRewriter &rewriter,
-                      const mlir::TypeConverter *converter) {
-  unsigned conds = select.getNumConditions();
-  auto cases = select.getCases().getValue();
-  mlir::Value selector = adaptor.getSelector();
-  auto loc = select.getLoc();
-  assert(conds > 0 && "select must have cases");
-
-  llvm::SmallVector<mlir::Block *> destinations;
-  llvm::SmallVector<mlir::ValueRange> destinationsOperands;
-  mlir::Block *defaultDestination;
-  mlir::ValueRange defaultOperands;
-  llvm::SmallVector<int32_t> caseValues;
-
-  for (unsigned t = 0; t != conds; ++t) {
-    mlir::Block *dest = select.getSuccessor(t);
-    auto destOps = select.getSuccessorOperands(adaptor.getOperands(), t);
-    const mlir::Attribute &attr = cases[t];
-    if (auto intAttr = mlir::dyn_cast<mlir::IntegerAttr>(attr)) {
-      destinationsOperands.push_back(destOps ? *destOps : mlir::ValueRange{});
-      auto convertedBlock =
-          getConvertedBlock(rewriter, converter, select, dest,
-                            mlir::TypeRange(destinationsOperands.back()));
+struct SelectOpConversionBase : public fir::FIROpConversion<OP> {
+  using fir::FIROpConversion<OP>::FIROpConversion;
+
+private:
+  /// Helper function for converting select ops. This function converts the
+  /// signature of the given block. If the new block signature is 
diff erent from
+  /// `expectedTypes`, returns "failure".
+  llvm::FailureOr<mlir::Block *>
+  getConvertedBlock(mlir::ConversionPatternRewriter &rewriter,
+                    mlir::Operation *branchOp, mlir::Block *block,
+                    mlir::TypeRange expectedTypes) const {
+    const mlir::TypeConverter *converter = this->getTypeConverter();
+    assert(converter && "expected non-null type converter");
+    assert(!block->isEntryBlock() && "entry blocks have no predecessors");
+
+    // There is nothing to do if the types already match.
+    if (block->getArgumentTypes() == expectedTypes)
+      return block;
+
+    // Compute the new block argument types and convert the block.
+    std::optional<mlir::TypeConverter::SignatureConversion> conversion =
+        converter->convertBlockSignature(block);
+    if (!conversion)
+      return rewriter.notifyMatchFailure(branchOp,
+                                         "could not compute block signature");
+    if (expectedTypes != conversion->getConvertedTypes())
+      return rewriter.notifyMatchFailure(branchOp,
+                                         "mismatch between adaptor operand "
+                                         "types and computed block signature");
+    return rewriter.applySignatureConversion(block, *conversion, converter);
+  }
+
+protected:
+  llvm::LogicalResult
+  selectMatchAndRewrite(OP select, typename OP::Adaptor adaptor,
+                        mlir::ConversionPatternRewriter &rewriter) const {
+    unsigned conds = select.getNumConditions();
+    auto cases = select.getCases().getValue();
+    mlir::Value selector = adaptor.getSelector();
+    auto loc = select.getLoc();
+    assert(conds > 0 && "select must have cases");
+
+    llvm::SmallVector<mlir::Block *> destinations;
+    llvm::SmallVector<mlir::ValueRange> destinationsOperands;
+    mlir::Block *defaultDestination;
+    mlir::ValueRange defaultOperands;
+    // LLVM::SwitchOp selector type and the case values types
+    // must have the same bit width, so cast the selector to i64,
+    // and use i64 for the case values. It is hard to imagine
+    // a computed GO TO with the number of labels in the label-list
+    // bigger than INT_MAX, but let's use i64 to be on the safe side.
+    // Moreover, fir.select operation is more relaxed than
+    // a Fortran computed GO TO, so it may specify such a case value
+    // even if there is just a single label/case.
+    llvm::SmallVector<int64_t> caseValues;
+
+    for (unsigned t = 0; t != conds; ++t) {
+      mlir::Block *dest = select.getSuccessor(t);
+      auto destOps = select.getSuccessorOperands(adaptor.getOperands(), t);
+      const mlir::Attribute &attr = cases[t];
+      if (auto intAttr = mlir::dyn_cast<mlir::IntegerAttr>(attr)) {
+        destinationsOperands.push_back(destOps ? *destOps : mlir::ValueRange{});
+        auto convertedBlock =
+            getConvertedBlock(rewriter, select, dest,
+                              mlir::TypeRange(destinationsOperands.back()));
+        if (mlir::failed(convertedBlock))
+          return mlir::failure();
+        destinations.push_back(*convertedBlock);
+        caseValues.push_back(intAttr.getInt());
+        continue;
+      }
+      assert(mlir::dyn_cast_or_null<mlir::UnitAttr>(attr));
+      assert((t + 1 == conds) && "unit must be last");
+      defaultOperands = destOps ? *destOps : mlir::ValueRange{};
+      auto convertedBlock = getConvertedBlock(rewriter, select, dest,
+                                              mlir::TypeRange(defaultOperands));
       if (mlir::failed(convertedBlock))
         return mlir::failure();
-      destinations.push_back(*convertedBlock);
-      caseValues.push_back(intAttr.getInt());
-      continue;
+      defaultDestination = *convertedBlock;
     }
-    assert(mlir::dyn_cast_or_null<mlir::UnitAttr>(attr));
-    assert((t + 1 == conds) && "unit must be last");
-    defaultOperands = destOps ? *destOps : mlir::ValueRange{};
-    auto convertedBlock = getConvertedBlock(rewriter, converter, select, dest,
-                                            mlir::TypeRange(defaultOperands));
-    if (mlir::failed(convertedBlock))
-      return mlir::failure();
-    defaultDestination = *convertedBlock;
-  }
-
-  // LLVM::SwitchOp takes a i32 type for the selector.
-  if (select.getSelector().getType() != rewriter.getI32Type())
-    selector = mlir::LLVM::TruncOp::create(rewriter, loc, rewriter.getI32Type(),
-                                           selector);
-
-  rewriter.replaceOpWithNewOp<mlir::LLVM::SwitchOp>(
-      select, selector,
-      /*defaultDestination=*/defaultDestination,
-      /*defaultOperands=*/defaultOperands,
-      /*caseValues=*/caseValues,
-      /*caseDestinations=*/destinations,
-      /*caseOperands=*/destinationsOperands,
-      /*branchWeights=*/llvm::ArrayRef<std::int32_t>());
-  return mlir::success();
-}
 
+    selector =
+        this->integerCast(loc, rewriter, rewriter.getI64Type(), selector);
+
+    rewriter.replaceOpWithNewOp<mlir::LLVM::SwitchOp>(
+        select, selector,
+        /*defaultDestination=*/defaultDestination,
+        /*defaultOperands=*/defaultOperands,
+        /*caseValues=*/rewriter.getI64VectorAttr(caseValues),
+        /*caseDestinations=*/destinations,
+        /*caseOperands=*/destinationsOperands,
+        /*branchWeights=*/llvm::ArrayRef<std::int32_t>());
+    return mlir::success();
+  }
+};
 /// conversion of fir::SelectOp to an if-then-else ladder
-struct SelectOpConversion : public fir::FIROpConversion<fir::SelectOp> {
-  using FIROpConversion::FIROpConversion;
+struct SelectOpConversion : public SelectOpConversionBase<fir::SelectOp> {
+  using SelectOpConversionBase::SelectOpConversionBase;
 
   llvm::LogicalResult
   matchAndRewrite(fir::SelectOp op, OpAdaptor adaptor,
                   mlir::ConversionPatternRewriter &rewriter) const override {
-    return selectMatchAndRewrite<fir::SelectOp>(lowerTy(), op, adaptor,
-                                                rewriter, getTypeConverter());
+    return this->selectMatchAndRewrite(op, adaptor, rewriter);
   }
 };
 
 /// conversion of fir::SelectRankOp to an if-then-else ladder
-struct SelectRankOpConversion : public fir::FIROpConversion<fir::SelectRankOp> {
-  using FIROpConversion::FIROpConversion;
+struct SelectRankOpConversion
+    : public SelectOpConversionBase<fir::SelectRankOp> {
+  using SelectOpConversionBase::SelectOpConversionBase;
 
   llvm::LogicalResult
   matchAndRewrite(fir::SelectRankOp op, OpAdaptor adaptor,
                   mlir::ConversionPatternRewriter &rewriter) const override {
-    return selectMatchAndRewrite<fir::SelectRankOp>(
-        lowerTy(), op, adaptor, rewriter, getTypeConverter());
+    return this->selectMatchAndRewrite(op, adaptor, rewriter);
   }
 };
 

diff  --git a/flang/test/Fir/convert-to-llvm.fir b/flang/test/Fir/convert-to-llvm.fir
index 50a98466f0d4b..cd87bf8d28ed5 100644
--- a/flang/test/Fir/convert-to-llvm.fir
+++ b/flang/test/Fir/convert-to-llvm.fir
@@ -338,8 +338,7 @@ func.func @select(%arg : index, %arg2 : i32) -> i32 {
 // CHECK:         %[[C0:.*]] = llvm.mlir.constant(1 : i32) : i32
 // CHECK:         %[[C1:.*]] = llvm.mlir.constant(2 : i32) : i32
 // CHECK:         %[[C2:.*]] = llvm.mlir.constant(3 : i32) : i32
-// CHECK:         %[[SELECTOR:.*]] = llvm.trunc %[[SELECTVALUE]] : i{{.*}} to i32
-// CHECK:         llvm.switch %[[SELECTOR]] : i32, ^bb5 [
+// CHECK:         llvm.switch %[[SELECTVALUE]] : i64, ^bb5 [
 // CHECK:           1: ^bb1(%[[C0]] : i32),
 // CHECK:           2: ^bb2(%[[C2]], %[[SELECTVALUE]], %[[ARG1]] : i32, [[IDX]], i32),
 // CHECK:           3: ^bb3(%[[ARG1]], %[[C2]] : i32, i32),
@@ -384,7 +383,8 @@ func.func @select_rank(%arg : i32, %arg2 : i32) -> i32 {
 // CHECK:         %[[C0:.*]] = llvm.mlir.constant(1 : i32) : i32
 // CHECK:         %[[C1:.*]] = llvm.mlir.constant(2 : i32) : i32
 // CHECK:         %[[C2:.*]] = llvm.mlir.constant(3 : i32) : i32
-// CHECK:         llvm.switch %[[SELECTVALUE]] : i32, ^bb5 [
+// CHECK:         %[[SELECTOR:.*]] = llvm.sext %[[SELECTVALUE]] : i{{.*}} to i64
+// CHECK:         llvm.switch %[[SELECTOR]] : i64, ^bb5 [
 // CHECK:           1: ^bb1(%[[C0]] : i32),
 // CHECK:           2: ^bb2(%[[C2]], %[[SELECTVALUE]], %[[ARG1]] : i32, i32, i32),
 // CHECK:           3: ^bb3(%[[ARG1]], %[[C2]] : i32, i32),
@@ -2853,6 +2853,8 @@ func.func @test_call_arg_attrs_direct(%arg0: i32, %arg1: !fir.ref<i64>) {
   return
 }
 
+// -----
+
 // CHECK-LABEL: @test_call_arg_attrs_indirect
 func.func @test_call_arg_attrs_indirect(%arg0: i16, %arg1: (i16)-> i16) -> i16 {
   // CHECK: llvm.call %arg1(%{{.*}}) : !llvm.ptr, (i16 {llvm.noundef, llvm.signext}) -> (i16 {llvm.signext})
@@ -2860,6 +2862,8 @@ func.func @test_call_arg_attrs_indirect(%arg0: i16, %arg1: (i16)-> i16) -> i16 {
   return %0 : i16
 }
 
+// -----
+
 // CHECK-LABEL: @test_byval
 func.func @test_byval(%arg0: (!fir.ref<!fir.type<t{a:!fir.array<5xf64>}>>, f64) -> (), %arg1: !fir.ref<!fir.type<t{a:!fir.array<5xf64>}>>, %arg2: f64) {
   //  llvm.call %{{.*}}(%{{.*}}, %{{.*}}) : !llvm.ptr, (!llvm.ptr {llvm.byval = !llvm.struct<"t", (array<5 x f64>)>}, f64) -> ()
@@ -2867,9 +2871,56 @@ func.func @test_byval(%arg0: (!fir.ref<!fir.type<t{a:!fir.array<5xf64>}>>, f64)
   return
 }
 
+// -----
+
 // CHECK-LABEL: @test_sret
 func.func @test_sret(%arg0: (!fir.ref<!fir.type<t{a:!fir.array<5xf64>}>>, f64) -> (), %arg1: !fir.ref<!fir.type<t{a:!fir.array<5xf64>}>>, %arg2: f64) {
   //  llvm.call %{{.*}}(%{{.*}}, %{{.*}}) : !llvm.ptr, (!llvm.ptr {llvm.sret = !llvm.struct<"t", (array<5 x f64>)>}, f64) -> ()
   fir.call %arg0(%arg1, %arg2) : (!fir.ref<!fir.type<t{a:!fir.array<5xf64>}>> {llvm.sret = !fir.type<t{a:!fir.array<5xf64>}>}, f64) -> ()
   return
 }
+
+// -----
+
+func.func @select_with_cast(%arg1 : i8, %arg2 : i16, %arg3: i64, %arg4: index) -> () {
+  fir.select %arg1 : i8 [ 1, ^bb1, unit, ^bb1 ]
+  ^bb1:
+  fir.select %arg2 : i16 [ 1, ^bb2, unit, ^bb2 ]
+  ^bb2:
+  fir.select %arg3 : i64 [ 1, ^bb3, unit, ^bb3 ]
+  ^bb3:
+  fir.select %arg4 : index [ 1, ^bb4, unit, ^bb4 ]
+  ^bb4:
+  fir.select %arg3 : i64 [ 4294967296, ^bb5, unit, ^bb5 ]
+  ^bb5:
+  return
+}
+// CHECK-LABEL:   llvm.func @select_with_cast(
+// CHECK-SAME:      %[[ARG0:.*]]: i8,
+// CHECK-SAME:      %[[ARG1:.*]]: i16,
+// CHECK-SAME:      %[[ARG2:.*]]: i64,
+// CHECK-SAME:      %[[ARG3:.*]]: i64) {
+// CHECK:           %[[VAL_0:.*]] = llvm.sext %[[ARG0]] : i8 to i64
+// CHECK:           llvm.switch %[[VAL_0]] : i64, ^bb1 [
+// CHECK:             1: ^bb1
+// CHECK:           ]
+// CHECK:         ^bb1:
+// CHECK:           %[[VAL_1:.*]] = llvm.sext %[[ARG1]] : i16 to i64
+// CHECK:           llvm.switch %[[VAL_1]] : i64, ^bb2 [
+// CHECK:             1: ^bb2
+// CHECK:           ]
+// CHECK:         ^bb2:
+// CHECK:           llvm.switch %[[ARG2]] : i64, ^bb3 [
+// CHECK:             1: ^bb3
+// CHECK:           ]
+// CHECK:         ^bb3:
+// CHECK:           llvm.switch %[[ARG3]] : i64, ^bb4 [
+// CHECK:             1: ^bb4
+// CHECK:           ]
+// CHECK:         ^bb4:
+// CHECK:           llvm.switch %[[ARG2]] : i64, ^bb5 [
+// CHECK:             4294967296: ^bb5
+// CHECK:           ]
+// CHECK:         ^bb5:
+// CHECK:           llvm.return
+// CHECK:         }

diff  --git a/flang/test/Fir/select.fir b/flang/test/Fir/select.fir
index 47cc5e4122076..5e88048446407 100644
--- a/flang/test/Fir/select.fir
+++ b/flang/test/Fir/select.fir
@@ -7,8 +7,8 @@
 func.func @f(%a : i32) -> i32 {
    %1 = arith.constant 1 : i32
    %2 = arith.constant 42 : i32
-// CHECK: switch i32 %{{.*}}, label %{{.*}} [
-// CHECK:   i32 1, label %{{.*}}
+// CHECK: switch i64 %{{.*}}, label %{{.*}} [
+// CHECK:   i64 1, label %{{.*}}
 // CHECK: ]
    fir.select %a : i32 [1, ^bb2(%1:i32), unit, ^bb3(%2:i32)]
 ^bb2(%3 : i32) :
@@ -24,9 +24,9 @@ func.func @g(%a : i32) -> i32 {
    %1 = arith.constant 1 : i32
    %2 = arith.constant 42 : i32
 
-// CHECK: switch i32 %{{.*}}, label %{{.*}} [
-// CHECK:  i32 1, label %{{.*}}
-// CHECK:  i32 -1, label %{{.*}}
+// CHECK: switch i64 %{{.*}}, label %{{.*}} [
+// CHECK:  i64 1, label %{{.*}}
+// CHECK:  i64 -1, label %{{.*}}
 // CHECK: ]
    fir.select_rank %a : i32 [1, ^bb2(%1:i32), -1, ^bb4, unit, ^bb3(%2:i32)]
 ^bb2(%3 : i32) :


        


More information about the flang-commits mailing list