[llvm] r326617 - [InstCombine] Rewrite the binary op shrinking in visitFPTrunc to avoid creating overly small ConstantFPs that we'll just need to extend again.

Craig Topper via llvm-commits llvm-commits at lists.llvm.org
Fri Mar 2 13:25:19 PST 2018


Author: ctopper
Date: Fri Mar  2 13:25:18 2018
New Revision: 326617

URL: http://llvm.org/viewvc/llvm-project?rev=326617&view=rev
Log:
[InstCombine] Rewrite the binary op shrinking in visitFPTrunc to avoid creating overly small ConstantFPs that we'll just need to extend again.

Instead of returning the smaller FP constant we now return the minimal Type the constant can fit into. We also return the Type of the input to any fp extends. The legality checks are then done on just the size of these Types. If we find something profitable we then emit FPTruncs in front of the smaller binop and assume those FPTruncs will be constant folded or combined with any ConstantFPs or fpextends.

Differential Revision: https://reviews.llvm.org/D44038

Modified:
    llvm/trunk/lib/Transforms/InstCombine/InstCombineCasts.cpp

Modified: llvm/trunk/lib/Transforms/InstCombine/InstCombineCasts.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Transforms/InstCombine/InstCombineCasts.cpp?rev=326617&r1=326616&r2=326617&view=diff
==============================================================================
--- llvm/trunk/lib/Transforms/InstCombine/InstCombineCasts.cpp (original)
+++ llvm/trunk/lib/Transforms/InstCombine/InstCombineCasts.cpp Fri Mar  2 13:25:18 2018
@@ -1411,45 +1411,43 @@ Instruction *InstCombiner::visitSExt(SEx
 
 /// Return a Constant* for the specified floating-point constant if it fits
 /// in the specified FP type without changing its value.
-static Constant *fitsInFPType(ConstantFP *CFP, const fltSemantics &Sem) {
+static bool fitsInFPType(ConstantFP *CFP, const fltSemantics &Sem) {
   bool losesInfo;
   APFloat F = CFP->getValueAPF();
   (void)F.convert(Sem, APFloat::rmNearestTiesToEven, &losesInfo);
-  if (!losesInfo)
-    return ConstantFP::get(CFP->getContext(), F);
-  return nullptr;
+  return !losesInfo;
 }
 
-static Constant *shrinkFPConstant(ConstantFP *CFP) {
+static Type *shrinkFPConstant(ConstantFP *CFP) {
   if (CFP->getType() == Type::getPPC_FP128Ty(CFP->getContext()))
     return nullptr;  // No constant folding of this.
   // See if the value can be truncated to half and then reextended.
-  if (Constant *NewCFP = fitsInFPType(CFP, APFloat::IEEEhalf()))
-    return NewCFP;
+  if (fitsInFPType(CFP, APFloat::IEEEhalf()))
+    return Type::getHalfTy(CFP->getContext());
   // See if the value can be truncated to float and then reextended.
-  if (Constant *NewCFP = fitsInFPType(CFP, APFloat::IEEEsingle()))
-    return NewCFP;
+  if (fitsInFPType(CFP, APFloat::IEEEsingle()))
+    return Type::getFloatTy(CFP->getContext());
   if (CFP->getType()->isDoubleTy())
     return nullptr;  // Won't shrink.
-  if (Constant *NewCFP = fitsInFPType(CFP, APFloat::IEEEdouble()))
-    return NewCFP;
+  if (fitsInFPType(CFP, APFloat::IEEEdouble()))
+    return Type::getDoubleTy(CFP->getContext());
   // Don't try to shrink to various long double types.
   return nullptr;
 }
 
-/// Look through floating-point extensions until we get the source value.
-static Value *lookThroughFPExtensions(Value *V) {
-  while (auto *FPExt = dyn_cast<FPExtInst>(V))
-    V = FPExt->getOperand(0);
+/// Find the minimum FP type we can safely truncate to.
+static Type *getMinimumFPType(Value *V) {
+  if (auto *FPExt = dyn_cast<FPExtInst>(V))
+    return FPExt->getOperand(0)->getType();
 
   // If this value is a constant, return the constant in the smallest FP type
   // that can accurately represent it.  This allows us to turn
   // (float)((double)X+2.0) into x+2.0f.
   if (auto *CFP = dyn_cast<ConstantFP>(V))
-    if (Constant *NewCFP = shrinkFPConstant(CFP))
-      return NewCFP;
+    if (Type *T = shrinkFPConstant(CFP))
+      return T;
 
-  return V;
+  return V->getType();
 }
 
 Instruction *InstCombiner::visitFPTrunc(FPTruncInst &CI) {
@@ -1464,11 +1462,11 @@ Instruction *InstCombiner::visitFPTrunc(
   // is explained below in the various case statements.
   BinaryOperator *OpI = dyn_cast<BinaryOperator>(CI.getOperand(0));
   if (OpI && OpI->hasOneUse()) {
-    Value *LHSOrig = lookThroughFPExtensions(OpI->getOperand(0));
-    Value *RHSOrig = lookThroughFPExtensions(OpI->getOperand(1));
+    Type *LHSMinType = getMinimumFPType(OpI->getOperand(0));
+    Type *RHSMinType = getMinimumFPType(OpI->getOperand(1));
     unsigned OpWidth = OpI->getType()->getFPMantissaWidth();
-    unsigned LHSWidth = LHSOrig->getType()->getFPMantissaWidth();
-    unsigned RHSWidth = RHSOrig->getType()->getFPMantissaWidth();
+    unsigned LHSWidth = LHSMinType->getFPMantissaWidth();
+    unsigned RHSWidth = RHSMinType->getFPMantissaWidth();
     unsigned SrcWidth = std::max(LHSWidth, RHSWidth);
     unsigned DstWidth = CI.getType()->getFPMantissaWidth();
     switch (OpI->getOpcode()) {
@@ -1494,12 +1492,10 @@ Instruction *InstCombiner::visitFPTrunc(
         // could be tightened for those cases, but they are rare (the main
         // case of interest here is (float)((double)float + float)).
         if (OpWidth >= 2*DstWidth+1 && DstWidth >= SrcWidth) {
-          if (LHSOrig->getType() != CI.getType())
-            LHSOrig = Builder.CreateFPExt(LHSOrig, CI.getType());
-          if (RHSOrig->getType() != CI.getType())
-            RHSOrig = Builder.CreateFPExt(RHSOrig, CI.getType());
+          Value *LHS = Builder.CreateFPTrunc(OpI->getOperand(0), CI.getType());
+          Value *RHS = Builder.CreateFPTrunc(OpI->getOperand(1), CI.getType());
           Instruction *RI =
-            BinaryOperator::Create(OpI->getOpcode(), LHSOrig, RHSOrig);
+            BinaryOperator::Create(OpI->getOpcode(), LHS, RHS);
           RI->copyFastMathFlags(OpI);
           return RI;
         }
@@ -1511,12 +1507,10 @@ Instruction *InstCombiner::visitFPTrunc(
         // rounding can possibly occur; we can safely perform the operation
         // in the destination format if it can represent both sources.
         if (OpWidth >= LHSWidth + RHSWidth && DstWidth >= SrcWidth) {
-          if (LHSOrig->getType() != CI.getType())
-            LHSOrig = Builder.CreateFPExt(LHSOrig, CI.getType());
-          if (RHSOrig->getType() != CI.getType())
-            RHSOrig = Builder.CreateFPExt(RHSOrig, CI.getType());
+          Value *LHS = Builder.CreateFPTrunc(OpI->getOperand(0), CI.getType());
+          Value *RHS = Builder.CreateFPTrunc(OpI->getOperand(1), CI.getType());
           Instruction *RI =
-            BinaryOperator::CreateFMul(LHSOrig, RHSOrig);
+            BinaryOperator::CreateFMul(LHS, RHS);
           RI->copyFastMathFlags(OpI);
           return RI;
         }
@@ -1529,33 +1523,35 @@ Instruction *InstCombiner::visitFPTrunc(
         // condition used here is a good conservative first pass.
         // TODO: Tighten bound via rigorous analysis of the unbalanced case.
         if (OpWidth >= 2*DstWidth && DstWidth >= SrcWidth) {
-          if (LHSOrig->getType() != CI.getType())
-            LHSOrig = Builder.CreateFPExt(LHSOrig, CI.getType());
-          if (RHSOrig->getType() != CI.getType())
-            RHSOrig = Builder.CreateFPExt(RHSOrig, CI.getType());
+          Value *LHS = Builder.CreateFPTrunc(OpI->getOperand(0), CI.getType());
+          Value *RHS = Builder.CreateFPTrunc(OpI->getOperand(1), CI.getType());
           Instruction *RI =
-            BinaryOperator::CreateFDiv(LHSOrig, RHSOrig);
+            BinaryOperator::CreateFDiv(LHS, RHS);
           RI->copyFastMathFlags(OpI);
           return RI;
         }
         break;
-      case Instruction::FRem:
+      case Instruction::FRem: {
         // Remainder is straightforward.  Remainder is always exact, so the
         // type of OpI doesn't enter into things at all.  We simply evaluate
         // in whichever source type is larger, then convert to the
         // destination type.
         if (SrcWidth == OpWidth)
           break;
-        if (LHSWidth < SrcWidth)
-          LHSOrig = Builder.CreateFPExt(LHSOrig, RHSOrig->getType());
-        else if (RHSWidth <= SrcWidth)
-          RHSOrig = Builder.CreateFPExt(RHSOrig, LHSOrig->getType());
-        if (LHSOrig != OpI->getOperand(0) || RHSOrig != OpI->getOperand(1)) {
-          Value *ExactResult = Builder.CreateFRem(LHSOrig, RHSOrig);
-          if (Instruction *RI = dyn_cast<Instruction>(ExactResult))
-            RI->copyFastMathFlags(OpI);
-          return CastInst::CreateFPCast(ExactResult, CI.getType());
+        Value *LHS, *RHS;
+        if (LHSWidth == SrcWidth) {
+           LHS = Builder.CreateFPTrunc(OpI->getOperand(0), LHSMinType);
+           RHS = Builder.CreateFPTrunc(OpI->getOperand(1), LHSMinType);
+        } else {
+           LHS = Builder.CreateFPTrunc(OpI->getOperand(0), RHSMinType);
+           RHS = Builder.CreateFPTrunc(OpI->getOperand(1), RHSMinType);
         }
+
+        Value *ExactResult = Builder.CreateFRem(LHS, RHS);
+        if (Instruction *RI = dyn_cast<Instruction>(ExactResult))
+          RI->copyFastMathFlags(OpI);
+        return CastInst::CreateFPCast(ExactResult, CI.getType());
+      }
     }
 
     // (fptrunc (fneg x)) -> (fneg (fptrunc x))




More information about the llvm-commits mailing list