[llvm] [GISel][AMDGPU] Fold ShuffleVec into ExtractSubvec, and custom lower ExtractSubvec (PR #124527)

Alan Li via llvm-commits llvm-commits at lists.llvm.org
Mon Jan 27 02:32:37 PST 2025


https://github.com/lialan created https://github.com/llvm/llvm-project/pull/124527

None

>From 89cccb419fdb5bc6da70d84da90efdbce040d64a Mon Sep 17 00:00:00 2001
From: Alan Li <me at alanli.org>
Date: Mon, 27 Jan 2025 18:26:19 +0800
Subject: [PATCH] First commit

---
 .../llvm/CodeGen/GlobalISel/CombinerHelper.h  |  4 ++
 .../include/llvm/Target/GlobalISel/Combine.td |  9 ++-
 .../lib/CodeGen/GlobalISel/CombinerHelper.cpp | 41 +++++++++++++
 .../lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp | 59 +++++++++++++++++++
 llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.h  |  2 +
 5 files changed, 114 insertions(+), 1 deletion(-)

diff --git a/llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h b/llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h
index 9b78342c8fc393..c1c303fd18e6b5 100644
--- a/llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h
+++ b/llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h
@@ -264,6 +264,10 @@ class CombinerHelper {
   void applyCombineShuffleConcat(MachineInstr &MI,
                                  SmallVector<Register> &Ops) const;
 
+  /// Replace \p MI with a narrow extract_subvector.
+  bool matchCombineShuffleExtract(MachineInstr &MI, int64_t &IsFirst) const;
+  void applyCombineShuffleExtract(MachineInstr &MI, int64_t IsFirst) const;
+
   /// Try to combine G_SHUFFLE_VECTOR into G_CONCAT_VECTORS.
   /// Returns true if MI changed.
   ///
diff --git a/llvm/include/llvm/Target/GlobalISel/Combine.td b/llvm/include/llvm/Target/GlobalISel/Combine.td
index 3590ab221ad441..30316305d9e4f7 100644
--- a/llvm/include/llvm/Target/GlobalISel/Combine.td
+++ b/llvm/include/llvm/Target/GlobalISel/Combine.td
@@ -1560,6 +1560,13 @@ def combine_shuffle_concat : GICombineRule<
         [{ return Helper.matchCombineShuffleConcat(*${root}, ${matchinfo}); }]),
   (apply [{ Helper.applyCombineShuffleConcat(*${root}, ${matchinfo}); }])>;
 
+// Combines shuffles of vector into extract_subvector
+def combine_shuffle_vector : GICombineRule<
+  (defs root:$root, int64_matchinfo:$matchinfo),
+  (match (wip_match_opcode G_SHUFFLE_VECTOR):$root,
+    [{ return Helper.matchCombineShuffleExtract(*${root}, ${matchinfo}); }]),
+  (apply [{ Helper.applyCombineShuffleExtract(*${root}, ${matchinfo}); }])>;
+
 def insert_vector_element_idx_undef : GICombineRule<
    (defs root:$root),
    (match (G_IMPLICIT_DEF $idx),
@@ -2026,7 +2033,7 @@ def all_combines : GICombineGroup<[integer_reassoc_combines, trivial_combines,
     and_or_disjoint_mask, fma_combines, fold_binop_into_select,
     sub_add_reg, select_to_minmax,
     fsub_to_fneg, commute_constant_to_rhs, match_ands, match_ors,
-    simplify_neg_minmax, combine_concat_vector,
+    simplify_neg_minmax, combine_concat_vector, combine_shuffle_vector,
     sext_trunc, zext_trunc, prefer_sign_combines, shuffle_combines,
     combine_use_vector_truncate, merge_combines, overflow_combines]>;
 
diff --git a/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp b/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp
index b193d8bb0aa18a..fca7a81dd5fbdb 100644
--- a/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp
@@ -384,6 +384,47 @@ void CombinerHelper::applyCombineConcatVectors(
   MI.eraseFromParent();
 }
 
+bool CombinerHelper::matchCombineShuffleExtract(MachineInstr &MI, int64_t &Idx) const {
+  assert(MI.getOpcode() == TargetOpcode::G_SHUFFLE_VECTOR &&
+         "Invalid instruction");
+  auto &Shuffle = cast<GShuffleVector>(MI);
+  const auto &TLI = getTargetLowering();
+  
+  auto SrcVec1 = Shuffle.getSrc1Reg();
+  auto SrcVec2 = Shuffle.getSrc2Reg();
+  auto Mask = Shuffle.getMask();
+
+  int Width = MRI.getType(SrcVec1).getNumElements();
+
+  // Check if all elements are extracted from the same vector, or within single
+  // vector.
+  auto MaxValue = *std::max_element(Mask.begin(), Mask.end());
+  auto MinValue = *std::min_element(Mask.begin(), Mask.end());
+  if (MaxValue >= Width && MinValue < Width) {
+    return false;
+  }
+  // Check if the extractee's order is kept:
+  if (!std::is_sorted(Mask.begin(), Mask.end())) {
+    return false;
+  }
+
+  Idx = Mask.front();
+  return true;
+}
+
+void CombinerHelper::applyCombineShuffleExtract(MachineInstr &MI, int64_t Idx) const {
+  auto &Shuffle = cast<GShuffleVector>(MI);
+
+  auto SrcVec1 = Shuffle.getSrc1Reg();
+  auto SrcVec2 = Shuffle.getSrc2Reg();
+  int Width = MRI.getType(SrcVec1).getNumElements();
+
+  auto SrcVec = Idx < Width ? SrcVec1 : SrcVec2;
+
+  Builder.buildExtractSubvector(MI.getOperand(0).getReg(), SrcVec, Idx);
+  MI.eraseFromParent();
+}
+
 bool CombinerHelper::matchCombineShuffleConcat(
     MachineInstr &MI, SmallVector<Register> &Ops) const {
   ArrayRef<int> Mask = MI.getOperand(3).getShuffleMask();
diff --git a/llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp b/llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp
index e9e47eaadd557f..68b0a8b5aecbf6 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp
@@ -29,6 +29,7 @@
 #include "llvm/CodeGen/GlobalISel/MIPatternMatch.h"
 #include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"
 #include "llvm/CodeGen/GlobalISel/Utils.h"
+#include "llvm/CodeGen/Register.h"
 #include "llvm/CodeGen/TargetOpcodes.h"
 #include "llvm/IR/DiagnosticInfo.h"
 #include "llvm/IR/IntrinsicsAMDGPU.h"
@@ -1832,6 +1833,11 @@ AMDGPULegalizerInfo::AMDGPULegalizerInfo(const GCNSubtarget &ST_,
       .lower();
   }
 
+  getActionDefinitionsBuilder(G_EXTRACT_SUBVECTOR)
+    //.fewerElementsIf(isWideVec16(0), changeTo(0, V2S16))
+    .customFor({V8S16, V4S16})
+    .lower();
+
   getActionDefinitionsBuilder(G_EXTRACT_VECTOR_ELT)
     .unsupportedIf([=](const LegalityQuery &Query) {
         const LLT &EltTy = Query.Types[1].getElementType();
@@ -2127,6 +2133,8 @@ bool AMDGPULegalizerInfo::legalizeCustom(
   case TargetOpcode::G_FMINNUM_IEEE:
   case TargetOpcode::G_FMAXNUM_IEEE:
     return legalizeMinNumMaxNum(Helper, MI);
+  case TargetOpcode::G_EXTRACT_SUBVECTOR:
+    return legalizeExtractSubvector(MI, MRI, B);
   case TargetOpcode::G_EXTRACT_VECTOR_ELT:
     return legalizeExtractVectorElt(MI, MRI, B);
   case TargetOpcode::G_INSERT_VECTOR_ELT:
@@ -2716,6 +2724,57 @@ bool AMDGPULegalizerInfo::legalizeMinNumMaxNum(LegalizerHelper &Helper,
   return Helper.lowerFMinNumMaxNum(MI) == LegalizerHelper::Legalized;
 }
 
+static auto buildExtractSubvector(MachineIRBuilder &B, SrcOp Src,
+                                  LLT DstTy, unsigned Start) {
+  SmallVector<Register, 8> Subvectors;
+  for (unsigned i = Start, e = Start + DstTy.getNumElements(); i != e; ++i) {
+    Subvectors.push_back(
+        B.buildExtractVectorElementConstant(DstTy.getElementType(), Src, i)
+            .getReg(0));
+  }
+  return B.buildBuildVector(DstTy, Subvectors);
+}
+
+bool AMDGPULegalizerInfo::legalizeExtractSubvector(
+  MachineInstr &MI, MachineRegisterInfo &MRI,
+  MachineIRBuilder &B) const {
+  const auto &Instr = llvm::cast<GExtractSubvector>(MI);
+  Register Src = Instr.getSrcVec();
+  Register Dst = MI.getOperand(0).getReg();
+  auto Start = Instr.getIndexImm();
+
+  LLT SrcTy = MRI.getType(Src);
+  LLT DstTy = MRI.getType(Dst);
+
+  LLT EltTy = SrcTy.getElementType();
+  assert(EltTy == DstTy.getElementType());
+  auto Count = DstTy.getNumElements();
+  assert(SrcTy.getNumElements() % 2 == 0 && Count % 2 == 0);
+
+  // Split vector size into legal sub vectors, and use build_vector
+  // to merge the result.
+  if (EltTy.getScalarSizeInBits() == 16 && Start % 2 == 0) {
+    bool UseScalar = Count == 2;
+    // Extract 32-bit registers at a time.
+    LLT NewSrcTy =
+        UseScalar ? S32 : LLT::fixed_vector(SrcTy.getNumElements() / 2, S32);
+    auto Bitcasted = B.buildBitcast(NewSrcTy, Src).getReg(0);
+    LLT NewDstTy = LLT::fixed_vector(DstTy.getNumElements() / 2, S32);
+
+    SmallVector<Register, 8> Subvectors;
+    for (unsigned i = Start / 2, e = (Start + Count) / 2; i != e; ++i) {
+      auto Subvec = B.buildExtractVectorElementConstant(S32, Bitcasted, i);
+      Subvectors.push_back(Subvec.getReg(0));
+    }
+
+    auto BuildVec = B.buildBuildVector(NewDstTy, Subvectors);
+    B.buildBitcast(Dst, BuildVec.getReg(0));
+    MI.eraseFromParent();
+    return true;
+  }
+  return false;
+}
+
 bool AMDGPULegalizerInfo::legalizeExtractVectorElt(
   MachineInstr &MI, MachineRegisterInfo &MRI,
   MachineIRBuilder &B) const {
diff --git a/llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.h b/llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.h
index 86c15197805d23..7b55492afb9821 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.h
+++ b/llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.h
@@ -56,6 +56,8 @@ class AMDGPULegalizerInfo final : public LegalizerInfo {
   bool legalizeFPTOI(MachineInstr &MI, MachineRegisterInfo &MRI,
                      MachineIRBuilder &B, bool Signed) const;
   bool legalizeMinNumMaxNum(LegalizerHelper &Helper, MachineInstr &MI) const;
+  bool legalizeExtractSubvector(MachineInstr &MI, MachineRegisterInfo &MRI,
+                                MachineIRBuilder &B) const;
   bool legalizeExtractVectorElt(MachineInstr &MI, MachineRegisterInfo &MRI,
                                 MachineIRBuilder &B) const;
   bool legalizeInsertVectorElt(MachineInstr &MI, MachineRegisterInfo &MRI,



More information about the llvm-commits mailing list