[llvm] p ext codegen mul2 (PR #171593)
via llvm-commits
llvm-commits at lists.llvm.org
Wed Dec 10 02:30:33 PST 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-backend-risc-v
Author: Brandon Wu (4vtomat)
<details>
<summary>Changes</summary>
- [llvm][RISCV] Support mulh for P extension codegen
- [llvm][RISCV] Support rounding mulh for P extension codegen
---
Patch is 22.15 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/171593.diff
4 Files Affected:
- (modified) llvm/lib/Target/RISCV/RISCVISelLowering.cpp (+79-30)
- (modified) llvm/lib/Target/RISCV/RISCVInstrInfoP.td (+30-6)
- (modified) llvm/test/CodeGen/RISCV/rvp-ext-rv32.ll (+123)
- (modified) llvm/test/CodeGen/RISCV/rvp-ext-rv64.ll (+242)
``````````diff
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index f28772a74d433..c75ac76415714 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -15210,18 +15210,26 @@ void RISCVTargetLowering::ReplaceNodeResults(SDNode *N,
break;
}
case RISCVISD::PASUB:
- case RISCVISD::PASUBU: {
+ case RISCVISD::PASUBU:
+ case RISCVISD::PMULHSU:
+ case RISCVISD::PMULHR:
+ case RISCVISD::PMULHRU:
+ case RISCVISD::PMULHRSU: {
MVT VT = N->getSimpleValueType(0);
SDValue Op0 = N->getOperand(0);
SDValue Op1 = N->getOperand(1);
- assert(VT == MVT::v2i16 || VT == MVT::v4i8);
+ unsigned Opcode = N->getOpcode();
+ // PMULH* variants don't support i8
+ bool IsMulH = Opcode == RISCVISD::PMULHSU || Opcode == RISCVISD::PMULHR ||
+ Opcode == RISCVISD::PMULHRU || Opcode == RISCVISD::PMULHRSU;
+ assert(VT == MVT::v2i16 || (!IsMulH && VT == MVT::v4i8));
MVT NewVT = MVT::v4i16;
if (VT == MVT::v4i8)
NewVT = MVT::v8i8;
SDValue Undef = DAG.getUNDEF(VT);
Op0 = DAG.getNode(ISD::CONCAT_VECTORS, DL, NewVT, {Op0, Undef});
Op1 = DAG.getNode(ISD::CONCAT_VECTORS, DL, NewVT, {Op1, Undef});
- Results.push_back(DAG.getNode(N->getOpcode(), DL, NewVT, {Op0, Op1}));
+ Results.push_back(DAG.getNode(Opcode, DL, NewVT, {Op0, Op1}));
return;
}
case ISD::EXTRACT_VECTOR_ELT: {
@@ -16331,9 +16339,10 @@ static SDValue combineTruncSelectToSMaxUSat(SDNode *N, SelectionDAG &DAG) {
return DAG.getNode(ISD::TRUNCATE, DL, VT, Min);
}
-// Handle P extension averaging subtraction pattern:
-// (vXiY (trunc (srl (sub ([s|z]ext vXiY:$a), ([s|z]ext vXiY:$b)), 1)))
-// -> PASUB/PASUBU
+// Handle P extension truncate patterns:
+// PASUB/PASUBU: (trunc (srl (sub ([s|z]ext a), ([s|z]ext b)), 1))
+// PMULHSU: (trunc (srl (mul (sext a), (zext b)), EltBits))
+// PMULHR*: (trunc (srl (add (mul (sext a), (zext b)), round_const), EltBits))
static SDValue combinePExtTruncate(SDNode *N, SelectionDAG &DAG,
const RISCVSubtarget &Subtarget) {
SDValue N0 = N->getOperand(0);
@@ -16346,7 +16355,7 @@ static SDValue combinePExtTruncate(SDNode *N, SelectionDAG &DAG,
VecVT != MVT::v4i8 && VecVT != MVT::v2i32)
return SDValue();
- // Check if shift amount is 1
+ // Check if shift amount is a splat constant
SDValue ShAmt = N0.getOperand(1);
if (ShAmt.getOpcode() != ISD::BUILD_VECTOR)
return SDValue();
@@ -16360,44 +16369,84 @@ static SDValue combinePExtTruncate(SDNode *N, SelectionDAG &DAG,
ConstantSDNode *C = dyn_cast<ConstantSDNode>(Splat);
if (!C)
return SDValue();
- if (C->getZExtValue() != 1)
- return SDValue();
- // Check for SUB operation
- SDValue Sub = N0.getOperand(0);
- if (Sub.getOpcode() != ISD::SUB)
- return SDValue();
+ SDValue Op = N0.getOperand(0);
+ unsigned ShAmtVal = C->getZExtValue();
+ unsigned EltBits = VecVT.getScalarSizeInBits();
+
+ // Check for rounding pattern: (add (mul ...), round_const)
+ bool IsRounding = false;
+ if (Op.getOpcode() == ISD::ADD && (EltBits == 16 || EltBits == 32)) {
+ SDValue AddRHS = Op.getOperand(1);
+ if (AddRHS.getOpcode() == ISD::BUILD_VECTOR) {
+ if (auto *RndBV = dyn_cast<BuildVectorSDNode>(AddRHS.getNode())) {
+ if (auto *RndC =
+ dyn_cast_or_null<ConstantSDNode>(RndBV->getSplatValue())) {
+ uint64_t ExpectedRnd = 1ULL << (EltBits - 1);
+ if (RndC->getZExtValue() == ExpectedRnd &&
+ Op.getOperand(0).getOpcode() == ISD::MUL) {
+ Op = Op.getOperand(0);
+ IsRounding = true;
+ }
+ }
+ }
+ }
+ }
- SDValue LHS = Sub.getOperand(0);
- SDValue RHS = Sub.getOperand(1);
+ SDValue LHS = Op.getOperand(0);
+ SDValue RHS = Op.getOperand(1);
- // Check if both operands are sign/zero extends from the target
- // type
- bool IsSignExt = LHS.getOpcode() == ISD::SIGN_EXTEND &&
- RHS.getOpcode() == ISD::SIGN_EXTEND;
- bool IsZeroExt = LHS.getOpcode() == ISD::ZERO_EXTEND &&
- RHS.getOpcode() == ISD::ZERO_EXTEND;
+ bool LHSIsSExt = LHS.getOpcode() == ISD::SIGN_EXTEND;
+ bool LHSIsZExt = LHS.getOpcode() == ISD::ZERO_EXTEND;
+ bool RHSIsSExt = RHS.getOpcode() == ISD::SIGN_EXTEND;
+ bool RHSIsZExt = RHS.getOpcode() == ISD::ZERO_EXTEND;
- if (!IsSignExt && !IsZeroExt)
+ if (!(LHSIsSExt || LHSIsZExt) || !(RHSIsSExt || RHSIsZExt))
return SDValue();
SDValue A = LHS.getOperand(0);
SDValue B = RHS.getOperand(0);
- // Check if the extends are from our target vector type
if (A.getValueType() != VT || B.getValueType() != VT)
return SDValue();
- // Determine the instruction based on type and signedness
unsigned Opc;
- if (IsSignExt)
- Opc = RISCVISD::PASUB;
- else if (IsZeroExt)
- Opc = RISCVISD::PASUBU;
- else
+ switch (Op.getOpcode()) {
+ default:
return SDValue();
+ case ISD::SUB:
+ // PASUB/PASUBU: shift amount must be 1
+ if (ShAmtVal != 1)
+ return SDValue();
+ if (LHSIsSExt && RHSIsSExt)
+ Opc = RISCVISD::PASUB;
+ else if (LHSIsZExt && RHSIsZExt)
+ Opc = RISCVISD::PASUBU;
+ else
+ return SDValue();
+ break;
+ case ISD::MUL:
+ // PMULH*/PMULHR*: shift amount must be element size, only for i16/i32
+ if (ShAmtVal != EltBits || (EltBits != 16 && EltBits != 32))
+ return SDValue();
+ if (IsRounding) {
+ if (LHSIsSExt && RHSIsSExt)
+ Opc = RISCVISD::PMULHR;
+ else if (LHSIsZExt && RHSIsZExt)
+ Opc = RISCVISD::PMULHRU;
+ else if (LHSIsSExt && RHSIsZExt)
+ Opc = RISCVISD::PMULHRSU;
+ else
+ return SDValue();
+ } else {
+ if (LHSIsSExt && RHSIsZExt)
+ Opc = RISCVISD::PMULHSU;
+ else
+ return SDValue();
+ }
+ break;
+ }
- // Create the machine node directly
return DAG.getNode(Opc, SDLoc(N), VT, {A, B});
}
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoP.td b/llvm/lib/Target/RISCV/RISCVInstrInfoP.td
index bba9f961b9639..587ac8ee238f4 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoP.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoP.td
@@ -1463,12 +1463,16 @@ let Predicates = [HasStdExtP, IsRV32] in {
def riscv_absw : RVSDNode<"ABSW", SDT_RISCVIntUnaryOpW>;
-def SDT_RISCVPASUB : SDTypeProfile<1, 2, [SDTCisVec<0>,
- SDTCisInt<0>,
- SDTCisSameAs<0, 1>,
- SDTCisSameAs<0, 2>]>;
-def riscv_pasub : RVSDNode<"PASUB", SDT_RISCVPASUB>;
-def riscv_pasubu : RVSDNode<"PASUBU", SDT_RISCVPASUB>;
+def SDT_RISCVPBinOp : SDTypeProfile<1, 2, [SDTCisVec<0>,
+ SDTCisInt<0>,
+ SDTCisSameAs<0, 1>,
+ SDTCisSameAs<0, 2>]>;
+def riscv_pasub : RVSDNode<"PASUB", SDT_RISCVPBinOp>;
+def riscv_pasubu : RVSDNode<"PASUBU", SDT_RISCVPBinOp>;
+def riscv_pmulhsu : RVSDNode<"PMULHSU", SDT_RISCVPBinOp>;
+def riscv_pmulhr : RVSDNode<"PMULHR", SDT_RISCVPBinOp>;
+def riscv_pmulhru : RVSDNode<"PMULHRU", SDT_RISCVPBinOp>;
+def riscv_pmulhrsu : RVSDNode<"PMULHRSU", SDT_RISCVPBinOp>;
let Predicates = [HasStdExtP] in {
def : PatGpr<abs, ABS>;
@@ -1513,6 +1517,16 @@ let Predicates = [HasStdExtP] in {
def: Pat<(XLenVecI16VT (abds GPR:$rs1, GPR:$rs2)), (PABD_H GPR:$rs1, GPR:$rs2)>;
def: Pat<(XLenVecI16VT (abdu GPR:$rs1, GPR:$rs2)), (PABDU_H GPR:$rs1, GPR:$rs2)>;
+ // 16-bit multiply high patterns
+ def: Pat<(XLenVecI16VT (mulhs GPR:$rs1, GPR:$rs2)), (PMULH_H GPR:$rs1, GPR:$rs2)>;
+ def: Pat<(XLenVecI16VT (mulhu GPR:$rs1, GPR:$rs2)), (PMULHU_H GPR:$rs1, GPR:$rs2)>;
+ def: Pat<(XLenVecI16VT (riscv_pmulhsu GPR:$rs1, GPR:$rs2)), (PMULHSU_H GPR:$rs1, GPR:$rs2)>;
+
+ // 16-bit multiply high rounding patterns
+ def: Pat<(XLenVecI16VT (riscv_pmulhr GPR:$rs1, GPR:$rs2)), (PMULHR_H GPR:$rs1, GPR:$rs2)>;
+ def: Pat<(XLenVecI16VT (riscv_pmulhru GPR:$rs1, GPR:$rs2)), (PMULHRU_H GPR:$rs1, GPR:$rs2)>;
+ def: Pat<(XLenVecI16VT (riscv_pmulhrsu GPR:$rs1, GPR:$rs2)), (PMULHRSU_H GPR:$rs1, GPR:$rs2)>;
+
// 8-bit logical shift left patterns
def: Pat<(XLenVecI8VT (shl GPR:$rs1, (XLenVecI8VT (splat_vector uimm3:$shamt)))),
(PSLLI_B GPR:$rs1, uimm3:$shamt)>;
@@ -1580,6 +1594,16 @@ let Predicates = [HasStdExtP, IsRV64] in {
def: Pat<(v2i32 (riscv_pasub GPR:$rs1, GPR:$rs2)), (PASUB_W GPR:$rs1, GPR:$rs2)>;
def: Pat<(v2i32 (riscv_pasubu GPR:$rs1, GPR:$rs2)), (PASUBU_W GPR:$rs1, GPR:$rs2)>;
+ // 32-bit multiply high patterns
+ def: Pat<(v2i32 (mulhs GPR:$rs1, GPR:$rs2)), (PMULH_W GPR:$rs1, GPR:$rs2)>;
+ def: Pat<(v2i32 (mulhu GPR:$rs1, GPR:$rs2)), (PMULHU_W GPR:$rs1, GPR:$rs2)>;
+ def: Pat<(v2i32 (riscv_pmulhsu GPR:$rs1, GPR:$rs2)), (PMULHSU_W GPR:$rs1, GPR:$rs2)>;
+
+ // 32-bit multiply high rounding patterns
+ def: Pat<(v2i32 (riscv_pmulhr GPR:$rs1, GPR:$rs2)), (PMULHR_W GPR:$rs1, GPR:$rs2)>;
+ def: Pat<(v2i32 (riscv_pmulhru GPR:$rs1, GPR:$rs2)), (PMULHRU_W GPR:$rs1, GPR:$rs2)>;
+ def: Pat<(v2i32 (riscv_pmulhrsu GPR:$rs1, GPR:$rs2)), (PMULHRSU_W GPR:$rs1, GPR:$rs2)>;
+
// 32-bit logical shift left
def: Pat<(v2i32 (shl GPR:$rs1, (v2i32 (splat_vector (XLenVT GPR:$rs2))))),
(PSLL_WS GPR:$rs1, GPR:$rs2)>;
diff --git a/llvm/test/CodeGen/RISCV/rvp-ext-rv32.ll b/llvm/test/CodeGen/RISCV/rvp-ext-rv32.ll
index cd59aa03597e2..33127c3d140fa 100644
--- a/llvm/test/CodeGen/RISCV/rvp-ext-rv32.ll
+++ b/llvm/test/CodeGen/RISCV/rvp-ext-rv32.ll
@@ -746,3 +746,126 @@ define void @test_psll_bs_vec_shamt(ptr %ret_ptr, ptr %a_ptr, ptr %shamt_ptr) {
store <4 x i8> %res, ptr %ret_ptr
ret void
}
+
+; Test packed multiply high signed for v2i16
+define void @test_pmulh_h(ptr %ret_ptr, ptr %a_ptr, ptr %b_ptr) {
+; CHECK-LABEL: test_pmulh_h:
+; CHECK: # %bb.0:
+; CHECK-NEXT: lw a1, 0(a1)
+; CHECK-NEXT: lw a2, 0(a2)
+; CHECK-NEXT: pmulh.h a1, a1, a2
+; CHECK-NEXT: sw a1, 0(a0)
+; CHECK-NEXT: ret
+ %a = load <2 x i16>, ptr %a_ptr
+ %b = load <2 x i16>, ptr %b_ptr
+ %a_ext = sext <2 x i16> %a to <2 x i32>
+ %b_ext = sext <2 x i16> %b to <2 x i32>
+ %mul = mul <2 x i32> %a_ext, %b_ext
+ %shift = lshr <2 x i32> %mul, <i32 16, i32 16>
+ %res = trunc <2 x i32> %shift to <2 x i16>
+ store <2 x i16> %res, ptr %ret_ptr
+ ret void
+}
+
+; Test packed multiply high unsigned for v2i16
+define void @test_pmulhu_h(ptr %ret_ptr, ptr %a_ptr, ptr %b_ptr) {
+; CHECK-LABEL: test_pmulhu_h:
+; CHECK: # %bb.0:
+; CHECK-NEXT: lw a1, 0(a1)
+; CHECK-NEXT: lw a2, 0(a2)
+; CHECK-NEXT: pmulhu.h a1, a1, a2
+; CHECK-NEXT: sw a1, 0(a0)
+; CHECK-NEXT: ret
+ %a = load <2 x i16>, ptr %a_ptr
+ %b = load <2 x i16>, ptr %b_ptr
+ %a_ext = zext <2 x i16> %a to <2 x i32>
+ %b_ext = zext <2 x i16> %b to <2 x i32>
+ %mul = mul <2 x i32> %a_ext, %b_ext
+ %shift = lshr <2 x i32> %mul, <i32 16, i32 16>
+ %res = trunc <2 x i32> %shift to <2 x i16>
+ store <2 x i16> %res, ptr %ret_ptr
+ ret void
+}
+
+; Test packed multiply high signed-unsigned for v2i16
+define void @test_pmulhsu_h(ptr %ret_ptr, ptr %a_ptr, ptr %b_ptr) {
+; CHECK-LABEL: test_pmulhsu_h:
+; CHECK: # %bb.0:
+; CHECK-NEXT: lw a1, 0(a1)
+; CHECK-NEXT: lw a2, 0(a2)
+; CHECK-NEXT: pmulhsu.h a1, a1, a2
+; CHECK-NEXT: sw a1, 0(a0)
+; CHECK-NEXT: ret
+ %a = load <2 x i16>, ptr %a_ptr
+ %b = load <2 x i16>, ptr %b_ptr
+ %a_ext = sext <2 x i16> %a to <2 x i32>
+ %b_ext = zext <2 x i16> %b to <2 x i32>
+ %mul = mul <2 x i32> %a_ext, %b_ext
+ %shift = lshr <2 x i32> %mul, <i32 16, i32 16>
+ %res = trunc <2 x i32> %shift to <2 x i16>
+ store <2 x i16> %res, ptr %ret_ptr
+ ret void
+}
+
+; Test packed multiply high rounding signed for v2i16
+define void @test_pmulhr_h(ptr %ret_ptr, ptr %a_ptr, ptr %b_ptr) {
+; CHECK-LABEL: test_pmulhr_h:
+; CHECK: # %bb.0:
+; CHECK-NEXT: lw a1, 0(a1)
+; CHECK-NEXT: lw a2, 0(a2)
+; CHECK-NEXT: pmulhr.h a1, a1, a2
+; CHECK-NEXT: sw a1, 0(a0)
+; CHECK-NEXT: ret
+ %a = load <2 x i16>, ptr %a_ptr
+ %b = load <2 x i16>, ptr %b_ptr
+ %a_ext = sext <2 x i16> %a to <2 x i32>
+ %b_ext = sext <2 x i16> %b to <2 x i32>
+ %mul = mul <2 x i32> %a_ext, %b_ext
+ %add = add <2 x i32> %mul, <i32 32768, i32 32768>
+ %shift = lshr <2 x i32> %add, <i32 16, i32 16>
+ %res = trunc <2 x i32> %shift to <2 x i16>
+ store <2 x i16> %res, ptr %ret_ptr
+ ret void
+}
+
+; Test packed multiply high rounding unsigned for v2i16
+define void @test_pmulhru_h(ptr %ret_ptr, ptr %a_ptr, ptr %b_ptr) {
+; CHECK-LABEL: test_pmulhru_h:
+; CHECK: # %bb.0:
+; CHECK-NEXT: lw a1, 0(a1)
+; CHECK-NEXT: lw a2, 0(a2)
+; CHECK-NEXT: pmulhru.h a1, a1, a2
+; CHECK-NEXT: sw a1, 0(a0)
+; CHECK-NEXT: ret
+ %a = load <2 x i16>, ptr %a_ptr
+ %b = load <2 x i16>, ptr %b_ptr
+ %a_ext = zext <2 x i16> %a to <2 x i32>
+ %b_ext = zext <2 x i16> %b to <2 x i32>
+ %mul = mul <2 x i32> %a_ext, %b_ext
+ %add = add <2 x i32> %mul, <i32 32768, i32 32768>
+ %shift = lshr <2 x i32> %add, <i32 16, i32 16>
+ %res = trunc <2 x i32> %shift to <2 x i16>
+ store <2 x i16> %res, ptr %ret_ptr
+ ret void
+}
+
+; Test packed multiply high rounding signed-unsigned for v2i16
+define void @test_pmulhrsu_h(ptr %ret_ptr, ptr %a_ptr, ptr %b_ptr) {
+; CHECK-LABEL: test_pmulhrsu_h:
+; CHECK: # %bb.0:
+; CHECK-NEXT: lw a1, 0(a1)
+; CHECK-NEXT: lw a2, 0(a2)
+; CHECK-NEXT: pmulhrsu.h a1, a1, a2
+; CHECK-NEXT: sw a1, 0(a0)
+; CHECK-NEXT: ret
+ %a = load <2 x i16>, ptr %a_ptr
+ %b = load <2 x i16>, ptr %b_ptr
+ %a_ext = sext <2 x i16> %a to <2 x i32>
+ %b_ext = zext <2 x i16> %b to <2 x i32>
+ %mul = mul <2 x i32> %a_ext, %b_ext
+ %add = add <2 x i32> %mul, <i32 32768, i32 32768>
+ %shift = lshr <2 x i32> %add, <i32 16, i32 16>
+ %res = trunc <2 x i32> %shift to <2 x i16>
+ store <2 x i16> %res, ptr %ret_ptr
+ ret void
+}
diff --git a/llvm/test/CodeGen/RISCV/rvp-ext-rv64.ll b/llvm/test/CodeGen/RISCV/rvp-ext-rv64.ll
index c7fb891cdd996..8a741f5821b70 100644
--- a/llvm/test/CodeGen/RISCV/rvp-ext-rv64.ll
+++ b/llvm/test/CodeGen/RISCV/rvp-ext-rv64.ll
@@ -841,3 +841,245 @@ define void @test_psll_ws_vec_shamt(ptr %ret_ptr, ptr %a_ptr, ptr %shamt_ptr) {
store <2 x i32> %res, ptr %ret_ptr
ret void
}
+
+; Test packed multiply high signed
+define void @test_pmulh_h(ptr %ret_ptr, ptr %a_ptr, ptr %b_ptr) {
+; CHECK-LABEL: test_pmulh_h:
+; CHECK: # %bb.0:
+; CHECK-NEXT: ld a1, 0(a1)
+; CHECK-NEXT: ld a2, 0(a2)
+; CHECK-NEXT: pmulh.h a1, a1, a2
+; CHECK-NEXT: sd a1, 0(a0)
+; CHECK-NEXT: ret
+ %a = load <4 x i16>, ptr %a_ptr
+ %b = load <4 x i16>, ptr %b_ptr
+ %a_ext = sext <4 x i16> %a to <4 x i32>
+ %b_ext = sext <4 x i16> %b to <4 x i32>
+ %mul = mul <4 x i32> %a_ext, %b_ext
+ %shift = lshr <4 x i32> %mul, <i32 16, i32 16, i32 16, i32 16>
+ %res = trunc <4 x i32> %shift to <4 x i16>
+ store <4 x i16> %res, ptr %ret_ptr
+ ret void
+}
+
+define void @test_pmulh_w(ptr %ret_ptr, ptr %a_ptr, ptr %b_ptr) {
+; CHECK-LABEL: test_pmulh_w:
+; CHECK: # %bb.0:
+; CHECK-NEXT: ld a1, 0(a1)
+; CHECK-NEXT: ld a2, 0(a2)
+; CHECK-NEXT: pmulh.w a1, a1, a2
+; CHECK-NEXT: sd a1, 0(a0)
+; CHECK-NEXT: ret
+ %a = load <2 x i32>, ptr %a_ptr
+ %b = load <2 x i32>, ptr %b_ptr
+ %a_ext = sext <2 x i32> %a to <2 x i64>
+ %b_ext = sext <2 x i32> %b to <2 x i64>
+ %mul = mul <2 x i64> %a_ext, %b_ext
+ %shift = lshr <2 x i64> %mul, <i64 32, i64 32>
+ %res = trunc <2 x i64> %shift to <2 x i32>
+ store <2 x i32> %res, ptr %ret_ptr
+ ret void
+}
+
+
+; Test packed multiply high unsigned
+define void @test_pmulhu_h(ptr %ret_ptr, ptr %a_ptr, ptr %b_ptr) {
+; CHECK-LABEL: test_pmulhu_h:
+; CHECK: # %bb.0:
+; CHECK-NEXT: ld a1, 0(a1)
+; CHECK-NEXT: ld a2, 0(a2)
+; CHECK-NEXT: pmulhu.h a1, a1, a2
+; CHECK-NEXT: sd a1, 0(a0)
+; CHECK-NEXT: ret
+ %a = load <4 x i16>, ptr %a_ptr
+ %b = load <4 x i16>, ptr %b_ptr
+ %a_ext = zext <4 x i16> %a to <4 x i32>
+ %b_ext = zext <4 x i16> %b to <4 x i32>
+ %mul = mul <4 x i32> %a_ext, %b_ext
+ %shift = lshr <4 x i32> %mul, <i32 16, i32 16, i32 16, i32 16>
+ %res = trunc <4 x i32> %shift to <4 x i16>
+ store <4 x i16> %res, ptr %ret_ptr
+ ret void
+}
+
+define void @test_pmulhu_w(ptr %ret_ptr, ptr %a_ptr, ptr %b_ptr) {
+; CHECK-LABEL: test_pmulhu_w:
+; CHECK: # %bb.0:
+; CHECK-NEXT: ld a1, 0(a1)
+; CHECK-NEXT: ld a2, 0(a2)
+; CHECK-NEXT: pmulhu.w a1, a1, a2
+; CHECK-NEXT: sd a1, 0(a0)
+; CHECK-NEXT: ret
+ %a = load <2 x i32>, ptr %a_ptr
+ %b = load <2 x i32>, ptr %b_ptr
+ %a_ext = zext <2 x i32> %a to <2 x i64>
+ %b_ext = zext <2 x i32> %b to <2 x i64>
+ %mul = mul <2 x i64> %a_ext, %b_ext
+ %shift = lshr <2 x i64> %mul, <i64 32, i64 32>
+ %res = trunc <2 x i64> %shift to <2 x i32>
+ store <2 x i32> %res, ptr %ret_ptr
+ ret void
+}
+
+
+; Test packed multiply high signed-unsigned
+define void @test_pmulhsu_h(ptr %ret_ptr, ptr %a_ptr, ptr %b_ptr) {
+; CHECK-LABEL: test_pmulhsu_h:
+; CHECK: # %bb.0:
+; CHECK-NEXT: ld a1, 0(a1)
+; CHECK-NEXT: ld a2, 0(a2)
+; CHECK-NEXT: pmulhsu.h a1, a1, a2
+; CHECK-NEXT: sd a1, 0(a0)
+; CHECK-NEXT: ret
+ %a = load <4 x i16>, ptr %a_ptr
+ %b = load <4 x i16>, ptr %b_ptr
+ %a_ext = sext <4 x i16> %a to <4 x i32>
+ %b_ext = zext <4 x i16> %b to <4 x i32>
+ %mul = mul <4 x i32> %a_ext, %b_ext
+ %shift = lshr <4 x i32> %mul, <i32 16, i32 16, i32 16, i32 16>
+ %res = trunc <4 x i32> %shift to <4 x i16>
+ store <4 x i16> %res, ptr %ret_ptr
+ ret void
+}
+
+define void @test_pmulhsu_w(ptr %ret_ptr, ptr %a_ptr, ptr %b_ptr) {
+; CHECK-LABEL: test_pmulhsu_w:
+; CHECK: # %bb.0:
+; CHECK-NEXT: ld a1, 0(a1)
+; CHECK-NEXT: ld a2, 0(a2)
+; CHECK-NEXT: pmulhsu.w a1, a1, a2
+; CHECK-NEXT: sd a1, 0(a0)
+; CHECK-NEXT: ret
+ %a = load <2 x i32>, ptr %a_ptr
+ %b = load <2 x i32>, ptr %b_ptr
+ %a_ext = sext <2 x i32> %a to <2 x i64>
+ %b_ext = zext <2 x i32> %b to <2 x i64>
+ %mul = mul <2 x i64> %a_ext, %b_ext
+ %shift = lshr <2 x i64> %mul, <i64 32, i64 32>
+ %res = trunc <2 x i64> %shift to <2 x i32>
+ store <2 x i32> %res, ptr %ret_ptr
+ ret void
+}
+
+; Test packed multiply high rounding signed
+define void @test_pmulhr_h(ptr %ret_ptr, ptr %a_ptr, ptr %b_ptr) {
+; CHECK-LABEL: test_pmulhr_h:
+; CHECK: # %bb.0:
+; CHECK-NEXT: ld a1, 0(a1)
+; CHECK-NEXT: ld a2, 0(a2)
+; CHECK-NEXT: pmulhr.h a1, a1, a2
+; CHECK-NEXT: sd a1, 0(a0)
+; CHECK-NEXT: ret
+ %a = load <4 x i16>, ptr %a_ptr
+ %b = load <4 x i16>, ptr %b_ptr
+ %a_ext = sext <4 x i16> %a to <4 x i32>
+ %b_ext = sext <4 x i16> %b to <4 x i32>
+ %mul = mul <4 x i32> %a_ext, %b_ext
+ %add = add <4 x i32> %mul, <i32 32768, i32 32768, i32 32768, i32 32768>
+ %shift = lshr <4 x i32> %add, <i32 16, i32 16, i32 16, i32 16>
+ %res = trunc <4 x i32> %shift to <4 x i16>
+ store <4 x i16> %res, ptr %ret_ptr
+ ret void
+}
+
+define void @test_pmulhr_w(ptr %ret_ptr, ptr %a_ptr, ptr %b_ptr) {
+; CHECK-LABEL: test_pmulhr_w:
+; CHECK: # %bb.0:
+; CHECK-NEXT: ld a1, 0(a1)
+; CHECK-NEXT: ld a2, 0(a2)
+; CHECK-NEXT: pmulhr.w a1, a1, a2
+; CHECK-NEXT: sd a1, 0(a0)
+; CHECK-NEXT: ret
+ %a = load <2 x i32>, ptr %a_ptr
+ %b = load <2 x i32>, ptr %b_ptr
+ %a_ext = sext <2 x i32> %a to <2 x i64>
+ %b_ext = sext <2 x i32> %b to <2 x i64>
+ %mul = mul <2 x i64> %a_ext, %b_ext
+ %add = add <2 x i64> %mul, <i64 2147483648, i64 2147483648>
+ %shift = lshr <2 x i64> %add, <i64 32, i64 32>
+ %res = trunc <2 x i64> %shift to <2 x i32>
+ store <2 x i32> %res, ptr %ret_ptr
+ ret void
+}
+
+; Test packed multiply high rounding unsigned
+define void @test_pmulhru_h(ptr %ret_ptr, ptr %a_ptr, ptr %b_ptr) {
+; CHECK-LABEL: test_pmulhru_h:
+; CHECK: # %bb.0:
+; CHECK-NEXT: ld a1, 0(a1)
+; CHECK-NEXT: ld a2, 0(a2)
+; CHECK-NEXT: pmulhru.h a1, a1, a2
+; CHECK-NEXT: sd a1, 0(a0)
+; CHECK-NEXT: ret
+ %a = load <4 x i16>, ptr %a_ptr
+ %b = load <4 x i16>, ptr %b_ptr
+ %a_ext = zext <4 x i16> %a to <4 x i32>
+ %b_ext = zext <4 x i16> %b to <4 x i32>
+ %mul = mul <4 x i32> %a_ext, %b_ext
+ %add = add <4 x i32> %mul, <i32 32768, i32 32768, i32 32768, i32 32768>
+ %shift = lshr <4 x i32> %add, <i32 16, i32 16, i32 16, i32 16>
+ %res = trunc <4 x i32> %shift to <4 x i16>
+ s...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/171593
More information about the llvm-commits
mailing list