[llvm] p ext codegen mul2 (PR #171593)

via llvm-commits llvm-commits at lists.llvm.org
Wed Dec 10 02:30:33 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-risc-v

Author: Brandon Wu (4vtomat)

<details>
<summary>Changes</summary>

- [llvm][RISCV] Support mulh for P extension codegen
- [llvm][RISCV] Support rounding mulh for P extension codegen


---

Patch is 22.15 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/171593.diff


4 Files Affected:

- (modified) llvm/lib/Target/RISCV/RISCVISelLowering.cpp (+79-30) 
- (modified) llvm/lib/Target/RISCV/RISCVInstrInfoP.td (+30-6) 
- (modified) llvm/test/CodeGen/RISCV/rvp-ext-rv32.ll (+123) 
- (modified) llvm/test/CodeGen/RISCV/rvp-ext-rv64.ll (+242) 


``````````diff
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index f28772a74d433..c75ac76415714 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -15210,18 +15210,26 @@ void RISCVTargetLowering::ReplaceNodeResults(SDNode *N,
     break;
   }
   case RISCVISD::PASUB:
-  case RISCVISD::PASUBU: {
+  case RISCVISD::PASUBU:
+  case RISCVISD::PMULHSU:
+  case RISCVISD::PMULHR:
+  case RISCVISD::PMULHRU:
+  case RISCVISD::PMULHRSU: {
     MVT VT = N->getSimpleValueType(0);
     SDValue Op0 = N->getOperand(0);
     SDValue Op1 = N->getOperand(1);
-    assert(VT == MVT::v2i16 || VT == MVT::v4i8);
+    unsigned Opcode = N->getOpcode();
+    // PMULH* variants don't support i8
+    bool IsMulH = Opcode == RISCVISD::PMULHSU || Opcode == RISCVISD::PMULHR ||
+                  Opcode == RISCVISD::PMULHRU || Opcode == RISCVISD::PMULHRSU;
+    assert(VT == MVT::v2i16 || (!IsMulH && VT == MVT::v4i8));
     MVT NewVT = MVT::v4i16;
     if (VT == MVT::v4i8)
       NewVT = MVT::v8i8;
     SDValue Undef = DAG.getUNDEF(VT);
     Op0 = DAG.getNode(ISD::CONCAT_VECTORS, DL, NewVT, {Op0, Undef});
     Op1 = DAG.getNode(ISD::CONCAT_VECTORS, DL, NewVT, {Op1, Undef});
-    Results.push_back(DAG.getNode(N->getOpcode(), DL, NewVT, {Op0, Op1}));
+    Results.push_back(DAG.getNode(Opcode, DL, NewVT, {Op0, Op1}));
     return;
   }
   case ISD::EXTRACT_VECTOR_ELT: {
@@ -16331,9 +16339,10 @@ static SDValue combineTruncSelectToSMaxUSat(SDNode *N, SelectionDAG &DAG) {
   return DAG.getNode(ISD::TRUNCATE, DL, VT, Min);
 }
 
-// Handle P extension averaging subtraction pattern:
-// (vXiY (trunc (srl (sub ([s|z]ext vXiY:$a), ([s|z]ext vXiY:$b)), 1)))
-// -> PASUB/PASUBU
+// Handle P extension truncate patterns:
+// PASUB/PASUBU: (trunc (srl (sub ([s|z]ext a), ([s|z]ext b)), 1))
+// PMULHSU: (trunc (srl (mul (sext a), (zext b)), EltBits))
+// PMULHR*: (trunc (srl (add (mul (sext a), (zext b)), round_const), EltBits))
 static SDValue combinePExtTruncate(SDNode *N, SelectionDAG &DAG,
                                    const RISCVSubtarget &Subtarget) {
   SDValue N0 = N->getOperand(0);
@@ -16346,7 +16355,7 @@ static SDValue combinePExtTruncate(SDNode *N, SelectionDAG &DAG,
       VecVT != MVT::v4i8 && VecVT != MVT::v2i32)
     return SDValue();
 
-  // Check if shift amount is 1
+  // Check if shift amount is a splat constant
   SDValue ShAmt = N0.getOperand(1);
   if (ShAmt.getOpcode() != ISD::BUILD_VECTOR)
     return SDValue();
@@ -16360,44 +16369,84 @@ static SDValue combinePExtTruncate(SDNode *N, SelectionDAG &DAG,
   ConstantSDNode *C = dyn_cast<ConstantSDNode>(Splat);
   if (!C)
     return SDValue();
-  if (C->getZExtValue() != 1)
-    return SDValue();
 
-  // Check for SUB operation
-  SDValue Sub = N0.getOperand(0);
-  if (Sub.getOpcode() != ISD::SUB)
-    return SDValue();
+  SDValue Op = N0.getOperand(0);
+  unsigned ShAmtVal = C->getZExtValue();
+  unsigned EltBits = VecVT.getScalarSizeInBits();
+
+  // Check for rounding pattern: (add (mul ...), round_const)
+  bool IsRounding = false;
+  if (Op.getOpcode() == ISD::ADD && (EltBits == 16 || EltBits == 32)) {
+    SDValue AddRHS = Op.getOperand(1);
+    if (AddRHS.getOpcode() == ISD::BUILD_VECTOR) {
+      if (auto *RndBV = dyn_cast<BuildVectorSDNode>(AddRHS.getNode())) {
+        if (auto *RndC =
+                dyn_cast_or_null<ConstantSDNode>(RndBV->getSplatValue())) {
+          uint64_t ExpectedRnd = 1ULL << (EltBits - 1);
+          if (RndC->getZExtValue() == ExpectedRnd &&
+              Op.getOperand(0).getOpcode() == ISD::MUL) {
+            Op = Op.getOperand(0);
+            IsRounding = true;
+          }
+        }
+      }
+    }
+  }
 
-  SDValue LHS = Sub.getOperand(0);
-  SDValue RHS = Sub.getOperand(1);
+  SDValue LHS = Op.getOperand(0);
+  SDValue RHS = Op.getOperand(1);
 
-  // Check if both operands are sign/zero extends from the target
-  // type
-  bool IsSignExt = LHS.getOpcode() == ISD::SIGN_EXTEND &&
-                   RHS.getOpcode() == ISD::SIGN_EXTEND;
-  bool IsZeroExt = LHS.getOpcode() == ISD::ZERO_EXTEND &&
-                   RHS.getOpcode() == ISD::ZERO_EXTEND;
+  bool LHSIsSExt = LHS.getOpcode() == ISD::SIGN_EXTEND;
+  bool LHSIsZExt = LHS.getOpcode() == ISD::ZERO_EXTEND;
+  bool RHSIsSExt = RHS.getOpcode() == ISD::SIGN_EXTEND;
+  bool RHSIsZExt = RHS.getOpcode() == ISD::ZERO_EXTEND;
 
-  if (!IsSignExt && !IsZeroExt)
+  if (!(LHSIsSExt || LHSIsZExt) || !(RHSIsSExt || RHSIsZExt))
     return SDValue();
 
   SDValue A = LHS.getOperand(0);
   SDValue B = RHS.getOperand(0);
 
-  // Check if the extends are from our target vector type
   if (A.getValueType() != VT || B.getValueType() != VT)
     return SDValue();
 
-  // Determine the instruction based on type and signedness
   unsigned Opc;
-  if (IsSignExt)
-    Opc = RISCVISD::PASUB;
-  else if (IsZeroExt)
-    Opc = RISCVISD::PASUBU;
-  else
+  switch (Op.getOpcode()) {
+  default:
     return SDValue();
+  case ISD::SUB:
+    // PASUB/PASUBU: shift amount must be 1
+    if (ShAmtVal != 1)
+      return SDValue();
+    if (LHSIsSExt && RHSIsSExt)
+      Opc = RISCVISD::PASUB;
+    else if (LHSIsZExt && RHSIsZExt)
+      Opc = RISCVISD::PASUBU;
+    else
+      return SDValue();
+    break;
+  case ISD::MUL:
+    // PMULH*/PMULHR*: shift amount must be element size, only for i16/i32
+    if (ShAmtVal != EltBits || (EltBits != 16 && EltBits != 32))
+      return SDValue();
+    if (IsRounding) {
+      if (LHSIsSExt && RHSIsSExt)
+        Opc = RISCVISD::PMULHR;
+      else if (LHSIsZExt && RHSIsZExt)
+        Opc = RISCVISD::PMULHRU;
+      else if (LHSIsSExt && RHSIsZExt)
+        Opc = RISCVISD::PMULHRSU;
+      else
+        return SDValue();
+    } else {
+      if (LHSIsSExt && RHSIsZExt)
+        Opc = RISCVISD::PMULHSU;
+      else
+        return SDValue();
+    }
+    break;
+  }
 
-  // Create the machine node directly
   return DAG.getNode(Opc, SDLoc(N), VT, {A, B});
 }
 
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoP.td b/llvm/lib/Target/RISCV/RISCVInstrInfoP.td
index bba9f961b9639..587ac8ee238f4 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoP.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoP.td
@@ -1463,12 +1463,16 @@ let Predicates = [HasStdExtP, IsRV32] in {
 
 def riscv_absw : RVSDNode<"ABSW", SDT_RISCVIntUnaryOpW>;
 
-def SDT_RISCVPASUB : SDTypeProfile<1, 2, [SDTCisVec<0>,
-                                          SDTCisInt<0>,
-                                          SDTCisSameAs<0, 1>,
-                                          SDTCisSameAs<0, 2>]>;
-def riscv_pasub : RVSDNode<"PASUB", SDT_RISCVPASUB>;
-def riscv_pasubu : RVSDNode<"PASUBU", SDT_RISCVPASUB>;
+def SDT_RISCVPBinOp : SDTypeProfile<1, 2, [SDTCisVec<0>,
+                                           SDTCisInt<0>,
+                                           SDTCisSameAs<0, 1>,
+                                           SDTCisSameAs<0, 2>]>;
+def riscv_pasub : RVSDNode<"PASUB", SDT_RISCVPBinOp>;
+def riscv_pasubu : RVSDNode<"PASUBU", SDT_RISCVPBinOp>;
+def riscv_pmulhsu : RVSDNode<"PMULHSU", SDT_RISCVPBinOp>;
+def riscv_pmulhr : RVSDNode<"PMULHR", SDT_RISCVPBinOp>;
+def riscv_pmulhru : RVSDNode<"PMULHRU", SDT_RISCVPBinOp>;
+def riscv_pmulhrsu : RVSDNode<"PMULHRSU", SDT_RISCVPBinOp>;
 
 let Predicates = [HasStdExtP] in {
   def : PatGpr<abs, ABS>;
@@ -1513,6 +1517,16 @@ let Predicates = [HasStdExtP] in {
   def: Pat<(XLenVecI16VT (abds GPR:$rs1, GPR:$rs2)), (PABD_H GPR:$rs1, GPR:$rs2)>;
   def: Pat<(XLenVecI16VT (abdu GPR:$rs1, GPR:$rs2)), (PABDU_H GPR:$rs1, GPR:$rs2)>;
 
+  // 16-bit multiply high patterns
+  def: Pat<(XLenVecI16VT (mulhs GPR:$rs1, GPR:$rs2)), (PMULH_H GPR:$rs1, GPR:$rs2)>;
+  def: Pat<(XLenVecI16VT (mulhu GPR:$rs1, GPR:$rs2)), (PMULHU_H GPR:$rs1, GPR:$rs2)>;
+  def: Pat<(XLenVecI16VT (riscv_pmulhsu GPR:$rs1, GPR:$rs2)), (PMULHSU_H GPR:$rs1, GPR:$rs2)>;
+
+  // 16-bit multiply high rounding patterns
+  def: Pat<(XLenVecI16VT (riscv_pmulhr GPR:$rs1, GPR:$rs2)), (PMULHR_H GPR:$rs1, GPR:$rs2)>;
+  def: Pat<(XLenVecI16VT (riscv_pmulhru GPR:$rs1, GPR:$rs2)), (PMULHRU_H GPR:$rs1, GPR:$rs2)>;
+  def: Pat<(XLenVecI16VT (riscv_pmulhrsu GPR:$rs1, GPR:$rs2)), (PMULHRSU_H GPR:$rs1, GPR:$rs2)>;
+
   // 8-bit logical shift left patterns
   def: Pat<(XLenVecI8VT (shl GPR:$rs1, (XLenVecI8VT (splat_vector uimm3:$shamt)))),
            (PSLLI_B GPR:$rs1, uimm3:$shamt)>;
@@ -1580,6 +1594,16 @@ let Predicates = [HasStdExtP, IsRV64] in {
   def: Pat<(v2i32 (riscv_pasub GPR:$rs1, GPR:$rs2)), (PASUB_W GPR:$rs1, GPR:$rs2)>;
   def: Pat<(v2i32 (riscv_pasubu GPR:$rs1, GPR:$rs2)), (PASUBU_W GPR:$rs1, GPR:$rs2)>;
 
+  // 32-bit multiply high patterns
+  def: Pat<(v2i32 (mulhs GPR:$rs1, GPR:$rs2)), (PMULH_W GPR:$rs1, GPR:$rs2)>;
+  def: Pat<(v2i32 (mulhu GPR:$rs1, GPR:$rs2)), (PMULHU_W GPR:$rs1, GPR:$rs2)>;
+  def: Pat<(v2i32 (riscv_pmulhsu GPR:$rs1, GPR:$rs2)), (PMULHSU_W GPR:$rs1, GPR:$rs2)>;
+
+  // 32-bit multiply high rounding patterns
+  def: Pat<(v2i32 (riscv_pmulhr GPR:$rs1, GPR:$rs2)), (PMULHR_W GPR:$rs1, GPR:$rs2)>;
+  def: Pat<(v2i32 (riscv_pmulhru GPR:$rs1, GPR:$rs2)), (PMULHRU_W GPR:$rs1, GPR:$rs2)>;
+  def: Pat<(v2i32 (riscv_pmulhrsu GPR:$rs1, GPR:$rs2)), (PMULHRSU_W GPR:$rs1, GPR:$rs2)>;
+
   // 32-bit logical shift left
   def: Pat<(v2i32 (shl GPR:$rs1, (v2i32 (splat_vector (XLenVT GPR:$rs2))))),
            (PSLL_WS GPR:$rs1, GPR:$rs2)>;
diff --git a/llvm/test/CodeGen/RISCV/rvp-ext-rv32.ll b/llvm/test/CodeGen/RISCV/rvp-ext-rv32.ll
index cd59aa03597e2..33127c3d140fa 100644
--- a/llvm/test/CodeGen/RISCV/rvp-ext-rv32.ll
+++ b/llvm/test/CodeGen/RISCV/rvp-ext-rv32.ll
@@ -746,3 +746,126 @@ define void @test_psll_bs_vec_shamt(ptr %ret_ptr, ptr %a_ptr, ptr %shamt_ptr) {
   store <4 x i8> %res, ptr %ret_ptr
   ret void
 }
+
+; Test packed multiply high signed for v2i16
+define void @test_pmulh_h(ptr %ret_ptr, ptr %a_ptr, ptr %b_ptr) {
+; CHECK-LABEL: test_pmulh_h:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    lw a1, 0(a1)
+; CHECK-NEXT:    lw a2, 0(a2)
+; CHECK-NEXT:    pmulh.h a1, a1, a2
+; CHECK-NEXT:    sw a1, 0(a0)
+; CHECK-NEXT:    ret
+  %a = load <2 x i16>, ptr %a_ptr
+  %b = load <2 x i16>, ptr %b_ptr
+  %a_ext = sext <2 x i16> %a to <2 x i32>
+  %b_ext = sext <2 x i16> %b to <2 x i32>
+  %mul = mul <2 x i32> %a_ext, %b_ext
+  %shift = lshr <2 x i32> %mul, <i32 16, i32 16>
+  %res = trunc <2 x i32> %shift to <2 x i16>
+  store <2 x i16> %res, ptr %ret_ptr
+  ret void
+}
+
+; Test packed multiply high unsigned for v2i16
+define void @test_pmulhu_h(ptr %ret_ptr, ptr %a_ptr, ptr %b_ptr) {
+; CHECK-LABEL: test_pmulhu_h:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    lw a1, 0(a1)
+; CHECK-NEXT:    lw a2, 0(a2)
+; CHECK-NEXT:    pmulhu.h a1, a1, a2
+; CHECK-NEXT:    sw a1, 0(a0)
+; CHECK-NEXT:    ret
+  %a = load <2 x i16>, ptr %a_ptr
+  %b = load <2 x i16>, ptr %b_ptr
+  %a_ext = zext <2 x i16> %a to <2 x i32>
+  %b_ext = zext <2 x i16> %b to <2 x i32>
+  %mul = mul <2 x i32> %a_ext, %b_ext
+  %shift = lshr <2 x i32> %mul, <i32 16, i32 16>
+  %res = trunc <2 x i32> %shift to <2 x i16>
+  store <2 x i16> %res, ptr %ret_ptr
+  ret void
+}
+
+; Test packed multiply high signed-unsigned for v2i16
+define void @test_pmulhsu_h(ptr %ret_ptr, ptr %a_ptr, ptr %b_ptr) {
+; CHECK-LABEL: test_pmulhsu_h:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    lw a1, 0(a1)
+; CHECK-NEXT:    lw a2, 0(a2)
+; CHECK-NEXT:    pmulhsu.h a1, a1, a2
+; CHECK-NEXT:    sw a1, 0(a0)
+; CHECK-NEXT:    ret
+  %a = load <2 x i16>, ptr %a_ptr
+  %b = load <2 x i16>, ptr %b_ptr
+  %a_ext = sext <2 x i16> %a to <2 x i32>
+  %b_ext = zext <2 x i16> %b to <2 x i32>
+  %mul = mul <2 x i32> %a_ext, %b_ext
+  %shift = lshr <2 x i32> %mul, <i32 16, i32 16>
+  %res = trunc <2 x i32> %shift to <2 x i16>
+  store <2 x i16> %res, ptr %ret_ptr
+  ret void
+}
+
+; Test packed multiply high rounding signed for v2i16
+define void @test_pmulhr_h(ptr %ret_ptr, ptr %a_ptr, ptr %b_ptr) {
+; CHECK-LABEL: test_pmulhr_h:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    lw a1, 0(a1)
+; CHECK-NEXT:    lw a2, 0(a2)
+; CHECK-NEXT:    pmulhr.h a1, a1, a2
+; CHECK-NEXT:    sw a1, 0(a0)
+; CHECK-NEXT:    ret
+  %a = load <2 x i16>, ptr %a_ptr
+  %b = load <2 x i16>, ptr %b_ptr
+  %a_ext = sext <2 x i16> %a to <2 x i32>
+  %b_ext = sext <2 x i16> %b to <2 x i32>
+  %mul = mul <2 x i32> %a_ext, %b_ext
+  %add = add <2 x i32> %mul, <i32 32768, i32 32768>
+  %shift = lshr <2 x i32> %add, <i32 16, i32 16>
+  %res = trunc <2 x i32> %shift to <2 x i16>
+  store <2 x i16> %res, ptr %ret_ptr
+  ret void
+}
+
+; Test packed multiply high rounding unsigned for v2i16
+define void @test_pmulhru_h(ptr %ret_ptr, ptr %a_ptr, ptr %b_ptr) {
+; CHECK-LABEL: test_pmulhru_h:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    lw a1, 0(a1)
+; CHECK-NEXT:    lw a2, 0(a2)
+; CHECK-NEXT:    pmulhru.h a1, a1, a2
+; CHECK-NEXT:    sw a1, 0(a0)
+; CHECK-NEXT:    ret
+  %a = load <2 x i16>, ptr %a_ptr
+  %b = load <2 x i16>, ptr %b_ptr
+  %a_ext = zext <2 x i16> %a to <2 x i32>
+  %b_ext = zext <2 x i16> %b to <2 x i32>
+  %mul = mul <2 x i32> %a_ext, %b_ext
+  %add = add <2 x i32> %mul, <i32 32768, i32 32768>
+  %shift = lshr <2 x i32> %add, <i32 16, i32 16>
+  %res = trunc <2 x i32> %shift to <2 x i16>
+  store <2 x i16> %res, ptr %ret_ptr
+  ret void
+}
+
+; Test packed multiply high rounding signed-unsigned for v2i16
+define void @test_pmulhrsu_h(ptr %ret_ptr, ptr %a_ptr, ptr %b_ptr) {
+; CHECK-LABEL: test_pmulhrsu_h:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    lw a1, 0(a1)
+; CHECK-NEXT:    lw a2, 0(a2)
+; CHECK-NEXT:    pmulhrsu.h a1, a1, a2
+; CHECK-NEXT:    sw a1, 0(a0)
+; CHECK-NEXT:    ret
+  %a = load <2 x i16>, ptr %a_ptr
+  %b = load <2 x i16>, ptr %b_ptr
+  %a_ext = sext <2 x i16> %a to <2 x i32>
+  %b_ext = zext <2 x i16> %b to <2 x i32>
+  %mul = mul <2 x i32> %a_ext, %b_ext
+  %add = add <2 x i32> %mul, <i32 32768, i32 32768>
+  %shift = lshr <2 x i32> %add, <i32 16, i32 16>
+  %res = trunc <2 x i32> %shift to <2 x i16>
+  store <2 x i16> %res, ptr %ret_ptr
+  ret void
+}
diff --git a/llvm/test/CodeGen/RISCV/rvp-ext-rv64.ll b/llvm/test/CodeGen/RISCV/rvp-ext-rv64.ll
index c7fb891cdd996..8a741f5821b70 100644
--- a/llvm/test/CodeGen/RISCV/rvp-ext-rv64.ll
+++ b/llvm/test/CodeGen/RISCV/rvp-ext-rv64.ll
@@ -841,3 +841,245 @@ define void @test_psll_ws_vec_shamt(ptr %ret_ptr, ptr %a_ptr, ptr %shamt_ptr) {
   store <2 x i32> %res, ptr %ret_ptr
   ret void
 }
+
+; Test packed multiply high signed
+define void @test_pmulh_h(ptr %ret_ptr, ptr %a_ptr, ptr %b_ptr) {
+; CHECK-LABEL: test_pmulh_h:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    ld a1, 0(a1)
+; CHECK-NEXT:    ld a2, 0(a2)
+; CHECK-NEXT:    pmulh.h a1, a1, a2
+; CHECK-NEXT:    sd a1, 0(a0)
+; CHECK-NEXT:    ret
+  %a = load <4 x i16>, ptr %a_ptr
+  %b = load <4 x i16>, ptr %b_ptr
+  %a_ext = sext <4 x i16> %a to <4 x i32>
+  %b_ext = sext <4 x i16> %b to <4 x i32>
+  %mul = mul <4 x i32> %a_ext, %b_ext
+  %shift = lshr <4 x i32> %mul, <i32 16, i32 16, i32 16, i32 16>
+  %res = trunc <4 x i32> %shift to <4 x i16>
+  store <4 x i16> %res, ptr %ret_ptr
+  ret void
+}
+
+define void @test_pmulh_w(ptr %ret_ptr, ptr %a_ptr, ptr %b_ptr) {
+; CHECK-LABEL: test_pmulh_w:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    ld a1, 0(a1)
+; CHECK-NEXT:    ld a2, 0(a2)
+; CHECK-NEXT:    pmulh.w a1, a1, a2
+; CHECK-NEXT:    sd a1, 0(a0)
+; CHECK-NEXT:    ret
+  %a = load <2 x i32>, ptr %a_ptr
+  %b = load <2 x i32>, ptr %b_ptr
+  %a_ext = sext <2 x i32> %a to <2 x i64>
+  %b_ext = sext <2 x i32> %b to <2 x i64>
+  %mul = mul <2 x i64> %a_ext, %b_ext
+  %shift = lshr <2 x i64> %mul, <i64 32, i64 32>
+  %res = trunc <2 x i64> %shift to <2 x i32>
+  store <2 x i32> %res, ptr %ret_ptr
+  ret void
+}
+
+
+; Test packed multiply high unsigned
+define void @test_pmulhu_h(ptr %ret_ptr, ptr %a_ptr, ptr %b_ptr) {
+; CHECK-LABEL: test_pmulhu_h:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    ld a1, 0(a1)
+; CHECK-NEXT:    ld a2, 0(a2)
+; CHECK-NEXT:    pmulhu.h a1, a1, a2
+; CHECK-NEXT:    sd a1, 0(a0)
+; CHECK-NEXT:    ret
+  %a = load <4 x i16>, ptr %a_ptr
+  %b = load <4 x i16>, ptr %b_ptr
+  %a_ext = zext <4 x i16> %a to <4 x i32>
+  %b_ext = zext <4 x i16> %b to <4 x i32>
+  %mul = mul <4 x i32> %a_ext, %b_ext
+  %shift = lshr <4 x i32> %mul, <i32 16, i32 16, i32 16, i32 16>
+  %res = trunc <4 x i32> %shift to <4 x i16>
+  store <4 x i16> %res, ptr %ret_ptr
+  ret void
+}
+
+define void @test_pmulhu_w(ptr %ret_ptr, ptr %a_ptr, ptr %b_ptr) {
+; CHECK-LABEL: test_pmulhu_w:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    ld a1, 0(a1)
+; CHECK-NEXT:    ld a2, 0(a2)
+; CHECK-NEXT:    pmulhu.w a1, a1, a2
+; CHECK-NEXT:    sd a1, 0(a0)
+; CHECK-NEXT:    ret
+  %a = load <2 x i32>, ptr %a_ptr
+  %b = load <2 x i32>, ptr %b_ptr
+  %a_ext = zext <2 x i32> %a to <2 x i64>
+  %b_ext = zext <2 x i32> %b to <2 x i64>
+  %mul = mul <2 x i64> %a_ext, %b_ext
+  %shift = lshr <2 x i64> %mul, <i64 32, i64 32>
+  %res = trunc <2 x i64> %shift to <2 x i32>
+  store <2 x i32> %res, ptr %ret_ptr
+  ret void
+}
+
+
+; Test packed multiply high signed-unsigned
+define void @test_pmulhsu_h(ptr %ret_ptr, ptr %a_ptr, ptr %b_ptr) {
+; CHECK-LABEL: test_pmulhsu_h:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    ld a1, 0(a1)
+; CHECK-NEXT:    ld a2, 0(a2)
+; CHECK-NEXT:    pmulhsu.h a1, a1, a2
+; CHECK-NEXT:    sd a1, 0(a0)
+; CHECK-NEXT:    ret
+  %a = load <4 x i16>, ptr %a_ptr
+  %b = load <4 x i16>, ptr %b_ptr
+  %a_ext = sext <4 x i16> %a to <4 x i32>
+  %b_ext = zext <4 x i16> %b to <4 x i32>
+  %mul = mul <4 x i32> %a_ext, %b_ext
+  %shift = lshr <4 x i32> %mul, <i32 16, i32 16, i32 16, i32 16>
+  %res = trunc <4 x i32> %shift to <4 x i16>
+  store <4 x i16> %res, ptr %ret_ptr
+  ret void
+}
+
+define void @test_pmulhsu_w(ptr %ret_ptr, ptr %a_ptr, ptr %b_ptr) {
+; CHECK-LABEL: test_pmulhsu_w:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    ld a1, 0(a1)
+; CHECK-NEXT:    ld a2, 0(a2)
+; CHECK-NEXT:    pmulhsu.w a1, a1, a2
+; CHECK-NEXT:    sd a1, 0(a0)
+; CHECK-NEXT:    ret
+  %a = load <2 x i32>, ptr %a_ptr
+  %b = load <2 x i32>, ptr %b_ptr
+  %a_ext = sext <2 x i32> %a to <2 x i64>
+  %b_ext = zext <2 x i32> %b to <2 x i64>
+  %mul = mul <2 x i64> %a_ext, %b_ext
+  %shift = lshr <2 x i64> %mul, <i64 32, i64 32>
+  %res = trunc <2 x i64> %shift to <2 x i32>
+  store <2 x i32> %res, ptr %ret_ptr
+  ret void
+}
+
+; Test packed multiply high rounding signed
+define void @test_pmulhr_h(ptr %ret_ptr, ptr %a_ptr, ptr %b_ptr) {
+; CHECK-LABEL: test_pmulhr_h:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    ld a1, 0(a1)
+; CHECK-NEXT:    ld a2, 0(a2)
+; CHECK-NEXT:    pmulhr.h a1, a1, a2
+; CHECK-NEXT:    sd a1, 0(a0)
+; CHECK-NEXT:    ret
+  %a = load <4 x i16>, ptr %a_ptr
+  %b = load <4 x i16>, ptr %b_ptr
+  %a_ext = sext <4 x i16> %a to <4 x i32>
+  %b_ext = sext <4 x i16> %b to <4 x i32>
+  %mul = mul <4 x i32> %a_ext, %b_ext
+  %add = add <4 x i32> %mul, <i32 32768, i32 32768, i32 32768, i32 32768>
+  %shift = lshr <4 x i32> %add, <i32 16, i32 16, i32 16, i32 16>
+  %res = trunc <4 x i32> %shift to <4 x i16>
+  store <4 x i16> %res, ptr %ret_ptr
+  ret void
+}
+
+define void @test_pmulhr_w(ptr %ret_ptr, ptr %a_ptr, ptr %b_ptr) {
+; CHECK-LABEL: test_pmulhr_w:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    ld a1, 0(a1)
+; CHECK-NEXT:    ld a2, 0(a2)
+; CHECK-NEXT:    pmulhr.w a1, a1, a2
+; CHECK-NEXT:    sd a1, 0(a0)
+; CHECK-NEXT:    ret
+  %a = load <2 x i32>, ptr %a_ptr
+  %b = load <2 x i32>, ptr %b_ptr
+  %a_ext = sext <2 x i32> %a to <2 x i64>
+  %b_ext = sext <2 x i32> %b to <2 x i64>
+  %mul = mul <2 x i64> %a_ext, %b_ext
+  %add = add <2 x i64> %mul, <i64 2147483648, i64 2147483648>
+  %shift = lshr <2 x i64> %add, <i64 32, i64 32>
+  %res = trunc <2 x i64> %shift to <2 x i32>
+  store <2 x i32> %res, ptr %ret_ptr
+  ret void
+}
+
+; Test packed multiply high rounding unsigned
+define void @test_pmulhru_h(ptr %ret_ptr, ptr %a_ptr, ptr %b_ptr) {
+; CHECK-LABEL: test_pmulhru_h:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    ld a1, 0(a1)
+; CHECK-NEXT:    ld a2, 0(a2)
+; CHECK-NEXT:    pmulhru.h a1, a1, a2
+; CHECK-NEXT:    sd a1, 0(a0)
+; CHECK-NEXT:    ret
+  %a = load <4 x i16>, ptr %a_ptr
+  %b = load <4 x i16>, ptr %b_ptr
+  %a_ext = zext <4 x i16> %a to <4 x i32>
+  %b_ext = zext <4 x i16> %b to <4 x i32>
+  %mul = mul <4 x i32> %a_ext, %b_ext
+  %add = add <4 x i32> %mul, <i32 32768, i32 32768, i32 32768, i32 32768>
+  %shift = lshr <4 x i32> %add, <i32 16, i32 16, i32 16, i32 16>
+  %res = trunc <4 x i32> %shift to <4 x i16>
+  s...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/171593


More information about the llvm-commits mailing list