[llvm] [LoongArch] Custom lower vecreduce_add. (PR #154304)

via llvm-commits llvm-commits at lists.llvm.org
Wed Aug 20 02:47:22 PDT 2025


https://github.com/tangaac updated https://github.com/llvm/llvm-project/pull/154304

>From a77de1c97b5b6f2e58847acfb2e25b3e4268ed9b Mon Sep 17 00:00:00 2001
From: tangaac <tangyan01 at loongson.cn>
Date: Tue, 19 Aug 2025 16:36:21 +0800
Subject: [PATCH 1/3] Custom lower vecreduce_add

---
 .../LoongArch/LoongArchISelLowering.cpp       | 48 +++++++++++++++++++
 .../Target/LoongArch/LoongArchISelLowering.h  |  4 ++
 .../LoongArch/LoongArchLASXInstrInfo.td       | 14 ++++++
 .../Target/LoongArch/LoongArchLSXInstrInfo.td | 16 +++++++
 .../LoongArchTargetTransformInfo.cpp          |  9 ++++
 .../LoongArch/LoongArchTargetTransformInfo.h  |  2 +
 6 files changed, 93 insertions(+)

diff --git a/llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp b/llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp
index 5b2d185594f44..67011c78a337f 100644
--- a/llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp
+++ b/llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp
@@ -310,6 +310,7 @@ LoongArchTargetLowering::LoongArchTargetLowering(const TargetMachine &TM,
       setOperationAction(ISD::SCALAR_TO_VECTOR, VT, Custom);
       setOperationAction(ISD::ABDS, VT, Legal);
       setOperationAction(ISD::ABDU, VT, Legal);
+      setOperationAction(ISD::VECREDUCE_ADD, VT, Custom);
     }
     for (MVT VT : {MVT::v16i8, MVT::v8i16, MVT::v4i32})
       setOperationAction(ISD::BITREVERSE, VT, Custom);
@@ -377,6 +378,7 @@ LoongArchTargetLowering::LoongArchTargetLowering(const TargetMachine &TM,
       setOperationAction(ISD::SCALAR_TO_VECTOR, VT, Custom);
       setOperationAction(ISD::ABDS, VT, Legal);
       setOperationAction(ISD::ABDU, VT, Legal);
+      setOperationAction(ISD::VECREDUCE_ADD, VT, Custom);
     }
     for (MVT VT : {MVT::v32i8, MVT::v16i16, MVT::v8i32})
       setOperationAction(ISD::BITREVERSE, VT, Custom);
@@ -522,10 +524,55 @@ SDValue LoongArchTargetLowering::LowerOperation(SDValue Op,
     return lowerFP_TO_BF16(Op, DAG);
   case ISD::BF16_TO_FP:
     return lowerBF16_TO_FP(Op, DAG);
+  case ISD::VECREDUCE_ADD:
+    return lowerVECREDUCE_ADD(Op, DAG);
   }
   return SDValue();
 }
 
+SDValue LoongArchTargetLowering::lowerVECREDUCE_ADD(SDValue Op,
+                                                    SelectionDAG &DAG) const {
+
+  SDLoc DL(Op);
+  MVT OpVT = Op.getSimpleValueType();
+  SDValue Val = Op.getOperand(0);
+  MVT ValTy = Val.getSimpleValueType().getScalarType();
+
+  SDValue Idx = DAG.getConstant(0, DL, Subtarget.getGRLenVT());
+  unsigned EC = Val.getSimpleValueType().getVectorNumElements();
+
+  switch (ValTy.SimpleTy) {
+  default:
+    llvm_unreachable("Unexpected value type!");
+  case MVT::i8:
+    Val = DAG.getNode(LoongArchISD::VHADDW, DL, MVT::getVectorVT(MVT::i8, EC),
+                      Val, Val);
+    EC = EC / 2;
+    LLVM_FALLTHROUGH;
+  case MVT::i16:
+    Val = DAG.getNode(LoongArchISD::VHADDW, DL, MVT::getVectorVT(MVT::i16, EC),
+                      Val, Val);
+    EC = EC / 2;
+    LLVM_FALLTHROUGH;
+  case MVT::i32:
+    Val = DAG.getNode(LoongArchISD::VHADDW, DL, MVT::getVectorVT(MVT::i32, EC),
+                      Val, Val);
+    EC = EC / 2;
+    LLVM_FALLTHROUGH;
+  case MVT::i64:
+    Val = DAG.getNode(LoongArchISD::VHADDW, DL, MVT::getVectorVT(MVT::i64, EC),
+                      Val, Val);
+  }
+
+  if (Subtarget.hasExtLASX()) {
+    SDValue Tmp = DAG.getNode(LoongArchISD::XVPERMI, DL, MVT::v4i64, Val,
+                              DAG.getConstant(2, DL, MVT::i64));
+    Val = DAG.getNode(ISD::ADD, DL, MVT::v4i64, Tmp, Val);
+  }
+
+  return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, OpVT, Val, Idx);
+}
+
 SDValue LoongArchTargetLowering::lowerPREFETCH(SDValue Op,
                                                SelectionDAG &DAG) const {
   unsigned IsData = Op.getConstantOperandVal(4);
@@ -6659,6 +6706,7 @@ const char *LoongArchTargetLowering::getTargetNodeName(unsigned Opcode) const {
     NODE_NAME_CASE(XVMSKGEZ)
     NODE_NAME_CASE(XVMSKEQZ)
     NODE_NAME_CASE(XVMSKNEZ)
+    NODE_NAME_CASE(VHADDW)
   }
 #undef NODE_NAME_CASE
   return nullptr;
diff --git a/llvm/lib/Target/LoongArch/LoongArchISelLowering.h b/llvm/lib/Target/LoongArch/LoongArchISelLowering.h
index f79ba7450cc36..40e237b1c69e4 100644
--- a/llvm/lib/Target/LoongArch/LoongArchISelLowering.h
+++ b/llvm/lib/Target/LoongArch/LoongArchISelLowering.h
@@ -177,6 +177,9 @@ enum NodeType : unsigned {
   XVMSKEQZ,
   XVMSKNEZ,
 
+  // Vector Horizontal Addition with Widening‌
+  VHADDW
+
   // Intrinsic operations end =============================================
 };
 } // end namespace LoongArchISD
