[llvm] [InstCombine] Optimise x / sqrt(y / z) with fast-math pattern. (PR #76737)
Zain Jaffal via llvm-commits
llvm-commits at lists.llvm.org
Thu Feb 8 05:42:32 PST 2024
https://github.com/zjaffal updated https://github.com/llvm/llvm-project/pull/76737
>From f106f580c23702c1c0543cdcafd8fe46bdf3cc96 Mon Sep 17 00:00:00 2001
From: Zain Jaffal <zain at jjaffal.com>
Date: Tue, 2 Jan 2024 17:14:21 +0000
Subject: [PATCH 1/4] [InstCombine] Optimise x / sqrt(y / z) with fast-math
pattern.
Replace the pattern with
x * sqrt(z/y)
---
.../InstCombine/InstCombineMulDivRem.cpp | 28 +++++++++++++++++
llvm/test/Transforms/InstCombine/fdiv-sqrt.ll | 30 +++++++++----------
2 files changed, 43 insertions(+), 15 deletions(-)
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
index f9cee9dfcfadae..89592929f3c896 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
@@ -1709,6 +1709,31 @@ static Instruction *foldFDivPowDivisor(BinaryOperator &I,
return BinaryOperator::CreateFMulFMF(Op0, Pow, &I);
}
+/// Convert div to mul if we have an sqrt divisor iff sqrt's operand is a fdiv
+/// instruction.
+static Instruction *foldFDivSqrtDivisor(BinaryOperator &I,
+ InstCombiner::BuilderTy &Builder) {
+ // X / sqrt(Y / Z) --> X * sqrt(Z / Y)
+ Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
+ auto *II = dyn_cast<IntrinsicInst>(Op1);
+ if (!II || II->getIntrinsicID() != Intrinsic::sqrt || !II->hasOneUse() ||
+ !I.hasAllowReassoc() || !I.hasAllowReciprocal())
+ return nullptr;
+
+ Value *Y, *Z;
+ auto *DivOp = dyn_cast<Instruction>(II->getOperand(0));
+ if (!DivOp || !DivOp->hasOneUse() || !DivOp->hasAllowReassoc() ||
+ !I.hasAllowReciprocal())
+ return nullptr;
+ if (match(DivOp, m_FDiv(m_Value(Y), m_Value(Z)))) {
+ Value *SwapDiv = Builder.CreateFDivFMF(Z, Y, DivOp);
+ Value *NewSqrt = Builder.CreateIntrinsic(II->getIntrinsicID(),
+ II->getType(), {SwapDiv}, II);
+ return BinaryOperator::CreateFMulFMF(Op0, NewSqrt, &I);
+ }
+ return nullptr;
+}
+
Instruction *InstCombinerImpl::visitFDiv(BinaryOperator &I) {
Module *M = I.getModule();
@@ -1816,6 +1841,9 @@ Instruction *InstCombinerImpl::visitFDiv(BinaryOperator &I) {
if (Instruction *Mul = foldFDivPowDivisor(I, Builder))
return Mul;
+ if (Instruction *Mul = foldFDivSqrtDivisor(I, Builder))
+ return Mul;
+
// pow(X, Y) / X --> pow(X, Y-1)
if (I.hasAllowReassoc() &&
match(Op0, m_OneUse(m_Intrinsic<Intrinsic::pow>(m_Specific(Op1),
diff --git a/llvm/test/Transforms/InstCombine/fdiv-sqrt.ll b/llvm/test/Transforms/InstCombine/fdiv-sqrt.ll
index 346271be7da761..0eafdfea1c1519 100644
--- a/llvm/test/Transforms/InstCombine/fdiv-sqrt.ll
+++ b/llvm/test/Transforms/InstCombine/fdiv-sqrt.ll
@@ -6,9 +6,9 @@ declare double @llvm.sqrt.f64(double)
define double @sqrt_div_fast(double %x, double %y, double %z) {
; CHECK-LABEL: @sqrt_div_fast(
; CHECK-NEXT: entry:
-; CHECK-NEXT: [[DIV:%.*]] = fdiv fast double [[Y:%.*]], [[Z:%.*]]
-; CHECK-NEXT: [[SQRT:%.*]] = call fast double @llvm.sqrt.f64(double [[DIV]])
-; CHECK-NEXT: [[DIV1:%.*]] = fdiv fast double [[X:%.*]], [[SQRT]]
+; CHECK-NEXT: [[TMP0:%.*]] = fdiv fast double [[Z:%.*]], [[Y:%.*]]
+; CHECK-NEXT: [[TMP1:%.*]] = call fast double @llvm.sqrt.f64(double [[TMP0]])
+; CHECK-NEXT: [[DIV1:%.*]] = fmul fast double [[TMP1]], [[X:%.*]]
; CHECK-NEXT: ret double [[DIV1]]
;
entry:
@@ -36,9 +36,9 @@ entry:
define double @sqrt_div_reassoc_arcp(double %x, double %y, double %z) {
; CHECK-LABEL: @sqrt_div_reassoc_arcp(
; CHECK-NEXT: entry:
-; CHECK-NEXT: [[DIV:%.*]] = fdiv reassoc arcp double [[Y:%.*]], [[Z:%.*]]
-; CHECK-NEXT: [[SQRT:%.*]] = call reassoc arcp double @llvm.sqrt.f64(double [[DIV]])
-; CHECK-NEXT: [[DIV1:%.*]] = fdiv reassoc arcp double [[X:%.*]], [[SQRT]]
+; CHECK-NEXT: [[TMP0:%.*]] = fdiv reassoc arcp double [[Z:%.*]], [[Y:%.*]]
+; CHECK-NEXT: [[TMP1:%.*]] = call reassoc arcp double @llvm.sqrt.f64(double [[TMP0]])
+; CHECK-NEXT: [[DIV1:%.*]] = fmul reassoc arcp double [[TMP1]], [[X:%.*]]
; CHECK-NEXT: ret double [[DIV1]]
;
entry:
@@ -66,9 +66,9 @@ entry:
define double @sqrt_div_reassoc_missing2(double %x, double %y, double %z) {
; CHECK-LABEL: @sqrt_div_reassoc_missing2(
; CHECK-NEXT: entry:
-; CHECK-NEXT: [[DIV:%.*]] = fdiv reassoc arcp double [[Y:%.*]], [[Z:%.*]]
-; CHECK-NEXT: [[SQRT:%.*]] = call arcp double @llvm.sqrt.f64(double [[DIV]])
-; CHECK-NEXT: [[DIV1:%.*]] = fdiv reassoc arcp double [[X:%.*]], [[SQRT]]
+; CHECK-NEXT: [[TMP0:%.*]] = fdiv reassoc arcp double [[Z:%.*]], [[Y:%.*]]
+; CHECK-NEXT: [[TMP1:%.*]] = call arcp double @llvm.sqrt.f64(double [[TMP0]])
+; CHECK-NEXT: [[DIV1:%.*]] = fmul reassoc arcp double [[TMP1]], [[X:%.*]]
; CHECK-NEXT: ret double [[DIV1]]
;
entry:
@@ -96,9 +96,9 @@ entry:
define double @sqrt_div_arcp_missing(double %x, double %y, double %z) {
; CHECK-LABEL: @sqrt_div_arcp_missing(
; CHECK-NEXT: entry:
-; CHECK-NEXT: [[DIV:%.*]] = fdiv reassoc double [[Y:%.*]], [[Z:%.*]]
-; CHECK-NEXT: [[SQRT:%.*]] = call reassoc arcp double @llvm.sqrt.f64(double [[DIV]])
-; CHECK-NEXT: [[DIV1:%.*]] = fdiv reassoc arcp double [[X:%.*]], [[SQRT]]
+; CHECK-NEXT: [[TMP0:%.*]] = fdiv reassoc double [[Z:%.*]], [[Y:%.*]]
+; CHECK-NEXT: [[TMP1:%.*]] = call reassoc arcp double @llvm.sqrt.f64(double [[TMP0]])
+; CHECK-NEXT: [[DIV1:%.*]] = fmul reassoc arcp double [[TMP1]], [[X:%.*]]
; CHECK-NEXT: ret double [[DIV1]]
;
entry:
@@ -111,9 +111,9 @@ entry:
define double @sqrt_div_arcp_missing2(double %x, double %y, double %z) {
; CHECK-LABEL: @sqrt_div_arcp_missing2(
; CHECK-NEXT: entry:
-; CHECK-NEXT: [[DIV:%.*]] = fdiv reassoc arcp double [[Y:%.*]], [[Z:%.*]]
-; CHECK-NEXT: [[SQRT:%.*]] = call reassoc double @llvm.sqrt.f64(double [[DIV]])
-; CHECK-NEXT: [[DIV1:%.*]] = fdiv reassoc arcp double [[X:%.*]], [[SQRT]]
+; CHECK-NEXT: [[TMP0:%.*]] = fdiv reassoc arcp double [[Z:%.*]], [[Y:%.*]]
+; CHECK-NEXT: [[TMP1:%.*]] = call reassoc double @llvm.sqrt.f64(double [[TMP0]])
+; CHECK-NEXT: [[DIV1:%.*]] = fmul reassoc arcp double [[TMP1]], [[X:%.*]]
; CHECK-NEXT: ret double [[DIV1]]
;
entry:
>From 303ce4ccebf27c37763bef3e8f082c833cafe8e7 Mon Sep 17 00:00:00 2001
From: Zain Jaffal <zain at jjaffal.com>
Date: Tue, 9 Jan 2024 11:04:11 +0000
Subject: [PATCH 2/4] [InstCombine] Make sure all instructions have `arcp` and
`reassoc` flags
---
.../Transforms/InstCombine/InstCombineMulDivRem.cpp | 8 +++++---
llvm/test/Transforms/InstCombine/fdiv-sqrt.ll | 12 ++++++------
2 files changed, 11 insertions(+), 9 deletions(-)
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
index 89592929f3c896..3d761d37d508ef 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
@@ -1714,10 +1714,12 @@ static Instruction *foldFDivPowDivisor(BinaryOperator &I,
static Instruction *foldFDivSqrtDivisor(BinaryOperator &I,
InstCombiner::BuilderTy &Builder) {
// X / sqrt(Y / Z) --> X * sqrt(Z / Y)
+ if (!I.hasAllowReassoc() || !I.hasAllowReciprocal())
+ return nullptr;
Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
auto *II = dyn_cast<IntrinsicInst>(Op1);
if (!II || II->getIntrinsicID() != Intrinsic::sqrt || !II->hasOneUse() ||
- !I.hasAllowReassoc() || !I.hasAllowReciprocal())
+ !II->hasAllowReassoc() || !II->hasAllowReciprocal())
return nullptr;
Value *Y, *Z;
@@ -1727,8 +1729,8 @@ static Instruction *foldFDivSqrtDivisor(BinaryOperator &I,
return nullptr;
if (match(DivOp, m_FDiv(m_Value(Y), m_Value(Z)))) {
Value *SwapDiv = Builder.CreateFDivFMF(Z, Y, DivOp);
- Value *NewSqrt = Builder.CreateIntrinsic(II->getIntrinsicID(),
- II->getType(), {SwapDiv}, II);
+ Value *NewSqrt =
+ Builder.CreateUnaryIntrinsic(II->getIntrinsicID(), SwapDiv, II);
return BinaryOperator::CreateFMulFMF(Op0, NewSqrt, &I);
}
return nullptr;
diff --git a/llvm/test/Transforms/InstCombine/fdiv-sqrt.ll b/llvm/test/Transforms/InstCombine/fdiv-sqrt.ll
index 0eafdfea1c1519..361837e90ed46d 100644
--- a/llvm/test/Transforms/InstCombine/fdiv-sqrt.ll
+++ b/llvm/test/Transforms/InstCombine/fdiv-sqrt.ll
@@ -66,9 +66,9 @@ entry:
define double @sqrt_div_reassoc_missing2(double %x, double %y, double %z) {
; CHECK-LABEL: @sqrt_div_reassoc_missing2(
; CHECK-NEXT: entry:
-; CHECK-NEXT: [[TMP0:%.*]] = fdiv reassoc arcp double [[Z:%.*]], [[Y:%.*]]
-; CHECK-NEXT: [[TMP1:%.*]] = call arcp double @llvm.sqrt.f64(double [[TMP0]])
-; CHECK-NEXT: [[DIV1:%.*]] = fmul reassoc arcp double [[TMP1]], [[X:%.*]]
+; CHECK-NEXT: [[DIV:%.*]] = fdiv reassoc arcp double [[Y:%.*]], [[Z:%.*]]
+; CHECK-NEXT: [[SQRT:%.*]] = call arcp double @llvm.sqrt.f64(double [[DIV]])
+; CHECK-NEXT: [[DIV1:%.*]] = fdiv reassoc arcp double [[X:%.*]], [[SQRT]]
; CHECK-NEXT: ret double [[DIV1]]
;
entry:
@@ -111,9 +111,9 @@ entry:
define double @sqrt_div_arcp_missing2(double %x, double %y, double %z) {
; CHECK-LABEL: @sqrt_div_arcp_missing2(
; CHECK-NEXT: entry:
-; CHECK-NEXT: [[TMP0:%.*]] = fdiv reassoc arcp double [[Z:%.*]], [[Y:%.*]]
-; CHECK-NEXT: [[TMP1:%.*]] = call reassoc double @llvm.sqrt.f64(double [[TMP0]])
-; CHECK-NEXT: [[DIV1:%.*]] = fmul reassoc arcp double [[TMP1]], [[X:%.*]]
+; CHECK-NEXT: [[DIV:%.*]] = fdiv reassoc arcp double [[Y:%.*]], [[Z:%.*]]
+; CHECK-NEXT: [[SQRT:%.*]] = call reassoc double @llvm.sqrt.f64(double [[DIV]])
+; CHECK-NEXT: [[DIV1:%.*]] = fdiv reassoc arcp double [[X:%.*]], [[SQRT]]
; CHECK-NEXT: ret double [[DIV1]]
;
entry:
>From 40ab48407db4b8a4b624eb6033483e08d29c2c36 Mon Sep 17 00:00:00 2001
From: Matt Arsenault <arsenm2 at gmail.com>
Date: Thu, 8 Feb 2024 16:51:03 +0530
Subject: [PATCH 3/4] Reorder checks
Co-authored-by: Zain Jaffal <zain at jjaffal.com>
---
llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
index 3d761d37d508ef..e07c96764ce90a 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
@@ -1724,7 +1724,7 @@ static Instruction *foldFDivSqrtDivisor(BinaryOperator &I,
Value *Y, *Z;
auto *DivOp = dyn_cast<Instruction>(II->getOperand(0));
- if (!DivOp || !DivOp->hasOneUse() || !DivOp->hasAllowReassoc() ||
+ if (!DivOp || !DivOp->hasAllowReassoc() || !I.hasAllowReciprocal() || !DivOp->hasOneUse())
!I.hasAllowReciprocal())
return nullptr;
if (match(DivOp, m_FDiv(m_Value(Y), m_Value(Z)))) {
>From ab33697d651bfd33129c3228ce528e5321907175 Mon Sep 17 00:00:00 2001
From: Matt Arsenault <arsenm2 at gmail.com>
Date: Thu, 8 Feb 2024 17:00:13 +0530
Subject: [PATCH 4/4] Fix duplicated condition and formatting
---
llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
index e07c96764ce90a..c1c16ea84f732d 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
@@ -1724,8 +1724,8 @@ static Instruction *foldFDivSqrtDivisor(BinaryOperator &I,
Value *Y, *Z;
auto *DivOp = dyn_cast<Instruction>(II->getOperand(0));
- if (!DivOp || !DivOp->hasAllowReassoc() || !I.hasAllowReciprocal() || !DivOp->hasOneUse())
- !I.hasAllowReciprocal())
+ if (!DivOp || !DivOp->hasAllowReassoc() || !I.hasAllowReciprocal() ||
+ !DivOp->hasOneUse())
return nullptr;
if (match(DivOp, m_FDiv(m_Value(Y), m_Value(Z)))) {
Value *SwapDiv = Builder.CreateFDivFMF(Z, Y, DivOp);
More information about the llvm-commits
mailing list