[llvm] [NVPTX] MAD combine through CVT (PR #150477)
Justin Fargnoli via llvm-commits
llvm-commits at lists.llvm.org
Thu Jul 24 10:27:09 PDT 2025
https://github.com/justinfargnoli created https://github.com/llvm/llvm-project/pull/150477
`(add (cvt (mul a, b)), c) -> (mad a, b, c)`
>From 1c1cd4a744cec9b9cb04a26162ddff1330507454 Mon Sep 17 00:00:00 2001
From: Justin Fargnoli <jfargnoli at nvidia.com>
Date: Thu, 24 Jul 2025 17:19:53 +0000
Subject: [PATCH] Initial commit
---
llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp | 37 ++++++-
llvm/lib/Target/NVPTX/NVPTXISelLowering.h | 2 +
llvm/lib/Target/NVPTX/NVPTXInstrInfo.td | 20 ++++
llvm/test/CodeGen/NVPTX/combine-ext-mad.ll | 117 ++++++++++++++++++++
4 files changed, 171 insertions(+), 5 deletions(-)
create mode 100644 llvm/test/CodeGen/NVPTX/combine-ext-mad.ll
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 7aa06f9079b09..fb9f4c844a1a4 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -1101,6 +1101,8 @@ const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
MAKE_CASE(NVPTXISD::SETP_BF16X2)
MAKE_CASE(NVPTXISD::MUL_WIDE_SIGNED)
MAKE_CASE(NVPTXISD::MUL_WIDE_UNSIGNED)
+ MAKE_CASE(NVPTXISD::MAD_WIDE_UNSIGNED)
+ MAKE_CASE(NVPTXISD::MAD_WIDE_SIGNED)
MAKE_CASE(NVPTXISD::BrxEnd)
MAKE_CASE(NVPTXISD::BrxItem)
MAKE_CASE(NVPTXISD::BrxStart)
@@ -4885,6 +4887,30 @@ static bool isConstZero(const SDValue &Operand) {
return Const && Const->getZExtValue() == 0;
}
+static SDValue
+PerformMADCombineWithOperands(SDNode *N, SDValue N0, SDValue N1,
+ TargetLowering::DAGCombinerInfo &DCI) {
+ assert(N->getOpcode() == ISD::ADD);
+ if (!(N0.getOpcode() == ISD::ZERO_EXTEND ||
+ N0.getOpcode() == ISD::ANY_EXTEND ||
+ N0.getOpcode() == ISD::SIGN_EXTEND))
+ return SDValue();
+ if (N->getValueType(0) != MVT::i64)
+ return SDValue();
+ SDValue M = N0.getOperand(0);
+ if (M.getOpcode() != ISD::MUL)
+ return SDValue();
+ if (M.getValueType() != MVT::i32)
+ return SDValue();
+
+ unsigned Opcode = NVPTXISD::MAD_WIDE_UNSIGNED;
+ if (N0.getOpcode() == ISD::SIGN_EXTEND)
+ Opcode = NVPTXISD::MAD_WIDE_SIGNED;
+ SDValue Mul = N0.getOperand(0);
+ return DCI.DAG.getNode(Opcode, SDLoc(N), N->getValueType(0),
+ Mul.getOperand(0), Mul.getOperand(1), N1);
+}
+
/// PerformADDCombineWithOperands - Try DAG combinations for an ADD with
/// operands N0 and N1. This is a helper for PerformADDCombine that is
/// called with the default operands, and if that fails, with commuted
@@ -4905,6 +4931,9 @@ PerformADDCombineWithOperands(SDNode *N, SDValue N0, SDValue N1,
// -> (select cond, c, (add (mul a, b), c))
//
if (N0.getOpcode() == ISD::SELECT) {
+ // Skip non-integer, non-scalar case
+ if (VT.isVector() || VT != MVT::i32)
+ return SDValue();
unsigned ZeroOpNum;
if (isConstZero(N0->getOperand(1)))
ZeroOpNum = 1;
@@ -4926,6 +4955,9 @@ PerformADDCombineWithOperands(SDNode *N, SDValue N0, SDValue N1,
((ZeroOpNum == 1) ? MAD : N1));
}
+ if (SDValue V = PerformMADCombineWithOperands(N, N0, N1, DCI))
+ return V;
+
return SDValue();
}
@@ -5274,11 +5306,6 @@ static SDValue PerformADDCombine(SDNode *N,
SDValue N0 = N->getOperand(0);
SDValue N1 = N->getOperand(1);
- // Skip non-integer, non-scalar case
- EVT VT = N0.getValueType();
- if (VT.isVector() || VT != MVT::i32)
- return SDValue();
-
// First try with the default operand order.
if (SDValue Result = PerformADDCombineWithOperands(N, N0, N1, DCI))
return Result;
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
index bc3548c0272bb..39c3787641ad0 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
@@ -48,6 +48,8 @@ enum NodeType : unsigned {
FSHR_CLAMP,
MUL_WIDE_SIGNED,
MUL_WIDE_UNSIGNED,
+ MAD_WIDE_UNSIGNED,
+ MAD_WIDE_SIGNED,
SETP_F16X2,
SETP_BF16X2,
BFI,
diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
index a5bb83dfadb84..4ef650f6d7397 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
@@ -990,6 +990,26 @@ defm MAD32 : MAD<"mad.lo.s32", i32, B32, i32imm>;
defm MAD64 : MAD<"mad.lo.s64", i64, B64, i64imm>;
}
+def SDTMadWide : SDTypeProfile<1, 3, [SDTCisInt<0>, SDTCisSameAs<0, 3>, SDTCisInt<1>, SDTCisSameAs<1, 2>]>;
+def mad_wide_signed : SDNode<"NVPTXISD::MAD_WIDE_SIGNED", SDTMadWide>;
+def mad_wide_unsigned : SDNode<"NVPTXISD::MAD_WIDE_UNSIGNED", SDTMadWide>;
+
+multiclass MAD_WIDE<string PtxSuffix, SDNode Op, ValueType BigVT, NVPTXRegClass BigReg, ValueType SmallVT, NVPTXRegClass SmallReg, Operand SmallImm> {
+ def rrr:
+ BasicNVPTXInst<(outs BigReg:$dst),
+ (ins SmallReg:$a, SmallReg:$b, BigReg:$c),
+ "mad.wide." # PtxSuffix,
+ [(set BigVT:$dst, (Op SmallVT:$a, SmallVT:$b, BigVT:$c))]>;
+ def rir:
+ BasicNVPTXInst<(outs BigReg:$dst),
+ (ins SmallReg:$a, SmallImm:$b, BigReg:$c),
+ "mad.wide." # PtxSuffix,
+ [(set BigVT:$dst, (Op SmallVT:$a, imm:$b, BigVT:$c))]>;
+}
+
+defm MAD_WIDE_UNSIGNED_32 : MAD_WIDE<"u32", mad_wide_unsigned, i64, Int64Regs, i32, Int32Regs, i32imm>;
+defm MAD_WIDE_SIGNED_32 : MAD_WIDE<"s32", mad_wide_signed, i64, Int64Regs, i32, Int32Regs, i32imm>;
+
foreach t = [I16RT, I32RT, I64RT] in {
def NEG_S # t.Size :
BasicNVPTXInst<(outs t.RC:$dst), (ins t.RC:$src),
diff --git a/llvm/test/CodeGen/NVPTX/combine-ext-mad.ll b/llvm/test/CodeGen/NVPTX/combine-ext-mad.ll
new file mode 100644
index 0000000000000..c6d656ef65725
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/combine-ext-mad.ll
@@ -0,0 +1,117 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc < %s | FileCheck %s
+
+target triple = "nvptx64-nvidia-cuda"
+
+define i64 @t1(i32 %a, i32 %b, i64 %c) {
+; CHECK-LABEL: t1(
+; CHECK: {
+; CHECK-NEXT: .reg .b32 %r<3>;
+; CHECK-NEXT: .reg .b64 %rd<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.b32 %r1, [t1_param_0];
+; CHECK-NEXT: ld.param.b32 %r2, [t1_param_1];
+; CHECK-NEXT: ld.param.b64 %rd1, [t1_param_2];
+; CHECK-NEXT: mad.wide.s32 %rd2, %r1, %r2, %rd1;
+; CHECK-NEXT: st.param.b64 [func_retval0], %rd2;
+; CHECK-NEXT: ret;
+ %mul = mul i32 %a, %b
+ %sext = sext i32 %mul to i64
+ %add = add i64 %c, %sext
+ ret i64 %add
+}
+
+define i64 @t2(i32 %a, i32 %b, i64 %c) {
+; CHECK-LABEL: t2(
+; CHECK: {
+; CHECK-NEXT: .reg .b32 %r<3>;
+; CHECK-NEXT: .reg .b64 %rd<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.b32 %r1, [t2_param_0];
+; CHECK-NEXT: ld.param.b32 %r2, [t2_param_1];
+; CHECK-NEXT: ld.param.b64 %rd1, [t2_param_2];
+; CHECK-NEXT: mad.wide.s32 %rd2, %r1, %r2, %rd1;
+; CHECK-NEXT: st.param.b64 [func_retval0], %rd2;
+; CHECK-NEXT: ret;
+ %mul = mul i32 %a, %b
+ %sext = sext i32 %mul to i64
+ %add = add i64 %sext, %c
+ ret i64 %add
+}
+
+define i64 @t3(i32 %a, i32 %b) {
+; CHECK-LABEL: t3(
+; CHECK: {
+; CHECK-NEXT: .reg .b32 %r<3>;
+; CHECK-NEXT: .reg .b64 %rd<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.b32 %r1, [t3_param_0];
+; CHECK-NEXT: ld.param.b32 %r2, [t3_param_1];
+; CHECK-NEXT: mov.b64 %rd1, 1;
+; CHECK-NEXT: mad.wide.s32 %rd2, %r1, %r2, %rd1;
+; CHECK-NEXT: st.param.b64 [func_retval0], %rd2;
+; CHECK-NEXT: ret;
+ %mul = mul i32 %a, %b
+ %sext = sext i32 %mul to i64
+ %add = add i64 1, %sext
+ ret i64 %add
+}
+
+define i64 @t4(i32 %a, i64 %c) {
+; CHECK-LABEL: t4(
+; CHECK: {
+; CHECK-NEXT: .reg .b32 %r<2>;
+; CHECK-NEXT: .reg .b64 %rd<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.b32 %r1, [t4_param_0];
+; CHECK-NEXT: ld.param.b64 %rd1, [t4_param_1];
+; CHECK-NEXT: mad.wide.s32 %rd2, %r1, 3, %rd1;
+; CHECK-NEXT: st.param.b64 [func_retval0], %rd2;
+; CHECK-NEXT: ret;
+ %mul = mul i32 %a, 3
+ %sext = sext i32 %mul to i64
+ %add = add i64 %c, %sext
+ ret i64 %add
+}
+
+define i64 @t5(i32 %a, i32 %b, i64 %c) {
+; CHECK-LABEL: t5(
+; CHECK: {
+; CHECK-NEXT: .reg .b32 %r<3>;
+; CHECK-NEXT: .reg .b64 %rd<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.b32 %r1, [t5_param_0];
+; CHECK-NEXT: ld.param.b32 %r2, [t5_param_1];
+; CHECK-NEXT: ld.param.b64 %rd1, [t5_param_2];
+; CHECK-NEXT: mad.wide.u32 %rd2, %r1, %r2, %rd1;
+; CHECK-NEXT: st.param.b64 [func_retval0], %rd2;
+; CHECK-NEXT: ret;
+ %mul = mul i32 %a, %b
+ %zext = zext i32 %mul to i64
+ %add = add i64 %c, %zext
+ ret i64 %add
+}
+
+define i64 @t6(i32 %a, i32 %b, i64 %c) {
+; CHECK-LABEL: t6(
+; CHECK: {
+; CHECK-NEXT: .reg .b32 %r<3>;
+; CHECK-NEXT: .reg .b64 %rd<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.b32 %r1, [t6_param_0];
+; CHECK-NEXT: ld.param.b32 %r2, [t6_param_1];
+; CHECK-NEXT: ld.param.b64 %rd1, [t6_param_2];
+; CHECK-NEXT: mad.wide.u32 %rd2, %r1, %r2, %rd1;
+; CHECK-NEXT: st.param.b64 [func_retval0], %rd2;
+; CHECK-NEXT: ret;
+ %mul = mul i32 %a, %b
+ %zext = zext i32 %mul to i64
+ %add = add i64 %zext, %c
+ ret i64 %add
+}
More information about the llvm-commits
mailing list