[llvm] InstCombine: support floating-point equivalences (PR #114975)

Ramkumar Ramachandra via llvm-commits llvm-commits at lists.llvm.org
Tue Nov 5 04:27:14 PST 2024


https://github.com/artagnon created https://github.com/llvm/llvm-project/pull/114975

Since cd16b07 (IR: introduce CmpInst::isEquivalence), there is now an isEquivalence routine in CmpInst that we can use to determine equivalence in foldSelectValueEquivalence. Implement this, extending equivalences to floating-points as well.

>From 23394814d500ca03086bb8c57c0b458ac86575ba Mon Sep 17 00:00:00 2001
From: Ramkumar Ramachandra <ramkumar.ramachandra at codasip.com>
Date: Tue, 5 Nov 2024 11:52:41 +0000
Subject: [PATCH] InstCombine: support floating-point equivalences

Since cd16b07 (IR: introduce CmpInst::isEquivalence), there is now an
isEquivalence routine in CmpInst that we can use to determine
equivalence in foldSelectValueEquivalence. Implement this, extending
equivalences to floating-points as well.
---
 .../InstCombine/InstCombineInternal.h         |  2 +-
 .../InstCombine/InstCombineSelect.cpp         | 21 +++---
 .../InstCombine/select-binop-cmp.ll           | 65 +++++++++----------
 .../InstCombine/select-value-equivalence.ll   | 16 ++---
 4 files changed, 51 insertions(+), 53 deletions(-)

diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
index adbd9186c59c5a..9588930d7658c4 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
+++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
@@ -750,7 +750,7 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final
                             Value *A, Value *B, Instruction &Outer,
                             SelectPatternFlavor SPF2, Value *C);
   Instruction *foldSelectInstWithICmp(SelectInst &SI, ICmpInst *ICI);
