[llvm] r219944 - fold: sqrt(x * x * y) -> fabs(x) * sqrt(y)
Sanjay Patel
spatel at rotateright.com
Thu Oct 16 11:48:17 PDT 2014
Author: spatel
Date: Thu Oct 16 13:48:17 2014
New Revision: 219944
URL: http://llvm.org/viewvc/llvm-project?rev=219944&view=rev
Log:
fold: sqrt(x * x * y) -> fabs(x) * sqrt(y)
If a square root call has an FP multiplication argument that can be reassociated,
then we can hoist a repeated factor out of the square root call and into a fabs().
In the simplest case, this:
y = sqrt(x * x);
becomes this:
y = fabs(x);
This patch relies on an earlier optimization in instcombine or reassociate to put the
multiplication tree into a canonical form, so we don't have to search over
every permutation of the multiplication tree.
Because there are no IR-level FastMathFlags for intrinsics (PR21290), we have to
use function-level attributes to do this optimization. This needs to be fixed
for both the intrinsics and in the backend.
Differential Revision: http://reviews.llvm.org/D5787
Modified:
llvm/trunk/include/llvm/Transforms/Utils/SimplifyLibCalls.h
llvm/trunk/lib/Transforms/Utils/SimplifyLibCalls.cpp
llvm/trunk/test/Transforms/InstCombine/fast-math.ll
Modified: llvm/trunk/include/llvm/Transforms/Utils/SimplifyLibCalls.h
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/include/llvm/Transforms/Utils/SimplifyLibCalls.h?rev=219944&r1=219943&r2=219944&view=diff
==============================================================================
--- llvm/trunk/include/llvm/Transforms/Utils/SimplifyLibCalls.h (original)
+++ llvm/trunk/include/llvm/Transforms/Utils/SimplifyLibCalls.h Thu Oct 16 13:48:17 2014
@@ -93,6 +93,7 @@ private:
Value *optimizePow(CallInst *CI, IRBuilder<> &B);
Value *optimizeExp2(CallInst *CI, IRBuilder<> &B);
Value *optimizeFabs(CallInst *CI, IRBuilder<> &B);
+ Value *optimizeSqrt(CallInst *CI, IRBuilder<> &B);
Value *optimizeSinCosPi(CallInst *CI, IRBuilder<> &B);
// Integer Library Call Optimizations
Modified: llvm/trunk/lib/Transforms/Utils/SimplifyLibCalls.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Transforms/Utils/SimplifyLibCalls.cpp?rev=219944&r1=219943&r2=219944&view=diff
==============================================================================
--- llvm/trunk/lib/Transforms/Utils/SimplifyLibCalls.cpp (original)
+++ llvm/trunk/lib/Transforms/Utils/SimplifyLibCalls.cpp Thu Oct 16 13:48:17 2014
@@ -27,12 +27,14 @@
#include "llvm/IR/Intrinsics.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/Module.h"
+#include "llvm/IR/PatternMatch.h"
#include "llvm/Support/Allocator.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Target/TargetLibraryInfo.h"
#include "llvm/Transforms/Utils/BuildLibCalls.h"
using namespace llvm;
+using namespace PatternMatch;
static cl::opt<bool>
ColdErrorCalls("error-reporting-is-cold", cl::init(true), cl::Hidden,
@@ -1254,6 +1256,85 @@ Value *LibCallSimplifier::optimizeFabs(C
return Ret;
}
+Value *LibCallSimplifier::optimizeSqrt(CallInst *CI, IRBuilder<> &B) {
+ Function *Callee = CI->getCalledFunction();
+
+ Value *Ret = nullptr;
+ if (UnsafeFPShrink && Callee->getName() == "sqrt" &&
+ TLI->has(LibFunc::sqrtf)) {
+ Ret = optimizeUnaryDoubleFP(CI, B, true);
+ }
+
+ // FIXME: For finer-grain optimization, we need intrinsics to have the same
+ // fast-math flag decorations that are applied to FP instructions. For now,
+ // we have to rely on the function-level unsafe-fp-math attribute to do this
+ // optimization because there's no other way to express that the sqrt can be
+ // reassociated.
+ Function *F = CI->getParent()->getParent();
+ if (F->hasFnAttribute("unsafe-fp-math")) {
+ // Check for unsafe-fp-math = true.
+ Attribute Attr = F->getFnAttribute("unsafe-fp-math");
+ if (Attr.getValueAsString() != "true")
+ return Ret;
+ }
+ Value *Op = CI->getArgOperand(0);
+ if (Instruction *I = dyn_cast<Instruction>(Op)) {
+ if (I->getOpcode() == Instruction::FMul && I->hasUnsafeAlgebra()) {
+ // We're looking for a repeated factor in a multiplication tree,
+ // so we can do this fold: sqrt(x * x) -> fabs(x);
+ // or this fold: sqrt(x * x * y) -> fabs(x) * sqrt(y).
+ Value *Op0 = I->getOperand(0);
+ Value *Op1 = I->getOperand(1);
+ Value *RepeatOp = nullptr;
+ Value *OtherOp = nullptr;
+ if (Op0 == Op1) {
+ // Simple match: the operands of the multiply are identical.
+ RepeatOp = Op0;
+ } else {
+ // Look for a more complicated pattern: one of the operands is itself
+ // a multiply, so search for a common factor in that multiply.
+ // Note: We don't bother looking any deeper than this first level or for
+ // variations of this pattern because instcombine's visitFMUL and/or the
+ // reassociation pass should give us this form.
+ Value *OtherMul0, *OtherMul1;
+ if (match(Op0, m_FMul(m_Value(OtherMul0), m_Value(OtherMul1)))) {
+ // Pattern: sqrt((x * y) * z)
+ if (OtherMul0 == OtherMul1) {
+ // Matched: sqrt((x * x) * z)
+ RepeatOp = OtherMul0;
+ OtherOp = Op1;
+ }
+ }
+ }
+ if (RepeatOp) {
+ // Fast math flags for any created instructions should match the sqrt
+ // and multiply.
+ // FIXME: We're not checking the sqrt because it doesn't have
+ // fast-math-flags (see earlier comment).
+ IRBuilder<true, ConstantFolder,
+ IRBuilderDefaultInserter<true> >::FastMathFlagGuard Guard(B);
+ B.SetFastMathFlags(I->getFastMathFlags());
+ // If we found a repeated factor, hoist it out of the square root and
+ // replace it with the fabs of that factor.
+ Module *M = Callee->getParent();
+ Type *ArgType = Op->getType();
+ Value *Fabs = Intrinsic::getDeclaration(M, Intrinsic::fabs, ArgType);
+ Value *FabsCall = B.CreateCall(Fabs, RepeatOp, "fabs");
+ if (OtherOp) {
+ // If we found a non-repeated factor, we still need to get its square
+ // root. We then multiply that by the value that was simplified out
+ // of the square root calculation.
+ Value *Sqrt = Intrinsic::getDeclaration(M, Intrinsic::sqrt, ArgType);
+ Value *SqrtCall = B.CreateCall(Sqrt, OtherOp, "sqrt");
+ return B.CreateFMul(FabsCall, SqrtCall);
+ }
+ return FabsCall;
+ }
+ }
+ }
+ return Ret;
+}
+
static bool isTrigLibCall(CallInst *CI);
static void insertSinCosCall(IRBuilder<> &B, Function *OrigCallee, Value *Arg,
bool UseFloat, Value *&Sin, Value *&Cos,
@@ -1919,6 +2000,8 @@ Value *LibCallSimplifier::optimizeCall(C
return optimizeExp2(CI, Builder);
case Intrinsic::fabs:
return optimizeFabs(CI, Builder);
+ case Intrinsic::sqrt:
+ return optimizeSqrt(CI, Builder);
default:
return nullptr;
}
@@ -1995,6 +2078,10 @@ Value *LibCallSimplifier::optimizeCall(C
case LibFunc::fabs:
case LibFunc::fabsl:
return optimizeFabs(CI, Builder);
+ case LibFunc::sqrtf:
+ case LibFunc::sqrt:
+ case LibFunc::sqrtl:
+ return optimizeSqrt(CI, Builder);
case LibFunc::ffs:
case LibFunc::ffsl:
case LibFunc::ffsll:
@@ -2055,7 +2142,6 @@ Value *LibCallSimplifier::optimizeCall(C
case LibFunc::logb:
case LibFunc::sin:
case LibFunc::sinh:
- case LibFunc::sqrt:
case LibFunc::tan:
case LibFunc::tanh:
if (UnsafeFPShrink && hasFloatVersion(FuncName))
Modified: llvm/trunk/test/Transforms/InstCombine/fast-math.ll
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/test/Transforms/InstCombine/fast-math.ll?rev=219944&r1=219943&r2=219944&view=diff
==============================================================================
--- llvm/trunk/test/Transforms/InstCombine/fast-math.ll (original)
+++ llvm/trunk/test/Transforms/InstCombine/fast-math.ll Thu Oct 16 13:48:17 2014
@@ -530,3 +530,173 @@ define float @fact_div6(float %x) {
; CHECK: fact_div6
; CHECK: %t3 = fsub fast float %t1, %t2
}
+
+; =========================================================================
+;
+; Test-cases for square root
+;
+; =========================================================================
+
+; A squared factor fed into a square root intrinsic should be hoisted out
+; as a fabs() value.
+; We have to rely on a function-level attribute to enable this optimization
+; because intrinsics don't currently have access to IR-level fast-math
+; flags. If that changes, we can relax the requirement on all of these
+; tests to just specify 'fast' on the sqrt.
+
+attributes #0 = { "unsafe-fp-math" = "true" }
+
+declare double @llvm.sqrt.f64(double)
+
+define double @sqrt_intrinsic_arg_squared(double %x) #0 {
+ %mul = fmul fast double %x, %x
+ %sqrt = call double @llvm.sqrt.f64(double %mul)
+ ret double %sqrt
+
+; CHECK-LABEL: sqrt_intrinsic_arg_squared(
+; CHECK-NEXT: %fabs = call double @llvm.fabs.f64(double %x)
+; CHECK-NEXT: ret double %fabs
+}
+
+; Check all 6 combinations of a 3-way multiplication tree where
+; one factor is repeated.
+
+define double @sqrt_intrinsic_three_args1(double %x, double %y) #0 {
+ %mul = fmul fast double %y, %x
+ %mul2 = fmul fast double %mul, %x
+ %sqrt = call double @llvm.sqrt.f64(double %mul2)
+ ret double %sqrt
+
+; CHECK-LABEL: sqrt_intrinsic_three_args1(
+; CHECK-NEXT: %fabs = call double @llvm.fabs.f64(double %x)
+; CHECK-NEXT: %sqrt1 = call double @llvm.sqrt.f64(double %y)
+; CHECK-NEXT: %1 = fmul fast double %fabs, %sqrt1
+; CHECK-NEXT: ret double %1
+}
+
+define double @sqrt_intrinsic_three_args2(double %x, double %y) #0 {
+ %mul = fmul fast double %x, %y
+ %mul2 = fmul fast double %mul, %x
+ %sqrt = call double @llvm.sqrt.f64(double %mul2)
+ ret double %sqrt
+
+; CHECK-LABEL: sqrt_intrinsic_three_args2(
+; CHECK-NEXT: %fabs = call double @llvm.fabs.f64(double %x)
+; CHECK-NEXT: %sqrt1 = call double @llvm.sqrt.f64(double %y)
+; CHECK-NEXT: %1 = fmul fast double %fabs, %sqrt1
+; CHECK-NEXT: ret double %1
+}
+
+define double @sqrt_intrinsic_three_args3(double %x, double %y) #0 {
+ %mul = fmul fast double %x, %x
+ %mul2 = fmul fast double %mul, %y
+ %sqrt = call double @llvm.sqrt.f64(double %mul2)
+ ret double %sqrt
+
+; CHECK-LABEL: sqrt_intrinsic_three_args3(
+; CHECK-NEXT: %fabs = call double @llvm.fabs.f64(double %x)
+; CHECK-NEXT: %sqrt1 = call double @llvm.sqrt.f64(double %y)
+; CHECK-NEXT: %1 = fmul fast double %fabs, %sqrt1
+; CHECK-NEXT: ret double %1
+}
+
+define double @sqrt_intrinsic_three_args4(double %x, double %y) #0 {
+ %mul = fmul fast double %y, %x
+ %mul2 = fmul fast double %x, %mul
+ %sqrt = call double @llvm.sqrt.f64(double %mul2)
+ ret double %sqrt
+
+; CHECK-LABEL: sqrt_intrinsic_three_args4(
+; CHECK-NEXT: %fabs = call double @llvm.fabs.f64(double %x)
+; CHECK-NEXT: %sqrt1 = call double @llvm.sqrt.f64(double %y)
+; CHECK-NEXT: %1 = fmul fast double %fabs, %sqrt1
+; CHECK-NEXT: ret double %1
+}
+
+define double @sqrt_intrinsic_three_args5(double %x, double %y) #0 {
+ %mul = fmul fast double %x, %y
+ %mul2 = fmul fast double %x, %mul
+ %sqrt = call double @llvm.sqrt.f64(double %mul2)
+ ret double %sqrt
+
+; CHECK-LABEL: sqrt_intrinsic_three_args5(
+; CHECK-NEXT: %fabs = call double @llvm.fabs.f64(double %x)
+; CHECK-NEXT: %sqrt1 = call double @llvm.sqrt.f64(double %y)
+; CHECK-NEXT: %1 = fmul fast double %fabs, %sqrt1
+; CHECK-NEXT: ret double %1
+}
+
+define double @sqrt_intrinsic_three_args6(double %x, double %y) #0 {
+ %mul = fmul fast double %x, %x
+ %mul2 = fmul fast double %y, %mul
+ %sqrt = call double @llvm.sqrt.f64(double %mul2)
+ ret double %sqrt
+
+; CHECK-LABEL: sqrt_intrinsic_three_args6(
+; CHECK-NEXT: %fabs = call double @llvm.fabs.f64(double %x)
+; CHECK-NEXT: %sqrt1 = call double @llvm.sqrt.f64(double %y)
+; CHECK-NEXT: %1 = fmul fast double %fabs, %sqrt1
+; CHECK-NEXT: ret double %1
+}
+
+define double @sqrt_intrinsic_arg_4th(double %x) #0 {
+ %mul = fmul fast double %x, %x
+ %mul2 = fmul fast double %mul, %mul
+ %sqrt = call double @llvm.sqrt.f64(double %mul2)
+ ret double %sqrt
+
+; CHECK-LABEL: sqrt_intrinsic_arg_4th(
+; CHECK-NEXT: %mul = fmul fast double %x, %x
+; CHECK-NEXT: ret double %mul
+}
+
+define double @sqrt_intrinsic_arg_5th(double %x) #0 {
+ %mul = fmul fast double %x, %x
+ %mul2 = fmul fast double %mul, %x
+ %mul3 = fmul fast double %mul2, %mul
+ %sqrt = call double @llvm.sqrt.f64(double %mul3)
+ ret double %sqrt
+
+; CHECK-LABEL: sqrt_intrinsic_arg_5th(
+; CHECK-NEXT: %mul = fmul fast double %x, %x
+; CHECK-NEXT: %sqrt1 = call double @llvm.sqrt.f64(double %x)
+; CHECK-NEXT: %1 = fmul fast double %mul, %sqrt1
+; CHECK-NEXT: ret double %1
+}
+
+; Check that square root calls have the same behavior.
+
+declare float @sqrtf(float)
+declare double @sqrt(double)
+declare fp128 @sqrtl(fp128)
+
+define float @sqrt_call_squared_f32(float %x) #0 {
+ %mul = fmul fast float %x, %x
+ %sqrt = call float @sqrtf(float %mul)
+ ret float %sqrt
+
+; CHECK-LABEL: sqrt_call_squared_f32(
+; CHECK-NEXT: %fabs = call float @llvm.fabs.f32(float %x)
+; CHECK-NEXT: ret float %fabs
+}
+
+define double @sqrt_call_squared_f64(double %x) #0 {
+ %mul = fmul fast double %x, %x
+ %sqrt = call double @sqrt(double %mul)
+ ret double %sqrt
+
+; CHECK-LABEL: sqrt_call_squared_f64(
+; CHECK-NEXT: %fabs = call double @llvm.fabs.f64(double %x)
+; CHECK-NEXT: ret double %fabs
+}
+
+define fp128 @sqrt_call_squared_f128(fp128 %x) #0 {
+ %mul = fmul fast fp128 %x, %x
+ %sqrt = call fp128 @sqrtl(fp128 %mul)
+ ret fp128 %sqrt
+
+; CHECK-LABEL: sqrt_call_squared_f128(
+; CHECK-NEXT: %fabs = call fp128 @llvm.fabs.f128(fp128 %x)
+; CHECK-NEXT: ret fp128 %fabs
+}
+
More information about the llvm-commits
mailing list