[llvm] [llvm][RISCV] Support mulh for P extension codegen (PR #171581)

Brandon Wu via llvm-commits llvm-commits at lists.llvm.org
Wed Dec 10 19:55:52 PST 2025


https://github.com/4vtomat updated https://github.com/llvm/llvm-project/pull/171581

>From 8f43e66917c5eaa84c4174cfdf27c4e688359da9 Mon Sep 17 00:00:00 2001
From: Brandon Wu <songwu0813 at gmail.com>
Date: Wed, 10 Dec 2025 15:42:26 +0800
Subject: [PATCH 1/2] [llvm][RISCV] Support mulh for P extension codegen

For mulh pattern with operands that are both signed and unsigned,
combination is performed automatically. However for mulh with operands
which are signed and unsigned respectively we need to combine them
manually same as we've done for PASUB*.
---
 llvm/lib/Target/RISCV/RISCVISelLowering.cpp |  74 +++++++-----
 llvm/lib/Target/RISCV/RISCVInstrInfoP.td    |  23 +++-
 llvm/test/CodeGen/RISCV/rvp-ext-rv32.ll     |  60 ++++++++++
 llvm/test/CodeGen/RISCV/rvp-ext-rv64.ll     | 119 ++++++++++++++++++++
 4 files changed, 240 insertions(+), 36 deletions(-)

diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index f28772a74d433..30c000814b59a 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -15210,18 +15210,22 @@ void RISCVTargetLowering::ReplaceNodeResults(SDNode *N,
     break;
   }
   case RISCVISD::PASUB:
-  case RISCVISD::PASUBU: {
+  case RISCVISD::PASUBU:
+  case RISCVISD::PMULHSU: {
     MVT VT = N->getSimpleValueType(0);
     SDValue Op0 = N->getOperand(0);
     SDValue Op1 = N->getOperand(1);
-    assert(VT == MVT::v2i16 || VT == MVT::v4i8);
+    unsigned Opcode = N->getOpcode();
+    // PMULHSU doesn't support i8 variants
+    assert(VT == MVT::v2i16 ||
+           (Opcode != RISCVISD::PMULHSU && VT == MVT::v4i8));
     MVT NewVT = MVT::v4i16;
     if (VT == MVT::v4i8)
       NewVT = MVT::v8i8;
     SDValue Undef = DAG.getUNDEF(VT);
     Op0 = DAG.getNode(ISD::CONCAT_VECTORS, DL, NewVT, {Op0, Undef});
     Op1 = DAG.getNode(ISD::CONCAT_VECTORS, DL, NewVT, {Op1, Undef});
-    Results.push_back(DAG.getNode(N->getOpcode(), DL, NewVT, {Op0, Op1}));
+    Results.push_back(DAG.getNode(Opcode, DL, NewVT, {Op0, Op1}));
     return;
   }
   case ISD::EXTRACT_VECTOR_ELT: {
@@ -16331,9 +16335,9 @@ static SDValue combineTruncSelectToSMaxUSat(SDNode *N, SelectionDAG &DAG) {
   return DAG.getNode(ISD::TRUNCATE, DL, VT, Min);
 }
 
-// Handle P extension averaging subtraction pattern:
-// (vXiY (trunc (srl (sub ([s|z]ext vXiY:$a), ([s|z]ext vXiY:$b)), 1)))
-// -> PASUB/PASUBU
+// Handle P extension truncate patterns:
+// PASUB/PASUBU: (trunc (srl (sub ([s|z]ext a), ([s|z]ext b)), 1))
+// PMULHSU: (trunc (srl (mul (sext a), (zext b)), EltBits))
 static SDValue combinePExtTruncate(SDNode *N, SelectionDAG &DAG,
                                    const RISCVSubtarget &Subtarget) {
   SDValue N0 = N->getOperand(0);
@@ -16346,7 +16350,7 @@ static SDValue combinePExtTruncate(SDNode *N, SelectionDAG &DAG,
       VecVT != MVT::v4i8 && VecVT != MVT::v2i32)
     return SDValue();
 
-  // Check if shift amount is 1
+  // Check if shift amount is a splat constant
   SDValue ShAmt = N0.getOperand(1);
   if (ShAmt.getOpcode() != ISD::BUILD_VECTOR)
     return SDValue();
@@ -16360,44 +16364,54 @@ static SDValue combinePExtTruncate(SDNode *N, SelectionDAG &DAG,
   ConstantSDNode *C = dyn_cast<ConstantSDNode>(Splat);
   if (!C)
     return SDValue();
-  if (C->getZExtValue() != 1)
-    return SDValue();
 
-  // Check for SUB operation
-  SDValue Sub = N0.getOperand(0);
-  if (Sub.getOpcode() != ISD::SUB)
-    return SDValue();
+  SDValue Op = N0.getOperand(0);
+  unsigned ShAmtVal = C->getZExtValue();
 
-  SDValue LHS = Sub.getOperand(0);
-  SDValue RHS = Sub.getOperand(1);
+  SDValue LHS = Op.getOperand(0);
+  SDValue RHS = Op.getOperand(1);
 
-  // Check if both operands are sign/zero extends from the target
-  // type
-  bool IsSignExt = LHS.getOpcode() == ISD::SIGN_EXTEND &&
-                   RHS.getOpcode() == ISD::SIGN_EXTEND;
-  bool IsZeroExt = LHS.getOpcode() == ISD::ZERO_EXTEND &&
-                   RHS.getOpcode() == ISD::ZERO_EXTEND;
+  bool LHSIsSExt = LHS.getOpcode() == ISD::SIGN_EXTEND;
+  bool LHSIsZExt = LHS.getOpcode() == ISD::ZERO_EXTEND;
+  bool RHSIsSExt = RHS.getOpcode() == ISD::SIGN_EXTEND;
+  bool RHSIsZExt = RHS.getOpcode() == ISD::ZERO_EXTEND;
 
-  if (!IsSignExt && !IsZeroExt)
+  if (!(LHSIsSExt || LHSIsZExt) || !(RHSIsSExt || RHSIsZExt))
     return SDValue();
 
   SDValue A = LHS.getOperand(0);
   SDValue B = RHS.getOperand(0);
 
-  // Check if the extends are from our target vector type
   if (A.getValueType() != VT || B.getValueType() != VT)
     return SDValue();
 
-  // Determine the instruction based on type and signedness
   unsigned Opc;
-  if (IsSignExt)
-    Opc = RISCVISD::PASUB;
-  else if (IsZeroExt)
-    Opc = RISCVISD::PASUBU;
-  else
+  switch (Op.getOpcode()) {
+  default:
     return SDValue();
+  case ISD::SUB:
+    // PASUB/PASUBU: shift amount must be 1
+    if (ShAmtVal != 1)
+      return SDValue();
+    if (LHSIsSExt && RHSIsSExt)
+      Opc = RISCVISD::PASUB;
+    else if (LHSIsZExt && RHSIsZExt)
+      Opc = RISCVISD::PASUBU;
+    else
+      return SDValue();
+    break;
+  case ISD::MUL:
+    // PMULHSU: shift amount must be element size, only for i16/i32
+    unsigned EltBits = VecVT.getScalarSizeInBits();
+    if (ShAmtVal != EltBits || (EltBits != 16 && EltBits != 32))
+      return SDValue();
+    if (LHSIsSExt && RHSIsZExt)
+      Opc = RISCVISD::PMULHSU;
+    else
+      return SDValue();
+    break;
+  }
 
-  // Create the machine node directly
   return DAG.getNode(Opc, SDLoc(N), VT, {A, B});
 }
 
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoP.td b/llvm/lib/Target/RISCV/RISCVInstrInfoP.td
index bba9f961b9639..485f1d984d96f 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoP.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoP.td
@@ -1463,12 +1463,13 @@ let Predicates = [HasStdExtP, IsRV32] in {
 
 def riscv_absw : RVSDNode<"ABSW", SDT_RISCVIntUnaryOpW>;
 
-def SDT_RISCVPASUB : SDTypeProfile<1, 2, [SDTCisVec<0>,
-                                          SDTCisInt<0>,
-                                          SDTCisSameAs<0, 1>,
-                                          SDTCisSameAs<0, 2>]>;
-def riscv_pasub : RVSDNode<"PASUB", SDT_RISCVPASUB>;
-def riscv_pasubu : RVSDNode<"PASUBU", SDT_RISCVPASUB>;
+def SDT_RISCVPBinOp : SDTypeProfile<1, 2, [SDTCisVec<0>,
+                                           SDTCisInt<0>,
+                                           SDTCisSameAs<0, 1>,
+                                           SDTCisSameAs<0, 2>]>;
+def riscv_pasub : RVSDNode<"PASUB", SDT_RISCVPBinOp>;
+def riscv_pasubu : RVSDNode<"PASUBU", SDT_RISCVPBinOp>;
+def riscv_pmulhsu : RVSDNode<"PMULHSU", SDT_RISCVPBinOp>;
 
 let Predicates = [HasStdExtP] in {
   def : PatGpr<abs, ABS>;
@@ -1513,6 +1514,11 @@ let Predicates = [HasStdExtP] in {
   def: Pat<(XLenVecI16VT (abds GPR:$rs1, GPR:$rs2)), (PABD_H GPR:$rs1, GPR:$rs2)>;
   def: Pat<(XLenVecI16VT (abdu GPR:$rs1, GPR:$rs2)), (PABDU_H GPR:$rs1, GPR:$rs2)>;
 
+  // 16-bit multiply high patterns
+  def: Pat<(XLenVecI16VT (mulhs GPR:$rs1, GPR:$rs2)), (PMULH_H GPR:$rs1, GPR:$rs2)>;
+  def: Pat<(XLenVecI16VT (mulhu GPR:$rs1, GPR:$rs2)), (PMULHU_H GPR:$rs1, GPR:$rs2)>;
+  def: Pat<(XLenVecI16VT (riscv_pmulhsu GPR:$rs1, GPR:$rs2)), (PMULHSU_H GPR:$rs1, GPR:$rs2)>;
+
   // 8-bit logical shift left patterns
   def: Pat<(XLenVecI8VT (shl GPR:$rs1, (XLenVecI8VT (splat_vector uimm3:$shamt)))),
            (PSLLI_B GPR:$rs1, uimm3:$shamt)>;
@@ -1580,6 +1586,11 @@ let Predicates = [HasStdExtP, IsRV64] in {
   def: Pat<(v2i32 (riscv_pasub GPR:$rs1, GPR:$rs2)), (PASUB_W GPR:$rs1, GPR:$rs2)>;
   def: Pat<(v2i32 (riscv_pasubu GPR:$rs1, GPR:$rs2)), (PASUBU_W GPR:$rs1, GPR:$rs2)>;
 
+  // 32-bit multiply high patterns
+  def: Pat<(v2i32 (mulhs GPR:$rs1, GPR:$rs2)), (PMULH_W GPR:$rs1, GPR:$rs2)>;
+  def: Pat<(v2i32 (mulhu GPR:$rs1, GPR:$rs2)), (PMULHU_W GPR:$rs1, GPR:$rs2)>;
+  def: Pat<(v2i32 (riscv_pmulhsu GPR:$rs1, GPR:$rs2)), (PMULHSU_W GPR:$rs1, GPR:$rs2)>;
+
   // 32-bit logical shift left
   def: Pat<(v2i32 (shl GPR:$rs1, (v2i32 (splat_vector (XLenVT GPR:$rs2))))),
            (PSLL_WS GPR:$rs1, GPR:$rs2)>;
diff --git a/llvm/test/CodeGen/RISCV/rvp-ext-rv32.ll b/llvm/test/CodeGen/RISCV/rvp-ext-rv32.ll
index cd59aa03597e2..87673d35a058a 100644
--- a/llvm/test/CodeGen/RISCV/rvp-ext-rv32.ll
+++ b/llvm/test/CodeGen/RISCV/rvp-ext-rv32.ll
@@ -746,3 +746,63 @@ define void @test_psll_bs_vec_shamt(ptr %ret_ptr, ptr %a_ptr, ptr %shamt_ptr) {
   store <4 x i8> %res, ptr %ret_ptr
   ret void
 }
+
+; Test packed multiply high signed for v2i16
+define void @test_pmulh_h(ptr %ret_ptr, ptr %a_ptr, ptr %b_ptr) {
+; CHECK-LABEL: test_pmulh_h:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    lw a1, 0(a1)
+; CHECK-NEXT:    lw a2, 0(a2)
+; CHECK-NEXT:    pmulh.h a1, a1, a2
+; CHECK-NEXT:    sw a1, 0(a0)
+; CHECK-NEXT:    ret
+  %a = load <2 x i16>, ptr %a_ptr
+  %b = load <2 x i16>, ptr %b_ptr
+  %a_ext = sext <2 x i16> %a to <2 x i32>
+  %b_ext = sext <2 x i16> %b to <2 x i32>
+  %mul = mul <2 x i32> %a_ext, %b_ext
+  %shift = lshr <2 x i32> %mul, <i32 16, i32 16>
+  %res = trunc <2 x i32> %shift to <2 x i16>
+  store <2 x i16> %res, ptr %ret_ptr
+  ret void
+}
+
+; Test packed multiply high unsigned for v2i16
+define void @test_pmulhu_h(ptr %ret_ptr, ptr %a_ptr, ptr %b_ptr) {
+; CHECK-LABEL: test_pmulhu_h:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    lw a1, 0(a1)
+; CHECK-NEXT:    lw a2, 0(a2)
+; CHECK-NEXT:    pmulhu.h a1, a1, a2
+; CHECK-NEXT:    sw a1, 0(a0)
+; CHECK-NEXT:    ret
+  %a = load <2 x i16>, ptr %a_ptr
+  %b = load <2 x i16>, ptr %b_ptr
+  %a_ext = zext <2 x i16> %a to <2 x i32>
+  %b_ext = zext <2 x i16> %b to <2 x i32>
+  %mul = mul <2 x i32> %a_ext, %b_ext
+  %shift = lshr <2 x i32> %mul, <i32 16, i32 16>
+  %res = trunc <2 x i32> %shift to <2 x i16>
+  store <2 x i16> %res, ptr %ret_ptr
+  ret void
+}
+
+; Test packed multiply high signed-unsigned for v2i16
+define void @test_pmulhsu_h(ptr %ret_ptr, ptr %a_ptr, ptr %b_ptr) {
+; CHECK-LABEL: test_pmulhsu_h:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    lw a1, 0(a1)
+; CHECK-NEXT:    lw a2, 0(a2)
+; CHECK-NEXT:    pmulhsu.h a1, a1, a2
+; CHECK-NEXT:    sw a1, 0(a0)
+; CHECK-NEXT:    ret
+  %a = load <2 x i16>, ptr %a_ptr
+  %b = load <2 x i16>, ptr %b_ptr
+  %a_ext = sext <2 x i16> %a to <2 x i32>
+  %b_ext = zext <2 x i16> %b to <2 x i32>
+  %mul = mul <2 x i32> %a_ext, %b_ext
+  %shift = lshr <2 x i32> %mul, <i32 16, i32 16>
+  %res = trunc <2 x i32> %shift to <2 x i16>
+  store <2 x i16> %res, ptr %ret_ptr
+  ret void
+}
diff --git a/llvm/test/CodeGen/RISCV/rvp-ext-rv64.ll b/llvm/test/CodeGen/RISCV/rvp-ext-rv64.ll
index c7fb891cdd996..9415fb0488bb7 100644
--- a/llvm/test/CodeGen/RISCV/rvp-ext-rv64.ll
+++ b/llvm/test/CodeGen/RISCV/rvp-ext-rv64.ll
@@ -841,3 +841,122 @@ define void @test_psll_ws_vec_shamt(ptr %ret_ptr, ptr %a_ptr, ptr %shamt_ptr) {
   store <2 x i32> %res, ptr %ret_ptr
   ret void
 }
+
+; Test packed multiply high signed
+define void @test_pmulh_h(ptr %ret_ptr, ptr %a_ptr, ptr %b_ptr) {
+; CHECK-LABEL: test_pmulh_h:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    ld a1, 0(a1)
+; CHECK-NEXT:    ld a2, 0(a2)
+; CHECK-NEXT:    pmulh.h a1, a1, a2
+; CHECK-NEXT:    sd a1, 0(a0)
+; CHECK-NEXT:    ret
+  %a = load <4 x i16>, ptr %a_ptr
+  %b = load <4 x i16>, ptr %b_ptr
+  %a_ext = sext <4 x i16> %a to <4 x i32>
+  %b_ext = sext <4 x i16> %b to <4 x i32>
+  %mul = mul <4 x i32> %a_ext, %b_ext
+  %shift = lshr <4 x i32> %mul, <i32 16, i32 16, i32 16, i32 16>
+  %res = trunc <4 x i32> %shift to <4 x i16>
+  store <4 x i16> %res, ptr %ret_ptr
+  ret void
+}
+
+define void @test_pmulh_w(ptr %ret_ptr, ptr %a_ptr, ptr %b_ptr) {
+; CHECK-LABEL: test_pmulh_w:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    ld a1, 0(a1)
+; CHECK-NEXT:    ld a2, 0(a2)
+; CHECK-NEXT:    pmulh.w a1, a1, a2
+; CHECK-NEXT:    sd a1, 0(a0)
+; CHECK-NEXT:    ret
+  %a = load <2 x i32>, ptr %a_ptr
+  %b = load <2 x i32>, ptr %b_ptr
+  %a_ext = sext <2 x i32> %a to <2 x i64>
+  %b_ext = sext <2 x i32> %b to <2 x i64>
+  %mul = mul <2 x i64> %a_ext, %b_ext
+  %shift = lshr <2 x i64> %mul, <i64 32, i64 32>
+  %res = trunc <2 x i64> %shift to <2 x i32>
+  store <2 x i32> %res, ptr %ret_ptr
+  ret void
+}
+
+
+; Test packed multiply high unsigned
+define void @test_pmulhu_h(ptr %ret_ptr, ptr %a_ptr, ptr %b_ptr) {
+; CHECK-LABEL: test_pmulhu_h:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    ld a1, 0(a1)
+; CHECK-NEXT:    ld a2, 0(a2)
+; CHECK-NEXT:    pmulhu.h a1, a1, a2
+; CHECK-NEXT:    sd a1, 0(a0)
+; CHECK-NEXT:    ret
+  %a = load <4 x i16>, ptr %a_ptr
+  %b = load <4 x i16>, ptr %b_ptr
+  %a_ext = zext <4 x i16> %a to <4 x i32>
+  %b_ext = zext <4 x i16> %b to <4 x i32>
+  %mul = mul <4 x i32> %a_ext, %b_ext
+  %shift = lshr <4 x i32> %mul, <i32 16, i32 16, i32 16, i32 16>
+  %res = trunc <4 x i32> %shift to <4 x i16>
+  store <4 x i16> %res, ptr %ret_ptr
+  ret void
+}
+
+define void @test_pmulhu_w(ptr %ret_ptr, ptr %a_ptr, ptr %b_ptr) {
+; CHECK-LABEL: test_pmulhu_w:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    ld a1, 0(a1)
+; CHECK-NEXT:    ld a2, 0(a2)
+; CHECK-NEXT:    pmulhu.w a1, a1, a2
+; CHECK-NEXT:    sd a1, 0(a0)
+; CHECK-NEXT:    ret
+  %a = load <2 x i32>, ptr %a_ptr
+  %b = load <2 x i32>, ptr %b_ptr
+  %a_ext = zext <2 x i32> %a to <2 x i64>
+  %b_ext = zext <2 x i32> %b to <2 x i64>
+  %mul = mul <2 x i64> %a_ext, %b_ext
+  %shift = lshr <2 x i64> %mul, <i64 32, i64 32>
+  %res = trunc <2 x i64> %shift to <2 x i32>
+  store <2 x i32> %res, ptr %ret_ptr
+  ret void
+}
+
+
+; Test packed multiply high signed-unsigned
+define void @test_pmulhsu_h(ptr %ret_ptr, ptr %a_ptr, ptr %b_ptr) {
+; CHECK-LABEL: test_pmulhsu_h:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    ld a1, 0(a1)
+; CHECK-NEXT:    ld a2, 0(a2)
+; CHECK-NEXT:    pmulhsu.h a1, a1, a2
+; CHECK-NEXT:    sd a1, 0(a0)
+; CHECK-NEXT:    ret
+  %a = load <4 x i16>, ptr %a_ptr
+  %b = load <4 x i16>, ptr %b_ptr
+  %a_ext = sext <4 x i16> %a to <4 x i32>
+  %b_ext = zext <4 x i16> %b to <4 x i32>
+  %mul = mul <4 x i32> %a_ext, %b_ext
+  %shift = lshr <4 x i32> %mul, <i32 16, i32 16, i32 16, i32 16>
+  %res = trunc <4 x i32> %shift to <4 x i16>
+  store <4 x i16> %res, ptr %ret_ptr
+  ret void
+}
+
+define void @test_pmulhsu_w(ptr %ret_ptr, ptr %a_ptr, ptr %b_ptr) {
+; CHECK-LABEL: test_pmulhsu_w:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    ld a1, 0(a1)
+; CHECK-NEXT:    ld a2, 0(a2)
+; CHECK-NEXT:    pmulhsu.w a1, a1, a2
+; CHECK-NEXT:    sd a1, 0(a0)
+; CHECK-NEXT:    ret
+  %a = load <2 x i32>, ptr %a_ptr
+  %b = load <2 x i32>, ptr %b_ptr
+  %a_ext = sext <2 x i32> %a to <2 x i64>
+  %b_ext = zext <2 x i32> %b to <2 x i64>
+  %mul = mul <2 x i64> %a_ext, %b_ext
+  %shift = lshr <2 x i64> %mul, <i64 32, i64 32>
+  %res = trunc <2 x i64> %shift to <2 x i32>
+  store <2 x i32> %res, ptr %ret_ptr
+  ret void
+}

>From 36c56205398f5be8a21df27ef63446ee244c83cc Mon Sep 17 00:00:00 2001
From: Brandon Wu <songwu0813 at gmail.com>
Date: Wed, 10 Dec 2025 19:55:40 -0800
Subject: [PATCH 2/2] fixup! commuted case

---
 llvm/lib/Target/RISCV/RISCVISelLowering.cpp |  6 +++-
 llvm/test/CodeGen/RISCV/rvp-ext-rv32.ll     | 19 ++++++++++
 llvm/test/CodeGen/RISCV/rvp-ext-rv64.ll     | 39 ++++++++++++++++++++-
 3 files changed, 62 insertions(+), 2 deletions(-)

diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 30c000814b59a..a5dffdf7ce53d 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -16405,8 +16405,12 @@ static SDValue combinePExtTruncate(SDNode *N, SelectionDAG &DAG,
     unsigned EltBits = VecVT.getScalarSizeInBits();
     if (ShAmtVal != EltBits || (EltBits != 16 && EltBits != 32))
       return SDValue();
-    if (LHSIsSExt && RHSIsZExt)
+    if ((LHSIsSExt && RHSIsZExt) || (LHSIsZExt && RHSIsSExt)) {
       Opc = RISCVISD::PMULHSU;
+      // commuted case
+      if (LHSIsZExt && RHSIsSExt)
+        std::swap(A, B);
+    }
     else
       return SDValue();
     break;
diff --git a/llvm/test/CodeGen/RISCV/rvp-ext-rv32.ll b/llvm/test/CodeGen/RISCV/rvp-ext-rv32.ll
index 87673d35a058a..ec52642b19e7e 100644
--- a/llvm/test/CodeGen/RISCV/rvp-ext-rv32.ll
+++ b/llvm/test/CodeGen/RISCV/rvp-ext-rv32.ll
@@ -806,3 +806,22 @@ define void @test_pmulhsu_h(ptr %ret_ptr, ptr %a_ptr, ptr %b_ptr) {
   store <2 x i16> %res, ptr %ret_ptr
   ret void
 }
+
+define void @test_pmulhsu_h_commuted(ptr %ret_ptr, ptr %a_ptr, ptr %b_ptr) {
+; CHECK-LABEL: test_pmulhsu_h_commuted:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    lw a1, 0(a1)
+; CHECK-NEXT:    lw a2, 0(a2)
+; CHECK-NEXT:    pmulhsu.h a1, a2, a1
+; CHECK-NEXT:    sw a1, 0(a0)
+; CHECK-NEXT:    ret
+  %a = load <2 x i16>, ptr %a_ptr
+  %b = load <2 x i16>, ptr %b_ptr
+  %a_ext = zext <2 x i16> %a to <2 x i32>
+  %b_ext = sext <2 x i16> %b to <2 x i32>
+  %mul = mul <2 x i32> %a_ext, %b_ext
+  %shift = lshr <2 x i32> %mul, <i32 16, i32 16>
+  %res = trunc <2 x i32> %shift to <2 x i16>
+  store <2 x i16> %res, ptr %ret_ptr
+  ret void
+}
diff --git a/llvm/test/CodeGen/RISCV/rvp-ext-rv64.ll b/llvm/test/CodeGen/RISCV/rvp-ext-rv64.ll
index 9415fb0488bb7..3663d47ad8e15 100644
--- a/llvm/test/CodeGen/RISCV/rvp-ext-rv64.ll
+++ b/llvm/test/CodeGen/RISCV/rvp-ext-rv64.ll
@@ -921,7 +921,6 @@ define void @test_pmulhu_w(ptr %ret_ptr, ptr %a_ptr, ptr %b_ptr) {
   ret void
 }
 
-
 ; Test packed multiply high signed-unsigned
 define void @test_pmulhsu_h(ptr %ret_ptr, ptr %a_ptr, ptr %b_ptr) {
 ; CHECK-LABEL: test_pmulhsu_h:
@@ -942,6 +941,25 @@ define void @test_pmulhsu_h(ptr %ret_ptr, ptr %a_ptr, ptr %b_ptr) {
   ret void
 }
 
+define void @test_pmulhsu_h_commuted(ptr %ret_ptr, ptr %a_ptr, ptr %b_ptr) {
+; CHECK-LABEL: test_pmulhsu_h_commuted:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    ld a1, 0(a1)
+; CHECK-NEXT:    ld a2, 0(a2)
+; CHECK-NEXT:    pmulhsu.h a1, a2, a1
+; CHECK-NEXT:    sd a1, 0(a0)
+; CHECK-NEXT:    ret
+  %a = load <4 x i16>, ptr %a_ptr
+  %b = load <4 x i16>, ptr %b_ptr
+  %a_ext = zext <4 x i16> %a to <4 x i32>
+  %b_ext = sext <4 x i16> %b to <4 x i32>
+  %mul = mul <4 x i32> %a_ext, %b_ext
+  %shift = lshr <4 x i32> %mul, <i32 16, i32 16, i32 16, i32 16>
+  %res = trunc <4 x i32> %shift to <4 x i16>
+  store <4 x i16> %res, ptr %ret_ptr
+  ret void
+}
+
 define void @test_pmulhsu_w(ptr %ret_ptr, ptr %a_ptr, ptr %b_ptr) {
 ; CHECK-LABEL: test_pmulhsu_w:
 ; CHECK:       # %bb.0:
@@ -960,3 +978,22 @@ define void @test_pmulhsu_w(ptr %ret_ptr, ptr %a_ptr, ptr %b_ptr) {
   store <2 x i32> %res, ptr %ret_ptr
   ret void
 }
+
+define void @test_pmulhsu_w_commuted(ptr %ret_ptr, ptr %a_ptr, ptr %b_ptr) {
+; CHECK-LABEL: test_pmulhsu_w_commuted:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    ld a1, 0(a1)
+; CHECK-NEXT:    ld a2, 0(a2)
+; CHECK-NEXT:    pmulhsu.w a1, a2, a1
+; CHECK-NEXT:    sd a1, 0(a0)
+; CHECK-NEXT:    ret
+  %a = load <2 x i32>, ptr %a_ptr
+  %b = load <2 x i32>, ptr %b_ptr
+  %a_ext = zext <2 x i32> %a to <2 x i64>
+  %b_ext = sext <2 x i32> %b to <2 x i64>
+  %mul = mul <2 x i64> %a_ext, %b_ext
+  %shift = lshr <2 x i64> %mul, <i64 32, i64 32>
+  %res = trunc <2 x i64> %shift to <2 x i32>
+  store <2 x i32> %res, ptr %ret_ptr
+  ret void
+}



More information about the llvm-commits mailing list