[llvm] [AMDGPU] Simplify lowerBUILD_VECTOR (PR #109094)
via llvm-commits
llvm-commits at lists.llvm.org
Wed Sep 18 00:35:34 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-backend-amdgpu
Author: Piotr Sobczak (piotrAMD)
<details>
<summary>Changes</summary>
Simplify `lowerBUILD_VECTOR` by commoning up the way the vectors are split.
Also reorder the checks to avoid a long condition inside `if`.
---
Full diff: https://github.com/llvm/llvm-project/pull/109094.diff
2 Files Affected:
- (modified) llvm/lib/Target/AMDGPU/SIISelLowering.cpp (+31-80)
- (modified) llvm/test/CodeGen/AMDGPU/insert_vector_elt.v2bf16.ll (+1-9)
``````````diff
diff --git a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
index 4a861f0c03a0c5..10108866a7005a 100644
--- a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
+++ b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
@@ -7443,98 +7443,49 @@ SDValue SITargetLowering::lowerBUILD_VECTOR(SDValue Op,
SDLoc SL(Op);
EVT VT = Op.getValueType();
- if (VT == MVT::v4i16 || VT == MVT::v4f16 || VT == MVT::v8i16 ||
- VT == MVT::v8f16 || VT == MVT::v4bf16 || VT == MVT::v8bf16) {
- EVT HalfVT = MVT::getVectorVT(VT.getVectorElementType().getSimpleVT(),
- VT.getVectorNumElements() / 2);
- MVT HalfIntVT = MVT::getIntegerVT(HalfVT.getSizeInBits());
+ if (VT == MVT::v2f16 || VT == MVT::v2i16 || VT == MVT::v2bf16) {
+ assert(!Subtarget->hasVOP3PInsts() && "this should be legal");
- // Turn into pair of packed build_vectors.
- // TODO: Special case for constants that can be materialized with s_mov_b64.
- SmallVector<SDValue, 4> LoOps, HiOps;
- for (unsigned I = 0, E = VT.getVectorNumElements() / 2; I != E; ++I) {
- LoOps.push_back(Op.getOperand(I));
- HiOps.push_back(Op.getOperand(I + E));
- }
- SDValue Lo = DAG.getBuildVector(HalfVT, SL, LoOps);
- SDValue Hi = DAG.getBuildVector(HalfVT, SL, HiOps);
-
- SDValue CastLo = DAG.getNode(ISD::BITCAST, SL, HalfIntVT, Lo);
- SDValue CastHi = DAG.getNode(ISD::BITCAST, SL, HalfIntVT, Hi);
-
- SDValue Blend = DAG.getBuildVector(MVT::getVectorVT(HalfIntVT, 2), SL,
- { CastLo, CastHi });
- return DAG.getNode(ISD::BITCAST, SL, VT, Blend);
- }
+ SDValue Lo = Op.getOperand(0);
+ SDValue Hi = Op.getOperand(1);
- if (VT == MVT::v16i16 || VT == MVT::v16f16 || VT == MVT::v16bf16) {
- EVT QuarterVT = MVT::getVectorVT(VT.getVectorElementType().getSimpleVT(),
- VT.getVectorNumElements() / 4);
- MVT QuarterIntVT = MVT::getIntegerVT(QuarterVT.getSizeInBits());
-
- SmallVector<SDValue, 4> Parts[4];
- for (unsigned I = 0, E = VT.getVectorNumElements() / 4; I != E; ++I) {
- for (unsigned P = 0; P < 4; ++P)
- Parts[P].push_back(Op.getOperand(I + P * E));
- }
- SDValue Casts[4];
- for (unsigned P = 0; P < 4; ++P) {
- SDValue Vec = DAG.getBuildVector(QuarterVT, SL, Parts[P]);
- Casts[P] = DAG.getNode(ISD::BITCAST, SL, QuarterIntVT, Vec);
+ // Avoid adding defined bits with the zero_extend.
+ if (Hi.isUndef()) {
+ Lo = DAG.getNode(ISD::BITCAST, SL, MVT::i16, Lo);
+ SDValue ExtLo = DAG.getNode(ISD::ANY_EXTEND, SL, MVT::i32, Lo);
+ return DAG.getNode(ISD::BITCAST, SL, VT, ExtLo);
}
- SDValue Blend =
- DAG.getBuildVector(MVT::getVectorVT(QuarterIntVT, 4), SL, Casts);
- return DAG.getNode(ISD::BITCAST, SL, VT, Blend);
- }
+ Hi = DAG.getNode(ISD::BITCAST, SL, MVT::i16, Hi);
+ Hi = DAG.getNode(ISD::ZERO_EXTEND, SL, MVT::i32, Hi);
- if (VT == MVT::v32i16 || VT == MVT::v32f16 || VT == MVT::v32bf16) {
- EVT QuarterVT = MVT::getVectorVT(VT.getVectorElementType().getSimpleVT(),
- VT.getVectorNumElements() / 8);
- MVT QuarterIntVT = MVT::getIntegerVT(QuarterVT.getSizeInBits());
+ SDValue ShlHi = DAG.getNode(ISD::SHL, SL, MVT::i32, Hi,
+ DAG.getConstant(16, SL, MVT::i32));
+ if (Lo.isUndef())
+ return DAG.getNode(ISD::BITCAST, SL, VT, ShlHi);
- SmallVector<SDValue, 8> Parts[8];
- for (unsigned I = 0, E = VT.getVectorNumElements() / 8; I != E; ++I) {
- for (unsigned P = 0; P < 8; ++P)
- Parts[P].push_back(Op.getOperand(I + P * E));
- }
- SDValue Casts[8];
- for (unsigned P = 0; P < 8; ++P) {
- SDValue Vec = DAG.getBuildVector(QuarterVT, SL, Parts[P]);
- Casts[P] = DAG.getNode(ISD::BITCAST, SL, QuarterIntVT, Vec);
- }
+ Lo = DAG.getNode(ISD::BITCAST, SL, MVT::i16, Lo);
+ Lo = DAG.getNode(ISD::ZERO_EXTEND, SL, MVT::i32, Lo);
- SDValue Blend =
- DAG.getBuildVector(MVT::getVectorVT(QuarterIntVT, 8), SL, Casts);
- return DAG.getNode(ISD::BITCAST, SL, VT, Blend);
+ SDValue Or = DAG.getNode(ISD::OR, SL, MVT::i32, Lo, ShlHi);
+ return DAG.getNode(ISD::BITCAST, SL, VT, Or);
}
- assert(VT == MVT::v2f16 || VT == MVT::v2i16 || VT == MVT::v2bf16);
- assert(!Subtarget->hasVOP3PInsts() && "this should be legal");
+ // Split into 2-element chunks.
+ const unsigned NumParts = VT.getVectorNumElements() / 2;
+ EVT PartVT = MVT::getVectorVT(VT.getVectorElementType().getSimpleVT(), 2);
+ MVT PartIntVT = MVT::getIntegerVT(PartVT.getSizeInBits());
- SDValue Lo = Op.getOperand(0);
- SDValue Hi = Op.getOperand(1);
-
- // Avoid adding defined bits with the zero_extend.
- if (Hi.isUndef()) {
- Lo = DAG.getNode(ISD::BITCAST, SL, MVT::i16, Lo);
- SDValue ExtLo = DAG.getNode(ISD::ANY_EXTEND, SL, MVT::i32, Lo);
- return DAG.getNode(ISD::BITCAST, SL, VT, ExtLo);
+ SmallVector<SDValue> Casts;
+ for (unsigned P = 0; P < NumParts; ++P) {
+ SDValue Vec = DAG.getBuildVector(
+ PartVT, SL, {Op.getOperand(P * 2), Op.getOperand(P * 2 + 1)});
+ Casts.push_back(DAG.getNode(ISD::BITCAST, SL, PartIntVT, Vec));
}
- Hi = DAG.getNode(ISD::BITCAST, SL, MVT::i16, Hi);
- Hi = DAG.getNode(ISD::ZERO_EXTEND, SL, MVT::i32, Hi);
-
- SDValue ShlHi = DAG.getNode(ISD::SHL, SL, MVT::i32, Hi,
- DAG.getConstant(16, SL, MVT::i32));
- if (Lo.isUndef())
- return DAG.getNode(ISD::BITCAST, SL, VT, ShlHi);
-
- Lo = DAG.getNode(ISD::BITCAST, SL, MVT::i16, Lo);
- Lo = DAG.getNode(ISD::ZERO_EXTEND, SL, MVT::i32, Lo);
-
- SDValue Or = DAG.getNode(ISD::OR, SL, MVT::i32, Lo, ShlHi);
- return DAG.getNode(ISD::BITCAST, SL, VT, Or);
+ SDValue Blend =
+ DAG.getBuildVector(MVT::getVectorVT(PartIntVT, NumParts), SL, Casts);
+ return DAG.getNode(ISD::BITCAST, SL, VT, Blend);
}
bool
diff --git a/llvm/test/CodeGen/AMDGPU/insert_vector_elt.v2bf16.ll b/llvm/test/CodeGen/AMDGPU/insert_vector_elt.v2bf16.ll
index 3135addec16183..c68138acc9b2bf 100644
--- a/llvm/test/CodeGen/AMDGPU/insert_vector_elt.v2bf16.ll
+++ b/llvm/test/CodeGen/AMDGPU/insert_vector_elt.v2bf16.ll
@@ -965,11 +965,7 @@ define amdgpu_kernel void @v_insertelement_v8bf16_3(ptr addrspace(1) %out, ptr a
; GFX900-NEXT: v_mov_b32_e32 v5, 0x5040100
; GFX900-NEXT: s_waitcnt lgkmcnt(0)
; GFX900-NEXT: global_load_dwordx4 v[0:3], v4, s[2:3]
-; GFX900-NEXT: s_mov_b32 s2, 0xffff
; GFX900-NEXT: s_waitcnt vmcnt(0)
-; GFX900-NEXT: v_bfi_b32 v3, s2, v3, v3
-; GFX900-NEXT: v_bfi_b32 v2, s2, v2, v2
-; GFX900-NEXT: v_bfi_b32 v0, s2, v0, v0
; GFX900-NEXT: v_perm_b32 v1, s4, v1, v5
; GFX900-NEXT: global_store_dwordx4 v4, v[0:3], s[0:1]
; GFX900-NEXT: s_endpgm
@@ -980,14 +976,10 @@ define amdgpu_kernel void @v_insertelement_v8bf16_3(ptr addrspace(1) %out, ptr a
; GFX940-NEXT: s_load_dword s0, s[2:3], 0x10
; GFX940-NEXT: v_and_b32_e32 v0, 0x3ff, v0
; GFX940-NEXT: v_lshlrev_b32_e32 v4, 4, v0
-; GFX940-NEXT: s_mov_b32 s1, 0xffff
+; GFX940-NEXT: v_mov_b32_e32 v5, 0x5040100
; GFX940-NEXT: s_waitcnt lgkmcnt(0)
; GFX940-NEXT: global_load_dwordx4 v[0:3], v4, s[6:7]
-; GFX940-NEXT: v_mov_b32_e32 v5, 0x5040100
; GFX940-NEXT: s_waitcnt vmcnt(0)
-; GFX940-NEXT: v_bfi_b32 v3, s1, v3, v3
-; GFX940-NEXT: v_bfi_b32 v2, s1, v2, v2
-; GFX940-NEXT: v_bfi_b32 v0, s1, v0, v0
; GFX940-NEXT: v_perm_b32 v1, s0, v1, v5
; GFX940-NEXT: global_store_dwordx4 v4, v[0:3], s[4:5] sc0 sc1
; GFX940-NEXT: s_endpgm
``````````
</details>
https://github.com/llvm/llvm-project/pull/109094
More information about the llvm-commits
mailing list