[llvm] r219944 - fold: sqrt(x * x * y) -> fabs(x) * sqrt(y)

Sanjay Patel spatel at rotateright.com
Thu Oct 16 11:48:17 PDT 2014


Author: spatel
Date: Thu Oct 16 13:48:17 2014
New Revision: 219944

URL: http://llvm.org/viewvc/llvm-project?rev=219944&view=rev
Log:
fold: sqrt(x * x * y) -> fabs(x) * sqrt(y)

If a square root call has an FP multiplication argument that can be reassociated,
then we can hoist a repeated factor out of the square root call and into a fabs().

In the simplest case, this:

   y = sqrt(x * x);

becomes this:

   y = fabs(x);

This patch relies on an earlier optimization in instcombine or reassociate to put the
multiplication tree into a canonical form, so we don't have to search over
every permutation of the multiplication tree.

Because there are no IR-level FastMathFlags for intrinsics (PR21290), we have to
use function-level attributes to do this optimization. This needs to be fixed
for both the intrinsics and in the backend.

Differential Revision: http://reviews.llvm.org/D5787


Modified:
    llvm/trunk/include/llvm/Transforms/Utils/SimplifyLibCalls.h
    llvm/trunk/lib/Transforms/Utils/SimplifyLibCalls.cpp
    llvm/trunk/test/Transforms/InstCombine/fast-math.ll

Modified: llvm/trunk/include/llvm/Transforms/Utils/SimplifyLibCalls.h
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/include/llvm/Transforms/Utils/SimplifyLibCalls.h?rev=219944&r1=219943&r2=219944&view=diff
==============================================================================
--- llvm/trunk/include/llvm/Transforms/Utils/SimplifyLibCalls.h (original)
+++ llvm/trunk/include/llvm/Transforms/Utils/SimplifyLibCalls.h Thu Oct 16 13:48:17 2014
@@ -93,6 +93,7 @@ private:
   Value *optimizePow(CallInst *CI, IRBuilder<> &B);
   Value *optimizeExp2(CallInst *CI, IRBuilder<> &B);
   Value *optimizeFabs(CallInst *CI, IRBuilder<> &B);
+  Value *optimizeSqrt(CallInst *CI, IRBuilder<> &B);
   Value *optimizeSinCosPi(CallInst *CI, IRBuilder<> &B);
 
   // Integer Library Call Optimizations

Modified: llvm/trunk/lib/Transforms/Utils/SimplifyLibCalls.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Transforms/Utils/SimplifyLibCalls.cpp?rev=219944&r1=219943&r2=219944&view=diff
==============================================================================
--- llvm/trunk/lib/Transforms/Utils/SimplifyLibCalls.cpp (original)
+++ llvm/trunk/lib/Transforms/Utils/SimplifyLibCalls.cpp Thu Oct 16 13:48:17 2014
@@ -27,12 +27,14 @@
 #include "llvm/IR/Intrinsics.h"
 #include "llvm/IR/LLVMContext.h"
 #include "llvm/IR/Module.h"
+#include "llvm/IR/PatternMatch.h"
 #include "llvm/Support/Allocator.h"
 #include "llvm/Support/CommandLine.h"
 #include "llvm/Target/TargetLibraryInfo.h"
 #include "llvm/Transforms/Utils/BuildLibCalls.h"
 
 using namespace llvm;
