[llvm] [NVPTX] Use cvt.sat to lower min/max clamping to i8 and i16 ranges (PR #143016)
Alex MacLean via llvm-commits
llvm-commits at lists.llvm.org
Thu Jun 5 11:28:17 PDT 2025
https://github.com/AlexMaclean created https://github.com/llvm/llvm-project/pull/143016
None
>From 70fd5a4d20c8742628377cb74e053abccf78bd9f Mon Sep 17 00:00:00 2001
From: Alex Maclean <amaclean at nvidia.com>
Date: Tue, 3 Jun 2025 20:25:06 +0000
Subject: [PATCH] [NVPTX] Use cvt.sat to lower min/max clamping to i8 and i16
ranges
---
llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp | 76 ++++++++-
llvm/lib/Target/NVPTX/NVPTXISelLowering.h | 7 +
llvm/lib/Target/NVPTX/NVPTXInstrInfo.td | 19 +++
llvm/test/CodeGen/NVPTX/trunc-sat.ll | 177 ++++++++++++++++++++
4 files changed, 278 insertions(+), 1 deletion(-)
create mode 100644 llvm/test/CodeGen/NVPTX/trunc-sat.ll
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index d6a134d9abafd..10c56deb4642d 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -28,6 +28,7 @@
#include "llvm/CodeGen/MachineFunction.h"
#include "llvm/CodeGen/MachineJumpTableInfo.h"
#include "llvm/CodeGen/MachineMemOperand.h"
+#include "llvm/CodeGen/SDPatternMatch.h"
#include "llvm/CodeGen/SelectionDAG.h"
#include "llvm/CodeGen/SelectionDAGNodes.h"
#include "llvm/CodeGen/TargetCallingConv.h"
@@ -74,6 +75,7 @@
#define DEBUG_TYPE "nvptx-lower"
using namespace llvm;
+using namespace llvm::SDPatternMatch;
static cl::opt<bool> sched4reg(
"nvptx-sched4reg",
@@ -659,6 +661,11 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
setOperationAction(ISD::BR_CC, VT, Expand);
}
+ setOperationAction({ISD::TRUNCATE_SSAT_S, ISD::TRUNCATE_SSAT_U}, MVT::i16,
+ Legal);
+ setOperationAction({ISD::TRUNCATE_SSAT_S, ISD::TRUNCATE_SSAT_U}, MVT::i8,
+ Custom);
+
// Some SIGN_EXTEND_INREG can be done using cvt instruction.
// For others we will expand to a SHL/SRA pair.
setOperationAction(ISD::SIGN_EXTEND_INREG, MVT::i64, Legal);
@@ -836,7 +843,8 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
// We have some custom DAG combine patterns for these nodes
setTargetDAGCombine({ISD::ADD, ISD::AND, ISD::EXTRACT_VECTOR_ELT, ISD::FADD,
ISD::MUL, ISD::SHL, ISD::SREM, ISD::UREM, ISD::VSELECT,
- ISD::BUILD_VECTOR, ISD::ADDRSPACECAST});
+ ISD::BUILD_VECTOR, ISD::ADDRSPACECAST, ISD::SMIN,
+ ISD::SMAX});
// setcc for f16x2 and bf16x2 needs special handling to prevent
// legalizer's attempt to scalarize it due to v2i1 not being legal.
@@ -1081,6 +1089,8 @@ const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
MAKE_CASE(NVPTXISD::PseudoUseParam)
MAKE_CASE(NVPTXISD::UNPACK_VECTOR)
MAKE_CASE(NVPTXISD::BUILD_VECTOR)
+ MAKE_CASE(NVPTXISD::TRUNCATE_SSAT_U_I8)
+ MAKE_CASE(NVPTXISD::TRUNCATE_SSAT_S_I8)
MAKE_CASE(NVPTXISD::RETURN)
MAKE_CASE(NVPTXISD::CallSeqBegin)
MAKE_CASE(NVPTXISD::CallSeqEnd)
@@ -5667,6 +5677,49 @@ static SDValue combineADDRSPACECAST(SDNode *N,
return SDValue();
}
+static SDValue combineMINMAX(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
+
+ EVT VT = N->getValueType(0);
+ if (!(VT == MVT::i32 || VT == MVT::i64 || VT == MVT::i16))
+ return SDValue();
+
+ SDValue Val;
+ APInt Ceil, Floor;
+ if (!(sd_match(N, m_SMin(m_SMax(m_Value(Val), m_ConstInt(Floor)),
+ m_ConstInt(Ceil))) ||
+ sd_match(N, m_SMax(m_SMin(m_Value(Val), m_ConstInt(Ceil)),
+ m_ConstInt(Floor)))))
+ return SDValue();
+
+ const unsigned BitWidth = VT.getSizeInBits();
+ SDLoc DL(N);
+ auto MatchTuncSat = [&](MVT DestVT) {
+ const unsigned DestBitWidth = DestVT.getSizeInBits();
+ bool IsSigned;
+ if (Ceil == APInt::getSignedMaxValue(DestBitWidth).sext(BitWidth) &&
+ Floor == APInt::getSignedMinValue(DestBitWidth).sext(BitWidth))
+ IsSigned = true;
+ else if (Ceil == APInt::getMaxValue(DestBitWidth).zext(BitWidth) &&
+ Floor == APInt::getMinValue(BitWidth))
+ IsSigned = false;
+ else
+ return SDValue();
+
+ unsigned Opcode = IsSigned ? ISD::TRUNCATE_SSAT_S : ISD::TRUNCATE_SSAT_U;
+ SDValue Trunc = DCI.DAG.getNode(Opcode, DL, DestVT, Val);
+ return DCI.DAG.getExtOrTrunc(IsSigned, Trunc, DL, VT);
+ };
+
+ if (VT != MVT::i16)
+ if (auto Res = MatchTuncSat(MVT::i16))
+ return Res;
+
+ if (auto Res = MatchTuncSat(MVT::i8))
+ return Res;
+
+ return SDValue();
+}
+
SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
DAGCombinerInfo &DCI) const {
CodeGenOptLevel OptLevel = getTargetMachine().getOptLevel();
@@ -5685,6 +5738,9 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
case ISD::UREM:
case ISD::SREM:
return PerformREMCombine(N, DCI, OptLevel);
+ case ISD::SMIN:
+ case ISD::SMAX:
+ return combineMINMAX(N, DCI);
case ISD::SETCC:
return PerformSETCCCombine(N, DCI, STI.getSmVersion());
case NVPTXISD::StoreRetval:
@@ -6045,6 +6101,20 @@ static void ReplaceCopyFromReg_128(SDNode *N, SelectionDAG &DAG,
Results.push_back(NewValue.getValue(3));
}
+static void replaceTruncateSSat(SDNode *N, SelectionDAG &DAG,
+ SmallVectorImpl<SDValue> &Results) {
+ SDLoc DL(N);
+
+ const bool IsSigned = N->getOpcode() == ISD::TRUNCATE_SSAT_S;
+ const unsigned Opcode =
+ IsSigned ? NVPTXISD::TRUNCATE_SSAT_S_I8 : NVPTXISD::TRUNCATE_SSAT_U_I8;
+ SDValue NewTrunc = DAG.getNode(Opcode, DL, MVT::i16, N->getOperand(0));
+ SDValue Assert = DAG.getNode(IsSigned ? ISD::AssertSext : ISD::AssertZext, DL,
+ MVT::i16, NewTrunc, DAG.getValueType(MVT::i8));
+
+ Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, MVT::i8, Assert));
+}
+
void NVPTXTargetLowering::ReplaceNodeResults(
SDNode *N, SmallVectorImpl<SDValue> &Results, SelectionDAG &DAG) const {
switch (N->getOpcode()) {
@@ -6062,6 +6132,10 @@ void NVPTXTargetLowering::ReplaceNodeResults(
case ISD::CopyFromReg:
ReplaceCopyFromReg_128(N, DAG, Results);
return;
+ case ISD::TRUNCATE_SSAT_U:
+ case ISD::TRUNCATE_SSAT_S:
+ replaceTruncateSSat(N, DAG, Results);
+ return;
}
}
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
index 8d71022a1f102..1bd8ffb65c501 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
@@ -72,6 +72,13 @@ enum NodeType : unsigned {
/// converting it to a vector.
UNPACK_VECTOR,
+ /// These nodes are equivalent to the corresponding ISD nodes except that
+ /// they truncate to an i8 output and then sign or zero extend that value back
+ /// to i16. This is a workaround for the fact that NVPTX does not consider
+ /// i8 to be a legal type. TODO: consider making i8 legal and removing these.
+ TRUNCATE_SSAT_U_I8,
+ TRUNCATE_SSAT_S_I8,
+
FCOPYSIGN,
DYNAMIC_STACKALLOC,
STACKRESTORE,
diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
index b646d39194c7e..e26a05a112e09 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
@@ -2649,6 +2649,25 @@ def : Pat<(i1 (trunc i32:$a)), (SETP_b32ri (ANDb32ri $a, 1), 0, CmpNE)>;
// truncate i16
def : Pat<(i1 (trunc i16:$a)), (SETP_b16ri (ANDb16ri $a, 1), 0, CmpNE)>;
+// truncate ssat
+def SDTTruncSatI8Op : SDTypeProfile<1, 1, [SDTCisInt<1>, SDTCisVT<0, i16>]>;
+def truncssat_s_i8 : SDNode<"NVPTXISD::TRUNCATE_SSAT_S_I8", SDTTruncSatI8Op>;
+def truncssat_u_i8 : SDNode<"NVPTXISD::TRUNCATE_SSAT_U_I8", SDTTruncSatI8Op>;
+
+def : Pat<(i16 (truncssat_s i32:$a)), (CVT_s16_s32 $a, CvtSAT)>;
+def : Pat<(i16 (truncssat_s i64:$a)), (CVT_s16_s64 $a, CvtSAT)>;
+
+def : Pat<(i16 (truncssat_u i32:$a)), (CVT_u16_s32 $a, CvtSAT)>;
+def : Pat<(i16 (truncssat_u i64:$a)), (CVT_u16_s64 $a, CvtSAT)>;
+
+def : Pat<(truncssat_s_i8 i16:$a), (CVT_s8_s16 $a, CvtSAT)>;
+def : Pat<(truncssat_s_i8 i32:$a), (CVT_s8_s32 $a, CvtSAT)>;
+def : Pat<(truncssat_s_i8 i64:$a), (CVT_s8_s64 $a, CvtSAT)>;
+
+def : Pat<(truncssat_u_i8 i16:$a), (CVT_u8_u16 $a, CvtSAT)>;
+def : Pat<(truncssat_u_i8 i32:$a), (CVT_u8_u32 $a, CvtSAT)>;
+def : Pat<(truncssat_u_i8 i64:$a), (CVT_u8_u64 $a, CvtSAT)>;
+
// sext_inreg
def : Pat<(sext_inreg i16:$a, i8), (CVT_INREG_s16_s8 $a)>;
def : Pat<(sext_inreg i32:$a, i8), (CVT_INREG_s32_s8 $a)>;
diff --git a/llvm/test/CodeGen/NVPTX/trunc-sat.ll b/llvm/test/CodeGen/NVPTX/trunc-sat.ll
new file mode 100644
index 0000000000000..e77e7691c9cb7
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/trunc-sat.ll
@@ -0,0 +1,177 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc < %s -mtriple=nvptx64 -mcpu=sm_20 -verify-machineinstrs | FileCheck %s
+; RUN: %if ptxas %{ llc < %s -mtriple=nvptx64 -mcpu=sm_20 -verify-machineinstrs | %ptxas-verify %}
+
+target triple = "nvptx-unknown-cuda"
+
+
+define i64 @trunc_ssat_i64_u16(i64 %a) {
+; CHECK-LABEL: trunc_ssat_i64_u16(
+; CHECK: {
+; CHECK-NEXT: .reg .b16 %rs<2>;
+; CHECK-NEXT: .reg .b64 %rd<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.b64 %rd1, [trunc_ssat_i64_u16_param_0];
+; CHECK-NEXT: cvt.sat.u16.s64 %rs1, %rd1;
+; CHECK-NEXT: cvt.u64.u16 %rd2, %rs1;
+; CHECK-NEXT: st.param.b64 [func_retval0], %rd2;
+; CHECK-NEXT: ret;
+ %v1 = call i64 @llvm.smax.i64(i64 %a, i64 0)
+ %v2 = call i64 @llvm.smin.i64(i64 %v1, i64 65535)
+ ret i64 %v2
+}
+
+define i32 @trunc_ssat_i32_u16(i32 %a) {
+; CHECK-LABEL: trunc_ssat_i32_u16(
+; CHECK: {
+; CHECK-NEXT: .reg .b16 %rs<2>;
+; CHECK-NEXT: .reg .b32 %r<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.b32 %r1, [trunc_ssat_i32_u16_param_0];
+; CHECK-NEXT: cvt.sat.u16.s32 %rs1, %r1;
+; CHECK-NEXT: cvt.u32.u16 %r2, %rs1;
+; CHECK-NEXT: st.param.b32 [func_retval0], %r2;
+; CHECK-NEXT: ret;
+ %v1 = call i32 @llvm.smax.i32(i32 %a, i32 0)
+ %v2 = call i32 @llvm.smin.i32(i32 %v1, i32 65535)
+ ret i32 %v2
+}
+
+define i64 @trunc_ssat_i64_s16(i64 %a) {
+; CHECK-LABEL: trunc_ssat_i64_s16(
+; CHECK: {
+; CHECK-NEXT: .reg .b16 %rs<2>;
+; CHECK-NEXT: .reg .b64 %rd<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.b64 %rd1, [trunc_ssat_i64_s16_param_0];
+; CHECK-NEXT: cvt.sat.s16.s64 %rs1, %rd1;
+; CHECK-NEXT: cvt.s64.s16 %rd2, %rs1;
+; CHECK-NEXT: st.param.b64 [func_retval0], %rd2;
+; CHECK-NEXT: ret;
+ %v1 = call i64 @llvm.smax.i64(i64 %a, i64 -32768)
+ %v2 = call i64 @llvm.smin.i64(i64 %v1, i64 32767)
+ ret i64 %v2
+}
+
+define i32 @trunc_ssat_i32_s16(i32 %a) {
+; CHECK-LABEL: trunc_ssat_i32_s16(
+; CHECK: {
+; CHECK-NEXT: .reg .b16 %rs<2>;
+; CHECK-NEXT: .reg .b32 %r<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.b32 %r1, [trunc_ssat_i32_s16_param_0];
+; CHECK-NEXT: cvt.sat.s16.s32 %rs1, %r1;
+; CHECK-NEXT: cvt.s32.s16 %r2, %rs1;
+; CHECK-NEXT: st.param.b32 [func_retval0], %r2;
+; CHECK-NEXT: ret;
+ %v1 = call i32 @llvm.smax.i32(i32 %a, i32 -32768)
+ %v2 = call i32 @llvm.smin.i32(i32 %v1, i32 32767)
+ ret i32 %v2
+}
+
+define i64 @trunc_ssat_i64_u8(i64 %a) {
+; CHECK-LABEL: trunc_ssat_i64_u8(
+; CHECK: {
+; CHECK-NEXT: .reg .b16 %rs<2>;
+; CHECK-NEXT: .reg .b64 %rd<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.b64 %rd1, [trunc_ssat_i64_u8_param_0];
+; CHECK-NEXT: cvt.sat.u8.u64 %rs1, %rd1;
+; CHECK-NEXT: cvt.u64.u16 %rd2, %rs1;
+; CHECK-NEXT: st.param.b64 [func_retval0], %rd2;
+; CHECK-NEXT: ret;
+ %v1 = call i64 @llvm.smax.i64(i64 %a, i64 0)
+ %v2 = call i64 @llvm.smin.i64(i64 %v1, i64 255)
+ ret i64 %v2
+}
+
+define i32 @trunc_ssat_i32_u8(i32 %a) {
+; CHECK-LABEL: trunc_ssat_i32_u8(
+; CHECK: {
+; CHECK-NEXT: .reg .b16 %rs<2>;
+; CHECK-NEXT: .reg .b32 %r<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.b32 %r1, [trunc_ssat_i32_u8_param_0];
+; CHECK-NEXT: cvt.sat.u8.u32 %rs1, %r1;
+; CHECK-NEXT: cvt.u32.u16 %r2, %rs1;
+; CHECK-NEXT: st.param.b32 [func_retval0], %r2;
+; CHECK-NEXT: ret;
+ %v1 = call i32 @llvm.smax.i32(i32 %a, i32 0)
+ %v2 = call i32 @llvm.smin.i32(i32 %v1, i32 255)
+ ret i32 %v2
+}
+
+define i16 @trunc_ssat_i16_u8(i16 %a) {
+; CHECK-LABEL: trunc_ssat_i16_u8(
+; CHECK: {
+; CHECK-NEXT: .reg .b16 %rs<3>;
+; CHECK-NEXT: .reg .b32 %r<2>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.b16 %rs1, [trunc_ssat_i16_u8_param_0];
+; CHECK-NEXT: cvt.sat.u8.u16 %rs2, %rs1;
+; CHECK-NEXT: cvt.u32.u16 %r1, %rs2;
+; CHECK-NEXT: st.param.b32 [func_retval0], %r1;
+; CHECK-NEXT: ret;
+ %v1 = call i16 @llvm.smax.i16(i16 %a, i16 0)
+ %v2 = call i16 @llvm.smin.i16(i16 %v1, i16 255)
+ ret i16 %v2
+}
+
+define i64 @trunc_ssat_i64_s8(i64 %a) {
+; CHECK-LABEL: trunc_ssat_i64_s8(
+; CHECK: {
+; CHECK-NEXT: .reg .b16 %rs<2>;
+; CHECK-NEXT: .reg .b64 %rd<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.b64 %rd1, [trunc_ssat_i64_s8_param_0];
+; CHECK-NEXT: cvt.sat.s8.s64 %rs1, %rd1;
+; CHECK-NEXT: cvt.s64.s16 %rd2, %rs1;
+; CHECK-NEXT: st.param.b64 [func_retval0], %rd2;
+; CHECK-NEXT: ret;
+ %v1 = call i64 @llvm.smax.i64(i64 %a, i64 -128)
+ %v2 = call i64 @llvm.smin.i64(i64 %v1, i64 127)
+ ret i64 %v2
+}
+
+define i32 @trunc_ssat_i32_s8(i32 %a) {
+; CHECK-LABEL: trunc_ssat_i32_s8(
+; CHECK: {
+; CHECK-NEXT: .reg .b16 %rs<2>;
+; CHECK-NEXT: .reg .b32 %r<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.b32 %r1, [trunc_ssat_i32_s8_param_0];
+; CHECK-NEXT: cvt.sat.s8.s32 %rs1, %r1;
+; CHECK-NEXT: cvt.s32.s16 %r2, %rs1;
+; CHECK-NEXT: st.param.b32 [func_retval0], %r2;
+; CHECK-NEXT: ret;
+ %v1 = call i32 @llvm.smax.i32(i32 %a, i32 -128)
+ %v2 = call i32 @llvm.smin.i32(i32 %v1, i32 127)
+ ret i32 %v2
+}
+
+define i16 @trunc_ssat_i16_s8(i16 %a) {
+; CHECK-LABEL: trunc_ssat_i16_s8(
+; CHECK: {
+; CHECK-NEXT: .reg .b16 %rs<3>;
+; CHECK-NEXT: .reg .b32 %r<2>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.b16 %rs1, [trunc_ssat_i16_s8_param_0];
+; CHECK-NEXT: cvt.sat.s8.s16 %rs2, %rs1;
+; CHECK-NEXT: cvt.u32.u16 %r1, %rs2;
+; CHECK-NEXT: st.param.b32 [func_retval0], %r1;
+; CHECK-NEXT: ret;
+ %v1 = call i16 @llvm.smax.i16(i16 %a, i16 -128)
+ %v2 = call i16 @llvm.smin.i16(i16 %v1, i16 127)
+ ret i16 %v2
+}
+
More information about the llvm-commits
mailing list