[llvm] [NVPTX] MAD combine through CVT (PR #150477)

Justin Fargnoli via llvm-commits llvm-commits at lists.llvm.org
Thu Jul 24 10:27:09 PDT 2025


https://github.com/justinfargnoli created https://github.com/llvm/llvm-project/pull/150477

`(add (cvt (mul a, b)), c) -> (mad a, b, c)`

>From 1c1cd4a744cec9b9cb04a26162ddff1330507454 Mon Sep 17 00:00:00 2001
From: Justin Fargnoli <jfargnoli at nvidia.com>
Date: Thu, 24 Jul 2025 17:19:53 +0000
Subject: [PATCH] Initial commit

---
 llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp |  37 ++++++-
 llvm/lib/Target/NVPTX/NVPTXISelLowering.h   |   2 +
 llvm/lib/Target/NVPTX/NVPTXInstrInfo.td     |  20 ++++
 llvm/test/CodeGen/NVPTX/combine-ext-mad.ll  | 117 ++++++++++++++++++++
 4 files changed, 171 insertions(+), 5 deletions(-)
 create mode 100644 llvm/test/CodeGen/NVPTX/combine-ext-mad.ll

diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 7aa06f9079b09..fb9f4c844a1a4 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -1101,6 +1101,8 @@ const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
     MAKE_CASE(NVPTXISD::SETP_BF16X2)
     MAKE_CASE(NVPTXISD::MUL_WIDE_SIGNED)
     MAKE_CASE(NVPTXISD::MUL_WIDE_UNSIGNED)
+    MAKE_CASE(NVPTXISD::MAD_WIDE_UNSIGNED)
+    MAKE_CASE(NVPTXISD::MAD_WIDE_SIGNED)
     MAKE_CASE(NVPTXISD::BrxEnd)
     MAKE_CASE(NVPTXISD::BrxItem)
     MAKE_CASE(NVPTXISD::BrxStart)
@@ -4885,6 +4887,30 @@ static bool isConstZero(const SDValue &Operand) {
   return Const && Const->getZExtValue() == 0;
 }
 
+static SDValue
+PerformMADCombineWithOperands(SDNode *N, SDValue N0, SDValue N1,
+                              TargetLowering::DAGCombinerInfo &DCI) {
+  assert(N->getOpcode() == ISD::ADD);
+  if (!(N0.getOpcode() == ISD::ZERO_EXTEND ||
+        N0.getOpcode() == ISD::ANY_EXTEND ||
+        N0.getOpcode() == ISD::SIGN_EXTEND))
+    return SDValue();
+  if (N->getValueType(0) != MVT::i64)
+    return SDValue();
+  SDValue M = N0.getOperand(0);
+  if (M.getOpcode() != ISD::MUL)
+    return SDValue();
+  if (M.getValueType() != MVT::i32)
+    return SDValue();
+
+  unsigned Opcode = NVPTXISD::MAD_WIDE_UNSIGNED;
+  if (N0.getOpcode() == ISD::SIGN_EXTEND)
+    Opcode = NVPTXISD::MAD_WIDE_SIGNED;
+  SDValue Mul = N0.getOperand(0);
+  return DCI.DAG.getNode(Opcode, SDLoc(N), N->getValueType(0),
+                         Mul.getOperand(0), Mul.getOperand(1), N1);
+}
+
 /// PerformADDCombineWithOperands - Try DAG combinations for an ADD with
 /// operands N0 and N1.  This is a helper for PerformADDCombine that is
 /// called with the default operands, and if that fails, with commuted
@@ -4905,6 +4931,9 @@ PerformADDCombineWithOperands(SDNode *N, SDValue N0, SDValue N1,
   //   -> (select cond, c, (add (mul a, b), c))
   //
   if (N0.getOpcode() == ISD::SELECT) {
+    // Skip non-integer, non-scalar case
+    if (VT.isVector() || VT != MVT::i32)
+      return SDValue();
     unsigned ZeroOpNum;
     if (isConstZero(N0->getOperand(1)))
       ZeroOpNum = 1;
@@ -4926,6 +4955,9 @@ PerformADDCombineWithOperands(SDNode *N, SDValue N0, SDValue N1,
                              ((ZeroOpNum == 1) ? MAD : N1));
   }
 
+  if (SDValue V = PerformMADCombineWithOperands(N, N0, N1, DCI))
+    return V;
+
   return SDValue();
 }
 
@@ -5274,11 +5306,6 @@ static SDValue PerformADDCombine(SDNode *N,
   SDValue N0 = N->getOperand(0);
   SDValue N1 = N->getOperand(1);
 
-  // Skip non-integer, non-scalar case
-  EVT VT = N0.getValueType();
-  if (VT.isVector() || VT != MVT::i32)
-    return SDValue();
-
   // First try with the default operand order.
   if (SDValue Result = PerformADDCombineWithOperands(N, N0, N1, DCI))
     return Result;
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
index bc3548c0272bb..39c3787641ad0 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
@@ -48,6 +48,8 @@ enum NodeType : unsigned {
   FSHR_CLAMP,
   MUL_WIDE_SIGNED,
   MUL_WIDE_UNSIGNED,
+  MAD_WIDE_UNSIGNED,
+  MAD_WIDE_SIGNED,
   SETP_F16X2,
   SETP_BF16X2,
   BFI,
diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
index a5bb83dfadb84..4ef650f6d7397 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
@@ -990,6 +990,26 @@ defm MAD32 : MAD<"mad.lo.s32", i32, B32, i32imm>;
 defm MAD64 : MAD<"mad.lo.s64", i64, B64, i64imm>;
 }
 
+def SDTMadWide : SDTypeProfile<1, 3, [SDTCisInt<0>, SDTCisSameAs<0, 3>, SDTCisInt<1>, SDTCisSameAs<1, 2>]>;
+def mad_wide_signed : SDNode<"NVPTXISD::MAD_WIDE_SIGNED", SDTMadWide>;
+def mad_wide_unsigned : SDNode<"NVPTXISD::MAD_WIDE_UNSIGNED", SDTMadWide>;
+
+multiclass MAD_WIDE<string PtxSuffix, SDNode Op, ValueType BigVT, NVPTXRegClass BigReg, ValueType SmallVT, NVPTXRegClass SmallReg, Operand SmallImm> {
+  def rrr:
+    BasicNVPTXInst<(outs BigReg:$dst),
+              (ins SmallReg:$a, SmallReg:$b, BigReg:$c),
+              "mad.wide." # PtxSuffix,
+              [(set BigVT:$dst, (Op SmallVT:$a, SmallVT:$b, BigVT:$c))]>;
+  def rir:
+    BasicNVPTXInst<(outs BigReg:$dst),
+              (ins SmallReg:$a, SmallImm:$b, BigReg:$c),
+              "mad.wide." # PtxSuffix,
+              [(set BigVT:$dst, (Op SmallVT:$a, imm:$b, BigVT:$c))]>;
+}
+
+defm MAD_WIDE_UNSIGNED_32 : MAD_WIDE<"u32", mad_wide_unsigned, i64, Int64Regs, i32, Int32Regs, i32imm>;
+defm MAD_WIDE_SIGNED_32 : MAD_WIDE<"s32", mad_wide_signed, i64, Int64Regs, i32, Int32Regs, i32imm>;
+
 foreach t = [I16RT, I32RT, I64RT] in {
   def NEG_S # t.Size :
     BasicNVPTXInst<(outs t.RC:$dst), (ins t.RC:$src),
diff --git a/llvm/test/CodeGen/NVPTX/combine-ext-mad.ll b/llvm/test/CodeGen/NVPTX/combine-ext-mad.ll
new file mode 100644
index 0000000000000..c6d656ef65725
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/combine-ext-mad.ll
@@ -0,0 +1,117 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc < %s | FileCheck %s
+
+target triple = "nvptx64-nvidia-cuda"
+
+define i64 @t1(i32 %a, i32 %b, i64 %c) {
+; CHECK-LABEL: t1(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b32 %r<3>;
+; CHECK-NEXT:    .reg .b64 %rd<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.b32 %r1, [t1_param_0];
+; CHECK-NEXT:    ld.param.b32 %r2, [t1_param_1];
+; CHECK-NEXT:    ld.param.b64 %rd1, [t1_param_2];
+; CHECK-NEXT:    mad.wide.s32 %rd2, %r1, %r2, %rd1;
+; CHECK-NEXT:    st.param.b64 [func_retval0], %rd2;
+; CHECK-NEXT:    ret;
+  %mul = mul i32 %a, %b
+  %sext = sext i32 %mul to i64
+  %add = add i64 %c, %sext
+  ret i64 %add
+}
+
+define i64 @t2(i32 %a, i32 %b, i64 %c) {
+; CHECK-LABEL: t2(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b32 %r<3>;
+; CHECK-NEXT:    .reg .b64 %rd<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.b32 %r1, [t2_param_0];
+; CHECK-NEXT:    ld.param.b32 %r2, [t2_param_1];
+; CHECK-NEXT:    ld.param.b64 %rd1, [t2_param_2];
+; CHECK-NEXT:    mad.wide.s32 %rd2, %r1, %r2, %rd1;
+; CHECK-NEXT:    st.param.b64 [func_retval0], %rd2;
+; CHECK-NEXT:    ret;
+  %mul = mul i32 %a, %b
+  %sext = sext i32 %mul to i64
+  %add = add i64 %sext, %c
+  ret i64 %add
+}
+
+define i64 @t3(i32 %a, i32 %b) {
+; CHECK-LABEL: t3(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b32 %r<3>;
+; CHECK-NEXT:    .reg .b64 %rd<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.b32 %r1, [t3_param_0];
+; CHECK-NEXT:    ld.param.b32 %r2, [t3_param_1];
+; CHECK-NEXT:    mov.b64 %rd1, 1;
+; CHECK-NEXT:    mad.wide.s32 %rd2, %r1, %r2, %rd1;
+; CHECK-NEXT:    st.param.b64 [func_retval0], %rd2;
+; CHECK-NEXT:    ret;
+  %mul = mul i32 %a, %b
+  %sext = sext i32 %mul to i64
+  %add = add i64 1, %sext
+  ret i64 %add
+}
+
+define i64 @t4(i32 %a, i64 %c) {
+; CHECK-LABEL: t4(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b32 %r<2>;
+; CHECK-NEXT:    .reg .b64 %rd<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.b32 %r1, [t4_param_0];
+; CHECK-NEXT:    ld.param.b64 %rd1, [t4_param_1];
+; CHECK-NEXT:    mad.wide.s32 %rd2, %r1, 3, %rd1;
+; CHECK-NEXT:    st.param.b64 [func_retval0], %rd2;
+; CHECK-NEXT:    ret;
+  %mul = mul i32 %a, 3
+  %sext = sext i32 %mul to i64
+  %add = add i64 %c, %sext
+  ret i64 %add
+}
+
+define i64 @t5(i32 %a, i32 %b, i64 %c) {
+; CHECK-LABEL: t5(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b32 %r<3>;
+; CHECK-NEXT:    .reg .b64 %rd<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.b32 %r1, [t5_param_0];
+; CHECK-NEXT:    ld.param.b32 %r2, [t5_param_1];
+; CHECK-NEXT:    ld.param.b64 %rd1, [t5_param_2];
+; CHECK-NEXT:    mad.wide.u32 %rd2, %r1, %r2, %rd1;
+; CHECK-NEXT:    st.param.b64 [func_retval0], %rd2;
+; CHECK-NEXT:    ret;
+  %mul = mul i32 %a, %b
+  %zext = zext i32 %mul to i64
+  %add = add i64 %c, %zext
+  ret i64 %add
+}
+
+define i64 @t6(i32 %a, i32 %b, i64 %c) {
+; CHECK-LABEL: t6(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b32 %r<3>;
+; CHECK-NEXT:    .reg .b64 %rd<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.b32 %r1, [t6_param_0];
+; CHECK-NEXT:    ld.param.b32 %r2, [t6_param_1];
+; CHECK-NEXT:    ld.param.b64 %rd1, [t6_param_2];
+; CHECK-NEXT:    mad.wide.u32 %rd2, %r1, %r2, %rd1;
+; CHECK-NEXT:    st.param.b64 [func_retval0], %rd2;
+; CHECK-NEXT:    ret;
+  %mul = mul i32 %a, %b
+  %zext = zext i32 %mul to i64
+  %add = add i64 %zext, %c
+  ret i64 %add
+}



More information about the llvm-commits mailing list