[flang-commits] [flang] b8e4232 - [flang] Cast fir.select[_rank] selector to i64. (#153239)
via flang-commits
flang-commits at lists.llvm.org
Tue Aug 12 16:43:48 PDT 2025
Author: Slava Zakharin
Date: 2025-08-12T16:43:44-07:00
New Revision: b8e4232bd2fb326cca994dd88cfb249266d6c53e
URL: https://github.com/llvm/llvm-project/commit/b8e4232bd2fb326cca994dd88cfb249266d6c53e
DIFF: https://github.com/llvm/llvm-project/commit/b8e4232bd2fb326cca994dd88cfb249266d6c53e.diff
LOG: [flang] Cast fir.select[_rank] selector to i64. (#153239)
Properly cast the selector to `i64` regardless of its integer type.
We used to generate llvm.trunc always.
We have to use `i64` as long as the case values may exceed INT_MAX.
Fixes #153050.
Added:
Modified:
flang/lib/Optimizer/CodeGen/CodeGen.cpp
flang/test/Fir/convert-to-llvm.fir
flang/test/Fir/select.fir
Removed:
################################################################################
diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
index 1b289ae690cbe..ba5fef97c83ed 100644
--- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp
+++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
@@ -3525,114 +3525,123 @@ struct SelectCaseOpConversion : public fir::FIROpConversion<fir::SelectCaseOp> {
}
};
-/// Helper function for converting select ops. This function converts the
-/// signature of the given block. If the new block signature is
diff erent from
-/// `expectedTypes`, returns "failure".
-static llvm::FailureOr<mlir::Block *>
-getConvertedBlock(mlir::ConversionPatternRewriter &rewriter,
- const mlir::TypeConverter *converter,
- mlir::Operation *branchOp, mlir::Block *block,
- mlir::TypeRange expectedTypes) {
- assert(converter && "expected non-null type converter");
- assert(!block->isEntryBlock() && "entry blocks have no predecessors");
-
- // There is nothing to do if the types already match.
- if (block->getArgumentTypes() == expectedTypes)
- return block;
-
- // Compute the new block argument types and convert the block.
- std::optional<mlir::TypeConverter::SignatureConversion> conversion =
- converter->convertBlockSignature(block);
- if (!conversion)
- return rewriter.notifyMatchFailure(branchOp,
- "could not compute block signature");
- if (expectedTypes != conversion->getConvertedTypes())
- return rewriter.notifyMatchFailure(
- branchOp,
- "mismatch between adaptor operand types and computed block signature");
- return rewriter.applySignatureConversion(block, *conversion, converter);
-}
-
+/// Base class for SelectOpConversion and SelectRankOpConversion.
template <typename OP>
-static llvm::LogicalResult
-selectMatchAndRewrite(const fir::LLVMTypeConverter &lowering, OP select,
- typename OP::Adaptor adaptor,
- mlir::ConversionPatternRewriter &rewriter,
- const mlir::TypeConverter *converter) {
- unsigned conds = select.getNumConditions();
- auto cases = select.getCases().getValue();
- mlir::Value selector = adaptor.getSelector();
- auto loc = select.getLoc();
- assert(conds > 0 && "select must have cases");
-
- llvm::SmallVector<mlir::Block *> destinations;
- llvm::SmallVector<mlir::ValueRange> destinationsOperands;
- mlir::Block *defaultDestination;
- mlir::ValueRange defaultOperands;
- llvm::SmallVector<int32_t> caseValues;
-
- for (unsigned t = 0; t != conds; ++t) {
- mlir::Block *dest = select.getSuccessor(t);
- auto destOps = select.getSuccessorOperands(adaptor.getOperands(), t);
- const mlir::Attribute &attr = cases[t];
- if (auto intAttr = mlir::dyn_cast<mlir::IntegerAttr>(attr)) {
- destinationsOperands.push_back(destOps ? *destOps : mlir::ValueRange{});
- auto convertedBlock =
- getConvertedBlock(rewriter, converter, select, dest,
- mlir::TypeRange(destinationsOperands.back()));
+struct SelectOpConversionBase : public fir::FIROpConversion<OP> {
+ using fir::FIROpConversion<OP>::FIROpConversion;
+
+private:
+ /// Helper function for converting select ops. This function converts the
+ /// signature of the given block. If the new block signature is
diff erent from
+ /// `expectedTypes`, returns "failure".
+ llvm::FailureOr<mlir::Block *>
+ getConvertedBlock(mlir::ConversionPatternRewriter &rewriter,
+ mlir::Operation *branchOp, mlir::Block *block,
+ mlir::TypeRange expectedTypes) const {
+ const mlir::TypeConverter *converter = this->getTypeConverter();
+ assert(converter && "expected non-null type converter");
+ assert(!block->isEntryBlock() && "entry blocks have no predecessors");
+
+ // There is nothing to do if the types already match.
+ if (block->getArgumentTypes() == expectedTypes)
+ return block;
+
+ // Compute the new block argument types and convert the block.
+ std::optional<mlir::TypeConverter::SignatureConversion> conversion =
+ converter->convertBlockSignature(block);
+ if (!conversion)
+ return rewriter.notifyMatchFailure(branchOp,
+ "could not compute block signature");
+ if (expectedTypes != conversion->getConvertedTypes())
+ return rewriter.notifyMatchFailure(branchOp,
+ "mismatch between adaptor operand "
+ "types and computed block signature");
+ return rewriter.applySignatureConversion(block, *conversion, converter);
+ }
+
+protected:
+ llvm::LogicalResult
+ selectMatchAndRewrite(OP select, typename OP::Adaptor adaptor,
+ mlir::ConversionPatternRewriter &rewriter) const {
+ unsigned conds = select.getNumConditions();
+ auto cases = select.getCases().getValue();
+ mlir::Value selector = adaptor.getSelector();
+ auto loc = select.getLoc();
+ assert(conds > 0 && "select must have cases");
+
+ llvm::SmallVector<mlir::Block *> destinations;
+ llvm::SmallVector<mlir::ValueRange> destinationsOperands;
+ mlir::Block *defaultDestination;
+ mlir::ValueRange defaultOperands;
+ // LLVM::SwitchOp selector type and the case values types
+ // must have the same bit width, so cast the selector to i64,
+ // and use i64 for the case values. It is hard to imagine
+ // a computed GO TO with the number of labels in the label-list
+ // bigger than INT_MAX, but let's use i64 to be on the safe side.
+ // Moreover, fir.select operation is more relaxed than
+ // a Fortran computed GO TO, so it may specify such a case value
+ // even if there is just a single label/case.
+ llvm::SmallVector<int64_t> caseValues;
+
+ for (unsigned t = 0; t != conds; ++t) {
+ mlir::Block *dest = select.getSuccessor(t);
+ auto destOps = select.getSuccessorOperands(adaptor.getOperands(), t);
+ const mlir::Attribute &attr = cases[t];
+ if (auto intAttr = mlir::dyn_cast<mlir::IntegerAttr>(attr)) {
+ destinationsOperands.push_back(destOps ? *destOps : mlir::ValueRange{});
+ auto convertedBlock =
+ getConvertedBlock(rewriter, select, dest,
+ mlir::TypeRange(destinationsOperands.back()));
+ if (mlir::failed(convertedBlock))
+ return mlir::failure();
+ destinations.push_back(*convertedBlock);
+ caseValues.push_back(intAttr.getInt());
+ continue;
+ }
+ assert(mlir::dyn_cast_or_null<mlir::UnitAttr>(attr));
+ assert((t + 1 == conds) && "unit must be last");
+ defaultOperands = destOps ? *destOps : mlir::ValueRange{};
+ auto convertedBlock = getConvertedBlock(rewriter, select, dest,
+ mlir::TypeRange(defaultOperands));
if (mlir::failed(convertedBlock))
return mlir::failure();
- destinations.push_back(*convertedBlock);
- caseValues.push_back(intAttr.getInt());
- continue;
+ defaultDestination = *convertedBlock;
}
- assert(mlir::dyn_cast_or_null<mlir::UnitAttr>(attr));
- assert((t + 1 == conds) && "unit must be last");
- defaultOperands = destOps ? *destOps : mlir::ValueRange{};
- auto convertedBlock = getConvertedBlock(rewriter, converter, select, dest,
- mlir::TypeRange(defaultOperands));
- if (mlir::failed(convertedBlock))
- return mlir::failure();
- defaultDestination = *convertedBlock;
- }
-
- // LLVM::SwitchOp takes a i32 type for the selector.
- if (select.getSelector().getType() != rewriter.getI32Type())
- selector = mlir::LLVM::TruncOp::create(rewriter, loc, rewriter.getI32Type(),
- selector);
-
- rewriter.replaceOpWithNewOp<mlir::LLVM::SwitchOp>(
- select, selector,
- /*defaultDestination=*/defaultDestination,
- /*defaultOperands=*/defaultOperands,
- /*caseValues=*/caseValues,
- /*caseDestinations=*/destinations,
- /*caseOperands=*/destinationsOperands,
- /*branchWeights=*/llvm::ArrayRef<std::int32_t>());
- return mlir::success();
-}
+ selector =
+ this->integerCast(loc, rewriter, rewriter.getI64Type(), selector);
+
+ rewriter.replaceOpWithNewOp<mlir::LLVM::SwitchOp>(
+ select, selector,
+ /*defaultDestination=*/defaultDestination,
+ /*defaultOperands=*/defaultOperands,
+ /*caseValues=*/rewriter.getI64VectorAttr(caseValues),
+ /*caseDestinations=*/destinations,
+ /*caseOperands=*/destinationsOperands,
+ /*branchWeights=*/llvm::ArrayRef<std::int32_t>());
+ return mlir::success();
+ }
+};
/// conversion of fir::SelectOp to an if-then-else ladder
-struct SelectOpConversion : public fir::FIROpConversion<fir::SelectOp> {
- using FIROpConversion::FIROpConversion;
+struct SelectOpConversion : public SelectOpConversionBase<fir::SelectOp> {
+ using SelectOpConversionBase::SelectOpConversionBase;
llvm::LogicalResult
matchAndRewrite(fir::SelectOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
- return selectMatchAndRewrite<fir::SelectOp>(lowerTy(), op, adaptor,
- rewriter, getTypeConverter());
+ return this->selectMatchAndRewrite(op, adaptor, rewriter);
}
};
/// conversion of fir::SelectRankOp to an if-then-else ladder
-struct SelectRankOpConversion : public fir::FIROpConversion<fir::SelectRankOp> {
- using FIROpConversion::FIROpConversion;
+struct SelectRankOpConversion
+ : public SelectOpConversionBase<fir::SelectRankOp> {
+ using SelectOpConversionBase::SelectOpConversionBase;
llvm::LogicalResult
matchAndRewrite(fir::SelectRankOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
- return selectMatchAndRewrite<fir::SelectRankOp>(
- lowerTy(), op, adaptor, rewriter, getTypeConverter());
+ return this->selectMatchAndRewrite(op, adaptor, rewriter);
}
};
diff --git a/flang/test/Fir/convert-to-llvm.fir b/flang/test/Fir/convert-to-llvm.fir
index 50a98466f0d4b..cd87bf8d28ed5 100644
--- a/flang/test/Fir/convert-to-llvm.fir
+++ b/flang/test/Fir/convert-to-llvm.fir
@@ -338,8 +338,7 @@ func.func @select(%arg : index, %arg2 : i32) -> i32 {
// CHECK: %[[C0:.*]] = llvm.mlir.constant(1 : i32) : i32
// CHECK: %[[C1:.*]] = llvm.mlir.constant(2 : i32) : i32
// CHECK: %[[C2:.*]] = llvm.mlir.constant(3 : i32) : i32
-// CHECK: %[[SELECTOR:.*]] = llvm.trunc %[[SELECTVALUE]] : i{{.*}} to i32
-// CHECK: llvm.switch %[[SELECTOR]] : i32, ^bb5 [
+// CHECK: llvm.switch %[[SELECTVALUE]] : i64, ^bb5 [
// CHECK: 1: ^bb1(%[[C0]] : i32),
// CHECK: 2: ^bb2(%[[C2]], %[[SELECTVALUE]], %[[ARG1]] : i32, [[IDX]], i32),
// CHECK: 3: ^bb3(%[[ARG1]], %[[C2]] : i32, i32),
@@ -384,7 +383,8 @@ func.func @select_rank(%arg : i32, %arg2 : i32) -> i32 {
// CHECK: %[[C0:.*]] = llvm.mlir.constant(1 : i32) : i32
// CHECK: %[[C1:.*]] = llvm.mlir.constant(2 : i32) : i32
// CHECK: %[[C2:.*]] = llvm.mlir.constant(3 : i32) : i32
-// CHECK: llvm.switch %[[SELECTVALUE]] : i32, ^bb5 [
+// CHECK: %[[SELECTOR:.*]] = llvm.sext %[[SELECTVALUE]] : i{{.*}} to i64
+// CHECK: llvm.switch %[[SELECTOR]] : i64, ^bb5 [
// CHECK: 1: ^bb1(%[[C0]] : i32),
// CHECK: 2: ^bb2(%[[C2]], %[[SELECTVALUE]], %[[ARG1]] : i32, i32, i32),
// CHECK: 3: ^bb3(%[[ARG1]], %[[C2]] : i32, i32),
@@ -2853,6 +2853,8 @@ func.func @test_call_arg_attrs_direct(%arg0: i32, %arg1: !fir.ref<i64>) {
return
}
+// -----
+
// CHECK-LABEL: @test_call_arg_attrs_indirect
func.func @test_call_arg_attrs_indirect(%arg0: i16, %arg1: (i16)-> i16) -> i16 {
// CHECK: llvm.call %arg1(%{{.*}}) : !llvm.ptr, (i16 {llvm.noundef, llvm.signext}) -> (i16 {llvm.signext})
@@ -2860,6 +2862,8 @@ func.func @test_call_arg_attrs_indirect(%arg0: i16, %arg1: (i16)-> i16) -> i16 {
return %0 : i16
}
+// -----
+
// CHECK-LABEL: @test_byval
func.func @test_byval(%arg0: (!fir.ref<!fir.type<t{a:!fir.array<5xf64>}>>, f64) -> (), %arg1: !fir.ref<!fir.type<t{a:!fir.array<5xf64>}>>, %arg2: f64) {
// llvm.call %{{.*}}(%{{.*}}, %{{.*}}) : !llvm.ptr, (!llvm.ptr {llvm.byval = !llvm.struct<"t", (array<5 x f64>)>}, f64) -> ()
@@ -2867,9 +2871,56 @@ func.func @test_byval(%arg0: (!fir.ref<!fir.type<t{a:!fir.array<5xf64>}>>, f64)
return
}
+// -----
+
// CHECK-LABEL: @test_sret
func.func @test_sret(%arg0: (!fir.ref<!fir.type<t{a:!fir.array<5xf64>}>>, f64) -> (), %arg1: !fir.ref<!fir.type<t{a:!fir.array<5xf64>}>>, %arg2: f64) {
// llvm.call %{{.*}}(%{{.*}}, %{{.*}}) : !llvm.ptr, (!llvm.ptr {llvm.sret = !llvm.struct<"t", (array<5 x f64>)>}, f64) -> ()
fir.call %arg0(%arg1, %arg2) : (!fir.ref<!fir.type<t{a:!fir.array<5xf64>}>> {llvm.sret = !fir.type<t{a:!fir.array<5xf64>}>}, f64) -> ()
return
}
+
+// -----
+
+func.func @select_with_cast(%arg1 : i8, %arg2 : i16, %arg3: i64, %arg4: index) -> () {
+ fir.select %arg1 : i8 [ 1, ^bb1, unit, ^bb1 ]
+ ^bb1:
+ fir.select %arg2 : i16 [ 1, ^bb2, unit, ^bb2 ]
+ ^bb2:
+ fir.select %arg3 : i64 [ 1, ^bb3, unit, ^bb3 ]
+ ^bb3:
+ fir.select %arg4 : index [ 1, ^bb4, unit, ^bb4 ]
+ ^bb4:
+ fir.select %arg3 : i64 [ 4294967296, ^bb5, unit, ^bb5 ]
+ ^bb5:
+ return
+}
+// CHECK-LABEL: llvm.func @select_with_cast(
+// CHECK-SAME: %[[ARG0:.*]]: i8,
+// CHECK-SAME: %[[ARG1:.*]]: i16,
+// CHECK-SAME: %[[ARG2:.*]]: i64,
+// CHECK-SAME: %[[ARG3:.*]]: i64) {
+// CHECK: %[[VAL_0:.*]] = llvm.sext %[[ARG0]] : i8 to i64
+// CHECK: llvm.switch %[[VAL_0]] : i64, ^bb1 [
+// CHECK: 1: ^bb1
+// CHECK: ]
+// CHECK: ^bb1:
+// CHECK: %[[VAL_1:.*]] = llvm.sext %[[ARG1]] : i16 to i64
+// CHECK: llvm.switch %[[VAL_1]] : i64, ^bb2 [
+// CHECK: 1: ^bb2
+// CHECK: ]
+// CHECK: ^bb2:
+// CHECK: llvm.switch %[[ARG2]] : i64, ^bb3 [
+// CHECK: 1: ^bb3
+// CHECK: ]
+// CHECK: ^bb3:
+// CHECK: llvm.switch %[[ARG3]] : i64, ^bb4 [
+// CHECK: 1: ^bb4
+// CHECK: ]
+// CHECK: ^bb4:
+// CHECK: llvm.switch %[[ARG2]] : i64, ^bb5 [
+// CHECK: 4294967296: ^bb5
+// CHECK: ]
+// CHECK: ^bb5:
+// CHECK: llvm.return
+// CHECK: }
diff --git a/flang/test/Fir/select.fir b/flang/test/Fir/select.fir
index 47cc5e4122076..5e88048446407 100644
--- a/flang/test/Fir/select.fir
+++ b/flang/test/Fir/select.fir
@@ -7,8 +7,8 @@
func.func @f(%a : i32) -> i32 {
%1 = arith.constant 1 : i32
%2 = arith.constant 42 : i32
-// CHECK: switch i32 %{{.*}}, label %{{.*}} [
-// CHECK: i32 1, label %{{.*}}
+// CHECK: switch i64 %{{.*}}, label %{{.*}} [
+// CHECK: i64 1, label %{{.*}}
// CHECK: ]
fir.select %a : i32 [1, ^bb2(%1:i32), unit, ^bb3(%2:i32)]
^bb2(%3 : i32) :
@@ -24,9 +24,9 @@ func.func @g(%a : i32) -> i32 {
%1 = arith.constant 1 : i32
%2 = arith.constant 42 : i32
-// CHECK: switch i32 %{{.*}}, label %{{.*}} [
-// CHECK: i32 1, label %{{.*}}
-// CHECK: i32 -1, label %{{.*}}
+// CHECK: switch i64 %{{.*}}, label %{{.*}} [
+// CHECK: i64 1, label %{{.*}}
+// CHECK: i64 -1, label %{{.*}}
// CHECK: ]
fir.select_rank %a : i32 [1, ^bb2(%1:i32), -1, ^bb4, unit, ^bb3(%2:i32)]
^bb2(%3 : i32) :
More information about the flang-commits
mailing list