[llvm] [SimplifyLibCalls] Merge sqrt into the power of exp (PR #79146)

Anton Sidorenko via llvm-commits llvm-commits at lists.llvm.org
Tue Jan 23 06:51:38 PST 2024


https://github.com/asi-sc created https://github.com/llvm/llvm-project/pull/79146

Under fast-math flags it's possible to convert `sqrt(exp(X)) `into `exp(X * 0.5)`. I suppose that this transformation is always profitable. This is similar to the optimization existing in GCC.

>From d148d81877505ad24a53458e834a78f99ca32971 Mon Sep 17 00:00:00 2001
From: Anton Sidorenko <anton.sidorenko at syntacore.com>
Date: Tue, 9 Jan 2024 16:23:54 +0300
Subject: [PATCH 1/2] [SimplifyLibCalls] Precommit test for sqrt(exp(X)) ->
 exp(X * 0.5) transformation

---
 llvm/test/Transforms/InstCombine/sqrt.ll | 109 +++++++++++++++++++++++
 1 file changed, 109 insertions(+)

diff --git a/llvm/test/Transforms/InstCombine/sqrt.ll b/llvm/test/Transforms/InstCombine/sqrt.ll
index 004df3e30c72a1..e396e4dc4d94f3 100644
--- a/llvm/test/Transforms/InstCombine/sqrt.ll
+++ b/llvm/test/Transforms/InstCombine/sqrt.ll
@@ -88,7 +88,116 @@ define float @sqrt_call_fabs_f32(float %x) {
   ret float %sqrt
 }
 
+define double @sqrt_exp(double %x) {
+; CHECK-LABEL: @sqrt_exp(
+; CHECK-NEXT:    [[MUL:%.*]] = fmul fast double [[X:%.*]], 1.000000e+01
+; CHECK-NEXT:    [[E:%.*]] = call fast double @llvm.exp.f64(double [[MUL]])
+; CHECK-NEXT:    [[RES:%.*]] = call fast double @llvm.sqrt.f64(double [[E]])
+; CHECK-NEXT:    ret double [[RES]]
+;
+  %mul = fmul fast double %x, 10.0
+  %e = call fast double @llvm.exp.f64(double %mul)
+  %res = call fast double @llvm.sqrt.f64(double %e)
+  ret double %res
+}
+
+define double @sqrt_exp_2(double %x) {
+; CHECK-LABEL: @sqrt_exp_2(
+; CHECK-NEXT:    [[MUL:%.*]] = fmul fast double [[X:%.*]], 1.000000e+01
+; CHECK-NEXT:    [[E:%.*]] = call fast double @exp(double [[MUL]])
+; CHECK-NEXT:    [[RES:%.*]] = call fast double @sqrt(double [[E]])
+; CHECK-NEXT:    ret double [[RES]]
+;
+  %mul = fmul fast double %x, 10.0
+  %e = call fast double @exp(double %mul)
+  %res = call fast double @sqrt(double %e)
+  ret double %res
+}
+
+define double @sqrt_exp2(double %x) {
+; CHECK-LABEL: @sqrt_exp2(
+; CHECK-NEXT:    [[MUL:%.*]] = fmul fast double [[X:%.*]], 1.000000e+01
+; CHECK-NEXT:    [[E:%.*]] = call fast double @exp2(double [[MUL]])
+; CHECK-NEXT:    [[RES:%.*]] = call fast double @sqrt(double [[E]])
+; CHECK-NEXT:    ret double [[RES]]
+;
+  %mul = fmul fast double %x, 10.0
+  %e = call fast double @exp2(double %mul)
+  %res = call fast double @sqrt(double %e)
+  ret double %res
+}
+
+define double @sqrt_exp10(double %x) {
+; CHECK-LABEL: @sqrt_exp10(
+; CHECK-NEXT:    [[MUL:%.*]] = fmul fast double [[X:%.*]], 1.000000e+01
+; CHECK-NEXT:    [[E:%.*]] = call fast double @exp10(double [[MUL]])
+; CHECK-NEXT:    [[RES:%.*]] = call fast double @sqrt(double [[E]])
+; CHECK-NEXT:    ret double [[RES]]
+;
+  %mul = fmul fast double %x, 10.0
+  %e = call fast double @exp10(double %mul)
+  %res = call fast double @sqrt(double %e)
+  ret double %res
+}
+
+define double @sqrt_exp_nofast_1(double %x) {
+; CHECK-LABEL: @sqrt_exp_nofast_1(
+; CHECK-NEXT:    [[MUL:%.*]] = fmul double [[X:%.*]], 1.000000e+01
+; CHECK-NEXT:    [[E:%.*]] = call fast double @llvm.exp.f64(double [[MUL]])
+; CHECK-NEXT:    [[RES:%.*]] = call fast double @llvm.sqrt.f64(double [[E]])
+; CHECK-NEXT:    ret double [[RES]]
+;
+  %mul = fmul double %x, 10.0
+  %e = call fast double @llvm.exp.f64(double %mul)
+  %res = call fast double @llvm.sqrt.f64(double %e)
+  ret double %res
+}
+
+define double @sqrt_exp_nofast_2(double %x) {
+; CHECK-LABEL: @sqrt_exp_nofast_2(
+; CHECK-NEXT:    [[MUL:%.*]] = fmul fast double [[X:%.*]], 1.000000e+01
+; CHECK-NEXT:    [[E:%.*]] = call double @llvm.exp.f64(double [[MUL]])
+; CHECK-NEXT:    [[RES:%.*]] = call fast double @llvm.sqrt.f64(double [[E]])
+; CHECK-NEXT:    ret double [[RES]]
+;
+  %mul = fmul fast double %x, 10.0
+  %e = call double @llvm.exp.f64(double %mul)
+  %res = call fast double @llvm.sqrt.f64(double %e)
+  ret double %res
+}
+
+define double @sqrt_exp_nofast_3(double %x) {
+; CHECK-LABEL: @sqrt_exp_nofast_3(
+; CHECK-NEXT:    [[MUL:%.*]] = fmul fast double [[X:%.*]], 1.000000e+01
+; CHECK-NEXT:    [[E:%.*]] = call fast double @llvm.exp.f64(double [[MUL]])
+; CHECK-NEXT:    [[RES:%.*]] = call double @llvm.sqrt.f64(double [[E]])
+; CHECK-NEXT:    ret double [[RES]]
+;
+  %mul = fmul fast double %x, 10.0
+  %e = call fast double @llvm.exp.f64(double %mul)
+  %res = call double @llvm.sqrt.f64(double %e)
+  ret double %res
+}
+
+define double @sqrt_exp_noconst(double %x, double %y) {
+; CHECK-LABEL: @sqrt_exp_noconst(
+; CHECK-NEXT:    [[MUL:%.*]] = fmul fast double [[X:%.*]], [[Y:%.*]]
+; CHECK-NEXT:    [[E:%.*]] = call fast double @llvm.exp.f64(double [[MUL]])
+; CHECK-NEXT:    [[RES:%.*]] = call double @llvm.sqrt.f64(double [[E]])
+; CHECK-NEXT:    ret double [[RES]]
+;
+  %mul = fmul fast double %x, %y
+  %e = call fast double @llvm.exp.f64(double %mul)
+  %res = call double @llvm.sqrt.f64(double %e)
+  ret double %res
+}
+
 declare i32 @foo(double)
 declare double @sqrt(double) readnone
 declare float @sqrtf(float)
 declare float @llvm.fabs.f32(float)
+declare double @llvm.exp.f64(double)
+declare double @llvm.sqrt.f64(double)
+declare double @exp(double)
+declare double @exp2(double)
+declare double @exp10(double)

>From 3f95c7b8dcc1fdfe7926cfdb53590dd2c96640ba Mon Sep 17 00:00:00 2001
From: Anton Sidorenko <anton.sidorenko at syntacore.com>
Date: Wed, 27 Dec 2023 20:12:34 +0300
Subject: [PATCH 2/2] [SimplifyLibCalls] Merge sqrt into the power of exp

sqrt(exp(X)) -> exp(X * 0.5). This is similar to the optimization existing in GCC.
---
 .../llvm/Transforms/Utils/SimplifyLibCalls.h  |  1 +
 .../lib/Transforms/Utils/SimplifyLibCalls.cpp | 67 +++++++++++++++++++
 llvm/test/Transforms/InstCombine/sqrt.ll      | 38 +++++------
 3 files changed, 86 insertions(+), 20 deletions(-)

diff --git a/llvm/include/llvm/Transforms/Utils/SimplifyLibCalls.h b/llvm/include/llvm/Transforms/Utils/SimplifyLibCalls.h
index eb10545ee149e4..1aad0b2988451c 100644
--- a/llvm/include/llvm/Transforms/Utils/SimplifyLibCalls.h
+++ b/llvm/include/llvm/Transforms/Utils/SimplifyLibCalls.h
@@ -201,6 +201,7 @@ class LibCallSimplifier {
   Value *optimizeFMinFMax(CallInst *CI, IRBuilderBase &B);
   Value *optimizeLog(CallInst *CI, IRBuilderBase &B);
   Value *optimizeSqrt(CallInst *CI, IRBuilderBase &B);
+  Value *mergeSqrtToExp(CallInst *CI, IRBuilderBase &B);
   Value *optimizeSinCosPi(CallInst *CI, bool IsSin, IRBuilderBase &B);
   Value *optimizeTan(CallInst *CI, IRBuilderBase &B);
   // Wrapper for all floating point library call optimizations
diff --git a/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp b/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp
index 52eef9ab58a4d9..047a793f349e7e 100644
--- a/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp
+++ b/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp
@@ -2538,6 +2538,70 @@ Value *LibCallSimplifier::optimizeLog(CallInst *Log, IRBuilderBase &B) {
   return Ret;
 }
 
+// sqrt(exp(X)) -> exp(X * 0.5)
+Value *LibCallSimplifier::mergeSqrtToExp(CallInst *CI, IRBuilderBase &B) {
+  if (!CI->isFast())
+    return nullptr;
+
+  Function *SqrtFn = CI->getCalledFunction();
+  CallInst *Arg = dyn_cast<CallInst>(CI->getArgOperand(0));
+  if (!Arg || !Arg->isFast() || !Arg->hasOneUse())
+    return nullptr;
+  Intrinsic::ID ArgID = Arg->getIntrinsicID();
+  LibFunc ArgLb = NotLibFunc;
+  TLI->getLibFunc(*Arg, ArgLb);
+
+  LibFunc SqrtLb, ExpLb, Exp2Lb, Exp10Lb;
+
+  if (TLI->getLibFunc(SqrtFn->getName(), SqrtLb))
+    switch (SqrtLb) {
+    case LibFunc_sqrtf:
+      ExpLb = LibFunc_expf;
+      Exp2Lb = LibFunc_exp2f;
+      Exp10Lb = LibFunc_exp10f;
+      break;
+    case LibFunc_sqrt:
+      ExpLb = LibFunc_exp;
+      Exp2Lb = LibFunc_exp2;
+      Exp10Lb = LibFunc_exp10;
+      break;
+    case LibFunc_sqrtl:
+      ExpLb = LibFunc_expl;
+      Exp2Lb = LibFunc_exp2l;
+      Exp10Lb = LibFunc_exp10l;
+      break;
+    default:
+      return nullptr;
+    }
+  else if (SqrtFn->getIntrinsicID() == Intrinsic::sqrt) {
+    if (CI->getType()->getScalarType()->isFloatTy()) {
+      ExpLb = LibFunc_expf;
+      Exp2Lb = LibFunc_exp2f;
+      Exp10Lb = LibFunc_exp10f;
+    } else if (CI->getType()->getScalarType()->isDoubleTy()) {
+      ExpLb = LibFunc_exp;
+      Exp2Lb = LibFunc_exp2;
+      Exp10Lb = LibFunc_exp10;
+    } else
+      return nullptr;
+  } else
+    return nullptr;
+
+  if (ArgLb != ExpLb && ArgLb != Exp2Lb && ArgLb != Exp10Lb &&
+      ArgID != Intrinsic::exp && ArgID != Intrinsic::exp2)
+    return nullptr;
+
+  IRBuilderBase::InsertPointGuard Guard(B);
+  B.SetInsertPoint(Arg);
+  auto *ExpOperand = Arg->getOperand(0);
+  auto *FMul =
+      B.CreateFMulFMF(ExpOperand, ConstantFP::get(ExpOperand->getType(), 0.5),
+                      CI, "merged.sqrt");
+
+  Arg->setOperand(0, FMul);
+  return Arg;
+}
+
 Value *LibCallSimplifier::optimizeSqrt(CallInst *CI, IRBuilderBase &B) {
   Module *M = CI->getModule();
   Function *Callee = CI->getCalledFunction();
@@ -2553,6 +2617,9 @@ Value *LibCallSimplifier::optimizeSqrt(CallInst *CI, IRBuilderBase &B) {
   if (!CI->isFast())
     return Ret;
 
+  if (Value *Opt = mergeSqrtToExp(CI, B))
+    return Opt;
+
   Instruction *I = dyn_cast<Instruction>(CI->getArgOperand(0));
   if (!I || I->getOpcode() != Instruction::FMul || !I->isFast())
     return Ret;
diff --git a/llvm/test/Transforms/InstCombine/sqrt.ll b/llvm/test/Transforms/InstCombine/sqrt.ll
index e396e4dc4d94f3..b9cdbb6f6910c0 100644
--- a/llvm/test/Transforms/InstCombine/sqrt.ll
+++ b/llvm/test/Transforms/InstCombine/sqrt.ll
@@ -90,10 +90,9 @@ define float @sqrt_call_fabs_f32(float %x) {
 
 define double @sqrt_exp(double %x) {
 ; CHECK-LABEL: @sqrt_exp(
-; CHECK-NEXT:    [[MUL:%.*]] = fmul fast double [[X:%.*]], 1.000000e+01
-; CHECK-NEXT:    [[E:%.*]] = call fast double @llvm.exp.f64(double [[MUL]])
-; CHECK-NEXT:    [[RES:%.*]] = call fast double @llvm.sqrt.f64(double [[E]])
-; CHECK-NEXT:    ret double [[RES]]
+; CHECK-NEXT:    [[MERGED_SQRT:%.*]] = fmul fast double [[X:%.*]], 5.000000e+00
+; CHECK-NEXT:    [[E:%.*]] = call fast double @llvm.exp.f64(double [[MERGED_SQRT]])
+; CHECK-NEXT:    ret double [[E]]
 ;
   %mul = fmul fast double %x, 10.0
   %e = call fast double @llvm.exp.f64(double %mul)
@@ -103,10 +102,9 @@ define double @sqrt_exp(double %x) {
 
 define double @sqrt_exp_2(double %x) {
 ; CHECK-LABEL: @sqrt_exp_2(
-; CHECK-NEXT:    [[MUL:%.*]] = fmul fast double [[X:%.*]], 1.000000e+01
-; CHECK-NEXT:    [[E:%.*]] = call fast double @exp(double [[MUL]])
-; CHECK-NEXT:    [[RES:%.*]] = call fast double @sqrt(double [[E]])
-; CHECK-NEXT:    ret double [[RES]]
+; CHECK-NEXT:    [[MERGED_SQRT:%.*]] = fmul fast double [[X:%.*]], 5.000000e+00
+; CHECK-NEXT:    [[E:%.*]] = call fast double @exp(double [[MERGED_SQRT]])
+; CHECK-NEXT:    ret double [[E]]
 ;
   %mul = fmul fast double %x, 10.0
   %e = call fast double @exp(double %mul)
@@ -116,10 +114,9 @@ define double @sqrt_exp_2(double %x) {
 
 define double @sqrt_exp2(double %x) {
 ; CHECK-LABEL: @sqrt_exp2(
-; CHECK-NEXT:    [[MUL:%.*]] = fmul fast double [[X:%.*]], 1.000000e+01
-; CHECK-NEXT:    [[E:%.*]] = call fast double @exp2(double [[MUL]])
-; CHECK-NEXT:    [[RES:%.*]] = call fast double @sqrt(double [[E]])
-; CHECK-NEXT:    ret double [[RES]]
+; CHECK-NEXT:    [[MERGED_SQRT:%.*]] = fmul fast double [[X:%.*]], 5.000000e+00
+; CHECK-NEXT:    [[E:%.*]] = call fast double @exp2(double [[MERGED_SQRT]])
+; CHECK-NEXT:    ret double [[E]]
 ;
   %mul = fmul fast double %x, 10.0
   %e = call fast double @exp2(double %mul)
@@ -129,10 +126,9 @@ define double @sqrt_exp2(double %x) {
 
 define double @sqrt_exp10(double %x) {
 ; CHECK-LABEL: @sqrt_exp10(
-; CHECK-NEXT:    [[MUL:%.*]] = fmul fast double [[X:%.*]], 1.000000e+01
-; CHECK-NEXT:    [[E:%.*]] = call fast double @exp10(double [[MUL]])
-; CHECK-NEXT:    [[RES:%.*]] = call fast double @sqrt(double [[E]])
-; CHECK-NEXT:    ret double [[RES]]
+; CHECK-NEXT:    [[MERGED_SQRT:%.*]] = fmul fast double [[X:%.*]], 5.000000e+00
+; CHECK-NEXT:    [[E:%.*]] = call fast double @exp10(double [[MERGED_SQRT]])
+; CHECK-NEXT:    ret double [[E]]
 ;
   %mul = fmul fast double %x, 10.0
   %e = call fast double @exp10(double %mul)
@@ -142,10 +138,9 @@ define double @sqrt_exp10(double %x) {
 
 define double @sqrt_exp_nofast_1(double %x) {
 ; CHECK-LABEL: @sqrt_exp_nofast_1(
-; CHECK-NEXT:    [[MUL:%.*]] = fmul double [[X:%.*]], 1.000000e+01
-; CHECK-NEXT:    [[E:%.*]] = call fast double @llvm.exp.f64(double [[MUL]])
-; CHECK-NEXT:    [[RES:%.*]] = call fast double @llvm.sqrt.f64(double [[E]])
-; CHECK-NEXT:    ret double [[RES]]
+; CHECK-NEXT:    [[MERGED_SQRT:%.*]] = fmul fast double [[X:%.*]], 5.000000e+00
+; CHECK-NEXT:    [[E:%.*]] = call fast double @llvm.exp.f64(double [[MERGED_SQRT]])
+; CHECK-NEXT:    ret double [[E]]
 ;
   %mul = fmul double %x, 10.0
   %e = call fast double @llvm.exp.f64(double %mul)
@@ -153,6 +148,7 @@ define double @sqrt_exp_nofast_1(double %x) {
   ret double %res
 }
 
+; Negative test
 define double @sqrt_exp_nofast_2(double %x) {
 ; CHECK-LABEL: @sqrt_exp_nofast_2(
 ; CHECK-NEXT:    [[MUL:%.*]] = fmul fast double [[X:%.*]], 1.000000e+01
@@ -166,6 +162,7 @@ define double @sqrt_exp_nofast_2(double %x) {
   ret double %res
 }
 
+; Negative test
 define double @sqrt_exp_nofast_3(double %x) {
 ; CHECK-LABEL: @sqrt_exp_nofast_3(
 ; CHECK-NEXT:    [[MUL:%.*]] = fmul fast double [[X:%.*]], 1.000000e+01
@@ -179,6 +176,7 @@ define double @sqrt_exp_nofast_3(double %x) {
   ret double %res
 }
 
+; Negative test
 define double @sqrt_exp_noconst(double %x, double %y) {
 ; CHECK-LABEL: @sqrt_exp_noconst(
 ; CHECK-NEXT:    [[MUL:%.*]] = fmul fast double [[X:%.*]], [[Y:%.*]]



More information about the llvm-commits mailing list