[llvm-commits] [PATCH][instcombine, Fastmath] Enhancement to FP multiplication

Shuxin Yang shuxin.llvm at gmail.com
Tue Jan 8 16:01:30 PST 2013


Hi,

    This patch implements following InstCombine rules:

    1. (select cond ? V1 : V2) * X => select cond ? V1*X : V2*X
        if both V1*X and V2*X can be simplified.

       This is a general version of "X * select c ? 1 : 0 => select c ? 
x : 0".

    2. (-0.0 - X ) * Y => -0.0 - (X * Y)
       if  expression "-0.0 - X" only has on reference. This rule is 
applicable both to restrictive and relaxed FP mode.

       The purpose is to hoist minus sign as high as possible in an 
attempt to reveal some
       opt opportunities in the enclosing supper-expressions.

    3. (X*Y) * X => (X*X) * Y if X!=Y and the expression is flagged with 
"UnsafeAlgebra".

      The purpose of this transformation is two-fold:
      a) to form a power expression (of X).
      b) potentially shorten the critical path: After transformation, the
        latency of the instruction Y is amortized by the expression of X*X,
        and therefore Y is in a "less critical" position compared to what it
       was before the transformation.

    These rules are implemented in a two-iteration loop (see bellow), 
obviating the need of
    of examining symmetric situation, and to some extend reducing clutter.

Thank you for review.

Shuxin
-------------- next part --------------
Index: test/Transforms/InstCombine/fast-math.ll
===================================================================
--- test/Transforms/InstCombine/fast-math.ll	(revision 171814)
+++ test/Transforms/InstCombine/fast-math.ll	(working copy)
@@ -138,7 +138,7 @@
   %add = fadd double %mul, %y
   ret double %add
 ; CHECK: @select1
-; CHECK: select i1 %tobool, double %x, double 0.000000e+00
+; CHECK: select nnan nsz i1 %tobool, double %x, double 0.000000e+00
 }
 
 define double @select2(i32 %cond, double %x, double %y) {
@@ -148,7 +148,7 @@
   %add = fadd double %mul, %y
   ret double %add
 ; CHECK: @select2
-; CHECK: select i1 %tobool, double 0.000000e+00, double %x
+; CHECK: select nnan nsz i1 %tobool, double 0.000000e+00, double %x
 }
 
 define double @select3(i32 %cond, double %x, double %y) {
@@ -161,6 +161,18 @@
 ; CHECK: fmul nnan nsz double %cond1, %x
 }
 
+define double @select4(i32 %cond, double %x, double %y) {
+   %tobool = icmp ne i32 %cond, 0
+   %tobool2 = icmp ne i32 %cond, 1
+   %v0 = select i1 %tobool2, double %x, double %y
+   %mulopnd = select i1 %tobool, double 1.000000e+00, double 0.000000e+00
+   %mul = fmul nnan nsz double %v0, %mulopnd
+   %add = fadd double %mul, %y
+   ret double %add
+; CHECK: @select4
+; CHECK: select nnan nsz i1 %tobool, double %v0, double 0.000000e+00
+}
+
 ; =========================================================================
 ;
 ;   Testing-cases about fmul begin
@@ -243,5 +255,22 @@
 ; CHECK: fdiv fast float %f1, 0x47E8000000000000
 }
 
+; (X*Y) * X => (X*X) * Y
+define float @fmul6(float %f1, float %f2) {
+  %mul = fmul float %f1, %f2
+  %mul1 = fmul fast float %mul, %f1
+  ret float %mul1
+; CHECK: @fmul6
+; CHECK: fmul fast float %f1, %f1 
+}
 
 
+; "(X*Y) * X => (X*X) * Y" is disabled if "X*Y" has multiple uses
+define float @fmul7(float %f1, float %f2) {
+  %mul = fmul float %f1, %f2
+  %mul1 = fmul fast float %mul, %f1
+  %add = fadd float %mul1, %mul
+  ret float %add
+; CHECK: @fmul7
+; CHECK: fmul fast float %mul, %f1
+}
Index: test/Transforms/InstCombine/fmul.ll
===================================================================
--- test/Transforms/InstCombine/fmul.ll	(revision 0)
+++ test/Transforms/InstCombine/fmul.ll	(revision 0)
@@ -0,0 +1,42 @@
+; RUN: opt -S -instcombine < %s | FileCheck %s
+
+; (-0.0 - X) * C => X * -C
+define float @test1(float %x) {
+  %sub = fsub float -0.000000e+00, %x
+  %mul = fmul float %sub, 2.0e+1
+  ret float %mul
+
+; CHECK: @test1
+; CHECK: fmul float %x, -2.000000e+01
+}
+
+; (-0.0 - X) * (-0.0 - Y) => X * Y
+define float @test2(float %x, float %y) {
+  %sub1 = fsub float -0.000000e+00, %x
+  %sub2 = fsub float -0.000000e+00, %y
+  %mul = fmul float %sub1, %sub2
+  ret float %mul
+; CHECK: @test2
+; CHECK: fmul float %x, %y
+}
+
+; (-0.0 - X) * Y => -0.0 - (X * Y)
+define float @test3(float %x, float %y) {
+  %sub1 = fsub float -0.000000e+00, %x
+  %mul = fmul float %sub1, %y
+  ret float %mul
+; CHECK: @test3
+; CHECK: %1 = fmul float %x, %y
+; CHECK: %mul = fsub float -0.000000e+00, %1
+}
+
+; "(-0.0 - X) * Y => -0.0 - (X * Y)" is disabled if expression "-0.0 - X"
+; has multiple uses.
+define float @test4(float %x, float %y) {
+  %sub1 = fsub float -0.000000e+00, %x
+  %mul = fmul float %sub1, %y
+  %mul2 = fmul float %mul, %sub1
+  ret float %mul2
+; CHECK: @test4
+; CHECK: fsub float -0.000000e+00, %x
+}
Index: lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
===================================================================
--- lib/Transforms/InstCombine/InstCombineMulDivRem.cpp	(revision 171814)
+++ lib/Transforms/InstCombine/InstCombineMulDivRem.cpp	(working copy)
@@ -438,9 +438,6 @@
     }
   }
 