-  Instruction *foldSelectValueEquivalence(SelectInst &SI, ICmpInst &ICI);
+  Instruction *foldSelectValueEquivalence(SelectInst &SI, CmpInst &CI);
   bool replaceInInstruction(Value *V, Value *Old, Value *New,
                             unsigned Depth = 0);
 
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
index 999ad1adff20b8..008e9638d27cee 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
@@ -1327,16 +1327,16 @@ bool InstCombinerImpl::replaceInInstruction(Value *V, Value *Old, Value *New,
 /// We can't replace %sel with %add unless we strip away the flags.
 /// TODO: Wrapping flags could be preserved in some cases with better analysis.
 Instruction *InstCombinerImpl::foldSelectValueEquivalence(SelectInst &Sel,
-                                                          ICmpInst &Cmp) {
-  if (!Cmp.isEquality())
-    return nullptr;
-
-  // Canonicalize the pattern to ICMP_EQ by swapping the select operands.
+                                                          CmpInst &Cmp) {
+  // Canonicalize the pattern to an equivalence on the predicate by swapping the
+  // select operands.
   Value *TrueVal = Sel.getTrueValue(), *FalseVal = Sel.getFalseValue();
   bool Swapped = false;
-  if (Cmp.getPredicate() == ICmpInst::ICMP_NE) {
+  if (Cmp.isEquivalence(/*Invert=*/true)) {
     std::swap(TrueVal, FalseVal);
     Swapped = true;
+  } else if (!Cmp.isEquivalence()) {
+    return nullptr;
   }
 
   Value *CmpLHS = Cmp.getOperand(0), *CmpRHS = Cmp.getOperand(1);
@@ -1347,7 +1347,7 @@ Instruction *InstCombinerImpl::foldSelectValueEquivalence(SelectInst &Sel,
     // would lead to an infinite replacement cycle.
     // If we will be able to evaluate f(Y) to a constant, we can allow undef,
     // otherwise Y cannot be undef as we might pick different values for undef
-    // in the icmp and in f(Y).
+    // in the cmp and in f(Y).
     if (TrueVal == OldOp)
       return nullptr;
 
@@ -1901,9 +1901,6 @@ static Instruction *foldSelectICmpEq(SelectInst &SI, ICmpInst *ICI,
 /// Visit a SelectInst that has an ICmpInst as its first operand.
 Instruction *InstCombinerImpl::foldSelectInstWithICmp(SelectInst &SI,
                                                       ICmpInst *ICI) {
-  if (Instruction *NewSel = foldSelectValueEquivalence(SI, *ICI))
-    return NewSel;
-
   if (Value *V =
           canonicalizeSPF(*ICI, SI.getTrueValue(), SI.getFalseValue(), *this))
     return replaceInstUsesWith(SI, V);
@@ -3794,6 +3791,10 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) {
     return Fabs;
 
   // See if we are selecting two values based on a comparison of the two values.
+  if (CmpInst *CI = dyn_cast<CmpInst>(CondVal))
+    if (Instruction *NewSel = foldSelectValueEquivalence(SI, *CI))
+      return NewSel;
+
   if (ICmpInst *ICI = dyn_cast<ICmpInst>(CondVal))
     if (Instruction *Result = foldSelectInstWithICmp(SI, ICI))
       return Result;
diff --git a/llvm/test/Transforms/InstCombine/select-binop-cmp.ll b/llvm/test/Transforms/InstCombine/select-binop-cmp.ll
index cd8c29ba4cd819..e563cafbc7e4f2 100644
--- a/llvm/test/Transforms/InstCombine/select-binop-cmp.ll
+++ b/llvm/test/Transforms/InstCombine/select-binop-cmp.ll
@@ -706,11 +706,10 @@ define i32 @select_lshr_icmp_const_different_values(i32 %x, i32 %y) {
   ret i32 %C
 }
 
-; Invalid identity constant for FP op
-define float @select_fadd_fcmp_bad(float %x, float %y, float %z) {
-; CHECK-LABEL: @select_fadd_fcmp_bad(
+define float @select_fadd_fcmp_equiv(float %x, float %y, float %z) {
+; CHECK-LABEL: @select_fadd_fcmp_equiv(
 ; CHECK-NEXT:    [[A:%.*]] = fcmp oeq float [[X:%.*]], -1.000000e+00
-; CHECK-NEXT:    [[B:%.*]] = fadd nsz float [[X]], [[Z:%.*]]
+; CHECK-NEXT:    [[B:%.*]] = fadd nsz float [[Z:%.*]], -1.000000e+00
 ; CHECK-NEXT:    [[C:%.*]] = select i1 [[A]], float [[B]], float [[Y:%.*]]
 ; CHECK-NEXT:    ret float [[C]]
 ;
@@ -720,6 +719,19 @@ define float @select_fadd_fcmp_bad(float %x, float %y, float %z) {
   ret float %C
 }
 
+define float @select_fadd_fcmp_equiv2(float %x, float %y, float %z) {
+; CHECK-LABEL: @select_fadd_fcmp_equiv2(
+; CHECK-NEXT:    [[A:%.*]] = fcmp une float [[X:%.*]], -1.000000e+00
+; CHECK-NEXT:    [[B:%.*]] = fadd nsz float [[Z:%.*]], -1.000000e+00
+; CHECK-NEXT:    [[C:%.*]] = select i1 [[A]], float [[Y:%.*]], float [[B]]
+; CHECK-NEXT:    ret float [[C]]
+;
+  %A = fcmp une float %x, -1.0
+  %B = fadd nsz float %x, %z
+  %C = select i1 %A, float %y, float %B
+  ret float %C
+}
+
 ; Invalid comparison type
 define float @select_fadd_fcmp_bad_2(float %x, float %y, float %z) {
 ; CHECK-LABEL: @select_fadd_fcmp_bad_2(
@@ -893,24 +905,10 @@ define float @select_fadd_fcmp_bad_13(float %x, float %y, float %z) {
   ret float %C
 }
 
-; Invalid identity constant for FP op
-define float @select_fadd_fcmp_bad_14(float %x, float %y, float %z) {
-; CHECK-LABEL: @select_fadd_fcmp_bad_14(
-; CHECK-NEXT:    [[A:%.*]] = fcmp une float [[X:%.*]], -1.000000e+00
-; CHECK-NEXT:    [[B:%.*]] = fadd nsz float [[X]], [[Z:%.*]]
-; CHECK-NEXT:    [[C:%.*]] = select i1 [[A]], float [[Y:%.*]], float [[B]]
-; CHECK-NEXT:    ret float [[C]]
-;
-  %A = fcmp une float %x, -1.0
-  %B = fadd nsz float %x, %z
-  %C = select i1 %A, float %y, float %B
-  ret float %C
-}
-
-define float @select_fmul_fcmp_bad(float %x, float %y, float %z) {
-; CHECK-LABEL: @select_fmul_fcmp_bad(
+define float @select_fmul_fcmp_equiv(float %x, float %y, float %z) {
+; CHECK-LABEL: @select_fmul_fcmp_equiv(
 ; CHECK-NEXT:    [[A:%.*]] = fcmp oeq float [[X:%.*]], 3.000000e+00
-; CHECK-NEXT:    [[B:%.*]] = fmul nsz float [[X]], [[Z:%.*]]
+; CHECK-NEXT:    [[B:%.*]] = fmul nsz float [[Z:%.*]], 3.000000e+00
 ; CHECK-NEXT:    [[C:%.*]] = select i1 [[A]], float [[B]], float [[Y:%.*]]
 ; CHECK-NEXT:    ret float [[C]]
 ;
@@ -920,11 +918,10 @@ define float @select_fmul_fcmp_bad(float %x, float %y, float %z) {
   ret float %C
 }
 
-define float @select_fmul_fcmp_bad_2(float %x, float %y, float %z) {
-; CHECK-LABEL: @select_fmul_fcmp_bad_2(
+define float @select_fmul_fcmp_equiv2(float %x, float %y, float %z) {
+; CHECK-LABEL: @select_fmul_fcmp_equiv2(
 ; CHECK-NEXT:    [[A:%.*]] = fcmp oeq float [[X:%.*]], 1.000000e+00
-; CHECK-NEXT:    [[B:%.*]] = fmul float [[X]], [[Z:%.*]]
-; CHECK-NEXT:    [[C:%.*]] = select i1 [[A]], float [[B]], float [[Y:%.*]]
+; CHECK-NEXT:    [[C:%.*]] = select i1 [[A]], float [[B:%.*]], float [[Y:%.*]]
 ; CHECK-NEXT:    ret float [[C]]
 ;
   %A = fcmp oeq float %x, 1.0
@@ -959,10 +956,10 @@ define float @select_fmul_icmp_bad_2(float %x, float %y, float %z, i32 %k) {
   ret float %C
 }
 
-define float @select_fdiv_fcmp_bad(float %x, float %y, float %z) {
-; CHECK-LABEL: @select_fdiv_fcmp_bad(
+define float @select_fdiv_fcmp_equiv(float %x, float %y, float %z) {
+; CHECK-LABEL: @select_fdiv_fcmp_equiv(
 ; CHECK-NEXT:    [[A:%.*]] = fcmp oeq float [[X:%.*]], 1.000000e+00
-; CHECK-NEXT:    [[B:%.*]] = fdiv float [[X]], [[Z:%.*]]
+; CHECK-NEXT:    [[B:%.*]] = fdiv float 1.000000e+00, [[Z:%.*]]
 ; CHECK-NEXT:    [[C:%.*]] = select i1 [[A]], float [[B]], float [[Y:%.*]]
 ; CHECK-NEXT:    ret float [[C]]
 ;
@@ -972,10 +969,10 @@ define float @select_fdiv_fcmp_bad(float %x, float %y, float %z) {
   ret float %C
 }
 
-define float @select_fdiv_fcmp_bad_2(float %x, float %y, float %z) {
-; CHECK-LABEL: @select_fdiv_fcmp_bad_2(
+define float @select_fdiv_fcmp_equiv2(float %x, float %y, float %z) {
+; CHECK-LABEL: @select_fdiv_fcmp_equiv2(
 ; CHECK-NEXT:    [[A:%.*]] = fcmp oeq float [[X:%.*]], 3.000000e+00
-; CHECK-NEXT:    [[B:%.*]] = fdiv nsz float [[X]], [[Z:%.*]]
+; CHECK-NEXT:    [[B:%.*]] = fdiv nsz float 3.000000e+00, [[Z:%.*]]
 ; CHECK-NEXT:    [[C:%.*]] = select i1 [[A]], float [[B]], float [[Y:%.*]]
 ; CHECK-NEXT:    ret float [[C]]
 ;
@@ -1001,10 +998,10 @@ define float @select_fsub_fcmp_bad(float %x, float %y, float %z) {
   ret float %C
 }
 
-define float @select_fsub_fcmp_bad_2(float %x, float %y, float %z) {
-; CHECK-LABEL: @select_fsub_fcmp_bad_2(
+define float @select_fsub_fcmp_equiv(float %x, float %y, float %z) {
+; CHECK-LABEL: @select_fsub_fcmp_equiv(
 ; CHECK-NEXT:    [[A:%.*]] = fcmp oeq float [[X:%.*]], 1.000000e+00
-; CHECK-NEXT:    [[B:%.*]] = fsub nsz float [[Z:%.*]], [[X]]
+; CHECK-NEXT:    [[B:%.*]] = fadd nsz float [[Z:%.*]], -1.000000e+00
 ; CHECK-NEXT:    [[C:%.*]] = select i1 [[A]], float [[B]], float [[Y:%.*]]
 ; CHECK-NEXT:    ret float [[C]]
 ;
diff --git a/llvm/test/Transforms/InstCombine/select-value-equivalence.ll b/llvm/test/Transforms/InstCombine/select-value-equivalence.ll
index 28984282511924..79cd0a5e20dfe8 100644
--- a/llvm/test/Transforms/InstCombine/select-value-equivalence.ll
+++ b/llvm/test/Transforms/InstCombine/select-value-equivalence.ll
@@ -90,7 +90,7 @@ define float @select_fcmp_fadd_oeq_not_zero(float %x, float %y) {
 ; CHECK-LABEL: define float @select_fcmp_fadd_oeq_not_zero(
 ; CHECK-SAME: float [[X:%.*]], float [[Y:%.*]]) {
 ; CHECK-NEXT:    [[FCMP:%.*]] = fcmp oeq float [[Y]], 2.000000e+00
-; CHECK-NEXT:    [[FADD:%.*]] = fadd float [[X]], [[Y]]
+; CHECK-NEXT:    [[FADD:%.*]] = fadd float [[X]], 2.000000e+00
 ; CHECK-NEXT:    [[RETVAL:%.*]] = select i1 [[FCMP]], float [[FADD]], float [[X]]
 ; CHECK-NEXT:    ret float [[RETVAL]]
 ;
@@ -104,7 +104,7 @@ define float @select_fcmp_fadd_une_not_zero(float %x, float %y) {
 ; CHECK-LABEL: define float @select_fcmp_fadd_une_not_zero(
 ; CHECK-SAME: float [[X:%.*]], float [[Y:%.*]]) {
 ; CHECK-NEXT:    [[FCMP:%.*]] = fcmp une float [[Y]], 2.000000e+00
-; CHECK-NEXT:    [[FADD:%.*]] = fadd float [[X]], [[Y]]
+; CHECK-NEXT:    [[FADD:%.*]] = fadd float [[X]], 2.000000e+00
 ; CHECK-NEXT:    [[RETVAL:%.*]] = select i1 [[FCMP]], float [[X]], float [[FADD]]
 ; CHECK-NEXT:    ret float [[RETVAL]]
 ;
@@ -118,7 +118,7 @@ define float @select_fcmp_fadd_ueq_nnan_not_zero(float %x, float %y) {
 ; CHECK-LABEL: define float @select_fcmp_fadd_ueq_nnan_not_zero(
 ; CHECK-SAME: float [[X:%.*]], float [[Y:%.*]]) {
 ; CHECK-NEXT:    [[FCMP:%.*]] = fcmp nnan ueq float [[Y]], 2.000000e+00
-; CHECK-NEXT:    [[FADD:%.*]] = fadd float [[X]], [[Y]]
+; CHECK-NEXT:    [[FADD:%.*]] = fadd float [[X]], 2.000000e+00
 ; CHECK-NEXT:    [[RETVAL:%.*]] = select i1 [[FCMP]], float [[FADD]], float [[X]]
 ; CHECK-NEXT:    ret float [[RETVAL]]
 ;
@@ -132,7 +132,7 @@ define float @select_fcmp_fadd_one_nnan_not_zero(float %x, float %y) {
 ; CHECK-LABEL: define float @select_fcmp_fadd_one_nnan_not_zero(
 ; CHECK-SAME: float [[X:%.*]], float [[Y:%.*]]) {
 ; CHECK-NEXT:    [[FCMP:%.*]] = fcmp nnan one float [[Y]], 2.000000e+00
-; CHECK-NEXT:    [[FADD:%.*]] = fadd float [[X]], [[Y]]
+; CHECK-NEXT:    [[FADD:%.*]] = fadd float [[X]], 2.000000e+00
 ; CHECK-NEXT:    [[RETVAL:%.*]] = select i1 [[FCMP]], float [[X]], float [[FADD]]
 ; CHECK-NEXT:    ret float [[RETVAL]]
 ;
@@ -202,7 +202,7 @@ define <2 x float> @select_fcmp_fadd_oeq_not_zero_vec(<2 x float> %x, <2 x float
 ; CHECK-LABEL: define <2 x float> @select_fcmp_fadd_oeq_not_zero_vec(
 ; CHECK-SAME: <2 x float> [[X:%.*]], <2 x float> [[Y:%.*]]) {
 ; CHECK-NEXT:    [[FCMP:%.*]] = fcmp oeq <2 x float> [[Y]], <float 2.000000e+00, float 2.000000e+00>
-; CHECK-NEXT:    [[FADD:%.*]] = fadd <2 x float> [[X]], [[Y]]
+; CHECK-NEXT:    [[FADD:%.*]] = fadd <2 x float> [[X]], <float 2.000000e+00, float 2.000000e+00>
 ; CHECK-NEXT:    [[RETVAL:%.*]] = select <2 x i1> [[FCMP]], <2 x float> [[FADD]], <2 x float> [[X]]
 ; CHECK-NEXT:    ret <2 x float> [[RETVAL]]
 ;
@@ -216,7 +216,7 @@ define <2 x float> @select_fcmp_fadd_une_not_zero_vec(<2 x float> %x, <2 x float
 ; CHECK-LABEL: define <2 x float> @select_fcmp_fadd_une_not_zero_vec(
 ; CHECK-SAME: <2 x float> [[X:%.*]], <2 x float> [[Y:%.*]]) {
 ; CHECK-NEXT:    [[FCMP:%.*]] = fcmp une <2 x float> [[Y]], <float 2.000000e+00, float 2.000000e+00>
-; CHECK-NEXT:    [[FADD:%.*]] = fadd <2 x float> [[X]], [[Y]]
+; CHECK-NEXT:    [[FADD:%.*]] = fadd <2 x float> [[X]], <float 2.000000e+00, float 2.000000e+00>
 ; CHECK-NEXT:    [[RETVAL:%.*]] = select <2 x i1> [[FCMP]], <2 x float> [[X]], <2 x float> [[FADD]]
 ; CHECK-NEXT:    ret <2 x float> [[RETVAL]]
 ;
@@ -230,7 +230,7 @@ define <2 x float> @select_fcmp_fadd_ueq_nnan_not_zero_vec(<2 x float> %x, <2 x
 ; CHECK-LABEL: define <2 x float> @select_fcmp_fadd_ueq_nnan_not_zero_vec(
 ; CHECK-SAME: <2 x float> [[X:%.*]], <2 x float> [[Y:%.*]]) {
 ; CHECK-NEXT:    [[FCMP:%.*]] = fcmp nnan ueq <2 x float> [[Y]], <float 2.000000e+00, float 2.000000e+00>
-; CHECK-NEXT:    [[FADD:%.*]] = fadd <2 x float> [[X]], [[Y]]
+; CHECK-NEXT:    [[FADD:%.*]] = fadd <2 x float> [[X]], <float 2.000000e+00, float 2.000000e+00>
 ; CHECK-NEXT:    [[RETVAL:%.*]] = select <2 x i1> [[FCMP]], <2 x float> [[FADD]], <2 x float> [[X]]
 ; CHECK-NEXT:    ret <2 x float> [[RETVAL]]
 ;
@@ -244,7 +244,7 @@ define <2 x float> @select_fcmp_fadd_one_nnan_not_zero_vec(<2 x float> %x, <2 x
 ; CHECK-LABEL: define <2 x float> @select_fcmp_fadd_one_nnan_not_zero_vec(
 ; CHECK-SAME: <2 x float> [[X:%.*]], <2 x float> [[Y:%.*]]) {
 ; CHECK-NEXT:    [[FCMP:%.*]] = fcmp nnan one <2 x float> [[Y]], <float 2.000000e+00, float 2.000000e+00>
-; CHECK-NEXT:    [[FADD:%.*]] = fadd <2 x float> [[X]], [[Y]]
+; CHECK-NEXT:    [[FADD:%.*]] = fadd <2 x float> [[X]], <float 2.000000e+00, float 2.000000e+00>
 ; CHECK-NEXT:    [[RETVAL:%.*]] = select <2 x i1> [[FCMP]], <2 x float> [[X]], <2 x float> [[FADD]]
 ; CHECK-NEXT:    ret <2 x float> [[RETVAL]]
 ;



More information about the llvm-commits mailing list