[llvm] [NVPTX] Fix lowering of i1 SETCC (PR #115035)

Alex MacLean via llvm-commits llvm-commits at lists.llvm.org
Thu Nov 7 11:37:45 PST 2024


https://github.com/AlexMaclean updated https://github.com/llvm/llvm-project/pull/115035

>From 9535ee8c27ef687828fbfa530faf3c13bef88ef1 Mon Sep 17 00:00:00 2001
From: Alex Maclean <amaclean at nvidia.com>
Date: Thu, 7 Nov 2024 19:37:27 +0000
Subject: [PATCH] [NVPTX] Fix lowering of i1 SETCC - address comments

---
 llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp |   5 +-
 .../CodeGen/SelectionDAG/TargetLowering.cpp   |   4 +-
 llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp   |  47 +++++
 llvm/lib/Target/NVPTX/NVPTXISelLowering.h     |   2 +
 llvm/test/CodeGen/NVPTX/i1-icmp.ll            | 193 ++++++++++++++++++
 5 files changed, 247 insertions(+), 4 deletions(-)
 create mode 100644 llvm/test/CodeGen/NVPTX/i1-icmp.ll

diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 42232bd195a651..46c0b4de59d39a 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -18738,8 +18738,9 @@ SDValue DAGCombiner::rebuildSetCC(SDValue N) {
       if (LegalTypes)
         SetCCVT = getSetCCResultType(SetCCVT);
       // Replace the uses of XOR with SETCC
-      return DAG.getSetCC(SDLoc(N), SetCCVT, Op0, Op1,
-                          Equal ? ISD::SETEQ : ISD::SETNE);
+      const ISD::CondCode CC = Equal ? ISD::SETEQ : ISD::SETNE;
+      if (!LegalOperations || TLI.isCondCodeLegal(CC, Op0.getSimpleValueType()))
+        return DAG.getSetCC(SDLoc(N), SetCCVT, Op0, Op1, CC);
     }
   }
 