@@ -386,6 +389,7 @@ class LoongArchTargetLowering : public TargetLowering {
   SDValue lowerFP16_TO_FP(SDValue Op, SelectionDAG &DAG) const;
   SDValue lowerFP_TO_BF16(SDValue Op, SelectionDAG &DAG) const;
   SDValue lowerBF16_TO_FP(SDValue Op, SelectionDAG &DAG) const;
+  SDValue lowerVECREDUCE_ADD(SDValue Op, SelectionDAG &DAG) const;
 
   bool isFPImmLegal(const APFloat &Imm, EVT VT,
                     bool ForCodeSize) const override;
diff --git a/llvm/lib/Target/LoongArch/LoongArchLASXInstrInfo.td b/llvm/lib/Target/LoongArch/LoongArchLASXInstrInfo.td
index 0696b11d62ac9..2d3d77f8cbc31 100644
--- a/llvm/lib/Target/LoongArch/LoongArchLASXInstrInfo.td
+++ b/llvm/lib/Target/LoongArch/LoongArchLASXInstrInfo.td
@@ -1186,6 +1186,17 @@ multiclass PatXrXrXr<SDPatternOperator OpNode, string Inst> {
             (!cast<LAInst>(Inst#"_D") LASX256:$xd, LASX256:$xj, LASX256:$xk)>;
 }
 
+multiclass PatXrXrW<SDPatternOperator OpNode, string Inst> {
+  def : Pat<(OpNode(v32i8 LASX256:$vj), (v32i8 LASX256:$vk)),
+            (!cast<LAInst>(Inst#"_H_B") LASX256:$vj, LASX256:$vk)>;
+  def : Pat<(OpNode(v16i16 LASX256:$vj), (v16i16 LASX256:$vk)),
+            (!cast<LAInst>(Inst#"_W_H") LASX256:$vj, LASX256:$vk)>;
+  def : Pat<(OpNode(v8i32 LASX256:$vj), (v8i32 LASX256:$vk)),
+            (!cast<LAInst>(Inst#"_D_W") LASX256:$vj, LASX256:$vk)>;
+  def : Pat<(OpNode(v4i64 LASX256:$vj), (v4i64 LASX256:$vk)),
+            (!cast<LAInst>(Inst#"_Q_D") LASX256:$vj, LASX256:$vk)>;
+}
+
 multiclass PatShiftXrXr<SDPatternOperator OpNode, string Inst> {
   def : Pat<(OpNode (v32i8 LASX256:$xj), (and vsplati8_imm_eq_7,
                                               (v32i8 LASX256:$xk))),
@@ -1513,6 +1524,9 @@ def : Pat<(bswap (v8i32 LASX256:$xj)), (XVSHUF4I_B LASX256:$xj, 0b00011011)>;
 def : Pat<(bswap (v4i64 LASX256:$xj)),
           (XVSHUF4I_W (XVSHUF4I_B LASX256:$xj, 0b00011011), 0b10110001)>;
 
+// XVHADDW.{H.B/W.H/D.W/Q.D}
+defm : PatXrXrW<loongarch_vhaddw, "XVHADDW">;
+
 // XVFADD_{S/D}
 defm : PatXrXrF<fadd, "XVFADD">;
 
diff --git a/llvm/lib/Target/LoongArch/LoongArchLSXInstrInfo.td b/llvm/lib/Target/LoongArch/LoongArchLSXInstrInfo.td
index 3c9defb0366ff..6cd8553c245e7 100644
--- a/llvm/lib/Target/LoongArch/LoongArchLSXInstrInfo.td
+++ b/llvm/lib/Target/LoongArch/LoongArchLSXInstrInfo.td
@@ -71,6 +71,8 @@ def loongarch_vsrli : SDNode<"LoongArchISD::VSRLI", SDT_LoongArchV1RUimm>;
 def loongarch_vbsll : SDNode<"LoongArchISD::VBSLL", SDT_LoongArchV1RUimm>;
 def loongarch_vbsrl : SDNode<"LoongArchISD::VBSRL", SDT_LoongArchV1RUimm>;
 
+def loongarch_vhaddw : SDNode<"LoongArchISD::VHADDW", SDT_LoongArchV2R>;
+
 def loongarch_vldrepl
     : SDNode<"LoongArchISD::VLDREPL",
              SDT_LoongArchVLDREPL, [SDNPHasChain, SDNPMayLoad, SDNPMemOperand]>;
@@ -1364,6 +1366,17 @@ multiclass PatVrVrVr<SDPatternOperator OpNode, string Inst> {
             (!cast<LAInst>(Inst#"_D") LSX128:$vd, LSX128:$vj, LSX128:$vk)>;
 }
 
+multiclass PatVrVrW<SDPatternOperator OpNode, string Inst> {
+  def : Pat<(OpNode(v16i8 LSX128:$vj), (v16i8 LSX128:$vk)),
+            (!cast<LAInst>(Inst#"_H_B") LSX128:$vj, LSX128:$vk)>;
+  def : Pat<(OpNode(v8i16 LSX128:$vj), (v8i16 LSX128:$vk)),
+            (!cast<LAInst>(Inst#"_W_H") LSX128:$vj, LSX128:$vk)>;
+  def : Pat<(OpNode(v4i32 LSX128:$vj), (v4i32 LSX128:$vk)),
+            (!cast<LAInst>(Inst#"_D_W") LSX128:$vj, LSX128:$vk)>;
+  def : Pat<(OpNode(v2i64 LSX128:$vj), (v2i64 LSX128:$vk)),
+            (!cast<LAInst>(Inst#"_Q_D") LSX128:$vj, LSX128:$vk)>;
+}
+
 multiclass PatShiftVrVr<SDPatternOperator OpNode, string Inst> {
   def : Pat<(OpNode (v16i8 LSX128:$vj), (and vsplati8_imm_eq_7,
                                              (v16i8 LSX128:$vk))),
@@ -1709,6 +1722,9 @@ def : Pat<(bswap (v4i32 LSX128:$vj)), (VSHUF4I_B LSX128:$vj, 0b00011011)>;
 def : Pat<(bswap (v2i64 LSX128:$vj)),
           (VSHUF4I_W (VSHUF4I_B LSX128:$vj, 0b00011011), 0b10110001)>;
 
+// VHADDW.{H.B/W.H/D.W/Q.D}
+defm : PatVrVrW<loongarch_vhaddw, "VHADDW">;
+
 // VFADD_{S/D}
 defm : PatVrVrF<fadd, "VFADD">;
 
diff --git a/llvm/lib/Target/LoongArch/LoongArchTargetTransformInfo.cpp b/llvm/lib/Target/LoongArch/LoongArchTargetTransformInfo.cpp
index ede5477f04bda..efe898c33072e 100644
--- a/llvm/lib/Target/LoongArch/LoongArchTargetTransformInfo.cpp
+++ b/llvm/lib/Target/LoongArch/LoongArchTargetTransformInfo.cpp
@@ -95,4 +95,13 @@ unsigned LoongArchTTIImpl::getPrefetchDistance() const { return 200; }
 
 bool LoongArchTTIImpl::enableWritePrefetching() const { return true; }
 
+bool LoongArchTTIImpl::shouldExpandReduction(const IntrinsicInst *II) const {
+  switch (II->getIntrinsicID()) {
+  default:
+    return true;
+  case Intrinsic::vector_reduce_add:
+    return false;
+  }
+}
+
 // TODO: Implement more hooks to provide TTI machinery for LoongArch.
diff --git a/llvm/lib/Target/LoongArch/LoongArchTargetTransformInfo.h b/llvm/lib/Target/LoongArch/LoongArchTargetTransformInfo.h
index d43d2cb0eb124..e3f16c7804994 100644
--- a/llvm/lib/Target/LoongArch/LoongArchTargetTransformInfo.h
+++ b/llvm/lib/Target/LoongArch/LoongArchTargetTransformInfo.h
@@ -53,6 +53,8 @@ class LoongArchTTIImpl : public BasicTTIImplBase<LoongArchTTIImpl> {
   unsigned getPrefetchDistance() const override;
   bool enableWritePrefetching() const override;
 
+  bool shouldExpandReduction(const IntrinsicInst *II) const override;
+
   // TODO: Implement more hooks to provide TTI machinery for LoongArch.
 };
 

>From 67aabf1e43bfe5520ac2f362b70be24f6959125f Mon Sep 17 00:00:00 2001
From: tangaac <tangyan01 at loongson.cn>
Date: Wed, 20 Aug 2025 10:02:49 +0800
Subject: [PATCH 2/3] update tests

---
 .../LoongArch/LoongArchISelLowering.cpp       |  4 +-
 .../CodeGen/LoongArch/lasx/vec-reduce-add.ll  | 73 ++++++-------------
 .../CodeGen/LoongArch/lsx/vec-reduce-add.ll   | 39 ++++------
 3 files changed, 42 insertions(+), 74 deletions(-)

diff --git a/llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp b/llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp
index 67011c78a337f..0edf1ea86ad12 100644
--- a/llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp
+++ b/llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp
@@ -536,12 +536,12 @@ SDValue LoongArchTargetLowering::lowerVECREDUCE_ADD(SDValue Op,
   SDLoc DL(Op);
   MVT OpVT = Op.getSimpleValueType();
   SDValue Val = Op.getOperand(0);
-  MVT ValTy = Val.getSimpleValueType().getScalarType();
+  MVT EleTy = Val.getSimpleValueType().getScalarType();
 
   SDValue Idx = DAG.getConstant(0, DL, Subtarget.getGRLenVT());
   unsigned EC = Val.getSimpleValueType().getVectorNumElements();
 
-  switch (ValTy.SimpleTy) {
+  switch (EleTy.SimpleTy) {
   default:
     llvm_unreachable("Unexpected value type!");
   case MVT::i8:
diff --git a/llvm/test/CodeGen/LoongArch/lasx/vec-reduce-add.ll b/llvm/test/CodeGen/LoongArch/lasx/vec-reduce-add.ll
index bf5effd7b3912..7268eb24ee51c 100644
--- a/llvm/test/CodeGen/LoongArch/lasx/vec-reduce-add.ll
+++ b/llvm/test/CodeGen/LoongArch/lasx/vec-reduce-add.ll
@@ -1,27 +1,18 @@
 ; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
-
 ; RUN: llc --mtriple=loongarch64 --mattr=+lasx %s -o - | FileCheck %s
 
 define void @vec_reduce_add_v32i8(ptr %src, ptr %dst) nounwind {
 ; CHECK-LABEL: vec_reduce_add_v32i8:
 ; CHECK:       # %bb.0:
 ; CHECK-NEXT:    xvld $xr0, $a0, 0
-; CHECK-NEXT:    xvpermi.d $xr1, $xr0, 78
-; CHECK-NEXT:    xvshuf4i.b $xr1, $xr1, 228
-; CHECK-NEXT:    xvadd.b $xr0, $xr0, $xr1
-; CHECK-NEXT:    xvpermi.d $xr1, $xr0, 68
-; CHECK-NEXT:    xvbsrl.v $xr1, $xr1, 8
-; CHECK-NEXT:    xvadd.b $xr0, $xr0, $xr1
-; CHECK-NEXT:    xvpermi.d $xr1, $xr0, 68
-; CHECK-NEXT:    xvsrli.d $xr1, $xr1, 32
-; CHECK-NEXT:    xvadd.b $xr0, $xr0, $xr1
-; CHECK-NEXT:    xvpermi.d $xr1, $xr0, 68
-; CHECK-NEXT:    xvshuf4i.b $xr1, $xr1, 14
-; CHECK-NEXT:    xvadd.b $xr0, $xr0, $xr1
-; CHECK-NEXT:    xvpermi.d $xr1, $xr0, 68
-; CHECK-NEXT:    xvrepl128vei.b $xr1, $xr1, 1
-; CHECK-NEXT:    xvadd.b $xr0, $xr0, $xr1
-; CHECK-NEXT:    xvstelm.b $xr0, $a1, 0, 0
+; CHECK-NEXT:    xvhaddw.h.b $xr0, $xr0, $xr0
+; CHECK-NEXT:    xvhaddw.w.h $xr0, $xr0, $xr0
+; CHECK-NEXT:    xvhaddw.d.w $xr0, $xr0, $xr0
+; CHECK-NEXT:    xvhaddw.q.d $xr0, $xr0, $xr0
+; CHECK-NEXT:    xvpermi.d $xr1, $xr0, 2
+; CHECK-NEXT:    xvadd.d $xr0, $xr1, $xr0
+; CHECK-NEXT:    xvpickve2gr.d $a0, $xr0, 0
+; CHECK-NEXT:    st.b $a0, $a1, 0
 ; CHECK-NEXT:    ret
   %v = load <32 x i8>, ptr %src
   %res = call i8 @llvm.vector.reduce.add.v32i8(<32 x i8> %v)
@@ -33,19 +24,13 @@ define void @vec_reduce_add_v16i16(ptr %src, ptr %dst) nounwind {
 ; CHECK-LABEL: vec_reduce_add_v16i16:
 ; CHECK:       # %bb.0:
 ; CHECK-NEXT:    xvld $xr0, $a0, 0
-; CHECK-NEXT:    xvpermi.d $xr1, $xr0, 78
-; CHECK-NEXT:    xvshuf4i.h $xr1, $xr1, 228
-; CHECK-NEXT:    xvadd.h $xr0, $xr0, $xr1
-; CHECK-NEXT:    xvpermi.d $xr1, $xr0, 68
-; CHECK-NEXT:    xvbsrl.v $xr1, $xr1, 8
-; CHECK-NEXT:    xvadd.h $xr0, $xr0, $xr1
-; CHECK-NEXT:    xvpermi.d $xr1, $xr0, 68
-; CHECK-NEXT:    xvshuf4i.h $xr1, $xr1, 14
-; CHECK-NEXT:    xvadd.h $xr0, $xr0, $xr1
-; CHECK-NEXT:    xvpermi.d $xr1, $xr0, 68
-; CHECK-NEXT:    xvrepl128vei.h $xr1, $xr1, 1
-; CHECK-NEXT:    xvadd.h $xr0, $xr0, $xr1
-; CHECK-NEXT:    xvstelm.h $xr0, $a1, 0, 0
+; CHECK-NEXT:    xvhaddw.w.h $xr0, $xr0, $xr0
+; CHECK-NEXT:    xvhaddw.d.w $xr0, $xr0, $xr0
+; CHECK-NEXT:    xvhaddw.q.d $xr0, $xr0, $xr0
+; CHECK-NEXT:    xvpermi.d $xr1, $xr0, 2
+; CHECK-NEXT:    xvadd.d $xr0, $xr1, $xr0
+; CHECK-NEXT:    xvpickve2gr.d $a0, $xr0, 0
+; CHECK-NEXT:    st.h $a0, $a1, 0
 ; CHECK-NEXT:    ret
   %v = load <16 x i16>, ptr %src
   %res = call i16 @llvm.vector.reduce.add.v16i16(<16 x i16> %v)
@@ -57,16 +42,12 @@ define void @vec_reduce_add_v8i32(ptr %src, ptr %dst) nounwind {
 ; CHECK-LABEL: vec_reduce_add_v8i32:
 ; CHECK:       # %bb.0:
 ; CHECK-NEXT:    xvld $xr0, $a0, 0
-; CHECK-NEXT:    xvpermi.d $xr1, $xr0, 78
-; CHECK-NEXT:    xvshuf4i.w $xr1, $xr1, 228
-; CHECK-NEXT:    xvadd.w $xr0, $xr0, $xr1
-; CHECK-NEXT:    xvpermi.d $xr1, $xr0, 68
-; CHECK-NEXT:    xvshuf4i.w $xr1, $xr1, 14
-; CHECK-NEXT:    xvadd.w $xr0, $xr0, $xr1
-; CHECK-NEXT:    xvpermi.d $xr1, $xr0, 68
-; CHECK-NEXT:    xvrepl128vei.w $xr1, $xr1, 1
-; CHECK-NEXT:    xvadd.w $xr0, $xr0, $xr1
-; CHECK-NEXT:    xvstelm.w $xr0, $a1, 0, 0
+; CHECK-NEXT:    xvhaddw.d.w $xr0, $xr0, $xr0
+; CHECK-NEXT:    xvhaddw.q.d $xr0, $xr0, $xr0
+; CHECK-NEXT:    xvpermi.d $xr1, $xr0, 2
+; CHECK-NEXT:    xvadd.d $xr0, $xr1, $xr0
+; CHECK-NEXT:    xvpickve2gr.d $a0, $xr0, 0
+; CHECK-NEXT:    st.w $a0, $a1, 0
 ; CHECK-NEXT:    ret
   %v = load <8 x i32>, ptr %src
   %res = call i32 @llvm.vector.reduce.add.v8i32(<8 x i32> %v)
@@ -78,14 +59,9 @@ define void @vec_reduce_add_v4i64(ptr %src, ptr %dst) nounwind {
 ; CHECK-LABEL: vec_reduce_add_v4i64:
 ; CHECK:       # %bb.0:
 ; CHECK-NEXT:    xvld $xr0, $a0, 0
-; CHECK-NEXT:    pcalau12i $a0, %pc_hi20(.LCPI3_0)
-; CHECK-NEXT:    xvld $xr1, $a0, %pc_lo12(.LCPI3_0)
-; CHECK-NEXT:    xvpermi.d $xr2, $xr0, 78
-; CHECK-NEXT:    xvshuf.d $xr1, $xr0, $xr2
-; CHECK-NEXT:    xvadd.d $xr0, $xr0, $xr1
-; CHECK-NEXT:    xvpermi.d $xr1, $xr0, 68
-; CHECK-NEXT:    xvrepl128vei.d $xr1, $xr1, 1
-; CHECK-NEXT:    xvadd.d $xr0, $xr0, $xr1
+; CHECK-NEXT:    xvhaddw.q.d $xr0, $xr0, $xr0
+; CHECK-NEXT:    xvpermi.d $xr1, $xr0, 2
+; CHECK-NEXT:    xvadd.d $xr0, $xr1, $xr0
 ; CHECK-NEXT:    xvstelm.d $xr0, $a1, 0, 0
 ; CHECK-NEXT:    ret
   %v = load <4 x i64>, ptr %src
@@ -93,4 +69,3 @@ define void @vec_reduce_add_v4i64(ptr %src, ptr %dst) nounwind {
   store i64 %res, ptr %dst
   ret void
 }
-
diff --git a/llvm/test/CodeGen/LoongArch/lsx/vec-reduce-add.ll b/llvm/test/CodeGen/LoongArch/lsx/vec-reduce-add.ll
index a71bdea917cba..ea80b974f8413 100644
--- a/llvm/test/CodeGen/LoongArch/lsx/vec-reduce-add.ll
+++ b/llvm/test/CodeGen/LoongArch/lsx/vec-reduce-add.ll
@@ -5,15 +5,12 @@ define void @vec_reduce_add_v16i8(ptr %src, ptr %dst) nounwind {
 ; CHECK-LABEL: vec_reduce_add_v16i8:
 ; CHECK:       # %bb.0:
 ; CHECK-NEXT:    vld $vr0, $a0, 0
-; CHECK-NEXT:    vbsrl.v $vr1, $vr0, 8
-; CHECK-NEXT:    vadd.b $vr0, $vr0, $vr1
-; CHECK-NEXT:    vsrli.d $vr1, $vr0, 32
-; CHECK-NEXT:    vadd.b $vr0, $vr0, $vr1
-; CHECK-NEXT:    vshuf4i.b $vr1, $vr0, 14
-; CHECK-NEXT:    vadd.b $vr0, $vr0, $vr1
-; CHECK-NEXT:    vreplvei.b $vr1, $vr0, 1
-; CHECK-NEXT:    vadd.b $vr0, $vr0, $vr1
-; CHECK-NEXT:    vstelm.b $vr0, $a1, 0, 0
+; CHECK-NEXT:    vhaddw.h.b $vr0, $vr0, $vr0
+; CHECK-NEXT:    vhaddw.w.h $vr0, $vr0, $vr0
+; CHECK-NEXT:    vhaddw.d.w $vr0, $vr0, $vr0
+; CHECK-NEXT:    vhaddw.q.d $vr0, $vr0, $vr0
+; CHECK-NEXT:    vpickve2gr.d $a0, $vr0, 0
+; CHECK-NEXT:    st.b $a0, $a1, 0
 ; CHECK-NEXT:    ret
   %v = load <16 x i8>, ptr %src
   %res = call i8 @llvm.vector.reduce.add.v16i8(<16 x i8> %v)
@@ -25,13 +22,11 @@ define void @vec_reduce_add_v8i16(ptr %src, ptr %dst) nounwind {
 ; CHECK-LABEL: vec_reduce_add_v8i16:
 ; CHECK:       # %bb.0:
 ; CHECK-NEXT:    vld $vr0, $a0, 0
-; CHECK-NEXT:    vbsrl.v $vr1, $vr0, 8
-; CHECK-NEXT:    vadd.h $vr0, $vr0, $vr1
-; CHECK-NEXT:    vshuf4i.h $vr1, $vr0, 14
-; CHECK-NEXT:    vadd.h $vr0, $vr0, $vr1
-; CHECK-NEXT:    vreplvei.h $vr1, $vr0, 1
-; CHECK-NEXT:    vadd.h $vr0, $vr0, $vr1
-; CHECK-NEXT:    vstelm.h $vr0, $a1, 0, 0
+; CHECK-NEXT:    vhaddw.w.h $vr0, $vr0, $vr0
+; CHECK-NEXT:    vhaddw.d.w $vr0, $vr0, $vr0
+; CHECK-NEXT:    vhaddw.q.d $vr0, $vr0, $vr0
+; CHECK-NEXT:    vpickve2gr.d $a0, $vr0, 0
+; CHECK-NEXT:    st.h $a0, $a1, 0
 ; CHECK-NEXT:    ret
   %v = load <8 x i16>, ptr %src
   %res = call i16 @llvm.vector.reduce.add.v8i16(<8 x i16> %v)
@@ -43,11 +38,10 @@ define void @vec_reduce_add_v4i32(ptr %src, ptr %dst) nounwind {
 ; CHECK-LABEL: vec_reduce_add_v4i32:
 ; CHECK:       # %bb.0:
 ; CHECK-NEXT:    vld $vr0, $a0, 0
-; CHECK-NEXT:    vshuf4i.w $vr1, $vr0, 14
-; CHECK-NEXT:    vadd.w $vr0, $vr0, $vr1
-; CHECK-NEXT:    vreplvei.w $vr1, $vr0, 1
-; CHECK-NEXT:    vadd.w $vr0, $vr0, $vr1
-; CHECK-NEXT:    vstelm.w $vr0, $a1, 0, 0
+; CHECK-NEXT:    vhaddw.d.w $vr0, $vr0, $vr0
+; CHECK-NEXT:    vhaddw.q.d $vr0, $vr0, $vr0
+; CHECK-NEXT:    vpickve2gr.d $a0, $vr0, 0
+; CHECK-NEXT:    st.w $a0, $a1, 0
 ; CHECK-NEXT:    ret
   %v = load <4 x i32>, ptr %src
   %res = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> %v)
@@ -59,8 +53,7 @@ define void @vec_reduce_add_v2i64(ptr %src, ptr %dst) nounwind {
 ; CHECK-LABEL: vec_reduce_add_v2i64:
 ; CHECK:       # %bb.0:
 ; CHECK-NEXT:    vld $vr0, $a0, 0
-; CHECK-NEXT:    vreplvei.d $vr1, $vr0, 1
-; CHECK-NEXT:    vadd.d $vr0, $vr0, $vr1
+; CHECK-NEXT:    vhaddw.q.d $vr0, $vr0, $vr0
 ; CHECK-NEXT:    vstelm.d $vr0, $a1, 0, 0
 ; CHECK-NEXT:    ret
   %v = load <2 x i64>, ptr %src

>From 02f7c88c6bcad2ea1d8f39567b6e3edd5600d8b3 Mon Sep 17 00:00:00 2001
From: tangaac <tangyan01 at loongson.cn>
Date: Wed, 20 Aug 2025 17:34:18 +0800
Subject: [PATCH 3/3] deal with illegal vector types

---
 .../LoongArch/LoongArchISelLowering.cpp       | 49 +++++-----
 .../CodeGen/LoongArch/lsx/vec-reduce-add.ll   | 91 +++++++++++++++++++
 2 files changed, 113 insertions(+), 27 deletions(-)

diff --git a/llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp b/llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp
index 0edf1ea86ad12..47243209b2168 100644
--- a/llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp
+++ b/llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp
@@ -310,7 +310,6 @@ LoongArchTargetLowering::LoongArchTargetLowering(const TargetMachine &TM,
       setOperationAction(ISD::SCALAR_TO_VECTOR, VT, Custom);
       setOperationAction(ISD::ABDS, VT, Legal);
       setOperationAction(ISD::ABDU, VT, Legal);
-      setOperationAction(ISD::VECREDUCE_ADD, VT, Custom);
     }
     for (MVT VT : {MVT::v16i8, MVT::v8i16, MVT::v4i32})
       setOperationAction(ISD::BITREVERSE, VT, Custom);
@@ -341,6 +340,7 @@ LoongArchTargetLowering::LoongArchTargetLowering(const TargetMachine &TM,
          {MVT::v16i8, MVT::v8i8, MVT::v4i8, MVT::v2i8, MVT::v8i16, MVT::v4i16,
           MVT::v2i16, MVT::v4i32, MVT::v2i32, MVT::v2i64}) {
       setOperationAction(ISD::TRUNCATE, VT, Custom);
+      setOperationAction(ISD::VECREDUCE_ADD, VT, Custom);
     }
   }
 
@@ -536,41 +536,36 @@ SDValue LoongArchTargetLowering::lowerVECREDUCE_ADD(SDValue Op,
   SDLoc DL(Op);
   MVT OpVT = Op.getSimpleValueType();
   SDValue Val = Op.getOperand(0);
-  MVT EleTy = Val.getSimpleValueType().getScalarType();
 
-  SDValue Idx = DAG.getConstant(0, DL, Subtarget.getGRLenVT());
-  unsigned EC = Val.getSimpleValueType().getVectorNumElements();
+  unsigned NumEles = Val.getSimpleValueType().getVectorNumElements();
+  unsigned EleBits = Val.getSimpleValueType().getScalarSizeInBits();
 
-  switch (EleTy.SimpleTy) {
-  default:
-    llvm_unreachable("Unexpected value type!");
-  case MVT::i8:
-    Val = DAG.getNode(LoongArchISD::VHADDW, DL, MVT::getVectorVT(MVT::i8, EC),
-                      Val, Val);
-    EC = EC / 2;
-    LLVM_FALLTHROUGH;
-  case MVT::i16:
-    Val = DAG.getNode(LoongArchISD::VHADDW, DL, MVT::getVectorVT(MVT::i16, EC),
-                      Val, Val);
-    EC = EC / 2;
-    LLVM_FALLTHROUGH;
-  case MVT::i32:
-    Val = DAG.getNode(LoongArchISD::VHADDW, DL, MVT::getVectorVT(MVT::i32, EC),
-                      Val, Val);
-    EC = EC / 2;
-    LLVM_FALLTHROUGH;
-  case MVT::i64:
-    Val = DAG.getNode(LoongArchISD::VHADDW, DL, MVT::getVectorVT(MVT::i64, EC),
-                      Val, Val);
+  unsigned LegalVecSize = 128;
+  bool is256Vector = Subtarget.hasExtLASX() && Val.getValueSizeInBits() == 256;
+
+  while (!isTypeLegal(Val.getSimpleValueType())) {
+    Val = DAG.WidenVector(Val, DL);
   }
 
-  if (Subtarget.hasExtLASX()) {
+  if (is256Vector) {
+    NumEles /= 2;
+    LegalVecSize = 256;
+  }
+
+  for (unsigned i = 1; i < NumEles; i *= 2, EleBits *= 2) {
+    MVT IntTy = MVT::getIntegerVT(EleBits);
+    MVT VecTy = MVT::getVectorVT(IntTy, LegalVecSize / EleBits);
+    Val = DAG.getNode(LoongArchISD::VHADDW, DL, VecTy, Val, Val);
+  }
+
+  if (is256Vector) {
     SDValue Tmp = DAG.getNode(LoongArchISD::XVPERMI, DL, MVT::v4i64, Val,
                               DAG.getConstant(2, DL, MVT::i64));
     Val = DAG.getNode(ISD::ADD, DL, MVT::v4i64, Tmp, Val);
   }
 
-  return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, OpVT, Val, Idx);
+  return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, OpVT, Val,
+                     DAG.getConstant(0, DL, Subtarget.getGRLenVT()));
 }
 
 SDValue LoongArchTargetLowering::lowerPREFETCH(SDValue Op,
diff --git a/llvm/test/CodeGen/LoongArch/lsx/vec-reduce-add.ll b/llvm/test/CodeGen/LoongArch/lsx/vec-reduce-add.ll
index ea80b974f8413..57fd09ed2e09b 100644
--- a/llvm/test/CodeGen/LoongArch/lsx/vec-reduce-add.ll
+++ b/llvm/test/CodeGen/LoongArch/lsx/vec-reduce-add.ll
@@ -18,6 +18,53 @@ define void @vec_reduce_add_v16i8(ptr %src, ptr %dst) nounwind {
   ret void
 }
 
+define void @vec_reduce_add_v8i8(ptr %src, ptr %dst) nounwind {
+; CHECK-LABEL: vec_reduce_add_v8i8:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    ld.d $a0, $a0, 0
+; CHECK-NEXT:    vinsgr2vr.d $vr0, $a0, 0
+; CHECK-NEXT:    vhaddw.h.b $vr0, $vr0, $vr0
+; CHECK-NEXT:    vhaddw.w.h $vr0, $vr0, $vr0
+; CHECK-NEXT:    vhaddw.d.w $vr0, $vr0, $vr0
+; CHECK-NEXT:    vpickve2gr.w $a0, $vr0, 0
+; CHECK-NEXT:    st.b $a0, $a1, 0
+; CHECK-NEXT:    ret
+  %v = load <8 x i8>, ptr %src
+  %res = call i8 @llvm.vector.reduce.add.v8i8(<8 x i8> %v)
+  store i8 %res, ptr %dst
+  ret void
+}
+
+define void @vec_reduce_add_v4i8(ptr %src, ptr %dst) nounwind {
+; CHECK-LABEL: vec_reduce_add_v4i8:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    ld.w $a0, $a0, 0
+; CHECK-NEXT:    vinsgr2vr.w $vr0, $a0, 0
+; CHECK-NEXT:    vhaddw.h.b $vr0, $vr0, $vr0
+; CHECK-NEXT:    vhaddw.w.h $vr0, $vr0, $vr0
+; CHECK-NEXT:    vpickve2gr.h $a0, $vr0, 0
+; CHECK-NEXT:    st.b $a0, $a1, 0
+; CHECK-NEXT:    ret
+  %v = load <4 x i8>, ptr %src
+  %res = call i8 @llvm.vector.reduce.add.v4i8(<4 x i8> %v)
+  store i8 %res, ptr %dst
+  ret void
+}
+
+define void @vec_reduce_add_v2i8(ptr %src, ptr %dst) nounwind {
+; CHECK-LABEL: vec_reduce_add_v2i8:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    ld.h $a0, $a0, 0
+; CHECK-NEXT:    vinsgr2vr.h $vr0, $a0, 0
+; CHECK-NEXT:    vhaddw.h.b $vr0, $vr0, $vr0
+; CHECK-NEXT:    vstelm.b $vr0, $a1, 0, 0
+; CHECK-NEXT:    ret
+  %v = load <2 x i8>, ptr %src
+  %res = call i8 @llvm.vector.reduce.add.v2i8(<2 x i8> %v)
+  store i8 %res, ptr %dst
+  ret void
+}
+
 define void @vec_reduce_add_v8i16(ptr %src, ptr %dst) nounwind {
 ; CHECK-LABEL: vec_reduce_add_v8i16:
 ; CHECK:       # %bb.0:
@@ -34,6 +81,36 @@ define void @vec_reduce_add_v8i16(ptr %src, ptr %dst) nounwind {
   ret void
 }
 
+define void @vec_reduce_add_v4i16(ptr %src, ptr %dst) nounwind {
+; CHECK-LABEL: vec_reduce_add_v4i16:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    ld.d $a0, $a0, 0
+; CHECK-NEXT:    vinsgr2vr.d $vr0, $a0, 0
+; CHECK-NEXT:    vhaddw.w.h $vr0, $vr0, $vr0
+; CHECK-NEXT:    vhaddw.d.w $vr0, $vr0, $vr0
+; CHECK-NEXT:    vpickve2gr.w $a0, $vr0, 0
+; CHECK-NEXT:    st.h $a0, $a1, 0
+; CHECK-NEXT:    ret
+  %v = load <4 x i16>, ptr %src
+  %res = call i16 @llvm.vector.reduce.add.v4i16(<4 x i16> %v)
+  store i16 %res, ptr %dst
+  ret void
+}
+
+define void @vec_reduce_add_v2i16(ptr %src, ptr %dst) nounwind {
+; CHECK-LABEL: vec_reduce_add_v2i16:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    ld.w $a0, $a0, 0
+; CHECK-NEXT:    vinsgr2vr.w $vr0, $a0, 0
+; CHECK-NEXT:    vhaddw.w.h $vr0, $vr0, $vr0
+; CHECK-NEXT:    vstelm.h $vr0, $a1, 0, 0
+; CHECK-NEXT:    ret
+  %v = load <2 x i16>, ptr %src
+  %res = call i16 @llvm.vector.reduce.add.v2i16(<2 x i16> %v)
+  store i16 %res, ptr %dst
+  ret void
+}
+
 define void @vec_reduce_add_v4i32(ptr %src, ptr %dst) nounwind {
 ; CHECK-LABEL: vec_reduce_add_v4i32:
 ; CHECK:       # %bb.0:
@@ -49,6 +126,20 @@ define void @vec_reduce_add_v4i32(ptr %src, ptr %dst) nounwind {
   ret void
 }
 
+define void @vec_reduce_add_v2i32(ptr %src, ptr %dst) nounwind {
+; CHECK-LABEL: vec_reduce_add_v2i32:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    ld.d $a0, $a0, 0
+; CHECK-NEXT:    vinsgr2vr.d $vr0, $a0, 0
+; CHECK-NEXT:    vhaddw.d.w $vr0, $vr0, $vr0
+; CHECK-NEXT:    vstelm.w $vr0, $a1, 0, 0
+; CHECK-NEXT:    ret
+  %v = load <2 x i32>, ptr %src
+  %res = call i32 @llvm.vector.reduce.add.v2i32(<2 x i32> %v)
+  store i32 %res, ptr %dst
+  ret void
+}
+
 define void @vec_reduce_add_v2i64(ptr %src, ptr %dst) nounwind {
 ; CHECK-LABEL: vec_reduce_add_v2i64:
 ; CHECK:       # %bb.0:



More information about the llvm-commits mailing list