[llvm] 8bce40b - [AArch64][GISel] Support SVE with 128-bit min-size for G_LOAD and G_STORE (#92130)
via llvm-commits
llvm-commits at lists.llvm.org
Thu May 30 01:10:46 PDT 2024
Author: Him188
Date: 2024-05-30T09:10:43+01:00
New Revision: 8bce40b1eb3eb00358bbc3b7a05ea987a183265f
URL: https://github.com/llvm/llvm-project/commit/8bce40b1eb3eb00358bbc3b7a05ea987a183265f
DIFF: https://github.com/llvm/llvm-project/commit/8bce40b1eb3eb00358bbc3b7a05ea987a183265f.diff
LOG: [AArch64][GISel] Support SVE with 128-bit min-size for G_LOAD and G_STORE (#92130)
This patch adds basic support for scalable vector types in load & store
instructions for AArch64 with GISel.
Only scalable vector types with a 128-bit base size are supported, e.g.
`<vscale x 4 x i32>`, `<vscale x 16 x i8>`.
This patch adapted some ideas from a similar abandoned patch
[https://github.com/llvm/llvm-project/pull/72976](https://github.com/llvm/llvm-project/pull/72976).
Added:
llvm/test/CodeGen/AArch64/GlobalISel/sve-load-store.ll
Modified:
llvm/include/llvm/CodeGen/GlobalISel/GIMatchTableExecutorImpl.h
llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp
llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
llvm/lib/Target/AArch64/AArch64RegisterBanks.td
llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp
llvm/lib/Target/AArch64/GISel/AArch64PostLegalizerCombiner.cpp
llvm/lib/Target/AArch64/GISel/AArch64RegisterBankInfo.cpp
Removed:
################################################################################
diff --git a/llvm/include/llvm/CodeGen/GlobalISel/GIMatchTableExecutorImpl.h b/llvm/include/llvm/CodeGen/GlobalISel/GIMatchTableExecutorImpl.h
index 05f1a7e57e56b..90b4fe5518c87 100644
--- a/llvm/include/llvm/CodeGen/GlobalISel/GIMatchTableExecutorImpl.h
+++ b/llvm/include/llvm/CodeGen/GlobalISel/GIMatchTableExecutorImpl.h
@@ -669,17 +669,17 @@ bool GIMatchTableExecutor::executeMatchTable(
MachineMemOperand *MMO =
*(State.MIs[InsnID]->memoperands_begin() + MMOIdx);
- unsigned Size = MRI.getType(MO.getReg()).getSizeInBits();
+ const TypeSize Size = MRI.getType(MO.getReg()).getSizeInBits();
if (MatcherOpcode == GIM_CheckMemorySizeEqualToLLT &&
- MMO->getSizeInBits().getValue() != Size) {
+ MMO->getSizeInBits() != Size) {
if (handleReject() == RejectAndGiveUp)
return false;
} else if (MatcherOpcode == GIM_CheckMemorySizeLessThanLLT &&
- MMO->getSizeInBits().getValue() >= Size) {
+ TypeSize::isKnownGE(MMO->getSizeInBits().getValue(), Size)) {
if (handleReject() == RejectAndGiveUp)
return false;
} else if (MatcherOpcode == GIM_CheckMemorySizeGreaterThanLLT &&
- MMO->getSizeInBits().getValue() <= Size)
+ TypeSize::isKnownLE(MMO->getSizeInBits().getValue(), Size))
if (handleReject() == RejectAndGiveUp)
return false;
diff --git a/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp b/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp
index dcc1335a4bd44..29b665c7cbcc4 100644
--- a/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp
@@ -1150,7 +1150,8 @@ bool CombinerHelper::isIndexedLoadStoreLegal(GLoadStore &LdSt) const {
LLT Ty = MRI.getType(LdSt.getReg(0));
LLT MemTy = LdSt.getMMO().getMemoryType();
SmallVector<LegalityQuery::MemDesc, 2> MemDescrs(
- {{MemTy, MemTy.getSizeInBits(), AtomicOrdering::NotAtomic}});
+ {{MemTy, MemTy.getSizeInBits().getKnownMinValue(),
+ AtomicOrdering::NotAtomic}});
unsigned IndexedOpc = getIndexedOpc(LdSt.getOpcode());
SmallVector<LLT> OpTys;
if (IndexedOpc == TargetOpcode::G_INDEXED_STORE)
diff --git a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
index 5289b993476db..9fc9aa354b0db 100644
--- a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
@@ -1413,7 +1413,7 @@ bool IRTranslator::translateLoad(const User &U, MachineIRBuilder &MIRBuilder) {
bool IRTranslator::translateStore(const User &U, MachineIRBuilder &MIRBuilder) {
const StoreInst &SI = cast<StoreInst>(U);
- if (DL->getTypeStoreSize(SI.getValueOperand()->getType()) == 0)
+ if (DL->getTypeStoreSize(SI.getValueOperand()->getType()).isZero())
return true;
ArrayRef<Register> Vals = getOrCreateVRegs(*SI.getValueOperand());
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 3e2a5bfbc2321..365ef68dcb19b 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -145,6 +145,15 @@ static cl::opt<bool> EnableExtToTBL("aarch64-enable-ext-to-tbl", cl::Hidden,
static cl::opt<unsigned> MaxXors("aarch64-max-xors", cl::init(16), cl::Hidden,
cl::desc("Maximum of xors"));
+// By turning this on, we will not fallback to DAG ISel when encountering
+// scalable vector types for all instruction, even if SVE is not yet supported
+// with some instructions.
+// See [AArch64TargetLowering::fallbackToDAGISel] for implementation details.
+static cl::opt<bool> EnableSVEGISel(
+ "aarch64-enable-gisel-sve", cl::Hidden,
+ cl::desc("Enable / disable SVE scalable vectors in Global ISel"),
+ cl::init(false));
+
/// Value type used for condition codes.
static const MVT MVT_CC = MVT::i32;
@@ -26469,16 +26478,22 @@ bool AArch64TargetLowering::shouldLocalize(
}
bool AArch64TargetLowering::fallBackToDAGISel(const Instruction &Inst) const {
- if (Inst.getType()->isScalableTy())
- return true;
-
- for (unsigned i = 0; i < Inst.getNumOperands(); ++i)
- if (Inst.getOperand(i)->getType()->isScalableTy())
+ // Fallback for scalable vectors.
+ // Note that if EnableSVEGISel is true, we allow scalable vector types for
+ // all instructions, regardless of whether they are actually supported.
+ if (!EnableSVEGISel) {
+ if (Inst.getType()->isScalableTy()) {
return true;
+ }
- if (const AllocaInst *AI = dyn_cast<AllocaInst>(&Inst)) {
- if (AI->getAllocatedType()->isScalableTy())
- return true;
+ for (unsigned i = 0; i < Inst.getNumOperands(); ++i)
+ if (Inst.getOperand(i)->getType()->isScalableTy())
+ return true;
+
+ if (const AllocaInst *AI = dyn_cast<AllocaInst>(&Inst)) {
+ if (AI->getAllocatedType()->isScalableTy())
+ return true;
+ }
}
// Checks to allow the use of SME instructions
diff --git a/llvm/lib/Target/AArch64/AArch64RegisterBanks.td b/llvm/lib/Target/AArch64/AArch64RegisterBanks.td
index 615ce7d51d9ba..2b597b8606921 100644
--- a/llvm/lib/Target/AArch64/AArch64RegisterBanks.td
+++ b/llvm/lib/Target/AArch64/AArch64RegisterBanks.td
@@ -12,8 +12,8 @@
/// General Purpose Registers: W, X.
def GPRRegBank : RegisterBank<"GPR", [XSeqPairsClass]>;
-/// Floating Point/Vector Registers: B, H, S, D, Q.
-def FPRRegBank : RegisterBank<"FPR", [QQQQ]>;
+/// Floating Point, Vector, Scalable Vector Registers: B, H, S, D, Q, Z.
+def FPRRegBank : RegisterBank<"FPR", [QQQQ, ZPR]>;
/// Conditional register: NZCV.
def CCRegBank : RegisterBank<"CC", [CCR]>;
diff --git a/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp b/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp
index a21be7de6f42f..07a0473888ee5 100644
--- a/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp
+++ b/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp
@@ -61,6 +61,11 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
const LLT v2s64 = LLT::fixed_vector(2, 64);
const LLT v2p0 = LLT::fixed_vector(2, p0);
+ const LLT nxv16s8 = LLT::scalable_vector(16, s8);
+ const LLT nxv8s16 = LLT::scalable_vector(8, s16);
+ const LLT nxv4s32 = LLT::scalable_vector(4, s32);
+ const LLT nxv2s64 = LLT::scalable_vector(2, s64);
+
std::initializer_list<LLT> PackedVectorAllTypeList = {/* Begin 128bit types */
v16s8, v8s16, v4s32,
v2s64, v2p0,
@@ -328,7 +333,31 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
return ValTy.isPointerVector() && ValTy.getAddressSpace() == 0;
};
- getActionDefinitionsBuilder(G_LOAD)
+ auto &LoadActions = getActionDefinitionsBuilder(G_LOAD);
+ auto &StoreActions = getActionDefinitionsBuilder(G_STORE);
+
+ if (ST.hasSVE()) {
+ LoadActions.legalForTypesWithMemDesc({
+ // 128 bit base sizes
+ {nxv16s8, p0, nxv16s8, 8},
+ {nxv8s16, p0, nxv8s16, 8},
+ {nxv4s32, p0, nxv4s32, 8},
+ {nxv2s64, p0, nxv2s64, 8},
+ });
+
+ // TODO: Add nxv2p0. Consider bitcastIf.
+ // See #92130
+ // https://github.com/llvm/llvm-project/pull/92130#discussion_r1616888461
+ StoreActions.legalForTypesWithMemDesc({
+ // 128 bit base sizes
+ {nxv16s8, p0, nxv16s8, 8},
+ {nxv8s16, p0, nxv8s16, 8},
+ {nxv4s32, p0, nxv4s32, 8},
+ {nxv2s64, p0, nxv2s64, 8},
+ });
+ }
+
+ LoadActions
.customIf([=](const LegalityQuery &Query) {
return HasRCPC3 && Query.Types[0] == s128 &&
Query.MMODescrs[0].Ordering == AtomicOrdering::Acquire;
@@ -378,7 +407,7 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
.customIf(IsPtrVecPred)
.scalarizeIf(typeInSet(0, {v2s16, v2s8}), 0);
- getActionDefinitionsBuilder(G_STORE)
+ StoreActions
.customIf([=](const LegalityQuery &Query) {
return HasRCPC3 && Query.Types[0] == s128 &&
Query.MMODescrs[0].Ordering == AtomicOrdering::Release;
diff --git a/llvm/lib/Target/AArch64/GISel/AArch64PostLegalizerCombiner.cpp b/llvm/lib/Target/AArch64/GISel/AArch64PostLegalizerCombiner.cpp
index 7f3e0e01ccd25..0c7be9f42c570 100644
--- a/llvm/lib/Target/AArch64/GISel/AArch64PostLegalizerCombiner.cpp
+++ b/llvm/lib/Target/AArch64/GISel/AArch64PostLegalizerCombiner.cpp
@@ -309,6 +309,8 @@ bool matchSplitStoreZero128(MachineInstr &MI, MachineRegisterInfo &MRI) {
if (!Store.isSimple())
return false;
LLT ValTy = MRI.getType(Store.getValueReg());
+ if (ValTy.isScalableVector())
+ return false;
if (!ValTy.isVector() || ValTy.getSizeInBits() != 128)
return false;
if (Store.getMemSizeInBits() != ValTy.getSizeInBits())
@@ -708,6 +710,11 @@ bool AArch64PostLegalizerCombiner::optimizeConsecutiveMemOpAddressing(
// should only be in a single block.
resetState();
for (auto &MI : MBB) {
+ // Skip for scalable vectors
+ if (auto *LdSt = dyn_cast<GLoadStore>(&MI);
+ LdSt && MRI.getType(LdSt->getOperand(0).getReg()).isScalableVector())
+ continue;
+
if (auto *St = dyn_cast<GStore>(&MI)) {
Register PtrBaseReg;
APInt Offset;
diff --git a/llvm/lib/Target/AArch64/GISel/AArch64RegisterBankInfo.cpp b/llvm/lib/Target/AArch64/GISel/AArch64RegisterBankInfo.cpp
index 7785e020eaaf1..4aa6999d1d3ca 100644
--- a/llvm/lib/Target/AArch64/GISel/AArch64RegisterBankInfo.cpp
+++ b/llvm/lib/Target/AArch64/GISel/AArch64RegisterBankInfo.cpp
@@ -257,6 +257,7 @@ AArch64RegisterBankInfo::getRegBankFromRegClass(const TargetRegisterClass &RC,
case AArch64::QQRegClassID:
case AArch64::QQQRegClassID:
case AArch64::QQQQRegClassID:
+ case AArch64::ZPRRegClassID:
return getRegBank(AArch64::FPRRegBankID);
case AArch64::GPR32commonRegClassID:
case AArch64::GPR32RegClassID:
@@ -743,12 +744,15 @@ AArch64RegisterBankInfo::getInstrMapping(const MachineInstr &MI) const {
LLT Ty = MRI.getType(MO.getReg());
if (!Ty.isValid())
continue;
- OpSize[Idx] = Ty.getSizeInBits();
+ OpSize[Idx] = Ty.getSizeInBits().getKnownMinValue();
- // As a top-level guess, vectors go in FPRs, scalars and pointers in GPRs.
+ // As a top-level guess, vectors including both scalable and non-scalable
+ // ones go in FPRs, scalars and pointers in GPRs.
// For floating-point instructions, scalars go in FPRs.
- if (Ty.isVector() || isPreISelGenericFloatingPointOpcode(Opc) ||
- Ty.getSizeInBits() > 64)
+ if (Ty.isVector())
+ OpRegBankIdx[Idx] = PMI_FirstFPR;
+ else if (isPreISelGenericFloatingPointOpcode(Opc) ||
+ Ty.getSizeInBits() > 64)
OpRegBankIdx[Idx] = PMI_FirstFPR;
else
OpRegBankIdx[Idx] = PMI_FirstGPR;
diff --git a/llvm/test/CodeGen/AArch64/GlobalISel/sve-load-store.ll b/llvm/test/CodeGen/AArch64/GlobalISel/sve-load-store.ll
new file mode 100644
index 0000000000000..95a5bfa4b038f
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/GlobalISel/sve-load-store.ll
@@ -0,0 +1,50 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 4
+; RUN: llc -mtriple=aarch64-linux-gnu -mattr=+sve -global-isel -aarch64-enable-gisel-sve=true < %s | FileCheck %s
+
+define void @scalable_v16i8(ptr %l0, ptr %l1) {
+; CHECK-LABEL: scalable_v16i8:
+; CHECK: // %bb.0:
+; CHECK-NEXT: ptrue p0.b
+; CHECK-NEXT: ld1b { z0.b }, p0/z, [x0]
+; CHECK-NEXT: st1b { z0.b }, p0, [x1]
+; CHECK-NEXT: ret
+ %l3 = load <vscale x 16 x i8>, ptr %l0, align 16
+ store <vscale x 16 x i8> %l3, ptr %l1, align 16
+ ret void
+}
+
+define void @scalable_v8i16(ptr %l0, ptr %l1) {
+; CHECK-LABEL: scalable_v8i16:
+; CHECK: // %bb.0:
+; CHECK-NEXT: ptrue p0.h
+; CHECK-NEXT: ld1h { z0.h }, p0/z, [x0]
+; CHECK-NEXT: st1h { z0.h }, p0, [x1]
+; CHECK-NEXT: ret
+ %l3 = load <vscale x 8 x i16>, ptr %l0, align 16
+ store <vscale x 8 x i16> %l3, ptr %l1, align 16
+ ret void
+}
+
+define void @scalable_v4i32(ptr %l0, ptr %l1) {
+; CHECK-LABEL: scalable_v4i32:
+; CHECK: // %bb.0:
+; CHECK-NEXT: ptrue p0.s
+; CHECK-NEXT: ld1w { z0.s }, p0/z, [x0]
+; CHECK-NEXT: st1w { z0.s }, p0, [x1]
+; CHECK-NEXT: ret
+ %l3 = load <vscale x 4 x i32>, ptr %l0, align 16
+ store <vscale x 4 x i32> %l3, ptr %l1, align 16
+ ret void
+}
+
+define void @scalable_v2i64(ptr %l0, ptr %l1) {
+; CHECK-LABEL: scalable_v2i64:
+; CHECK: // %bb.0:
+; CHECK-NEXT: ptrue p0.d
+; CHECK-NEXT: ld1d { z0.d }, p0/z, [x0]
+; CHECK-NEXT: st1d { z0.d }, p0, [x1]
+; CHECK-NEXT: ret
+ %l3 = load <vscale x 2 x i64>, ptr %l0, align 16
+ store <vscale x 2 x i64> %l3, ptr %l1, align 16
+ ret void
+}
More information about the llvm-commits
mailing list