[llvm-commits] [PATCH][Instcombine, FAST-MATH] Some enhancement to Fmul

Shuxin Yang shuxin.llvm at gmail.com
Wed Jan 2 17:01:11 PST 2013


Hi,

    The attached patch is to implement following rules. (Please ignore 
the changes to APFloat.{h,cpp}).

    1. X/C1 * C2 => X * (C2/C1) (if C2/C1 is neither special FP nor 
denormal)
    2. X/C1 * C2 -> X/(C1/C2)   (if C2/C1 is either specical FP or 
denormal, but C1/C2 is a normal Fp)

     Let MDC denote multiplication or dividion with one & only one 
operand being a constant
   3. (MDC +/- C1) * C2 => (MDC * C2) +/- (C1 * C2)
     (so long as the constant-folding doesn't yield any denormal or 
special value)

Thanks
Shuxin
-------------- next part --------------
Index: test/Transforms/InstCombine/fast-math.ll
===================================================================
--- test/Transforms/InstCombine/fast-math.ll	(revision 171432)
+++ test/Transforms/InstCombine/fast-math.ll	(working copy)
@@ -160,3 +160,88 @@
 ; CHECK: @select3
 ; CHECK: fmul nnan nsz double %cond1, %x
 }
+
+; =========================================================================
+;
+;   Testing-cases about fmul begin
+;
+; =========================================================================
+
+; ((X*C1) + C2) * C3 => (X * (C1*C3)) + (C2*C3) (i.e. distribution)
+define float @fmul_distribute1(float %f1) {
+  %t1 = fmul float %f1, 6.0e+3
+  %t2 = fadd float %t1, 2.0e+3
+  %t3 = fmul fast float %t2, 5.0e+3
+  ret float %t3 
+; CHECK: @fmul_distribute1
+; CHECK: %1 = fmul fast float %f1, 3.000000e+07
+; CHECK: %t3 = fadd fast float %1, 1.000000e+07
+}
+
+; (X/C1 + C2) * C3 => X/(C1/C3) + C2*C3
+define double @fmul_distribute2(double %f1, double %f2) {
+  %t1 = fdiv double %f1, 3.0e+0
+  %t2 = fadd double %t1, 5.0e+1
+  ; 0x10000000000000 = DBL_MIN
+  %t3 = fmul fast double %t2, 0x10000000000000
+  ret double %t3
+
+; CHECK: @fmul_distribute2
+; CHECK: %1 = fdiv fast double %f1, 0x7FE8000000000000
+; CHECK: fadd fast double %1, 0x69000000000000
+}
+
+; 5.0e-1 * DBL_MIN yields denormal, so "(f1*3.0 + 5.0e-1) * DBL_MIN" cannot
+; be simplified into f1 * (3.0*DBL_MIN) + (5.0e-1*DBL_MIN)
+define double @fmul_distribute3(double %f1) {
+  %t1 = fdiv double %f1, 3.0e+0
+  %t2 = fadd double %t1, 5.0e-1
+  %t3 = fmul fast double %t2, 0x10000000000000
+  ret double %t3
+
+; CHECK: @fmul_distribute3
+; CHECK: fmul fast double %t2, 0x10000000000000
+}
+
+; C1/X * C2 => (C1*C2) / X
+define float @fmul2(float %f1) {
+  %t1 = fdiv float 2.0e+3, %f1 
+  %t3 = fmul fast float %t1, 6.0e+3
+  ret float %t3 
+; CHECK: @fmul2
+; CHECK: fdiv fast float 1.200000e+07, %f1
+}
+
+; X/C1 * C2 => X * (C2/C1) (if C2/C1 is normal Fp)
+define float @fmul3(float %f1, float %f2) {
+  %t1 = fdiv float %f1, 2.0e+3
+  %t3 = fmul fast float %t1, 6.0e+3
+  ret float %t3 
+; CHECK: @fmul3
+; CHECK: fmul fast float %f1, 3.000000e+00
+}
+
+; Rule "X/C1 * C2 => X * (C2/C1) is not applicable if C2/C1 is either a special
+; value of a denormal. The 0x3810000000000000 here take value FLT_MIN
+;
+define float @fmul4(float %f1, float %f2) {
+  %t1 = fdiv float %f1, 2.0e+3
+  %t3 = fmul fast float %t1, 0x3810000000000000
+  ret float %t3 
+; CHECK: @fmul4
+; CHECK: fmul fast float %t1, 0x3810000000000000
+}
+
+; X / C1 * C2 => X / (C2/C1) if  C1/C2 is either a special value of a denormal, 
+;  and C2/C1 is a normal value.
+; 
+define float @fmul5(float %f1, float %f2) {
+  %t1 = fdiv float %f1, 3.0e+0
+  %t3 = fmul fast float %t1, 0x3810000000000000
+  ret float %t3 
+; CHECK: @fmul5
+; CHECK: fdiv fast float %f1, 0x47E8000000000000
+}
+
+
+
Index: include/llvm/ADT/APFloat.h
===================================================================
--- include/llvm/ADT/APFloat.h	(revision 171432)
+++ include/llvm/ADT/APFloat.h	(working copy)
@@ -327,6 +327,7 @@
     bool isNegative() const { return sign; }
     bool isPosZero() const { return isZero() && !isNegative(); }
     bool isNegZero() const { return isZero() && isNegative(); }
