[llvm] [NVPTX] Use cvt.sat to lower min/max clamping to i8 and i16 ranges (PR #143016)

Alex MacLean via llvm-commits llvm-commits at lists.llvm.org
Thu Jun 5 11:28:17 PDT 2025


https://github.com/AlexMaclean created https://github.com/llvm/llvm-project/pull/143016

None

>From 70fd5a4d20c8742628377cb74e053abccf78bd9f Mon Sep 17 00:00:00 2001
From: Alex Maclean <amaclean at nvidia.com>
Date: Tue, 3 Jun 2025 20:25:06 +0000
Subject: [PATCH] [NVPTX] Use cvt.sat to lower min/max clamping to i8 and i16
 ranges

---
 llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp |  76 ++++++++-
 llvm/lib/Target/NVPTX/NVPTXISelLowering.h   |   7 +
 llvm/lib/Target/NVPTX/NVPTXInstrInfo.td     |  19 +++
 llvm/test/CodeGen/NVPTX/trunc-sat.ll        | 177 ++++++++++++++++++++
 4 files changed, 278 insertions(+), 1 deletion(-)
 create mode 100644 llvm/test/CodeGen/NVPTX/trunc-sat.ll

diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index d6a134d9abafd..10c56deb4642d 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -28,6 +28,7 @@
 #include "llvm/CodeGen/MachineFunction.h"
 #include "llvm/CodeGen/MachineJumpTableInfo.h"
 #include "llvm/CodeGen/MachineMemOperand.h"
+#include "llvm/CodeGen/SDPatternMatch.h"
 #include "llvm/CodeGen/SelectionDAG.h"
 #include "llvm/CodeGen/SelectionDAGNodes.h"
 #include "llvm/CodeGen/TargetCallingConv.h"
@@ -74,6 +75,7 @@
 #define DEBUG_TYPE "nvptx-lower"
 
 using namespace llvm;
+using namespace llvm::SDPatternMatch;
 
 static cl::opt<bool> sched4reg(
     "nvptx-sched4reg",
@@ -659,6 +661,11 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
     setOperationAction(ISD::BR_CC, VT, Expand);
   }
 
+  setOperationAction({ISD::TRUNCATE_SSAT_S, ISD::TRUNCATE_SSAT_U}, MVT::i16,
+                     Legal);
+  setOperationAction({ISD::TRUNCATE_SSAT_S, ISD::TRUNCATE_SSAT_U}, MVT::i8,
+                     Custom);
+
   // Some SIGN_EXTEND_INREG can be done using cvt instruction.
   // For others we will expand to a SHL/SRA pair.
   setOperationAction(ISD::SIGN_EXTEND_INREG, MVT::i64, Legal);
@@ -836,7 +843,8 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
   // We have some custom DAG combine patterns for these nodes
   setTargetDAGCombine({ISD::ADD, ISD::AND, ISD::EXTRACT_VECTOR_ELT, ISD::FADD,
                        ISD::MUL, ISD::SHL, ISD::SREM, ISD::UREM, ISD::VSELECT,
-                       ISD::BUILD_VECTOR, ISD::ADDRSPACECAST});
+                       ISD::BUILD_VECTOR, ISD::ADDRSPACECAST, ISD::SMIN,
+                       ISD::SMAX});
 
   // setcc for f16x2 and bf16x2 needs special handling to prevent
   // legalizer's attempt to scalarize it due to v2i1 not being legal.
@@ -1081,6 +1089,8 @@ const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
     MAKE_CASE(NVPTXISD::PseudoUseParam)
     MAKE_CASE(NVPTXISD::UNPACK_VECTOR)
     MAKE_CASE(NVPTXISD::BUILD_VECTOR)
+    MAKE_CASE(NVPTXISD::TRUNCATE_SSAT_U_I8)
+    MAKE_CASE(NVPTXISD::TRUNCATE_SSAT_S_I8)
     MAKE_CASE(NVPTXISD::RETURN)
     MAKE_CASE(NVPTXISD::CallSeqBegin)
     MAKE_CASE(NVPTXISD::CallSeqEnd)
