[llvm] [InstCombine] Optimise x / sqrt(y / z) with fast-math pattern. (PR #76737)
via llvm-commits
llvm-commits at lists.llvm.org
Tue Jan 2 09:16:33 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-llvm-transforms
Author: Zain Jaffal (zjaffal)
<details>
<summary>Changes</summary>
Replace the pattern with
x * sqrt(z/y)
---
Full diff: https://github.com/llvm/llvm-project/pull/76737.diff
2 Files Affected:
- (modified) llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp (+28)
- (added) llvm/test/Transforms/InstCombine/fdiv-sqrt.ll (+85)
``````````diff
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
index f0ea3d9fcad5df..172ce18d003aa8 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
@@ -1701,6 +1701,31 @@ static Instruction *foldFDivPowDivisor(BinaryOperator &I,
return BinaryOperator::CreateFMulFMF(Op0, Pow, &I);
}
+/// Convert div to mul if we have an sqrt divisor iff sqrt's operand is a fdiv
+/// instruction.
+static Instruction *foldFDivSqrtDivisor(BinaryOperator &I,
+ InstCombiner::BuilderTy &Builder) {
+ // X / sqrt(Y / Z) --> X * sqrt(Z / Y)
+ Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
+ auto *II = dyn_cast<IntrinsicInst>(Op1);
+ if (!II || II->getIntrinsicID() != Intrinsic::sqrt || !II->hasOneUse() ||
+ !I.hasAllowReassoc() || !I.hasAllowReciprocal())
+ return nullptr;
+
+ Value *Y, *Z;
+ auto *DivOp = dyn_cast<Instruction>(II->getOperand(0));
+ if (!DivOp || !DivOp->hasOneUse() || !DivOp->hasAllowReassoc() ||
+ !I.hasAllowReciprocal())
+ return nullptr;
+ if (match(DivOp, m_FDiv(m_Value(Y), m_Value(Z)))) {
+ Value *SwapDiv = Builder.CreateFDivFMF(Z, Y, DivOp);
+ Value *NewSqrt = Builder.CreateIntrinsic(II->getIntrinsicID(),
+ II->getType(), {SwapDiv}, II);
+ return BinaryOperator::CreateFMulFMF(Op0, NewSqrt, &I);
+ }
+ return nullptr;
+}
+
Instruction *InstCombinerImpl::visitFDiv(BinaryOperator &I) {
Module *M = I.getModule();
@@ -1808,6 +1833,9 @@ Instruction *InstCombinerImpl::visitFDiv(BinaryOperator &I) {
if (Instruction *Mul = foldFDivPowDivisor(I, Builder))
return Mul;
+ if (Instruction *Mul = foldFDivSqrtDivisor(I, Builder))
+ return Mul;
+
// pow(X, Y) / X --> pow(X, Y-1)
if (I.hasAllowReassoc() &&
match(Op0, m_OneUse(m_Intrinsic<Intrinsic::pow>(m_Specific(Op1),
diff --git a/llvm/test/Transforms/InstCombine/fdiv-sqrt.ll b/llvm/test/Transforms/InstCombine/fdiv-sqrt.ll
new file mode 100644
index 00000000000000..3f41b0f24ae040
--- /dev/null
+++ b/llvm/test/Transforms/InstCombine/fdiv-sqrt.ll
@@ -0,0 +1,85 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
+; RUN: opt -S -passes=instcombine < %s | FileCheck %s
+
+declare double @llvm.sqrt.f64(double)
+
+define double @sqrt_div_fast(double %x, double %y, double %z) {
+; CHECK-LABEL: @sqrt_div_fast(
+; CHECK-NEXT: entry:
+; CHECK-NEXT: [[TMP0:%.*]] = fdiv fast double [[Z:%.*]], [[Y:%.*]]
+; CHECK-NEXT: [[TMP1:%.*]] = call fast double @llvm.sqrt.f64(double [[TMP0]])
+; CHECK-NEXT: [[DIV1:%.*]] = fmul fast double [[TMP1]], [[X:%.*]]
+; CHECK-NEXT: ret double [[DIV1]]
+;
+entry:
+ %div = fdiv fast double %y, %z
+ %sqrt = call fast double @llvm.sqrt.f64(double %div)
+ %div1 = fdiv fast double %x, %sqrt
+ ret double %div1
+}
+
+define double @sqrt_div(double %x, double %y, double %z) {
+; CHECK-LABEL: @sqrt_div(
+; CHECK-NEXT: entry:
+; CHECK-NEXT: [[DIV:%.*]] = fdiv double [[Y:%.*]], [[Z:%.*]]
+; CHECK-NEXT: [[SQRT:%.*]] = call double @llvm.sqrt.f64(double [[DIV]])
+; CHECK-NEXT: [[DIV1:%.*]] = fdiv double [[X:%.*]], [[SQRT]]
+; CHECK-NEXT: ret double [[DIV1]]
+;
+entry:
+ %div = fdiv double %y, %z
+ %sqrt = call double @llvm.sqrt.f64(double %div)
+ %div1 = fdiv double %x, %sqrt
+ ret double %div1
+}
+
+define double @sqrt_div_reassoc_arcp(double %x, double %y, double %z) {
+; CHECK-LABEL: @sqrt_div_reassoc_arcp(
+; CHECK-NEXT: entry:
+; CHECK-NEXT: [[TMP0:%.*]] = fdiv reassoc arcp double [[Z:%.*]], [[Y:%.*]]
+; CHECK-NEXT: [[TMP1:%.*]] = call reassoc arcp double @llvm.sqrt.f64(double [[TMP0]])
+; CHECK-NEXT: [[DIV1:%.*]] = fmul reassoc arcp double [[TMP1]], [[X:%.*]]
+; CHECK-NEXT: ret double [[DIV1]]
+;
+entry:
+ %div = fdiv reassoc arcp double %y, %z
+ %sqrt = call reassoc arcp double @llvm.sqrt.f64(double %div)
+ %div1 = fdiv reassoc arcp double %x, %sqrt
+ ret double %div1
+}
+
+declare void @use(double)
+define double @sqrt_div_fast_multiple_uses_1(double %x, double %y, double %z) {
+; CHECK-LABEL: @sqrt_div_fast_multiple_uses_1(
+; CHECK-NEXT: entry:
+; CHECK-NEXT: [[DIV:%.*]] = fdiv fast double [[Y:%.*]], [[Z:%.*]]
+; CHECK-NEXT: call void @use(double [[DIV]])
+; CHECK-NEXT: [[SQRT:%.*]] = call fast double @llvm.sqrt.f64(double [[DIV]])
+; CHECK-NEXT: [[DIV1:%.*]] = fdiv fast double [[X:%.*]], [[SQRT]]
+; CHECK-NEXT: ret double [[DIV1]]
+;
+entry:
+ %div = fdiv fast double %y, %z
+ call void @use(double %div)
+ %sqrt = call fast double @llvm.sqrt.f64(double %div)
+ %div1 = fdiv fast double %x, %sqrt
+ ret double %div1
+}
+
+define double @sqrt_div_fast_multiple_uses_2(double %x, double %y, double %z) {
+; CHECK-LABEL: @sqrt_div_fast_multiple_uses_2(
+; CHECK-NEXT: entry:
+; CHECK-NEXT: [[DIV:%.*]] = fdiv fast double [[Y:%.*]], [[Z:%.*]]
+; CHECK-NEXT: [[SQRT:%.*]] = call fast double @llvm.sqrt.f64(double [[DIV]])
+; CHECK-NEXT: call void @use(double [[SQRT]])
+; CHECK-NEXT: [[DIV1:%.*]] = fdiv fast double [[X:%.*]], [[SQRT]]
+; CHECK-NEXT: ret double [[DIV1]]
+;
+entry:
+ %div = fdiv fast double %y, %z
+ %sqrt = call fast double @llvm.sqrt.f64(double %div)
+ call void @use(double %sqrt)
+ %div1 = fdiv fast double %x, %sqrt
+ ret double %div1
+}
+
``````````
</details>
https://github.com/llvm/llvm-project/pull/76737
More information about the llvm-commits
mailing list