+    bool isDenormal() const;
 
     APFloat& operator=(const APFloat &);
 
Index: lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
===================================================================
--- lib/Transforms/InstCombine/InstCombineMulDivRem.cpp	(revision 171432)
+++ lib/Transforms/InstCombine/InstCombineMulDivRem.cpp	(working copy)
@@ -291,10 +291,90 @@
      Y = I->getOperand(0);
 } 
 
+/// Helper function of InstCombiner::visitFMul(BinaryOperator(). It returns
+/// true iff the given value is FMul or FDiv with one and only one operand
+/// being a normal constant (i.e. not Zero/NaN/Infinity).
+static bool isFMulOrFDivWithConstant(Value *V) {
+  Instruction *I = dyn_cast<Instruction>(V);
+  if (!I || (I->getOpcode() != Instruction::FMul && 
+             I->getOpcode() != Instruction::FDiv)) {
+    return false;
+  }
+
+  ConstantFP *C0 = dyn_cast<ConstantFP>(I->getOperand(0));
+  ConstantFP *C1 = dyn_cast<ConstantFP>(I->getOperand(1));
+
+  if (C0 && C1)
+    return false;
+
+  return (C0 && C0->getValueAPF().isNormal()) ||
+         (C1 && C1->getValueAPF().isNormal());
+}
+
+static bool isNormalFp(const ConstantFP *C) {
+  const APFloat &Flt = C->getValueAPF();
+  return Flt.isNormal() && !Flt.isDenormal();
+}
+
+/// foldFMulConst() is a helper routine of InstCombiner::visitFMul().
+/// The input \p FMulOrDiv is a FMul/FDiv with one and only one operand
+/// being a constant (i.e. isFMulOrFDivWithConstant(FMulOrDiv) == true).
+/// This function is to simplify "FMulOrDiv * C" and returns the 
+/// resulting expression. Note that this function could return NULL in
+/// case the constants cannot be folded into a normal floating-point.
+/// 
+Value *InstCombiner::foldFMulConst
+  (Instruction *FMulOrDiv, ConstantFP *C, Instruction *InsertBefore) {
+  assert(isFMulOrFDivWithConstant(FMulOrDiv) && "V is invalid");
+
+  Value *Opnd0 = FMulOrDiv->getOperand(0);
+  Value *Opnd1 = FMulOrDiv->getOperand(1);
+
+  ConstantFP *C0 = dyn_cast<ConstantFP>(Opnd0);
+  ConstantFP *C1 = dyn_cast<ConstantFP>(Opnd1);
+
+  BinaryOperator *R = 0;
+
+  // (X * C0) * C => X * (C0*C)
+  if (FMulOrDiv->getOpcode() == Instruction::FMul) {
+    Constant *F = ConstantExpr::getFMul(C1 ? C1 : C0, C);
+    if (isNormalFp(cast<ConstantFP>(F)))
+      R = BinaryOperator::CreateFMul(C1 ? Opnd0 : Opnd1, F);
+  } else {
+    if (C0) {
+      // (C0 / X) * C => (C0 * C) / X
+      ConstantFP *F = cast<ConstantFP>(ConstantExpr::getFMul(C0, C));
+      if (isNormalFp(F))
+        R = BinaryOperator::CreateFDiv(F, Opnd1);
+    } else {
+      // (X / C1) * C => X * (C/C1) if C/C1 is not a denormal
+      ConstantFP *F = cast<ConstantFP>(ConstantExpr::getFDiv(C, C1));
+      if (isNormalFp(F)) {
+        R = BinaryOperator::CreateFMul(Opnd0, F);
+      } else {
+        // (X / C1) * C => X / (C1/C) 
+        Constant *F = ConstantExpr::getFDiv(C1, C);
+        if (isNormalFp(cast<ConstantFP>(F)))
+          R = BinaryOperator::CreateFDiv(Opnd0, F);
+      }
+    }
+  }
+
+  if (R) {
+    R->setHasUnsafeAlgebra(true);
+    InsertNewInstWith(R, *InsertBefore);
+  }
+
+  return R;
+}
+
 Instruction *InstCombiner::visitFMul(BinaryOperator &I) {
   bool Changed = SimplifyAssociativeOrCommutative(I);
   Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
 
+  if (isa<Constant>(Op0))
+    std::swap(Op0, Op1);
+
   if (Value *V = SimplifyFMulInst(Op0, Op1, I.getFastMathFlags(), TD))
     return ReplaceInstUsesWith(I, V);
 
@@ -308,6 +388,53 @@
     if (isa<PHINode>(Op0))
       if (Instruction *NV = FoldOpIntoPhi(I))
         return NV;
+
+    ConstantFP *C = dyn_cast<ConstantFP>(Op1);
+    if (C && I.hasUnsafeAlgebra() && C->getValueAPF().isNormal()) {
+      // Let MDC denote an expression in one of these forms:
+      // X * C, C/X, X/C, where C is a constant.
+      //
+      // Try to simplify "MDC * Constant"
+      if (isFMulOrFDivWithConstant(Op0)) {
+        Value *V = foldFMulConst(cast<Instruction>(Op0), C, &I);
+        if (V)
+          return ReplaceInstUsesWith(I, V);
+      }
+
+      // (MDC +/- C1) * C2 => (MDC * C2) +/- (C1 * C2)
+      Instruction *FAddSub = dyn_cast<Instruction>(Op0);
+      if (FAddSub &&
+          (FAddSub->getOpcode() == Instruction::FAdd ||
+           FAddSub->getOpcode() == Instruction::FSub)) {
+        Value *Opnd0 = FAddSub->getOperand(0);
+        Value *Opnd1 = FAddSub->getOperand(1);
+        ConstantFP *C0 = dyn_cast<ConstantFP>(Opnd0);
+        ConstantFP *C1 = dyn_cast<ConstantFP>(Opnd1);
+        bool Swap = false;
+        if (C0) {
+          std::swap(C0, C1); std::swap(Opnd0, Opnd1); Swap = true; 
+        }
+
+        if (C1 && C1->getValueAPF().isNormal() &&
+            isFMulOrFDivWithConstant(Opnd0)) {
+          Value *M0 = ConstantExpr::getFMul(C1, C);
+          Value *M1 = isNormalFp(cast<ConstantFP>(M0)) ? 
+                      foldFMulConst(cast<Instruction>(Opnd0), C, &I) :
+                      0;
+          if (M0 && M1) {
+            if (Swap && FAddSub->getOpcode() == Instruction::FSub)
+              std::swap(M0, M1);
+
+            Value *R = (FAddSub->getOpcode() == Instruction::FAdd) ?
+                        BinaryOperator::CreateFAdd(M0, M1) :
+                        BinaryOperator::CreateFSub(M0, M1);
+            Instruction *RI = cast<Instruction>(R);
+            RI->setHasUnsafeAlgebra(true);
+            return RI;
+          }
+        }
+      }
+    }
   }
 
   if (Value *Op0v = dyn_castFNegVal(Op0))     // -X * -Y = X*Y
