[llvm] [LoongArch] Lower build_vector to broadcast load if possible (PR #135896)
via llvm-commits
llvm-commits at lists.llvm.org
Wed Apr 16 02:59:07 PDT 2025
https://github.com/tangaac updated https://github.com/llvm/llvm-project/pull/135896
>From 4e18d62eb0497400925017a9178a52fba08c9133 Mon Sep 17 00:00:00 2001
From: tangaac <tangyan01 at loongson.cn>
Date: Wed, 16 Apr 2025 10:01:17 +0800
Subject: [PATCH 1/3] Lower build_vector to broadcast load if possible
---
.../LoongArch/LoongArchISelLowering.cpp | 45 +++++++++++++++++++
.../Target/LoongArch/LoongArchISelLowering.h | 5 ++-
.../Target/LoongArch/LoongArchInstrInfo.td | 9 ++--
.../LoongArch/LoongArchLASXInstrInfo.td | 6 +++
.../Target/LoongArch/LoongArchLSXInstrInfo.td | 19 ++++++++
5 files changed, 80 insertions(+), 4 deletions(-)
diff --git a/llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp b/llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp
index 002d88cbeeba3..8c5e095dea039 100644
--- a/llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp
+++ b/llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp
@@ -1721,6 +1721,47 @@ static bool isConstantOrUndefBUILD_VECTOR(const BuildVectorSDNode *Op) {
return false;
}
+// Lower BUILD_VECTOR as broadcast load (if possible).
+// For example:
+// %a = load i8, ptr %ptr
+// %b = build_vector %a, %a, %a, %a
+// is lowered to :
+// (VLDREPL_B $a0, 0)
+static SDValue lowerBUILD_VECTORAsBroadCastLoad(BuildVectorSDNode *BVOp,
+ const SDLoc &DL,
+ SelectionDAG &DAG) {
+ MVT VT = BVOp->getSimpleValueType(0);
+ int NumOps = BVOp->getNumOperands();
+
+ assert((VT.is128BitVector() || VT.is256BitVector() || VT.is512BitVector()) &&
+ "Unsupported vector type for broadcast.");
+
+ SDValue IdentitySrc;
+ bool IsIdeneity = true;
+
+ for (int i = 0; i != NumOps; i++) {
+ SDValue Op = BVOp->getOperand(i);
+ if (Op.getOpcode() != ISD::LOAD || (IdentitySrc && Op != IdentitySrc)) {
+ IsIdeneity = false;
+ break;
+ }
+ IdentitySrc = BVOp->getOperand(0);
+ }
+
+ if (IsIdeneity) {
+ auto *LN = cast<LoadSDNode>(IdentitySrc);
+ SDVTList Tys =
+ LN->isIndexed()
+ ? DAG.getVTList(VT, LN->getBasePtr().getValueType(), MVT::Other)
+ : DAG.getVTList(VT, MVT::Other);
+ SDValue Ops[] = {LN->getChain(), LN->getBasePtr(), LN->getOffset()};
+ SDValue BCast = DAG.getNode(LoongArchISD::VLDREPL, DL, Tys, Ops);
+ DAG.ReplaceAllUsesOfValueWith(SDValue(LN, 1), BCast.getValue(1));
+ return BCast;
+ }
+ return SDValue();
+}
+
SDValue LoongArchTargetLowering::lowerBUILD_VECTOR(SDValue Op,
SelectionDAG &DAG) const {
BuildVectorSDNode *Node = cast<BuildVectorSDNode>(Op);
@@ -1736,6 +1777,9 @@ SDValue LoongArchTargetLowering::lowerBUILD_VECTOR(SDValue Op,
(!Subtarget.hasExtLASX() || !Is256Vec))
return SDValue();
+ if (SDValue Result = lowerBUILD_VECTORAsBroadCastLoad(Node, DL, DAG))
+ return Result;
+
if (Node->isConstantSplat(SplatValue, SplatUndef, SplatBitSize, HasAnyUndefs,
/*MinSplatBits=*/8) &&
SplatBitSize <= 64) {
@@ -5171,6 +5215,7 @@ const char *LoongArchTargetLowering::getTargetNodeName(unsigned Opcode) const {
NODE_NAME_CASE(VSRLI)
NODE_NAME_CASE(VBSLL)
NODE_NAME_CASE(VBSRL)
+ NODE_NAME_CASE(VLDREPL)
}
#undef NODE_NAME_CASE
return nullptr;
diff --git a/llvm/lib/Target/LoongArch/LoongArchISelLowering.h b/llvm/lib/Target/LoongArch/LoongArchISelLowering.h
index 52d88b9b24a6b..71243a4f0d708 100644
--- a/llvm/lib/Target/LoongArch/LoongArchISelLowering.h
+++ b/llvm/lib/Target/LoongArch/LoongArchISelLowering.h
@@ -155,7 +155,10 @@ enum NodeType : unsigned {
// Vector byte logicial left / right shift
VBSLL,
- VBSRL
+ VBSRL,
+
+ // Scalar load broadcast to vector
+ VLDREPL
// Intrinsic operations end =============================================
};
diff --git a/llvm/lib/Target/LoongArch/LoongArchInstrInfo.td b/llvm/lib/Target/LoongArch/LoongArchInstrInfo.td
index e4feaa600c57d..775d9289af7c4 100644
--- a/llvm/lib/Target/LoongArch/LoongArchInstrInfo.td
+++ b/llvm/lib/Target/LoongArch/LoongArchInstrInfo.td
@@ -307,7 +307,8 @@ def simm8_lsl # I : Operand<GRLenVT> {
}
}
-def simm9_lsl3 : Operand<GRLenVT> {
+def simm9_lsl3 : Operand<GRLenVT>,
+ ImmLeaf<GRLenVT, [{return isShiftedInt<9,3>(Imm);}]> {
let ParserMatchClass = SImmAsmOperand<9, "lsl3">;
let EncoderMethod = "getImmOpValueAsr<3>";
let DecoderMethod = "decodeSImmOperand<9, 3>";
@@ -317,13 +318,15 @@ def simm10 : Operand<GRLenVT> {
let ParserMatchClass = SImmAsmOperand<10>;
}
-def simm10_lsl2 : Operand<GRLenVT> {
+def simm10_lsl2 : Operand<GRLenVT>,
+ ImmLeaf<GRLenVT, [{return isShiftedInt<10,2>(Imm);}]> {
let ParserMatchClass = SImmAsmOperand<10, "lsl2">;
let EncoderMethod = "getImmOpValueAsr<2>";
let DecoderMethod = "decodeSImmOperand<10, 2>";
}
-def simm11_lsl1 : Operand<GRLenVT> {
+def simm11_lsl1 : Operand<GRLenVT>,
+ ImmLeaf<GRLenVT, [{return isShiftedInt<11,1>(Imm);}]> {
let ParserMatchClass = SImmAsmOperand<11, "lsl1">;
let EncoderMethod = "getImmOpValueAsr<1>";
let DecoderMethod = "decodeSImmOperand<11, 1>";
diff --git a/llvm/lib/Target/LoongArch/LoongArchLASXInstrInfo.td b/llvm/lib/Target/LoongArch/LoongArchLASXInstrInfo.td
index d6d532cddb594..54fad8421378b 100644
--- a/llvm/lib/Target/LoongArch/LoongArchLASXInstrInfo.td
+++ b/llvm/lib/Target/LoongArch/LoongArchLASXInstrInfo.td
@@ -2161,6 +2161,7 @@ def : Pat<(int_loongarch_lasx_xvld GPR:$rj, timm:$imm),
def : Pat<(int_loongarch_lasx_xvldx GPR:$rj, GPR:$rk),
(XVLDX GPR:$rj, GPR:$rk)>;
+// xvldrepl
def : Pat<(int_loongarch_lasx_xvldrepl_b GPR:$rj, timm:$imm),
(XVLDREPL_B GPR:$rj, (to_valid_timm timm:$imm))>;
def : Pat<(int_loongarch_lasx_xvldrepl_h GPR:$rj, timm:$imm),
@@ -2170,6 +2171,11 @@ def : Pat<(int_loongarch_lasx_xvldrepl_w GPR:$rj, timm:$imm),
def : Pat<(int_loongarch_lasx_xvldrepl_d GPR:$rj, timm:$imm),
(XVLDREPL_D GPR:$rj, (to_valid_timm timm:$imm))>;
+defm : VldreplPat<v32i8, XVLDREPL_B, simm12_addlike>;
+defm : VldreplPat<v16i16, XVLDREPL_H, simm11_lsl1>;
+defm : VldreplPat<v8i32, XVLDREPL_W, simm10_lsl2>;
+defm : VldreplPat<v4i64, XVLDREPL_D, simm9_lsl3>;
+
// store
def : Pat<(int_loongarch_lasx_xvst LASX256:$xd, GPR:$rj, timm:$imm),
(XVST LASX256:$xd, GPR:$rj, (to_valid_timm timm:$imm))>;
diff --git a/llvm/lib/Target/LoongArch/LoongArchLSXInstrInfo.td b/llvm/lib/Target/LoongArch/LoongArchLSXInstrInfo.td
index b0d880749bf92..2b44361df29ba 100644
--- a/llvm/lib/Target/LoongArch/LoongArchLSXInstrInfo.td
+++ b/llvm/lib/Target/LoongArch/LoongArchLSXInstrInfo.td
@@ -26,6 +26,7 @@ def SDT_LoongArchV1RUimm: SDTypeProfile<1, 2, [SDTCisVec<0>,
def SDT_LoongArchVreplgr2vr : SDTypeProfile<1, 1, [SDTCisInt<0>, SDTCisVec<0>, SDTCisInt<1>]>;
def SDT_LoongArchVFRECIPE : SDTypeProfile<1, 1, [SDTCisFP<0>, SDTCisVec<0>, SDTCisSameAs<0, 1>]>;
def SDT_LoongArchVFRSQRTE : SDTypeProfile<1, 1, [SDTCisFP<0>, SDTCisVec<0>, SDTCisSameAs<0, 1>]>;
+def SDT_LoongArchVLDREPL : SDTypeProfile<1, 1, [SDTCisVec<0>, SDTCisPtrTy<1>]>;
// Target nodes.
def loongarch_vreplve : SDNode<"LoongArchISD::VREPLVE", SDT_LoongArchVreplve>;
@@ -64,6 +65,10 @@ def loongarch_vsrli : SDNode<"LoongArchISD::VSRLI", SDT_LoongArchV1RUimm>;
def loongarch_vbsll : SDNode<"LoongArchISD::VBSLL", SDT_LoongArchV1RUimm>;
def loongarch_vbsrl : SDNode<"LoongArchISD::VBSRL", SDT_LoongArchV1RUimm>;
+def loongarch_vldrepl
+ : SDNode<"LoongArchISD::VLDREPL",
+ SDT_LoongArchVLDREPL, [SDNPHasChain, SDNPMayLoad, SDNPMemOperand]>;
+
def immZExt1 : ImmLeaf<i64, [{return isUInt<1>(Imm);}]>;
def immZExt2 : ImmLeaf<i64, [{return isUInt<2>(Imm);}]>;
def immZExt3 : ImmLeaf<i64, [{return isUInt<3>(Imm);}]>;
@@ -1433,6 +1438,14 @@ multiclass PatCCVrVrF<CondCode CC, string Inst> {
(!cast<LAInst>(Inst#"_D") LSX128:$vj, LSX128:$vk)>;
}
+multiclass VldreplPat<ValueType vt, LAInst Inst, Operand ImmOpnd> {
+ def : Pat<(vt(loongarch_vldrepl BaseAddr:$rj)), (Inst BaseAddr:$rj, 0)>;
+ def : Pat<(vt(loongarch_vldrepl(AddrConstant GPR:$rj, ImmOpnd:$imm))),
+ (Inst GPR:$rj, ImmOpnd:$imm)>;
+ def : Pat<(vt(loongarch_vldrepl(AddLike BaseAddr:$rj, ImmOpnd:$imm))),
+ (Inst BaseAddr:$rj, ImmOpnd:$imm)>;
+}
+
let Predicates = [HasExtLSX] in {
// VADD_{B/H/W/D}
@@ -2338,6 +2351,7 @@ def : Pat<(int_loongarch_lsx_vld GPR:$rj, timm:$imm),
def : Pat<(int_loongarch_lsx_vldx GPR:$rj, GPR:$rk),
(VLDX GPR:$rj, GPR:$rk)>;
+// vldrepl
def : Pat<(int_loongarch_lsx_vldrepl_b GPR:$rj, timm:$imm),
(VLDREPL_B GPR:$rj, (to_valid_timm timm:$imm))>;
def : Pat<(int_loongarch_lsx_vldrepl_h GPR:$rj, timm:$imm),
@@ -2347,6 +2361,11 @@ def : Pat<(int_loongarch_lsx_vldrepl_w GPR:$rj, timm:$imm),
def : Pat<(int_loongarch_lsx_vldrepl_d GPR:$rj, timm:$imm),
(VLDREPL_D GPR:$rj, (to_valid_timm timm:$imm))>;
+defm : VldreplPat<v16i8, VLDREPL_B, simm12_addlike>;
+defm : VldreplPat<v8i16, VLDREPL_H, simm11_lsl1>;
+defm : VldreplPat<v4i32, VLDREPL_W, simm10_lsl2>;
+defm : VldreplPat<v2i64, VLDREPL_D, simm9_lsl3>;
+
// store
def : Pat<(int_loongarch_lsx_vst LSX128:$vd, GPR:$rj, timm:$imm),
(VST LSX128:$vd, GPR:$rj, (to_valid_timm timm:$imm))>;
>From 60977f3558b31b21de90c6061496c64029c219b8 Mon Sep 17 00:00:00 2001
From: tangaac <tangyan01 at loongson.cn>
Date: Wed, 16 Apr 2025 10:46:56 +0800
Subject: [PATCH 2/3] add v4f32, v2f64, v8f32, v4f64 support
---
llvm/lib/Target/LoongArch/LoongArchLASXInstrInfo.td | 2 ++
llvm/lib/Target/LoongArch/LoongArchLSXInstrInfo.td | 2 ++
2 files changed, 4 insertions(+)
diff --git a/llvm/lib/Target/LoongArch/LoongArchLASXInstrInfo.td b/llvm/lib/Target/LoongArch/LoongArchLASXInstrInfo.td
index 54fad8421378b..6b2071392d5c5 100644
--- a/llvm/lib/Target/LoongArch/LoongArchLASXInstrInfo.td
+++ b/llvm/lib/Target/LoongArch/LoongArchLASXInstrInfo.td
@@ -2175,6 +2175,8 @@ defm : VldreplPat<v32i8, XVLDREPL_B, simm12_addlike>;
defm : VldreplPat<v16i16, XVLDREPL_H, simm11_lsl1>;
defm : VldreplPat<v8i32, XVLDREPL_W, simm10_lsl2>;
defm : VldreplPat<v4i64, XVLDREPL_D, simm9_lsl3>;
+defm : VldreplPat<v8f32, XVLDREPL_W, simm10_lsl2>;
+defm : VldreplPat<v4f64, XVLDREPL_D, simm9_lsl3>;
// store
def : Pat<(int_loongarch_lasx_xvst LASX256:$xd, GPR:$rj, timm:$imm),
diff --git a/llvm/lib/Target/LoongArch/LoongArchLSXInstrInfo.td b/llvm/lib/Target/LoongArch/LoongArchLSXInstrInfo.td
index 2b44361df29ba..d3910a21876d3 100644
--- a/llvm/lib/Target/LoongArch/LoongArchLSXInstrInfo.td
+++ b/llvm/lib/Target/LoongArch/LoongArchLSXInstrInfo.td
@@ -2365,6 +2365,8 @@ defm : VldreplPat<v16i8, VLDREPL_B, simm12_addlike>;
defm : VldreplPat<v8i16, VLDREPL_H, simm11_lsl1>;
defm : VldreplPat<v4i32, VLDREPL_W, simm10_lsl2>;
defm : VldreplPat<v2i64, VLDREPL_D, simm9_lsl3>;
+defm : VldreplPat<v4f32, VLDREPL_W, simm10_lsl2>;
+defm : VldreplPat<v2f64, VLDREPL_D, simm9_lsl3>;
// store
def : Pat<(int_loongarch_lsx_vst LSX128:$vd, GPR:$rj, timm:$imm),
>From e7f09af8d7abec88346fdc968bb40a5018d03d58 Mon Sep 17 00:00:00 2001
From: tangaac <tangyan01 at loongson.cn>
Date: Wed, 16 Apr 2025 11:48:13 +0800
Subject: [PATCH 3/3] add extra judgement
---
llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp | 4 ++++
1 file changed, 4 insertions(+)
diff --git a/llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp b/llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp
index 8c5e095dea039..9761769696b94 100644
--- a/llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp
+++ b/llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp
@@ -1748,6 +1748,10 @@ static SDValue lowerBUILD_VECTORAsBroadCastLoad(BuildVectorSDNode *BVOp,
IdentitySrc = BVOp->getOperand(0);
}
+ // make sure that this load is valid and only has one user.
+ if (!IdentitySrc || !BVOp->isOnlyUserOf(IdentitySrc.getNode()))
+ return SDValue();
+
if (IsIdeneity) {
auto *LN = cast<LoadSDNode>(IdentitySrc);
SDVTList Tys =
More information about the llvm-commits
mailing list