[llvm] [DAG] Constant Folding for U/SMUL_LOHI (PR #69437)

Pierre van Houtryve via llvm-commits llvm-commits at lists.llvm.org
Wed Oct 18 02:57:13 PDT 2023


https://github.com/Pierre-vh updated https://github.com/llvm/llvm-project/pull/69437

>From 55903f3580c9d337a308129332b3e548753cacaa Mon Sep 17 00:00:00 2001
From: pvanhout <pierre.vanhoutryve at amd.com>
Date: Wed, 18 Oct 2023 11:05:18 +0200
Subject: [PATCH 1/4] [DAG] Constant Folding for U/SMUL_LOHI

---
 llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp |  8 +++
 .../lib/CodeGen/SelectionDAG/SelectionDAG.cpp | 21 ++++++++
 llvm/test/CodeGen/AMDGPU/udiv.ll              | 49 ++++++++++---------
 3 files changed, 54 insertions(+), 24 deletions(-)

diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 2dfdddad3cc389f..73950294200e46a 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -5346,6 +5346,10 @@ SDValue DAGCombiner::visitSMUL_LOHI(SDNode *N) {
   EVT VT = N->getValueType(0);
   SDLoc DL(N);
 
+  // Constant fold.
+  if (isa<ConstantSDNode>(N0) && isa<ConstantSDNode>(N1))
+    return DAG.getNode(ISD::SMUL_LOHI, DL, N->getVTList(), N0, N1);
+
   // canonicalize constant to RHS (vector doesn't have to splat)
   if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
       !DAG.isConstantIntBuildVectorOrConstantInt(N1))
@@ -5384,6 +5388,10 @@ SDValue DAGCombiner::visitUMUL_LOHI(SDNode *N) {
   EVT VT = N->getValueType(0);
   SDLoc DL(N);
 
+  // Constant fold.
+  if (isa<ConstantSDNode>(N0) && isa<ConstantSDNode>(N1))
+    return DAG.getNode(ISD::UMUL_LOHI, DL, N->getVTList(), N0, N1);
+
   // canonicalize constant to RHS (vector doesn't have to splat)
   if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
       !DAG.isConstantIntBuildVectorOrConstantInt(N1))
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index 8c275bfcfbd2796..22798bbcd461cf2 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -9879,6 +9879,27 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, SDVTList VTList,
            VTList.VTs[0] == Ops[0].getValueType() &&
            VTList.VTs[0] == Ops[1].getValueType() &&
            "Binary operator types must match!");