Index: lib/Transforms/InstCombine/InstCombine.h
===================================================================
--- lib/Transforms/InstCombine/InstCombine.h	(revision 171432)
+++ lib/Transforms/InstCombine/InstCombine.h	(working copy)
@@ -114,6 +114,8 @@
   Instruction *visitSub(BinaryOperator &I);
   Instruction *visitFSub(BinaryOperator &I);
   Instruction *visitMul(BinaryOperator &I);
+  Value *foldFMulConst(Instruction *FMulOrDiv, ConstantFP *C,
+                       Instruction *InsertBefore);
   Instruction *visitFMul(BinaryOperator &I);
   Instruction *visitURem(BinaryOperator &I);
   Instruction *visitSRem(BinaryOperator &I);
Index: lib/Support/APFloat.cpp
===================================================================
--- lib/Support/APFloat.cpp	(revision 171432)
+++ lib/Support/APFloat.cpp	(working copy)
@@ -697,6 +697,13 @@
 }
 
 bool
+APFloat::isDenormal() const {
+  return isNormal() && (exponent == semantics->minExponent) &&
+         (APInt::tcExtractBit(significandParts(), 
+                              semantics->precision - 1) == 0);
+}
+
+bool
 APFloat::bitwiseIsEqual(const APFloat &rhs) const {
   if (this == &rhs)
     return true;


More information about the llvm-commits mailing list