[llvm] [LoongArch] Custom lower vecreduce_add. (PR #154304)
via llvm-commits
llvm-commits at lists.llvm.org
Tue Aug 19 02:45:26 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-backend-loongarch
Author: None (tangaac)
<details>
<summary>Changes</summary>
---
Full diff: https://github.com/llvm/llvm-project/pull/154304.diff
6 Files Affected:
- (modified) llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp (+48)
- (modified) llvm/lib/Target/LoongArch/LoongArchISelLowering.h (+4)
- (modified) llvm/lib/Target/LoongArch/LoongArchLASXInstrInfo.td (+14)
- (modified) llvm/lib/Target/LoongArch/LoongArchLSXInstrInfo.td (+16)
- (modified) llvm/lib/Target/LoongArch/LoongArchTargetTransformInfo.cpp (+9)
- (modified) llvm/lib/Target/LoongArch/LoongArchTargetTransformInfo.h (+2)
``````````diff
diff --git a/llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp b/llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp
index 39315f05b8388..aead459633ff8 100644
--- a/llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp
+++ b/llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp
@@ -309,6 +309,7 @@ LoongArchTargetLowering::LoongArchTargetLowering(const TargetMachine &TM,
setOperationAction(ISD::SCALAR_TO_VECTOR, VT, Custom);
setOperationAction(ISD::ABDS, VT, Legal);
setOperationAction(ISD::ABDU, VT, Legal);
+ setOperationAction(ISD::VECREDUCE_ADD, VT, Custom);
}
for (MVT VT : {MVT::v16i8, MVT::v8i16, MVT::v4i32})
setOperationAction(ISD::BITREVERSE, VT, Custom);
@@ -376,6 +377,7 @@ LoongArchTargetLowering::LoongArchTargetLowering(const TargetMachine &TM,
setOperationAction(ISD::SCALAR_TO_VECTOR, VT, Custom);
setOperationAction(ISD::ABDS, VT, Legal);
setOperationAction(ISD::ABDU, VT, Legal);
+ setOperationAction(ISD::VECREDUCE_ADD, VT, Custom);
}
for (MVT VT : {MVT::v32i8, MVT::v16i16, MVT::v8i32})
setOperationAction(ISD::BITREVERSE, VT, Custom);
@@ -521,10 +523,55 @@ SDValue LoongArchTargetLowering::LowerOperation(SDValue Op,
return lowerFP_TO_BF16(Op, DAG);
case ISD::BF16_TO_FP:
return lowerBF16_TO_FP(Op, DAG);
+ case ISD::VECREDUCE_ADD:
+ return lowerVECREDUCE_ADD(Op, DAG);
}
return SDValue();
}
+SDValue LoongArchTargetLowering::lowerVECREDUCE_ADD(SDValue Op,
+ SelectionDAG &DAG) const {
+
+ SDLoc DL(Op);
+ MVT OpVT = Op.getSimpleValueType();
+ SDValue Val = Op.getOperand(0);
+ MVT ValTy = Val.getSimpleValueType().getScalarType();
+
+ SDValue Idx = DAG.getConstant(0, DL, Subtarget.getGRLenVT());
+ unsigned EC = Val.getSimpleValueType().getVectorNumElements();
+
+ switch (ValTy.SimpleTy) {
+ default:
+ llvm_unreachable("Unexpected value type!");
+ case MVT::i8:
+ Val = DAG.getNode(LoongArchISD::VHADDW, DL, MVT::getVectorVT(MVT::i8, EC),
+ Val, Val);
+ EC = EC / 2;
+ LLVM_FALLTHROUGH;
+ case MVT::i16:
+ Val = DAG.getNode(LoongArchISD::VHADDW, DL, MVT::getVectorVT(MVT::i16, EC),
+ Val, Val);
+ EC = EC / 2;
+ LLVM_FALLTHROUGH;
+ case MVT::i32:
+ Val = DAG.getNode(LoongArchISD::VHADDW, DL, MVT::getVectorVT(MVT::i32, EC),
+ Val, Val);
+ EC = EC / 2;
+ LLVM_FALLTHROUGH;
+ case MVT::i64:
+ Val = DAG.getNode(LoongArchISD::VHADDW, DL, MVT::getVectorVT(MVT::i64, EC),
+ Val, Val);
+ }
+
+ if (Subtarget.hasExtLASX()) {
+ SDValue Tmp = DAG.getNode(LoongArchISD::XVPERMI, DL, MVT::v4i64, Val,
+ DAG.getConstant(2, DL, MVT::i64));
+ Val = DAG.getNode(ISD::ADD, DL, MVT::v4i64, Tmp, Val);
+ }
+
+ return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, OpVT, Val, Idx);
+}
+
SDValue LoongArchTargetLowering::lowerPREFETCH(SDValue Op,
SelectionDAG &DAG) const {
unsigned IsData = Op.getConstantOperandVal(4);
@@ -6651,6 +6698,7 @@ const char *LoongArchTargetLowering::getTargetNodeName(unsigned Opcode) const {
NODE_NAME_CASE(XVMSKGEZ)
NODE_NAME_CASE(XVMSKEQZ)
NODE_NAME_CASE(XVMSKNEZ)
+ NODE_NAME_CASE(VHADDW)
}
#undef NODE_NAME_CASE
return nullptr;
diff --git a/llvm/lib/Target/LoongArch/LoongArchISelLowering.h b/llvm/lib/Target/LoongArch/LoongArchISelLowering.h
index f79ba7450cc36..40e237b1c69e4 100644
--- a/llvm/lib/Target/LoongArch/LoongArchISelLowering.h
+++ b/llvm/lib/Target/LoongArch/LoongArchISelLowering.h
@@ -177,6 +177,9 @@ enum NodeType : unsigned {
XVMSKEQZ,
XVMSKNEZ,
+ // Vector Horizontal Addition with Widening‌
+ VHADDW
+
// Intrinsic operations end =============================================
};
} // end namespace LoongArchISD
@@ -386,6 +389,7 @@ class LoongArchTargetLowering : public TargetLowering {
SDValue lowerFP16_TO_FP(SDValue Op, SelectionDAG &DAG) const;
SDValue lowerFP_TO_BF16(SDValue Op, SelectionDAG &DAG) const;
SDValue lowerBF16_TO_FP(SDValue Op, SelectionDAG &DAG) const;
+ SDValue lowerVECREDUCE_ADD(SDValue Op, SelectionDAG &DAG) const;
bool isFPImmLegal(const APFloat &Imm, EVT VT,
bool ForCodeSize) const override;
diff --git a/llvm/lib/Target/LoongArch/LoongArchLASXInstrInfo.td b/llvm/lib/Target/LoongArch/LoongArchLASXInstrInfo.td
index 0696b11d62ac9..2d3d77f8cbc31 100644
--- a/llvm/lib/Target/LoongArch/LoongArchLASXInstrInfo.td
+++ b/llvm/lib/Target/LoongArch/LoongArchLASXInstrInfo.td
@@ -1186,6 +1186,17 @@ multiclass PatXrXrXr<SDPatternOperator OpNode, string Inst> {
(!cast<LAInst>(Inst#"_D") LASX256:$xd, LASX256:$xj, LASX256:$xk)>;
}
+multiclass PatXrXrW<SDPatternOperator OpNode, string Inst> {
+ def : Pat<(OpNode(v32i8 LASX256:$vj), (v32i8 LASX256:$vk)),
+ (!cast<LAInst>(Inst#"_H_B") LASX256:$vj, LASX256:$vk)>;
+ def : Pat<(OpNode(v16i16 LASX256:$vj), (v16i16 LASX256:$vk)),
+ (!cast<LAInst>(Inst#"_W_H") LASX256:$vj, LASX256:$vk)>;
+ def : Pat<(OpNode(v8i32 LASX256:$vj), (v8i32 LASX256:$vk)),
+ (!cast<LAInst>(Inst#"_D_W") LASX256:$vj, LASX256:$vk)>;
+ def : Pat<(OpNode(v4i64 LASX256:$vj), (v4i64 LASX256:$vk)),
+ (!cast<LAInst>(Inst#"_Q_D") LASX256:$vj, LASX256:$vk)>;
+}
+
multiclass PatShiftXrXr<SDPatternOperator OpNode, string Inst> {
def : Pat<(OpNode (v32i8 LASX256:$xj), (and vsplati8_imm_eq_7,
(v32i8 LASX256:$xk))),
@@ -1513,6 +1524,9 @@ def : Pat<(bswap (v8i32 LASX256:$xj)), (XVSHUF4I_B LASX256:$xj, 0b00011011)>;
def : Pat<(bswap (v4i64 LASX256:$xj)),
(XVSHUF4I_W (XVSHUF4I_B LASX256:$xj, 0b00011011), 0b10110001)>;
+// XVHADDW.{H.B/W.H/D.W/Q.D}
+defm : PatXrXrW<loongarch_vhaddw, "XVHADDW">;
+
// XVFADD_{S/D}
defm : PatXrXrF<fadd, "XVFADD">;
diff --git a/llvm/lib/Target/LoongArch/LoongArchLSXInstrInfo.td b/llvm/lib/Target/LoongArch/LoongArchLSXInstrInfo.td
index 3c9defb0366ff..6cd8553c245e7 100644
--- a/llvm/lib/Target/LoongArch/LoongArchLSXInstrInfo.td
+++ b/llvm/lib/Target/LoongArch/LoongArchLSXInstrInfo.td
@@ -71,6 +71,8 @@ def loongarch_vsrli : SDNode<"LoongArchISD::VSRLI", SDT_LoongArchV1RUimm>;
def loongarch_vbsll : SDNode<"LoongArchISD::VBSLL", SDT_LoongArchV1RUimm>;
def loongarch_vbsrl : SDNode<"LoongArchISD::VBSRL", SDT_LoongArchV1RUimm>;
+def loongarch_vhaddw : SDNode<"LoongArchISD::VHADDW", SDT_LoongArchV2R>;
+
def loongarch_vldrepl
: SDNode<"LoongArchISD::VLDREPL",
SDT_LoongArchVLDREPL, [SDNPHasChain, SDNPMayLoad, SDNPMemOperand]>;
@@ -1364,6 +1366,17 @@ multiclass PatVrVrVr<SDPatternOperator OpNode, string Inst> {
(!cast<LAInst>(Inst#"_D") LSX128:$vd, LSX128:$vj, LSX128:$vk)>;
}
+multiclass PatVrVrW<SDPatternOperator OpNode, string Inst> {
+ def : Pat<(OpNode(v16i8 LSX128:$vj), (v16i8 LSX128:$vk)),
+ (!cast<LAInst>(Inst#"_H_B") LSX128:$vj, LSX128:$vk)>;
+ def : Pat<(OpNode(v8i16 LSX128:$vj), (v8i16 LSX128:$vk)),
+ (!cast<LAInst>(Inst#"_W_H") LSX128:$vj, LSX128:$vk)>;
+ def : Pat<(OpNode(v4i32 LSX128:$vj), (v4i32 LSX128:$vk)),
+ (!cast<LAInst>(Inst#"_D_W") LSX128:$vj, LSX128:$vk)>;
+ def : Pat<(OpNode(v2i64 LSX128:$vj), (v2i64 LSX128:$vk)),
+ (!cast<LAInst>(Inst#"_Q_D") LSX128:$vj, LSX128:$vk)>;
+}
+
multiclass PatShiftVrVr<SDPatternOperator OpNode, string Inst> {
def : Pat<(OpNode (v16i8 LSX128:$vj), (and vsplati8_imm_eq_7,
(v16i8 LSX128:$vk))),
@@ -1709,6 +1722,9 @@ def : Pat<(bswap (v4i32 LSX128:$vj)), (VSHUF4I_B LSX128:$vj, 0b00011011)>;
def : Pat<(bswap (v2i64 LSX128:$vj)),
(VSHUF4I_W (VSHUF4I_B LSX128:$vj, 0b00011011), 0b10110001)>;
+// VHADDW.{H.B/W.H/D.W/Q.D}
+defm : PatVrVrW<loongarch_vhaddw, "VHADDW">;
+
// VFADD_{S/D}
defm : PatVrVrF<fadd, "VFADD">;
diff --git a/llvm/lib/Target/LoongArch/LoongArchTargetTransformInfo.cpp b/llvm/lib/Target/LoongArch/LoongArchTargetTransformInfo.cpp
index ede5477f04bda..efe898c33072e 100644
--- a/llvm/lib/Target/LoongArch/LoongArchTargetTransformInfo.cpp
+++ b/llvm/lib/Target/LoongArch/LoongArchTargetTransformInfo.cpp
@@ -95,4 +95,13 @@ unsigned LoongArchTTIImpl::getPrefetchDistance() const { return 200; }
bool LoongArchTTIImpl::enableWritePrefetching() const { return true; }
+bool LoongArchTTIImpl::shouldExpandReduction(const IntrinsicInst *II) const {
+ switch (II->getIntrinsicID()) {
+ default:
+ return true;
+ case Intrinsic::vector_reduce_add:
+ return false;
+ }
+}
+
// TODO: Implement more hooks to provide TTI machinery for LoongArch.
diff --git a/llvm/lib/Target/LoongArch/LoongArchTargetTransformInfo.h b/llvm/lib/Target/LoongArch/LoongArchTargetTransformInfo.h
index d43d2cb0eb124..e3f16c7804994 100644
--- a/llvm/lib/Target/LoongArch/LoongArchTargetTransformInfo.h
+++ b/llvm/lib/Target/LoongArch/LoongArchTargetTransformInfo.h
@@ -53,6 +53,8 @@ class LoongArchTTIImpl : public BasicTTIImplBase<LoongArchTTIImpl> {
unsigned getPrefetchDistance() const override;
bool enableWritePrefetching() const override;
+ bool shouldExpandReduction(const IntrinsicInst *II) const override;
+
// TODO: Implement more hooks to provide TTI machinery for LoongArch.
};
``````````
</details>
https://github.com/llvm/llvm-project/pull/154304
More information about the llvm-commits
mailing list