+    // Constant fold.
+    ConstantSDNode* LHS = dyn_cast<ConstantSDNode>(Ops[0]);
+    ConstantSDNode* RHS = dyn_cast<ConstantSDNode>(Ops[1]);
+    if(LHS && RHS) {
+      unsigned Width = VTList.VTs[0].getScalarSizeInBits();
+      unsigned OutWidth = Width * 2;
+      APInt Val = LHS->getAPIntValue();
+      APInt Mul = RHS->getAPIntValue();
+      if(Opcode == ISD::SMUL_LOHI) {
+        Val = Val.sext(OutWidth);
+        Mul = Mul.sext(OutWidth);
+      } else {
+        Val = Val.zext(OutWidth);
+        Mul = Mul.zext(OutWidth);
+      }
+      Val *= Mul;
+
+      SDValue Hi = getConstant(Val.getHiBits(Width).trunc(Width), DL, VTList.VTs[0]);
+      SDValue Lo = getConstant(Val.trunc(Width), DL, VTList.VTs[0]);
+      return getNode(ISD::MERGE_VALUES, DL, VTList, {Lo, Hi}, Flags);
+    }
     break;
   }
   case ISD::FFREXP: {
diff --git a/llvm/test/CodeGen/AMDGPU/udiv.ll b/llvm/test/CodeGen/AMDGPU/udiv.ll
index e554f912ff64886..5979cce76322802 100644
--- a/llvm/test/CodeGen/AMDGPU/udiv.ll
+++ b/llvm/test/CodeGen/AMDGPU/udiv.ll
@@ -2619,37 +2619,38 @@ define i64 @v_test_udiv64_mulhi_fold(i64 %arg) {
 ; VI-LABEL: v_test_udiv64_mulhi_fold:
 ; VI:       ; %bb.0:
 ; VI-NEXT:    s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
-; VI-NEXT:    v_mov_b32_e32 v4, 0xa7c5
-; VI-NEXT:    v_mul_u32_u24_e32 v3, 0x500, v4
-; VI-NEXT:    v_mul_hi_u32_u24_e32 v2, 0x500, v4
-; VI-NEXT:    v_add_u32_e32 v3, vcc, 0x4237, v3
-; VI-NEXT:    v_addc_u32_e32 v5, vcc, 0, v2, vcc
-; VI-NEXT:    v_add_u32_e32 v6, vcc, 0xa9000000, v3
-; VI-NEXT:    s_mov_b32 s6, 0xfffe7960
-; VI-NEXT:    v_mad_u64_u32 v[2:3], s[4:5], v6, s6, 0
-; VI-NEXT:    v_addc_u32_e32 v7, vcc, v5, v4, vcc
-; VI-NEXT:    v_mul_lo_u32 v4, v7, s6
+; VI-NEXT:    s_mov_b32 s4, 0x346d900
+; VI-NEXT:    s_add_u32 s4, 0x4237, s4
+; VI-NEXT:    v_mov_b32_e32 v2, 0xa9000000
+; VI-NEXT:    v_add_u32_e32 v6, vcc, s4, v2
+; VI-NEXT:    s_mov_b32 s4, 0xfffe7960
+; VI-NEXT:    v_mad_u64_u32 v[2:3], s[4:5], v6, s4, 0
+; VI-NEXT:    s_addc_u32 s6, 0, 0
+; VI-NEXT:    s_cmp_lg_u64 vcc, 0
+; VI-NEXT:    s_addc_u32 s6, s6, 0xa7c5
+; VI-NEXT:    s_mul_i32 s4, s6, 0xfffe7960
 ; VI-NEXT:    v_sub_u32_e32 v3, vcc, v3, v6
-; VI-NEXT:    v_mul_hi_u32 v8, v6, v2
-; VI-NEXT:    v_add_u32_e32 v5, vcc, v4, v3
+; VI-NEXT:    v_add_u32_e32 v5, vcc, s4, v3
 ; VI-NEXT:    v_mad_u64_u32 v[3:4], s[4:5], v6, v5, 0
-; VI-NEXT:    v_add_u32_e32 v8, vcc, v8, v3
-; VI-NEXT:    v_mad_u64_u32 v[2:3], s[4:5], v7, v2, 0
-; VI-NEXT:    v_addc_u32_e32 v9, vcc, 0, v4, vcc
-; VI-NEXT:    v_mad_u64_u32 v[4:5], s[4:5], v7, v5, 0
-; VI-NEXT:    v_add_u32_e32 v2, vcc, v8, v2
-; VI-NEXT:    v_addc_u32_e32 v2, vcc, v9, v3, vcc
+; VI-NEXT:    v_mul_hi_u32 v7, v6, v2
+; VI-NEXT:    v_add_u32_e32 v7, vcc, v7, v3
+; VI-NEXT:    v_mad_u64_u32 v[2:3], s[4:5], s6, v2, 0
+; VI-NEXT:    v_addc_u32_e32 v8, vcc, 0, v4, vcc
+; VI-NEXT:    v_mad_u64_u32 v[4:5], s[4:5], s6, v5, 0
+; VI-NEXT:    v_add_u32_e32 v2, vcc, v7, v2
+; VI-NEXT:    v_addc_u32_e32 v2, vcc, v8, v3, vcc
 ; VI-NEXT:    v_addc_u32_e32 v3, vcc, 0, v5, vcc
 ; VI-NEXT:    v_add_u32_e32 v2, vcc, v2, v4
 ; VI-NEXT:    v_addc_u32_e32 v3, vcc, 0, v3, vcc
-; VI-NEXT:    v_add_u32_e32 v4, vcc, v6, v2
-; VI-NEXT:    v_addc_u32_e32 v5, vcc, v7, v3, vcc
-; VI-NEXT:    v_mad_u64_u32 v[2:3], s[4:5], v0, v5, 0
-; VI-NEXT:    v_mul_hi_u32 v6, v0, v4
+; VI-NEXT:    v_mov_b32_e32 v4, s6
+; VI-NEXT:    v_add_u32_e32 v5, vcc, v6, v2
+; VI-NEXT:    v_addc_u32_e32 v4, vcc, v4, v3, vcc
+; VI-NEXT:    v_mad_u64_u32 v[2:3], s[4:5], v0, v4, 0
+; VI-NEXT:    v_mul_hi_u32 v6, v0, v5
 ; VI-NEXT:    v_add_u32_e32 v6, vcc, v6, v2
 ; VI-NEXT:    v_addc_u32_e32 v7, vcc, 0, v3, vcc
-; VI-NEXT:    v_mad_u64_u32 v[2:3], s[4:5], v1, v4, 0
-; VI-NEXT:    v_mad_u64_u32 v[4:5], s[4:5], v1, v5, 0
+; VI-NEXT:    v_mad_u64_u32 v[2:3], s[4:5], v1, v5, 0
+; VI-NEXT:    v_mad_u64_u32 v[4:5], s[4:5], v1, v4, 0
 ; VI-NEXT:    v_add_u32_e32 v2, vcc, v6, v2
 ; VI-NEXT:    v_addc_u32_e32 v2, vcc, v7, v3, vcc
 ; VI-NEXT:    v_addc_u32_e32 v3, vcc, 0, v5, vcc

>From 93a1f38d9d9312874a58c0ba13fba00eb4baa65d Mon Sep 17 00:00:00 2001
From: pvanhout <pierre.vanhoutryve at amd.com>
Date: Wed, 18 Oct 2023 11:07:50 +0200
Subject: [PATCH 2/4] clang-format

---
 llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp | 11 ++++++-----
 1 file changed, 6 insertions(+), 5 deletions(-)

diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index 22798bbcd461cf2..097bb8b968bb029 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -9880,14 +9880,14 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, SDVTList VTList,
            VTList.VTs[0] == Ops[1].getValueType() &&
            "Binary operator types must match!");
     // Constant fold.
-    ConstantSDNode* LHS = dyn_cast<ConstantSDNode>(Ops[0]);
-    ConstantSDNode* RHS = dyn_cast<ConstantSDNode>(Ops[1]);
-    if(LHS && RHS) {
+    ConstantSDNode *LHS = dyn_cast<ConstantSDNode>(Ops[0]);
+    ConstantSDNode *RHS = dyn_cast<ConstantSDNode>(Ops[1]);
+    if (LHS && RHS) {
       unsigned Width = VTList.VTs[0].getScalarSizeInBits();
       unsigned OutWidth = Width * 2;
       APInt Val = LHS->getAPIntValue();
       APInt Mul = RHS->getAPIntValue();
-      if(Opcode == ISD::SMUL_LOHI) {
+      if (Opcode == ISD::SMUL_LOHI) {
         Val = Val.sext(OutWidth);
         Mul = Mul.sext(OutWidth);
       } else {
@@ -9896,7 +9896,8 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, SDVTList VTList,
       }
       Val *= Mul;
 
-      SDValue Hi = getConstant(Val.getHiBits(Width).trunc(Width), DL, VTList.VTs[0]);
+      SDValue Hi =
+          getConstant(Val.getHiBits(Width).trunc(Width), DL, VTList.VTs[0]);
       SDValue Lo = getConstant(Val.trunc(Width), DL, VTList.VTs[0]);
       return getNode(ISD::MERGE_VALUES, DL, VTList, {Lo, Hi}, Flags);
     }

>From 46af3d39799da27adc9f766f8b60bc7495d9e7fa Mon Sep 17 00:00:00 2001
From: pvanhout <pierre.vanhoutryve at amd.com>
Date: Wed, 18 Oct 2023 11:39:03 +0200
Subject: [PATCH 3/4] Use extractBits

---
 llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp | 3 +--
 1 file changed, 1 insertion(+), 2 deletions(-)

diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index 097bb8b968bb029..50679acd07cb97c 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -9896,8 +9896,7 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, SDVTList VTList,
       }
       Val *= Mul;
 
-      SDValue Hi =
-          getConstant(Val.getHiBits(Width).trunc(Width), DL, VTList.VTs[0]);
+      SDValue Hi = getConstant(Val.extractBits(32, 32), DL, VTList.VTs[0]);
       SDValue Lo = getConstant(Val.trunc(Width), DL, VTList.VTs[0]);
       return getNode(ISD::MERGE_VALUES, DL, VTList, {Lo, Hi}, Flags);
     }

>From 39d54ea780df9c98f738d84a7189f20add81a59d Mon Sep 17 00:00:00 2001
From: pvanhout <pierre.vanhoutryve at amd.com>
Date: Wed, 18 Oct 2023 11:56:59 +0200
Subject: [PATCH 4/4] Generalize extractBits args

---
 llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index 50679acd07cb97c..fcdf6a55ad2c5a1 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -9896,7 +9896,7 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, SDVTList VTList,
       }
       Val *= Mul;
 
-      SDValue Hi = getConstant(Val.extractBits(32, 32), DL, VTList.VTs[0]);
+      SDValue Hi = getConstant(Val.extractBits(Width, Width), DL, VTList.VTs[0]);
       SDValue Lo = getConstant(Val.trunc(Width), DL, VTList.VTs[0]);
       return getNode(ISD::MERGE_VALUES, DL, VTList, {Lo, Hi}, Flags);
     }



More information about the llvm-commits mailing list