-  if (Value *Op0v = dyn_castFNegVal(Op0))     // -X * -Y = X*Y
-    if (Value *Op1v = dyn_castFNegVal(Op1))
-      return BinaryOperator::CreateFMul(Op0v, Op1v);
 
   // Under unsafe algebra do:
   // X * log2(0.5*Y) = X*log2(Y) - X
@@ -469,36 +466,83 @@
     }
   }
 
-  // X * cond ? 1.0 : 0.0 => cond ? X : 0.0
-  if (I.hasNoNaNs() && I.hasNoSignedZeros()) {
-    Value *V0 = I.getOperand(0);
-    Value *V1 = I.getOperand(1);
-    Value *Cond, *SLHS, *SRHS;
-    bool Match = false;
+  // Handle symmetric situation in a 2-iteration loop
+  Value *Opnd0 = Op0;
+  Value *Opnd1 = Op1;
+  for (int i = 0; i < 2; i++) {
+    FastMathFlags FMF = I.getFastMathFlags();
 
-    if (match(V0, m_Select(m_Value(Cond), m_Value(SLHS), m_Value(SRHS)))) {
-      Match = true;
-    } else if (match(V1, m_Select(m_Value(Cond), m_Value(SLHS), 
-                     m_Value(SRHS)))) {
-      Match = true;
-      std::swap(V0, V1);
+    // X * cond ? V1 : V2 => cond ? X*V1 : X*V2 
+    //  if X*V1 and X*V2 can be simplified. 
+    // E.g "cond ? 1.0 : 0.0 => cond ? X : 0.0
+    {
+      Value *Cond, *SLHS, *SRHS;
+      if (match(Opnd0, m_Select(m_Value(Cond), m_Value(SLHS), m_Value(SRHS)))) {
+        Value *SimpL = SimplifyFMulInst(SLHS, Opnd1, FMF);
+        Value *SimpR = 0;
+        if (SimpL && (SimpR = SimplifyFMulInst(SRHS, Opnd1, FMF))) {
+          Value *T = Builder->CreateSelect(Cond, SimpL, SimpR);
+          if (FMF.any())
+            cast<Instruction>(T)->setFastMathFlags(FMF);
+          return ReplaceInstUsesWith(I, T);
+        }
+      }
     }
 
-    if (Match) {
-      ConstantFP *C0 = dyn_cast<ConstantFP>(SLHS);
-      ConstantFP *C1 = dyn_cast<ConstantFP>(SRHS);
+    if (Value *N0 = dyn_castFNegVal(Opnd0)) {
+      Value *N1 = dyn_castFNegVal(Opnd1);
 
-      if (C0 && C1 &&
-          ((C0->isZero() && C1->isExactlyValue(1.0)) ||
-           (C1->isZero() && C0->isExactlyValue(1.0)))) {
-        Value *T;
-        if (C0->isZero())
-          T = Builder->CreateSelect(Cond, SLHS, V1);
-        else
-          T = Builder->CreateSelect(Cond, V1, SRHS);
-        return ReplaceInstUsesWith(I, T);
+      // -X * -Y => X*Y
+      if (N1)
+        return BinaryOperator::CreateFMul(N0, N1);
+
+      if (Opnd0->hasOneUse()) {
+        // -X * Y => -(X*Y) (Promote negation as high as possible)
+        Value *T = Builder->CreateFMul(N0, Opnd1);
+        cast<Instruction>(T)->setDebugLoc(I.getDebugLoc());
+        Instruction *Neg = BinaryOperator::CreateFNeg(T);
+        if (FMF.any()) {
+          cast<Instruction>(T)->setFastMathFlags(FMF);
+          Neg->setFastMathFlags(FMF);
+        }
+        return Neg;
       }
     }
+
+    // (X*Y) * X => (X*X) * Y where Y != X
+    //  The purpose is two-fold: 
+    //   1) to form a power expression (of X).
+    //   2) potentially shorten the critical path: After transformation, the
+    //  latency of the instruction Y is amortized by the expression of X*X,
+    //  and therefore Y is in a "less critial" position compared to what it
+    //  was before the transformation.
+    //
+    if (FMF.unsafeAlgebra()) {
+      Value *Opnd0_0, *Opnd0_1;
+      if (Opnd0->hasOneUse() &&
+          match(Opnd0, m_FMul(m_Value(Opnd0_0), m_Value(Opnd0_1)))) {
+        Value *Y = 0;
+        if (Opnd0_0 == Opnd1 && Opnd0_1 != Opnd1)
+          Y = Opnd0_1;
+        else if (Opnd0_1 == Opnd1 && Opnd0_0 != Opnd1)
+          Y = Opnd0_0;
+
+        if (Y) {
+          Instruction *T = cast<Instruction>(Builder->CreateFMul(Opnd1, Opnd1));
+          T->setHasUnsafeAlgebra(true);
+          T->setDebugLoc(I.getDebugLoc());
+
+          Instruction *R = BinaryOperator::CreateFMul(T, Y);
+          R->setHasUnsafeAlgebra(true);
+          return R;
+        }
+      }
+    }
+
+    if (!isa<Constant>(Op1))
+      std::swap(Opnd0, Opnd1);
+    else
+      break;
   }
 
   return Changed ? &I : 0;


More information about the llvm-commits mailing list