[llvm] [llvm][RISCV] Support mulh for P extension codegen (PR #171581)

via llvm-commits llvm-commits at lists.llvm.org
Wed Dec 10 00:55:22 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-risc-v

Author: Brandon Wu (4vtomat)

<details>
<summary>Changes</summary>

For mulh pattern with operands that are both signed or unsigned,
combination is performed automatically. However for mulh with operands
which are signed and unsigned respectively we need to combine them
manually same as we've done for PASUB*.

Note: This is first patch for mulh which only handle basic high part multiplication, there will be followup patches to handle rest of mulh related instructions.


---
Full diff: https://github.com/llvm/llvm-project/pull/171581.diff


4 Files Affected:

- (modified) llvm/lib/Target/RISCV/RISCVISelLowering.cpp (+44-30) 
- (modified) llvm/lib/Target/RISCV/RISCVInstrInfoP.td (+17-6) 
- (modified) llvm/test/CodeGen/RISCV/rvp-ext-rv32.ll (+60) 
- (modified) llvm/test/CodeGen/RISCV/rvp-ext-rv64.ll (+119) 


``````````diff
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index f28772a74d433..30c000814b59a 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -15210,18 +15210,22 @@ void RISCVTargetLowering::ReplaceNodeResults(SDNode *N,
     break;
   }
   case RISCVISD::PASUB:
-  case RISCVISD::PASUBU: {
+  case RISCVISD::PASUBU:
+  case RISCVISD::PMULHSU: {
     MVT VT = N->getSimpleValueType(0);
     SDValue Op0 = N->getOperand(0);
     SDValue Op1 = N->getOperand(1);
-    assert(VT == MVT::v2i16 || VT == MVT::v4i8);
+    unsigned Opcode = N->getOpcode();
+    // PMULHSU doesn't support i8 variants
+    assert(VT == MVT::v2i16 ||
+           (Opcode != RISCVISD::PMULHSU && VT == MVT::v4i8));
     MVT NewVT = MVT::v4i16;
     if (VT == MVT::v4i8)
       NewVT = MVT::v8i8;
     SDValue Undef = DAG.getUNDEF(VT);
     Op0 = DAG.getNode(ISD::CONCAT_VECTORS, DL, NewVT, {Op0, Undef});
     Op1 = DAG.getNode(ISD::CONCAT_VECTORS, DL, NewVT, {Op1, Undef});
-    Results.push_back(DAG.getNode(N->getOpcode(), DL, NewVT, {Op0, Op1}));
+    Results.push_back(DAG.getNode(Opcode, DL, NewVT, {Op0, Op1}));
     return;
   }
   case ISD::EXTRACT_VECTOR_ELT: {
@@ -16331,9 +16335,9 @@ static SDValue combineTruncSelectToSMaxUSat(SDNode *N, SelectionDAG &DAG) {
   return DAG.getNode(ISD::TRUNCATE, DL, VT, Min);
 }
 
-// Handle P extension averaging subtraction pattern:
-// (vXiY (trunc (srl (sub ([s|z]ext vXiY:$a), ([s|z]ext vXiY:$b)), 1)))
-// -> PASUB/PASUBU
+// Handle P extension truncate patterns:
+// PASUB/PASUBU: (trunc (srl (sub ([s|z]ext a), ([s|z]ext b)), 1))
+// PMULHSU: (trunc (srl (mul (sext a), (zext b)), EltBits))
 static SDValue combinePExtTruncate(SDNode *N, SelectionDAG &DAG,
                                    const RISCVSubtarget &Subtarget) {
   SDValue N0 = N->getOperand(0);
@@ -16346,7 +16350,7 @@ static SDValue combinePExtTruncate(SDNode *N, SelectionDAG &DAG,
       VecVT != MVT::v4i8 && VecVT != MVT::v2i32)
     return SDValue();
 
-  // Check if shift amount is 1
+  // Check if shift amount is a splat constant
   SDValue ShAmt = N0.getOperand(1);
   if (ShAmt.getOpcode() != ISD::BUILD_VECTOR)
     return SDValue();
@@ -16360,44 +16364,54 @@ static SDValue combinePExtTruncate(SDNode *N, SelectionDAG &DAG,
   ConstantSDNode *C = dyn_cast<ConstantSDNode>(Splat);
   if (!C)
     return SDValue();
-  if (C->getZExtValue() != 1)
-    return SDValue();
 
-  // Check for SUB operation
-  SDValue Sub = N0.getOperand(0);
-  if (Sub.getOpcode() != ISD::SUB)
-    return SDValue();
+  SDValue Op = N0.getOperand(0);
+  unsigned ShAmtVal = C->getZExtValue();
 
-  SDValue LHS = Sub.getOperand(0);
-  SDValue RHS = Sub.getOperand(1);
+  SDValue LHS = Op.getOperand(0);
+  SDValue RHS = Op.getOperand(1);
 
-  // Check if both operands are sign/zero extends from the target
-  // type
-  bool IsSignExt = LHS.getOpcode() == ISD::SIGN_EXTEND &&
-                   RHS.getOpcode() == ISD::SIGN_EXTEND;
-  bool IsZeroExt = LHS.getOpcode() == ISD::ZERO_EXTEND &&
-                   RHS.getOpcode() == ISD::ZERO_EXTEND;
+  bool LHSIsSExt = LHS.getOpcode() == ISD::SIGN_EXTEND;
+  bool LHSIsZExt = LHS.getOpcode() == ISD::ZERO_EXTEND;
+  bool RHSIsSExt = RHS.getOpcode() == ISD::SIGN_EXTEND;
+  bool RHSIsZExt = RHS.getOpcode() == ISD::ZERO_EXTEND;
 
-  if (!IsSignExt && !IsZeroExt)
+  if (!(LHSIsSExt || LHSIsZExt) || !(RHSIsSExt || RHSIsZExt))
     return SDValue();
 
   SDValue A = LHS.getOperand(0);
   SDValue B = RHS.getOperand(0);
 
-  // Check if the extends are from our target vector type
   if (A.getValueType() != VT || B.getValueType() != VT)
     return SDValue();
 
-  // Determine the instruction based on type and signedness
   unsigned Opc;
-  if (IsSignExt)
-    Opc = RISCVISD::PASUB;
-  else if (IsZeroExt)
-    Opc = RISCVISD::PASUBU;
-  else
+  switch (Op.getOpcode()) {
+  default:
     return SDValue();
+  case ISD::SUB:
+    // PASUB/PASUBU: shift amount must be 1
+    if (ShAmtVal != 1)
+      return SDValue();
+    if (LHSIsSExt && RHSIsSExt)
+      Opc = RISCVISD::PASUB;
+    else if (LHSIsZExt && RHSIsZExt)
+      Opc = RISCVISD::PASUBU;
+    else
+      return SDValue();
+    break;
+  case ISD::MUL:
+    // PMULHSU: shift amount must be element size, only for i16/i32
+    unsigned EltBits = VecVT.getScalarSizeInBits();
+    if (ShAmtVal != EltBits || (EltBits != 16 && EltBits != 32))
+      return SDValue();
+    if (LHSIsSExt && RHSIsZExt)
+      Opc = RISCVISD::PMULHSU;
+    else
+      return SDValue();
+    break;
+  }
 
-  // Create the machine node directly
   return DAG.getNode(Opc, SDLoc(N), VT, {A, B});
 }
 
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoP.td b/llvm/lib/Target/RISCV/RISCVInstrInfoP.td
index bba9f961b9639..485f1d984d96f 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoP.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoP.td
@@ -1463,12 +1463,13 @@ let Predicates = [HasStdExtP, IsRV32] in {
 
 def riscv_absw : RVSDNode<"ABSW", SDT_RISCVIntUnaryOpW>;
 
-def SDT_RISCVPASUB : SDTypeProfile<1, 2, [SDTCisVec<0>,
-                                          SDTCisInt<0>,
-                                          SDTCisSameAs<0, 1>,
-                                          SDTCisSameAs<0, 2>]>;
-def riscv_pasub : RVSDNode<"PASUB", SDT_RISCVPASUB>;
-def riscv_pasubu : RVSDNode<"PASUBU", SDT_RISCVPASUB>;
+def SDT_RISCVPBinOp : SDTypeProfile<1, 2, [SDTCisVec<0>,
+                                           SDTCisInt<0>,
+                                           SDTCisSameAs<0, 1>,
+                                           SDTCisSameAs<0, 2>]>;
+def riscv_pasub : RVSDNode<"PASUB", SDT_RISCVPBinOp>;
+def riscv_pasubu : RVSDNode<"PASUBU", SDT_RISCVPBinOp>;
+def riscv_pmulhsu : RVSDNode<"PMULHSU", SDT_RISCVPBinOp>;
 
 let Predicates = [HasStdExtP] in {
   def : PatGpr<abs, ABS>;
@@ -1513,6 +1514,11 @@ let Predicates = [HasStdExtP] in {
   def: Pat<(XLenVecI16VT (abds GPR:$rs1, GPR:$rs2)), (PABD_H GPR:$rs1, GPR:$rs2)>;
   def: Pat<(XLenVecI16VT (abdu GPR:$rs1, GPR:$rs2)), (PABDU_H GPR:$rs1, GPR:$rs2)>;
 
+  // 16-bit multiply high patterns
+  def: Pat<(XLenVecI16VT (mulhs GPR:$rs1, GPR:$rs2)), (PMULH_H GPR:$rs1, GPR:$rs2)>;
+  def: Pat<(XLenVecI16VT (mulhu GPR:$rs1, GPR:$rs2)), (PMULHU_H GPR:$rs1, GPR:$rs2)>;
+  def: Pat<(XLenVecI16VT (riscv_pmulhsu GPR:$rs1, GPR:$rs2)), (PMULHSU_H GPR:$rs1, GPR:$rs2)>;
+
   // 8-bit logical shift left patterns
   def: Pat<(XLenVecI8VT (shl GPR:$rs1, (XLenVecI8VT (splat_vector uimm3:$shamt)))),
            (PSLLI_B GPR:$rs1, uimm3:$shamt)>;
@@ -1580,6 +1586,11 @@ let Predicates = [HasStdExtP, IsRV64] in {
   def: Pat<(v2i32 (riscv_pasub GPR:$rs1, GPR:$rs2)), (PASUB_W GPR:$rs1, GPR:$rs2)>;
   def: Pat<(v2i32 (riscv_pasubu GPR:$rs1, GPR:$rs2)), (PASUBU_W GPR:$rs1, GPR:$rs2)>;
 
+  // 32-bit multiply high patterns
+  def: Pat<(v2i32 (mulhs GPR:$rs1, GPR:$rs2)), (PMULH_W GPR:$rs1, GPR:$rs2)>;
+  def: Pat<(v2i32 (mulhu GPR:$rs1, GPR:$rs2)), (PMULHU_W GPR:$rs1, GPR:$rs2)>;
+  def: Pat<(v2i32 (riscv_pmulhsu GPR:$rs1, GPR:$rs2)), (PMULHSU_W GPR:$rs1, GPR:$rs2)>;
+
   // 32-bit logical shift left
   def: Pat<(v2i32 (shl GPR:$rs1, (v2i32 (splat_vector (XLenVT GPR:$rs2))))),
            (PSLL_WS GPR:$rs1, GPR:$rs2)>;
diff --git a/llvm/test/CodeGen/RISCV/rvp-ext-rv32.ll b/llvm/test/CodeGen/RISCV/rvp-ext-rv32.ll
index cd59aa03597e2..87673d35a058a 100644
--- a/llvm/test/CodeGen/RISCV/rvp-ext-rv32.ll
+++ b/llvm/test/CodeGen/RISCV/rvp-ext-rv32.ll
@@ -746,3 +746,63 @@ define void @test_psll_bs_vec_shamt(ptr %ret_ptr, ptr %a_ptr, ptr %shamt_ptr) {
   store <4 x i8> %res, ptr %ret_ptr
   ret void
 }
+
+; Test packed multiply high signed for v2i16
+define void @test_pmulh_h(ptr %ret_ptr, ptr %a_ptr, ptr %b_ptr) {
+; CHECK-LABEL: test_pmulh_h:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    lw a1, 0(a1)
+; CHECK-NEXT:    lw a2, 0(a2)
+; CHECK-NEXT:    pmulh.h a1, a1, a2
+; CHECK-NEXT:    sw a1, 0(a0)
+; CHECK-NEXT:    ret
+  %a = load <2 x i16>, ptr %a_ptr
+  %b = load <2 x i16>, ptr %b_ptr
+  %a_ext = sext <2 x i16> %a to <2 x i32>
+  %b_ext = sext <2 x i16> %b to <2 x i32>
+  %mul = mul <2 x i32> %a_ext, %b_ext
+  %shift = lshr <2 x i32> %mul, <i32 16, i32 16>
+  %res = trunc <2 x i32> %shift to <2 x i16>
+  store <2 x i16> %res, ptr %ret_ptr
+  ret void
+}
+
+; Test packed multiply high unsigned for v2i16
+define void @test_pmulhu_h(ptr %ret_ptr, ptr %a_ptr, ptr %b_ptr) {
+; CHECK-LABEL: test_pmulhu_h:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    lw a1, 0(a1)
+; CHECK-NEXT:    lw a2, 0(a2)
+; CHECK-NEXT:    pmulhu.h a1, a1, a2
+; CHECK-NEXT:    sw a1, 0(a0)
+; CHECK-NEXT:    ret
+  %a = load <2 x i16>, ptr %a_ptr
+  %b = load <2 x i16>, ptr %b_ptr
+  %a_ext = zext <2 x i16> %a to <2 x i32>
+  %b_ext = zext <2 x i16> %b to <2 x i32>
+  %mul = mul <2 x i32> %a_ext, %b_ext
+  %shift = lshr <2 x i32> %mul, <i32 16, i32 16>
+  %res = trunc <2 x i32> %shift to <2 x i16>
+  store <2 x i16> %res, ptr %ret_ptr
+  ret void
+}
+
+; Test packed multiply high signed-unsigned for v2i16
+define void @test_pmulhsu_h(ptr %ret_ptr, ptr %a_ptr, ptr %b_ptr) {
+; CHECK-LABEL: test_pmulhsu_h:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    lw a1, 0(a1)
+; CHECK-NEXT:    lw a2, 0(a2)
+; CHECK-NEXT:    pmulhsu.h a1, a1, a2
+; CHECK-NEXT:    sw a1, 0(a0)
+; CHECK-NEXT:    ret
+  %a = load <2 x i16>, ptr %a_ptr
+  %b = load <2 x i16>, ptr %b_ptr
+  %a_ext = sext <2 x i16> %a to <2 x i32>
+  %b_ext = zext <2 x i16> %b to <2 x i32>
+  %mul = mul <2 x i32> %a_ext, %b_ext
+  %shift = lshr <2 x i32> %mul, <i32 16, i32 16>
+  %res = trunc <2 x i32> %shift to <2 x i16>
+  store <2 x i16> %res, ptr %ret_ptr
+  ret void
+}
diff --git a/llvm/test/CodeGen/RISCV/rvp-ext-rv64.ll b/llvm/test/CodeGen/RISCV/rvp-ext-rv64.ll
index c7fb891cdd996..9415fb0488bb7 100644
--- a/llvm/test/CodeGen/RISCV/rvp-ext-rv64.ll
+++ b/llvm/test/CodeGen/RISCV/rvp-ext-rv64.ll
@@ -841,3 +841,122 @@ define void @test_psll_ws_vec_shamt(ptr %ret_ptr, ptr %a_ptr, ptr %shamt_ptr) {
   store <2 x i32> %res, ptr %ret_ptr
   ret void
 }
+
+; Test packed multiply high signed
+define void @test_pmulh_h(ptr %ret_ptr, ptr %a_ptr, ptr %b_ptr) {
+; CHECK-LABEL: test_pmulh_h:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    ld a1, 0(a1)
+; CHECK-NEXT:    ld a2, 0(a2)
+; CHECK-NEXT:    pmulh.h a1, a1, a2
+; CHECK-NEXT:    sd a1, 0(a0)
+; CHECK-NEXT:    ret
+  %a = load <4 x i16>, ptr %a_ptr
+  %b = load <4 x i16>, ptr %b_ptr
+  %a_ext = sext <4 x i16> %a to <4 x i32>
+  %b_ext = sext <4 x i16> %b to <4 x i32>
+  %mul = mul <4 x i32> %a_ext, %b_ext
+  %shift = lshr <4 x i32> %mul, <i32 16, i32 16, i32 16, i32 16>
+  %res = trunc <4 x i32> %shift to <4 x i16>
+  store <4 x i16> %res, ptr %ret_ptr
+  ret void
+}
+
+define void @test_pmulh_w(ptr %ret_ptr, ptr %a_ptr, ptr %b_ptr) {
+; CHECK-LABEL: test_pmulh_w:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    ld a1, 0(a1)
+; CHECK-NEXT:    ld a2, 0(a2)
+; CHECK-NEXT:    pmulh.w a1, a1, a2
+; CHECK-NEXT:    sd a1, 0(a0)
+; CHECK-NEXT:    ret
+  %a = load <2 x i32>, ptr %a_ptr
+  %b = load <2 x i32>, ptr %b_ptr
+  %a_ext = sext <2 x i32> %a to <2 x i64>
+  %b_ext = sext <2 x i32> %b to <2 x i64>
+  %mul = mul <2 x i64> %a_ext, %b_ext
+  %shift = lshr <2 x i64> %mul, <i64 32, i64 32>
+  %res = trunc <2 x i64> %shift to <2 x i32>
+  store <2 x i32> %res, ptr %ret_ptr
+  ret void
+}
+
+
+; Test packed multiply high unsigned
+define void @test_pmulhu_h(ptr %ret_ptr, ptr %a_ptr, ptr %b_ptr) {
+; CHECK-LABEL: test_pmulhu_h:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    ld a1, 0(a1)
+; CHECK-NEXT:    ld a2, 0(a2)
+; CHECK-NEXT:    pmulhu.h a1, a1, a2
+; CHECK-NEXT:    sd a1, 0(a0)
+; CHECK-NEXT:    ret
+  %a = load <4 x i16>, ptr %a_ptr
+  %b = load <4 x i16>, ptr %b_ptr
+  %a_ext = zext <4 x i16> %a to <4 x i32>
+  %b_ext = zext <4 x i16> %b to <4 x i32>
+  %mul = mul <4 x i32> %a_ext, %b_ext
+  %shift = lshr <4 x i32> %mul, <i32 16, i32 16, i32 16, i32 16>
+  %res = trunc <4 x i32> %shift to <4 x i16>
+  store <4 x i16> %res, ptr %ret_ptr
+  ret void
+}
+
+define void @test_pmulhu_w(ptr %ret_ptr, ptr %a_ptr, ptr %b_ptr) {
+; CHECK-LABEL: test_pmulhu_w:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    ld a1, 0(a1)
+; CHECK-NEXT:    ld a2, 0(a2)
+; CHECK-NEXT:    pmulhu.w a1, a1, a2
+; CHECK-NEXT:    sd a1, 0(a0)
+; CHECK-NEXT:    ret
+  %a = load <2 x i32>, ptr %a_ptr
+  %b = load <2 x i32>, ptr %b_ptr
+  %a_ext = zext <2 x i32> %a to <2 x i64>
+  %b_ext = zext <2 x i32> %b to <2 x i64>
+  %mul = mul <2 x i64> %a_ext, %b_ext
+  %shift = lshr <2 x i64> %mul, <i64 32, i64 32>
+  %res = trunc <2 x i64> %shift to <2 x i32>
+  store <2 x i32> %res, ptr %ret_ptr
+  ret void
+}
+
+
+; Test packed multiply high signed-unsigned
+define void @test_pmulhsu_h(ptr %ret_ptr, ptr %a_ptr, ptr %b_ptr) {
+; CHECK-LABEL: test_pmulhsu_h:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    ld a1, 0(a1)
+; CHECK-NEXT:    ld a2, 0(a2)
+; CHECK-NEXT:    pmulhsu.h a1, a1, a2
+; CHECK-NEXT:    sd a1, 0(a0)
+; CHECK-NEXT:    ret
+  %a = load <4 x i16>, ptr %a_ptr
+  %b = load <4 x i16>, ptr %b_ptr
+  %a_ext = sext <4 x i16> %a to <4 x i32>
+  %b_ext = zext <4 x i16> %b to <4 x i32>
+  %mul = mul <4 x i32> %a_ext, %b_ext
+  %shift = lshr <4 x i32> %mul, <i32 16, i32 16, i32 16, i32 16>
+  %res = trunc <4 x i32> %shift to <4 x i16>
+  store <4 x i16> %res, ptr %ret_ptr
+  ret void
+}
+
+define void @test_pmulhsu_w(ptr %ret_ptr, ptr %a_ptr, ptr %b_ptr) {
+; CHECK-LABEL: test_pmulhsu_w:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    ld a1, 0(a1)
+; CHECK-NEXT:    ld a2, 0(a2)
+; CHECK-NEXT:    pmulhsu.w a1, a1, a2
+; CHECK-NEXT:    sd a1, 0(a0)
+; CHECK-NEXT:    ret
+  %a = load <2 x i32>, ptr %a_ptr
+  %b = load <2 x i32>, ptr %b_ptr
+  %a_ext = sext <2 x i32> %a to <2 x i64>
+  %b_ext = zext <2 x i32> %b to <2 x i64>
+  %mul = mul <2 x i64> %a_ext, %b_ext
+  %shift = lshr <2 x i64> %mul, <i64 32, i64 32>
+  %res = trunc <2 x i64> %shift to <2 x i32>
+  store <2 x i32> %res, ptr %ret_ptr
+  ret void
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/171581


More information about the llvm-commits mailing list