diff --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
index f21233abfa4f5d..e165b40195e1cf 100644
--- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
@@ -4946,7 +4946,7 @@ SDValue TargetLowering::SimplifySetCC(EVT VT, SDValue N0, SDValue N1,
         APInt C = C1 - 1;
         ISD::CondCode NewCC = (Cond == ISD::SETGE) ? ISD::SETGT : ISD::SETUGT;
         if ((DCI.isBeforeLegalizeOps() ||
-             isCondCodeLegal(NewCC, VT.getSimpleVT())) &&
+             isCondCodeLegal(NewCC, OpVT.getSimpleVT())) &&
             (!N1C->isOpaque() || (C.getBitWidth() <= 64 &&
                                   isLegalICmpImmediate(C.getSExtValue())))) {
           return DAG.getSetCC(dl, VT, N0,
@@ -4966,7 +4966,7 @@ SDValue TargetLowering::SimplifySetCC(EVT VT, SDValue N0, SDValue N1,
         APInt C = C1 + 1;
         ISD::CondCode NewCC = (Cond == ISD::SETLE) ? ISD::SETLT : ISD::SETULT;
         if ((DCI.isBeforeLegalizeOps() ||
-             isCondCodeLegal(NewCC, VT.getSimpleVT())) &&
+             isCondCodeLegal(NewCC, OpVT.getSimpleVT())) &&
             (!N1C->isOpaque() || (C.getBitWidth() <= 64 &&
                                   isLegalICmpImmediate(C.getSExtValue())))) {
           return DAG.getSetCC(dl, VT, N0,
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index d3bf0ecfe2cc92..7b71273c5aeee1 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -668,6 +668,11 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
     setTruncStoreAction(VT, MVT::i1, Expand);
   }
 
+  setCondCodeAction({ISD::SETNE, ISD::SETEQ, ISD::SETUGE, ISD::SETULE,
+                     ISD::SETUGT, ISD::SETULT, ISD::SETGT, ISD::SETLT,
+                     ISD::SETGE, ISD::SETLE},
+                    MVT::i1, Custom);
+
   // expand extload of vector of integers.
   setLoadExtAction({ISD::EXTLOAD, ISD::SEXTLOAD, ISD::ZEXTLOAD}, MVT::v2i16,
                    MVT::v2i8, Expand);
@@ -2666,6 +2671,46 @@ SDValue NVPTXTargetLowering::LowerShiftLeftParts(SDValue Op,
   }
 }
 
+// Lowers SETCC nodes that aren't directly supported by our arch.
+SDValue NVPTXTargetLowering::LowerSETCC(SDValue Op, SelectionDAG &DAG) const {
+  SDValue L = Op->getOperand(0);
+  SDValue R = Op->getOperand(1);
+
+  if (L.getValueType() != MVT::i1)
+    return SDValue();
+
+  SDLoc DL(Op);
+  SDValue Ret;
+  switch (cast<CondCodeSDNode>(Op->getOperand(2))->get()) {
+  default:
+    llvm_unreachable("Unknown integer setcc!");
+  case ISD::SETEQ: // X == Y  -> ~(X^Y)
+    Ret = DAG.getNOT(DL, DAG.getNode(ISD::XOR, DL, MVT::i1, L, R), MVT::i1);
+    break;
+  case ISD::SETNE: // X != Y   -->  (X^Y)
+    Ret = DAG.getNode(ISD::XOR, DL, MVT::i1, L, R);
+    break;
+  case ISD::SETGT:  // X >s Y   -->  X == 0 & Y == 1  -->  ~X & Y
+  case ISD::SETULT: // X <u Y   -->  X == 0 & Y == 1  -->  ~X & Y
+    Ret = DAG.getNode(ISD::AND, DL, MVT::i1, R, DAG.getNOT(DL, L, MVT::i1));
+    break;
+  case ISD::SETLT:  // X <s Y   --> X == 1 & Y == 0  -->  ~Y & X
+  case ISD::SETUGT: // X >u Y   --> X == 1 & Y == 0  -->  ~Y & X
+    Ret = DAG.getNode(ISD::AND, DL, MVT::i1, L, DAG.getNOT(DL, R, MVT::i1));
+    break;
+  case ISD::SETULE: // X <=u Y  --> X == 0 | Y == 1  -->  ~X | Y
+  case ISD::SETGE:  // X >=s Y  --> X == 0 | Y == 1  -->  ~X | Y
+    Ret = DAG.getNode(ISD::OR, DL, MVT::i1, R, DAG.getNOT(DL, L, MVT::i1));
+    break;
+  case ISD::SETUGE: // X >=u Y  --> X == 1 | Y == 0  -->  ~Y | X
+  case ISD::SETLE:  // X <=s Y  --> X == 1 | Y == 0  -->  ~Y | X
+    Ret = DAG.getNode(ISD::OR, DL, MVT::i1, L, DAG.getNOT(DL, R, MVT::i1));
+    break;
+  }
+
+  return DAG.getZExtOrTrunc(Ret, DL, Op.getValueType());
+}
+
 /// If the types match, convert the generic copysign to the NVPTXISD version,
 /// otherwise bail ensuring that mismatched cases are properly expaned.
 SDValue NVPTXTargetLowering::LowerFCOPYSIGN(SDValue Op,
@@ -2919,6 +2964,8 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
     return LowerSTORE(Op, DAG);
   case ISD::LOAD:
     return LowerLOAD(Op, DAG);
+  case ISD::SETCC:
+    return LowerSETCC(Op, DAG);
   case ISD::SHL_PARTS:
     return LowerShiftLeftParts(Op, DAG);
   case ISD::SRA_PARTS:
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
index c8b589ae39413e..b1bb9090464ac4 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
@@ -628,6 +628,8 @@ class NVPTXTargetLowering : public TargetLowering {
   SDValue LowerINSERT_VECTOR_ELT(SDValue Op, SelectionDAG &DAG) const;
   SDValue LowerVECTOR_SHUFFLE(SDValue Op, SelectionDAG &DAG) const;
 
+  SDValue LowerSETCC(SDValue Op, SelectionDAG &DAG) const;
+
   SDValue LowerFCOPYSIGN(SDValue Op, SelectionDAG &DAG) const;
 
   SDValue LowerFROUND(SDValue Op, SelectionDAG &DAG) const;
diff --git a/llvm/test/CodeGen/NVPTX/i1-icmp.ll b/llvm/test/CodeGen/NVPTX/i1-icmp.ll
new file mode 100644
index 00000000000000..db9ae6541b87ae
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/i1-icmp.ll
@@ -0,0 +1,193 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc < %s -march=nvptx64 -mcpu=sm_20 | FileCheck %s
+; RUN: %if ptxas %{ llc < %s -march=nvptx64 -mcpu=sm_20 | %ptxas-verify %}
+
+target triple = "nvptx-nvidia-cuda"
+
+define i32 @icmp_i1_eq(i32 %a, i32 %b) {
+; CHECK-LABEL: icmp_i1_eq(
+; CHECK:       {
+; CHECK-NEXT:    .reg .pred %p<4>;
+; CHECK-NEXT:    .reg .b32 %r<5>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.u32 %r1, [icmp_i1_eq_param_0];
+; CHECK-NEXT:    setp.gt.s32 %p1, %r1, 1;
+; CHECK-NEXT:    ld.param.u32 %r2, [icmp_i1_eq_param_1];
+; CHECK-NEXT:    setp.gt.s32 %p2, %r2, 1;
+; CHECK-NEXT:    xor.pred %p3, %p1, %p2;
+; CHECK-NEXT:    @%p3 bra $L__BB0_2;
+; CHECK-NEXT:  // %bb.1: // %bb1
+; CHECK-NEXT:    mov.b32 %r4, 1;
+; CHECK-NEXT:    st.param.b32 [func_retval0], %r4;
+; CHECK-NEXT:    ret;
+; CHECK-NEXT:  $L__BB0_2: // %bb2
+; CHECK-NEXT:    mov.b32 %r3, 127;
+; CHECK-NEXT:    st.param.b32 [func_retval0], %r3;
+; CHECK-NEXT:    ret;
+  %p1 = icmp sgt i32 %a, 1
+  %p2 = icmp sgt i32 %b, 1
+  %c = icmp eq i1 %p1, %p2
+  br i1 %c, label %bb1, label %bb2
+bb1:
+  ret i32 1
+bb2:
+  ret i32 127
+}
+
+define i32 @icmp_i1_ne(i32 %a, i32 %b) {
+; CHECK-LABEL: icmp_i1_ne(
+; CHECK:       {
+; CHECK-NEXT:    .reg .pred %p<5>;
+; CHECK-NEXT:    .reg .b32 %r<5>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.u32 %r1, [icmp_i1_ne_param_0];
+; CHECK-NEXT:    setp.gt.s32 %p1, %r1, 1;
+; CHECK-NEXT:    ld.param.u32 %r2, [icmp_i1_ne_param_1];
+; CHECK-NEXT:    setp.gt.s32 %p2, %r2, 1;
+; CHECK-NEXT:    xor.pred %p3, %p1, %p2;
+; CHECK-NEXT:    not.pred %p4, %p3;
+; CHECK-NEXT:    @%p4 bra $L__BB1_2;
+; CHECK-NEXT:  // %bb.1: // %bb1
+; CHECK-NEXT:    mov.b32 %r4, 1;
+; CHECK-NEXT:    st.param.b32 [func_retval0], %r4;
+; CHECK-NEXT:    ret;
+; CHECK-NEXT:  $L__BB1_2: // %bb2
+; CHECK-NEXT:    mov.b32 %r3, 127;
+; CHECK-NEXT:    st.param.b32 [func_retval0], %r3;
+; CHECK-NEXT:    ret;
+  %p1 = icmp sgt i32 %a, 1
+  %p2 = icmp sgt i32 %b, 1
+  %c = icmp ne i1 %p1, %p2
+  br i1 %c, label %bb1, label %bb2
+bb1:
+  ret i32 1
+bb2:
+  ret i32 127
+}
+
+define i32 @icmp_i1_sgt(i32 %a, i32 %b) {
+; CHECK-LABEL: icmp_i1_sgt(
+; CHECK:       {
+; CHECK-NEXT:    .reg .pred %p<4>;
+; CHECK-NEXT:    .reg .b32 %r<5>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.u32 %r1, [icmp_i1_sgt_param_0];
+; CHECK-NEXT:    setp.gt.s32 %p1, %r1, 1;
+; CHECK-NEXT:    ld.param.u32 %r2, [icmp_i1_sgt_param_1];
+; CHECK-NEXT:    setp.lt.s32 %p2, %r2, 2;
+; CHECK-NEXT:    or.pred %p3, %p1, %p2;
+; CHECK-NEXT:    @%p3 bra $L__BB2_2;
+; CHECK-NEXT:  // %bb.1: // %bb1
+; CHECK-NEXT:    mov.b32 %r4, 1;
+; CHECK-NEXT:    st.param.b32 [func_retval0], %r4;
+; CHECK-NEXT:    ret;
+; CHECK-NEXT:  $L__BB2_2: // %bb2
+; CHECK-NEXT:    mov.b32 %r3, 127;
+; CHECK-NEXT:    st.param.b32 [func_retval0], %r3;
+; CHECK-NEXT:    ret;
+  %p1 = icmp sgt i32 %a, 1
+  %p2 = icmp sgt i32 %b, 1
+  %c = icmp sgt i1 %p1, %p2
+  br i1 %c, label %bb1, label %bb2
+bb1:
+  ret i32 1
+bb2:
+  ret i32 127
+}
+
+define i32 @icmp_i1_slt(i32 %a, i32 %b) {
+; CHECK-LABEL: icmp_i1_slt(
+; CHECK:       {
+; CHECK-NEXT:    .reg .pred %p<4>;
+; CHECK-NEXT:    .reg .b32 %r<5>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.u32 %r1, [icmp_i1_slt_param_0];
+; CHECK-NEXT:    setp.lt.s32 %p1, %r1, 2;
+; CHECK-NEXT:    ld.param.u32 %r2, [icmp_i1_slt_param_1];
+; CHECK-NEXT:    setp.gt.s32 %p2, %r2, 1;
+; CHECK-NEXT:    or.pred %p3, %p2, %p1;
+; CHECK-NEXT:    @%p3 bra $L__BB3_2;
+; CHECK-NEXT:  // %bb.1: // %bb1
+; CHECK-NEXT:    mov.b32 %r4, 1;
+; CHECK-NEXT:    st.param.b32 [func_retval0], %r4;
+; CHECK-NEXT:    ret;
+; CHECK-NEXT:  $L__BB3_2: // %bb2
+; CHECK-NEXT:    mov.b32 %r3, 127;
+; CHECK-NEXT:    st.param.b32 [func_retval0], %r3;
+; CHECK-NEXT:    ret;
+  %p1 = icmp sgt i32 %a, 1
+  %p2 = icmp sgt i32 %b, 1
+  %c = icmp slt i1 %p1, %p2
+  br i1 %c, label %bb1, label %bb2
+bb1:
+  ret i32 1
+bb2:
+  ret i32 127
+}
+
+define i32 @icmp_i1_sge(i32 %a, i32 %b) {
+; CHECK-LABEL: icmp_i1_sge(
+; CHECK:       {
+; CHECK-NEXT:    .reg .pred %p<4>;
+; CHECK-NEXT:    .reg .b32 %r<5>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.u32 %r1, [icmp_i1_sge_param_0];
+; CHECK-NEXT:    setp.gt.s32 %p1, %r1, 1;
+; CHECK-NEXT:    ld.param.u32 %r2, [icmp_i1_sge_param_1];
+; CHECK-NEXT:    setp.lt.s32 %p2, %r2, 2;
+; CHECK-NEXT:    and.pred %p3, %p1, %p2;
+; CHECK-NEXT:    @%p3 bra $L__BB4_2;
+; CHECK-NEXT:  // %bb.1: // %bb1
+; CHECK-NEXT:    mov.b32 %r4, 1;
+; CHECK-NEXT:    st.param.b32 [func_retval0], %r4;
+; CHECK-NEXT:    ret;
+; CHECK-NEXT:  $L__BB4_2: // %bb2
+; CHECK-NEXT:    mov.b32 %r3, 127;
+; CHECK-NEXT:    st.param.b32 [func_retval0], %r3;
+; CHECK-NEXT:    ret;
+  %p1 = icmp sgt i32 %a, 1
+  %p2 = icmp sgt i32 %b, 1
+  %c = icmp sge i1 %p1, %p2
+  br i1 %c, label %bb1, label %bb2
+bb1:
+  ret i32 1
+bb2:
+  ret i32 127
+}
+
+define i32 @icmp_i1_sle(i32 %a, i32 %b) {
+; CHECK-LABEL: icmp_i1_sle(
+; CHECK:       {
+; CHECK-NEXT:    .reg .pred %p<4>;
+; CHECK-NEXT:    .reg .b32 %r<5>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.u32 %r1, [icmp_i1_sle_param_0];
+; CHECK-NEXT:    setp.lt.s32 %p1, %r1, 2;
+; CHECK-NEXT:    ld.param.u32 %r2, [icmp_i1_sle_param_1];
+; CHECK-NEXT:    setp.gt.s32 %p2, %r2, 1;
+; CHECK-NEXT:    and.pred %p3, %p2, %p1;
+; CHECK-NEXT:    @%p3 bra $L__BB5_2;
+; CHECK-NEXT:  // %bb.1: // %bb1
+; CHECK-NEXT:    mov.b32 %r4, 1;
+; CHECK-NEXT:    st.param.b32 [func_retval0], %r4;
+; CHECK-NEXT:    ret;
+; CHECK-NEXT:  $L__BB5_2: // %bb2
+; CHECK-NEXT:    mov.b32 %r3, 127;
+; CHECK-NEXT:    st.param.b32 [func_retval0], %r3;
+; CHECK-NEXT:    ret;
+  %p1 = icmp sgt i32 %a, 1
+  %p2 = icmp sgt i32 %b, 1
+  %c = icmp sle i1 %p1, %p2
+  br i1 %c, label %bb1, label %bb2
+bb1:
+  ret i32 1
+bb2:
+  ret i32 127
+}
+



More information about the llvm-commits mailing list