[llvm] [InstCombineCompares] Replace the sqrt in if-condition (PR #91707)

via llvm-commits llvm-commits at lists.llvm.org
Fri May 10 00:40:27 PDT 2024


https://github.com/whokeke created https://github.com/llvm/llvm-project/pull/91707

Adding the transform, fcmp sqrt(X), Y => fcmp X, copysign(Y*Y, Y) , to help sink the sqrt. And give a option `replace-sqrt-compare-by-square` to control it.

>From 511dfd3de2be1ee0216c11a3658b8153863534bb Mon Sep 17 00:00:00 2001
From: hukeke <hukeke2 at huawei.com>
Date: Fri, 10 May 2024 11:35:03 +0800
Subject: [PATCH] [InstCombineCompares] Replace the sqrt in if-condition

Adding the transform, fcmp sqrt(X), Y => fcmp X, copysign(Y*Y, Y) , to
help sink the sqrt. And give a option `replace-sqrt-compare-by-square`
to control it.
---
 .../InstCombine/InstCombineCompares.cpp       | 49 +++++++++++
 .../test/Transforms/InstCombine/fcmp-fsqrt.ll | 87 +++++++++++++++++++
 2 files changed, 136 insertions(+)
 create mode 100644 llvm/test/Transforms/InstCombine/fcmp-fsqrt.ll

diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
index e1a3194a1beb7..ed0a11a56ab2c 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -38,6 +38,11 @@ using namespace PatternMatch;
 // How many times is a select replaced by one of its operands?
 STATISTIC(NumSel, "Number of select opts");
 
+static cl::opt<bool> ReplaceSqrtInIfCondition(
+    "replace-sqrt-compare-by-square", cl::init(false), cl::Hidden,
+    cl::desc(
+        "Try to replcae the sqrt, in the comparison condition, by the square."
+        "Like that fcmp sqrt(X), Y => fcmp X, copysing(Y*Y, Y)"));
 
 /// Compute Result = In1+In2, returning true if the result overflowed for this
 /// type.
@@ -8068,6 +8073,50 @@ Instruction *InstCombinerImpl::visitFCmpInst(FCmpInst &I) {
         return new FCmpInst(I.getSwappedPredicate(), X, NegC, "", &I);
   }
 
+  /// Try to sink the sqrt in this following case:
+  /// \code
+  /// tmp = sqrt(input)
+  /// if (tmp > cond)
+  ///   branch_without_tmp
+  /// else
+  ///   branch_with_tmp
+  /// ...
+  /// \endcode
+  /// Is optimized to:
+  /// \code
+  /// if (input > copysign(cond*cond, cond))
+  ///   branch_without_tmp
+  /// else
+  ///   tmp = sqrt(input)
+  ///   branch_with_tmp
+  /// ...
+  /// \endcode
+  /// Only call sqrt in the branch that needs the sqrt result,
+  /// not if it is not needed, reducing the number of calls to sqrt.
+  if (I.isFast() && ReplaceSqrtInIfCondition) {
+    // fcmp sqrt(X), sqrt(Y) => fcmp X, Y
+    if (match(Op0, m_Intrinsic<Intrinsic::sqrt>(m_Value(X))) &&
+        match(Op1, m_Intrinsic<Intrinsic::sqrt>(m_Value(Y)))) {
+      auto CIX = cast<CallInst>(Op0);
+      auto CIY = cast<CallInst>(Op1);
+    return new FCmpInst(Pred, CIX->getOperand(0), CIY->getOperand(0), ""
+    } else if (match(Op0, m_Intrinsic<Intrinsic::sqrt>(m_Value(X)))) {
+    // fcmp sqrt(X), Y => fcmp X, copysing(Y*Y, Y)
+    auto CI = cast<CallInst>(Op0);
+    Value *YY = Builder.CreateFMulFMF(Op1, Op1, &I);
+    Value *SYY =
+        Builder.CreateBinaryIntrinsic(Intrinsic::copysign, YY, Op1, &I);
+    return new FCmpInst(Pred, CI->getOperand(0), SYY, "", &I);
+    } else if (match(Op1, m_Intrinsic<Intrinsic::sqrt>(m_Value(Y)))) {
+    // fcmp X, sqrt(Y) => fcmp copysign(X*X, X), Y
+    auto CI = cast<CallInst>(Op1);
+    Value *XX = Builder.CreateFMulFMF(Op0, Op0, &I);
+    Value *SXX =
+        Builder.CreateBinaryIntrinsic(Intrinsic::copysign, XX, Op0, &I);
+    return new FCmpInst(Pred, SXX, CI->getOperand(0), "", &I);
+    }
+  }
+
   // fcmp (fadd X, 0.0), Y --> fcmp X, Y
   if (match(Op0, m_FAdd(m_Value(X), m_AnyZeroFP())))
     return new FCmpInst(Pred, X, Op1, "", &I);
diff --git a/llvm/test/Transforms/InstCombine/fcmp-fsqrt.ll b/llvm/test/Transforms/InstCombine/fcmp-fsqrt.ll
new file mode 100644
index 0000000000000..919e2ecfff962
--- /dev/null
+++ b/llvm/test/Transforms/InstCombine/fcmp-fsqrt.ll
@@ -0,0 +1,87 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
+; RUN: opt < %s -instcombine -S -replace-sqrt-compare-by-square=true | FileCheck %s -check-prefix=OPT-TRUE
+; RUN: opt < %s -instcombine -S | FileCheck %s -check-prefix=OPT-FALSE
+
+define i1 @foo1_fast(float %a, float %b) {
+; OPT-TRUE-LABEL: @foo1_fast(
+; OPT-TRUE-NEXT:    [[FCMP:%.*]] = fcmp fast ogt float %a, %b
+; OPT-TRUE-NEXT:    ret i1 [[FCMP]]
+;
+; OPT-FALSE-LABEL: @foo1_fast(
+; OPT-FALSE-NEXT:    %c = call fast float @llvm.sqrt.f32(float %a)
+; OPT-FALSE-NEXT:    %d = call fast float @llvm.sqrt.f32(float %b)
+; OPT-FALSE-NEXT:    [[FCMP:%.*]] = fcmp fast ogt float %c, %d
+; OPT-FALSE-NEXT:    ret i1 [[FCMP]]
+;
+  %c = call fast float @llvm.sqrt.f32(float %a)
+  %d = call fast float @llvm.sqrt.f32(float %b)
+  %e = fcmp fast ogt float %c, %d
+  ret i1 %e
+}
+
+define i1 @foo2_fast(float %a, float %b) {
+; OPT-TRUE-LABEL: @foo2_fast(
+; OPT-TRUE-NEXT:    [[FMUL:%.*]] = fmul fast float %b, %b
+; OPT-TRUE-NEXT:    [[SIGN:%.*]] = call fast float @llvm.copysign.f32(float [[FMUL]], float %b)
+; OPT-TRUE-NEXT:    [[FCMP:%.*]] = fcmp fast ogt float [[SIGN]], %a
+; OPT-TRUE-NEXT:    ret i1 [[FCMP]]
+;
+; OPT-FALSE-LABEL: @foo2_fast(
+; OPT-FALSE-NEXT:    [[SQRT:%.*]] = call fast float @llvm.sqrt.f32(float %a)
+; OPT-FALSE-NEXT:    [[FCMP:%.*]] = fcmp fast olt float [[SQRT]], %b
+; OPT-FALSE-NEXT:    ret i1 [[FCMP]]
+;
+  %c = call fast float @llvm.sqrt.f32(float %a)
+  %d = fcmp fast ogt float %b, %c
+  ret i1 %d
+}
+
+define i1 @foo3_fast(float %a, float %b) {
+; OPT-TRUE-LABEL: @foo3_fast(
+; OPT-TRUE-NEXT:    [[FMUL:%.*]] = fmul fast float %b, %b
+; OPT-TRUE-NEXT:    [[SIGN:%.*]] = call fast float @llvm.copysign.f32(float [[FMUL]], float %b)
+; OPT-TRUE-NEXT:    [[FCMP:%.*]] = fcmp fast olt float [[SIGN]], %a
+; OPT-TRUE-NEXT:    ret i1 [[FCMP]]
+;
+; OPT-FALSE-LABEL: @foo3_fast(
+; OPT-FALSE-NEXT:    [[SQRT:%.*]] = call fast float @llvm.sqrt.f32(float %a)
+; OPT-FALSE-NEXT:    [[FCMP:%.*]] = fcmp fast ogt float [[SQRT]], %b
+; OPT-FALSE-NEXT:    ret i1 [[FCMP]]
+;
+  %c = call fast float @llvm.sqrt.f32(float %a)
+  %d = fcmp fast ogt float %c, %b
+  ret i1 %d
+}
+
+define i1 @foo1_no_fast(float %a, float %b) {
+; OPT-TRUE-LABEL: @foo1_no_fast(
+; OPT-TRUE:    [[FCMP:%.*]] = fcmp ogt float %c, %d
+; OPT-TRUE-NEXT:    ret i1 [[FCMP]]
+;
+  %c = call float @llvm.sqrt.f32(float %a)
+  %d = call float @llvm.sqrt.f32(float %b)
+  %e = fcmp ogt float %c, %d
+  ret i1 %e
+}
+
+define i1 @foo2_no_fast(float %a, float %b) {
+; OPT-TRUE-LABEL: @foo2_no_fast(
+; OPT-TRUE:    [[FCMP:%.*]] = fcmp olt float %c, %b
+; OPT-TRUE-NEXT:    ret i1 [[FCMP]]
+;
+  %c = call float @llvm.sqrt.f32(float %a)
+  %d = fcmp ogt float %b, %c
+  ret i1 %d
+}
+
+define i1 @foo3_no_fast(float %a, float %b) {
+; OPT-TRUE-LABEL: @foo3_no_fast(
+; OPT-TRUE:    [[FCMP:%.*]] = fcmp ogt float %c, %b
+; OPT-TRUE-NEXT:    ret i1 [[FCMP]]
+;
+  %c = call float @llvm.sqrt.f32(float %a)
+  %d = fcmp ogt float %c, %b
+  ret i1 %d
+}
+
+declare float @llvm.sqrt.f32(float)



More information about the llvm-commits mailing list