[llvm] [InstCombine] Fix constant swap case of fcmp + fadd + sel xfrm (PR #119419)

via llvm-commits llvm-commits at lists.llvm.org
Tue Dec 10 09:30:00 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-transforms

Author: Rajat Bajpai (rajatbajpai)

<details>
<summary>Changes</summary>

The fcmp + fadd + sel => fcmp + sel + fadd xfrm performs incorrect transformation when select branch values are swapped. This change fixes this.

---
Full diff: https://github.com/llvm/llvm-project/pull/119419.diff


2 Files Affected:

- (modified) llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp (+25-18) 
- (modified) llvm/test/Transforms/InstCombine/fcmp-fadd-select.ll (+8-8) 


``````````diff
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
index c7a0c35d099cc4..dcdb59f6e95368 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
@@ -3723,22 +3723,9 @@ static Value *foldSelectIntoAddConstant(SelectInst &SI,
   if (!SIFOp || !SIFOp->hasNoSignedZeros() || !SIFOp->hasNoNaNs())
     return nullptr;
 
-  // select((fcmp Pred, X, 0), (fadd X, C), C)
-  //      => fadd((select (fcmp Pred, X, 0), X, 0), C)
-  //
-  // Pred := OGT, OGE, OLT, OLE, UGT, UGE, ULT, and ULE
-  Instruction *FAdd;
-  Constant *C;
-  Value *X, *Z;
-  CmpInst::Predicate Pred;
-
-  // Note: OneUse check for `Cmp` is necessary because it makes sure that other
-  // InstCombine folds don't undo this transformation and cause an infinite
-  // loop. Furthermore, it could also increase the operation count.
-  if (match(&SI, m_Select(m_OneUse(m_FCmp(Pred, m_Value(X), m_Value(Z))),
-                          m_OneUse(m_Instruction(FAdd)), m_Constant(C))) ||
-      match(&SI, m_Select(m_OneUse(m_FCmp(Pred, m_Value(X), m_Value(Z))),
-                          m_Constant(C), m_OneUse(m_Instruction(FAdd))))) {
+  auto TryFoldIntoAddConstant =
+      [&Builder, &SI](CmpInst::Predicate Pred, Value *X, Value *Z,
+                      Instruction *FAdd, Constant *C, bool Swapped) -> Value * {
     // Only these relational predicates can be transformed into maxnum/minnum
     // intrinsic.
     if (!CmpInst::isRelational(Pred) || !match(Z, m_AnyZeroFP()))
@@ -3747,7 +3734,8 @@ static Value *foldSelectIntoAddConstant(SelectInst &SI,
     if (!match(FAdd, m_FAdd(m_Specific(X), m_Specific(C))))
       return nullptr;
 
-    Value *NewSelect = Builder.CreateSelect(SI.getCondition(), X, Z, "", &SI);
+    Value *NewSelect = Builder.CreateSelect(SI.getCondition(), Swapped ? Z : X,
+                                            Swapped ? X : Z, "", &SI);
     NewSelect->takeName(&SI);
 
     Value *NewFAdd = Builder.CreateFAdd(NewSelect, C);
@@ -3762,7 +3750,26 @@ static Value *foldSelectIntoAddConstant(SelectInst &SI,
     cast<Instruction>(NewSelect)->setFastMathFlags(NewFMF);
 
     return NewFAdd;
-  }
+  };
+
+  // select((fcmp Pred, X, 0), (fadd X, C), C)
+  //      => fadd((select (fcmp Pred, X, 0), X, 0), C)
+  //
+  // Pred := OGT, OGE, OLT, OLE, UGT, UGE, ULT, and ULE
+  Instruction *FAdd;
+  Constant *C;
+  Value *X, *Z;
+  CmpInst::Predicate Pred;
+
+  // Note: OneUse check for `Cmp` is necessary because it makes sure that other
+  // InstCombine folds don't undo this transformation and cause an infinite
+  // loop. Furthermore, it could also increase the operation count.
+  if (match(&SI, m_Select(m_OneUse(m_FCmp(Pred, m_Value(X), m_Value(Z))),
+                          m_OneUse(m_Instruction(FAdd)), m_Constant(C))))
+    return TryFoldIntoAddConstant(Pred, X, Z, FAdd, C, false);
+  else if (match(&SI, m_Select(m_OneUse(m_FCmp(Pred, m_Value(X), m_Value(Z))),
+                               m_Constant(C), m_OneUse(m_Instruction(FAdd)))))
+    return TryFoldIntoAddConstant(Pred, X, Z, FAdd, C, true);
 
   return nullptr;
 }
diff --git a/llvm/test/Transforms/InstCombine/fcmp-fadd-select.ll b/llvm/test/Transforms/InstCombine/fcmp-fadd-select.ll
index 0d0af91608e7ac..15fad55db8df12 100644
--- a/llvm/test/Transforms/InstCombine/fcmp-fadd-select.ll
+++ b/llvm/test/Transforms/InstCombine/fcmp-fadd-select.ll
@@ -19,7 +19,7 @@ define float @test_fcmp_ogt_fadd_select_constant(float %in) {
 define float @test_fcmp_ogt_fadd_select_constant_swapped(float %in) {
 ; CHECK-LABEL: define float @test_fcmp_ogt_fadd_select_constant_swapped(
 ; CHECK-SAME: float [[IN:%.*]]) {
-; CHECK-NEXT:    [[SEL_NEW:%.*]] = call nsz float @llvm.maxnum.f32(float [[IN]], float 0.000000e+00)
+; CHECK-NEXT:    [[SEL_NEW:%.*]] = call nsz float @llvm.minnum.f32(float [[IN]], float 0.000000e+00)
 ; CHECK-NEXT:    [[ADD_NEW:%.*]] = fadd nnan nsz float [[SEL_NEW]], 1.000000e+00
 ; CHECK-NEXT:    ret float [[ADD_NEW]]
 ;
@@ -87,7 +87,7 @@ define float @test_fcmp_olt_fadd_select_constant(float %in) {
 define float @test_fcmp_olt_fadd_select_constant_swapped(float %in) {
 ; CHECK-LABEL: define float @test_fcmp_olt_fadd_select_constant_swapped(
 ; CHECK-SAME: float [[IN:%.*]]) {
-; CHECK-NEXT:    [[SEL_NEW:%.*]] = call nsz float @llvm.minnum.f32(float [[IN]], float 0.000000e+00)
+; CHECK-NEXT:    [[SEL_NEW:%.*]] = call nsz float @llvm.maxnum.f32(float [[IN]], float 0.000000e+00)
 ; CHECK-NEXT:    [[ADD_NEW:%.*]] = fadd nnan nsz float [[SEL_NEW]], 1.000000e+00
 ; CHECK-NEXT:    ret float [[ADD_NEW]]
 ;
@@ -155,7 +155,7 @@ define float @test_fcmp_oge_fadd_select_constant(float %in) {
 define float @test_fcmp_oge_fadd_select_constant_swapped(float %in) {
 ; CHECK-LABEL: define float @test_fcmp_oge_fadd_select_constant_swapped(
 ; CHECK-SAME: float [[IN:%.*]]) {
-; CHECK-NEXT:    [[SEL_NEW:%.*]] = call nsz float @llvm.maxnum.f32(float [[IN]], float 0.000000e+00)
+; CHECK-NEXT:    [[SEL_NEW:%.*]] = call nsz float @llvm.minnum.f32(float [[IN]], float 0.000000e+00)
 ; CHECK-NEXT:    [[ADD_NEW:%.*]] = fadd nnan nsz float [[SEL_NEW]], 1.000000e+00
 ; CHECK-NEXT:    ret float [[ADD_NEW]]
 ;
@@ -223,7 +223,7 @@ define float @test_fcmp_ole_fadd_select_constant(float %in) {
 define float @test_fcmp_ole_fadd_select_constant_swapped(float %in) {
 ; CHECK-LABEL: define float @test_fcmp_ole_fadd_select_constant_swapped(
 ; CHECK-SAME: float [[IN:%.*]]) {
-; CHECK-NEXT:    [[SEL_NEW:%.*]] = call nsz float @llvm.minnum.f32(float [[IN]], float 0.000000e+00)
+; CHECK-NEXT:    [[SEL_NEW:%.*]] = call nsz float @llvm.maxnum.f32(float [[IN]], float 0.000000e+00)
 ; CHECK-NEXT:    [[ADD_NEW:%.*]] = fadd nnan nsz float [[SEL_NEW]], 1.000000e+00
 ; CHECK-NEXT:    ret float [[ADD_NEW]]
 ;
@@ -293,7 +293,7 @@ define float @test_fcmp_ugt_fadd_select_constant_swapped(float %in) {
 ; CHECK-LABEL: define float @test_fcmp_ugt_fadd_select_constant_swapped(
 ; CHECK-SAME: float [[IN:%.*]]) {
 ; CHECK-NEXT:    [[CMP1_INV:%.*]] = fcmp ole float [[IN]], 0.000000e+00
-; CHECK-NEXT:    [[SEL_NEW:%.*]] = select i1 [[CMP1_INV]], float 0.000000e+00, float [[IN]]
+; CHECK-NEXT:    [[SEL_NEW:%.*]] = select i1 [[CMP1_INV]], float [[IN]], float 0.000000e+00
 ; CHECK-NEXT:    [[ADD_NEW:%.*]] = fadd nnan nsz float [[SEL_NEW]], 1.000000e+00
 ; CHECK-NEXT:    ret float [[ADD_NEW]]
 ;
@@ -366,7 +366,7 @@ define float @test_fcmp_uge_fadd_select_constant_swapped(float %in) {
 ; CHECK-LABEL: define float @test_fcmp_uge_fadd_select_constant_swapped(
 ; CHECK-SAME: float [[IN:%.*]]) {
 ; CHECK-NEXT:    [[CMP1_INV:%.*]] = fcmp olt float [[IN]], 0.000000e+00
-; CHECK-NEXT:    [[SEL_NEW:%.*]] = select i1 [[CMP1_INV]], float 0.000000e+00, float [[IN]]
+; CHECK-NEXT:    [[SEL_NEW:%.*]] = select i1 [[CMP1_INV]], float [[IN]], float 0.000000e+00
 ; CHECK-NEXT:    [[ADD_NEW:%.*]] = fadd nnan nsz float [[SEL_NEW]], 1.000000e+00
 ; CHECK-NEXT:    ret float [[ADD_NEW]]
 ;
@@ -439,7 +439,7 @@ define float @test_fcmp_ult_fadd_select_constant_swapped(float %in) {
 ; CHECK-LABEL: define float @test_fcmp_ult_fadd_select_constant_swapped(
 ; CHECK-SAME: float [[IN:%.*]]) {
 ; CHECK-NEXT:    [[CMP1_INV:%.*]] = fcmp oge float [[IN]], 0.000000e+00
-; CHECK-NEXT:    [[SEL_NEW:%.*]] = select i1 [[CMP1_INV]], float 0.000000e+00, float [[IN]]
+; CHECK-NEXT:    [[SEL_NEW:%.*]] = select i1 [[CMP1_INV]], float [[IN]], float 0.000000e+00
 ; CHECK-NEXT:    [[ADD_NEW:%.*]] = fadd nnan nsz float [[SEL_NEW]], 1.000000e+00
 ; CHECK-NEXT:    ret float [[ADD_NEW]]
 ;
@@ -512,7 +512,7 @@ define float @test_fcmp_ule_fadd_select_constant_swapped(float %in) {
 ; CHECK-LABEL: define float @test_fcmp_ule_fadd_select_constant_swapped(
 ; CHECK-SAME: float [[IN:%.*]]) {
 ; CHECK-NEXT:    [[CMP1_INV:%.*]] = fcmp ogt float [[IN]], 0.000000e+00
-; CHECK-NEXT:    [[SEL_NEW:%.*]] = select i1 [[CMP1_INV]], float 0.000000e+00, float [[IN]]
+; CHECK-NEXT:    [[SEL_NEW:%.*]] = select i1 [[CMP1_INV]], float [[IN]], float 0.000000e+00
 ; CHECK-NEXT:    [[ADD_NEW:%.*]] = fadd nnan nsz float [[SEL_NEW]], 1.000000e+00
 ; CHECK-NEXT:    ret float [[ADD_NEW]]
 ;

``````````

</details>


https://github.com/llvm/llvm-project/pull/119419


More information about the llvm-commits mailing list