[llvm] 8bce40b - [AArch64][GISel] Support SVE with 128-bit min-size for G_LOAD and G_STORE (#92130)

via llvm-commits llvm-commits at lists.llvm.org
Thu May 30 01:10:46 PDT 2024


Author: Him188
Date: 2024-05-30T09:10:43+01:00
New Revision: 8bce40b1eb3eb00358bbc3b7a05ea987a183265f

URL: https://github.com/llvm/llvm-project/commit/8bce40b1eb3eb00358bbc3b7a05ea987a183265f
DIFF: https://github.com/llvm/llvm-project/commit/8bce40b1eb3eb00358bbc3b7a05ea987a183265f.diff

LOG: [AArch64][GISel] Support SVE with 128-bit min-size for G_LOAD and G_STORE (#92130)

This patch adds basic support for scalable vector types in load & store
instructions for AArch64 with GISel.

Only scalable vector types with a 128-bit base size are supported, e.g.
`<vscale x 4 x i32>`, `<vscale x 16 x i8>`.

This patch adapted some ideas from a similar abandoned patch
[https://github.com/llvm/llvm-project/pull/72976](https://github.com/llvm/llvm-project/pull/72976).

Added: 
    llvm/test/CodeGen/AArch64/GlobalISel/sve-load-store.ll

Modified: 
    llvm/include/llvm/CodeGen/GlobalISel/GIMatchTableExecutorImpl.h
    llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp
    llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
    llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
    llvm/lib/Target/AArch64/AArch64RegisterBanks.td
    llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp
    llvm/lib/Target/AArch64/GISel/AArch64PostLegalizerCombiner.cpp
    llvm/lib/Target/AArch64/GISel/AArch64RegisterBankInfo.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/CodeGen/GlobalISel/GIMatchTableExecutorImpl.h b/llvm/include/llvm/CodeGen/GlobalISel/GIMatchTableExecutorImpl.h
index 05f1a7e57e56b..90b4fe5518c87 100644
--- a/llvm/include/llvm/CodeGen/GlobalISel/GIMatchTableExecutorImpl.h
+++ b/llvm/include/llvm/CodeGen/GlobalISel/GIMatchTableExecutorImpl.h
@@ -669,17 +669,17 @@ bool GIMatchTableExecutor::executeMatchTable(
       MachineMemOperand *MMO =
           *(State.MIs[InsnID]->memoperands_begin() + MMOIdx);
 
-      unsigned Size = MRI.getType(MO.getReg()).getSizeInBits();
+      const TypeSize Size = MRI.getType(MO.getReg()).getSizeInBits();
       if (MatcherOpcode == GIM_CheckMemorySizeEqualToLLT &&
-          MMO->getSizeInBits().getValue() != Size) {
+          MMO->getSizeInBits() != Size) {
         if (handleReject() == RejectAndGiveUp)
           return false;
       } else if (MatcherOpcode == GIM_CheckMemorySizeLessThanLLT &&
-                 MMO->getSizeInBits().getValue() >= Size) {
+                 TypeSize::isKnownGE(MMO->getSizeInBits().getValue(), Size)) {
         if (handleReject() == RejectAndGiveUp)
           return false;
       } else if (MatcherOpcode == GIM_CheckMemorySizeGreaterThanLLT &&
-                 MMO->getSizeInBits().getValue() <= Size)
+                 TypeSize::isKnownLE(MMO->getSizeInBits().getValue(), Size))
         if (handleReject() == RejectAndGiveUp)
           return false;
 

diff  --git a/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp b/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp
index dcc1335a4bd44..29b665c7cbcc4 100644
--- a/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp
@@ -1150,7 +1150,8 @@ bool CombinerHelper::isIndexedLoadStoreLegal(GLoadStore &LdSt) const {
   LLT Ty = MRI.getType(LdSt.getReg(0));
   LLT MemTy = LdSt.getMMO().getMemoryType();
   SmallVector<LegalityQuery::MemDesc, 2> MemDescrs(
-      {{MemTy, MemTy.getSizeInBits(), AtomicOrdering::NotAtomic}});
+      {{MemTy, MemTy.getSizeInBits().getKnownMinValue(),
+        AtomicOrdering::NotAtomic}});
   unsigned IndexedOpc = getIndexedOpc(LdSt.getOpcode());
   SmallVector<LLT> OpTys;
   if (IndexedOpc == TargetOpcode::G_INDEXED_STORE)

diff  --git a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
index 5289b993476db..9fc9aa354b0db 100644
--- a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
@@ -1413,7 +1413,7 @@ bool IRTranslator::translateLoad(const User &U, MachineIRBuilder &MIRBuilder) {
 
 bool IRTranslator::translateStore(const User &U, MachineIRBuilder &MIRBuilder) {
   const StoreInst &SI = cast<StoreInst>(U);
-  if (DL->getTypeStoreSize(SI.getValueOperand()->getType()) == 0)
+  if (DL->getTypeStoreSize(SI.getValueOperand()->getType()).isZero())
     return true;
 
   ArrayRef<Register> Vals = getOrCreateVRegs(*SI.getValueOperand());

diff  --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 3e2a5bfbc2321..365ef68dcb19b 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -145,6 +145,15 @@ static cl::opt<bool> EnableExtToTBL("aarch64-enable-ext-to-tbl", cl::Hidden,
 static cl::opt<unsigned> MaxXors("aarch64-max-xors", cl::init(16), cl::Hidden,
                                  cl::desc("Maximum of xors"));
 
+// By turning this on, we will not fallback to DAG ISel when encountering
+// scalable vector types for all instruction, even if SVE is not yet supported
+// with some instructions.
+// See [AArch64TargetLowering::fallbackToDAGISel] for implementation details.
+static cl::opt<bool> EnableSVEGISel(
+    "aarch64-enable-gisel-sve", cl::Hidden,
+    cl::desc("Enable / disable SVE scalable vectors in Global ISel"),
+    cl::init(false));
+
 /// Value type used for condition codes.
 static const MVT MVT_CC = MVT::i32;
 
@@ -26469,16 +26478,22 @@ bool AArch64TargetLowering::shouldLocalize(
 }
 
 bool AArch64TargetLowering::fallBackToDAGISel(const Instruction &Inst) const {
-  if (Inst.getType()->isScalableTy())
-    return true;
-
-  for (unsigned i = 0; i < Inst.getNumOperands(); ++i)
-    if (Inst.getOperand(i)->getType()->isScalableTy())
+  // Fallback for scalable vectors.
+  // Note that if EnableSVEGISel is true, we allow scalable vector types for
+  // all instructions, regardless of whether they are actually supported.
+  if (!EnableSVEGISel) {
+    if (Inst.getType()->isScalableTy()) {
       return true;
+    }
 
-  if (const AllocaInst *AI = dyn_cast<AllocaInst>(&Inst)) {
-    if (AI->getAllocatedType()->isScalableTy())
-      return true;
+    for (unsigned i = 0; i < Inst.getNumOperands(); ++i)
+      if (Inst.getOperand(i)->getType()->isScalableTy())
+        return true;
+
+    if (const AllocaInst *AI = dyn_cast<AllocaInst>(&Inst)) {
+      if (AI->getAllocatedType()->isScalableTy())
+        return true;
+    }
   }
 
   // Checks to allow the use of SME instructions

diff  --git a/llvm/lib/Target/AArch64/AArch64RegisterBanks.td b/llvm/lib/Target/AArch64/AArch64RegisterBanks.td
index 615ce7d51d9ba..2b597b8606921 100644
--- a/llvm/lib/Target/AArch64/AArch64RegisterBanks.td
+++ b/llvm/lib/Target/AArch64/AArch64RegisterBanks.td
@@ -12,8 +12,8 @@
 /// General Purpose Registers: W, X.
 def GPRRegBank : RegisterBank<"GPR", [XSeqPairsClass]>;
 
-/// Floating Point/Vector Registers: B, H, S, D, Q.
-def FPRRegBank : RegisterBank<"FPR", [QQQQ]>;
+/// Floating Point, Vector, Scalable Vector Registers: B, H, S, D, Q, Z.
+def FPRRegBank : RegisterBank<"FPR", [QQQQ, ZPR]>;
 
 /// Conditional register: NZCV.
 def CCRegBank : RegisterBank<"CC", [CCR]>;

diff  --git a/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp b/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp
index a21be7de6f42f..07a0473888ee5 100644
--- a/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp
+++ b/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp
@@ -61,6 +61,11 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
   const LLT v2s64 = LLT::fixed_vector(2, 64);
   const LLT v2p0 = LLT::fixed_vector(2, p0);
 
+  const LLT nxv16s8 = LLT::scalable_vector(16, s8);
+  const LLT nxv8s16 = LLT::scalable_vector(8, s16);
+  const LLT nxv4s32 = LLT::scalable_vector(4, s32);
+  const LLT nxv2s64 = LLT::scalable_vector(2, s64);
+
   std::initializer_list<LLT> PackedVectorAllTypeList = {/* Begin 128bit types */
                                                         v16s8, v8s16, v4s32,
                                                         v2s64, v2p0,
@@ -328,7 +333,31 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
     return ValTy.isPointerVector() && ValTy.getAddressSpace() == 0;
   };
 
-  getActionDefinitionsBuilder(G_LOAD)
+  auto &LoadActions = getActionDefinitionsBuilder(G_LOAD);
+  auto &StoreActions = getActionDefinitionsBuilder(G_STORE);
+
+  if (ST.hasSVE()) {
+    LoadActions.legalForTypesWithMemDesc({
+        // 128 bit base sizes
+        {nxv16s8, p0, nxv16s8, 8},
+        {nxv8s16, p0, nxv8s16, 8},
+        {nxv4s32, p0, nxv4s32, 8},
+        {nxv2s64, p0, nxv2s64, 8},
+    });
+
+    // TODO: Add nxv2p0. Consider bitcastIf.
+    //       See #92130
+    //       https://github.com/llvm/llvm-project/pull/92130#discussion_r1616888461
+    StoreActions.legalForTypesWithMemDesc({
+        // 128 bit base sizes
+        {nxv16s8, p0, nxv16s8, 8},
+        {nxv8s16, p0, nxv8s16, 8},
+        {nxv4s32, p0, nxv4s32, 8},
+        {nxv2s64, p0, nxv2s64, 8},
+    });
+  }
+
+  LoadActions
       .customIf([=](const LegalityQuery &Query) {
         return HasRCPC3 && Query.Types[0] == s128 &&
                Query.MMODescrs[0].Ordering == AtomicOrdering::Acquire;
@@ -378,7 +407,7 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
       .customIf(IsPtrVecPred)
       .scalarizeIf(typeInSet(0, {v2s16, v2s8}), 0);
 
-  getActionDefinitionsBuilder(G_STORE)
+  StoreActions
       .customIf([=](const LegalityQuery &Query) {
         return HasRCPC3 && Query.Types[0] == s128 &&
                Query.MMODescrs[0].Ordering == AtomicOrdering::Release;

diff  --git a/llvm/lib/Target/AArch64/GISel/AArch64PostLegalizerCombiner.cpp b/llvm/lib/Target/AArch64/GISel/AArch64PostLegalizerCombiner.cpp
index 7f3e0e01ccd25..0c7be9f42c570 100644
--- a/llvm/lib/Target/AArch64/GISel/AArch64PostLegalizerCombiner.cpp
+++ b/llvm/lib/Target/AArch64/GISel/AArch64PostLegalizerCombiner.cpp
@@ -309,6 +309,8 @@ bool matchSplitStoreZero128(MachineInstr &MI, MachineRegisterInfo &MRI) {
   if (!Store.isSimple())
     return false;
   LLT ValTy = MRI.getType(Store.getValueReg());
+  if (ValTy.isScalableVector())
+    return false;
   if (!ValTy.isVector() || ValTy.getSizeInBits() != 128)
     return false;
   if (Store.getMemSizeInBits() != ValTy.getSizeInBits())
@@ -708,6 +710,11 @@ bool AArch64PostLegalizerCombiner::optimizeConsecutiveMemOpAddressing(
     // should only be in a single block.
     resetState();
     for (auto &MI : MBB) {
+      // Skip for scalable vectors
+      if (auto *LdSt = dyn_cast<GLoadStore>(&MI);
+          LdSt && MRI.getType(LdSt->getOperand(0).getReg()).isScalableVector())
+        continue;
+
       if (auto *St = dyn_cast<GStore>(&MI)) {
         Register PtrBaseReg;
         APInt Offset;

diff  --git a/llvm/lib/Target/AArch64/GISel/AArch64RegisterBankInfo.cpp b/llvm/lib/Target/AArch64/GISel/AArch64RegisterBankInfo.cpp
index 7785e020eaaf1..4aa6999d1d3ca 100644
--- a/llvm/lib/Target/AArch64/GISel/AArch64RegisterBankInfo.cpp
+++ b/llvm/lib/Target/AArch64/GISel/AArch64RegisterBankInfo.cpp
@@ -257,6 +257,7 @@ AArch64RegisterBankInfo::getRegBankFromRegClass(const TargetRegisterClass &RC,
   case AArch64::QQRegClassID:
   case AArch64::QQQRegClassID:
   case AArch64::QQQQRegClassID:
+  case AArch64::ZPRRegClassID:
     return getRegBank(AArch64::FPRRegBankID);
   case AArch64::GPR32commonRegClassID:
   case AArch64::GPR32RegClassID:
@@ -743,12 +744,15 @@ AArch64RegisterBankInfo::getInstrMapping(const MachineInstr &MI) const {
     LLT Ty = MRI.getType(MO.getReg());
     if (!Ty.isValid())
       continue;
-    OpSize[Idx] = Ty.getSizeInBits();
+    OpSize[Idx] = Ty.getSizeInBits().getKnownMinValue();
 
-    // As a top-level guess, vectors go in FPRs, scalars and pointers in GPRs.
+    // As a top-level guess, vectors including both scalable and non-scalable
+    // ones go in FPRs, scalars and pointers in GPRs.
     // For floating-point instructions, scalars go in FPRs.
-    if (Ty.isVector() || isPreISelGenericFloatingPointOpcode(Opc) ||
-        Ty.getSizeInBits() > 64)
+    if (Ty.isVector())
+      OpRegBankIdx[Idx] = PMI_FirstFPR;
+    else if (isPreISelGenericFloatingPointOpcode(Opc) ||
+             Ty.getSizeInBits() > 64)
       OpRegBankIdx[Idx] = PMI_FirstFPR;
     else
       OpRegBankIdx[Idx] = PMI_FirstGPR;

diff  --git a/llvm/test/CodeGen/AArch64/GlobalISel/sve-load-store.ll b/llvm/test/CodeGen/AArch64/GlobalISel/sve-load-store.ll
new file mode 100644
index 0000000000000..95a5bfa4b038f
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/GlobalISel/sve-load-store.ll
@@ -0,0 +1,50 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 4
+; RUN: llc -mtriple=aarch64-linux-gnu -mattr=+sve -global-isel -aarch64-enable-gisel-sve=true < %s | FileCheck %s
+
+define void @scalable_v16i8(ptr %l0, ptr %l1) {
+; CHECK-LABEL: scalable_v16i8:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    ptrue p0.b
+; CHECK-NEXT:    ld1b { z0.b }, p0/z, [x0]
+; CHECK-NEXT:    st1b { z0.b }, p0, [x1]
+; CHECK-NEXT:    ret
+  %l3 = load <vscale x 16 x i8>, ptr %l0, align 16
+  store <vscale x 16 x i8> %l3, ptr %l1, align 16
+  ret void
+}
+
+define void @scalable_v8i16(ptr %l0, ptr %l1) {
+; CHECK-LABEL: scalable_v8i16:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    ptrue p0.h
+; CHECK-NEXT:    ld1h { z0.h }, p0/z, [x0]
+; CHECK-NEXT:    st1h { z0.h }, p0, [x1]
+; CHECK-NEXT:    ret
+  %l3 = load <vscale x 8 x i16>, ptr %l0, align 16
+  store <vscale x 8 x i16> %l3, ptr %l1, align 16
+  ret void
+}
+
+define void @scalable_v4i32(ptr %l0, ptr %l1) {
+; CHECK-LABEL: scalable_v4i32:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    ptrue p0.s
+; CHECK-NEXT:    ld1w { z0.s }, p0/z, [x0]
+; CHECK-NEXT:    st1w { z0.s }, p0, [x1]
+; CHECK-NEXT:    ret
+  %l3 = load <vscale x 4 x i32>, ptr %l0, align 16
+  store <vscale x 4 x i32> %l3, ptr %l1, align 16
+  ret void
+}
+
+define void @scalable_v2i64(ptr %l0, ptr %l1) {
+; CHECK-LABEL: scalable_v2i64:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    ptrue p0.d
+; CHECK-NEXT:    ld1d { z0.d }, p0/z, [x0]
+; CHECK-NEXT:    st1d { z0.d }, p0, [x1]
+; CHECK-NEXT:    ret
+  %l3 = load <vscale x 2 x i64>, ptr %l0, align 16
+  store <vscale x 2 x i64> %l3, ptr %l1, align 16
+  ret void
+}


        


More information about the llvm-commits mailing list