+using namespace PatternMatch;
 
 static cl::opt<bool>
     ColdErrorCalls("error-reporting-is-cold", cl::init(true), cl::Hidden,
@@ -1254,6 +1256,85 @@ Value *LibCallSimplifier::optimizeFabs(C
   return Ret;
 }
 
+Value *LibCallSimplifier::optimizeSqrt(CallInst *CI, IRBuilder<> &B) {
+  Function *Callee = CI->getCalledFunction();
+  
+  Value *Ret = nullptr;
+  if (UnsafeFPShrink && Callee->getName() == "sqrt" &&
+      TLI->has(LibFunc::sqrtf)) {
+    Ret = optimizeUnaryDoubleFP(CI, B, true);
+  }
+
+  // FIXME: For finer-grain optimization, we need intrinsics to have the same
+  // fast-math flag decorations that are applied to FP instructions. For now,
+  // we have to rely on the function-level unsafe-fp-math attribute to do this
+  // optimization because there's no other way to express that the sqrt can be
+  // reassociated.
+  Function *F = CI->getParent()->getParent();
+  if (F->hasFnAttribute("unsafe-fp-math")) {
+    // Check for unsafe-fp-math = true.
+    Attribute Attr = F->getFnAttribute("unsafe-fp-math");
+    if (Attr.getValueAsString() != "true")
+      return Ret;
+  }
+  Value *Op = CI->getArgOperand(0);
+  if (Instruction *I = dyn_cast<Instruction>(Op)) {
+    if (I->getOpcode() == Instruction::FMul && I->hasUnsafeAlgebra()) {
+      // We're looking for a repeated factor in a multiplication tree,
+      // so we can do this fold: sqrt(x * x) -> fabs(x);
+      // or this fold: sqrt(x * x * y) -> fabs(x) * sqrt(y).
+      Value *Op0 = I->getOperand(0);
+      Value *Op1 = I->getOperand(1);
+      Value *RepeatOp = nullptr;
+      Value *OtherOp = nullptr;
+      if (Op0 == Op1) {
+        // Simple match: the operands of the multiply are identical.
+        RepeatOp = Op0;
+      } else {
+        // Look for a more complicated pattern: one of the operands is itself
+        // a multiply, so search for a common factor in that multiply.
+        // Note: We don't bother looking any deeper than this first level or for
+        // variations of this pattern because instcombine's visitFMUL and/or the
+        // reassociation pass should give us this form.
+        Value *OtherMul0, *OtherMul1;
+        if (match(Op0, m_FMul(m_Value(OtherMul0), m_Value(OtherMul1)))) {
+          // Pattern: sqrt((x * y) * z)
+          if (OtherMul0 == OtherMul1) {
+            // Matched: sqrt((x * x) * z)
+            RepeatOp = OtherMul0;
+            OtherOp = Op1;
+          }
+        }
+      }
+      if (RepeatOp) {
+        // Fast math flags for any created instructions should match the sqrt
+        // and multiply.
+        // FIXME: We're not checking the sqrt because it doesn't have
+        // fast-math-flags (see earlier comment).
+        IRBuilder<true, ConstantFolder,
+          IRBuilderDefaultInserter<true> >::FastMathFlagGuard Guard(B);
+        B.SetFastMathFlags(I->getFastMathFlags());
+        // If we found a repeated factor, hoist it out of the square root and
+        // replace it with the fabs of that factor.
+        Module *M = Callee->getParent();
+        Type *ArgType = Op->getType();
+        Value *Fabs = Intrinsic::getDeclaration(M, Intrinsic::fabs, ArgType);
+        Value *FabsCall = B.CreateCall(Fabs, RepeatOp, "fabs");
+        if (OtherOp) {
+          // If we found a non-repeated factor, we still need to get its square
+          // root. We then multiply that by the value that was simplified out
+          // of the square root calculation.
+          Value *Sqrt = Intrinsic::getDeclaration(M, Intrinsic::sqrt, ArgType);
+          Value *SqrtCall = B.CreateCall(Sqrt, OtherOp, "sqrt");
+          return B.CreateFMul(FabsCall, SqrtCall);
+        }
+        return FabsCall;
+      }
+    }
+  }
+  return Ret;
+}
+
 static bool isTrigLibCall(CallInst *CI);
 static void insertSinCosCall(IRBuilder<> &B, Function *OrigCallee, Value *Arg,
                              bool UseFloat, Value *&Sin, Value *&Cos,
@@ -1919,6 +2000,8 @@ Value *LibCallSimplifier::optimizeCall(C
       return optimizeExp2(CI, Builder);
     case Intrinsic::fabs:
       return optimizeFabs(CI, Builder);
+    case Intrinsic::sqrt:
+      return optimizeSqrt(CI, Builder);
     default:
       return nullptr;
     }
@@ -1995,6 +2078,10 @@ Value *LibCallSimplifier::optimizeCall(C
     case LibFunc::fabs:
     case LibFunc::fabsl:
       return optimizeFabs(CI, Builder);
+    case LibFunc::sqrtf:
+    case LibFunc::sqrt:
+    case LibFunc::sqrtl:
+      return optimizeSqrt(CI, Builder);
     case LibFunc::ffs:
     case LibFunc::ffsl:
     case LibFunc::ffsll:
@@ -2055,7 +2142,6 @@ Value *LibCallSimplifier::optimizeCall(C
     case LibFunc::logb:
     case LibFunc::sin:
     case LibFunc::sinh:
-    case LibFunc::sqrt:
     case LibFunc::tan:
     case LibFunc::tanh:
       if (UnsafeFPShrink && hasFloatVersion(FuncName))

Modified: llvm/trunk/test/Transforms/InstCombine/fast-math.ll
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/test/Transforms/InstCombine/fast-math.ll?rev=219944&r1=219943&r2=219944&view=diff
==============================================================================
--- llvm/trunk/test/Transforms/InstCombine/fast-math.ll (original)
+++ llvm/trunk/test/Transforms/InstCombine/fast-math.ll Thu Oct 16 13:48:17 2014
@@ -530,3 +530,173 @@ define float @fact_div6(float %x) {
 ; CHECK: fact_div6
 ; CHECK: %t3 = fsub fast float %t1, %t2
 }
+
+; =========================================================================
+;
+;   Test-cases for square root
+;
+; =========================================================================
+
+; A squared factor fed into a square root intrinsic should be hoisted out
+; as a fabs() value.
+; We have to rely on a function-level attribute to enable this optimization
+; because intrinsics don't currently have access to IR-level fast-math
+; flags. If that changes, we can relax the requirement on all of these
+; tests to just specify 'fast' on the sqrt.
+
+attributes #0 = { "unsafe-fp-math" = "true" }
+
+declare double @llvm.sqrt.f64(double)
+
+define double @sqrt_intrinsic_arg_squared(double %x) #0 {
+  %mul = fmul fast double %x, %x
+  %sqrt = call double @llvm.sqrt.f64(double %mul)
+  ret double %sqrt
+
+; CHECK-LABEL: sqrt_intrinsic_arg_squared(
+; CHECK-NEXT: %fabs = call double @llvm.fabs.f64(double %x)
+; CHECK-NEXT: ret double %fabs
+}
+
+; Check all 6 combinations of a 3-way multiplication tree where
+; one factor is repeated.
+
+define double @sqrt_intrinsic_three_args1(double %x, double %y) #0 {
+  %mul = fmul fast double %y, %x
+  %mul2 = fmul fast double %mul, %x
+  %sqrt = call double @llvm.sqrt.f64(double %mul2)
+  ret double %sqrt
+
+; CHECK-LABEL: sqrt_intrinsic_three_args1(
+; CHECK-NEXT: %fabs = call double @llvm.fabs.f64(double %x)
+; CHECK-NEXT: %sqrt1 = call double @llvm.sqrt.f64(double %y)
+; CHECK-NEXT: %1 = fmul fast double %fabs, %sqrt1
+; CHECK-NEXT: ret double %1
+}
+
+define double @sqrt_intrinsic_three_args2(double %x, double %y) #0 {
+  %mul = fmul fast double %x, %y
+  %mul2 = fmul fast double %mul, %x
+  %sqrt = call double @llvm.sqrt.f64(double %mul2)
+  ret double %sqrt
+
+; CHECK-LABEL: sqrt_intrinsic_three_args2(
+; CHECK-NEXT: %fabs = call double @llvm.fabs.f64(double %x)
+; CHECK-NEXT: %sqrt1 = call double @llvm.sqrt.f64(double %y)
+; CHECK-NEXT: %1 = fmul fast double %fabs, %sqrt1
+; CHECK-NEXT: ret double %1
+}
+
+define double @sqrt_intrinsic_three_args3(double %x, double %y) #0 {
+  %mul = fmul fast double %x, %x
+  %mul2 = fmul fast double %mul, %y
+  %sqrt = call double @llvm.sqrt.f64(double %mul2)
+  ret double %sqrt
+
+; CHECK-LABEL: sqrt_intrinsic_three_args3(
+; CHECK-NEXT: %fabs = call double @llvm.fabs.f64(double %x)
+; CHECK-NEXT: %sqrt1 = call double @llvm.sqrt.f64(double %y)
+; CHECK-NEXT: %1 = fmul fast double %fabs, %sqrt1
+; CHECK-NEXT: ret double %1
+}
+
+define double @sqrt_intrinsic_three_args4(double %x, double %y) #0 {
+  %mul = fmul fast double %y, %x
+  %mul2 = fmul fast double %x, %mul
+  %sqrt = call double @llvm.sqrt.f64(double %mul2)
+  ret double %sqrt
+
+; CHECK-LABEL: sqrt_intrinsic_three_args4(
+; CHECK-NEXT: %fabs = call double @llvm.fabs.f64(double %x)
+; CHECK-NEXT: %sqrt1 = call double @llvm.sqrt.f64(double %y)
+; CHECK-NEXT: %1 = fmul fast double %fabs, %sqrt1
+; CHECK-NEXT: ret double %1
+}
+
+define double @sqrt_intrinsic_three_args5(double %x, double %y) #0 {
+  %mul = fmul fast double %x, %y
+  %mul2 = fmul fast double %x, %mul
+  %sqrt = call double @llvm.sqrt.f64(double %mul2)
+  ret double %sqrt
+
+; CHECK-LABEL: sqrt_intrinsic_three_args5(
+; CHECK-NEXT: %fabs = call double @llvm.fabs.f64(double %x)
+; CHECK-NEXT: %sqrt1 = call double @llvm.sqrt.f64(double %y)
+; CHECK-NEXT: %1 = fmul fast double %fabs, %sqrt1
+; CHECK-NEXT: ret double %1
+}
+
+define double @sqrt_intrinsic_three_args6(double %x, double %y) #0 {
+  %mul = fmul fast double %x, %x
+  %mul2 = fmul fast double %y, %mul
+  %sqrt = call double @llvm.sqrt.f64(double %mul2)
+  ret double %sqrt
+
+; CHECK-LABEL: sqrt_intrinsic_three_args6(
+; CHECK-NEXT: %fabs = call double @llvm.fabs.f64(double %x)
+; CHECK-NEXT: %sqrt1 = call double @llvm.sqrt.f64(double %y)
+; CHECK-NEXT: %1 = fmul fast double %fabs, %sqrt1
+; CHECK-NEXT: ret double %1
+}
+
+define double @sqrt_intrinsic_arg_4th(double %x) #0 {
+  %mul = fmul fast double %x, %x
+  %mul2 = fmul fast double %mul, %mul
+  %sqrt = call double @llvm.sqrt.f64(double %mul2)
+  ret double %sqrt
+
+; CHECK-LABEL: sqrt_intrinsic_arg_4th(
+; CHECK-NEXT: %mul = fmul fast double %x, %x
+; CHECK-NEXT: ret double %mul
+}
+
+define double @sqrt_intrinsic_arg_5th(double %x) #0 {
+  %mul = fmul fast double %x, %x
+  %mul2 = fmul fast double %mul, %x
+  %mul3 = fmul fast double %mul2, %mul
+  %sqrt = call double @llvm.sqrt.f64(double %mul3)
+  ret double %sqrt
+
+; CHECK-LABEL: sqrt_intrinsic_arg_5th(
+; CHECK-NEXT: %mul = fmul fast double %x, %x
+; CHECK-NEXT: %sqrt1 = call double @llvm.sqrt.f64(double %x)
+; CHECK-NEXT: %1 = fmul fast double %mul, %sqrt1
+; CHECK-NEXT: ret double %1
+}
+
+; Check that square root calls have the same behavior.
+
+declare float @sqrtf(float)
+declare double @sqrt(double)
+declare fp128 @sqrtl(fp128)
+
+define float @sqrt_call_squared_f32(float %x) #0 {
+  %mul = fmul fast float %x, %x
+  %sqrt = call float @sqrtf(float %mul)
+  ret float %sqrt
+
+; CHECK-LABEL: sqrt_call_squared_f32(
+; CHECK-NEXT: %fabs = call float @llvm.fabs.f32(float %x)
+; CHECK-NEXT: ret float %fabs
+}
+
+define double @sqrt_call_squared_f64(double %x) #0 {
+  %mul = fmul fast double %x, %x
+  %sqrt = call double @sqrt(double %mul)
+  ret double %sqrt
+
+; CHECK-LABEL: sqrt_call_squared_f64(
+; CHECK-NEXT: %fabs = call double @llvm.fabs.f64(double %x)
+; CHECK-NEXT: ret double %fabs
+}
+
+define fp128 @sqrt_call_squared_f128(fp128 %x) #0 {
+  %mul = fmul fast fp128 %x, %x
+  %sqrt = call fp128 @sqrtl(fp128 %mul)
+  ret fp128 %sqrt
+
+; CHECK-LABEL: sqrt_call_squared_f128(
+; CHECK-NEXT: %fabs = call fp128 @llvm.fabs.f128(fp128 %x)
+; CHECK-NEXT: ret fp128 %fabs
+}
+





More information about the llvm-commits mailing list