[llvm] p ext codegen mul2 (PR #171593)

Brandon Wu via llvm-commits llvm-commits at lists.llvm.org
Wed Dec 10 02:30:00 PST 2025


https://github.com/4vtomat created https://github.com/llvm/llvm-project/pull/171593

- [llvm][RISCV] Support mulh for P extension codegen
- [llvm][RISCV] Support rounding mulh for P extension codegen


>From 8f43e66917c5eaa84c4174cfdf27c4e688359da9 Mon Sep 17 00:00:00 2001
From: Brandon Wu <songwu0813 at gmail.com>
Date: Wed, 10 Dec 2025 15:42:26 +0800
Subject: [PATCH 1/2] [llvm][RISCV] Support mulh for P extension codegen

For mulh pattern with operands that are both signed and unsigned,
combination is performed automatically. However for mulh with operands
which are signed and unsigned respectively we need to combine them
manually same as we've done for PASUB*.
---
 llvm/lib/Target/RISCV/RISCVISelLowering.cpp |  74 +++++++-----
 llvm/lib/Target/RISCV/RISCVInstrInfoP.td    |  23 +++-
 llvm/test/CodeGen/RISCV/rvp-ext-rv32.ll     |  60 ++++++++++
 llvm/test/CodeGen/RISCV/rvp-ext-rv64.ll     | 119 ++++++++++++++++++++
 4 files changed, 240 insertions(+), 36 deletions(-)

diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index f28772a74d433..30c000814b59a 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -15210,18 +15210,22 @@ void RISCVTargetLowering::ReplaceNodeResults(SDNode *N,
     break;
   }
   case RISCVISD::PASUB:
-  case RISCVISD::PASUBU: {
+  case RISCVISD::PASUBU:
+  case RISCVISD::PMULHSU: {
     MVT VT = N->getSimpleValueType(0);
     SDValue Op0 = N->getOperand(0);
     SDValue Op1 = N->getOperand(1);
-    assert(VT == MVT::v2i16 || VT == MVT::v4i8);
+    unsigned Opcode = N->getOpcode();
+    // PMULHSU doesn't support i8 variants
+    assert(VT == MVT::v2i16 ||
+           (Opcode != RISCVISD::PMULHSU && VT == MVT::v4i8));
     MVT NewVT = MVT::v4i16;
     if (VT == MVT::v4i8)
       NewVT = MVT::v8i8;
     SDValue Undef = DAG.getUNDEF(VT);
     Op0 = DAG.getNode(ISD::CONCAT_VECTORS, DL, NewVT, {Op0, Undef});
     Op1 = DAG.getNode(ISD::CONCAT_VECTORS, DL, NewVT, {Op1, Undef});
-    Results.push_back(DAG.getNode(N->getOpcode(), DL, NewVT, {Op0, Op1}));
+    Results.push_back(DAG.getNode(Opcode, DL, NewVT, {Op0, Op1}));
     return;
   }
   case ISD::EXTRACT_VECTOR_ELT: {
@@ -16331,9 +16335,9 @@ static SDValue combineTruncSelectToSMaxUSat(SDNode *N, SelectionDAG &DAG) {
   return DAG.getNode(ISD::TRUNCATE, DL, VT, Min);
 }
 
-// Handle P extension averaging subtraction pattern:
-// (vXiY (trunc (srl (sub ([s|z]ext vXiY:$a), ([s|z]ext vXiY:$b)), 1)))
-// -> PASUB/PASUBU
+// Handle P extension truncate patterns:
+// PASUB/PASUBU: (trunc (srl (sub ([s|z]ext a), ([s|z]ext b)), 1))
+// PMULHSU: (trunc (srl (mul (sext a), (zext b)), EltBits))
 static SDValue combinePExtTruncate(SDNode *N, SelectionDAG &DAG,
                                    const RISCVSubtarget &Subtarget) {
   SDValue N0 = N->getOperand(0);
@@ -16346,7 +16350,7 @@ static SDValue combinePExtTruncate(SDNode *N, SelectionDAG &DAG,
       VecVT != MVT::v4i8 && VecVT != MVT::v2i32)
     return SDValue();
 
-  // Check if shift amount is 1
+  // Check if shift amount is a splat constant
   SDValue ShAmt = N0.getOperand(1);
   if (ShAmt.getOpcode() != ISD::BUILD_VECTOR)
     return SDValue();
@@ -16360,44 +16364,54 @@ static SDValue combinePExtTruncate(SDNode *N, SelectionDAG &DAG,
   ConstantSDNode *C = dyn_cast<ConstantSDNode>(Splat);
   if (!C)
     return SDValue();
-  if (C->getZExtValue() != 1)
-    return SDValue();
 
-  // Check for SUB operation
-  SDValue Sub = N0.getOperand(0);
-  if (Sub.getOpcode() != ISD::SUB)
-    return SDValue();
+  SDValue Op = N0.getOperand(0);
+  unsigned ShAmtVal = C->getZExtValue();
 
-  SDValue LHS = Sub.getOperand(0);
-  SDValue RHS = Sub.getOperand(1);
+  SDValue LHS = Op.getOperand(0);
+  SDValue RHS = Op.getOperand(1);
 
-  // Check if both operands are sign/zero extends from the target
-  // type
-  bool IsSignExt = LHS.getOpcode() == ISD::SIGN_EXTEND &&
-                   RHS.getOpcode() == ISD::SIGN_EXTEND;
-  bool IsZeroExt = LHS.getOpcode() == ISD::ZERO_EXTEND &&
-                   RHS.getOpcode() == ISD::ZERO_EXTEND;
+  bool LHSIsSExt = LHS.getOpcode() == ISD::SIGN_EXTEND;
+  bool LHSIsZExt = LHS.getOpcode() == ISD::ZERO_EXTEND;
+  bool RHSIsSExt = RHS.getOpcode() == ISD::SIGN_EXTEND;
+  bool RHSIsZExt = RHS.getOpcode() == ISD::ZERO_EXTEND;
 
-  if (!IsSignExt && !IsZeroExt)
+  if (!(LHSIsSExt || LHSIsZExt) || !(RHSIsSExt || RHSIsZExt))
     return SDValue();
 
   SDValue A = LHS.getOperand(0);
   SDValue B = RHS.getOperand(0);
 
-  // Check if the extends are from our target vector type
   if (A.getValueType() != VT || B.getValueType() != VT)
     return SDValue();
 
-  // Determine the instruction based on type and signedness
   unsigned Opc;
-  if (IsSignExt)
-    Opc = RISCVISD::PASUB;
-  else if (IsZeroExt)
-    Opc = RISCVISD::PASUBU;
-  else
+  switch (Op.getOpcode()) {
+  default:
     return SDValue();
+  case ISD::SUB:
+    // PASUB/PASUBU: shift amount must be 1
+    if (ShAmtVal != 1)
+      return SDValue();
+    if (LHSIsSExt && RHSIsSExt)
+      Opc = RISCVISD::PASUB;
+    else if (LHSIsZExt && RHSIsZExt)
+      Opc = RISCVISD::PASUBU;
+    else
+      return SDValue();
+    break;
+  case ISD::MUL:
+    // PMULHSU: shift amount must be element size, only for i16/i32
+    unsigned EltBits = VecVT.getScalarSizeInBits();
+    if (ShAmtVal != EltBits || (EltBits != 16 && EltBits != 32))
+      return SDValue();
+    if (LHSIsSExt && RHSIsZExt)
+      Opc = RISCVISD::PMULHSU;
+    else
+      return SDValue();
+    break;
+  }
 
-  // Create the machine node directly
   return DAG.getNode(Opc, SDLoc(N), VT, {A, B});
 }
 
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoP.td b/llvm/lib/Target/RISCV/RISCVInstrInfoP.td
index bba9f961b9639..485f1d984d96f 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoP.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoP.td
@@ -1463,12 +1463,13 @@ let Predicates = [HasStdExtP, IsRV32] in {
 
 def riscv_absw : RVSDNode<"ABSW", SDT_RISCVIntUnaryOpW>;
 
-def SDT_RISCVPASUB : SDTypeProfile<1, 2, [SDTCisVec<0>,
-                                          SDTCisInt<0>,
-                                          SDTCisSameAs<0, 1>,
-                                          SDTCisSameAs<0, 2>]>;
-def riscv_pasub : RVSDNode<"PASUB", SDT_RISCVPASUB>;
-def riscv_pasubu : RVSDNode<"PASUBU", SDT_RISCVPASUB>;
+def SDT_RISCVPBinOp : SDTypeProfile<1, 2, [SDTCisVec<0>,
+                                           SDTCisInt<0>,
+                                           SDTCisSameAs<0, 1>,
+                                           SDTCisSameAs<0, 2>]>;
+def riscv_pasub : RVSDNode<"PASUB", SDT_RISCVPBinOp>;
+def riscv_pasubu : RVSDNode<"PASUBU", SDT_RISCVPBinOp>;
+def riscv_pmulhsu : RVSDNode<"PMULHSU", SDT_RISCVPBinOp>;
 
 let Predicates = [HasStdExtP] in {
   def : PatGpr<abs, ABS>;
@@ -1513,6 +1514,11 @@ let Predicates = [HasStdExtP] in {
   def: Pat<(XLenVecI16VT (abds GPR:$rs1, GPR:$rs2)), (PABD_H GPR:$rs1, GPR:$rs2)>;
   def: Pat<(XLenVecI16VT (abdu GPR:$rs1, GPR:$rs2)), (PABDU_H GPR:$rs1, GPR:$rs2)>;
 
+  // 16-bit multiply high patterns
+  def: Pat<(XLenVecI16VT (mulhs GPR:$rs1, GPR:$rs2)), (PMULH_H GPR:$rs1, GPR:$rs2)>;
+  def: Pat<(XLenVecI16VT (mulhu GPR:$rs1, GPR:$rs2)), (PMULHU_H GPR:$rs1, GPR:$rs2)>;
+  def: Pat<(XLenVecI16VT (riscv_pmulhsu GPR:$rs1, GPR:$rs2)), (PMULHSU_H GPR:$rs1, GPR:$rs2)>;
+
   // 8-bit logical shift left patterns
   def: Pat<(XLenVecI8VT (shl GPR:$rs1, (XLenVecI8VT (splat_vector uimm3:$shamt)))),
            (PSLLI_B GPR:$rs1, uimm3:$shamt)>;
@@ -1580,6 +1586,11 @@ let Predicates = [HasStdExtP, IsRV64] in {
   def: Pat<(v2i32 (riscv_pasub GPR:$rs1, GPR:$rs2)), (PASUB_W GPR:$rs1, GPR:$rs2)>;
   def: Pat<(v2i32 (riscv_pasubu GPR:$rs1, GPR:$rs2)), (PASUBU_W GPR:$rs1, GPR:$rs2)>;
 
+  // 32-bit multiply high patterns
+  def: Pat<(v2i32 (mulhs GPR:$rs1, GPR:$rs2)), (PMULH_W GPR:$rs1, GPR:$rs2)>;
+  def: Pat<(v2i32 (mulhu GPR:$rs1, GPR:$rs2)), (PMULHU_W GPR:$rs1, GPR:$rs2)>;
+  def: Pat<(v2i32 (riscv_pmulhsu GPR:$rs1, GPR:$rs2)), (PMULHSU_W GPR:$rs1, GPR:$rs2)>;
+
   // 32-bit logical shift left
   def: Pat<(v2i32 (shl GPR:$rs1, (v2i32 (splat_vector (XLenVT GPR:$rs2))))),
            (PSLL_WS GPR:$rs1, GPR:$rs2)>;
diff --git a/llvm/test/CodeGen/RISCV/rvp-ext-rv32.ll b/llvm/test/CodeGen/RISCV/rvp-ext-rv32.ll
index cd59aa03597e2..87673d35a058a 100644
--- a/llvm/test/CodeGen/RISCV/rvp-ext-rv32.ll
+++ b/llvm/test/CodeGen/RISCV/rvp-ext-rv32.ll
@@ -746,3 +746,63 @@ define void @test_psll_bs_vec_shamt(ptr %ret_ptr, ptr %a_ptr, ptr %shamt_ptr) {
   store <4 x i8> %res, ptr %ret_ptr
   ret void
 }
+
+; Test packed multiply high signed for v2i16
+define void @test_pmulh_h(ptr %ret_ptr, ptr %a_ptr, ptr %b_ptr) {
+; CHECK-LABEL: test_pmulh_h:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    lw a1, 0(a1)
+; CHECK-NEXT:    lw a2, 0(a2)
+; CHECK-NEXT:    pmulh.h a1, a1, a2
+; CHECK-NEXT:    sw a1, 0(a0)
+; CHECK-NEXT:    ret
+  %a = load <2 x i16>, ptr %a_ptr
+  %b = load <2 x i16>, ptr %b_ptr
+  %a_ext = sext <2 x i16> %a to <2 x i32>
+  %b_ext = sext <2 x i16> %b to <2 x i32>
+  %mul = mul <2 x i32> %a_ext, %b_ext
+  %shift = lshr <2 x i32> %mul, <i32 16, i32 16>
+  %res = trunc <2 x i32> %shift to <2 x i16>
+  store <2 x i16> %res, ptr %ret_ptr
+  ret void
+}
+
+; Test packed multiply high unsigned for v2i16
+define void @test_pmulhu_h(ptr %ret_ptr, ptr %a_ptr, ptr %b_ptr) {
+; CHECK-LABEL: test_pmulhu_h:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    lw a1, 0(a1)
+; CHECK-NEXT:    lw a2, 0(a2)
+; CHECK-NEXT:    pmulhu.h a1, a1, a2
+; CHECK-NEXT:    sw a1, 0(a0)
+; CHECK-NEXT:    ret
+  %a = load <2 x i16>, ptr %a_ptr
+  %b = load <2 x i16>, ptr %b_ptr
+  %a_ext = zext <2 x i16> %a to <2 x i32>
+  %b_ext = zext <2 x i16> %b to <2 x i32>
+  %mul = mul <2 x i32> %a_ext, %b_ext
+  %shift = lshr <2 x i32> %mul, <i32 16, i32 16>
+  %res = trunc <2 x i32> %shift to <2 x i16>
+  store <2 x i16> %res, ptr %ret_ptr
+  ret void
+}
+
+; Test packed multiply high signed-unsigned for v2i16
+define void @test_pmulhsu_h(ptr %ret_ptr, ptr %a_ptr, ptr %b_ptr) {
+; CHECK-LABEL: test_pmulhsu_h:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    lw a1, 0(a1)
+; CHECK-NEXT:    lw a2, 0(a2)
+; CHECK-NEXT:    pmulhsu.h a1, a1, a2
+; CHECK-NEXT:    sw a1, 0(a0)
+; CHECK-NEXT:    ret
+  %a = load <2 x i16>, ptr %a_ptr
+  %b = load <2 x i16>, ptr %b_ptr
+  %a_ext = sext <2 x i16> %a to <2 x i32>
+  %b_ext = zext <2 x i16> %b to <2 x i32>
+  %mul = mul <2 x i32> %a_ext, %b_ext
+  %shift = lshr <2 x i32> %mul, <i32 16, i32 16>
+  %res = trunc <2 x i32> %shift to <2 x i16>
+  store <2 x i16> %res, ptr %ret_ptr
+  ret void
+}
diff --git a/llvm/test/CodeGen/RISCV/rvp-ext-rv64.ll b/llvm/test/CodeGen/RISCV/rvp-ext-rv64.ll
index c7fb891cdd996..9415fb0488bb7 100644
--- a/llvm/test/CodeGen/RISCV/rvp-ext-rv64.ll
+++ b/llvm/test/CodeGen/RISCV/rvp-ext-rv64.ll
@@ -841,3 +841,122 @@ define void @test_psll_ws_vec_shamt(ptr %ret_ptr, ptr %a_ptr, ptr %shamt_ptr) {
   store <2 x i32> %res, ptr %ret_ptr
   ret void
 }
+
+; Test packed multiply high signed
+define void @test_pmulh_h(ptr %ret_ptr, ptr %a_ptr, ptr %b_ptr) {
+; CHECK-LABEL: test_pmulh_h:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    ld a1, 0(a1)
+; CHECK-NEXT:    ld a2, 0(a2)
+; CHECK-NEXT:    pmulh.h a1, a1, a2
+; CHECK-NEXT:    sd a1, 0(a0)
+; CHECK-NEXT:    ret
+  %a = load <4 x i16>, ptr %a_ptr
+  %b = load <4 x i16>, ptr %b_ptr
+  %a_ext = sext <4 x i16> %a to <4 x i32>
+  %b_ext = sext <4 x i16> %b to <4 x i32>
+  %mul = mul <4 x i32> %a_ext, %b_ext
+  %shift = lshr <4 x i32> %mul, <i32 16, i32 16, i32 16, i32 16>
+  %res = trunc <4 x i32> %shift to <4 x i16>
+  store <4 x i16> %res, ptr %ret_ptr
+  ret void
+}
+
+define void @test_pmulh_w(ptr %ret_ptr, ptr %a_ptr, ptr %b_ptr) {
+; CHECK-LABEL: test_pmulh_w:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    ld a1, 0(a1)
+; CHECK-NEXT:    ld a2, 0(a2)
+; CHECK-NEXT:    pmulh.w a1, a1, a2
+; CHECK-NEXT:    sd a1, 0(a0)
+; CHECK-NEXT:    ret
+  %a = load <2 x i32>, ptr %a_ptr
+  %b = load <2 x i32>, ptr %b_ptr
+  %a_ext = sext <2 x i32> %a to <2 x i64>
+  %b_ext = sext <2 x i32> %b to <2 x i64>
+  %mul = mul <2 x i64> %a_ext, %b_ext
+  %shift = lshr <2 x i64> %mul, <i64 32, i64 32>
+  %res = trunc <2 x i64> %shift to <2 x i32>
+  store <2 x i32> %res, ptr %ret_ptr
+  ret void
+}
+
+
+; Test packed multiply high unsigned
+define void @test_pmulhu_h(ptr %ret_ptr, ptr %a_ptr, ptr %b_ptr) {
+; CHECK-LABEL: test_pmulhu_h:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    ld a1, 0(a1)
+; CHECK-NEXT:    ld a2, 0(a2)
+; CHECK-NEXT:    pmulhu.h a1, a1, a2
+; CHECK-NEXT:    sd a1, 0(a0)
+; CHECK-NEXT:    ret
+  %a = load <4 x i16>, ptr %a_ptr
+  %b = load <4 x i16>, ptr %b_ptr
+  %a_ext = zext <4 x i16> %a to <4 x i32>
+  %b_ext = zext <4 x i16> %b to <4 x i32>
+  %mul = mul <4 x i32> %a_ext, %b_ext
+  %shift = lshr <4 x i32> %mul, <i32 16, i32 16, i32 16, i32 16>
+  %res = trunc <4 x i32> %shift to <4 x i16>
+  store <4 x i16> %res, ptr %ret_ptr
+  ret void
+}
+
+define void @test_pmulhu_w(ptr %ret_ptr, ptr %a_ptr, ptr %b_ptr) {
+; CHECK-LABEL: test_pmulhu_w:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    ld a1, 0(a1)
+; CHECK-NEXT:    ld a2, 0(a2)
+; CHECK-NEXT:    pmulhu.w a1, a1, a2
+; CHECK-NEXT:    sd a1, 0(a0)
+; CHECK-NEXT:    ret
+  %a = load <2 x i32>, ptr %a_ptr
+  %b = load <2 x i32>, ptr %b_ptr
+  %a_ext = zext <2 x i32> %a to <2 x i64>
+  %b_ext = zext <2 x i32> %b to <2 x i64>
+  %mul = mul <2 x i64> %a_ext, %b_ext
+  %shift = lshr <2 x i64> %mul, <i64 32, i64 32>
+  %res = trunc <2 x i64> %shift to <2 x i32>
+  store <2 x i32> %res, ptr %ret_ptr
+  ret void
+}
+
+
+; Test packed multiply high signed-unsigned
+define void @test_pmulhsu_h(ptr %ret_ptr, ptr %a_ptr, ptr %b_ptr) {
+; CHECK-LABEL: test_pmulhsu_h:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    ld a1, 0(a1)
+; CHECK-NEXT:    ld a2, 0(a2)
+; CHECK-NEXT:    pmulhsu.h a1, a1, a2
+; CHECK-NEXT:    sd a1, 0(a0)
+; CHECK-NEXT:    ret
+  %a = load <4 x i16>, ptr %a_ptr
+  %b = load <4 x i16>, ptr %b_ptr
+  %a_ext = sext <4 x i16> %a to <4 x i32>
+  %b_ext = zext <4 x i16> %b to <4 x i32>
+  %mul = mul <4 x i32> %a_ext, %b_ext
+  %shift = lshr <4 x i32> %mul, <i32 16, i32 16, i32 16, i32 16>
+  %res = trunc <4 x i32> %shift to <4 x i16>
+  store <4 x i16> %res, ptr %ret_ptr
+  ret void
+}
+
+define void @test_pmulhsu_w(ptr %ret_ptr, ptr %a_ptr, ptr %b_ptr) {
+; CHECK-LABEL: test_pmulhsu_w:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    ld a1, 0(a1)
+; CHECK-NEXT:    ld a2, 0(a2)
+; CHECK-NEXT:    pmulhsu.w a1, a1, a2
+; CHECK-NEXT:    sd a1, 0(a0)
+; CHECK-NEXT:    ret
+  %a = load <2 x i32>, ptr %a_ptr
+  %b = load <2 x i32>, ptr %b_ptr
+  %a_ext = sext <2 x i32> %a to <2 x i64>
+  %b_ext = zext <2 x i32> %b to <2 x i64>
+  %mul = mul <2 x i64> %a_ext, %b_ext
+  %shift = lshr <2 x i64> %mul, <i64 32, i64 32>
+  %res = trunc <2 x i64> %shift to <2 x i32>
+  store <2 x i32> %res, ptr %ret_ptr
+  ret void
+}

>From 5b5936aab556cb723335e64e4adcdf349cfcc0ed Mon Sep 17 00:00:00 2001
From: Brandon Wu <songwu0813 at gmail.com>
Date: Wed, 10 Dec 2025 18:14:30 +0800
Subject: [PATCH 2/2] [llvm][RISCV] Support rounding mulh for P extension
 codegen

In p extension spec, rounding is performed by adding 1 << (elt_bits - 1)
to its result.

Stack on: #171581
---
 llvm/lib/Target/RISCV/RISCVISelLowering.cpp |  55 +++++++--
 llvm/lib/Target/RISCV/RISCVInstrInfoP.td    |  13 +++
 llvm/test/CodeGen/RISCV/rvp-ext-rv32.ll     |  63 ++++++++++
 llvm/test/CodeGen/RISCV/rvp-ext-rv64.ll     | 123 ++++++++++++++++++++
 4 files changed, 244 insertions(+), 10 deletions(-)

diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 30c000814b59a..c75ac76415714 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -15211,14 +15211,18 @@ void RISCVTargetLowering::ReplaceNodeResults(SDNode *N,
   }
   case RISCVISD::PASUB:
   case RISCVISD::PASUBU:
-  case RISCVISD::PMULHSU: {
+  case RISCVISD::PMULHSU:
+  case RISCVISD::PMULHR:
+  case RISCVISD::PMULHRU:
+  case RISCVISD::PMULHRSU: {
     MVT VT = N->getSimpleValueType(0);
     SDValue Op0 = N->getOperand(0);
     SDValue Op1 = N->getOperand(1);
     unsigned Opcode = N->getOpcode();
-    // PMULHSU doesn't support i8 variants
-    assert(VT == MVT::v2i16 ||
-           (Opcode != RISCVISD::PMULHSU && VT == MVT::v4i8));
+    // PMULH* variants don't support i8
+    bool IsMulH = Opcode == RISCVISD::PMULHSU || Opcode == RISCVISD::PMULHR ||
+                  Opcode == RISCVISD::PMULHRU || Opcode == RISCVISD::PMULHRSU;
+    assert(VT == MVT::v2i16 || (!IsMulH && VT == MVT::v4i8));
     MVT NewVT = MVT::v4i16;
     if (VT == MVT::v4i8)
       NewVT = MVT::v8i8;
@@ -16338,6 +16342,7 @@ static SDValue combineTruncSelectToSMaxUSat(SDNode *N, SelectionDAG &DAG) {
 // Handle P extension truncate patterns:
 // PASUB/PASUBU: (trunc (srl (sub ([s|z]ext a), ([s|z]ext b)), 1))
 // PMULHSU: (trunc (srl (mul (sext a), (zext b)), EltBits))
+// PMULHR*: (trunc (srl (add (mul (sext a), (zext b)), round_const), EltBits))
 static SDValue combinePExtTruncate(SDNode *N, SelectionDAG &DAG,
                                    const RISCVSubtarget &Subtarget) {
   SDValue N0 = N->getOperand(0);
@@ -16367,6 +16372,26 @@ static SDValue combinePExtTruncate(SDNode *N, SelectionDAG &DAG,
 
   SDValue Op = N0.getOperand(0);
   unsigned ShAmtVal = C->getZExtValue();
+  unsigned EltBits = VecVT.getScalarSizeInBits();
+
+  // Check for rounding pattern: (add (mul ...), round_const)
+  bool IsRounding = false;
+  if (Op.getOpcode() == ISD::ADD && (EltBits == 16 || EltBits == 32)) {
+    SDValue AddRHS = Op.getOperand(1);
+    if (AddRHS.getOpcode() == ISD::BUILD_VECTOR) {
+      if (auto *RndBV = dyn_cast<BuildVectorSDNode>(AddRHS.getNode())) {
+        if (auto *RndC =
+                dyn_cast_or_null<ConstantSDNode>(RndBV->getSplatValue())) {
+          uint64_t ExpectedRnd = 1ULL << (EltBits - 1);
+          if (RndC->getZExtValue() == ExpectedRnd &&
+              Op.getOperand(0).getOpcode() == ISD::MUL) {
+            Op = Op.getOperand(0);
+            IsRounding = true;
+          }
+        }
+      }
+    }
+  }
 
   SDValue LHS = Op.getOperand(0);
   SDValue RHS = Op.getOperand(1);
@@ -16401,14 +16426,24 @@ static SDValue combinePExtTruncate(SDNode *N, SelectionDAG &DAG,
       return SDValue();
     break;
   case ISD::MUL:
-    // PMULHSU: shift amount must be element size, only for i16/i32
-    unsigned EltBits = VecVT.getScalarSizeInBits();
+    // PMULH*/PMULHR*: shift amount must be element size, only for i16/i32
     if (ShAmtVal != EltBits || (EltBits != 16 && EltBits != 32))
       return SDValue();
-    if (LHSIsSExt && RHSIsZExt)
-      Opc = RISCVISD::PMULHSU;
-    else
-      return SDValue();
+    if (IsRounding) {
+      if (LHSIsSExt && RHSIsSExt)
+        Opc = RISCVISD::PMULHR;
+      else if (LHSIsZExt && RHSIsZExt)
+        Opc = RISCVISD::PMULHRU;
+      else if (LHSIsSExt && RHSIsZExt)
+        Opc = RISCVISD::PMULHRSU;
+      else
+        return SDValue();
+    } else {
+      if (LHSIsSExt && RHSIsZExt)
+        Opc = RISCVISD::PMULHSU;
+      else
+        return SDValue();
+    }
     break;
   }
 
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoP.td b/llvm/lib/Target/RISCV/RISCVInstrInfoP.td
index 485f1d984d96f..587ac8ee238f4 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoP.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoP.td
@@ -1470,6 +1470,9 @@ def SDT_RISCVPBinOp : SDTypeProfile<1, 2, [SDTCisVec<0>,
 def riscv_pasub : RVSDNode<"PASUB", SDT_RISCVPBinOp>;
 def riscv_pasubu : RVSDNode<"PASUBU", SDT_RISCVPBinOp>;
 def riscv_pmulhsu : RVSDNode<"PMULHSU", SDT_RISCVPBinOp>;
+def riscv_pmulhr : RVSDNode<"PMULHR", SDT_RISCVPBinOp>;
+def riscv_pmulhru : RVSDNode<"PMULHRU", SDT_RISCVPBinOp>;
+def riscv_pmulhrsu : RVSDNode<"PMULHRSU", SDT_RISCVPBinOp>;
 
 let Predicates = [HasStdExtP] in {
   def : PatGpr<abs, ABS>;
@@ -1519,6 +1522,11 @@ let Predicates = [HasStdExtP] in {
   def: Pat<(XLenVecI16VT (mulhu GPR:$rs1, GPR:$rs2)), (PMULHU_H GPR:$rs1, GPR:$rs2)>;
   def: Pat<(XLenVecI16VT (riscv_pmulhsu GPR:$rs1, GPR:$rs2)), (PMULHSU_H GPR:$rs1, GPR:$rs2)>;
 
+  // 16-bit multiply high rounding patterns
+  def: Pat<(XLenVecI16VT (riscv_pmulhr GPR:$rs1, GPR:$rs2)), (PMULHR_H GPR:$rs1, GPR:$rs2)>;
+  def: Pat<(XLenVecI16VT (riscv_pmulhru GPR:$rs1, GPR:$rs2)), (PMULHRU_H GPR:$rs1, GPR:$rs2)>;
+  def: Pat<(XLenVecI16VT (riscv_pmulhrsu GPR:$rs1, GPR:$rs2)), (PMULHRSU_H GPR:$rs1, GPR:$rs2)>;
+
   // 8-bit logical shift left patterns
   def: Pat<(XLenVecI8VT (shl GPR:$rs1, (XLenVecI8VT (splat_vector uimm3:$shamt)))),
            (PSLLI_B GPR:$rs1, uimm3:$shamt)>;
@@ -1591,6 +1599,11 @@ let Predicates = [HasStdExtP, IsRV64] in {
   def: Pat<(v2i32 (mulhu GPR:$rs1, GPR:$rs2)), (PMULHU_W GPR:$rs1, GPR:$rs2)>;
   def: Pat<(v2i32 (riscv_pmulhsu GPR:$rs1, GPR:$rs2)), (PMULHSU_W GPR:$rs1, GPR:$rs2)>;
 
+  // 32-bit multiply high rounding patterns
+  def: Pat<(v2i32 (riscv_pmulhr GPR:$rs1, GPR:$rs2)), (PMULHR_W GPR:$rs1, GPR:$rs2)>;
+  def: Pat<(v2i32 (riscv_pmulhru GPR:$rs1, GPR:$rs2)), (PMULHRU_W GPR:$rs1, GPR:$rs2)>;
+  def: Pat<(v2i32 (riscv_pmulhrsu GPR:$rs1, GPR:$rs2)), (PMULHRSU_W GPR:$rs1, GPR:$rs2)>;
+
   // 32-bit logical shift left
   def: Pat<(v2i32 (shl GPR:$rs1, (v2i32 (splat_vector (XLenVT GPR:$rs2))))),
            (PSLL_WS GPR:$rs1, GPR:$rs2)>;
diff --git a/llvm/test/CodeGen/RISCV/rvp-ext-rv32.ll b/llvm/test/CodeGen/RISCV/rvp-ext-rv32.ll
index 87673d35a058a..33127c3d140fa 100644
--- a/llvm/test/CodeGen/RISCV/rvp-ext-rv32.ll
+++ b/llvm/test/CodeGen/RISCV/rvp-ext-rv32.ll
@@ -806,3 +806,66 @@ define void @test_pmulhsu_h(ptr %ret_ptr, ptr %a_ptr, ptr %b_ptr) {
   store <2 x i16> %res, ptr %ret_ptr
   ret void
 }
+
+; Test packed multiply high rounding signed for v2i16
+define void @test_pmulhr_h(ptr %ret_ptr, ptr %a_ptr, ptr %b_ptr) {
+; CHECK-LABEL: test_pmulhr_h:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    lw a1, 0(a1)
+; CHECK-NEXT:    lw a2, 0(a2)
+; CHECK-NEXT:    pmulhr.h a1, a1, a2
+; CHECK-NEXT:    sw a1, 0(a0)
+; CHECK-NEXT:    ret
+  %a = load <2 x i16>, ptr %a_ptr
+  %b = load <2 x i16>, ptr %b_ptr
+  %a_ext = sext <2 x i16> %a to <2 x i32>
+  %b_ext = sext <2 x i16> %b to <2 x i32>
+  %mul = mul <2 x i32> %a_ext, %b_ext
+  %add = add <2 x i32> %mul, <i32 32768, i32 32768>
+  %shift = lshr <2 x i32> %add, <i32 16, i32 16>
+  %res = trunc <2 x i32> %shift to <2 x i16>
+  store <2 x i16> %res, ptr %ret_ptr
+  ret void
+}
+
+; Test packed multiply high rounding unsigned for v2i16
+define void @test_pmulhru_h(ptr %ret_ptr, ptr %a_ptr, ptr %b_ptr) {
+; CHECK-LABEL: test_pmulhru_h:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    lw a1, 0(a1)
+; CHECK-NEXT:    lw a2, 0(a2)
+; CHECK-NEXT:    pmulhru.h a1, a1, a2
+; CHECK-NEXT:    sw a1, 0(a0)
+; CHECK-NEXT:    ret
+  %a = load <2 x i16>, ptr %a_ptr
+  %b = load <2 x i16>, ptr %b_ptr
+  %a_ext = zext <2 x i16> %a to <2 x i32>
+  %b_ext = zext <2 x i16> %b to <2 x i32>
+  %mul = mul <2 x i32> %a_ext, %b_ext
+  %add = add <2 x i32> %mul, <i32 32768, i32 32768>
+  %shift = lshr <2 x i32> %add, <i32 16, i32 16>
+  %res = trunc <2 x i32> %shift to <2 x i16>
+  store <2 x i16> %res, ptr %ret_ptr
+  ret void
+}
+
+; Test packed multiply high rounding signed-unsigned for v2i16
+define void @test_pmulhrsu_h(ptr %ret_ptr, ptr %a_ptr, ptr %b_ptr) {
+; CHECK-LABEL: test_pmulhrsu_h:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    lw a1, 0(a1)
+; CHECK-NEXT:    lw a2, 0(a2)
+; CHECK-NEXT:    pmulhrsu.h a1, a1, a2
+; CHECK-NEXT:    sw a1, 0(a0)
+; CHECK-NEXT:    ret
+  %a = load <2 x i16>, ptr %a_ptr
+  %b = load <2 x i16>, ptr %b_ptr
+  %a_ext = sext <2 x i16> %a to <2 x i32>
+  %b_ext = zext <2 x i16> %b to <2 x i32>
+  %mul = mul <2 x i32> %a_ext, %b_ext
+  %add = add <2 x i32> %mul, <i32 32768, i32 32768>
+  %shift = lshr <2 x i32> %add, <i32 16, i32 16>
+  %res = trunc <2 x i32> %shift to <2 x i16>
+  store <2 x i16> %res, ptr %ret_ptr
+  ret void
+}
diff --git a/llvm/test/CodeGen/RISCV/rvp-ext-rv64.ll b/llvm/test/CodeGen/RISCV/rvp-ext-rv64.ll
index 9415fb0488bb7..8a741f5821b70 100644
--- a/llvm/test/CodeGen/RISCV/rvp-ext-rv64.ll
+++ b/llvm/test/CodeGen/RISCV/rvp-ext-rv64.ll
@@ -960,3 +960,126 @@ define void @test_pmulhsu_w(ptr %ret_ptr, ptr %a_ptr, ptr %b_ptr) {
   store <2 x i32> %res, ptr %ret_ptr
   ret void
 }
+
+; Test packed multiply high rounding signed
+define void @test_pmulhr_h(ptr %ret_ptr, ptr %a_ptr, ptr %b_ptr) {
+; CHECK-LABEL: test_pmulhr_h:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    ld a1, 0(a1)
+; CHECK-NEXT:    ld a2, 0(a2)
+; CHECK-NEXT:    pmulhr.h a1, a1, a2
+; CHECK-NEXT:    sd a1, 0(a0)
+; CHECK-NEXT:    ret
+  %a = load <4 x i16>, ptr %a_ptr
+  %b = load <4 x i16>, ptr %b_ptr
+  %a_ext = sext <4 x i16> %a to <4 x i32>
+  %b_ext = sext <4 x i16> %b to <4 x i32>
+  %mul = mul <4 x i32> %a_ext, %b_ext
+  %add = add <4 x i32> %mul, <i32 32768, i32 32768, i32 32768, i32 32768>
+  %shift = lshr <4 x i32> %add, <i32 16, i32 16, i32 16, i32 16>
+  %res = trunc <4 x i32> %shift to <4 x i16>
+  store <4 x i16> %res, ptr %ret_ptr
+  ret void
+}
+
+define void @test_pmulhr_w(ptr %ret_ptr, ptr %a_ptr, ptr %b_ptr) {
+; CHECK-LABEL: test_pmulhr_w:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    ld a1, 0(a1)
+; CHECK-NEXT:    ld a2, 0(a2)
+; CHECK-NEXT:    pmulhr.w a1, a1, a2
+; CHECK-NEXT:    sd a1, 0(a0)
+; CHECK-NEXT:    ret
+  %a = load <2 x i32>, ptr %a_ptr
+  %b = load <2 x i32>, ptr %b_ptr
+  %a_ext = sext <2 x i32> %a to <2 x i64>
+  %b_ext = sext <2 x i32> %b to <2 x i64>
+  %mul = mul <2 x i64> %a_ext, %b_ext
+  %add = add <2 x i64> %mul, <i64 2147483648, i64 2147483648>
+  %shift = lshr <2 x i64> %add, <i64 32, i64 32>
+  %res = trunc <2 x i64> %shift to <2 x i32>
+  store <2 x i32> %res, ptr %ret_ptr
+  ret void
+}
+
+; Test packed multiply high rounding unsigned
+define void @test_pmulhru_h(ptr %ret_ptr, ptr %a_ptr, ptr %b_ptr) {
+; CHECK-LABEL: test_pmulhru_h:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    ld a1, 0(a1)
+; CHECK-NEXT:    ld a2, 0(a2)
+; CHECK-NEXT:    pmulhru.h a1, a1, a2
+; CHECK-NEXT:    sd a1, 0(a0)
+; CHECK-NEXT:    ret
+  %a = load <4 x i16>, ptr %a_ptr
+  %b = load <4 x i16>, ptr %b_ptr
+  %a_ext = zext <4 x i16> %a to <4 x i32>
+  %b_ext = zext <4 x i16> %b to <4 x i32>
+  %mul = mul <4 x i32> %a_ext, %b_ext
+  %add = add <4 x i32> %mul, <i32 32768, i32 32768, i32 32768, i32 32768>
+  %shift = lshr <4 x i32> %add, <i32 16, i32 16, i32 16, i32 16>
+  %res = trunc <4 x i32> %shift to <4 x i16>
+  store <4 x i16> %res, ptr %ret_ptr
+  ret void
+}
+
+define void @test_pmulhru_w(ptr %ret_ptr, ptr %a_ptr, ptr %b_ptr) {
+; CHECK-LABEL: test_pmulhru_w:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    ld a1, 0(a1)
+; CHECK-NEXT:    ld a2, 0(a2)
+; CHECK-NEXT:    pmulhru.w a1, a1, a2
+; CHECK-NEXT:    sd a1, 0(a0)
+; CHECK-NEXT:    ret
+  %a = load <2 x i32>, ptr %a_ptr
+  %b = load <2 x i32>, ptr %b_ptr
+  %a_ext = zext <2 x i32> %a to <2 x i64>
+  %b_ext = zext <2 x i32> %b to <2 x i64>
+  %mul = mul <2 x i64> %a_ext, %b_ext
+  %add = add <2 x i64> %mul, <i64 2147483648, i64 2147483648>
+  %shift = lshr <2 x i64> %add, <i64 32, i64 32>
+  %res = trunc <2 x i64> %shift to <2 x i32>
+  store <2 x i32> %res, ptr %ret_ptr
+  ret void
+}
+
+; Test packed multiply high rounding signed-unsigned
+define void @test_pmulhrsu_h(ptr %ret_ptr, ptr %a_ptr, ptr %b_ptr) {
+; CHECK-LABEL: test_pmulhrsu_h:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    ld a1, 0(a1)
+; CHECK-NEXT:    ld a2, 0(a2)
+; CHECK-NEXT:    pmulhrsu.h a1, a1, a2
+; CHECK-NEXT:    sd a1, 0(a0)
+; CHECK-NEXT:    ret
+  %a = load <4 x i16>, ptr %a_ptr
+  %b = load <4 x i16>, ptr %b_ptr
+  %a_ext = sext <4 x i16> %a to <4 x i32>
+  %b_ext = zext <4 x i16> %b to <4 x i32>
+  %mul = mul <4 x i32> %a_ext, %b_ext
+  %add = add <4 x i32> %mul, <i32 32768, i32 32768, i32 32768, i32 32768>
+  %shift = lshr <4 x i32> %add, <i32 16, i32 16, i32 16, i32 16>
+  %res = trunc <4 x i32> %shift to <4 x i16>
+  store <4 x i16> %res, ptr %ret_ptr
+  ret void
+}
+
+define void @test_pmulhrsu_w(ptr %ret_ptr, ptr %a_ptr, ptr %b_ptr) {
+; CHECK-LABEL: test_pmulhrsu_w:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    ld a1, 0(a1)
+; CHECK-NEXT:    ld a2, 0(a2)
+; CHECK-NEXT:    pmulhrsu.w a1, a1, a2
+; CHECK-NEXT:    sd a1, 0(a0)
+; CHECK-NEXT:    ret
+  %a = load <2 x i32>, ptr %a_ptr
+  %b = load <2 x i32>, ptr %b_ptr
+  %a_ext = sext <2 x i32> %a to <2 x i64>
+  %b_ext = zext <2 x i32> %b to <2 x i64>
+  %mul = mul <2 x i64> %a_ext, %b_ext
+  %add = add <2 x i64> %mul, <i64 2147483648, i64 2147483648>
+  %shift = lshr <2 x i64> %add, <i64 32, i64 32>
+  %res = trunc <2 x i64> %shift to <2 x i32>
+  store <2 x i32> %res, ptr %ret_ptr
+  ret void
+}



More information about the llvm-commits mailing list