[llvm] [AMDGPU][SDAG] Try folding "lshr i64 + mad" to "mad_[iu]64_[iu]32" (PR #119218)

Vikram Hegde via llvm-commits llvm-commits at lists.llvm.org
Wed Dec 18 22:01:50 PST 2024


https://github.com/vikramRH updated https://github.com/llvm/llvm-project/pull/119218

>From e733a6da0bd650f7cca8868c01714c11fb32f85c Mon Sep 17 00:00:00 2001
From: vikhegde <vikram.hegde at amd.com>
Date: Mon, 9 Dec 2024 19:25:14 +0530
Subject: [PATCH 1/3] [AMDGPU][SDAG] Try folding "lshr i64 + mad" to
 "mad_[iu]64_[iu]32"

---
 llvm/lib/Target/AMDGPU/SIISelLowering.cpp | 49 +++++++++++++
 llvm/test/CodeGen/AMDGPU/mad_64_32.ll     | 85 +++++++++++++++++++++++
 2 files changed, 134 insertions(+)

diff --git a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
index f89fe8faa600ba..c18db9f65dc08e 100644
--- a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
+++ b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
@@ -13857,6 +13857,52 @@ static SDValue getMad64_32(SelectionDAG &DAG, const SDLoc &SL, EVT VT,
   return DAG.getNode(ISD::TRUNCATE, SL, VT, Mad);
 }
 
+// Fold
+//     y = lshr i64 x, 32
+//     res = add (mul i64 y, Constant), x   where "Constant" is a 32 bit
+//     negative value
+// To
+//     res = mad_u64_u32 y.lo ,Constant.lo, x.lo
+static SDValue tryFoldMADwithSRL(SelectionDAG &DAG, const SDLoc &SL,
+                                 SDValue MulLHS, SDValue MulRHS,
+                                 SDValue AddRHS) {
+
+  if (MulLHS.getValueType() != MVT::i64)
+    return SDValue();
+
+  ConstantSDNode *ConstOp;
+  SDValue ShiftOp;
+  if (MulLHS.getOpcode() == ISD::SRL && MulRHS.getOpcode() == ISD::Constant) {
+    ConstOp = cast<ConstantSDNode>(MulRHS.getNode());
+    ShiftOp = MulLHS;
+  } else if (MulRHS.getOpcode() == ISD::SRL &&
+             MulLHS.getOpcode() == ISD::Constant) {
+    ConstOp = cast<ConstantSDNode>(MulLHS.getNode());
+    ShiftOp = MulRHS;
+  } else
+    return SDValue();
+
+  if (ShiftOp.getOperand(1).getOpcode() != ISD::Constant ||
+      AddRHS != ShiftOp.getOperand(0))
+    return SDValue();
+
+  if (cast<ConstantSDNode>(ShiftOp->getOperand(1))->getAsZExtVal() != 32)
+    return SDValue();
+
+  APInt ConstVal = ConstOp->getAPIntValue();
+  if (!ConstVal.isNegative() || !ConstVal.isSignedIntN(33))
+    return SDValue();
+
+  SDValue Zero = DAG.getConstant(0, SL, MVT::i32);
+  SDValue ConstMul = DAG.getConstant(
+      ConstVal.getZExtValue() & 0x00000000FFFFFFFF, SL, MVT::i32);
+  AddRHS = DAG.getNode(ISD::AND, SL, MVT::i64, AddRHS,
+                       DAG.getConstant(0x00000000FFFFFFFF, SL, MVT::i64));
+  return getMad64_32(DAG, SL, MVT::i64,
+                     DAG.getNode(ISD::TRUNCATE, SL, MVT::i32, MulLHS), ConstMul,
+                     AddRHS, false);
+}
+
 // Fold (add (mul x, y), z) --> (mad_[iu]64_[iu]32 x, y, z) plus high
 // multiplies, if any.
 //
@@ -13915,6 +13961,9 @@ SDValue SITargetLowering::tryFoldToMad64_32(SDNode *N,
   SDValue MulRHS = LHS.getOperand(1);
   SDValue AddRHS = RHS;
 
+  if (SDValue FoldedMAD = tryFoldMADwithSRL(DAG, SL, MulLHS, MulRHS, AddRHS))
+    return FoldedMAD;
+
   // Always check whether operands are small unsigned values, since that
   // knowledge is useful in more cases. Check for small signed values only if
   // doing so can unlock a shorter code sequence.
diff --git a/llvm/test/CodeGen/AMDGPU/mad_64_32.ll b/llvm/test/CodeGen/AMDGPU/mad_64_32.ll
index 33007e5b285d80..d8f6eb266fc6ca 100644
--- a/llvm/test/CodeGen/AMDGPU/mad_64_32.ll
+++ b/llvm/test/CodeGen/AMDGPU/mad_64_32.ll
@@ -1333,5 +1333,90 @@ define i48 @mad_i48_i48(i48 %arg0, i48 %arg1, i48 %arg2) #0 {
   ret i48 %a
 }
 
+define i64 @lshr_mad_i64(ptr addrspace(1) %1) local_unnamed_addr #0 {
+; CI-LABEL: lshr_mad_i64:
+; CI:       ; %bb.0:
+; CI-NEXT:    s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
+; CI-NEXT:    s_mov_b32 s6, 0
+; CI-NEXT:    s_mov_b32 s7, 0xf000
+; CI-NEXT:    s_mov_b32 s4, s6
+; CI-NEXT:    s_mov_b32 s5, s6
+; CI-NEXT:    buffer_load_dwordx2 v[0:1], v[0:1], s[4:7], 0 addr64
+; CI-NEXT:    v_mov_b32_e32 v3, 0
+; CI-NEXT:    s_movk_i32 s4, 0xd1
+; CI-NEXT:    s_waitcnt vmcnt(0)
+; CI-NEXT:    v_mov_b32_e32 v2, v0
+; CI-NEXT:    v_mad_u64_u32 v[0:1], s[4:5], v1, s4, v[2:3]
+; CI-NEXT:    s_setpc_b64 s[30:31]
+;
+; SI-LABEL: lshr_mad_i64:
+; SI:       ; %bb.0:
+; SI-NEXT:    s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
+; SI-NEXT:    s_mov_b32 s6, 0
+; SI-NEXT:    s_mov_b32 s7, 0xf000
+; SI-NEXT:    s_mov_b32 s4, s6
+; SI-NEXT:    s_mov_b32 s5, s6
+; SI-NEXT:    buffer_load_dwordx2 v[0:1], v[0:1], s[4:7], 0 addr64
+; SI-NEXT:    s_movk_i32 s4, 0xd1
+; SI-NEXT:    s_waitcnt vmcnt(0)
+; SI-NEXT:    v_mul_hi_u32 v2, v1, s4
+; SI-NEXT:    v_mul_lo_u32 v3, v1, s4
+; SI-NEXT:    v_sub_i32_e32 v2, vcc, v2, v1
+; SI-NEXT:    v_add_i32_e32 v0, vcc, v3, v0
+; SI-NEXT:    v_addc_u32_e32 v1, vcc, v2, v1, vcc
+; SI-NEXT:    s_setpc_b64 s[30:31]
+;
+; GFX9-LABEL: lshr_mad_i64:
+; GFX9:       ; %bb.0:
+; GFX9-NEXT:    s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
+; GFX9-NEXT:    global_load_dwordx2 v[0:1], v[0:1], off
+; GFX9-NEXT:    v_mov_b32_e32 v3, 0
+; GFX9-NEXT:    s_movk_i32 s4, 0xd1
+; GFX9-NEXT:    s_waitcnt vmcnt(0)
+; GFX9-NEXT:    v_mov_b32_e32 v2, v0
+; GFX9-NEXT:    v_mad_u64_u32 v[0:1], s[4:5], v1, s4, v[2:3]
+; GFX9-NEXT:    s_setpc_b64 s[30:31]
+;
+; GFX1100-LABEL: lshr_mad_i64:
+; GFX1100:       ; %bb.0:
+; GFX1100-NEXT:    s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
+; GFX1100-NEXT:    global_load_b64 v[1:2], v[0:1], off
+; GFX1100-NEXT:    s_waitcnt vmcnt(0)
+; GFX1100-NEXT:    v_dual_mov_b32 v4, 0 :: v_dual_mov_b32 v3, v1
+; GFX1100-NEXT:    s_delay_alu instid0(VALU_DEP_1)
+; GFX1100-NEXT:    v_mad_u64_u32 v[0:1], null, 0xd1, v2, v[3:4]
+; GFX1100-NEXT:    s_setpc_b64 s[30:31]
+;
+; GFX1150-LABEL: lshr_mad_i64:
+; GFX1150:       ; %bb.0:
+; GFX1150-NEXT:    s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
+; GFX1150-NEXT:    global_load_b64 v[0:1], v[0:1], off
+; GFX1150-NEXT:    s_waitcnt vmcnt(0)
+; GFX1150-NEXT:    v_dual_mov_b32 v3, 0 :: v_dual_mov_b32 v2, v0
+; GFX1150-NEXT:    s_delay_alu instid0(VALU_DEP_1)
+; GFX1150-NEXT:    v_mad_u64_u32 v[0:1], null, 0xd1, v1, v[2:3]
+; GFX1150-NEXT:    s_setpc_b64 s[30:31]
+;
+; GFX12-LABEL: lshr_mad_i64:
+; GFX12:       ; %bb.0:
+; GFX12-NEXT:    s_wait_loadcnt_dscnt 0x0
+; GFX12-NEXT:    s_wait_expcnt 0x0
+; GFX12-NEXT:    s_wait_samplecnt 0x0
+; GFX12-NEXT:    s_wait_bvhcnt 0x0
+; GFX12-NEXT:    s_wait_kmcnt 0x0
+; GFX12-NEXT:    global_load_b64 v[0:1], v[0:1], off
+; GFX12-NEXT:    s_wait_loadcnt 0x0
+; GFX12-NEXT:    v_dual_mov_b32 v3, 0 :: v_dual_mov_b32 v2, v0
+; GFX12-NEXT:    s_delay_alu instid0(VALU_DEP_1)
+; GFX12-NEXT:    v_mad_co_u64_u32 v[0:1], null, 0xd1, v1, v[2:3]
+; GFX12-NEXT:    s_setpc_b64 s[30:31]
+  %3 = load i64, ptr addrspace(1) %1, align 8
+  %4 = lshr i64 %3, 32
+  %5 = mul nsw i64 %4, -4294967087
+  %6 = add nsw i64 %5, %3
+
+  ret i64 %6
+}
+
 attributes #0 = { nounwind }
 attributes #1 = { nounwind readnone speculatable }

