[llvm] [InstCombine] Pick bfloat over half when shrinking ops that started with an fpext from bfloat (PR #82493)

Benjamin Kramer via llvm-commits llvm-commits at lists.llvm.org
Wed Feb 21 06:09:27 PST 2024


https://github.com/d0k created https://github.com/llvm/llvm-project/pull/82493

This fixes the case where we would shrink an frem to half and then bitcast to bfloat, producing invalid results. The transformation was written under the assumption that there is only one type with a given bit width.

Also add a strategic assert to CastInst::CreateFPCast to turn this miscompilation into a crash.

>From 446d270802421663fb2ead6f5d7b00e0e9ca9f01 Mon Sep 17 00:00:00 2001
From: Benjamin Kramer <benny.kra at googlemail.com>
Date: Wed, 21 Feb 2024 15:03:23 +0100
Subject: [PATCH] [InstCombine] Pick bfloat over half when shrinking ops that
 started with a fpext from bfloat

This fixes the case where we would shrink an frem to half and then
bitcast to bfloat, producing invalid results. This code was written
under the assumption that there is only one type with a given bit width.

Also add a strategic assert to CastInst::CreateFPCast to turn this
miscompilation into a crash.
---
 llvm/lib/IR/Instructions.cpp                  |  1 +
 .../InstCombine/InstCombineCasts.cpp          | 23 +++++++++++--------
 llvm/test/Transforms/InstCombine/fpextend.ll  | 11 +++++++++
 3 files changed, 26 insertions(+), 9 deletions(-)

diff --git a/llvm/lib/IR/Instructions.cpp b/llvm/lib/IR/Instructions.cpp
index ce0df53d9ffb9a..fc5c9b201487e0 100644
--- a/llvm/lib/IR/Instructions.cpp
+++ b/llvm/lib/IR/Instructions.cpp
@@ -3525,6 +3525,7 @@ CastInst *CastInst::CreateFPCast(Value *C, Type *Ty,
          "Invalid cast");
   unsigned SrcBits = C->getType()->getScalarSizeInBits();
   unsigned DstBits = Ty->getScalarSizeInBits();
+  assert((C->getType() == Ty || SrcBits != DstBits) && "Invalid cast");
   Instruction::CastOps opcode =
     (SrcBits == DstBits ? Instruction::BitCast :
      (SrcBits > DstBits ? Instruction::FPTrunc : Instruction::FPExt));
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
index ed47de287302ed..33ed1d5575375a 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
@@ -1543,11 +1543,14 @@ static bool fitsInFPType(ConstantFP *CFP, const fltSemantics &Sem) {
   return !losesInfo;
 }
 
-static Type *shrinkFPConstant(ConstantFP *CFP) {
+static Type *shrinkFPConstant(ConstantFP *CFP, bool PreferBFloat) {
   if (CFP->getType() == Type::getPPC_FP128Ty(CFP->getContext()))
     return nullptr;  // No constant folding of this.
+  // See if the value can be truncated to bfloat and then reextended.
+  if (PreferBFloat && fitsInFPType(CFP, APFloat::BFloat()))
+    return Type::getBFloatTy(CFP->getContext());
   // See if the value can be truncated to half and then reextended.
-  if (fitsInFPType(CFP, APFloat::IEEEhalf()))
+  if (!PreferBFloat && fitsInFPType(CFP, APFloat::IEEEhalf()))
     return Type::getHalfTy(CFP->getContext());
   // See if the value can be truncated to float and then reextended.
   if (fitsInFPType(CFP, APFloat::IEEEsingle()))
@@ -1562,7 +1565,7 @@ static Type *shrinkFPConstant(ConstantFP *CFP) {
 
 // Determine if this is a vector of ConstantFPs and if so, return the minimal
 // type we can safely truncate all elements to.
-static Type *shrinkFPConstantVector(Value *V) {
+static Type *shrinkFPConstantVector(Value *V, bool PreferBFloat) {
   auto *CV = dyn_cast<Constant>(V);
   auto *CVVTy = dyn_cast<FixedVectorType>(V->getType());
   if (!CV || !CVVTy)
@@ -1582,7 +1585,7 @@ static Type *shrinkFPConstantVector(Value *V) {
     if (!CFP)
       return nullptr;
 
-    Type *T = shrinkFPConstant(CFP);
+    Type *T = shrinkFPConstant(CFP, PreferBFloat);
     if (!T)
       return nullptr;
 
@@ -1597,7 +1600,7 @@ static Type *shrinkFPConstantVector(Value *V) {
 }
 
 /// Find the minimum FP type we can safely truncate to.
-static Type *getMinimumFPType(Value *V) {
+static Type *getMinimumFPType(Value *V, bool PreferBFloat) {
   if (auto *FPExt = dyn_cast<FPExtInst>(V))
     return FPExt->getOperand(0)->getType();
 
@@ -1605,7 +1608,7 @@ static Type *getMinimumFPType(Value *V) {
   // that can accurately represent it.  This allows us to turn
   // (float)((double)X+2.0) into x+2.0f.
   if (auto *CFP = dyn_cast<ConstantFP>(V))
-    if (Type *T = shrinkFPConstant(CFP))
+    if (Type *T = shrinkFPConstant(CFP, PreferBFloat))
       return T;
 
   // We can only correctly find a minimum type for a scalable vector when it is
@@ -1617,7 +1620,7 @@ static Type *getMinimumFPType(Value *V) {
 
   // Try to shrink a vector of FP constants. This returns nullptr on scalable
   // vectors
-  if (Type *T = shrinkFPConstantVector(V))
+  if (Type *T = shrinkFPConstantVector(V, PreferBFloat))
     return T;
 
   return V->getType();
@@ -1686,8 +1689,10 @@ Instruction *InstCombinerImpl::visitFPTrunc(FPTruncInst &FPT) {
   Type *Ty = FPT.getType();
   auto *BO = dyn_cast<BinaryOperator>(FPT.getOperand(0));
   if (BO && BO->hasOneUse()) {
-    Type *LHSMinType = getMinimumFPType(BO->getOperand(0));
-    Type *RHSMinType = getMinimumFPType(BO->getOperand(1));
+    Type *LHSMinType =
+        getMinimumFPType(BO->getOperand(0), /*PreferBFloat=*/Ty->isBFloatTy());
+    Type *RHSMinType =
+        getMinimumFPType(BO->getOperand(1), /*PreferBFloat=*/Ty->isBFloatTy());
     unsigned OpWidth = BO->getType()->getFPMantissaWidth();
     unsigned LHSWidth = LHSMinType->getFPMantissaWidth();
     unsigned RHSWidth = RHSMinType->getFPMantissaWidth();
diff --git a/llvm/test/Transforms/InstCombine/fpextend.ll b/llvm/test/Transforms/InstCombine/fpextend.ll
index a41f2a4ca300f6..f96b902c503ed4 100644
--- a/llvm/test/Transforms/InstCombine/fpextend.ll
+++ b/llvm/test/Transforms/InstCombine/fpextend.ll
@@ -437,3 +437,14 @@ define half @bf16_to_f32_to_f16(bfloat %a) nounwind {
   %z = fptrunc float %y to half
   ret half %z
 }
+
+define bfloat @bf16_frem(bfloat %x) nounwind {
+; CHECK-LABEL: @bf16_frem(
+; CHECK-NEXT:    [[FREM:%.*]] = frem bfloat [[X:%.*]], 0xR40C9
+; CHECK-NEXT:    ret bfloat [[FREM]]
+;
+  %t1 = fpext bfloat %x to float
+  %t2 = frem float %t1, 6.281250e+00
+  %t3 = fptrunc float %t2 to bfloat
+  ret bfloat %t3
+}



More information about the llvm-commits mailing list