[llvm] [AMDGPU] Simplify lowerBUILD_VECTOR (PR #109094)

via llvm-commits llvm-commits at lists.llvm.org
Wed Sep 18 00:35:34 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-amdgpu

Author: Piotr Sobczak (piotrAMD)

<details>
<summary>Changes</summary>

Simplify `lowerBUILD_VECTOR` by commoning up the way the vectors are split. 
Also reorder the checks to avoid a long condition inside `if`.

---
Full diff: https://github.com/llvm/llvm-project/pull/109094.diff


2 Files Affected:

- (modified) llvm/lib/Target/AMDGPU/SIISelLowering.cpp (+31-80) 
- (modified) llvm/test/CodeGen/AMDGPU/insert_vector_elt.v2bf16.ll (+1-9) 


``````````diff
diff --git a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
index 4a861f0c03a0c5..10108866a7005a 100644
--- a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
+++ b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
@@ -7443,98 +7443,49 @@ SDValue SITargetLowering::lowerBUILD_VECTOR(SDValue Op,
   SDLoc SL(Op);
   EVT VT = Op.getValueType();
 
-  if (VT == MVT::v4i16 || VT == MVT::v4f16 || VT == MVT::v8i16 ||
-      VT == MVT::v8f16 || VT == MVT::v4bf16 || VT == MVT::v8bf16) {
-    EVT HalfVT = MVT::getVectorVT(VT.getVectorElementType().getSimpleVT(),
-                                  VT.getVectorNumElements() / 2);
-    MVT HalfIntVT = MVT::getIntegerVT(HalfVT.getSizeInBits());
+  if (VT == MVT::v2f16 || VT == MVT::v2i16 || VT == MVT::v2bf16) {
+    assert(!Subtarget->hasVOP3PInsts() && "this should be legal");
 
-    // Turn into pair of packed build_vectors.
-    // TODO: Special case for constants that can be materialized with s_mov_b64.
-    SmallVector<SDValue, 4> LoOps, HiOps;
-    for (unsigned I = 0, E = VT.getVectorNumElements() / 2; I != E; ++I) {
-      LoOps.push_back(Op.getOperand(I));
-      HiOps.push_back(Op.getOperand(I + E));
-    }
-    SDValue Lo = DAG.getBuildVector(HalfVT, SL, LoOps);
-    SDValue Hi = DAG.getBuildVector(HalfVT, SL, HiOps);
-
-    SDValue CastLo = DAG.getNode(ISD::BITCAST, SL, HalfIntVT, Lo);
-    SDValue CastHi = DAG.getNode(ISD::BITCAST, SL, HalfIntVT, Hi);
-
-    SDValue Blend = DAG.getBuildVector(MVT::getVectorVT(HalfIntVT, 2), SL,
-                                       { CastLo, CastHi });
-    return DAG.getNode(ISD::BITCAST, SL, VT, Blend);
-  }
+    SDValue Lo = Op.getOperand(0);
+    SDValue Hi = Op.getOperand(1);
 
-  if (VT == MVT::v16i16 || VT == MVT::v16f16 || VT == MVT::v16bf16) {
-    EVT QuarterVT = MVT::getVectorVT(VT.getVectorElementType().getSimpleVT(),
-                                     VT.getVectorNumElements() / 4);
-    MVT QuarterIntVT = MVT::getIntegerVT(QuarterVT.getSizeInBits());
-
-    SmallVector<SDValue, 4> Parts[4];
-    for (unsigned I = 0, E = VT.getVectorNumElements() / 4; I != E; ++I) {
-      for (unsigned P = 0; P < 4; ++P)
-        Parts[P].push_back(Op.getOperand(I + P * E));
-    }
-    SDValue Casts[4];
-    for (unsigned P = 0; P < 4; ++P) {
-      SDValue Vec = DAG.getBuildVector(QuarterVT, SL, Parts[P]);
-      Casts[P] = DAG.getNode(ISD::BITCAST, SL, QuarterIntVT, Vec);
+    // Avoid adding defined bits with the zero_extend.
+    if (Hi.isUndef()) {
+      Lo = DAG.getNode(ISD::BITCAST, SL, MVT::i16, Lo);
+      SDValue ExtLo = DAG.getNode(ISD::ANY_EXTEND, SL, MVT::i32, Lo);
+      return DAG.getNode(ISD::BITCAST, SL, VT, ExtLo);
     }
 
-    SDValue Blend =
-        DAG.getBuildVector(MVT::getVectorVT(QuarterIntVT, 4), SL, Casts);
-    return DAG.getNode(ISD::BITCAST, SL, VT, Blend);
-  }
+    Hi = DAG.getNode(ISD::BITCAST, SL, MVT::i16, Hi);
+    Hi = DAG.getNode(ISD::ZERO_EXTEND, SL, MVT::i32, Hi);
 
-  if (VT == MVT::v32i16 || VT == MVT::v32f16 || VT == MVT::v32bf16) {
-    EVT QuarterVT = MVT::getVectorVT(VT.getVectorElementType().getSimpleVT(),
-                                     VT.getVectorNumElements() / 8);
-    MVT QuarterIntVT = MVT::getIntegerVT(QuarterVT.getSizeInBits());
+    SDValue ShlHi = DAG.getNode(ISD::SHL, SL, MVT::i32, Hi,
+                                DAG.getConstant(16, SL, MVT::i32));
+    if (Lo.isUndef())
+      return DAG.getNode(ISD::BITCAST, SL, VT, ShlHi);
 
-    SmallVector<SDValue, 8> Parts[8];
-    for (unsigned I = 0, E = VT.getVectorNumElements() / 8; I != E; ++I) {
-      for (unsigned P = 0; P < 8; ++P)
-        Parts[P].push_back(Op.getOperand(I + P * E));
-    }
-    SDValue Casts[8];
-    for (unsigned P = 0; P < 8; ++P) {
-      SDValue Vec = DAG.getBuildVector(QuarterVT, SL, Parts[P]);
-      Casts[P] = DAG.getNode(ISD::BITCAST, SL, QuarterIntVT, Vec);
-    }
+    Lo = DAG.getNode(ISD::BITCAST, SL, MVT::i16, Lo);
+    Lo = DAG.getNode(ISD::ZERO_EXTEND, SL, MVT::i32, Lo);
 
-    SDValue Blend =
-        DAG.getBuildVector(MVT::getVectorVT(QuarterIntVT, 8), SL, Casts);
-    return DAG.getNode(ISD::BITCAST, SL, VT, Blend);
+    SDValue Or = DAG.getNode(ISD::OR, SL, MVT::i32, Lo, ShlHi);
+    return DAG.getNode(ISD::BITCAST, SL, VT, Or);
   }
 
-  assert(VT == MVT::v2f16 || VT == MVT::v2i16 || VT == MVT::v2bf16);
-  assert(!Subtarget->hasVOP3PInsts() && "this should be legal");
+  // Split into 2-element chunks.
+  const unsigned NumParts = VT.getVectorNumElements() / 2;
+  EVT PartVT = MVT::getVectorVT(VT.getVectorElementType().getSimpleVT(), 2);
+  MVT PartIntVT = MVT::getIntegerVT(PartVT.getSizeInBits());
 
-  SDValue Lo = Op.getOperand(0);
-  SDValue Hi = Op.getOperand(1);
-
-  // Avoid adding defined bits with the zero_extend.
-  if (Hi.isUndef()) {
-    Lo = DAG.getNode(ISD::BITCAST, SL, MVT::i16, Lo);
-    SDValue ExtLo = DAG.getNode(ISD::ANY_EXTEND, SL, MVT::i32, Lo);
-    return DAG.getNode(ISD::BITCAST, SL, VT, ExtLo);
+  SmallVector<SDValue> Casts;
+  for (unsigned P = 0; P < NumParts; ++P) {
+    SDValue Vec = DAG.getBuildVector(
+        PartVT, SL, {Op.getOperand(P * 2), Op.getOperand(P * 2 + 1)});
+    Casts.push_back(DAG.getNode(ISD::BITCAST, SL, PartIntVT, Vec));
   }
 
-  Hi = DAG.getNode(ISD::BITCAST, SL, MVT::i16, Hi);
-  Hi = DAG.getNode(ISD::ZERO_EXTEND, SL, MVT::i32, Hi);
-
-  SDValue ShlHi = DAG.getNode(ISD::SHL, SL, MVT::i32, Hi,
-                              DAG.getConstant(16, SL, MVT::i32));
-  if (Lo.isUndef())
-    return DAG.getNode(ISD::BITCAST, SL, VT, ShlHi);
-
-  Lo = DAG.getNode(ISD::BITCAST, SL, MVT::i16, Lo);
-  Lo = DAG.getNode(ISD::ZERO_EXTEND, SL, MVT::i32, Lo);
-
-  SDValue Or = DAG.getNode(ISD::OR, SL, MVT::i32, Lo, ShlHi);
-  return DAG.getNode(ISD::BITCAST, SL, VT, Or);
+  SDValue Blend =
+      DAG.getBuildVector(MVT::getVectorVT(PartIntVT, NumParts), SL, Casts);
+  return DAG.getNode(ISD::BITCAST, SL, VT, Blend);
 }
 
 bool
diff --git a/llvm/test/CodeGen/AMDGPU/insert_vector_elt.v2bf16.ll b/llvm/test/CodeGen/AMDGPU/insert_vector_elt.v2bf16.ll
index 3135addec16183..c68138acc9b2bf 100644
--- a/llvm/test/CodeGen/AMDGPU/insert_vector_elt.v2bf16.ll
+++ b/llvm/test/CodeGen/AMDGPU/insert_vector_elt.v2bf16.ll
@@ -965,11 +965,7 @@ define amdgpu_kernel void @v_insertelement_v8bf16_3(ptr addrspace(1) %out, ptr a
 ; GFX900-NEXT:    v_mov_b32_e32 v5, 0x5040100
 ; GFX900-NEXT:    s_waitcnt lgkmcnt(0)
 ; GFX900-NEXT:    global_load_dwordx4 v[0:3], v4, s[2:3]
-; GFX900-NEXT:    s_mov_b32 s2, 0xffff
 ; GFX900-NEXT:    s_waitcnt vmcnt(0)
-; GFX900-NEXT:    v_bfi_b32 v3, s2, v3, v3
-; GFX900-NEXT:    v_bfi_b32 v2, s2, v2, v2
-; GFX900-NEXT:    v_bfi_b32 v0, s2, v0, v0
 ; GFX900-NEXT:    v_perm_b32 v1, s4, v1, v5
 ; GFX900-NEXT:    global_store_dwordx4 v4, v[0:3], s[0:1]
 ; GFX900-NEXT:    s_endpgm
@@ -980,14 +976,10 @@ define amdgpu_kernel void @v_insertelement_v8bf16_3(ptr addrspace(1) %out, ptr a
 ; GFX940-NEXT:    s_load_dword s0, s[2:3], 0x10
 ; GFX940-NEXT:    v_and_b32_e32 v0, 0x3ff, v0
 ; GFX940-NEXT:    v_lshlrev_b32_e32 v4, 4, v0
-; GFX940-NEXT:    s_mov_b32 s1, 0xffff
+; GFX940-NEXT:    v_mov_b32_e32 v5, 0x5040100
 ; GFX940-NEXT:    s_waitcnt lgkmcnt(0)
 ; GFX940-NEXT:    global_load_dwordx4 v[0:3], v4, s[6:7]
-; GFX940-NEXT:    v_mov_b32_e32 v5, 0x5040100
 ; GFX940-NEXT:    s_waitcnt vmcnt(0)
-; GFX940-NEXT:    v_bfi_b32 v3, s1, v3, v3
-; GFX940-NEXT:    v_bfi_b32 v2, s1, v2, v2
-; GFX940-NEXT:    v_bfi_b32 v0, s1, v0, v0
 ; GFX940-NEXT:    v_perm_b32 v1, s0, v1, v5
 ; GFX940-NEXT:    global_store_dwordx4 v4, v[0:3], s[4:5] sc0 sc1
 ; GFX940-NEXT:    s_endpgm

``````````

</details>


https://github.com/llvm/llvm-project/pull/109094


More information about the llvm-commits mailing list