[llvm] [LoongArch] Lower build_vector to broadcast load if possible (PR #135896)

via llvm-commits llvm-commits at lists.llvm.org
Tue Apr 15 20:58:50 PDT 2025


https://github.com/tangaac updated https://github.com/llvm/llvm-project/pull/135896

>From 4e18d62eb0497400925017a9178a52fba08c9133 Mon Sep 17 00:00:00 2001
From: tangaac <tangyan01 at loongson.cn>
Date: Wed, 16 Apr 2025 10:01:17 +0800
Subject: [PATCH 1/3] Lower build_vector to broadcast load if possible

---
 .../LoongArch/LoongArchISelLowering.cpp       | 45 +++++++++++++++++++
 .../Target/LoongArch/LoongArchISelLowering.h  |  5 ++-
 .../Target/LoongArch/LoongArchInstrInfo.td    |  9 ++--
 .../LoongArch/LoongArchLASXInstrInfo.td       |  6 +++
 .../Target/LoongArch/LoongArchLSXInstrInfo.td | 19 ++++++++
 5 files changed, 80 insertions(+), 4 deletions(-)

diff --git a/llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp b/llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp
index 002d88cbeeba3..8c5e095dea039 100644
--- a/llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp
+++ b/llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp
@@ -1721,6 +1721,47 @@ static bool isConstantOrUndefBUILD_VECTOR(const BuildVectorSDNode *Op) {
   return false;
 }
 
+// Lower BUILD_VECTOR as broadcast load (if possible).
+// For example:
+//   %a = load i8, ptr %ptr
+//   %b = build_vector %a, %a, %a, %a
+// is lowered to :
+//   (VLDREPL_B $a0, 0)
+static SDValue lowerBUILD_VECTORAsBroadCastLoad(BuildVectorSDNode *BVOp,
+                                                const SDLoc &DL,
+                                                SelectionDAG &DAG) {
+  MVT VT = BVOp->getSimpleValueType(0);
+  int NumOps = BVOp->getNumOperands();
+
+  assert((VT.is128BitVector() || VT.is256BitVector() || VT.is512BitVector()) &&
+         "Unsupported vector type for broadcast.");
+
+  SDValue IdentitySrc;
+  bool IsIdeneity = true;
+
+  for (int i = 0; i != NumOps; i++) {
+    SDValue Op = BVOp->getOperand(i);
+    if (Op.getOpcode() != ISD::LOAD || (IdentitySrc && Op != IdentitySrc)) {
+      IsIdeneity = false;
+      break;
+    }
+    IdentitySrc = BVOp->getOperand(0);
+  }
+
+  if (IsIdeneity) {
+    auto *LN = cast<LoadSDNode>(IdentitySrc);
+    SDVTList Tys =
+        LN->isIndexed()
+            ? DAG.getVTList(VT, LN->getBasePtr().getValueType(), MVT::Other)
+            : DAG.getVTList(VT, MVT::Other);
+    SDValue Ops[] = {LN->getChain(), LN->getBasePtr(), LN->getOffset()};
+    SDValue BCast = DAG.getNode(LoongArchISD::VLDREPL, DL, Tys, Ops);
+    DAG.ReplaceAllUsesOfValueWith(SDValue(LN, 1), BCast.getValue(1));
+    return BCast;
+  }
+  return SDValue();
+}
+
 SDValue LoongArchTargetLowering::lowerBUILD_VECTOR(SDValue Op,
                                                    SelectionDAG &DAG) const {
   BuildVectorSDNode *Node = cast<BuildVectorSDNode>(Op);
@@ -1736,6 +1777,9 @@ SDValue LoongArchTargetLowering::lowerBUILD_VECTOR(SDValue Op,
       (!Subtarget.hasExtLASX() || !Is256Vec))
     return SDValue();
 
+  if (SDValue Result = lowerBUILD_VECTORAsBroadCastLoad(Node, DL, DAG))
+    return Result;
+
   if (Node->isConstantSplat(SplatValue, SplatUndef, SplatBitSize, HasAnyUndefs,
                             /*MinSplatBits=*/8) &&
       SplatBitSize <= 64) {
@@ -5171,6 +5215,7 @@ const char *LoongArchTargetLowering::getTargetNodeName(unsigned Opcode) const {
     NODE_NAME_CASE(VSRLI)
     NODE_NAME_CASE(VBSLL)
     NODE_NAME_CASE(VBSRL)
+    NODE_NAME_CASE(VLDREPL)
   }
 #undef NODE_NAME_CASE
   return nullptr;
diff --git a/llvm/lib/Target/LoongArch/LoongArchISelLowering.h b/llvm/lib/Target/LoongArch/LoongArchISelLowering.h
index 52d88b9b24a6b..71243a4f0d708 100644
--- a/llvm/lib/Target/LoongArch/LoongArchISelLowering.h
+++ b/llvm/lib/Target/LoongArch/LoongArchISelLowering.h
@@ -155,7 +155,10 @@ enum NodeType : unsigned {
 
   // Vector byte logicial left / right shift
   VBSLL,
-  VBSRL
+  VBSRL,
+
+  // Scalar load broadcast to vector
+  VLDREPL
 
   // Intrinsic operations end =============================================
 };
diff --git a/llvm/lib/Target/LoongArch/LoongArchInstrInfo.td b/llvm/lib/Target/LoongArch/LoongArchInstrInfo.td
index e4feaa600c57d..775d9289af7c4 100644
--- a/llvm/lib/Target/LoongArch/LoongArchInstrInfo.td
+++ b/llvm/lib/Target/LoongArch/LoongArchInstrInfo.td
@@ -307,7 +307,8 @@ def simm8_lsl # I : Operand<GRLenVT> {
 }
 }
 
-def simm9_lsl3 : Operand<GRLenVT> {
+def simm9_lsl3 : Operand<GRLenVT>,
+                 ImmLeaf<GRLenVT, [{return isShiftedInt<9,3>(Imm);}]> {
   let ParserMatchClass = SImmAsmOperand<9, "lsl3">;
   let EncoderMethod = "getImmOpValueAsr<3>";
   let DecoderMethod = "decodeSImmOperand<9, 3>";
@@ -317,13 +318,15 @@ def simm10 : Operand<GRLenVT> {
   let ParserMatchClass = SImmAsmOperand<10>;
 }
 
-def simm10_lsl2 : Operand<GRLenVT> {
+def simm10_lsl2 : Operand<GRLenVT>,
+                  ImmLeaf<GRLenVT, [{return isShiftedInt<10,2>(Imm);}]> {
   let ParserMatchClass = SImmAsmOperand<10, "lsl2">;
   let EncoderMethod = "getImmOpValueAsr<2>";
   let DecoderMethod = "decodeSImmOperand<10, 2>";
 }
 
-def simm11_lsl1 : Operand<GRLenVT> {
+def simm11_lsl1 : Operand<GRLenVT>,
+                  ImmLeaf<GRLenVT, [{return isShiftedInt<11,1>(Imm);}]> {
   let ParserMatchClass = SImmAsmOperand<11, "lsl1">;
   let EncoderMethod = "getImmOpValueAsr<1>";
   let DecoderMethod = "decodeSImmOperand<11, 1>";
diff --git a/llvm/lib/Target/LoongArch/LoongArchLASXInstrInfo.td b/llvm/lib/Target/LoongArch/LoongArchLASXInstrInfo.td
index d6d532cddb594..54fad8421378b 100644
--- a/llvm/lib/Target/LoongArch/LoongArchLASXInstrInfo.td
+++ b/llvm/lib/Target/LoongArch/LoongArchLASXInstrInfo.td
@@ -2161,6 +2161,7 @@ def : Pat<(int_loongarch_lasx_xvld GPR:$rj, timm:$imm),
 def : Pat<(int_loongarch_lasx_xvldx GPR:$rj, GPR:$rk),
           (XVLDX GPR:$rj, GPR:$rk)>;
 
+// xvldrepl
 def : Pat<(int_loongarch_lasx_xvldrepl_b GPR:$rj, timm:$imm),
           (XVLDREPL_B GPR:$rj, (to_valid_timm timm:$imm))>;
 def : Pat<(int_loongarch_lasx_xvldrepl_h GPR:$rj, timm:$imm),
@@ -2170,6 +2171,11 @@ def : Pat<(int_loongarch_lasx_xvldrepl_w GPR:$rj, timm:$imm),
 def : Pat<(int_loongarch_lasx_xvldrepl_d GPR:$rj, timm:$imm),
           (XVLDREPL_D GPR:$rj, (to_valid_timm timm:$imm))>;
 
+defm : VldreplPat<v32i8, XVLDREPL_B, simm12_addlike>;
+defm : VldreplPat<v16i16, XVLDREPL_H, simm11_lsl1>;
+defm : VldreplPat<v8i32, XVLDREPL_W, simm10_lsl2>;
+defm : VldreplPat<v4i64, XVLDREPL_D, simm9_lsl3>;
+
 // store
 def : Pat<(int_loongarch_lasx_xvst LASX256:$xd, GPR:$rj, timm:$imm),
           (XVST LASX256:$xd, GPR:$rj, (to_valid_timm timm:$imm))>;
diff --git a/llvm/lib/Target/LoongArch/LoongArchLSXInstrInfo.td b/llvm/lib/Target/LoongArch/LoongArchLSXInstrInfo.td
index b0d880749bf92..2b44361df29ba 100644
--- a/llvm/lib/Target/LoongArch/LoongArchLSXInstrInfo.td
+++ b/llvm/lib/Target/LoongArch/LoongArchLSXInstrInfo.td
@@ -26,6 +26,7 @@ def SDT_LoongArchV1RUimm: SDTypeProfile<1, 2, [SDTCisVec<0>,
 def SDT_LoongArchVreplgr2vr : SDTypeProfile<1, 1, [SDTCisInt<0>, SDTCisVec<0>, SDTCisInt<1>]>;
 def SDT_LoongArchVFRECIPE : SDTypeProfile<1, 1, [SDTCisFP<0>, SDTCisVec<0>, SDTCisSameAs<0, 1>]>;
 def SDT_LoongArchVFRSQRTE : SDTypeProfile<1, 1, [SDTCisFP<0>, SDTCisVec<0>, SDTCisSameAs<0, 1>]>;
+def SDT_LoongArchVLDREPL : SDTypeProfile<1, 1, [SDTCisVec<0>, SDTCisPtrTy<1>]>;
 
 // Target nodes.
 def loongarch_vreplve : SDNode<"LoongArchISD::VREPLVE", SDT_LoongArchVreplve>;
@@ -64,6 +65,10 @@ def loongarch_vsrli : SDNode<"LoongArchISD::VSRLI", SDT_LoongArchV1RUimm>;
 def loongarch_vbsll : SDNode<"LoongArchISD::VBSLL", SDT_LoongArchV1RUimm>;
 def loongarch_vbsrl : SDNode<"LoongArchISD::VBSRL", SDT_LoongArchV1RUimm>;
 
+def loongarch_vldrepl
+    : SDNode<"LoongArchISD::VLDREPL",
+             SDT_LoongArchVLDREPL, [SDNPHasChain, SDNPMayLoad, SDNPMemOperand]>;
+
 def immZExt1 : ImmLeaf<i64, [{return isUInt<1>(Imm);}]>;
 def immZExt2 : ImmLeaf<i64, [{return isUInt<2>(Imm);}]>;
 def immZExt3 : ImmLeaf<i64, [{return isUInt<3>(Imm);}]>;
@@ -1433,6 +1438,14 @@ multiclass PatCCVrVrF<CondCode CC, string Inst> {
             (!cast<LAInst>(Inst#"_D") LSX128:$vj, LSX128:$vk)>;
 }
 
+multiclass VldreplPat<ValueType vt, LAInst Inst, Operand ImmOpnd> {
+  def : Pat<(vt(loongarch_vldrepl BaseAddr:$rj)), (Inst BaseAddr:$rj, 0)>;
+  def : Pat<(vt(loongarch_vldrepl(AddrConstant GPR:$rj, ImmOpnd:$imm))),
+            (Inst GPR:$rj, ImmOpnd:$imm)>;
+  def : Pat<(vt(loongarch_vldrepl(AddLike BaseAddr:$rj, ImmOpnd:$imm))),
+            (Inst BaseAddr:$rj, ImmOpnd:$imm)>;
+}
+
 let Predicates = [HasExtLSX] in {
 
 // VADD_{B/H/W/D}
@@ -2338,6 +2351,7 @@ def : Pat<(int_loongarch_lsx_vld GPR:$rj, timm:$imm),
 def : Pat<(int_loongarch_lsx_vldx GPR:$rj, GPR:$rk),
           (VLDX GPR:$rj, GPR:$rk)>;
 
+// vldrepl
 def : Pat<(int_loongarch_lsx_vldrepl_b GPR:$rj, timm:$imm),
           (VLDREPL_B GPR:$rj, (to_valid_timm timm:$imm))>;
 def : Pat<(int_loongarch_lsx_vldrepl_h GPR:$rj, timm:$imm),
@@ -2347,6 +2361,11 @@ def : Pat<(int_loongarch_lsx_vldrepl_w GPR:$rj, timm:$imm),
 def : Pat<(int_loongarch_lsx_vldrepl_d GPR:$rj, timm:$imm),
           (VLDREPL_D GPR:$rj, (to_valid_timm timm:$imm))>;
 
+defm : VldreplPat<v16i8, VLDREPL_B, simm12_addlike>;
+defm : VldreplPat<v8i16, VLDREPL_H, simm11_lsl1>;
+defm : VldreplPat<v4i32, VLDREPL_W, simm10_lsl2>;
+defm : VldreplPat<v2i64, VLDREPL_D, simm9_lsl3>;
+
 // store
 def : Pat<(int_loongarch_lsx_vst LSX128:$vd, GPR:$rj, timm:$imm),
           (VST LSX128:$vd, GPR:$rj, (to_valid_timm timm:$imm))>;

>From 60977f3558b31b21de90c6061496c64029c219b8 Mon Sep 17 00:00:00 2001
From: tangaac <tangyan01 at loongson.cn>
Date: Wed, 16 Apr 2025 10:46:56 +0800
Subject: [PATCH 2/3] add v4f32, v2f64, v8f32, v4f64 support

---
 llvm/lib/Target/LoongArch/LoongArchLASXInstrInfo.td | 2 ++
 llvm/lib/Target/LoongArch/LoongArchLSXInstrInfo.td  | 2 ++
 2 files changed, 4 insertions(+)

diff --git a/llvm/lib/Target/LoongArch/LoongArchLASXInstrInfo.td b/llvm/lib/Target/LoongArch/LoongArchLASXInstrInfo.td
index 54fad8421378b..6b2071392d5c5 100644
--- a/llvm/lib/Target/LoongArch/LoongArchLASXInstrInfo.td
+++ b/llvm/lib/Target/LoongArch/LoongArchLASXInstrInfo.td
@@ -2175,6 +2175,8 @@ defm : VldreplPat<v32i8, XVLDREPL_B, simm12_addlike>;
 defm : VldreplPat<v16i16, XVLDREPL_H, simm11_lsl1>;
 defm : VldreplPat<v8i32, XVLDREPL_W, simm10_lsl2>;
 defm : VldreplPat<v4i64, XVLDREPL_D, simm9_lsl3>;
+defm : VldreplPat<v8f32, XVLDREPL_W, simm10_lsl2>;
+defm : VldreplPat<v4f64, XVLDREPL_D, simm9_lsl3>;
 
 // store
 def : Pat<(int_loongarch_lasx_xvst LASX256:$xd, GPR:$rj, timm:$imm),
diff --git a/llvm/lib/Target/LoongArch/LoongArchLSXInstrInfo.td b/llvm/lib/Target/LoongArch/LoongArchLSXInstrInfo.td
index 2b44361df29ba..d3910a21876d3 100644
--- a/llvm/lib/Target/LoongArch/LoongArchLSXInstrInfo.td
+++ b/llvm/lib/Target/LoongArch/LoongArchLSXInstrInfo.td
@@ -2365,6 +2365,8 @@ defm : VldreplPat<v16i8, VLDREPL_B, simm12_addlike>;
 defm : VldreplPat<v8i16, VLDREPL_H, simm11_lsl1>;
 defm : VldreplPat<v4i32, VLDREPL_W, simm10_lsl2>;
 defm : VldreplPat<v2i64, VLDREPL_D, simm9_lsl3>;
+defm : VldreplPat<v4f32, VLDREPL_W, simm10_lsl2>;
+defm : VldreplPat<v2f64, VLDREPL_D, simm9_lsl3>;
 
 // store
 def : Pat<(int_loongarch_lsx_vst LSX128:$vd, GPR:$rj, timm:$imm),

>From e412c722f7b3efc500f6a0c2770b23ea6b310e0b Mon Sep 17 00:00:00 2001
From: tangaac <tangyan01 at loongson.cn>
Date: Wed, 16 Apr 2025 11:48:13 +0800
Subject: [PATCH 3/3] add extra judgement

---
 llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp | 3 +++
 1 file changed, 3 insertions(+)

diff --git a/llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp b/llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp
index 8c5e095dea039..46f8af1c5590d 100644
--- a/llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp
+++ b/llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp
@@ -1748,6 +1748,9 @@ static SDValue lowerBUILD_VECTORAsBroadCastLoad(BuildVectorSDNode *BVOp,
     IdentitySrc = BVOp->getOperand(0);
   }
 
+  if (!BVOp->isOnlyUserOf(IdentitySrc.getNode()))
+    return SDValue();
+
   if (IsIdeneity) {
     auto *LN = cast<LoadSDNode>(IdentitySrc);
     SDVTList Tys =



More information about the llvm-commits mailing list