[llvm] [InstCombine] optimize exp(exp(x)) / exp(x) with fast-math (PR #66177)

Zain Jaffal via llvm-commits llvm-commits at lists.llvm.org
Wed Sep 13 00:12:26 PDT 2023


https://github.com/zjaffal created https://github.com/llvm/llvm-project/pull/66177:

This patch enables us to replace exp(exp(x)) / exp(x) to exp(exp(x) - x) making us avoid the div instruction

Closes #65608 

>From 72e81ef37bde27d7e78e2231326648995c3c0daf Mon Sep 17 00:00:00 2001
From: Zain Jaffal <z_jaffal at apple.com>
Date: Wed, 13 Sep 2023 10:05:25 +0300
Subject: [PATCH] [InstCombine] optimize exp(exp(x)) / exp(x) with fast-math

This patch enables us to replace exp(exp(x)) / exp(x) to exp(exp(x) - x)
making us avoid the div instruction
---
 .../InstCombine/InstCombineMulDivRem.cpp      | 12 ++++
 llvm/test/Transforms/InstCombine/fdiv-exp.ll  | 55 +++++++++++++++++++
 2 files changed, 67 insertions(+)
 create mode 100644 llvm/test/Transforms/InstCombine/fdiv-exp.ll

diff --git a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
index 86cacf979839126..692fcf01455835a 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
@@ -1748,6 +1748,18 @@ Instruction *InstCombinerImpl::visitFDiv(BinaryOperator &I) {
 
   if (Instruction *Mul = foldFDivPowDivisor(I, Builder))
     return Mul;
+  Value *ExpX;
+  // exp(exp(X)) / exp(X) -> exp(exp(X) - X)
+  if (match(Op0, m_Intrinsic<Intrinsic::exp>(m_Value(ExpX))) &&
+      match(Op1, m_Intrinsic<Intrinsic::exp>(m_Value(X))) &&
+      match(ExpX, m_Intrinsic<Intrinsic::exp>(m_Specific(X)))) {
+    // check that exp(x) is only used in the div expression.
+    if (Op1->hasNUses(2)) {
+      Value *XY = Builder.CreateFSubFMF(ExpX, X, &I);
+      auto *NewPow = Builder.CreateUnaryIntrinsic(Intrinsic::exp, XY, &I);
+      return replaceInstUsesWith(I, NewPow);
+    }
+  }
 
   // pow(X, Y) / X --> pow(X, Y-1)
   if (I.hasAllowReassoc() &&
diff --git a/llvm/test/Transforms/InstCombine/fdiv-exp.ll b/llvm/test/Transforms/InstCombine/fdiv-exp.ll
new file mode 100644
index 000000000000000..3ea281cd4bb3af5
--- /dev/null
+++ b/llvm/test/Transforms/InstCombine/fdiv-exp.ll
@@ -0,0 +1,55 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 2
+; RUN: opt -S -passes=instcombine < %s | FileCheck %s
+
+define double @fdiv_exp(double %x) {
+; CHECK-LABEL: define double @fdiv_exp
+; CHECK-SAME: (double [[X:%.*]]) {
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[EXP_X:%.*]] = call fast double @llvm.exp.f64(double [[X]])
+; CHECK-NEXT:    [[TMP0:%.*]] = fsub fast double [[EXP_X]], [[X]]
+; CHECK-NEXT:    [[DIV:%.*]] = call fast double @llvm.exp.f64(double [[TMP0]])
+; CHECK-NEXT:    ret double [[DIV]]
+;
+entry:
+  %exp_x = call fast double @llvm.exp.f64(double %x)
+  %exp_exp_x = call fast double @llvm.exp.f64(double %exp_x)
+  %div = fdiv fast double %exp_exp_x, %exp_x
+  ret double %div
+}
+
+define double @fdiv_exp_multiple_uses(double %x) {
+; CHECK-LABEL: define double @fdiv_exp_multiple_uses
+; CHECK-SAME: (double [[X:%.*]]) {
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[EXP_X:%.*]] = call fast double @llvm.exp.f64(double [[X]])
+; CHECK-NEXT:    [[EXP_EXP_X:%.*]] = call fast double @llvm.exp.f64(double [[EXP_X]])
+; CHECK-NEXT:    [[DIV:%.*]] = fdiv fast double [[EXP_EXP_X]], [[EXP_X]]
+; CHECK-NEXT:    call void @use(double [[EXP_X]])
+; CHECK-NEXT:    ret double [[DIV]]
+;
+entry:
+  %exp_x = call fast double @llvm.exp.f64(double %x)
+  %exp_exp_x = call fast double @llvm.exp.f64(double %exp_x)
+  %div = fdiv fast double %exp_exp_x, %exp_x
+  call void @use(double %exp_x)
+  ret double %div
+}
+
+define double @fdiv_exp_swapped(double %x) {
+; CHECK-LABEL: define double @fdiv_exp_swapped
+; CHECK-SAME: (double [[X:%.*]]) {
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[EXP_X:%.*]] = call fast double @llvm.exp.f64(double [[X]])
+; CHECK-NEXT:    [[TMP0:%.*]] = fsub fast double [[X]], [[EXP_X]]
+; CHECK-NEXT:    [[DIV:%.*]] = call fast double @llvm.exp.f64(double [[TMP0]])
+; CHECK-NEXT:    ret double [[DIV]]
+;
+entry:
+  %exp_x = call fast double @llvm.exp.f64(double %x)
+  %exp_exp_x = call fast double @llvm.exp.f64(double %exp_x)
+  %div = fdiv fast double %exp_x, %exp_exp_x
+  ret double %div
+}
+
+declare double @llvm.exp.f64(double)
+declare void @use(double)



More information about the llvm-commits mailing list