[llvm] [RISCV] Combine `(setcc (riscv_selectcc A, B, ...), Y)` to just `(setcc A, B)` when possible (PR #90538)

Min-Yih Hsu via llvm-commits llvm-commits at lists.llvm.org
Tue Apr 30 10:59:26 PDT 2024


https://github.com/mshockwave updated https://github.com/llvm/llvm-project/pull/90538

>From d61f5c62fc869938587bd1151ff1669783c1c17c Mon Sep 17 00:00:00 2001
From: Min Hsu <min.hsu at sifive.com>
Date: Mon, 29 Apr 2024 17:52:35 -0700
Subject: [PATCH 1/4] [RISCV] Combine `(setcc (riscv_selectcc A, B, ...), Y)`
 to just `(setcc A, B)` when possible

Given `(seteq (riscv_selectcc LHS, RHS, CC, X, Y), X)`, we can turn it
into `(setCC LHS, RHS)`.

I think we can generalize this into ISD::SELECT_CC as well.
---
 llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 73 ++++++++++++++++++---
 llvm/test/CodeGen/RISCV/rvv/vp-cttz-elts.ll | 10 +--
 2 files changed, 66 insertions(+), 17 deletions(-)

diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 454b486b797b1b..43ecf3e3a7fb3a 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -28,6 +28,7 @@
 #include "llvm/CodeGen/MachineInstrBuilder.h"
 #include "llvm/CodeGen/MachineJumpTableInfo.h"
 #include "llvm/CodeGen/MachineRegisterInfo.h"
+#include "llvm/CodeGen/SDPatternMatch.h"
 #include "llvm/CodeGen/SelectionDAGAddressAnalysis.h"
 #include "llvm/CodeGen/TargetLoweringObjectFileImpl.h"
 #include "llvm/CodeGen/ValueTypes.h"
@@ -13678,9 +13679,69 @@ static SDValue performSETCCCombine(SDNode *N, SelectionDAG &DAG,
                                    const RISCVSubtarget &Subtarget) {
   SDValue N0 = N->getOperand(0);
   SDValue N1 = N->getOperand(1);
+  ISD::CondCode Cond = cast<CondCodeSDNode>(N->getOperand(2))->get();
   EVT VT = N->getValueType(0);
   EVT OpVT = N0.getValueType();
+  SDLoc DL(N);
+
+  // Both rules are looking for an equality compare.
+  if (!isIntEqualitySetCC(Cond))
+    return SDValue();
+
+  // Rule 1
+  using namespace SDPatternMatch;
+  auto getSelectCCPattern = [](SDValue Candidate, bool Inverse,
+                               SDValue &Select) -> auto {
+    if (Inverse)
+      return m_AllOf(
+          m_OneUse(m_Node(RISCVISD::SELECT_CC, m_Value(), m_Value(), m_Value(),
+                          /*TrueVal=*/m_Value(),
+                          /*FalseVal=*/m_Specific(Candidate))),
+          m_Value(Select));
+    else
+      return m_AllOf(
+          m_OneUse(m_Node(RISCVISD::SELECT_CC, m_Value(), m_Value(), m_Value(),
+                          /*TrueVal=*/m_Specific(Candidate),
+                          /*FalseVal=*/m_Value())),
+          m_Value(Select));
+  };
+
+  auto buildSetCC = [&](SDValue Select, bool Inverse) -> SDValue {
+    ISD::CondCode NewCC = cast<CondCodeSDNode>(Select->getOperand(2))->get();
+    if (Inverse)
+      NewCC = ISD::getSetCCInverse(NewCC, OpVT);
+    return DAG.getNode(
+        ISD::SETCC, DL, VT,
+        {Select->getOperand(0), Select->getOperand(1), DAG.getCondCode(NewCC)},
+        N->getFlags());
+  };
 
+  SDValue SelectVal;
+  if (sd_match(N0, getSelectCCPattern(N1, false, SelectVal)) ||
+      sd_match(N1, getSelectCCPattern(N0, false, SelectVal))) {
+    if (Cond == ISD::SETEQ) {
+      // (seteq (SELECT_CC LHS, RHS, CC, N1, X), N1) => (setCC LHS, RHS)
+      // (seteq N0, (SELECT_CC LHS, RHS, CC, N0, X)) => (setCC LHS, RHS)
+      return buildSetCC(SelectVal, false);
+    } else {
+      // (setne (SELECT_CC LHS, RHS, CC, N1, X), N1) => (setInvCC LHS, RHS)
+      // (setne N0, (SELECT_CC LHS, RHS, CC, N0, X)) => (setInvCC LHS, RHS)
+      return buildSetCC(SelectVal, true);
+    }
+  } else if (sd_match(N0, getSelectCCPattern(N1, true, SelectVal)) ||
+             sd_match(N1, getSelectCCPattern(N0, true, SelectVal))) {
+    if (Cond == ISD::SETEQ) {
+      // (seteq (SELECT_CC LHS, RHS, CC, X, N1), N1) => (setInvCC LHS, RHS)
+      // (seteq N0, (SELECT_CC LHS, RHS, CC, X, N0)) => (setInvCC LHS, RHS)
+      return buildSetCC(SelectVal, true);
+    } else {
+      // (setne (SELECT_CC LHS, RHS, CC, X, N1), N1) => (setCC LHS, RHS)
+      // (setne N0, (SELECT_CC LHS, RHS, CC, X, N0)) => (setCC LHS, RHS)
+      return buildSetCC(SelectVal, false);
+    }
+  }
+
+  // Rule 2
   if (OpVT != MVT::i64 || !Subtarget.is64Bit())
     return SDValue();
 
@@ -13695,11 +13756,6 @@ static SDValue performSETCCCombine(SDNode *N, SelectionDAG &DAG,
       N0.getConstantOperandVal(1) != UINT64_C(0xffffffff))
     return SDValue();
 
-  // Looking for an equality compare.
-  ISD::CondCode Cond = cast<CondCodeSDNode>(N->getOperand(2))->get();
-  if (!isIntEqualitySetCC(Cond))
-    return SDValue();
-
   // Don't do this if the sign bit is provably zero, it will be turned back into
   // an AND.
   APInt SignMask = APInt::getOneBitSet(64, 31);
@@ -13708,16 +13764,15 @@ static SDValue performSETCCCombine(SDNode *N, SelectionDAG &DAG,
 
   const APInt &C1 = N1C->getAPIntValue();
 
-  SDLoc dl(N);
   // If the constant is larger than 2^32 - 1 it is impossible for both sides
   // to be equal.
   if (C1.getActiveBits() > 32)
-    return DAG.getBoolConstant(Cond == ISD::SETNE, dl, VT, OpVT);
+    return DAG.getBoolConstant(Cond == ISD::SETNE, DL, VT, OpVT);
 
   SDValue SExtOp = DAG.getNode(ISD::SIGN_EXTEND_INREG, N, OpVT,
                                N0.getOperand(0), DAG.getValueType(MVT::i32));
-  return DAG.getSetCC(dl, VT, SExtOp, DAG.getConstant(C1.trunc(32).sext(64),
-                                                      dl, OpVT), Cond);
+  return DAG.getSetCC(DL, VT, SExtOp,
+                      DAG.getConstant(C1.trunc(32).sext(64), DL, OpVT), Cond);
 }
 
 static SDValue
diff --git a/llvm/test/CodeGen/RISCV/rvv/vp-cttz-elts.ll b/llvm/test/CodeGen/RISCV/rvv/vp-cttz-elts.ll
index 8b368bfaab08ee..c3c600dde943e6 100644
--- a/llvm/test/CodeGen/RISCV/rvv/vp-cttz-elts.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/vp-cttz-elts.ll
@@ -155,14 +155,8 @@ define i1 @nxv2i32_cmp_evl(<vscale x 2 x i32> %src, <vscale x 2 x i1> %m, i32 %e
 ; RV32:       # %bb.0:
 ; RV32-NEXT:    vsetvli zero, a0, e32, m1, ta, ma
 ; RV32-NEXT:    vmsne.vi v8, v8, 0, v0.t
-; RV32-NEXT:    vfirst.m a2, v8, v0.t
-; RV32-NEXT:    mv a1, a0
-; RV32-NEXT:    bltz a2, .LBB6_2
-; RV32-NEXT:  # %bb.1:
-; RV32-NEXT:    mv a1, a2
-; RV32-NEXT:  .LBB6_2:
-; RV32-NEXT:    xor a0, a1, a0
-; RV32-NEXT:    seqz a0, a0
+; RV32-NEXT:    vfirst.m a0, v8, v0.t
+; RV32-NEXT:    slti a0, a0, 0
 ; RV32-NEXT:    ret
 ;
 ; RV64-LABEL: nxv2i32_cmp_evl:

>From 29b7faf7f0dcb37899af288fcf2d5ed227c8510c Mon Sep 17 00:00:00 2001
From: Min Hsu <min.hsu at sifive.com>
Date: Tue, 30 Apr 2024 09:37:04 -0700
Subject: [PATCH 2/4] Address review comments

---
 llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 44 +++++++++++----------
 1 file changed, 24 insertions(+), 20 deletions(-)

diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 43ecf3e3a7fb3a..b86847435c1a6e 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -13690,35 +13690,39 @@ static SDValue performSETCCCombine(SDNode *N, SelectionDAG &DAG,
 
   // Rule 1
   using namespace SDPatternMatch;
-  auto getSelectCCPattern = [](SDValue Candidate, bool Inverse,
-                               SDValue &Select) -> auto {
+  auto matchSelectCC = [](SDValue Op, SDValue Candidate, bool Inverse,
+                          SDValue &Select) -> bool {
+    SDValue NegCandidate;
     if (Inverse)
-      return m_AllOf(
-          m_OneUse(m_Node(RISCVISD::SELECT_CC, m_Value(), m_Value(), m_Value(),
-                          /*TrueVal=*/m_Value(),
-                          /*FalseVal=*/m_Specific(Candidate))),
-          m_Value(Select));
+      return sd_match(
+                 Op,
+                 m_AllOf(m_OneUse(m_Node(RISCVISD::SELECT_CC, m_Value(),
+                                         m_Value(), m_Value(),
+                                         /*TrueVal=*/m_Value(NegCandidate),
+                                         /*FalseVal=*/m_Specific(Candidate))),
+                         m_Value(Select))) &&
+             NegCandidate != Candidate;
     else
-      return m_AllOf(
-          m_OneUse(m_Node(RISCVISD::SELECT_CC, m_Value(), m_Value(), m_Value(),
-                          /*TrueVal=*/m_Specific(Candidate),
-                          /*FalseVal=*/m_Value())),
-          m_Value(Select));
+      return sd_match(
+                 Op,
+                 m_AllOf(m_OneUse(m_Node(RISCVISD::SELECT_CC, m_Value(),
+                                         m_Value(), m_Value(),
+                                         /*TrueVal=*/m_Specific(Candidate),
+                                         /*FalseVal=*/m_Value(NegCandidate))),
+                         m_Value(Select))) &&
+             NegCandidate != Candidate;
   };
 
   auto buildSetCC = [&](SDValue Select, bool Inverse) -> SDValue {
     ISD::CondCode NewCC = cast<CondCodeSDNode>(Select->getOperand(2))->get();
     if (Inverse)
       NewCC = ISD::getSetCCInverse(NewCC, OpVT);
-    return DAG.getNode(
-        ISD::SETCC, DL, VT,
-        {Select->getOperand(0), Select->getOperand(1), DAG.getCondCode(NewCC)},
-        N->getFlags());
+    return DAG.getSetCC(DL, VT, Select->getOperand(0), Select->getOperand(1), NewCC);
   };
 
   SDValue SelectVal;
-  if (sd_match(N0, getSelectCCPattern(N1, false, SelectVal)) ||
-      sd_match(N1, getSelectCCPattern(N0, false, SelectVal))) {
+  if (matchSelectCC(N0, N1, false, SelectVal) ||
+      matchSelectCC(N1, N0, false, SelectVal)) {
     if (Cond == ISD::SETEQ) {
       // (seteq (SELECT_CC LHS, RHS, CC, N1, X), N1) => (setCC LHS, RHS)
       // (seteq N0, (SELECT_CC LHS, RHS, CC, N0, X)) => (setCC LHS, RHS)
@@ -13728,8 +13732,8 @@ static SDValue performSETCCCombine(SDNode *N, SelectionDAG &DAG,
       // (setne N0, (SELECT_CC LHS, RHS, CC, N0, X)) => (setInvCC LHS, RHS)
       return buildSetCC(SelectVal, true);
     }
-  } else if (sd_match(N0, getSelectCCPattern(N1, true, SelectVal)) ||
-             sd_match(N1, getSelectCCPattern(N0, true, SelectVal))) {
+  } else if (matchSelectCC(N0, N1, true, SelectVal) ||
+             matchSelectCC(N1, N0, true, SelectVal)) {
     if (Cond == ISD::SETEQ) {
       // (seteq (SELECT_CC LHS, RHS, CC, X, N1), N1) => (setInvCC LHS, RHS)
       // (seteq N0, (SELECT_CC LHS, RHS, CC, X, N0)) => (setInvCC LHS, RHS)

>From 94165912491c1adeee410c0a03b6cdb70cd43122 Mon Sep 17 00:00:00 2001
From: Min Hsu <min.hsu at sifive.com>
Date: Tue, 30 Apr 2024 10:53:39 -0700
Subject: [PATCH 3/4] Add documentation comments and more tests

---
 llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 22 ++++++----
 llvm/test/CodeGen/RISCV/setcc-optimize.ll   | 47 +++++++++++++++++++++
 2 files changed, 61 insertions(+), 8 deletions(-)
 create mode 100644 llvm/test/CodeGen/RISCV/setcc-optimize.ll

diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index b86847435c1a6e..a5c006c1670cd5 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -13671,6 +13671,12 @@ static bool narrowIndex(SDValue &N, ISD::MemIndexType IndexType, SelectionDAG &D
   return true;
 }
 
+// We're performing 2 combinations here:
+// === Rule 1 ===
+// Given (seteq (riscv_selectcc LHS, RHS, CC, X, Y), X), we can replace it with
+// (setCC LHS, RHS). Similar replacements are done for `setne` too.
+//
+// === Rule 2 ===
 // Replace (seteq (i64 (and X, 0xffffffff)), C1) with
 // (seteq (i64 (sext_inreg (X, i32)), C1')) where C1' is C1 sign extended from
 // bit 31. Same for setne. C1' may be cheaper to materialize and the sext_inreg
@@ -13721,27 +13727,27 @@ static SDValue performSETCCCombine(SDNode *N, SelectionDAG &DAG,
   };
 
   SDValue SelectVal;
-  if (matchSelectCC(N0, N1, false, SelectVal) ||
-      matchSelectCC(N1, N0, false, SelectVal)) {
+  if (matchSelectCC(N0, N1, /*Inverse=*/false, SelectVal) ||
+      matchSelectCC(N1, N0, /*Inverse=*/false, SelectVal)) {
     if (Cond == ISD::SETEQ) {
       // (seteq (SELECT_CC LHS, RHS, CC, N1, X), N1) => (setCC LHS, RHS)
       // (seteq N0, (SELECT_CC LHS, RHS, CC, N0, X)) => (setCC LHS, RHS)
-      return buildSetCC(SelectVal, false);
+      return buildSetCC(SelectVal, /*Inverse=*/false);
     } else {
       // (setne (SELECT_CC LHS, RHS, CC, N1, X), N1) => (setInvCC LHS, RHS)
       // (setne N0, (SELECT_CC LHS, RHS, CC, N0, X)) => (setInvCC LHS, RHS)
-      return buildSetCC(SelectVal, true);
+      return buildSetCC(SelectVal, /*Inverse=*/true);
     }
-  } else if (matchSelectCC(N0, N1, true, SelectVal) ||
-             matchSelectCC(N1, N0, true, SelectVal)) {
+  } else if (matchSelectCC(N0, N1, /*Inverse=*/true, SelectVal) ||
+             matchSelectCC(N1, N0, /*Inverse=*/true, SelectVal)) {
     if (Cond == ISD::SETEQ) {
       // (seteq (SELECT_CC LHS, RHS, CC, X, N1), N1) => (setInvCC LHS, RHS)
       // (seteq N0, (SELECT_CC LHS, RHS, CC, X, N0)) => (setInvCC LHS, RHS)
-      return buildSetCC(SelectVal, true);
+      return buildSetCC(SelectVal, /*Inverse=*/true);
     } else {
       // (setne (SELECT_CC LHS, RHS, CC, X, N1), N1) => (setCC LHS, RHS)
       // (setne N0, (SELECT_CC LHS, RHS, CC, X, N0)) => (setCC LHS, RHS)
-      return buildSetCC(SelectVal, false);
+      return buildSetCC(SelectVal, /*Inverse=*/false);
     }
   }
 
diff --git a/llvm/test/CodeGen/RISCV/setcc-optimize.ll b/llvm/test/CodeGen/RISCV/setcc-optimize.ll
new file mode 100644
index 00000000000000..07060ea4320d1f
--- /dev/null
+++ b/llvm/test/CodeGen/RISCV/setcc-optimize.ll
@@ -0,0 +1,47 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 4
+; RUN: llc -mtriple=riscv32 < %s | FileCheck %s
+
+define i1 @eq(i32 %a, i32 %b, i32 %c, i32 %d) {
+; CHECK-LABEL: eq:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    sltu a0, a0, a1
+; CHECK-NEXT:    xori a0, a0, 1
+; CHECK-NEXT:    ret
+  %p = icmp uge i32 %a, %b
+  %s = select i1 %p, i32 %c, i32 %d
+  %r = icmp eq i32 %s, %c
+  ret i1 %r
+}
+
+define i1 @eq_inv(i32 %a, i32 %b, i32 %c, i32 %d) {
+; CHECK-LABEL: eq_inv:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    sltu a0, a0, a1
+; CHECK-NEXT:    ret
+  %p = icmp uge i32 %a, %b
+  %s = select i1 %p, i32 %d, i32 %c
+  %r = icmp eq i32 %s, %c
+  ret i1 %r
+}
+
+define i1 @ne(i32 %a, i32 %b, i32 %c, i32 %d) {
+; CHECK-LABEL: ne:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    sltu a0, a0, a1
+; CHECK-NEXT:    ret
+  %p = icmp uge i32 %a, %b
+  %s = select i1 %p, i32 %c, i32 %d
+  %r = icmp ne i32 %s, %c
+  ret i1 %r
+}
+
+define i1 @ne_inv(i32 %a, i32 %b, i32 %c, i32 %d) {
+; CHECK-LABEL: ne_inv:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    sltu a0, a1, a0
+; CHECK-NEXT:    ret
+  %p = icmp ugt i32 %a, %b
+  %s = select i1 %p, i32 %d, i32 %c
+  %r = icmp ne i32 %s, %c
+  ret i1 %r
+}

>From 3fa9c7e17494e420b87ddd2856a223b67f0ed84c Mon Sep 17 00:00:00 2001
From: Min Hsu <min.hsu at sifive.com>
Date: Tue, 30 Apr 2024 10:58:55 -0700
Subject: [PATCH 4/4] fixup! Add documentation comments and more tests

---
 llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 3 ++-
 1 file changed, 2 insertions(+), 1 deletion(-)

diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index a5c006c1670cd5..88c88d6b01e42d 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -13723,7 +13723,8 @@ static SDValue performSETCCCombine(SDNode *N, SelectionDAG &DAG,
     ISD::CondCode NewCC = cast<CondCodeSDNode>(Select->getOperand(2))->get();
     if (Inverse)
       NewCC = ISD::getSetCCInverse(NewCC, OpVT);
-    return DAG.getSetCC(DL, VT, Select->getOperand(0), Select->getOperand(1), NewCC);
+    return DAG.getSetCC(DL, VT, Select->getOperand(0), Select->getOperand(1),
+                        NewCC);
   };
 
   SDValue SelectVal;



More information about the llvm-commits mailing list