[llvm] [GISel][AMDGPU] Fold ShuffleVec into ExtractSubvec, and custom lower ExtractSubvec (PR #124527)
Alan Li via llvm-commits
llvm-commits at lists.llvm.org
Mon Jan 27 02:32:37 PST 2025
https://github.com/lialan created https://github.com/llvm/llvm-project/pull/124527
None
>From 89cccb419fdb5bc6da70d84da90efdbce040d64a Mon Sep 17 00:00:00 2001
From: Alan Li <me at alanli.org>
Date: Mon, 27 Jan 2025 18:26:19 +0800
Subject: [PATCH] First commit
---
.../llvm/CodeGen/GlobalISel/CombinerHelper.h | 4 ++
.../include/llvm/Target/GlobalISel/Combine.td | 9 ++-
.../lib/CodeGen/GlobalISel/CombinerHelper.cpp | 41 +++++++++++++
.../lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp | 59 +++++++++++++++++++
llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.h | 2 +
5 files changed, 114 insertions(+), 1 deletion(-)
diff --git a/llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h b/llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h
index 9b78342c8fc393..c1c303fd18e6b5 100644
--- a/llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h
+++ b/llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h
@@ -264,6 +264,10 @@ class CombinerHelper {
void applyCombineShuffleConcat(MachineInstr &MI,
SmallVector<Register> &Ops) const;
+ /// Replace \p MI with a narrow extract_subvector.
+ bool matchCombineShuffleExtract(MachineInstr &MI, int64_t &IsFirst) const;
+ void applyCombineShuffleExtract(MachineInstr &MI, int64_t IsFirst) const;
+
/// Try to combine G_SHUFFLE_VECTOR into G_CONCAT_VECTORS.
/// Returns true if MI changed.
///
diff --git a/llvm/include/llvm/Target/GlobalISel/Combine.td b/llvm/include/llvm/Target/GlobalISel/Combine.td
index 3590ab221ad441..30316305d9e4f7 100644
--- a/llvm/include/llvm/Target/GlobalISel/Combine.td
+++ b/llvm/include/llvm/Target/GlobalISel/Combine.td
@@ -1560,6 +1560,13 @@ def combine_shuffle_concat : GICombineRule<
[{ return Helper.matchCombineShuffleConcat(*${root}, ${matchinfo}); }]),
(apply [{ Helper.applyCombineShuffleConcat(*${root}, ${matchinfo}); }])>;
+// Combines shuffles of vector into extract_subvector
+def combine_shuffle_vector : GICombineRule<
+ (defs root:$root, int64_matchinfo:$matchinfo),
+ (match (wip_match_opcode G_SHUFFLE_VECTOR):$root,
+ [{ return Helper.matchCombineShuffleExtract(*${root}, ${matchinfo}); }]),
+ (apply [{ Helper.applyCombineShuffleExtract(*${root}, ${matchinfo}); }])>;
+
def insert_vector_element_idx_undef : GICombineRule<
(defs root:$root),
(match (G_IMPLICIT_DEF $idx),
@@ -2026,7 +2033,7 @@ def all_combines : GICombineGroup<[integer_reassoc_combines, trivial_combines,
and_or_disjoint_mask, fma_combines, fold_binop_into_select,
sub_add_reg, select_to_minmax,
fsub_to_fneg, commute_constant_to_rhs, match_ands, match_ors,
- simplify_neg_minmax, combine_concat_vector,
+ simplify_neg_minmax, combine_concat_vector, combine_shuffle_vector,
sext_trunc, zext_trunc, prefer_sign_combines, shuffle_combines,
combine_use_vector_truncate, merge_combines, overflow_combines]>;
diff --git a/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp b/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp
index b193d8bb0aa18a..fca7a81dd5fbdb 100644
--- a/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp
@@ -384,6 +384,47 @@ void CombinerHelper::applyCombineConcatVectors(
MI.eraseFromParent();
}
+bool CombinerHelper::matchCombineShuffleExtract(MachineInstr &MI, int64_t &Idx) const {
+ assert(MI.getOpcode() == TargetOpcode::G_SHUFFLE_VECTOR &&
+ "Invalid instruction");
+ auto &Shuffle = cast<GShuffleVector>(MI);
+ const auto &TLI = getTargetLowering();
+
+ auto SrcVec1 = Shuffle.getSrc1Reg();
+ auto SrcVec2 = Shuffle.getSrc2Reg();
+ auto Mask = Shuffle.getMask();
+
+ int Width = MRI.getType(SrcVec1).getNumElements();
+
+ // Check if all elements are extracted from the same vector, or within single
+ // vector.
+ auto MaxValue = *std::max_element(Mask.begin(), Mask.end());
+ auto MinValue = *std::min_element(Mask.begin(), Mask.end());
+ if (MaxValue >= Width && MinValue < Width) {
+ return false;
+ }
+ // Check if the extractee's order is kept:
+ if (!std::is_sorted(Mask.begin(), Mask.end())) {
+ return false;
+ }
+
+ Idx = Mask.front();
+ return true;
+}
+
+void CombinerHelper::applyCombineShuffleExtract(MachineInstr &MI, int64_t Idx) const {
+ auto &Shuffle = cast<GShuffleVector>(MI);
+
+ auto SrcVec1 = Shuffle.getSrc1Reg();
+ auto SrcVec2 = Shuffle.getSrc2Reg();
+ int Width = MRI.getType(SrcVec1).getNumElements();
+
+ auto SrcVec = Idx < Width ? SrcVec1 : SrcVec2;
+
+ Builder.buildExtractSubvector(MI.getOperand(0).getReg(), SrcVec, Idx);
+ MI.eraseFromParent();
+}
+
bool CombinerHelper::matchCombineShuffleConcat(
MachineInstr &MI, SmallVector<Register> &Ops) const {
ArrayRef<int> Mask = MI.getOperand(3).getShuffleMask();
diff --git a/llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp b/llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp
index e9e47eaadd557f..68b0a8b5aecbf6 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp
@@ -29,6 +29,7 @@
#include "llvm/CodeGen/GlobalISel/MIPatternMatch.h"
#include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"
#include "llvm/CodeGen/GlobalISel/Utils.h"
+#include "llvm/CodeGen/Register.h"
#include "llvm/CodeGen/TargetOpcodes.h"
#include "llvm/IR/DiagnosticInfo.h"
#include "llvm/IR/IntrinsicsAMDGPU.h"
@@ -1832,6 +1833,11 @@ AMDGPULegalizerInfo::AMDGPULegalizerInfo(const GCNSubtarget &ST_,
.lower();
}
+ getActionDefinitionsBuilder(G_EXTRACT_SUBVECTOR)
+ //.fewerElementsIf(isWideVec16(0), changeTo(0, V2S16))
+ .customFor({V8S16, V4S16})
+ .lower();
+
getActionDefinitionsBuilder(G_EXTRACT_VECTOR_ELT)
.unsupportedIf([=](const LegalityQuery &Query) {
const LLT &EltTy = Query.Types[1].getElementType();
@@ -2127,6 +2133,8 @@ bool AMDGPULegalizerInfo::legalizeCustom(
case TargetOpcode::G_FMINNUM_IEEE:
case TargetOpcode::G_FMAXNUM_IEEE:
return legalizeMinNumMaxNum(Helper, MI);
+ case TargetOpcode::G_EXTRACT_SUBVECTOR:
+ return legalizeExtractSubvector(MI, MRI, B);
case TargetOpcode::G_EXTRACT_VECTOR_ELT:
return legalizeExtractVectorElt(MI, MRI, B);
case TargetOpcode::G_INSERT_VECTOR_ELT:
@@ -2716,6 +2724,57 @@ bool AMDGPULegalizerInfo::legalizeMinNumMaxNum(LegalizerHelper &Helper,
return Helper.lowerFMinNumMaxNum(MI) == LegalizerHelper::Legalized;
}
+static auto buildExtractSubvector(MachineIRBuilder &B, SrcOp Src,
+ LLT DstTy, unsigned Start) {
+ SmallVector<Register, 8> Subvectors;
+ for (unsigned i = Start, e = Start + DstTy.getNumElements(); i != e; ++i) {
+ Subvectors.push_back(
+ B.buildExtractVectorElementConstant(DstTy.getElementType(), Src, i)
+ .getReg(0));
+ }
+ return B.buildBuildVector(DstTy, Subvectors);
+}
+
+bool AMDGPULegalizerInfo::legalizeExtractSubvector(
+ MachineInstr &MI, MachineRegisterInfo &MRI,
+ MachineIRBuilder &B) const {
+ const auto &Instr = llvm::cast<GExtractSubvector>(MI);
+ Register Src = Instr.getSrcVec();
+ Register Dst = MI.getOperand(0).getReg();
+ auto Start = Instr.getIndexImm();
+
+ LLT SrcTy = MRI.getType(Src);
+ LLT DstTy = MRI.getType(Dst);
+
+ LLT EltTy = SrcTy.getElementType();
+ assert(EltTy == DstTy.getElementType());
+ auto Count = DstTy.getNumElements();
+ assert(SrcTy.getNumElements() % 2 == 0 && Count % 2 == 0);
+
+ // Split vector size into legal sub vectors, and use build_vector
+ // to merge the result.
+ if (EltTy.getScalarSizeInBits() == 16 && Start % 2 == 0) {
+ bool UseScalar = Count == 2;
+ // Extract 32-bit registers at a time.
+ LLT NewSrcTy =
+ UseScalar ? S32 : LLT::fixed_vector(SrcTy.getNumElements() / 2, S32);
+ auto Bitcasted = B.buildBitcast(NewSrcTy, Src).getReg(0);
+ LLT NewDstTy = LLT::fixed_vector(DstTy.getNumElements() / 2, S32);
+
+ SmallVector<Register, 8> Subvectors;
+ for (unsigned i = Start / 2, e = (Start + Count) / 2; i != e; ++i) {
+ auto Subvec = B.buildExtractVectorElementConstant(S32, Bitcasted, i);
+ Subvectors.push_back(Subvec.getReg(0));
+ }
+
+ auto BuildVec = B.buildBuildVector(NewDstTy, Subvectors);
+ B.buildBitcast(Dst, BuildVec.getReg(0));
+ MI.eraseFromParent();
+ return true;
+ }
+ return false;
+}
+
bool AMDGPULegalizerInfo::legalizeExtractVectorElt(
MachineInstr &MI, MachineRegisterInfo &MRI,
MachineIRBuilder &B) const {
diff --git a/llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.h b/llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.h
index 86c15197805d23..7b55492afb9821 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.h
+++ b/llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.h
@@ -56,6 +56,8 @@ class AMDGPULegalizerInfo final : public LegalizerInfo {
bool legalizeFPTOI(MachineInstr &MI, MachineRegisterInfo &MRI,
MachineIRBuilder &B, bool Signed) const;
bool legalizeMinNumMaxNum(LegalizerHelper &Helper, MachineInstr &MI) const;
+ bool legalizeExtractSubvector(MachineInstr &MI, MachineRegisterInfo &MRI,
+ MachineIRBuilder &B) const;
bool legalizeExtractVectorElt(MachineInstr &MI, MachineRegisterInfo &MRI,
MachineIRBuilder &B) const;
bool legalizeInsertVectorElt(MachineInstr &MI, MachineRegisterInfo &MRI,
More information about the llvm-commits
mailing list