[llvm] [LoongArch] Custom lower vecreduce_add. (PR #154304)

via llvm-commits llvm-commits at lists.llvm.org
Tue Aug 19 02:45:26 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-loongarch

Author: None (tangaac)

<details>
<summary>Changes</summary>



---
Full diff: https://github.com/llvm/llvm-project/pull/154304.diff


6 Files Affected:

- (modified) llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp (+48) 
- (modified) llvm/lib/Target/LoongArch/LoongArchISelLowering.h (+4) 
- (modified) llvm/lib/Target/LoongArch/LoongArchLASXInstrInfo.td (+14) 
- (modified) llvm/lib/Target/LoongArch/LoongArchLSXInstrInfo.td (+16) 
- (modified) llvm/lib/Target/LoongArch/LoongArchTargetTransformInfo.cpp (+9) 
- (modified) llvm/lib/Target/LoongArch/LoongArchTargetTransformInfo.h (+2) 


``````````diff
diff --git a/llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp b/llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp
index 39315f05b8388..aead459633ff8 100644
--- a/llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp
+++ b/llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp
@@ -309,6 +309,7 @@ LoongArchTargetLowering::LoongArchTargetLowering(const TargetMachine &TM,
       setOperationAction(ISD::SCALAR_TO_VECTOR, VT, Custom);
       setOperationAction(ISD::ABDS, VT, Legal);
       setOperationAction(ISD::ABDU, VT, Legal);
+      setOperationAction(ISD::VECREDUCE_ADD, VT, Custom);
     }
     for (MVT VT : {MVT::v16i8, MVT::v8i16, MVT::v4i32})
       setOperationAction(ISD::BITREVERSE, VT, Custom);
@@ -376,6 +377,7 @@ LoongArchTargetLowering::LoongArchTargetLowering(const TargetMachine &TM,
       setOperationAction(ISD::SCALAR_TO_VECTOR, VT, Custom);
       setOperationAction(ISD::ABDS, VT, Legal);
       setOperationAction(ISD::ABDU, VT, Legal);
+      setOperationAction(ISD::VECREDUCE_ADD, VT, Custom);
     }
     for (MVT VT : {MVT::v32i8, MVT::v16i16, MVT::v8i32})
       setOperationAction(ISD::BITREVERSE, VT, Custom);
@@ -521,10 +523,55 @@ SDValue LoongArchTargetLowering::LowerOperation(SDValue Op,
     return lowerFP_TO_BF16(Op, DAG);
   case ISD::BF16_TO_FP:
     return lowerBF16_TO_FP(Op, DAG);
+  case ISD::VECREDUCE_ADD:
+    return lowerVECREDUCE_ADD(Op, DAG);
   }
   return SDValue();
 }
 
+SDValue LoongArchTargetLowering::lowerVECREDUCE_ADD(SDValue Op,
+                                                    SelectionDAG &DAG) const {
+
+  SDLoc DL(Op);
+  MVT OpVT = Op.getSimpleValueType();
+  SDValue Val = Op.getOperand(0);
+  MVT ValTy = Val.getSimpleValueType().getScalarType();
+
+  SDValue Idx = DAG.getConstant(0, DL, Subtarget.getGRLenVT());
+  unsigned EC = Val.getSimpleValueType().getVectorNumElements();
+
+  switch (ValTy.SimpleTy) {
+  default:
+    llvm_unreachable("Unexpected value type!");
+  case MVT::i8:
+    Val = DAG.getNode(LoongArchISD::VHADDW, DL, MVT::getVectorVT(MVT::i8, EC),
+                      Val, Val);
+    EC = EC / 2;
+    LLVM_FALLTHROUGH;
+  case MVT::i16:
+    Val = DAG.getNode(LoongArchISD::VHADDW, DL, MVT::getVectorVT(MVT::i16, EC),
+                      Val, Val);
+    EC = EC / 2;
+    LLVM_FALLTHROUGH;
+  case MVT::i32:
+    Val = DAG.getNode(LoongArchISD::VHADDW, DL, MVT::getVectorVT(MVT::i32, EC),
+                      Val, Val);
+    EC = EC / 2;
+    LLVM_FALLTHROUGH;
+  case MVT::i64:
+    Val = DAG.getNode(LoongArchISD::VHADDW, DL, MVT::getVectorVT(MVT::i64, EC),
+                      Val, Val);
+  }
+
+  if (Subtarget.hasExtLASX()) {
+    SDValue Tmp = DAG.getNode(LoongArchISD::XVPERMI, DL, MVT::v4i64, Val,
+                              DAG.getConstant(2, DL, MVT::i64));
+    Val = DAG.getNode(ISD::ADD, DL, MVT::v4i64, Tmp, Val);
+  }
+
+  return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, OpVT, Val, Idx);
+}
+
 SDValue LoongArchTargetLowering::lowerPREFETCH(SDValue Op,
                                                SelectionDAG &DAG) const {
   unsigned IsData = Op.getConstantOperandVal(4);
@@ -6651,6 +6698,7 @@ const char *LoongArchTargetLowering::getTargetNodeName(unsigned Opcode) const {
     NODE_NAME_CASE(XVMSKGEZ)
     NODE_NAME_CASE(XVMSKEQZ)
     NODE_NAME_CASE(XVMSKNEZ)
+    NODE_NAME_CASE(VHADDW)
   }
 #undef NODE_NAME_CASE
   return nullptr;
diff --git a/llvm/lib/Target/LoongArch/LoongArchISelLowering.h b/llvm/lib/Target/LoongArch/LoongArchISelLowering.h
index f79ba7450cc36..40e237b1c69e4 100644
--- a/llvm/lib/Target/LoongArch/LoongArchISelLowering.h
+++ b/llvm/lib/Target/LoongArch/LoongArchISelLowering.h
@@ -177,6 +177,9 @@ enum NodeType : unsigned {
   XVMSKEQZ,
   XVMSKNEZ,
 
+  // Vector Horizontal Addition with Widening‌
+  VHADDW
+
   // Intrinsic operations end =============================================
 };
 } // end namespace LoongArchISD