>From 6be9c4728c54e36439156366c9afdda6df4721f9 Mon Sep 17 00:00:00 2001
From: vikhegde <vikram.hegde at amd.com>
Date: Tue, 10 Dec 2024 08:23:57 +0000
Subject: [PATCH 2/3] review comments

---
 llvm/lib/Target/AMDGPU/SIISelLowering.cpp | 45 ++++++++++-------------
 1 file changed, 19 insertions(+), 26 deletions(-)

diff --git a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
index c18db9f65dc08e..693fdfe934d8c3 100644
--- a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
+++ b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
@@ -13859,43 +13859,30 @@ static SDValue getMad64_32(SelectionDAG &DAG, const SDLoc &SL, EVT VT,
 
 // Fold
 //     y = lshr i64 x, 32
-//     res = add (mul i64 y, Constant), x   where "Constant" is a 32 bit
-//     negative value
+//     res = add (mul i64 y, Const), x   where "Const" is a 64-bit constant
+//     with Const.hi == -1
 // To
-//     res = mad_u64_u32 y.lo ,Constant.lo, x.lo
+//     res = mad_u64_u32 y.lo ,Const.lo, x.lo
 static SDValue tryFoldMADwithSRL(SelectionDAG &DAG, const SDLoc &SL,
                                  SDValue MulLHS, SDValue MulRHS,
                                  SDValue AddRHS) {
 
-  if (MulLHS.getValueType() != MVT::i64)
+  if (MulLHS.getValueType() != MVT::i64 || MulLHS.getOpcode() != ISD::SRL)
     return SDValue();
 
-  ConstantSDNode *ConstOp;
-  SDValue ShiftOp;
-  if (MulLHS.getOpcode() == ISD::SRL && MulRHS.getOpcode() == ISD::Constant) {
-    ConstOp = cast<ConstantSDNode>(MulRHS.getNode());
-    ShiftOp = MulLHS;
-  } else if (MulRHS.getOpcode() == ISD::SRL &&
-             MulLHS.getOpcode() == ISD::Constant) {
-    ConstOp = cast<ConstantSDNode>(MulLHS.getNode());
-    ShiftOp = MulRHS;
-  } else
-    return SDValue();
-
-  if (ShiftOp.getOperand(1).getOpcode() != ISD::Constant ||
-      AddRHS != ShiftOp.getOperand(0))
+  if (MulLHS.getOperand(1).getOpcode() != ISD::Constant ||
+      MulLHS.getOperand(0) != AddRHS)
     return SDValue();
 
-  if (cast<ConstantSDNode>(ShiftOp->getOperand(1))->getAsZExtVal() != 32)
+  if (cast<ConstantSDNode>(MulLHS->getOperand(1))->getAsZExtVal() != 32)
     return SDValue();
 
-  APInt ConstVal = ConstOp->getAPIntValue();
-  if (!ConstVal.isNegative() || !ConstVal.isSignedIntN(33))
+  APInt Const = cast<ConstantSDNode>(MulRHS.getNode())->getAPIntValue();
+  if (!Const.isNegative() || !Const.isSignedIntN(33))
     return SDValue();
 
-  SDValue Zero = DAG.getConstant(0, SL, MVT::i32);
-  SDValue ConstMul = DAG.getConstant(
-      ConstVal.getZExtValue() & 0x00000000FFFFFFFF, SL, MVT::i32);
+  SDValue ConstMul =
+      DAG.getConstant(Const.getZExtValue() & 0x00000000FFFFFFFF, SL, MVT::i32);
   AddRHS = DAG.getNode(ISD::AND, SL, MVT::i64, AddRHS,
                        DAG.getConstant(0x00000000FFFFFFFF, SL, MVT::i64));
   return getMad64_32(DAG, SL, MVT::i64,
@@ -13961,8 +13948,14 @@ SDValue SITargetLowering::tryFoldToMad64_32(SDNode *N,
   SDValue MulRHS = LHS.getOperand(1);
   SDValue AddRHS = RHS;
 
-  if (SDValue FoldedMAD = tryFoldMADwithSRL(DAG, SL, MulLHS, MulRHS, AddRHS))
-    return FoldedMAD;
+  if (MulLHS.getOpcode() == ISD::Constant ||
+      MulRHS.getOpcode() == ISD::Constant) {
+    if (MulRHS.getOpcode() == ISD::SRL)
+      std::swap(MulLHS, MulRHS);
+
+    if (SDValue FoldedMAD = tryFoldMADwithSRL(DAG, SL, MulLHS, MulRHS, AddRHS))
+      return FoldedMAD;
+  }
 
   // Always check whether operands are small unsigned values, since that
   // knowledge is useful in more cases. Check for small signed values only if

>From 2966e120cba0962bbcdb5ca7f5e1cccca0fc8c44 Mon Sep 17 00:00:00 2001
From: vikhegde <vikram.hegde at amd.com>
Date: Tue, 17 Dec 2024 12:24:13 +0530
Subject: [PATCH 3/3] review comments

---
 llvm/lib/Target/AMDGPU/SIISelLowering.cpp | 15 ++++++---------
 1 file changed, 6 insertions(+), 9 deletions(-)

diff --git a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
index 693fdfe934d8c3..287e101ac199fd 100644
--- a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
+++ b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
@@ -13870,24 +13870,22 @@ static SDValue tryFoldMADwithSRL(SelectionDAG &DAG, const SDLoc &SL,
   if (MulLHS.getValueType() != MVT::i64 || MulLHS.getOpcode() != ISD::SRL)
     return SDValue();
 
-  if (MulLHS.getOperand(1).getOpcode() != ISD::Constant ||
-      MulLHS.getOperand(0) != AddRHS)
+  ConstantSDNode *ShiftVal = dyn_cast<ConstantSDNode>(MulLHS.getOperand(1));
+  if (!ShiftVal || MulLHS.getOperand(0) != AddRHS)
     return SDValue();
 
-  if (cast<ConstantSDNode>(MulLHS->getOperand(1))->getAsZExtVal() != 32)
+  if (ShiftVal->getAsZExtVal() != 32)
     return SDValue();
 
-  APInt Const = cast<ConstantSDNode>(MulRHS.getNode())->getAPIntValue();
+  APInt Const = dyn_cast<ConstantSDNode>(MulRHS.getNode())->getAPIntValue();
   if (!Const.isNegative() || !Const.isSignedIntN(33))
     return SDValue();
 
   SDValue ConstMul =
       DAG.getConstant(Const.getZExtValue() & 0x00000000FFFFFFFF, SL, MVT::i32);
-  AddRHS = DAG.getNode(ISD::AND, SL, MVT::i64, AddRHS,
-                       DAG.getConstant(0x00000000FFFFFFFF, SL, MVT::i64));
   return getMad64_32(DAG, SL, MVT::i64,
                      DAG.getNode(ISD::TRUNCATE, SL, MVT::i32, MulLHS), ConstMul,
-                     AddRHS, false);
+                     DAG.getZeroExtendInReg(AddRHS, SL, MVT::i32), false);
 }
 
 // Fold (add (mul x, y), z) --> (mad_[iu]64_[iu]32 x, y, z) plus high
@@ -13948,8 +13946,7 @@ SDValue SITargetLowering::tryFoldToMad64_32(SDNode *N,
   SDValue MulRHS = LHS.getOperand(1);
   SDValue AddRHS = RHS;
 
-  if (MulLHS.getOpcode() == ISD::Constant ||
-      MulRHS.getOpcode() == ISD::Constant) {
+  if (isa<ConstantSDNode>(MulLHS) || isa<ConstantSDNode>(MulRHS)) {
     if (MulRHS.getOpcode() == ISD::SRL)
       std::swap(MulLHS, MulRHS);
 



More information about the llvm-commits mailing list