[llvm] Zjaffal/fix fold fdiv sqrt (PR #81970)

Zain Jaffal via llvm-commits llvm-commits at lists.llvm.org
Thu Feb 15 22:16:31 PST 2024


https://github.com/zjaffal created https://github.com/llvm/llvm-project/pull/81970

This patch fixes the issues introduced in https://github.com/llvm/llvm-project/commit/bb5c3899d1936ebdf7ebf5ca4347ee2e057bee7f. 

I moved the check for the instruction to be div before I check for the fast math flags which resolves the crash in 

```
float a, b;
double sqrt();
void c() { b = a / sqrt(a); }
```

>From f17218d79a28c2c9e10c6dd1c57b1a7f7472ddc7 Mon Sep 17 00:00:00 2001
From: Zain Jaffal <zain at jjaffal.com>
Date: Thu, 15 Feb 2024 23:08:04 +0000
Subject: [PATCH 1/2] Revert "Revert "[InstCombine] Optimise x / sqrt(y / z)
 with fast-math pattern. (#76737)""

This reverts commit f022aaf4e722eae9d0feaf7715a5d8960f4d017b.
---
 .../InstCombine/InstCombineMulDivRem.cpp      | 30 +++++++++++++++++++
 llvm/test/Transforms/InstCombine/fdiv-sqrt.ll | 18 +++++------
 2 files changed, 39 insertions(+), 9 deletions(-)

diff --git a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
index 0bd4b6d1a835af..82530b3df5d32e 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
@@ -1706,6 +1706,33 @@ static Instruction *foldFDivPowDivisor(BinaryOperator &I,
   return BinaryOperator::CreateFMulFMF(Op0, Pow, &I);
 }
 
+/// Convert div to mul if we have an sqrt divisor iff sqrt's operand is a fdiv
+/// instruction.
+static Instruction *foldFDivSqrtDivisor(BinaryOperator &I,
+                                        InstCombiner::BuilderTy &Builder) {
+  // X / sqrt(Y / Z) -->  X * sqrt(Z / Y)
+  if (!I.hasAllowReassoc() || !I.hasAllowReciprocal())
+    return nullptr;
+  Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
+  auto *II = dyn_cast<IntrinsicInst>(Op1);
+  if (!II || II->getIntrinsicID() != Intrinsic::sqrt || !II->hasOneUse() ||
+      !II->hasAllowReassoc() || !II->hasAllowReciprocal())
+    return nullptr;
+
+  Value *Y, *Z;
+  auto *DivOp = dyn_cast<Instruction>(II->getOperand(0));
+  if (!DivOp || !DivOp->hasAllowReassoc() || !I.hasAllowReciprocal() ||
+      !DivOp->hasOneUse())
+    return nullptr;
+  if (match(DivOp, m_FDiv(m_Value(Y), m_Value(Z)))) {
+    Value *SwapDiv = Builder.CreateFDivFMF(Z, Y, DivOp);
+    Value *NewSqrt =
+        Builder.CreateUnaryIntrinsic(II->getIntrinsicID(), SwapDiv, II);
+    return BinaryOperator::CreateFMulFMF(Op0, NewSqrt, &I);
+  }
+  return nullptr;
+}
+
 Instruction *InstCombinerImpl::visitFDiv(BinaryOperator &I) {
   Module *M = I.getModule();
 
@@ -1813,6 +1840,9 @@ Instruction *InstCombinerImpl::visitFDiv(BinaryOperator &I) {
   if (Instruction *Mul = foldFDivPowDivisor(I, Builder))
     return Mul;
 
+  if (Instruction *Mul = foldFDivSqrtDivisor(I, Builder))
+    return Mul;
+
   // pow(X, Y) / X --> pow(X, Y-1)
   if (I.hasAllowReassoc() &&
       match(Op0, m_OneUse(m_Intrinsic<Intrinsic::pow>(m_Specific(Op1),
diff --git a/llvm/test/Transforms/InstCombine/fdiv-sqrt.ll b/llvm/test/Transforms/InstCombine/fdiv-sqrt.ll
index 346271be7da761..361837e90ed46d 100644
--- a/llvm/test/Transforms/InstCombine/fdiv-sqrt.ll
+++ b/llvm/test/Transforms/InstCombine/fdiv-sqrt.ll
@@ -6,9 +6,9 @@ declare double @llvm.sqrt.f64(double)
 define double @sqrt_div_fast(double %x, double %y, double %z) {
 ; CHECK-LABEL: @sqrt_div_fast(
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[DIV:%.*]] = fdiv fast double [[Y:%.*]], [[Z:%.*]]
-; CHECK-NEXT:    [[SQRT:%.*]] = call fast double @llvm.sqrt.f64(double [[DIV]])
-; CHECK-NEXT:    [[DIV1:%.*]] = fdiv fast double [[X:%.*]], [[SQRT]]
+; CHECK-NEXT:    [[TMP0:%.*]] = fdiv fast double [[Z:%.*]], [[Y:%.*]]
+; CHECK-NEXT:    [[TMP1:%.*]] = call fast double @llvm.sqrt.f64(double [[TMP0]])
+; CHECK-NEXT:    [[DIV1:%.*]] = fmul fast double [[TMP1]], [[X:%.*]]
 ; CHECK-NEXT:    ret double [[DIV1]]
 ;
 entry:
@@ -36,9 +36,9 @@ entry:
 define double @sqrt_div_reassoc_arcp(double %x, double %y, double %z) {
 ; CHECK-LABEL: @sqrt_div_reassoc_arcp(
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[DIV:%.*]] = fdiv reassoc arcp double [[Y:%.*]], [[Z:%.*]]
-; CHECK-NEXT:    [[SQRT:%.*]] = call reassoc arcp double @llvm.sqrt.f64(double [[DIV]])
-; CHECK-NEXT:    [[DIV1:%.*]] = fdiv reassoc arcp double [[X:%.*]], [[SQRT]]
+; CHECK-NEXT:    [[TMP0:%.*]] = fdiv reassoc arcp double [[Z:%.*]], [[Y:%.*]]
+; CHECK-NEXT:    [[TMP1:%.*]] = call reassoc arcp double @llvm.sqrt.f64(double [[TMP0]])
+; CHECK-NEXT:    [[DIV1:%.*]] = fmul reassoc arcp double [[TMP1]], [[X:%.*]]
 ; CHECK-NEXT:    ret double [[DIV1]]
 ;
 entry:
@@ -96,9 +96,9 @@ entry:
 define double @sqrt_div_arcp_missing(double %x, double %y, double %z) {
 ; CHECK-LABEL: @sqrt_div_arcp_missing(
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[DIV:%.*]] = fdiv reassoc double [[Y:%.*]], [[Z:%.*]]
-; CHECK-NEXT:    [[SQRT:%.*]] = call reassoc arcp double @llvm.sqrt.f64(double [[DIV]])
-; CHECK-NEXT:    [[DIV1:%.*]] = fdiv reassoc arcp double [[X:%.*]], [[SQRT]]
+; CHECK-NEXT:    [[TMP0:%.*]] = fdiv reassoc double [[Z:%.*]], [[Y:%.*]]
+; CHECK-NEXT:    [[TMP1:%.*]] = call reassoc arcp double @llvm.sqrt.f64(double [[TMP0]])
+; CHECK-NEXT:    [[DIV1:%.*]] = fmul reassoc arcp double [[TMP1]], [[X:%.*]]
 ; CHECK-NEXT:    ret double [[DIV1]]
 ;
 entry:

>From 22a64c438c209bd5f919466d9cc24dcb3b1ae044 Mon Sep 17 00:00:00 2001
From: Zain Jaffal <zain at jjaffal.com>
Date: Fri, 16 Feb 2024 06:10:40 +0000
Subject: [PATCH 2/2] [InstCombine] check if operand is div in fold
 FDivSqrtDivisor

This change resolves the crash introduced in bb5c389.
---
 .../InstCombine/InstCombineMulDivRem.cpp        | 17 +++++++++--------
 llvm/test/Transforms/InstCombine/fdiv-sqrt.ll   | 17 +++++++++++++++++
 2 files changed, 26 insertions(+), 8 deletions(-)

diff --git a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
index 82530b3df5d32e..912d9ac404052a 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
@@ -1721,16 +1721,17 @@ static Instruction *foldFDivSqrtDivisor(BinaryOperator &I,
 
   Value *Y, *Z;
   auto *DivOp = dyn_cast<Instruction>(II->getOperand(0));
-  if (!DivOp || !DivOp->hasAllowReassoc() || !I.hasAllowReciprocal() ||
+  if (!DivOp)
+    return nullptr;
+  if (!match(DivOp, m_FDiv(m_Value(Y), m_Value(Z))))
+    return nullptr;
+  if (!DivOp->hasAllowReassoc() || !I.hasAllowReciprocal() ||
       !DivOp->hasOneUse())
     return nullptr;
-  if (match(DivOp, m_FDiv(m_Value(Y), m_Value(Z)))) {
-    Value *SwapDiv = Builder.CreateFDivFMF(Z, Y, DivOp);
-    Value *NewSqrt =
-        Builder.CreateUnaryIntrinsic(II->getIntrinsicID(), SwapDiv, II);
-    return BinaryOperator::CreateFMulFMF(Op0, NewSqrt, &I);
-  }
-  return nullptr;
+  Value *SwapDiv = Builder.CreateFDivFMF(Z, Y, DivOp);
+  Value *NewSqrt =
+      Builder.CreateUnaryIntrinsic(II->getIntrinsicID(), SwapDiv, II);
+  return BinaryOperator::CreateFMulFMF(Op0, NewSqrt, &I);
 }
 
 Instruction *InstCombinerImpl::visitFDiv(BinaryOperator &I) {
diff --git a/llvm/test/Transforms/InstCombine/fdiv-sqrt.ll b/llvm/test/Transforms/InstCombine/fdiv-sqrt.ll
index 361837e90ed46d..58cc7c297e90a0 100644
--- a/llvm/test/Transforms/InstCombine/fdiv-sqrt.ll
+++ b/llvm/test/Transforms/InstCombine/fdiv-sqrt.ll
@@ -173,3 +173,20 @@ entry:
   ret double %div1
 }
 
+; Function Attrs: nounwind ssp uwtable(sync)
+define float @sqrt_non_div_operator(float %a) {
+; CHECK-LABEL: @sqrt_non_div_operator(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[CONV:%.*]] = fpext float [[A:%.*]] to double
+; CHECK-NEXT:    [[SQRT:%.*]] = call fast double @llvm.sqrt.f64(double [[CONV]])
+; CHECK-NEXT:    [[DIV:%.*]] = fdiv fast double [[CONV]], [[SQRT]]
+; CHECK-NEXT:    [[CONV2:%.*]] = fptrunc double [[DIV]] to float
+; CHECK-NEXT:    ret float [[CONV2]]
+;
+entry:
+  %conv = fpext float %a to double
+  %sqrt = call fast double @llvm.sqrt.f64(double %conv)
+  %div = fdiv fast double %conv, %sqrt
+  %conv2 = fptrunc double %div to float
+  ret float %conv2
+}



More information about the llvm-commits mailing list