[llvm] r325960 - [InstCombine] allow fmul-sqrt folds with less than full -ffast-math

Sanjay Patel via llvm-commits llvm-commits at lists.llvm.org
Fri Feb 23 13:16:12 PST 2018


Author: spatel
Date: Fri Feb 23 13:16:12 2018
New Revision: 325960

URL: http://llvm.org/viewvc/llvm-project?rev=325960&view=rev
Log:
[InstCombine] allow fmul-sqrt folds with less than full -ffast-math

Also, add a Builder method for intrinsics to reduce code duplication for clients.

Modified:
    llvm/trunk/include/llvm/IR/IRBuilder.h
    llvm/trunk/lib/IR/IRBuilder.cpp
    llvm/trunk/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
    llvm/trunk/test/Transforms/InstCombine/fmul-sqrt.ll

Modified: llvm/trunk/include/llvm/IR/IRBuilder.h
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/include/llvm/IR/IRBuilder.h?rev=325960&r1=325959&r2=325960&view=diff
==============================================================================
--- llvm/trunk/include/llvm/IR/IRBuilder.h (original)
+++ llvm/trunk/include/llvm/IR/IRBuilder.h Fri Feb 23 13:16:12 2018
@@ -669,6 +669,13 @@ public:
                                   Value *LHS, Value *RHS,
                                   const Twine &Name = "");
 
