[llvm] [GISel][AMDGPU] Fold ShuffleVec into ExtractSubvec, and custom lower ExtractSubvec (PR #124527)
Alan Li via llvm-commits
llvm-commits at lists.llvm.org
Mon Jan 27 06:15:35 PST 2025
https://github.com/lialan updated https://github.com/llvm/llvm-project/pull/124527
>From 89cccb419fdb5bc6da70d84da90efdbce040d64a Mon Sep 17 00:00:00 2001
From: Alan Li <me at alanli.org>
Date: Mon, 27 Jan 2025 18:26:19 +0800
Subject: [PATCH 1/3] First commit
---
.../llvm/CodeGen/GlobalISel/CombinerHelper.h | 4 ++
.../include/llvm/Target/GlobalISel/Combine.td | 9 ++-
.../lib/CodeGen/GlobalISel/CombinerHelper.cpp | 41 +++++++++++++
.../lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp | 59 +++++++++++++++++++
llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.h | 2 +
5 files changed, 114 insertions(+), 1 deletion(-)
diff --git a/llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h b/llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h
index 9b78342c8fc393..c1c303fd18e6b5 100644
--- a/llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h
+++ b/llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h
@@ -264,6 +264,10 @@ class CombinerHelper {
void applyCombineShuffleConcat(MachineInstr &MI,
SmallVector<Register> &Ops) const;
+ /// Replace \p MI with a narrow extract_subvector.
+ bool matchCombineShuffleExtract(MachineInstr &MI, int64_t &IsFirst) const;
+ void applyCombineShuffleExtract(MachineInstr &MI, int64_t IsFirst) const;
+
/// Try to combine G_SHUFFLE_VECTOR into G_CONCAT_VECTORS.
/// Returns true if MI changed.
///
diff --git a/llvm/include/llvm/Target/GlobalISel/Combine.td b/llvm/include/llvm/Target/GlobalISel/Combine.td
index 3590ab221ad441..30316305d9e4f7 100644
--- a/llvm/include/llvm/Target/GlobalISel/Combine.td
+++ b/llvm/include/llvm/Target/GlobalISel/Combine.td
@@ -1560,6 +1560,13 @@ def combine_shuffle_concat : GICombineRule<
[{ return Helper.matchCombineShuffleConcat(*${root}, ${matchinfo}); }]),
(apply [{ Helper.applyCombineShuffleConcat(*${root}, ${matchinfo}); }])>;
+// Combines shuffles of vector into extract_subvector
+def combine_shuffle_vector : GICombineRule<
+ (defs root:$root, int64_matchinfo:$matchinfo),
+ (match (wip_match_opcode G_SHUFFLE_VECTOR):$root,
+ [{ return Helper.matchCombineShuffleExtract(*${root}, ${matchinfo}); }]),
+ (apply [{ Helper.applyCombineShuffleExtract(*${root}, ${matchinfo}); }])>;
+
def insert_vector_element_idx_undef : GICombineRule<
(defs root:$root),
(match (G_IMPLICIT_DEF $idx),
@@ -2026,7 +2033,7 @@ def all_combines : GICombineGroup<[integer_reassoc_combines, trivial_combines,
and_or_disjoint_mask, fma_combines, fold_binop_into_select,
sub_add_reg, select_to_minmax,
fsub_to_fneg, commute_constant_to_rhs, match_ands, match_ors,
- simplify_neg_minmax, combine_concat_vector,
+ simplify_neg_minmax, combine_concat_vector, combine_shuffle_vector,
sext_trunc, zext_trunc, prefer_sign_combines, shuffle_combines,
combine_use_vector_truncate, merge_combines, overflow_combines]>;
diff --git a/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp b/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp
index b193d8bb0aa18a..fca7a81dd5fbdb 100644
--- a/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp
@@ -384,6 +384,47 @@ void CombinerHelper::applyCombineConcatVectors(
MI.eraseFromParent();
}
+bool CombinerHelper::matchCombineShuffleExtract(MachineInstr &MI, int64_t &Idx) const {
+ assert(MI.getOpcode() == TargetOpcode::G_SHUFFLE_VECTOR &&
+ "Invalid instruction");
+ auto &Shuffle = cast<GShuffleVector>(MI);
+ const auto &TLI = getTargetLowering();
+
+ auto SrcVec1 = Shuffle.getSrc1Reg();
+ auto SrcVec2 = Shuffle.getSrc2Reg();
+ auto Mask = Shuffle.getMask();
+
+ int Width = MRI.getType(SrcVec1).getNumElements();
+
+ // Check if all elements are extracted from the same vector, or within single
+ // vector.
+ auto MaxValue = *std::max_element(Mask.begin(), Mask.end());
+ auto MinValue = *std::min_element(Mask.begin(), Mask.end());
+ if (MaxValue >= Width && MinValue < Width) {
+ return false;
+ }
+ // Check if the extractee's order is kept:
+ if (!std::is_sorted(Mask.begin(), Mask.end())) {
+ return false;
+ }
+
+ Idx = Mask.front();
+ return true;
+}
+
+void CombinerHelper::applyCombineShuffleExtract(MachineInstr &MI, int64_t Idx) const {
+ auto &Shuffle = cast<GShuffleVector>(MI);
+
+ auto SrcVec1 = Shuffle.getSrc1Reg();
+ auto SrcVec2 = Shuffle.getSrc2Reg();
+ int Width = MRI.getType(SrcVec1).getNumElements();
+
+ auto SrcVec = Idx < Width ? SrcVec1 : SrcVec2;
+
+ Builder.buildExtractSubvector(MI.getOperand(0).getReg(), SrcVec, Idx);
+ MI.eraseFromParent();
+}
+
bool CombinerHelper::matchCombineShuffleConcat(
MachineInstr &MI, SmallVector<Register> &Ops) const {
ArrayRef<int> Mask = MI.getOperand(3).getShuffleMask();
diff --git a/llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp b/llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp
index e9e47eaadd557f..68b0a8b5aecbf6 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp
@@ -29,6 +29,7 @@
#include "llvm/CodeGen/GlobalISel/MIPatternMatch.h"
#include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"
#include "llvm/CodeGen/GlobalISel/Utils.h"
+#include "llvm/CodeGen/Register.h"
#include "llvm/CodeGen/TargetOpcodes.h"
#include "llvm/IR/DiagnosticInfo.h"
#include "llvm/IR/IntrinsicsAMDGPU.h"
@@ -1832,6 +1833,11 @@ AMDGPULegalizerInfo::AMDGPULegalizerInfo(const GCNSubtarget &ST_,
.lower();
}
+ getActionDefinitionsBuilder(G_EXTRACT_SUBVECTOR)
+ //.fewerElementsIf(isWideVec16(0), changeTo(0, V2S16))
+ .customFor({V8S16, V4S16})
+ .lower();
+
getActionDefinitionsBuilder(G_EXTRACT_VECTOR_ELT)
.unsupportedIf([=](const LegalityQuery &Query) {
const LLT &EltTy = Query.Types[1].getElementType();
@@ -2127,6 +2133,8 @@ bool AMDGPULegalizerInfo::legalizeCustom(
case TargetOpcode::G_FMINNUM_IEEE:
case TargetOpcode::G_FMAXNUM_IEEE:
return legalizeMinNumMaxNum(Helper, MI);
+ case TargetOpcode::G_EXTRACT_SUBVECTOR:
+ return legalizeExtractSubvector(MI, MRI, B);
case TargetOpcode::G_EXTRACT_VECTOR_ELT:
return legalizeExtractVectorElt(MI, MRI, B);
case TargetOpcode::G_INSERT_VECTOR_ELT:
@@ -2716,6 +2724,57 @@ bool AMDGPULegalizerInfo::legalizeMinNumMaxNum(LegalizerHelper &Helper,
return Helper.lowerFMinNumMaxNum(MI) == LegalizerHelper::Legalized;
}
+static auto buildExtractSubvector(MachineIRBuilder &B, SrcOp Src,
+ LLT DstTy, unsigned Start) {
+ SmallVector<Register, 8> Subvectors;
+ for (unsigned i = Start, e = Start + DstTy.getNumElements(); i != e; ++i) {
+ Subvectors.push_back(
+ B.buildExtractVectorElementConstant(DstTy.getElementType(), Src, i)
+ .getReg(0));
+ }
+ return B.buildBuildVector(DstTy, Subvectors);
+}
+
+bool AMDGPULegalizerInfo::legalizeExtractSubvector(
+ MachineInstr &MI, MachineRegisterInfo &MRI,
+ MachineIRBuilder &B) const {
+ const auto &Instr = llvm::cast<GExtractSubvector>(MI);
+ Register Src = Instr.getSrcVec();
+ Register Dst = MI.getOperand(0).getReg();
+ auto Start = Instr.getIndexImm();
+
+ LLT SrcTy = MRI.getType(Src);
+ LLT DstTy = MRI.getType(Dst);
+
+ LLT EltTy = SrcTy.getElementType();
+ assert(EltTy == DstTy.getElementType());
+ auto Count = DstTy.getNumElements();
+ assert(SrcTy.getNumElements() % 2 == 0 && Count % 2 == 0);
+
+ // Split vector size into legal sub vectors, and use build_vector
+ // to merge the result.
+ if (EltTy.getScalarSizeInBits() == 16 && Start % 2 == 0) {
+ bool UseScalar = Count == 2;
+ // Extract 32-bit registers at a time.
+ LLT NewSrcTy =
+ UseScalar ? S32 : LLT::fixed_vector(SrcTy.getNumElements() / 2, S32);
+ auto Bitcasted = B.buildBitcast(NewSrcTy, Src).getReg(0);
+ LLT NewDstTy = LLT::fixed_vector(DstTy.getNumElements() / 2, S32);
+
+ SmallVector<Register, 8> Subvectors;
+ for (unsigned i = Start / 2, e = (Start + Count) / 2; i != e; ++i) {
+ auto Subvec = B.buildExtractVectorElementConstant(S32, Bitcasted, i);
+ Subvectors.push_back(Subvec.getReg(0));
+ }
+
+ auto BuildVec = B.buildBuildVector(NewDstTy, Subvectors);
+ B.buildBitcast(Dst, BuildVec.getReg(0));
+ MI.eraseFromParent();
+ return true;
+ }
+ return false;
+}
+
bool AMDGPULegalizerInfo::legalizeExtractVectorElt(
MachineInstr &MI, MachineRegisterInfo &MRI,
MachineIRBuilder &B) const {
diff --git a/llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.h b/llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.h
index 86c15197805d23..7b55492afb9821 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.h
+++ b/llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.h
@@ -56,6 +56,8 @@ class AMDGPULegalizerInfo final : public LegalizerInfo {
bool legalizeFPTOI(MachineInstr &MI, MachineRegisterInfo &MRI,
MachineIRBuilder &B, bool Signed) const;
bool legalizeMinNumMaxNum(LegalizerHelper &Helper, MachineInstr &MI) const;
+ bool legalizeExtractSubvector(MachineInstr &MI, MachineRegisterInfo &MRI,
+ MachineIRBuilder &B) const;
bool legalizeExtractVectorElt(MachineInstr &MI, MachineRegisterInfo &MRI,
MachineIRBuilder &B) const;
bool legalizeInsertVectorElt(MachineInstr &MI, MachineRegisterInfo &MRI,
>From 83f2bb075342c78d8445d753deac0947f075de0b Mon Sep 17 00:00:00 2001
From: Alan Li <me at alanli.org>
Date: Mon, 27 Jan 2025 19:17:12 +0800
Subject: [PATCH 2/3] update
---
llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp | 14 ++++++++++++--
llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp | 12 ++++--------
2 files changed, 16 insertions(+), 10 deletions(-)
diff --git a/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp b/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp
index fca7a81dd5fbdb..e515fde1b5cc10 100644
--- a/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp
@@ -388,13 +388,16 @@ bool CombinerHelper::matchCombineShuffleExtract(MachineInstr &MI, int64_t &Idx)
assert(MI.getOpcode() == TargetOpcode::G_SHUFFLE_VECTOR &&
"Invalid instruction");
auto &Shuffle = cast<GShuffleVector>(MI);
- const auto &TLI = getTargetLowering();
auto SrcVec1 = Shuffle.getSrc1Reg();
- auto SrcVec2 = Shuffle.getSrc2Reg();
+ int SrcVec2 = Shuffle.getSrc2Reg();
auto Mask = Shuffle.getMask();
int Width = MRI.getType(SrcVec1).getNumElements();
+ int Width2 = MRI.getType(SrcVec2).getNumElements();
+
+ if (!llvm::isPowerOf2_32(Width))
+ return false;
// Check if all elements are extracted from the same vector, or within single
// vector.
@@ -403,6 +406,13 @@ bool CombinerHelper::matchCombineShuffleExtract(MachineInstr &MI, int64_t &Idx)
if (MaxValue >= Width && MinValue < Width) {
return false;
}
+
+ // Check that the extractee length is power of 2.
+ if ((MaxValue < Width && !llvm::isPowerOf2_32(Width)) ||
+ (MinValue >= Width && !llvm::isPowerOf2_32(Width2))) {
+ return false;
+ }
+
// Check if the extractee's order is kept:
if (!std::is_sorted(Mask.begin(), Mask.end())) {
return false;
diff --git a/llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp b/llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp
index 68b0a8b5aecbf6..75017c23bb502d 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp
@@ -2758,16 +2758,12 @@ bool AMDGPULegalizerInfo::legalizeExtractSubvector(
// Extract 32-bit registers at a time.
LLT NewSrcTy =
UseScalar ? S32 : LLT::fixed_vector(SrcTy.getNumElements() / 2, S32);
- auto Bitcasted = B.buildBitcast(NewSrcTy, Src).getReg(0);
LLT NewDstTy = LLT::fixed_vector(DstTy.getNumElements() / 2, S32);
+ auto Bitcasted = B.buildBitcast(NewSrcTy, Src);
- SmallVector<Register, 8> Subvectors;
- for (unsigned i = Start / 2, e = (Start + Count) / 2; i != e; ++i) {
- auto Subvec = B.buildExtractVectorElementConstant(S32, Bitcasted, i);
- Subvectors.push_back(Subvec.getReg(0));
- }
-
- auto BuildVec = B.buildBuildVector(NewDstTy, Subvectors);
+ auto BuildVec =
+ UseScalar ? Bitcasted
+ : buildExtractSubvector(B, Bitcasted, NewDstTy, Start / 2);
B.buildBitcast(Dst, BuildVec.getReg(0));
MI.eraseFromParent();
return true;
>From dc8448a4c2ed2b0e732b73ac2aeeee4f51346c72 Mon Sep 17 00:00:00 2001
From: Alan Li <me at alanli.org>
Date: Mon, 27 Jan 2025 22:15:15 +0800
Subject: [PATCH 3/3] small fix
---
llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp | 8 +++++---
1 file changed, 5 insertions(+), 3 deletions(-)
diff --git a/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp b/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp
index e515fde1b5cc10..bb686c4c1c622d 100644
--- a/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp
@@ -413,9 +413,11 @@ bool CombinerHelper::matchCombineShuffleExtract(MachineInstr &MI, int64_t &Idx)
return false;
}
- // Check if the extractee's order is kept:
- if (!std::is_sorted(Mask.begin(), Mask.end())) {
- return false;
+ // Check if the extractee's order is kept, and they should be conscecutive.
+ for (int i = 1; i < Mask.size(); ++i) {
+ if (Mask[i] != Mask[i - 1] + 1 || Mask[i] == -1) {
+ return false; // Not consecutive
+ }
}
Idx = Mask.front();
More information about the llvm-commits
mailing list