@@ -386,6 +389,7 @@ class LoongArchTargetLowering : public TargetLowering {
   SDValue lowerFP16_TO_FP(SDValue Op, SelectionDAG &DAG) const;
   SDValue lowerFP_TO_BF16(SDValue Op, SelectionDAG &DAG) const;
   SDValue lowerBF16_TO_FP(SDValue Op, SelectionDAG &DAG) const;
+  SDValue lowerVECREDUCE_ADD(SDValue Op, SelectionDAG &DAG) const;
 
   bool isFPImmLegal(const APFloat &Imm, EVT VT,
                     bool ForCodeSize) const override;
diff --git a/llvm/lib/Target/LoongArch/LoongArchLASXInstrInfo.td b/llvm/lib/Target/LoongArch/LoongArchLASXInstrInfo.td
index 0696b11d62ac9..2d3d77f8cbc31 100644
--- a/llvm/lib/Target/LoongArch/LoongArchLASXInstrInfo.td
+++ b/llvm/lib/Target/LoongArch/LoongArchLASXInstrInfo.td
@@ -1186,6 +1186,17 @@ multiclass PatXrXrXr<SDPatternOperator OpNode, string Inst> {
             (!cast<LAInst>(Inst#"_D") LASX256:$xd, LASX256:$xj, LASX256:$xk)>;
 }
 
+multiclass PatXrXrW<SDPatternOperator OpNode, string Inst> {
+  def : Pat<(OpNode(v32i8 LASX256:$vj), (v32i8 LASX256:$vk)),
+            (!cast<LAInst>(Inst#"_H_B") LASX256:$vj, LASX256:$vk)>;
+  def : Pat<(OpNode(v16i16 LASX256:$vj), (v16i16 LASX256:$vk)),
+            (!cast<LAInst>(Inst#"_W_H") LASX256:$vj, LASX256:$vk)>;
+  def : Pat<(OpNode(v8i32 LASX256:$vj), (v8i32 LASX256:$vk)),
+            (!cast<LAInst>(Inst#"_D_W") LASX256:$vj, LASX256:$vk)>;
+  def : Pat<(OpNode(v4i64 LASX256:$vj), (v4i64 LASX256:$vk)),
+            (!cast<LAInst>(Inst#"_Q_D") LASX256:$vj, LASX256:$vk)>;
+}
+
 multiclass PatShiftXrXr<SDPatternOperator OpNode, string Inst> {
   def : Pat<(OpNode (v32i8 LASX256:$xj), (and vsplati8_imm_eq_7,
                                               (v32i8 LASX256:$xk))),
@@ -1513,6 +1524,9 @@ def : Pat<(bswap (v8i32 LASX256:$xj)), (XVSHUF4I_B LASX256:$xj, 0b00011011)>;
 def : Pat<(bswap (v4i64 LASX256:$xj)),
           (XVSHUF4I_W (XVSHUF4I_B LASX256:$xj, 0b00011011), 0b10110001)>;
 
+// XVHADDW.{H.B/W.H/D.W/Q.D}
+defm : PatXrXrW<loongarch_vhaddw, "XVHADDW">;
+
 // XVFADD_{S/D}
 defm : PatXrXrF<fadd, "XVFADD">;
 
diff --git a/llvm/lib/Target/LoongArch/LoongArchLSXInstrInfo.td b/llvm/lib/Target/LoongArch/LoongArchLSXInstrInfo.td
index 3c9defb0366ff..6cd8553c245e7 100644
--- a/llvm/lib/Target/LoongArch/LoongArchLSXInstrInfo.td
+++ b/llvm/lib/Target/LoongArch/LoongArchLSXInstrInfo.td
@@ -71,6 +71,8 @@ def loongarch_vsrli : SDNode<"LoongArchISD::VSRLI", SDT_LoongArchV1RUimm>;
 def loongarch_vbsll : SDNode<"LoongArchISD::VBSLL", SDT_LoongArchV1RUimm>;
 def loongarch_vbsrl : SDNode<"LoongArchISD::VBSRL", SDT_LoongArchV1RUimm>;
 
+def loongarch_vhaddw : SDNode<"LoongArchISD::VHADDW", SDT_LoongArchV2R>;
+
 def loongarch_vldrepl
     : SDNode<"LoongArchISD::VLDREPL",
              SDT_LoongArchVLDREPL, [SDNPHasChain, SDNPMayLoad, SDNPMemOperand]>;
@@ -1364,6 +1366,17 @@ multiclass PatVrVrVr<SDPatternOperator OpNode, string Inst> {
             (!cast<LAInst>(Inst#"_D") LSX128:$vd, LSX128:$vj, LSX128:$vk)>;
 }
 
+multiclass PatVrVrW<SDPatternOperator OpNode, string Inst> {
+  def : Pat<(OpNode(v16i8 LSX128:$vj), (v16i8 LSX128:$vk)),
+            (!cast<LAInst>(Inst#"_H_B") LSX128:$vj, LSX128:$vk)>;
+  def : Pat<(OpNode(v8i16 LSX128:$vj), (v8i16 LSX128:$vk)),
+            (!cast<LAInst>(Inst#"_W_H") LSX128:$vj, LSX128:$vk)>;
+  def : Pat<(OpNode(v4i32 LSX128:$vj), (v4i32 LSX128:$vk)),
+            (!cast<LAInst>(Inst#"_D_W") LSX128:$vj, LSX128:$vk)>;
+  def : Pat<(OpNode(v2i64 LSX128:$vj), (v2i64 LSX128:$vk)),
+            (!cast<LAInst>(Inst#"_Q_D") LSX128:$vj, LSX128:$vk)>;
+}
+
 multiclass PatShiftVrVr<SDPatternOperator OpNode, string Inst> {
   def : Pat<(OpNode (v16i8 LSX128:$vj), (and vsplati8_imm_eq_7,
                                              (v16i8 LSX128:$vk))),
@@ -1709,6 +1722,9 @@ def : Pat<(bswap (v4i32 LSX128:$vj)), (VSHUF4I_B LSX128:$vj, 0b00011011)>;
 def : Pat<(bswap (v2i64 LSX128:$vj)),
           (VSHUF4I_W (VSHUF4I_B LSX128:$vj, 0b00011011), 0b10110001)>;
 
+// VHADDW.{H.B/W.H/D.W/Q.D}
+defm : PatVrVrW<loongarch_vhaddw, "VHADDW">;
+
 // VFADD_{S/D}
 defm : PatVrVrF<fadd, "VFADD">;
 
diff --git a/llvm/lib/Target/LoongArch/LoongArchTargetTransformInfo.cpp b/llvm/lib/Target/LoongArch/LoongArchTargetTransformInfo.cpp
index ede5477f04bda..efe898c33072e 100644
--- a/llvm/lib/Target/LoongArch/LoongArchTargetTransformInfo.cpp
+++ b/llvm/lib/Target/LoongArch/LoongArchTargetTransformInfo.cpp
@@ -95,4 +95,13 @@ unsigned LoongArchTTIImpl::getPrefetchDistance() const { return 200; }
 
 bool LoongArchTTIImpl::enableWritePrefetching() const { return true; }
 
+bool LoongArchTTIImpl::shouldExpandReduction(const IntrinsicInst *II) const {
+  switch (II->getIntrinsicID()) {
+  default:
+    return true;
+  case Intrinsic::vector_reduce_add:
+    return false;
+  }
+}
+
 // TODO: Implement more hooks to provide TTI machinery for LoongArch.
diff --git a/llvm/lib/Target/LoongArch/LoongArchTargetTransformInfo.h b/llvm/lib/Target/LoongArch/LoongArchTargetTransformInfo.h
index d43d2cb0eb124..e3f16c7804994 100644
--- a/llvm/lib/Target/LoongArch/LoongArchTargetTransformInfo.h
+++ b/llvm/lib/Target/LoongArch/LoongArchTargetTransformInfo.h
@@ -53,6 +53,8 @@ class LoongArchTTIImpl : public BasicTTIImplBase<LoongArchTTIImpl> {
   unsigned getPrefetchDistance() const override;
   bool enableWritePrefetching() const override;
 
+  bool shouldExpandReduction(const IntrinsicInst *II) const override;
+
   // TODO: Implement more hooks to provide TTI machinery for LoongArch.
 };
 

``````````

</details>


https://github.com/llvm/llvm-project/pull/154304


More information about the llvm-commits mailing list