+  /// Create a call to intrinsic \p ID with 1 or more operands assuming the
+  /// intrinsic and all operands have the same type. If \p FMFSource is
+  /// provided, copy fast-math-flags from that instruction to the intrinsic.
+  CallInst *CreateIntrinsic(Intrinsic::ID ID, ArrayRef<Value *> Args,
+                            Instruction *FMFSource = nullptr,
+                            const Twine &Name = "");
+
   /// Create call to the minnum intrinsic.
   CallInst *CreateMinNum(Value *LHS, Value *RHS, const Twine &Name = "") {
     return CreateBinaryIntrinsic(Intrinsic::minnum, LHS, RHS, Name);

Modified: llvm/trunk/lib/IR/IRBuilder.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/IR/IRBuilder.cpp?rev=325960&r1=325959&r2=325960&view=diff
==============================================================================
--- llvm/trunk/lib/IR/IRBuilder.cpp (original)
+++ llvm/trunk/lib/IR/IRBuilder.cpp Fri Feb 23 13:16:12 2018
@@ -59,8 +59,11 @@ Value *IRBuilderBase::getCastedInt8PtrVa
 
 static CallInst *createCallHelper(Value *Callee, ArrayRef<Value *> Ops,
                                   IRBuilderBase *Builder,
-                                  const Twine& Name="") {
+                                  const Twine &Name = "",
+                                  Instruction *FMFSource = nullptr) {
   CallInst *CI = CallInst::Create(Callee, Ops, Name);
+  if (FMFSource)
+    CI->copyFastMathFlags(FMFSource);
   Builder->GetInsertBlock()->getInstList().insert(Builder->GetInsertPoint(),CI);
   Builder->SetInstDebugLocation(CI);
   return CI;  
@@ -646,7 +649,18 @@ CallInst *IRBuilderBase::CreateGCRelocat
 CallInst *IRBuilderBase::CreateBinaryIntrinsic(Intrinsic::ID ID,
                                                Value *LHS, Value *RHS,
                                                const Twine &Name) {
-  Module *M = BB->getParent()->getParent();
-  Function *Fn =  Intrinsic::getDeclaration(M, ID, { LHS->getType() });
+  Module *M = BB->getModule();
+  Function *Fn = Intrinsic::getDeclaration(M, ID, { LHS->getType() });
   return createCallHelper(Fn, { LHS, RHS }, this, Name);
 }
+
+CallInst *IRBuilderBase::CreateIntrinsic(Intrinsic::ID ID,
+                                         ArrayRef<Value *> Args,
+                                         Instruction *FMFSource,
+                                         const Twine &Name) {
+  assert(!Args.empty() && "Expected at least one argument to intrinsic");
+  Module *M = BB->getModule();
+  Function *Fn = Intrinsic::getDeclaration(M, ID, { Args.front()->getType() });
+  return createCallHelper(Fn, Args, this, Name, FMFSource);
+}
+

Modified: llvm/trunk/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp?rev=325960&r1=325959&r2=325960&view=diff
==============================================================================
--- llvm/trunk/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp (original)
+++ llvm/trunk/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp Fri Feb 23 13:16:12 2018
@@ -662,24 +662,17 @@ Instruction *InstCombiner::visitFMul(Bin
     }
   }
 
-  // sqrt(a) * sqrt(b) -> sqrt(a * b)
-  if (AllowReassociate && Op0->hasOneUse() && Op1->hasOneUse()) {
-    Value *Opnd0 = nullptr;
-    Value *Opnd1 = nullptr;
-    if (match(Op0, m_Intrinsic<Intrinsic::sqrt>(m_Value(Opnd0))) &&
-        match(Op1, m_Intrinsic<Intrinsic::sqrt>(m_Value(Opnd1)))) {
-      BuilderTy::FastMathFlagGuard Guard(Builder);
-      Builder.setFastMathFlags(I.getFastMathFlags());
-      Value *FMulVal = Builder.CreateFMul(Opnd0, Opnd1);
-      Value *Sqrt = Intrinsic::getDeclaration(I.getModule(), 
-                                              Intrinsic::sqrt, I.getType());
-      Value *SqrtCall = Builder.CreateCall(Sqrt, FMulVal);
-      return replaceInstUsesWith(I, SqrtCall);
-    }
+  // sqrt(X) * sqrt(Y) -> sqrt(X * Y)
+  Value *X, *Y;
+  if (I.hasAllowReassoc() &&
+      match(Op0, m_OneUse(m_Intrinsic<Intrinsic::sqrt>(m_Value(X)))) &&
+      match(Op1, m_OneUse(m_Intrinsic<Intrinsic::sqrt>(m_Value(Y))))) {
+    Value *XY = Builder.CreateFMulFMF(X, Y, &I);
+    Value *Sqrt = Builder.CreateIntrinsic(Intrinsic::sqrt, { XY }, &I);
+    return replaceInstUsesWith(I, Sqrt);
   }
 
   // -X * -Y --> X * Y
-  Value *X, *Y;
   if (match(Op0, m_FNeg(m_Value(X))) && match(Op1, m_FNeg(m_Value(Y))))
     return BinaryOperator::CreateFMulFMF(X, Y, &I);
 

Modified: llvm/trunk/test/Transforms/InstCombine/fmul-sqrt.ll
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/test/Transforms/InstCombine/fmul-sqrt.ll?rev=325960&r1=325959&r2=325960&view=diff
==============================================================================
--- llvm/trunk/test/Transforms/InstCombine/fmul-sqrt.ll (original)
+++ llvm/trunk/test/Transforms/InstCombine/fmul-sqrt.ll Fri Feb 23 13:16:12 2018
@@ -5,6 +5,7 @@ declare double @llvm.sqrt.f64(double) no
 declare void @use(double)
 
 ; sqrt(a) * sqrt(b) no math flags
+
 define double @sqrt_a_sqrt_b(double %a, double %b) {
 ; CHECK-LABEL: @sqrt_a_sqrt_b(
 ; CHECK-NEXT:    [[TMP1:%.*]] = call double @llvm.sqrt.f64(double [[A:%.*]])
@@ -19,6 +20,7 @@ define double @sqrt_a_sqrt_b(double %a,
 }
 
 ; sqrt(a) * sqrt(b) fast-math, multiple uses
+
 define double @sqrt_a_sqrt_b_multiple_uses(double %a, double %b) {
 ; CHECK-LABEL: @sqrt_a_sqrt_b_multiple_uses(
 ; CHECK-NEXT:    [[TMP1:%.*]] = call fast double @llvm.sqrt.f64(double [[A:%.*]])
@@ -35,33 +37,37 @@ define double @sqrt_a_sqrt_b_multiple_us
 }
 
 ; sqrt(a) * sqrt(b) => sqrt(a*b) with fast-math
-define double @sqrt_a_sqrt_b_fast(double %a, double %b) {
-; CHECK-LABEL: @sqrt_a_sqrt_b_fast(
-; CHECK-NEXT:    [[TMP1:%.*]] = fmul fast double [[A:%.*]], [[B:%.*]]
-; CHECK-NEXT:    [[TMP2:%.*]] = call fast double @llvm.sqrt.f64(double [[TMP1]])
+
+define double @sqrt_a_sqrt_b_reassoc(double %a, double %b) {
+; CHECK-LABEL: @sqrt_a_sqrt_b_reassoc(
+; CHECK-NEXT:    [[TMP1:%.*]] = fmul reassoc double [[A:%.*]], [[B:%.*]]
+; CHECK-NEXT:    [[TMP2:%.*]] = call reassoc double @llvm.sqrt.f64(double [[TMP1]])
 ; CHECK-NEXT:    ret double [[TMP2]]
 ;
-  %1 = call fast double @llvm.sqrt.f64(double %a)
-  %2 = call fast double @llvm.sqrt.f64(double %b)
-  %mul = fmul fast double %1, %2
+  %1 = call double @llvm.sqrt.f64(double %a)
+  %2 = call double @llvm.sqrt.f64(double %b)
+  %mul = fmul reassoc double %1, %2
   ret double %mul
 }
 
-; sqrt(a) * sqrt(b) * sqrt(c) * sqrt(d) => sqrt(a*b*c+d) with fast-math
-define double @sqrt_a_sqrt_b_sqrt_c_sqrt_d_fast(double %a, double %b, double %c, double %d) {
-; CHECK-LABEL: @sqrt_a_sqrt_b_sqrt_c_sqrt_d_fast(
-; CHECK-NEXT:    [[TMP1:%.*]] = fmul fast double [[A:%.*]], [[B:%.*]]
-; CHECK-NEXT:    [[TMP2:%.*]] = fmul fast double [[TMP1]], [[C:%.*]]
-; CHECK-NEXT:    [[TMP3:%.*]] = fmul fast double [[TMP2]], [[D:%.*]]
-; CHECK-NEXT:    [[TMP4:%.*]] = call fast double @llvm.sqrt.f64(double [[TMP3]])
+; sqrt(a) * sqrt(b) * sqrt(c) * sqrt(d) => sqrt(a*b*c*d) with fast-math
+; 'reassoc' on the fmuls is all that is required, but check propagation of other FMF.
+
+define double @sqrt_a_sqrt_b_sqrt_c_sqrt_d_reassoc(double %a, double %b, double %c, double %d) {
+; CHECK-LABEL: @sqrt_a_sqrt_b_sqrt_c_sqrt_d_reassoc(
+; CHECK-NEXT:    [[TMP1:%.*]] = fmul reassoc arcp double [[A:%.*]], [[B:%.*]]
+; CHECK-NEXT:    [[TMP2:%.*]] = fmul reassoc nnan double [[TMP1]], [[C:%.*]]
+; CHECK-NEXT:    [[TMP3:%.*]] = fmul reassoc ninf double [[TMP2]], [[D:%.*]]
+; CHECK-NEXT:    [[TMP4:%.*]] = call reassoc ninf double @llvm.sqrt.f64(double [[TMP3]])
 ; CHECK-NEXT:    ret double [[TMP4]]
 ;
-  %1 = call fast double @llvm.sqrt.f64(double %a)
-  %2 = call fast double @llvm.sqrt.f64(double %b)
-  %mul = fmul fast double %1, %2
-  %3 = call fast double @llvm.sqrt.f64(double %c)
-  %mul1 = fmul fast double %mul, %3
-  %4 = call fast double @llvm.sqrt.f64(double %d)
-  %mul2 = fmul fast double %mul1, %4
+  %1 = call double @llvm.sqrt.f64(double %a)
+  %2 = call double @llvm.sqrt.f64(double %b)
+  %3 = call double @llvm.sqrt.f64(double %c)
+  %4 = call double @llvm.sqrt.f64(double %d)
+  %mul = fmul reassoc arcp double %1, %2
+  %mul1 = fmul reassoc nnan double %mul, %3
+  %mul2 = fmul reassoc ninf double %mul1, %4
   ret double %mul2
 }
+




More information about the llvm-commits mailing list