@@ -5667,6 +5677,49 @@ static SDValue combineADDRSPACECAST(SDNode *N,
   return SDValue();
 }
 
+static SDValue combineMINMAX(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
+
+  EVT VT = N->getValueType(0);
+  if (!(VT == MVT::i32 || VT == MVT::i64 || VT == MVT::i16))
+    return SDValue();
+
+  SDValue Val;
+  APInt Ceil, Floor;
+  if (!(sd_match(N, m_SMin(m_SMax(m_Value(Val), m_ConstInt(Floor)),
+                           m_ConstInt(Ceil))) ||
+        sd_match(N, m_SMax(m_SMin(m_Value(Val), m_ConstInt(Ceil)),
+                           m_ConstInt(Floor)))))
+    return SDValue();
+
+  const unsigned BitWidth = VT.getSizeInBits();
+  SDLoc DL(N);
+  auto MatchTuncSat = [&](MVT DestVT) {
+    const unsigned DestBitWidth = DestVT.getSizeInBits();
+    bool IsSigned;
+    if (Ceil == APInt::getSignedMaxValue(DestBitWidth).sext(BitWidth) &&
+        Floor == APInt::getSignedMinValue(DestBitWidth).sext(BitWidth))
+      IsSigned = true;
+    else if (Ceil == APInt::getMaxValue(DestBitWidth).zext(BitWidth) &&
+             Floor == APInt::getMinValue(BitWidth))
+      IsSigned = false;
+    else
+      return SDValue();
+
+    unsigned Opcode = IsSigned ? ISD::TRUNCATE_SSAT_S : ISD::TRUNCATE_SSAT_U;
+    SDValue Trunc = DCI.DAG.getNode(Opcode, DL, DestVT, Val);
+    return DCI.DAG.getExtOrTrunc(IsSigned, Trunc, DL, VT);
+  };
+
+  if (VT != MVT::i16)
+    if (auto Res = MatchTuncSat(MVT::i16))
+      return Res;
+
+  if (auto Res = MatchTuncSat(MVT::i8))
+    return Res;
+
+  return SDValue();
+}
+
 SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
                                                DAGCombinerInfo &DCI) const {
   CodeGenOptLevel OptLevel = getTargetMachine().getOptLevel();
@@ -5685,6 +5738,9 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
     case ISD::UREM:
     case ISD::SREM:
       return PerformREMCombine(N, DCI, OptLevel);
+    case ISD::SMIN:
+    case ISD::SMAX:
+      return combineMINMAX(N, DCI);
     case ISD::SETCC:
       return PerformSETCCCombine(N, DCI, STI.getSmVersion());
     case NVPTXISD::StoreRetval:
@@ -6045,6 +6101,20 @@ static void ReplaceCopyFromReg_128(SDNode *N, SelectionDAG &DAG,
   Results.push_back(NewValue.getValue(3));
 }
 
+static void replaceTruncateSSat(SDNode *N, SelectionDAG &DAG,
+                                SmallVectorImpl<SDValue> &Results) {
+  SDLoc DL(N);
+
+  const bool IsSigned = N->getOpcode() == ISD::TRUNCATE_SSAT_S;
+  const unsigned Opcode =
+      IsSigned ? NVPTXISD::TRUNCATE_SSAT_S_I8 : NVPTXISD::TRUNCATE_SSAT_U_I8;
+  SDValue NewTrunc = DAG.getNode(Opcode, DL, MVT::i16, N->getOperand(0));
+  SDValue Assert = DAG.getNode(IsSigned ? ISD::AssertSext : ISD::AssertZext, DL,
+                               MVT::i16, NewTrunc, DAG.getValueType(MVT::i8));
+
+  Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, MVT::i8, Assert));
+}
+
 void NVPTXTargetLowering::ReplaceNodeResults(
     SDNode *N, SmallVectorImpl<SDValue> &Results, SelectionDAG &DAG) const {
   switch (N->getOpcode()) {
@@ -6062,6 +6132,10 @@ void NVPTXTargetLowering::ReplaceNodeResults(
   case ISD::CopyFromReg:
     ReplaceCopyFromReg_128(N, DAG, Results);
     return;
+  case ISD::TRUNCATE_SSAT_U:
+  case ISD::TRUNCATE_SSAT_S:
+    replaceTruncateSSat(N, DAG, Results);
+    return;
   }
 }
 
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
index 8d71022a1f102..1bd8ffb65c501 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
@@ -72,6 +72,13 @@ enum NodeType : unsigned {
   /// converting it to a vector.
   UNPACK_VECTOR,
 
+  /// These nodes are equivalent to the corresponding ISD nodes except that
+  /// they truncate to an i8 output and then sign or zero extend that value back
+  /// to i16. This is a workaround for the fact that NVPTX does not consider
+  /// i8 to be a legal type. TODO: consider making i8 legal and removing these.
+  TRUNCATE_SSAT_U_I8,
+  TRUNCATE_SSAT_S_I8,
+
   FCOPYSIGN,
   DYNAMIC_STACKALLOC,
   STACKRESTORE,
diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
index b646d39194c7e..e26a05a112e09 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
@@ -2649,6 +2649,25 @@ def : Pat<(i1  (trunc i32:$a)), (SETP_b32ri (ANDb32ri $a, 1), 0, CmpNE)>;
 // truncate i16
 def : Pat<(i1 (trunc i16:$a)), (SETP_b16ri (ANDb16ri $a, 1), 0, CmpNE)>;
 
+// truncate ssat
+def SDTTruncSatI8Op : SDTypeProfile<1, 1, [SDTCisInt<1>, SDTCisVT<0, i16>]>;
+def truncssat_s_i8 : SDNode<"NVPTXISD::TRUNCATE_SSAT_S_I8", SDTTruncSatI8Op>;
+def truncssat_u_i8 : SDNode<"NVPTXISD::TRUNCATE_SSAT_U_I8", SDTTruncSatI8Op>;
+
+def : Pat<(i16 (truncssat_s i32:$a)), (CVT_s16_s32 $a, CvtSAT)>;
+def : Pat<(i16 (truncssat_s i64:$a)), (CVT_s16_s64 $a, CvtSAT)>;
+
+def : Pat<(i16 (truncssat_u i32:$a)), (CVT_u16_s32 $a, CvtSAT)>;
+def : Pat<(i16 (truncssat_u i64:$a)), (CVT_u16_s64 $a, CvtSAT)>;
+
+def : Pat<(truncssat_s_i8 i16:$a), (CVT_s8_s16 $a, CvtSAT)>;
+def : Pat<(truncssat_s_i8 i32:$a), (CVT_s8_s32 $a, CvtSAT)>;
+def : Pat<(truncssat_s_i8 i64:$a), (CVT_s8_s64 $a, CvtSAT)>;
+
+def : Pat<(truncssat_u_i8 i16:$a), (CVT_u8_u16 $a, CvtSAT)>;
+def : Pat<(truncssat_u_i8 i32:$a), (CVT_u8_u32 $a, CvtSAT)>;
+def : Pat<(truncssat_u_i8 i64:$a), (CVT_u8_u64 $a, CvtSAT)>;
+
 // sext_inreg
 def : Pat<(sext_inreg i16:$a, i8), (CVT_INREG_s16_s8 $a)>;
 def : Pat<(sext_inreg i32:$a, i8), (CVT_INREG_s32_s8 $a)>;
diff --git a/llvm/test/CodeGen/NVPTX/trunc-sat.ll b/llvm/test/CodeGen/NVPTX/trunc-sat.ll
new file mode 100644
index 0000000000000..e77e7691c9cb7
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/trunc-sat.ll
@@ -0,0 +1,177 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc < %s -mtriple=nvptx64 -mcpu=sm_20 -verify-machineinstrs | FileCheck %s
+; RUN: %if ptxas %{ llc < %s -mtriple=nvptx64 -mcpu=sm_20 -verify-machineinstrs | %ptxas-verify %}
+
+target triple = "nvptx-unknown-cuda"
+
+
+define i64 @trunc_ssat_i64_u16(i64 %a) {
+; CHECK-LABEL: trunc_ssat_i64_u16(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b16 %rs<2>;
+; CHECK-NEXT:    .reg .b64 %rd<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.b64 %rd1, [trunc_ssat_i64_u16_param_0];
+; CHECK-NEXT:    cvt.sat.u16.s64 %rs1, %rd1;
+; CHECK-NEXT:    cvt.u64.u16 %rd2, %rs1;
+; CHECK-NEXT:    st.param.b64 [func_retval0], %rd2;
+; CHECK-NEXT:    ret;
+  %v1 = call i64 @llvm.smax.i64(i64 %a, i64 0)
+  %v2 = call i64 @llvm.smin.i64(i64 %v1, i64 65535)
+  ret i64 %v2
+}
+
+define i32 @trunc_ssat_i32_u16(i32 %a) {
+; CHECK-LABEL: trunc_ssat_i32_u16(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b16 %rs<2>;
+; CHECK-NEXT:    .reg .b32 %r<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.b32 %r1, [trunc_ssat_i32_u16_param_0];
+; CHECK-NEXT:    cvt.sat.u16.s32 %rs1, %r1;
+; CHECK-NEXT:    cvt.u32.u16 %r2, %rs1;
+; CHECK-NEXT:    st.param.b32 [func_retval0], %r2;
+; CHECK-NEXT:    ret;
+  %v1 = call i32 @llvm.smax.i32(i32 %a, i32 0)
+  %v2 = call i32 @llvm.smin.i32(i32 %v1, i32 65535)
+  ret i32 %v2
+}
+
+define i64 @trunc_ssat_i64_s16(i64 %a) {
+; CHECK-LABEL: trunc_ssat_i64_s16(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b16 %rs<2>;
+; CHECK-NEXT:    .reg .b64 %rd<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.b64 %rd1, [trunc_ssat_i64_s16_param_0];
+; CHECK-NEXT:    cvt.sat.s16.s64 %rs1, %rd1;
+; CHECK-NEXT:    cvt.s64.s16 %rd2, %rs1;
+; CHECK-NEXT:    st.param.b64 [func_retval0], %rd2;
+; CHECK-NEXT:    ret;
+  %v1 = call i64 @llvm.smax.i64(i64 %a, i64 -32768)
+  %v2 = call i64 @llvm.smin.i64(i64 %v1, i64 32767)
+  ret i64 %v2
+}
+
+define i32 @trunc_ssat_i32_s16(i32 %a) {
+; CHECK-LABEL: trunc_ssat_i32_s16(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b16 %rs<2>;
+; CHECK-NEXT:    .reg .b32 %r<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.b32 %r1, [trunc_ssat_i32_s16_param_0];
+; CHECK-NEXT:    cvt.sat.s16.s32 %rs1, %r1;
+; CHECK-NEXT:    cvt.s32.s16 %r2, %rs1;
+; CHECK-NEXT:    st.param.b32 [func_retval0], %r2;
+; CHECK-NEXT:    ret;
+  %v1 = call i32 @llvm.smax.i32(i32 %a, i32 -32768)
+  %v2 = call i32 @llvm.smin.i32(i32 %v1, i32 32767)
+  ret i32 %v2
+}
+
+define i64 @trunc_ssat_i64_u8(i64 %a) {
+; CHECK-LABEL: trunc_ssat_i64_u8(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b16 %rs<2>;
+; CHECK-NEXT:    .reg .b64 %rd<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.b64 %rd1, [trunc_ssat_i64_u8_param_0];
+; CHECK-NEXT:    cvt.sat.u8.u64 %rs1, %rd1;
+; CHECK-NEXT:    cvt.u64.u16 %rd2, %rs1;
+; CHECK-NEXT:    st.param.b64 [func_retval0], %rd2;
+; CHECK-NEXT:    ret;
+  %v1 = call i64 @llvm.smax.i64(i64 %a, i64 0)
+  %v2 = call i64 @llvm.smin.i64(i64 %v1, i64 255)
+  ret i64 %v2
+}
+
+define i32 @trunc_ssat_i32_u8(i32 %a) {
+; CHECK-LABEL: trunc_ssat_i32_u8(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b16 %rs<2>;
+; CHECK-NEXT:    .reg .b32 %r<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.b32 %r1, [trunc_ssat_i32_u8_param_0];
+; CHECK-NEXT:    cvt.sat.u8.u32 %rs1, %r1;
+; CHECK-NEXT:    cvt.u32.u16 %r2, %rs1;
+; CHECK-NEXT:    st.param.b32 [func_retval0], %r2;
+; CHECK-NEXT:    ret;
+  %v1 = call i32 @llvm.smax.i32(i32 %a, i32 0)
+  %v2 = call i32 @llvm.smin.i32(i32 %v1, i32 255)
+  ret i32 %v2
+}
+
+define i16 @trunc_ssat_i16_u8(i16 %a) {
+; CHECK-LABEL: trunc_ssat_i16_u8(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b16 %rs<3>;
+; CHECK-NEXT:    .reg .b32 %r<2>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.b16 %rs1, [trunc_ssat_i16_u8_param_0];
+; CHECK-NEXT:    cvt.sat.u8.u16 %rs2, %rs1;
+; CHECK-NEXT:    cvt.u32.u16 %r1, %rs2;
+; CHECK-NEXT:    st.param.b32 [func_retval0], %r1;
+; CHECK-NEXT:    ret;
+  %v1 = call i16 @llvm.smax.i16(i16 %a, i16 0)
+  %v2 = call i16 @llvm.smin.i16(i16 %v1, i16 255)
+  ret i16 %v2
+}
+
+define i64 @trunc_ssat_i64_s8(i64 %a) {
+; CHECK-LABEL: trunc_ssat_i64_s8(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b16 %rs<2>;
+; CHECK-NEXT:    .reg .b64 %rd<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.b64 %rd1, [trunc_ssat_i64_s8_param_0];
+; CHECK-NEXT:    cvt.sat.s8.s64 %rs1, %rd1;
+; CHECK-NEXT:    cvt.s64.s16 %rd2, %rs1;
+; CHECK-NEXT:    st.param.b64 [func_retval0], %rd2;
+; CHECK-NEXT:    ret;
+  %v1 = call i64 @llvm.smax.i64(i64 %a, i64 -128)
+  %v2 = call i64 @llvm.smin.i64(i64 %v1, i64 127)
+  ret i64 %v2
+}
+
+define i32 @trunc_ssat_i32_s8(i32 %a) {
+; CHECK-LABEL: trunc_ssat_i32_s8(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b16 %rs<2>;
+; CHECK-NEXT:    .reg .b32 %r<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.b32 %r1, [trunc_ssat_i32_s8_param_0];
+; CHECK-NEXT:    cvt.sat.s8.s32 %rs1, %r1;
+; CHECK-NEXT:    cvt.s32.s16 %r2, %rs1;
+; CHECK-NEXT:    st.param.b32 [func_retval0], %r2;
+; CHECK-NEXT:    ret;
+  %v1 = call i32 @llvm.smax.i32(i32 %a, i32 -128)
+  %v2 = call i32 @llvm.smin.i32(i32 %v1, i32 127)
+  ret i32 %v2
+}
+
+define i16 @trunc_ssat_i16_s8(i16 %a) {
+; CHECK-LABEL: trunc_ssat_i16_s8(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b16 %rs<3>;
+; CHECK-NEXT:    .reg .b32 %r<2>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.b16 %rs1, [trunc_ssat_i16_s8_param_0];
+; CHECK-NEXT:    cvt.sat.s8.s16 %rs2, %rs1;
+; CHECK-NEXT:    cvt.u32.u16 %r1, %rs2;
+; CHECK-NEXT:    st.param.b32 [func_retval0], %r1;
+; CHECK-NEXT:    ret;
+  %v1 = call i16 @llvm.smax.i16(i16 %a, i16 -128)
+  %v2 = call i16 @llvm.smin.i16(i16 %v1, i16 127)
+  ret i16 %v2
+}
+



More information about the llvm-commits mailing list