[llvm] bdcbc1e - [LLVM][InstCombine] Preserve vector types when shrinking FP constants. (#163598)
via llvm-commits
llvm-commits at lists.llvm.org
Tue Oct 21 03:55:25 PDT 2025
Author: Paul Walker
Date: 2025-10-21T11:55:21+01:00
New Revision: bdcbc1e5e7df220d9ae2afa6529268524fdde8ca
URL: https://github.com/llvm/llvm-project/commit/bdcbc1e5e7df220d9ae2afa6529268524fdde8ca
DIFF: https://github.com/llvm/llvm-project/commit/bdcbc1e5e7df220d9ae2afa6529268524fdde8ca.diff
LOG: [LLVM][InstCombine] Preserve vector types when shrinking FP constants. (#163598)
While my objective is to make the shrinkfp path safe for ConstantFP
based splats I discovered the following issues also affect
ConstantVector based splats:
1. PreferBFloat is not set for bfloat vectors.
2. getMinimumFPType() returns a scalar type for vector constants where
getSplatValue() is successful.
Added:
Modified:
llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
llvm/test/Transforms/InstCombine/fpextend.ll
Removed:
################################################################################
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
index cdc559b489e9d..9b9fe265c7bce 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
@@ -1643,33 +1643,46 @@ Instruction *InstCombinerImpl::visitSExt(SExtInst &Sext) {
/// Return a Constant* for the specified floating-point constant if it fits
/// in the specified FP type without changing its value.
-static bool fitsInFPType(ConstantFP *CFP, const fltSemantics &Sem) {
+static bool fitsInFPType(APFloat F, const fltSemantics &Sem) {
bool losesInfo;
- APFloat F = CFP->getValueAPF();
(void)F.convert(Sem, APFloat::rmNearestTiesToEven, &losesInfo);
return !losesInfo;
}
-static Type *shrinkFPConstant(ConstantFP *CFP, bool PreferBFloat) {
- if (CFP->getType() == Type::getPPC_FP128Ty(CFP->getContext()))
- return nullptr; // No constant folding of this.
+static Type *shrinkFPConstant(LLVMContext &Ctx, const APFloat &F,
+ bool PreferBFloat) {
// See if the value can be truncated to bfloat and then reextended.
- if (PreferBFloat && fitsInFPType(CFP, APFloat::BFloat()))
- return Type::getBFloatTy(CFP->getContext());
+ if (PreferBFloat && fitsInFPType(F, APFloat::BFloat()))
+ return Type::getBFloatTy(Ctx);
// See if the value can be truncated to half and then reextended.
- if (!PreferBFloat && fitsInFPType(CFP, APFloat::IEEEhalf()))
- return Type::getHalfTy(CFP->getContext());
+ if (!PreferBFloat && fitsInFPType(F, APFloat::IEEEhalf()))
+ return Type::getHalfTy(Ctx);
// See if the value can be truncated to float and then reextended.
- if (fitsInFPType(CFP, APFloat::IEEEsingle()))
- return Type::getFloatTy(CFP->getContext());
- if (CFP->getType()->isDoubleTy())
- return nullptr; // Won't shrink.
- if (fitsInFPType(CFP, APFloat::IEEEdouble()))
- return Type::getDoubleTy(CFP->getContext());
+ if (fitsInFPType(F, APFloat::IEEEsingle()))
+ return Type::getFloatTy(Ctx);
+ if (&F.getSemantics() == &APFloat::IEEEdouble())
+ return nullptr; // Won't shrink.
+ // See if the value can be truncated to double and then reextended.
+ if (fitsInFPType(F, APFloat::IEEEdouble()))
+ return Type::getDoubleTy(Ctx);
// Don't try to shrink to various long double types.
return nullptr;
}
+static Type *shrinkFPConstant(ConstantFP *CFP, bool PreferBFloat) {
+ Type *Ty = CFP->getType();
+ if (Ty->getScalarType()->isPPC_FP128Ty())
+ return nullptr; // No constant folding of this.
+
+ Type *ShrinkTy =
+ shrinkFPConstant(CFP->getContext(), CFP->getValueAPF(), PreferBFloat);
+ if (ShrinkTy)
+ if (auto *VecTy = dyn_cast<VectorType>(Ty))
+ ShrinkTy = VectorType::get(ShrinkTy, VecTy);
+
+ return ShrinkTy;
+}
+
// Determine if this is a vector of ConstantFPs and if so, return the minimal
// type we can safely truncate all elements to.
static Type *shrinkFPConstantVector(Value *V, bool PreferBFloat) {
@@ -1720,10 +1733,10 @@ static Type *getMinimumFPType(Value *V, bool PreferBFloat) {
// Try to shrink scalable and fixed splat vectors.
if (auto *FPC = dyn_cast<Constant>(V))
- if (isa<VectorType>(V->getType()))
+ if (auto *VTy = dyn_cast<VectorType>(V->getType()))
if (auto *Splat = dyn_cast_or_null<ConstantFP>(FPC->getSplatValue()))
if (Type *T = shrinkFPConstant(Splat, PreferBFloat))
- return T;
+ return VectorType::get(T, VTy);
// Try to shrink a vector of FP constants. This returns nullptr on scalable
// vectors
@@ -1796,10 +1809,9 @@ Instruction *InstCombinerImpl::visitFPTrunc(FPTruncInst &FPT) {
Type *Ty = FPT.getType();
auto *BO = dyn_cast<BinaryOperator>(FPT.getOperand(0));
if (BO && BO->hasOneUse()) {
- Type *LHSMinType =
- getMinimumFPType(BO->getOperand(0), /*PreferBFloat=*/Ty->isBFloatTy());
- Type *RHSMinType =
- getMinimumFPType(BO->getOperand(1), /*PreferBFloat=*/Ty->isBFloatTy());
+ bool PreferBFloat = Ty->getScalarType()->isBFloatTy();
+ Type *LHSMinType = getMinimumFPType(BO->getOperand(0), PreferBFloat);
+ Type *RHSMinType = getMinimumFPType(BO->getOperand(1), PreferBFloat);
unsigned OpWidth = BO->getType()->getFPMantissaWidth();
unsigned LHSWidth = LHSMinType->getFPMantissaWidth();
unsigned RHSWidth = RHSMinType->getFPMantissaWidth();
diff --git a/llvm/test/Transforms/InstCombine/fpextend.ll b/llvm/test/Transforms/InstCombine/fpextend.ll
index 9125339c00ecf..a65b73b1ca75a 100644
--- a/llvm/test/Transforms/InstCombine/fpextend.ll
+++ b/llvm/test/Transforms/InstCombine/fpextend.ll
@@ -1,5 +1,6 @@
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
; RUN: opt < %s -passes=instcombine -S | FileCheck %s
+; RUN: opt < %s -passes=instcombine -use-constant-fp-for-fixed-length-splat -S | FileCheck %s
define float @test(float %x) nounwind {
; CHECK-LABEL: @test(
@@ -449,6 +450,28 @@ define bfloat @bf16_frem(bfloat %x) {
ret bfloat %t3
}
+define <4 x bfloat> @v4bf16_frem_x_const(<4 x bfloat> %x) {
+; CHECK-LABEL: @v4bf16_frem_x_const(
+; CHECK-NEXT: [[TMP1:%.*]] = frem <4 x bfloat> [[X:%.*]], splat (bfloat 0xR40C9)
+; CHECK-NEXT: ret <4 x bfloat> [[TMP1]]
+;
+ %t1 = fpext <4 x bfloat> %x to <4 x float>
+ %t2 = frem <4 x float> %t1, splat(float 6.281250e+00)
+ %t3 = fptrunc <4 x float> %t2 to <4 x bfloat>
+ ret <4 x bfloat> %t3
+}
+
+define <4 x bfloat> @v4bf16_frem_const_x(<4 x bfloat> %x) {
+; CHECK-LABEL: @v4bf16_frem_const_x(
+; CHECK-NEXT: [[TMP1:%.*]] = frem <4 x bfloat> splat (bfloat 0xR40C9), [[X:%.*]]
+; CHECK-NEXT: ret <4 x bfloat> [[TMP1]]
+;
+ %t1 = fpext <4 x bfloat> %x to <4 x float>
+ %t2 = frem <4 x float> splat(float 6.281250e+00), %t1
+ %t3 = fptrunc <4 x float> %t2 to <4 x bfloat>
+ ret <4 x bfloat> %t3
+}
+
define <4 x float> @v4f32_fadd(<4 x float> %a) {
; CHECK-LABEL: @v4f32_fadd(
; CHECK-NEXT: [[TMP1:%.*]] = fadd <4 x float> [[A:%.*]], splat (float -1.000000e+00)
@@ -459,3 +482,16 @@ define <4 x float> @v4f32_fadd(<4 x float> %a) {
%5 = fptrunc <4 x double> %4 to <4 x float>
ret <4 x float> %5
}
+
+define <4 x float> @v4f32_fadd_const_not_shrinkable(<4 x float> %a) {
+; CHECK-LABEL: @v4f32_fadd_const_not_shrinkable(
+; CHECK-NEXT: [[TMP1:%.*]] = fpext <4 x float> [[A:%.*]] to <4 x double>
+; CHECK-NEXT: [[TMP2:%.*]] = fadd <4 x double> [[TMP1]], splat (double -1.000000e+100)
+; CHECK-NEXT: [[TMP3:%.*]] = fptrunc <4 x double> [[TMP2]] to <4 x float>
+; CHECK-NEXT: ret <4 x float> [[TMP3]]
+;
+ %2 = fpext <4 x float> %a to <4 x double>
+ %4 = fadd <4 x double> %2, splat (double -1.000000e+100)
+ %5 = fptrunc <4 x double> %4 to <4 x float>
+ ret <4 x float> %5
+}
More information about the llvm-commits
mailing list