[llvm] [AArch64][GISel] Support SVE with 128-bit min-size for G_LOAD and G_STORE (PR #92130)

via llvm-commits llvm-commits at lists.llvm.org
Tue May 28 07:37:05 PDT 2024


https://github.com/Him188 updated https://github.com/llvm/llvm-project/pull/92130

>From a61fb1a3e989596d53c8f5c53dc1b8ca702cb7c6 Mon Sep 17 00:00:00 2001
From: Tianyi Guan <tguan at nvidia.com>
Date: Tue, 7 May 2024 10:10:29 +0100
Subject: [PATCH 01/21] [AArch64][GISel] Support SVE with 128-bit min-size for
 G_LOAD and G_STORE

This patch adds basic support for scalable vector types in load & store instructions for AArch64 with GISel.
Only scalable vector types with a 128-bit base size are supported, e.g. <vscale x 4 x i32>, <vscale x 16 x i8>.

This patch adapted some ideas from a similar abandoned patch https://github.com/llvm/llvm-project/pull/72976.
---
 .../GlobalISel/GIMatchTableExecutorImpl.h     |  8 +-
 .../lib/CodeGen/GlobalISel/CombinerHelper.cpp |  2 +-
 llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp  |  2 +-
 .../Target/AArch64/AArch64ISelLowering.cpp    | 14 ++-
 .../Target/AArch64/AArch64RegisterBanks.td    |  2 +-
 .../GISel/AArch64InstructionSelector.cpp      | 59 ++++++++++--
 .../AArch64/GISel/AArch64LegalizerInfo.cpp    | 93 ++++++++++++++++++-
 .../GISel/AArch64PostLegalizerCombiner.cpp    |  6 +-
 .../AArch64/GISel/AArch64RegisterBankInfo.cpp | 10 +-
 .../AArch64/GlobalISel/sve-load-store.ll      | 50 ++++++++++
 10 files changed, 221 insertions(+), 25 deletions(-)
 create mode 100644 llvm/test/CodeGen/AArch64/GlobalISel/sve-load-store.ll

diff --git a/llvm/include/llvm/CodeGen/GlobalISel/GIMatchTableExecutorImpl.h b/llvm/include/llvm/CodeGen/GlobalISel/GIMatchTableExecutorImpl.h
index 4d147bf20c26a..29939d4619400 100644
--- a/llvm/include/llvm/CodeGen/GlobalISel/GIMatchTableExecutorImpl.h
+++ b/llvm/include/llvm/CodeGen/GlobalISel/GIMatchTableExecutorImpl.h
@@ -652,17 +652,17 @@ bool GIMatchTableExecutor::executeMatchTable(
       MachineMemOperand *MMO =
           *(State.MIs[InsnID]->memoperands_begin() + MMOIdx);
 
-      unsigned Size = MRI.getType(MO.getReg()).getSizeInBits();
+      const auto Size = MRI.getType(MO.getReg()).getSizeInBits();
       if (MatcherOpcode == GIM_CheckMemorySizeEqualToLLT &&
-          MMO->getSizeInBits().getValue() != Size) {
+          MMO->getSizeInBits() != Size) {
         if (handleReject() == RejectAndGiveUp)
           return false;
       } else if (MatcherOpcode == GIM_CheckMemorySizeLessThanLLT &&
-                 MMO->getSizeInBits().getValue() >= Size) {
+                 MMO->getSizeInBits().getValue() >= Size.getKnownMinValue()) {
         if (handleReject() == RejectAndGiveUp)
           return false;
       } else if (MatcherOpcode == GIM_CheckMemorySizeGreaterThanLLT &&
-                 MMO->getSizeInBits().getValue() <= Size)
+                 MMO->getSizeInBits().getValue() <= Size.getKnownMinValue())
         if (handleReject() == RejectAndGiveUp)
           return false;
 
diff --git a/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp b/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp
index 653e7689b5774..141c7ee15fe39 100644
--- a/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp
@@ -1080,7 +1080,7 @@ bool CombinerHelper::isIndexedLoadStoreLegal(GLoadStore &LdSt) const {
   LLT Ty = MRI.getType(LdSt.getReg(0));
   LLT MemTy = LdSt.getMMO().getMemoryType();
   SmallVector<LegalityQuery::MemDesc, 2> MemDescrs(
-      {{MemTy, MemTy.getSizeInBits(), AtomicOrdering::NotAtomic}});
+      {{MemTy, MemTy.getSizeInBits().getKnownMinValue(), AtomicOrdering::NotAtomic}});
   unsigned IndexedOpc = getIndexedOpc(LdSt.getOpcode());
   SmallVector<LLT> OpTys;
   if (IndexedOpc == TargetOpcode::G_INDEXED_STORE)
diff --git a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
index 6661127162e52..b14a004d5c4ac 100644
--- a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
@@ -1413,7 +1413,7 @@ bool IRTranslator::translateLoad(const User &U, MachineIRBuilder &MIRBuilder) {
 
 bool IRTranslator::translateStore(const User &U, MachineIRBuilder &MIRBuilder) {
   const StoreInst &SI = cast<StoreInst>(U);
-  if (DL->getTypeStoreSize(SI.getValueOperand()->getType()) == 0)
+  if (DL->getTypeStoreSize(SI.getValueOperand()->getType()).isZero())
     return true;
 
   ArrayRef<Register> Vals = getOrCreateVRegs(*SI.getValueOperand());
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index c1ca78af5cda8..e0be162e10a97 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -26375,12 +26375,20 @@ bool AArch64TargetLowering::shouldLocalize(
   return TargetLoweringBase::shouldLocalize(MI, TTI);
 }
 
+static bool isScalableTySupported(const unsigned Op) {
+  return Op == Instruction::Load || Op == Instruction::Store;
+}
+
 bool AArch64TargetLowering::fallBackToDAGISel(const Instruction &Inst) const {
-  if (Inst.getType()->isScalableTy())
-    return true;
+  const auto ScalableTySupported = isScalableTySupported(Inst.getOpcode());
+
+  // Fallback for scalable vectors
+  if (Inst.getType()->isScalableTy() && !ScalableTySupported) {
+      return true;
+  }
 
   for (unsigned i = 0; i < Inst.getNumOperands(); ++i)
-    if (Inst.getOperand(i)->getType()->isScalableTy())
+    if (Inst.getOperand(i)->getType()->isScalableTy() && !ScalableTySupported)
       return true;
 
   if (const AllocaInst *AI = dyn_cast<AllocaInst>(&Inst)) {
diff --git a/llvm/lib/Target/AArch64/AArch64RegisterBanks.td b/llvm/lib/Target/AArch64/AArch64RegisterBanks.td
index 615ce7d51d9ba..9e2ed356299e2 100644
--- a/llvm/lib/Target/AArch64/AArch64RegisterBanks.td
+++ b/llvm/lib/Target/AArch64/AArch64RegisterBanks.td
@@ -13,7 +13,7 @@
 def GPRRegBank : RegisterBank<"GPR", [XSeqPairsClass]>;
 
 /// Floating Point/Vector Registers: B, H, S, D, Q.
-def FPRRegBank : RegisterBank<"FPR", [QQQQ]>;
+def FPRRegBank : RegisterBank<"FPR", [QQQQ, ZPR]>;
 
 /// Conditional register: NZCV.
 def CCRegBank : RegisterBank<"CC", [CCR]>;
diff --git a/llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp b/llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp
index 61f5bc2464ee5..bc47443c45c8e 100644
--- a/llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp
+++ b/llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp
@@ -901,6 +901,27 @@ static unsigned selectLoadStoreUIOp(unsigned GenericOpc, unsigned RegBankID,
   return GenericOpc;
 }
 
+/// Select the AArch64 opcode for the G_LOAD or G_STORE operation for scalable 
+/// vectors.
+/// \p ElementSize size of the element of the scalable vector
+static unsigned selectLoadStoreSVEOp(const unsigned GenericOpc,
+                                     const unsigned ElementSize) {
+  const bool isStore = GenericOpc == TargetOpcode::G_STORE;
+  
+  switch (ElementSize) {
+    case 8:
+      return isStore ? AArch64::ST1B : AArch64::LD1B;
+    case 16:
+      return isStore ? AArch64::ST1H : AArch64::LD1H;
+    case 32:
+      return isStore ? AArch64::ST1W : AArch64::LD1W;
+    case 64:
+      return isStore ? AArch64::ST1D : AArch64::LD1D;
+  }
+  
+  return GenericOpc;
+}
+
 /// Helper function for selectCopy. Inserts a subregister copy from \p SrcReg
 /// to \p *To.
 ///
@@ -2853,8 +2874,8 @@ bool AArch64InstructionSelector::select(MachineInstr &I) {
       return false;
     }
 
-    uint64_t MemSizeInBytes = LdSt.getMemSize().getValue();
-    unsigned MemSizeInBits = LdSt.getMemSizeInBits().getValue();
+    uint64_t MemSizeInBytes = LdSt.getMemSize().getValue().getKnownMinValue();
+    unsigned MemSizeInBits = LdSt.getMemSizeInBits().getValue().getKnownMinValue();
     AtomicOrdering Order = LdSt.getMMO().getSuccessOrdering();
 
     // Need special instructions for atomics that affect ordering.
@@ -2906,9 +2927,23 @@ bool AArch64InstructionSelector::select(MachineInstr &I) {
     const LLT ValTy = MRI.getType(ValReg);
     const RegisterBank &RB = *RBI.getRegBank(ValReg, MRI, TRI);
 
+#ifndef NDEBUG
+    if (ValTy.isScalableVector()) {
+        assert(STI.hasSVE() 
+             && "Load/Store register operand is scalable vector "
+                "while SVE is not supported by the target");
+        // assert(RB.getID() == AArch64::SVRRegBankID 
+        //        && "Load/Store register operand is scalable vector "
+        //           "while its register bank is not SVR");
+    }
+#endif
+    
     // The code below doesn't support truncating stores, so we need to split it
     // again.
-    if (isa<GStore>(LdSt) && ValTy.getSizeInBits() > MemSizeInBits) {
+    // Truncate only if type is not scalable vector
+    const bool NeedTrunc = !ValTy.isScalableVector() 
+                      && ValTy.getSizeInBits().getFixedValue() > MemSizeInBits;
+    if (isa<GStore>(LdSt) && NeedTrunc) {
       unsigned SubReg;
       LLT MemTy = LdSt.getMMO().getMemoryType();
       auto *RC = getRegClassForTypeOnBank(MemTy, RB);
@@ -2921,7 +2956,7 @@ bool AArch64InstructionSelector::select(MachineInstr &I) {
                       .getReg(0);
       RBI.constrainGenericRegister(Copy, *RC, MRI);
       LdSt.getOperand(0).setReg(Copy);
-    } else if (isa<GLoad>(LdSt) && ValTy.getSizeInBits() > MemSizeInBits) {
+    } else if (isa<GLoad>(LdSt) && NeedTrunc) {
       // If this is an any-extending load from the FPR bank, split it into a regular
       // load + extend.
       if (RB.getID() == AArch64::FPRRegBankID) {
@@ -2951,10 +2986,19 @@ bool AArch64InstructionSelector::select(MachineInstr &I) {
     // instruction with an updated opcode, or a new instruction.
     auto SelectLoadStoreAddressingMode = [&]() -> MachineInstr * {
       bool IsStore = isa<GStore>(I);
-      const unsigned NewOpc =
-          selectLoadStoreUIOp(I.getOpcode(), RB.getID(), MemSizeInBits);
+      unsigned NewOpc;
+      if (ValTy.isScalableVector()) {
+        NewOpc = selectLoadStoreSVEOp(I.getOpcode(), ValTy.getElementType().getSizeInBits());
+      } else {
+        NewOpc = selectLoadStoreUIOp(I.getOpcode(), RB.getID(), MemSizeInBits);
+      }
       if (NewOpc == I.getOpcode())
         return nullptr;
+
+      if (ValTy.isScalableVector()) {
+        // Add the predicate register operand
+        I.addOperand(MachineOperand::CreatePredicate(true));
+      }
       // Check if we can fold anything into the addressing mode.
       auto AddrModeFns =
           selectAddrModeIndexed(I.getOperand(1), MemSizeInBytes);
@@ -2970,6 +3014,9 @@ bool AArch64InstructionSelector::select(MachineInstr &I) {
       Register CurValReg = I.getOperand(0).getReg();
       IsStore ? NewInst.addUse(CurValReg) : NewInst.addDef(CurValReg);
       NewInst.cloneMemRefs(I);
+      if (ValTy.isScalableVector()) {
+        NewInst.add(I.getOperand(1)); // Copy predicate register
+      }
       for (auto &Fn : *AddrModeFns)
         Fn(NewInst);
       I.eraseFromParent();
diff --git a/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp b/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp
index d4aac94d24f12..c4f5b75ce959f 100644
--- a/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp
+++ b/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp
@@ -61,6 +61,79 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
   const LLT v2s64 = LLT::fixed_vector(2, 64);
   const LLT v2p0 = LLT::fixed_vector(2, p0);
 
+  // Scalable vector sizes range from 128 to 2048
+  // Note that subtargets may not support the full range.
+  // See [ScalableVecTypes] below.
+  const LLT nxv16s8 = LLT::scalable_vector(16, s8);
+  const LLT nxv32s8 = LLT::scalable_vector(32, s8);
+  const LLT nxv64s8 = LLT::scalable_vector(64, s8);
+  const LLT nxv128s8 = LLT::scalable_vector(128, s8);
+  const LLT nxv256s8 = LLT::scalable_vector(256, s8);
+
+  const LLT nxv8s16 = LLT::scalable_vector(8, s16);
+  const LLT nxv16s16 = LLT::scalable_vector(16, s16);
+  const LLT nxv32s16 = LLT::scalable_vector(32, s16);
+  const LLT nxv64s16 = LLT::scalable_vector(64, s16);
+  const LLT nxv128s16 = LLT::scalable_vector(128, s16);
+
+  const LLT nxv4s32 = LLT::scalable_vector(4, s32); 
+  const LLT nxv8s32 = LLT::scalable_vector(8, s32); 
+  const LLT nxv16s32 = LLT::scalable_vector(16, s32); 
+  const LLT nxv32s32 = LLT::scalable_vector(32, s32);
+  const LLT nxv64s32 = LLT::scalable_vector(64, s32);
+
+  const LLT nxv2s64 = LLT::scalable_vector(2, s64);
+  const LLT nxv4s64 = LLT::scalable_vector(4, s64);
+  const LLT nxv8s64 = LLT::scalable_vector(8, s64);
+  const LLT nxv16s64 = LLT::scalable_vector(16, s64);
+  const LLT nxv32s64 = LLT::scalable_vector(32, s64);
+
+  const LLT nxv2p0 = LLT::scalable_vector(2, p0);
+  const LLT nxv4p0 = LLT::scalable_vector(4, p0);
+  const LLT nxv8p0 = LLT::scalable_vector(8, p0);
+  const LLT nxv16p0 = LLT::scalable_vector(16, p0);
+  const LLT nxv32p0 = LLT::scalable_vector(32, p0);
+
+  const auto ScalableVec128 = {
+    nxv16s8, nxv8s16, nxv4s32, nxv2s64, nxv2p0,
+  };
+  const auto ScalableVec256 = {
+    nxv32s8, nxv16s16, nxv8s32, nxv4s64, nxv4p0,
+  };
+  const auto ScalableVec512 = {
+    nxv64s8, nxv32s16, nxv16s32, nxv8s64, nxv8p0,
+  };
+  const auto ScalableVec1024 = {
+    nxv128s8, nxv64s16, nxv32s32, nxv16s64, nxv16p0,
+  };
+  const auto ScalableVec2048 = {
+    nxv256s8, nxv128s16, nxv64s32, nxv32s64, nxv32p0,
+  };
+
+  /// Scalable vector types supported by the sub target.
+  /// Empty if SVE is not supported.
+  SmallVector<LLT> ScalableVecTypes;
+  
+  if (ST.hasSVE()) {
+    // Add scalable vector types that are supported by the subtarget
+    const auto MinSize = ST.getMinSVEVectorSizeInBits();
+    auto MaxSize = ST.getMaxSVEVectorSizeInBits();
+    if (MaxSize == 0) {
+      // Unknown max size, assume the target supports all sizes.
+      MaxSize = 2048; 
+    }
+    if (MinSize <= 128 && 128 <= MaxSize)
+      ScalableVecTypes.append(ScalableVec128);
+    if (MinSize <= 256 && 256 <= MaxSize)
+      ScalableVecTypes.append(ScalableVec256);
+    if (MinSize <= 512 && 512 <= MaxSize)
+      ScalableVecTypes.append(ScalableVec512);
+    if (MinSize <= 1024 && 1024 <= MaxSize)
+      ScalableVecTypes.append(ScalableVec1024);
+    if (MinSize <= 2048 && 2048 <= MaxSize)
+      ScalableVecTypes.append(ScalableVec2048);
+  }
+
   std::initializer_list<LLT> PackedVectorAllTypeList = {/* Begin 128bit types */
                                                         v16s8, v8s16, v4s32,
                                                         v2s64, v2p0,
@@ -329,6 +402,18 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
     return ValTy.isPointerVector() && ValTy.getAddressSpace() == 0;
   };
 
+  const auto IsSameScalableVecTy = [=](const LegalityQuery &Query) {
+    // Legal if loading a scalable vector type
+    // into a scalable vector register of the exactly same type
+    if (!Query.Types[0].isScalableVector() || Query.Types[1] != p0)
+      return false;
+    if (Query.MMODescrs[0].MemoryTy != Query.Types[0])
+      return false;
+    if (Query.MMODescrs[0].AlignInBits < 128)
+      return false;
+    return is_contained(ScalableVecTypes, Query.Types[0]);
+  };
+
   getActionDefinitionsBuilder(G_LOAD)
       .customIf([=](const LegalityQuery &Query) {
         return HasRCPC3 && Query.Types[0] == s128 &&
@@ -354,6 +439,7 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
       // These extends are also legal
       .legalForTypesWithMemDesc(
           {{s32, p0, s8, 8}, {s32, p0, s16, 8}, {s64, p0, s32, 8}})
+      .legalIf(IsSameScalableVecTy)
       .widenScalarToNextPow2(0, /* MinSize = */ 8)
       .clampMaxNumElements(0, s8, 16)
       .clampMaxNumElements(0, s16, 8)
@@ -398,7 +484,9 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
            {s64, p0, s64, 8},   {s64, p0, s32, 8}, // truncstorei32 from s64
            {p0, p0, s64, 8},    {s128, p0, s128, 8},  {v16s8, p0, s128, 8},
            {v8s8, p0, s64, 8},  {v4s16, p0, s64, 8},  {v8s16, p0, s128, 8},
-           {v2s32, p0, s64, 8}, {v4s32, p0, s128, 8}, {v2s64, p0, s128, 8}})
+           {v2s32, p0, s64, 8}, {v4s32, p0, s128, 8}, {v2s64, p0, s128, 8},
+          })
+      .legalIf(IsSameScalableVecTy)
       .clampScalar(0, s8, s64)
       .lowerIf([=](const LegalityQuery &Query) {
         return Query.Types[0].isScalar() &&
@@ -440,8 +528,7 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
           {p0, v4s32, v4s32, 8},
           {p0, v2s64, v2s64, 8},
           {p0, v2p0, v2p0, 8},
-          {p0, s128, s128, 8},
-      })
+          {p0, s128, s128, 8}})
       .unsupported();
 
   auto IndexedLoadBasicPred = [=](const LegalityQuery &Query) {
diff --git a/llvm/lib/Target/AArch64/GISel/AArch64PostLegalizerCombiner.cpp b/llvm/lib/Target/AArch64/GISel/AArch64PostLegalizerCombiner.cpp
index d8ca5494ba50a..5830489e8ef90 100644
--- a/llvm/lib/Target/AArch64/GISel/AArch64PostLegalizerCombiner.cpp
+++ b/llvm/lib/Target/AArch64/GISel/AArch64PostLegalizerCombiner.cpp
@@ -309,7 +309,7 @@ bool matchSplitStoreZero128(MachineInstr &MI, MachineRegisterInfo &MRI) {
   if (!Store.isSimple())
     return false;
   LLT ValTy = MRI.getType(Store.getValueReg());
-  if (!ValTy.isVector() || ValTy.getSizeInBits() != 128)
+  if (!ValTy.isVector() || ValTy.getSizeInBits().getKnownMinValue() != 128)
     return false;
   if (Store.getMemSizeInBits() != ValTy.getSizeInBits())
     return false; // Don't split truncating stores.
@@ -657,8 +657,8 @@ bool AArch64PostLegalizerCombiner::optimizeConsecutiveMemOpAddressing(
         Register PtrBaseReg;
         APInt Offset;
         LLT StoredValTy = MRI.getType(St->getValueReg());
-        unsigned ValSize = StoredValTy.getSizeInBits();
-        if (ValSize < 32 || St->getMMO().getSizeInBits() != ValSize)
+        const auto ValSize = StoredValTy.getSizeInBits();
+        if (ValSize.getKnownMinValue() < 32 || St->getMMO().getSizeInBits() != ValSize)
           continue;
 
         Register PtrReg = St->getPointerReg();
diff --git a/llvm/lib/Target/AArch64/GISel/AArch64RegisterBankInfo.cpp b/llvm/lib/Target/AArch64/GISel/AArch64RegisterBankInfo.cpp
index 44ba9f0429e67..f249729b4b4ab 100644
--- a/llvm/lib/Target/AArch64/GISel/AArch64RegisterBankInfo.cpp
+++ b/llvm/lib/Target/AArch64/GISel/AArch64RegisterBankInfo.cpp
@@ -257,6 +257,7 @@ AArch64RegisterBankInfo::getRegBankFromRegClass(const TargetRegisterClass &RC,
   case AArch64::QQRegClassID:
   case AArch64::QQQRegClassID:
   case AArch64::QQQQRegClassID:
+  case AArch64::ZPRRegClassID:
     return getRegBank(AArch64::FPRRegBankID);
   case AArch64::GPR32commonRegClassID:
   case AArch64::GPR32RegClassID:
@@ -740,11 +741,14 @@ AArch64RegisterBankInfo::getInstrMapping(const MachineInstr &MI) const {
     LLT Ty = MRI.getType(MO.getReg());
     if (!Ty.isValid())
       continue;
-    OpSize[Idx] = Ty.getSizeInBits();
+    OpSize[Idx] = Ty.getSizeInBits().getKnownMinValue();
 
-    // As a top-level guess, vectors go in FPRs, scalars and pointers in GPRs.
+    // As a top-level guess, scalable vectors go in SVRs, non-scalable
+    // vectors go in FPRs, scalars and pointers in GPRs.
     // For floating-point instructions, scalars go in FPRs.
-    if (Ty.isVector() || isPreISelGenericFloatingPointOpcode(Opc) ||
+    if (Ty.isScalableVector()) 
+      OpRegBankIdx[Idx] = PMI_FirstFPR;
+    else if (Ty.isVector() || isPreISelGenericFloatingPointOpcode(Opc) ||
         Ty.getSizeInBits() > 64)
       OpRegBankIdx[Idx] = PMI_FirstFPR;
     else
diff --git a/llvm/test/CodeGen/AArch64/GlobalISel/sve-load-store.ll b/llvm/test/CodeGen/AArch64/GlobalISel/sve-load-store.ll
new file mode 100644
index 0000000000000..7a794387eb011
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/GlobalISel/sve-load-store.ll
@@ -0,0 +1,50 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 4
+; RUN: llc -mtriple=aarch64-linux-gnu -mattr=+sve -global-isel < %s | FileCheck %s
+
+define void @scalable_v16i8(ptr noalias nocapture noundef %l0, ptr noalias nocapture noundef %l1) {
+; CHECK-LABEL: scalable_v16i8:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    ptrue p0.b
+; CHECK-NEXT:    ld1b { z0.b }, p0/z, [x0]
+; CHECK-NEXT:    st1b { z0.b }, p0, [x1]
+; CHECK-NEXT:    ret
+  %l3 = load <vscale x 16 x i8>, ptr %l0, align 16
+  store <vscale x 16 x i8> %l3, ptr %l1, align 16
+  ret void
+}
+
+define void @scalable_v8i16(ptr noalias nocapture noundef %l0, ptr noalias nocapture noundef %l1) {
+; CHECK-LABEL: scalable_v8i16:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    ptrue p0.h
+; CHECK-NEXT:    ld1h { z0.h }, p0/z, [x0]
+; CHECK-NEXT:    st1h { z0.h }, p0, [x1]
+; CHECK-NEXT:    ret
+  %l3 = load <vscale x 8 x i16>, ptr %l0, align 16
+  store <vscale x 8 x i16> %l3, ptr %l1, align 16
+  ret void
+}
+
+define void @scalable_v4i32(ptr noalias nocapture noundef %l0, ptr noalias nocapture noundef %l1) {
+; CHECK-LABEL: scalable_v4i32:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    ptrue p0.s
+; CHECK-NEXT:    ld1w { z0.s }, p0/z, [x0]
+; CHECK-NEXT:    st1w { z0.s }, p0, [x1]
+; CHECK-NEXT:    ret
+  %l3 = load <vscale x 4 x i32>, ptr %l0, align 16
+  store <vscale x 4 x i32> %l3, ptr %l1, align 16
+  ret void
+}
+
+define void @scalable_v2i64(ptr noalias nocapture noundef %l0, ptr noalias nocapture noundef %l1) {
+; CHECK-LABEL: scalable_v2i64:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    ptrue p0.d
+; CHECK-NEXT:    ld1d { z0.d }, p0/z, [x0]
+; CHECK-NEXT:    st1d { z0.d }, p0, [x1]
+; CHECK-NEXT:    ret
+  %l3 = load <vscale x 2 x i64>, ptr %l0, align 16
+  store <vscale x 2 x i64> %l3, ptr %l1, align 16
+  ret void
+}

>From 135c91fd128ad8b60d56edacfe4fb0bb748de307 Mon Sep 17 00:00:00 2001
From: Tianyi Guan <tguan at nvidia.com>
Date: Tue, 14 May 2024 16:28:32 +0100
Subject: [PATCH 02/21] Remove unnecessary attributes in tests

---
 llvm/test/CodeGen/AArch64/GlobalISel/sve-load-store.ll | 8 ++++----
 1 file changed, 4 insertions(+), 4 deletions(-)

diff --git a/llvm/test/CodeGen/AArch64/GlobalISel/sve-load-store.ll b/llvm/test/CodeGen/AArch64/GlobalISel/sve-load-store.ll
index 7a794387eb011..4c3ffb99e5667 100644
--- a/llvm/test/CodeGen/AArch64/GlobalISel/sve-load-store.ll
+++ b/llvm/test/CodeGen/AArch64/GlobalISel/sve-load-store.ll
@@ -1,7 +1,7 @@
 ; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 4
 ; RUN: llc -mtriple=aarch64-linux-gnu -mattr=+sve -global-isel < %s | FileCheck %s
 
-define void @scalable_v16i8(ptr noalias nocapture noundef %l0, ptr noalias nocapture noundef %l1) {
+define void @scalable_v16i8(ptr %l0, ptr %l1) {
 ; CHECK-LABEL: scalable_v16i8:
 ; CHECK:       // %bb.0:
 ; CHECK-NEXT:    ptrue p0.b
@@ -13,7 +13,7 @@ define void @scalable_v16i8(ptr noalias nocapture noundef %l0, ptr noalias nocap
   ret void
 }
 
-define void @scalable_v8i16(ptr noalias nocapture noundef %l0, ptr noalias nocapture noundef %l1) {
+define void @scalable_v8i16(ptr %l0, ptr %l1) {
 ; CHECK-LABEL: scalable_v8i16:
 ; CHECK:       // %bb.0:
 ; CHECK-NEXT:    ptrue p0.h
@@ -25,7 +25,7 @@ define void @scalable_v8i16(ptr noalias nocapture noundef %l0, ptr noalias nocap
   ret void
 }
 
-define void @scalable_v4i32(ptr noalias nocapture noundef %l0, ptr noalias nocapture noundef %l1) {
+define void @scalable_v4i32(ptr %l0, ptr %l1) {
 ; CHECK-LABEL: scalable_v4i32:
 ; CHECK:       // %bb.0:
 ; CHECK-NEXT:    ptrue p0.s
@@ -37,7 +37,7 @@ define void @scalable_v4i32(ptr noalias nocapture noundef %l0, ptr noalias nocap
   ret void
 }
 
-define void @scalable_v2i64(ptr noalias nocapture noundef %l0, ptr noalias nocapture noundef %l1) {
+define void @scalable_v2i64(ptr %l0, ptr %l1) {
 ; CHECK-LABEL: scalable_v2i64:
 ; CHECK:       // %bb.0:
 ; CHECK-NEXT:    ptrue p0.d

>From e2ec9514876744627e99fecb7179accdc6969e4e Mon Sep 17 00:00:00 2001
From: Tianyi Guan <tguan at nvidia.com>
Date: Tue, 14 May 2024 16:30:37 +0100
Subject: [PATCH 03/21] Remove unnecessary `#ifndef` macro around assertions

---
 llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp | 5 -----
 1 file changed, 5 deletions(-)

diff --git a/llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp b/llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp
index bc47443c45c8e..1da6fce1aa283 100644
--- a/llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp
+++ b/llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp
@@ -2927,16 +2927,11 @@ bool AArch64InstructionSelector::select(MachineInstr &I) {
     const LLT ValTy = MRI.getType(ValReg);
     const RegisterBank &RB = *RBI.getRegBank(ValReg, MRI, TRI);
 
-#ifndef NDEBUG
     if (ValTy.isScalableVector()) {
         assert(STI.hasSVE() 
              && "Load/Store register operand is scalable vector "
                 "while SVE is not supported by the target");
-        // assert(RB.getID() == AArch64::SVRRegBankID 
-        //        && "Load/Store register operand is scalable vector "
-        //           "while its register bank is not SVR");
     }
-#endif
     
     // The code below doesn't support truncating stores, so we need to split it
     // again.

>From 70b3f217a7dafa10dd9f3662c868d7a6a994f68e Mon Sep 17 00:00:00 2001
From: Tianyi Guan <tguan at nvidia.com>
Date: Wed, 15 May 2024 11:21:16 +0100
Subject: [PATCH 04/21] Legal only for size of multiple of 128

---
 .../AArch64/GISel/AArch64LegalizerInfo.cpp    | 93 +++----------------
 1 file changed, 13 insertions(+), 80 deletions(-)

diff --git a/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp b/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp
index c4f5b75ce959f..b93ed0e50e8dd 100644
--- a/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp
+++ b/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp
@@ -61,78 +61,11 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
   const LLT v2s64 = LLT::fixed_vector(2, 64);
   const LLT v2p0 = LLT::fixed_vector(2, p0);
 
-  // Scalable vector sizes range from 128 to 2048
-  // Note that subtargets may not support the full range.
-  // See [ScalableVecTypes] below.
   const LLT nxv16s8 = LLT::scalable_vector(16, s8);
-  const LLT nxv32s8 = LLT::scalable_vector(32, s8);
-  const LLT nxv64s8 = LLT::scalable_vector(64, s8);
-  const LLT nxv128s8 = LLT::scalable_vector(128, s8);
-  const LLT nxv256s8 = LLT::scalable_vector(256, s8);
-
   const LLT nxv8s16 = LLT::scalable_vector(8, s16);
-  const LLT nxv16s16 = LLT::scalable_vector(16, s16);
-  const LLT nxv32s16 = LLT::scalable_vector(32, s16);
-  const LLT nxv64s16 = LLT::scalable_vector(64, s16);
-  const LLT nxv128s16 = LLT::scalable_vector(128, s16);
-
   const LLT nxv4s32 = LLT::scalable_vector(4, s32); 
-  const LLT nxv8s32 = LLT::scalable_vector(8, s32); 
-  const LLT nxv16s32 = LLT::scalable_vector(16, s32); 
-  const LLT nxv32s32 = LLT::scalable_vector(32, s32);
-  const LLT nxv64s32 = LLT::scalable_vector(64, s32);
-
   const LLT nxv2s64 = LLT::scalable_vector(2, s64);
-  const LLT nxv4s64 = LLT::scalable_vector(4, s64);
-  const LLT nxv8s64 = LLT::scalable_vector(8, s64);
-  const LLT nxv16s64 = LLT::scalable_vector(16, s64);
-  const LLT nxv32s64 = LLT::scalable_vector(32, s64);
-
   const LLT nxv2p0 = LLT::scalable_vector(2, p0);
-  const LLT nxv4p0 = LLT::scalable_vector(4, p0);
-  const LLT nxv8p0 = LLT::scalable_vector(8, p0);
-  const LLT nxv16p0 = LLT::scalable_vector(16, p0);
-  const LLT nxv32p0 = LLT::scalable_vector(32, p0);
-
-  const auto ScalableVec128 = {
-    nxv16s8, nxv8s16, nxv4s32, nxv2s64, nxv2p0,
-  };
-  const auto ScalableVec256 = {
-    nxv32s8, nxv16s16, nxv8s32, nxv4s64, nxv4p0,
-  };
-  const auto ScalableVec512 = {
-    nxv64s8, nxv32s16, nxv16s32, nxv8s64, nxv8p0,
-  };
-  const auto ScalableVec1024 = {
-    nxv128s8, nxv64s16, nxv32s32, nxv16s64, nxv16p0,
-  };
-  const auto ScalableVec2048 = {
-    nxv256s8, nxv128s16, nxv64s32, nxv32s64, nxv32p0,
-  };
-
-  /// Scalable vector types supported by the sub target.
-  /// Empty if SVE is not supported.
-  SmallVector<LLT> ScalableVecTypes;
-  
-  if (ST.hasSVE()) {
-    // Add scalable vector types that are supported by the subtarget
-    const auto MinSize = ST.getMinSVEVectorSizeInBits();
-    auto MaxSize = ST.getMaxSVEVectorSizeInBits();
-    if (MaxSize == 0) {
-      // Unknown max size, assume the target supports all sizes.
-      MaxSize = 2048; 
-    }
-    if (MinSize <= 128 && 128 <= MaxSize)
-      ScalableVecTypes.append(ScalableVec128);
-    if (MinSize <= 256 && 256 <= MaxSize)
-      ScalableVecTypes.append(ScalableVec256);
-    if (MinSize <= 512 && 512 <= MaxSize)
-      ScalableVecTypes.append(ScalableVec512);
-    if (MinSize <= 1024 && 1024 <= MaxSize)
-      ScalableVecTypes.append(ScalableVec1024);
-    if (MinSize <= 2048 && 2048 <= MaxSize)
-      ScalableVecTypes.append(ScalableVec2048);
-  }
 
   std::initializer_list<LLT> PackedVectorAllTypeList = {/* Begin 128bit types */
                                                         v16s8, v8s16, v4s32,
@@ -402,17 +335,19 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
     return ValTy.isPointerVector() && ValTy.getAddressSpace() == 0;
   };
 
-  const auto IsSameScalableVecTy = [=](const LegalityQuery &Query) {
-    // Legal if loading a scalable vector type
-    // into a scalable vector register of the exactly same type
-    if (!Query.Types[0].isScalableVector() || Query.Types[1] != p0)
-      return false;
-    if (Query.MMODescrs[0].MemoryTy != Query.Types[0])
-      return false;
-    if (Query.MMODescrs[0].AlignInBits < 128)
-      return false;
-    return is_contained(ScalableVecTypes, Query.Types[0]);
-  };
+  if (ST.hasSVE()) {
+    for (const auto OpCode : {G_LOAD, G_STORE}) {
+      getActionDefinitionsBuilder(OpCode)
+      .legalForTypesWithMemDesc({
+        // 128 bit base sizes
+        {nxv16s8, p0, nxv16s8, 128},
+        {nxv8s16, p0, nxv8s16, 128},
+        {nxv4s32, p0, nxv4s32, 128},
+        {nxv2s64, p0, nxv2s64, 128},
+        {nxv2p0, p0, nxv2p0, 128},
+      });
+    }
+  }
 
   getActionDefinitionsBuilder(G_LOAD)
       .customIf([=](const LegalityQuery &Query) {
@@ -439,7 +374,6 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
       // These extends are also legal
       .legalForTypesWithMemDesc(
           {{s32, p0, s8, 8}, {s32, p0, s16, 8}, {s64, p0, s32, 8}})
-      .legalIf(IsSameScalableVecTy)
       .widenScalarToNextPow2(0, /* MinSize = */ 8)
       .clampMaxNumElements(0, s8, 16)
       .clampMaxNumElements(0, s16, 8)
@@ -486,7 +420,6 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
            {v8s8, p0, s64, 8},  {v4s16, p0, s64, 8},  {v8s16, p0, s128, 8},
            {v2s32, p0, s64, 8}, {v4s32, p0, s128, 8}, {v2s64, p0, s128, 8},
           })
-      .legalIf(IsSameScalableVecTy)
       .clampScalar(0, s8, s64)
       .lowerIf([=](const LegalityQuery &Query) {
         return Query.Types[0].isScalar() &&

>From 73a618b6fc68f9b1c61760ee771f67ffca6657a0 Mon Sep 17 00:00:00 2001
From: Tianyi Guan <tguan at nvidia.com>
Date: Wed, 15 May 2024 16:06:45 +0100
Subject: [PATCH 05/21] Update comments on FPRRegBank in
 AArch64RegisterBanks.td

---
 llvm/lib/Target/AArch64/AArch64RegisterBanks.td | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/llvm/lib/Target/AArch64/AArch64RegisterBanks.td b/llvm/lib/Target/AArch64/AArch64RegisterBanks.td
index 9e2ed356299e2..2b597b8606921 100644
--- a/llvm/lib/Target/AArch64/AArch64RegisterBanks.td
+++ b/llvm/lib/Target/AArch64/AArch64RegisterBanks.td
@@ -12,7 +12,7 @@
 /// General Purpose Registers: W, X.
 def GPRRegBank : RegisterBank<"GPR", [XSeqPairsClass]>;
 
-/// Floating Point/Vector Registers: B, H, S, D, Q.
+/// Floating Point, Vector, Scalable Vector Registers: B, H, S, D, Q, Z.
 def FPRRegBank : RegisterBank<"FPR", [QQQQ, ZPR]>;
 
 /// Conditional register: NZCV.

>From b19d1b47c8f0ead007f45888364cbe0000e94723 Mon Sep 17 00:00:00 2001
From: Tianyi Guan <tguan at nvidia.com>
Date: Thu, 16 May 2024 11:52:41 +0100
Subject: [PATCH 06/21] Add option `aarch64-enable-sve-gisel` to allow SVE in
 GISel, disabled by default

---
 .../Target/AArch64/AArch64ISelLowering.cpp    | 39 +++++++++++--------
 .../AArch64/GlobalISel/sve-load-store.ll      |  2 +-
 2 files changed, 24 insertions(+), 17 deletions(-)

diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index e0be162e10a97..40ce9152d4748 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -145,6 +145,15 @@ static cl::opt<bool> EnableExtToTBL("aarch64-enable-ext-to-tbl", cl::Hidden,
 static cl::opt<unsigned> MaxXors("aarch64-max-xors", cl::init(16), cl::Hidden,
                                  cl::desc("Maximum of xors"));
 
+// By turning this on, we will not fallback to DAG ISel when encountering
+// scalable vector types for all instruction, even if SVE is not yet supported
+// with some instructions.
+// See [AArch64TargetLowering::fallbackToDAGISel] for implementation details.
+static cl::opt<bool> EnableSVEGISel(
+    "aarch64-enable-sve-gisel", cl::Hidden,
+    cl::desc("Enable / disable SVE scalable vectors in Global ISel"),
+    cl::init(false));
+
 /// Value type used for condition codes.
 static const MVT MVT_CC = MVT::i32;
 
@@ -26375,26 +26384,24 @@ bool AArch64TargetLowering::shouldLocalize(
   return TargetLoweringBase::shouldLocalize(MI, TTI);
 }
 
-static bool isScalableTySupported(const unsigned Op) {
-  return Op == Instruction::Load || Op == Instruction::Store;
-}
-
 bool AArch64TargetLowering::fallBackToDAGISel(const Instruction &Inst) const {
-  const auto ScalableTySupported = isScalableTySupported(Inst.getOpcode());
-
-  // Fallback for scalable vectors
-  if (Inst.getType()->isScalableTy() && !ScalableTySupported) {
+  // Fallback for scalable vectors.
+  // Note that if EnableSVEGISel is true, we allow scalable vector types for
+  // all instructions, regardless of whether they are actually supported.
+  if (!EnableSVEGISel) {
+    if (Inst.getType()->isScalableTy()) {
       return true;
-  }
+    }
 
-  for (unsigned i = 0; i < Inst.getNumOperands(); ++i)
-    if (Inst.getOperand(i)->getType()->isScalableTy() && !ScalableTySupported)
-      return true;
+    for (unsigned i = 0; i < Inst.getNumOperands(); ++i)
+      if (Inst.getOperand(i)->getType()->isScalableTy())
+        return true;
 
-  if (const AllocaInst *AI = dyn_cast<AllocaInst>(&Inst)) {
-    if (AI->getAllocatedType()->isScalableTy())
-      return true;
-  }
+    if (const AllocaInst *AI = dyn_cast<AllocaInst>(&Inst)) {
+      if (AI->getAllocatedType()->isScalableTy())
+        return true;
+    }
+  } 
 
   // Checks to allow the use of SME instructions
   if (auto *Base = dyn_cast<CallBase>(&Inst)) {
diff --git a/llvm/test/CodeGen/AArch64/GlobalISel/sve-load-store.ll b/llvm/test/CodeGen/AArch64/GlobalISel/sve-load-store.ll
index 4c3ffb99e5667..5f41bd2b129df 100644
--- a/llvm/test/CodeGen/AArch64/GlobalISel/sve-load-store.ll
+++ b/llvm/test/CodeGen/AArch64/GlobalISel/sve-load-store.ll
@@ -1,5 +1,5 @@
 ; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 4
-; RUN: llc -mtriple=aarch64-linux-gnu -mattr=+sve -global-isel < %s | FileCheck %s
+; RUN: llc -mtriple=aarch64-linux-gnu -mattr=+sve -global-isel -aarch64-enable-sve-gisel=true < %s | FileCheck %s
 
 define void @scalable_v16i8(ptr %l0, ptr %l1) {
 ; CHECK-LABEL: scalable_v16i8:

>From fb48ea14e6e6a95b117bd86d17b147f42bc5dc44 Mon Sep 17 00:00:00 2001
From: Tianyi Guan <tguan at nvidia.com>
Date: Mon, 20 May 2024 13:51:36 +0100
Subject: [PATCH 07/21] Explicitly assign scalable and non-scalable vectors
 into FPR

---
 llvm/lib/Target/AArch64/GISel/AArch64RegisterBankInfo.cpp | 8 ++++----
 1 file changed, 4 insertions(+), 4 deletions(-)

diff --git a/llvm/lib/Target/AArch64/GISel/AArch64RegisterBankInfo.cpp b/llvm/lib/Target/AArch64/GISel/AArch64RegisterBankInfo.cpp
index f249729b4b4ab..c44cc45e8b871 100644
--- a/llvm/lib/Target/AArch64/GISel/AArch64RegisterBankInfo.cpp
+++ b/llvm/lib/Target/AArch64/GISel/AArch64RegisterBankInfo.cpp
@@ -743,12 +743,12 @@ AArch64RegisterBankInfo::getInstrMapping(const MachineInstr &MI) const {
       continue;
     OpSize[Idx] = Ty.getSizeInBits().getKnownMinValue();
 
-    // As a top-level guess, scalable vectors go in SVRs, non-scalable
-    // vectors go in FPRs, scalars and pointers in GPRs.
+    // As a top-level guess, vectors including both scalable and non-scalable
+    // ones go in FPRs, scalars and pointers in GPRs.
     // For floating-point instructions, scalars go in FPRs.
-    if (Ty.isScalableVector()) 
+    if (Ty.isVector())
       OpRegBankIdx[Idx] = PMI_FirstFPR;
-    else if (Ty.isVector() || isPreISelGenericFloatingPointOpcode(Opc) ||
+    else if (isPreISelGenericFloatingPointOpcode(Opc) ||
         Ty.getSizeInBits() > 64)
       OpRegBankIdx[Idx] = PMI_FirstFPR;
     else

>From 9dedf0e096c6c8c156d1d263c517c8228b62dc8d Mon Sep 17 00:00:00 2001
From: Tianyi Guan <tguan at nvidia.com>
Date: Tue, 21 May 2024 13:59:33 +0100
Subject: [PATCH 08/21] Use getActionDefinitionsBuilder only once

---
 .../Target/AArch64/GISel/AArch64LegalizerInfo.cpp    | 12 +++++++-----
 1 file changed, 7 insertions(+), 5 deletions(-)

diff --git a/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp b/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp
index b93ed0e50e8dd..b7df5a9291373 100644
--- a/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp
+++ b/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp
@@ -335,10 +335,12 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
     return ValTy.isPointerVector() && ValTy.getAddressSpace() == 0;
   };
 
+  auto &LoadActions = getActionDefinitionsBuilder(G_LOAD);
+  auto &StoreActions = getActionDefinitionsBuilder(G_STORE);
+
   if (ST.hasSVE()) {
-    for (const auto OpCode : {G_LOAD, G_STORE}) {
-      getActionDefinitionsBuilder(OpCode)
-      .legalForTypesWithMemDesc({
+    for (auto *Actions : {&LoadActions, &StoreActions}) {
+      Actions->legalForTypesWithMemDesc({
         // 128 bit base sizes
         {nxv16s8, p0, nxv16s8, 128},
         {nxv8s16, p0, nxv8s16, 128},
@@ -349,7 +351,7 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
     }
   }
 
-  getActionDefinitionsBuilder(G_LOAD)
+  LoadActions
       .customIf([=](const LegalityQuery &Query) {
         return HasRCPC3 && Query.Types[0] == s128 &&
                Query.MMODescrs[0].Ordering == AtomicOrdering::Acquire;
@@ -399,7 +401,7 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
       .customIf(IsPtrVecPred)
       .scalarizeIf(typeInSet(0, {v2s16, v2s8}), 0);
 
-  getActionDefinitionsBuilder(G_STORE)
+  StoreActions
       .customIf([=](const LegalityQuery &Query) {
         return HasRCPC3 && Query.Types[0] == s128 &&
                Query.MMODescrs[0].Ordering == AtomicOrdering::Release;

>From ad331f96019728bf442d0a785572f975acd62d51 Mon Sep 17 00:00:00 2001
From: Tianyi Guan <tguan at nvidia.com>
Date: Wed, 22 May 2024 09:37:56 +0100
Subject: [PATCH 09/21] Replace TypeSize usages

---
 .../llvm/CodeGen/GlobalISel/GIMatchTableExecutorImpl.h      | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/llvm/include/llvm/CodeGen/GlobalISel/GIMatchTableExecutorImpl.h b/llvm/include/llvm/CodeGen/GlobalISel/GIMatchTableExecutorImpl.h
index 29939d4619400..554bef8406989 100644
--- a/llvm/include/llvm/CodeGen/GlobalISel/GIMatchTableExecutorImpl.h
+++ b/llvm/include/llvm/CodeGen/GlobalISel/GIMatchTableExecutorImpl.h
@@ -652,17 +652,17 @@ bool GIMatchTableExecutor::executeMatchTable(
       MachineMemOperand *MMO =
           *(State.MIs[InsnID]->memoperands_begin() + MMOIdx);
 
-      const auto Size = MRI.getType(MO.getReg()).getSizeInBits();
+      const TypeSize Size = MRI.getType(MO.getReg()).getSizeInBits();
       if (MatcherOpcode == GIM_CheckMemorySizeEqualToLLT &&
           MMO->getSizeInBits() != Size) {
         if (handleReject() == RejectAndGiveUp)
           return false;
       } else if (MatcherOpcode == GIM_CheckMemorySizeLessThanLLT &&
-                 MMO->getSizeInBits().getValue() >= Size.getKnownMinValue()) {
+                 TypeSize::isKnownGE(MMO->getSizeInBits().getValue(), Size)) {
         if (handleReject() == RejectAndGiveUp)
           return false;
       } else if (MatcherOpcode == GIM_CheckMemorySizeGreaterThanLLT &&
-                 MMO->getSizeInBits().getValue() <= Size.getKnownMinValue())
+                 TypeSize::isKnownLE(MMO->getSizeInBits().getValue(), Size))
         if (handleReject() == RejectAndGiveUp)
           return false;
 

>From 0c3cce327b0477badd9455fc050e850fcd9d12dc Mon Sep 17 00:00:00 2001
From: Tianyi Guan <tguan at nvidia.com>
Date: Wed, 22 May 2024 09:48:07 +0100
Subject: [PATCH 10/21] Simplify assertion

---
 .../Target/AArch64/GISel/AArch64InstructionSelector.cpp   | 8 +++-----
 1 file changed, 3 insertions(+), 5 deletions(-)

diff --git a/llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp b/llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp
index 1da6fce1aa283..9a54306239a11 100644
--- a/llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp
+++ b/llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp
@@ -2927,11 +2927,9 @@ bool AArch64InstructionSelector::select(MachineInstr &I) {
     const LLT ValTy = MRI.getType(ValReg);
     const RegisterBank &RB = *RBI.getRegBank(ValReg, MRI, TRI);
 
-    if (ValTy.isScalableVector()) {
-        assert(STI.hasSVE() 
-             && "Load/Store register operand is scalable vector "
-                "while SVE is not supported by the target");
-    }
+    assert((!ValTy.isScalableVector() || STI.hasSVE()) &&
+      "Load/Store register operand is scalable vector "
+      "while SVE is not supported by the target");
     
     // The code below doesn't support truncating stores, so we need to split it
     // again.

>From 45030749e348dfc62f85fee5a03eecda5a787aa9 Mon Sep 17 00:00:00 2001
From: Tianyi Guan <tguan at nvidia.com>
Date: Wed, 22 May 2024 10:09:18 +0100
Subject: [PATCH 11/21] Reformat code

---
 .../lib/CodeGen/GlobalISel/CombinerHelper.cpp |  3 +-
 .../Target/AArch64/AArch64ISelLowering.cpp    |  2 +-
 .../GISel/AArch64InstructionSelector.cpp      | 39 +++++-----
 .../AArch64/GISel/AArch64LegalizerInfo.cpp    | 71 +++++++++----------
 .../GISel/AArch64PostLegalizerCombiner.cpp    |  3 +-
 .../AArch64/GISel/AArch64RegisterBankInfo.cpp |  2 +-
 6 files changed, 62 insertions(+), 58 deletions(-)

diff --git a/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp b/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp
index 141c7ee15fe39..0c886a052d059 100644
--- a/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp
@@ -1080,7 +1080,8 @@ bool CombinerHelper::isIndexedLoadStoreLegal(GLoadStore &LdSt) const {
   LLT Ty = MRI.getType(LdSt.getReg(0));
   LLT MemTy = LdSt.getMMO().getMemoryType();
   SmallVector<LegalityQuery::MemDesc, 2> MemDescrs(
-      {{MemTy, MemTy.getSizeInBits().getKnownMinValue(), AtomicOrdering::NotAtomic}});
+      {{MemTy, MemTy.getSizeInBits().getKnownMinValue(),
+        AtomicOrdering::NotAtomic}});
   unsigned IndexedOpc = getIndexedOpc(LdSt.getOpcode());
   SmallVector<LLT> OpTys;
   if (IndexedOpc == TargetOpcode::G_INDEXED_STORE)
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 40ce9152d4748..cbd1b6a8e4792 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -26401,7 +26401,7 @@ bool AArch64TargetLowering::fallBackToDAGISel(const Instruction &Inst) const {
       if (AI->getAllocatedType()->isScalableTy())
         return true;
     }
-  } 
+  }
 
   // Checks to allow the use of SME instructions
   if (auto *Base = dyn_cast<CallBase>(&Inst)) {
diff --git a/llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp b/llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp
index 9a54306239a11..364cf72d198ef 100644
--- a/llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp
+++ b/llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp
@@ -901,24 +901,24 @@ static unsigned selectLoadStoreUIOp(unsigned GenericOpc, unsigned RegBankID,
   return GenericOpc;
 }
 
-/// Select the AArch64 opcode for the G_LOAD or G_STORE operation for scalable 
+/// Select the AArch64 opcode for the G_LOAD or G_STORE operation for scalable
 /// vectors.
 /// \p ElementSize size of the element of the scalable vector
 static unsigned selectLoadStoreSVEOp(const unsigned GenericOpc,
                                      const unsigned ElementSize) {
   const bool isStore = GenericOpc == TargetOpcode::G_STORE;
-  
+
   switch (ElementSize) {
-    case 8:
-      return isStore ? AArch64::ST1B : AArch64::LD1B;
-    case 16:
-      return isStore ? AArch64::ST1H : AArch64::LD1H;
-    case 32:
-      return isStore ? AArch64::ST1W : AArch64::LD1W;
-    case 64:
-      return isStore ? AArch64::ST1D : AArch64::LD1D;
+  case 8:
+    return isStore ? AArch64::ST1B : AArch64::LD1B;
+  case 16:
+    return isStore ? AArch64::ST1H : AArch64::LD1H;
+  case 32:
+    return isStore ? AArch64::ST1W : AArch64::LD1W;
+  case 64:
+    return isStore ? AArch64::ST1D : AArch64::LD1D;
   }
-  
+
   return GenericOpc;
 }
 
@@ -2875,7 +2875,8 @@ bool AArch64InstructionSelector::select(MachineInstr &I) {
     }
 
     uint64_t MemSizeInBytes = LdSt.getMemSize().getValue().getKnownMinValue();
-    unsigned MemSizeInBits = LdSt.getMemSizeInBits().getValue().getKnownMinValue();
+    unsigned MemSizeInBits =
+        LdSt.getMemSizeInBits().getValue().getKnownMinValue();
     AtomicOrdering Order = LdSt.getMMO().getSuccessOrdering();
 
     // Need special instructions for atomics that affect ordering.
@@ -2928,14 +2929,15 @@ bool AArch64InstructionSelector::select(MachineInstr &I) {
     const RegisterBank &RB = *RBI.getRegBank(ValReg, MRI, TRI);
 
     assert((!ValTy.isScalableVector() || STI.hasSVE()) &&
-      "Load/Store register operand is scalable vector "
-      "while SVE is not supported by the target");
-    
+           "Load/Store register operand is scalable vector "
+           "while SVE is not supported by the target");
+
     // The code below doesn't support truncating stores, so we need to split it
     // again.
     // Truncate only if type is not scalable vector
-    const bool NeedTrunc = !ValTy.isScalableVector() 
-                      && ValTy.getSizeInBits().getFixedValue() > MemSizeInBits;
+    const bool NeedTrunc =
+        !ValTy.isScalableVector() &&
+        ValTy.getSizeInBits().getFixedValue() > MemSizeInBits;
     if (isa<GStore>(LdSt) && NeedTrunc) {
       unsigned SubReg;
       LLT MemTy = LdSt.getMMO().getMemoryType();
@@ -2981,7 +2983,8 @@ bool AArch64InstructionSelector::select(MachineInstr &I) {
       bool IsStore = isa<GStore>(I);
       unsigned NewOpc;
       if (ValTy.isScalableVector()) {
-        NewOpc = selectLoadStoreSVEOp(I.getOpcode(), ValTy.getElementType().getSizeInBits());
+        NewOpc = selectLoadStoreSVEOp(I.getOpcode(),
+                                      ValTy.getElementType().getSizeInBits());
       } else {
         NewOpc = selectLoadStoreUIOp(I.getOpcode(), RB.getID(), MemSizeInBits);
       }
diff --git a/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp b/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp
index b7df5a9291373..84da936ea3ea1 100644
--- a/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp
+++ b/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp
@@ -63,7 +63,7 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
 
   const LLT nxv16s8 = LLT::scalable_vector(16, s8);
   const LLT nxv8s16 = LLT::scalable_vector(8, s16);
-  const LLT nxv4s32 = LLT::scalable_vector(4, s32); 
+  const LLT nxv4s32 = LLT::scalable_vector(4, s32);
   const LLT nxv2s64 = LLT::scalable_vector(2, s64);
   const LLT nxv2p0 = LLT::scalable_vector(2, p0);
 
@@ -341,12 +341,12 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
   if (ST.hasSVE()) {
     for (auto *Actions : {&LoadActions, &StoreActions}) {
       Actions->legalForTypesWithMemDesc({
-        // 128 bit base sizes
-        {nxv16s8, p0, nxv16s8, 128},
-        {nxv8s16, p0, nxv8s16, 128},
-        {nxv4s32, p0, nxv4s32, 128},
-        {nxv2s64, p0, nxv2s64, 128},
-        {nxv2p0, p0, nxv2p0, 128},
+          // 128 bit base sizes
+          {nxv16s8, p0, nxv16s8, 128},
+          {nxv8s16, p0, nxv8s16, 128},
+          {nxv4s32, p0, nxv4s32, 128},
+          {nxv2s64, p0, nxv2s64, 128},
+          {nxv2p0, p0, nxv2p0, 128},
       });
     }
   }
@@ -410,18 +410,18 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
         return Query.Types[0] == s128 &&
                Query.MMODescrs[0].Ordering != AtomicOrdering::NotAtomic;
       })
-      .legalForTypesWithMemDesc(
-          {{s8, p0, s8, 8},     {s16, p0, s8, 8},  // truncstorei8 from s16
-           {s32, p0, s8, 8},                       // truncstorei8 from s32
-           {s64, p0, s8, 8},                       // truncstorei8 from s64
-           {s16, p0, s16, 8},   {s32, p0, s16, 8}, // truncstorei16 from s32
-           {s64, p0, s16, 8},                      // truncstorei16 from s64
-           {s32, p0, s8, 8},    {s32, p0, s16, 8},    {s32, p0, s32, 8},
-           {s64, p0, s64, 8},   {s64, p0, s32, 8}, // truncstorei32 from s64
-           {p0, p0, s64, 8},    {s128, p0, s128, 8},  {v16s8, p0, s128, 8},
-           {v8s8, p0, s64, 8},  {v4s16, p0, s64, 8},  {v8s16, p0, s128, 8},
-           {v2s32, p0, s64, 8}, {v4s32, p0, s128, 8}, {v2s64, p0, s128, 8},
-          })
+      .legalForTypesWithMemDesc({
+          {s8, p0, s8, 8},     {s16, p0, s8, 8},  // truncstorei8 from s16
+          {s32, p0, s8, 8},                       // truncstorei8 from s32
+          {s64, p0, s8, 8},                       // truncstorei8 from s64
+          {s16, p0, s16, 8},   {s32, p0, s16, 8}, // truncstorei16 from s32
+          {s64, p0, s16, 8},                      // truncstorei16 from s64
+          {s32, p0, s8, 8},    {s32, p0, s16, 8},    {s32, p0, s32, 8},
+          {s64, p0, s64, 8},   {s64, p0, s32, 8}, // truncstorei32 from s64
+          {p0, p0, s64, 8},    {s128, p0, s128, 8},  {v16s8, p0, s128, 8},
+          {v8s8, p0, s64, 8},  {v4s16, p0, s64, 8},  {v8s16, p0, s128, 8},
+          {v2s32, p0, s64, 8}, {v4s32, p0, s128, 8}, {v2s64, p0, s128, 8},
+      })
       .clampScalar(0, s8, s64)
       .lowerIf([=](const LegalityQuery &Query) {
         return Query.Types[0].isScalar() &&
@@ -447,23 +447,22 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
       // Idx 0 == Ptr, Idx 1 == Val
       // TODO: we can implement legalizations but as of now these are
       // generated in a very specific way.
-      .legalForTypesWithMemDesc({
-          {p0, s8, s8, 8},
-          {p0, s16, s16, 8},
-          {p0, s32, s8, 8},
-          {p0, s32, s16, 8},
-          {p0, s32, s32, 8},
-          {p0, s64, s64, 8},
-          {p0, p0, p0, 8},
-          {p0, v8s8, v8s8, 8},
-          {p0, v16s8, v16s8, 8},
-          {p0, v4s16, v4s16, 8},
-          {p0, v8s16, v8s16, 8},
-          {p0, v2s32, v2s32, 8},
-          {p0, v4s32, v4s32, 8},
-          {p0, v2s64, v2s64, 8},
-          {p0, v2p0, v2p0, 8},
-          {p0, s128, s128, 8}})
+      .legalForTypesWithMemDesc({{p0, s8, s8, 8},
+                                 {p0, s16, s16, 8},
+                                 {p0, s32, s8, 8},
+                                 {p0, s32, s16, 8},
+                                 {p0, s32, s32, 8},
+                                 {p0, s64, s64, 8},
+                                 {p0, p0, p0, 8},
+                                 {p0, v8s8, v8s8, 8},
+                                 {p0, v16s8, v16s8, 8},
+                                 {p0, v4s16, v4s16, 8},
+                                 {p0, v8s16, v8s16, 8},
+                                 {p0, v2s32, v2s32, 8},
+                                 {p0, v4s32, v4s32, 8},
+                                 {p0, v2s64, v2s64, 8},
+                                 {p0, v2p0, v2p0, 8},
+                                 {p0, s128, s128, 8}})
       .unsupported();
 
   auto IndexedLoadBasicPred = [=](const LegalityQuery &Query) {
diff --git a/llvm/lib/Target/AArch64/GISel/AArch64PostLegalizerCombiner.cpp b/llvm/lib/Target/AArch64/GISel/AArch64PostLegalizerCombiner.cpp
index 5830489e8ef90..83dbf2077365c 100644
--- a/llvm/lib/Target/AArch64/GISel/AArch64PostLegalizerCombiner.cpp
+++ b/llvm/lib/Target/AArch64/GISel/AArch64PostLegalizerCombiner.cpp
@@ -658,7 +658,8 @@ bool AArch64PostLegalizerCombiner::optimizeConsecutiveMemOpAddressing(
         APInt Offset;
         LLT StoredValTy = MRI.getType(St->getValueReg());
         const auto ValSize = StoredValTy.getSizeInBits();
-        if (ValSize.getKnownMinValue() < 32 || St->getMMO().getSizeInBits() != ValSize)
+        if (ValSize.getKnownMinValue() < 32 ||
+            St->getMMO().getSizeInBits() != ValSize)
           continue;
 
         Register PtrReg = St->getPointerReg();
diff --git a/llvm/lib/Target/AArch64/GISel/AArch64RegisterBankInfo.cpp b/llvm/lib/Target/AArch64/GISel/AArch64RegisterBankInfo.cpp
index c44cc45e8b871..4d2a7fd412135 100644
--- a/llvm/lib/Target/AArch64/GISel/AArch64RegisterBankInfo.cpp
+++ b/llvm/lib/Target/AArch64/GISel/AArch64RegisterBankInfo.cpp
@@ -749,7 +749,7 @@ AArch64RegisterBankInfo::getInstrMapping(const MachineInstr &MI) const {
     if (Ty.isVector())
       OpRegBankIdx[Idx] = PMI_FirstFPR;
     else if (isPreISelGenericFloatingPointOpcode(Opc) ||
-        Ty.getSizeInBits() > 64)
+             Ty.getSizeInBits() > 64)
       OpRegBankIdx[Idx] = PMI_FirstFPR;
     else
       OpRegBankIdx[Idx] = PMI_FirstGPR;

>From 22de2adb74f2b248b543673969eeb0207b60129f Mon Sep 17 00:00:00 2001
From: Tianyi Guan <tguan at nvidia.com>
Date: Wed, 22 May 2024 10:13:55 +0100
Subject: [PATCH 12/21] Remove brackets from single statements

---
 .../lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp b/llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp
index 364cf72d198ef..5a990374c0ee7 100644
--- a/llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp
+++ b/llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp
@@ -2982,12 +2982,12 @@ bool AArch64InstructionSelector::select(MachineInstr &I) {
     auto SelectLoadStoreAddressingMode = [&]() -> MachineInstr * {
       bool IsStore = isa<GStore>(I);
       unsigned NewOpc;
-      if (ValTy.isScalableVector()) {
+      if (ValTy.isScalableVector())
         NewOpc = selectLoadStoreSVEOp(I.getOpcode(),
                                       ValTy.getElementType().getSizeInBits());
-      } else {
+      else
         NewOpc = selectLoadStoreUIOp(I.getOpcode(), RB.getID(), MemSizeInBits);
-      }
+
       if (NewOpc == I.getOpcode())
         return nullptr;
 

>From 3be060cf9f298554b6f50b676a42ca4ba4f09307 Mon Sep 17 00:00:00 2001
From: Tianyi Guan <tguan at nvidia.com>
Date: Wed, 22 May 2024 11:49:22 +0100
Subject: [PATCH 13/21] Revert formatting change

---
 .../AArch64/GISel/AArch64LegalizerInfo.cpp    | 57 ++++++++++---------
 1 file changed, 29 insertions(+), 28 deletions(-)

diff --git a/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp b/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp
index 84da936ea3ea1..867588a0f5143 100644
--- a/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp
+++ b/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp
@@ -410,18 +410,17 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
         return Query.Types[0] == s128 &&
                Query.MMODescrs[0].Ordering != AtomicOrdering::NotAtomic;
       })
-      .legalForTypesWithMemDesc({
-          {s8, p0, s8, 8},     {s16, p0, s8, 8},  // truncstorei8 from s16
-          {s32, p0, s8, 8},                       // truncstorei8 from s32
-          {s64, p0, s8, 8},                       // truncstorei8 from s64
-          {s16, p0, s16, 8},   {s32, p0, s16, 8}, // truncstorei16 from s32
-          {s64, p0, s16, 8},                      // truncstorei16 from s64
-          {s32, p0, s8, 8},    {s32, p0, s16, 8},    {s32, p0, s32, 8},
-          {s64, p0, s64, 8},   {s64, p0, s32, 8}, // truncstorei32 from s64
-          {p0, p0, s64, 8},    {s128, p0, s128, 8},  {v16s8, p0, s128, 8},
-          {v8s8, p0, s64, 8},  {v4s16, p0, s64, 8},  {v8s16, p0, s128, 8},
-          {v2s32, p0, s64, 8}, {v4s32, p0, s128, 8}, {v2s64, p0, s128, 8},
-      })
+      .legalForTypesWithMemDesc(
+          {{s8, p0, s8, 8},     {s16, p0, s8, 8},  // truncstorei8 from s16
+           {s32, p0, s8, 8},                       // truncstorei8 from s32
+           {s64, p0, s8, 8},                       // truncstorei8 from s64
+           {s16, p0, s16, 8},   {s32, p0, s16, 8}, // truncstorei16 from s32
+           {s64, p0, s16, 8},                      // truncstorei16 from s64
+           {s32, p0, s8, 8},    {s32, p0, s16, 8},    {s32, p0, s32, 8},
+           {s64, p0, s64, 8},   {s64, p0, s32, 8}, // truncstorei32 from s64
+           {p0, p0, s64, 8},    {s128, p0, s128, 8},  {v16s8, p0, s128, 8},
+           {v8s8, p0, s64, 8},  {v4s16, p0, s64, 8},  {v8s16, p0, s128, 8},
+           {v2s32, p0, s64, 8}, {v4s32, p0, s128, 8}, {v2s64, p0, s128, 8}})
       .clampScalar(0, s8, s64)
       .lowerIf([=](const LegalityQuery &Query) {
         return Query.Types[0].isScalar() &&
@@ -447,22 +446,24 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
       // Idx 0 == Ptr, Idx 1 == Val
       // TODO: we can implement legalizations but as of now these are
       // generated in a very specific way.
-      .legalForTypesWithMemDesc({{p0, s8, s8, 8},
-                                 {p0, s16, s16, 8},
-                                 {p0, s32, s8, 8},
-                                 {p0, s32, s16, 8},
-                                 {p0, s32, s32, 8},
-                                 {p0, s64, s64, 8},
-                                 {p0, p0, p0, 8},
-                                 {p0, v8s8, v8s8, 8},
-                                 {p0, v16s8, v16s8, 8},
-                                 {p0, v4s16, v4s16, 8},
-                                 {p0, v8s16, v8s16, 8},
-                                 {p0, v2s32, v2s32, 8},
-                                 {p0, v4s32, v4s32, 8},
-                                 {p0, v2s64, v2s64, 8},
-                                 {p0, v2p0, v2p0, 8},
-                                 {p0, s128, s128, 8}})
+      .legalForTypesWithMemDesc({
+          {p0, s8, s8, 8},
+          {p0, s16, s16, 8},
+          {p0, s32, s8, 8},
+          {p0, s32, s16, 8},
+          {p0, s32, s32, 8},
+          {p0, s64, s64, 8},
+          {p0, p0, p0, 8},
+          {p0, v8s8, v8s8, 8},
+          {p0, v16s8, v16s8, 8},
+          {p0, v4s16, v4s16, 8},
+          {p0, v8s16, v8s16, 8},
+          {p0, v2s32, v2s32, 8},
+          {p0, v4s32, v4s32, 8},
+          {p0, v2s64, v2s64, 8},
+          {p0, v2p0, v2p0, 8},
+          {p0, s128, s128, 8},
+      })
       .unsupported();
 
   auto IndexedLoadBasicPred = [=](const LegalityQuery &Query) {

>From b5c72d79d37d0c9ccd71256ca5b320e202b4860b Mon Sep 17 00:00:00 2001
From: Tianyi Guan <tguan at nvidia.com>
Date: Wed, 22 May 2024 11:49:44 +0100
Subject: [PATCH 14/21] Skip SplitStoreZero128 for scalable vectors

---
 .../lib/Target/AArch64/GISel/AArch64PostLegalizerCombiner.cpp | 4 +++-
 1 file changed, 3 insertions(+), 1 deletion(-)

diff --git a/llvm/lib/Target/AArch64/GISel/AArch64PostLegalizerCombiner.cpp b/llvm/lib/Target/AArch64/GISel/AArch64PostLegalizerCombiner.cpp
index 83dbf2077365c..40f6ec36467c1 100644
--- a/llvm/lib/Target/AArch64/GISel/AArch64PostLegalizerCombiner.cpp
+++ b/llvm/lib/Target/AArch64/GISel/AArch64PostLegalizerCombiner.cpp
@@ -309,7 +309,9 @@ bool matchSplitStoreZero128(MachineInstr &MI, MachineRegisterInfo &MRI) {
   if (!Store.isSimple())
     return false;
   LLT ValTy = MRI.getType(Store.getValueReg());
-  if (!ValTy.isVector() || ValTy.getSizeInBits().getKnownMinValue() != 128)
+  if (ValTy.isScalableVector())
+    return false;
+  if (!ValTy.isVector() || ValTy.getSizeInBits() != 128)
     return false;
   if (Store.getMemSizeInBits() != ValTy.getSizeInBits())
     return false; // Don't split truncating stores.

>From 2518b20419982e48658b3180ffc356bea5272f2c Mon Sep 17 00:00:00 2001
From: Tianyi Guan <tguan at nvidia.com>
Date: Wed, 22 May 2024 11:50:13 +0100
Subject: [PATCH 15/21] Skip optimizeConsecutiveMemOpAddressing for scalable
 vectors

---
 .../AArch64/GISel/AArch64PostLegalizerCombiner.cpp     | 10 +++++++---
 1 file changed, 7 insertions(+), 3 deletions(-)

diff --git a/llvm/lib/Target/AArch64/GISel/AArch64PostLegalizerCombiner.cpp b/llvm/lib/Target/AArch64/GISel/AArch64PostLegalizerCombiner.cpp
index 40f6ec36467c1..fe84d0e27189f 100644
--- a/llvm/lib/Target/AArch64/GISel/AArch64PostLegalizerCombiner.cpp
+++ b/llvm/lib/Target/AArch64/GISel/AArch64PostLegalizerCombiner.cpp
@@ -655,13 +655,17 @@ bool AArch64PostLegalizerCombiner::optimizeConsecutiveMemOpAddressing(
     // should only be in a single block.
     resetState();
     for (auto &MI : MBB) {
+      // Skip for scalable vectors
+      if (auto *LdSt = dyn_cast<GLoadStore>(&MI);
+          LdSt && MRI.getType(LdSt->getOperand(0).getReg()).isScalableVector())
+        continue;
+
       if (auto *St = dyn_cast<GStore>(&MI)) {
         Register PtrBaseReg;
         APInt Offset;
         LLT StoredValTy = MRI.getType(St->getValueReg());
-        const auto ValSize = StoredValTy.getSizeInBits();
-        if (ValSize.getKnownMinValue() < 32 ||
-            St->getMMO().getSizeInBits() != ValSize)
+        unsigned ValSize = StoredValTy.getSizeInBits();
+        if (ValSize < 32 || St->getMMO().getSizeInBits() != ValSize)
           continue;
 
         Register PtrReg = St->getPointerReg();

>From 5a288d896a431db36ff0196067fd81df4d76ceae Mon Sep 17 00:00:00 2001
From: Tianyi Guan <tguan at nvidia.com>
Date: Wed, 22 May 2024 16:10:23 +0100
Subject: [PATCH 16/21] Rename option to aarch64-enable-gisel-sve

---
 llvm/lib/Target/AArch64/AArch64ISelLowering.cpp        | 2 +-
 llvm/test/CodeGen/AArch64/GlobalISel/sve-load-store.ll | 2 +-
 2 files changed, 2 insertions(+), 2 deletions(-)

diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index cbd1b6a8e4792..93057ef87503c 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -150,7 +150,7 @@ static cl::opt<unsigned> MaxXors("aarch64-max-xors", cl::init(16), cl::Hidden,
 // with some instructions.
 // See [AArch64TargetLowering::fallbackToDAGISel] for implementation details.
 static cl::opt<bool> EnableSVEGISel(
-    "aarch64-enable-sve-gisel", cl::Hidden,
+    "aarch64-enable-gisel-sve", cl::Hidden,
     cl::desc("Enable / disable SVE scalable vectors in Global ISel"),
     cl::init(false));
 
diff --git a/llvm/test/CodeGen/AArch64/GlobalISel/sve-load-store.ll b/llvm/test/CodeGen/AArch64/GlobalISel/sve-load-store.ll
index 5f41bd2b129df..95a5bfa4b038f 100644
--- a/llvm/test/CodeGen/AArch64/GlobalISel/sve-load-store.ll
+++ b/llvm/test/CodeGen/AArch64/GlobalISel/sve-load-store.ll
@@ -1,5 +1,5 @@
 ; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 4
-; RUN: llc -mtriple=aarch64-linux-gnu -mattr=+sve -global-isel -aarch64-enable-sve-gisel=true < %s | FileCheck %s
+; RUN: llc -mtriple=aarch64-linux-gnu -mattr=+sve -global-isel -aarch64-enable-gisel-sve=true < %s | FileCheck %s
 
 define void @scalable_v16i8(ptr %l0, ptr %l1) {
 ; CHECK-LABEL: scalable_v16i8:

>From f1a4d7bd34489ce9136826637d6fc4775c3990ed Mon Sep 17 00:00:00 2001
From: Tianyi Guan <tguan at nvidia.com>
Date: Wed, 22 May 2024 16:45:11 +0100
Subject: [PATCH 17/21] Unfold `hasSVE` loop for legalizer

---
 .../AArch64/GISel/AArch64LegalizerInfo.cpp    | 27 ++++++++++++-------
 1 file changed, 17 insertions(+), 10 deletions(-)

diff --git a/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp b/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp
index 867588a0f5143..16e144dd83e6a 100644
--- a/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp
+++ b/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp
@@ -339,16 +339,23 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
   auto &StoreActions = getActionDefinitionsBuilder(G_STORE);
 
   if (ST.hasSVE()) {
-    for (auto *Actions : {&LoadActions, &StoreActions}) {
-      Actions->legalForTypesWithMemDesc({
-          // 128 bit base sizes
-          {nxv16s8, p0, nxv16s8, 128},
-          {nxv8s16, p0, nxv8s16, 128},
-          {nxv4s32, p0, nxv4s32, 128},
-          {nxv2s64, p0, nxv2s64, 128},
-          {nxv2p0, p0, nxv2p0, 128},
-      });
-    }
+    LoadActions.legalForTypesWithMemDesc({
+        // 128 bit base sizes
+        {nxv16s8, p0, nxv16s8, 128},
+        {nxv8s16, p0, nxv8s16, 128},
+        {nxv4s32, p0, nxv4s32, 128},
+        {nxv2s64, p0, nxv2s64, 128},
+        {nxv2p0, p0, nxv2p0, 128},
+    });
+
+    StoreActions.legalForTypesWithMemDesc({
+        // 128 bit base sizes
+        {nxv16s8, p0, nxv16s8, 128},
+        {nxv8s16, p0, nxv8s16, 128},
+        {nxv4s32, p0, nxv4s32, 128},
+        {nxv2s64, p0, nxv2s64, 128},
+        {nxv2p0, p0, nxv2p0, 128},
+    });
   }
 
   LoadActions

>From 09570db32ee4d1c0d2dde13a38c9212e0a91745f Mon Sep 17 00:00:00 2001
From: Tianyi Guan <tguan at nvidia.com>
Date: Fri, 24 May 2024 12:07:20 +0100
Subject: [PATCH 18/21] Revert instruction selector changes as they are already
 covered by tablegen

---
 .../GISel/AArch64InstructionSelector.cpp      | 55 ++-----------------
 1 file changed, 6 insertions(+), 49 deletions(-)

diff --git a/llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp b/llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp
index 5a990374c0ee7..61f5bc2464ee5 100644
--- a/llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp
+++ b/llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp
@@ -901,27 +901,6 @@ static unsigned selectLoadStoreUIOp(unsigned GenericOpc, unsigned RegBankID,
   return GenericOpc;
 }
 
-/// Select the AArch64 opcode for the G_LOAD or G_STORE operation for scalable
-/// vectors.
-/// \p ElementSize size of the element of the scalable vector
-static unsigned selectLoadStoreSVEOp(const unsigned GenericOpc,
-                                     const unsigned ElementSize) {
-  const bool isStore = GenericOpc == TargetOpcode::G_STORE;
-
-  switch (ElementSize) {
-  case 8:
-    return isStore ? AArch64::ST1B : AArch64::LD1B;
-  case 16:
-    return isStore ? AArch64::ST1H : AArch64::LD1H;
-  case 32:
-    return isStore ? AArch64::ST1W : AArch64::LD1W;
-  case 64:
-    return isStore ? AArch64::ST1D : AArch64::LD1D;
-  }
-
-  return GenericOpc;
-}
-
 /// Helper function for selectCopy. Inserts a subregister copy from \p SrcReg
 /// to \p *To.
 ///
@@ -2874,9 +2853,8 @@ bool AArch64InstructionSelector::select(MachineInstr &I) {
       return false;
     }
 
-    uint64_t MemSizeInBytes = LdSt.getMemSize().getValue().getKnownMinValue();
-    unsigned MemSizeInBits =
-        LdSt.getMemSizeInBits().getValue().getKnownMinValue();
+    uint64_t MemSizeInBytes = LdSt.getMemSize().getValue();
+    unsigned MemSizeInBits = LdSt.getMemSizeInBits().getValue();
     AtomicOrdering Order = LdSt.getMMO().getSuccessOrdering();
 
     // Need special instructions for atomics that affect ordering.
@@ -2928,17 +2906,9 @@ bool AArch64InstructionSelector::select(MachineInstr &I) {
     const LLT ValTy = MRI.getType(ValReg);
     const RegisterBank &RB = *RBI.getRegBank(ValReg, MRI, TRI);
 
-    assert((!ValTy.isScalableVector() || STI.hasSVE()) &&
-           "Load/Store register operand is scalable vector "
-           "while SVE is not supported by the target");
-
     // The code below doesn't support truncating stores, so we need to split it
     // again.
-    // Truncate only if type is not scalable vector
-    const bool NeedTrunc =
-        !ValTy.isScalableVector() &&
-        ValTy.getSizeInBits().getFixedValue() > MemSizeInBits;
-    if (isa<GStore>(LdSt) && NeedTrunc) {
+    if (isa<GStore>(LdSt) && ValTy.getSizeInBits() > MemSizeInBits) {
       unsigned SubReg;
       LLT MemTy = LdSt.getMMO().getMemoryType();
       auto *RC = getRegClassForTypeOnBank(MemTy, RB);
@@ -2951,7 +2921,7 @@ bool AArch64InstructionSelector::select(MachineInstr &I) {
                       .getReg(0);
       RBI.constrainGenericRegister(Copy, *RC, MRI);
       LdSt.getOperand(0).setReg(Copy);
-    } else if (isa<GLoad>(LdSt) && NeedTrunc) {
+    } else if (isa<GLoad>(LdSt) && ValTy.getSizeInBits() > MemSizeInBits) {
       // If this is an any-extending load from the FPR bank, split it into a regular
       // load + extend.
       if (RB.getID() == AArch64::FPRRegBankID) {
@@ -2981,20 +2951,10 @@ bool AArch64InstructionSelector::select(MachineInstr &I) {
     // instruction with an updated opcode, or a new instruction.
     auto SelectLoadStoreAddressingMode = [&]() -> MachineInstr * {
       bool IsStore = isa<GStore>(I);
-      unsigned NewOpc;
-      if (ValTy.isScalableVector())
-        NewOpc = selectLoadStoreSVEOp(I.getOpcode(),
-                                      ValTy.getElementType().getSizeInBits());
-      else
-        NewOpc = selectLoadStoreUIOp(I.getOpcode(), RB.getID(), MemSizeInBits);
-
+      const unsigned NewOpc =
+          selectLoadStoreUIOp(I.getOpcode(), RB.getID(), MemSizeInBits);
       if (NewOpc == I.getOpcode())
         return nullptr;
-
-      if (ValTy.isScalableVector()) {
-        // Add the predicate register operand
-        I.addOperand(MachineOperand::CreatePredicate(true));
-      }
       // Check if we can fold anything into the addressing mode.
       auto AddrModeFns =
           selectAddrModeIndexed(I.getOperand(1), MemSizeInBytes);
@@ -3010,9 +2970,6 @@ bool AArch64InstructionSelector::select(MachineInstr &I) {
       Register CurValReg = I.getOperand(0).getReg();
       IsStore ? NewInst.addUse(CurValReg) : NewInst.addDef(CurValReg);
       NewInst.cloneMemRefs(I);
-      if (ValTy.isScalableVector()) {
-        NewInst.add(I.getOperand(1)); // Copy predicate register
-      }
       for (auto &Fn : *AddrModeFns)
         Fn(NewInst);
       I.eraseFromParent();

>From 551fec4a77021c7d95ca7bde26496348f6bb5973 Mon Sep 17 00:00:00 2001
From: Tianyi Guan <tguan at nvidia.com>
Date: Fri, 24 May 2024 14:22:04 +0100
Subject: [PATCH 19/21] Allow alignment of 8

---
 .../AArch64/GISel/AArch64LegalizerInfo.cpp    | 20 +++++++++----------
 1 file changed, 10 insertions(+), 10 deletions(-)

diff --git a/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp b/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp
index 16e144dd83e6a..9ce9ca411dfea 100644
--- a/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp
+++ b/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp
@@ -341,20 +341,20 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
   if (ST.hasSVE()) {
     LoadActions.legalForTypesWithMemDesc({
         // 128 bit base sizes
-        {nxv16s8, p0, nxv16s8, 128},
-        {nxv8s16, p0, nxv8s16, 128},
-        {nxv4s32, p0, nxv4s32, 128},
-        {nxv2s64, p0, nxv2s64, 128},
-        {nxv2p0, p0, nxv2p0, 128},
+        {nxv16s8, p0, nxv16s8, 8},
+        {nxv8s16, p0, nxv8s16, 8},
+        {nxv4s32, p0, nxv4s32, 8},
+        {nxv2s64, p0, nxv2s64, 8},
+        {nxv2p0, p0, nxv2p0, 8},
     });
 
     StoreActions.legalForTypesWithMemDesc({
         // 128 bit base sizes
-        {nxv16s8, p0, nxv16s8, 128},
-        {nxv8s16, p0, nxv8s16, 128},
-        {nxv4s32, p0, nxv4s32, 128},
-        {nxv2s64, p0, nxv2s64, 128},
-        {nxv2p0, p0, nxv2p0, 128},
+        {nxv16s8, p0, nxv16s8, 8},
+        {nxv8s16, p0, nxv8s16, 8},
+        {nxv4s32, p0, nxv4s32, 8},
+        {nxv2s64, p0, nxv2s64, 8},
+        {nxv2p0, p0, nxv2p0, 8},
     });
   }
 

>From 5995c15754cdbd65054caef87c4a3cdd90e45266 Mon Sep 17 00:00:00 2001
From: Tianyi Guan <tguan at nvidia.com>
Date: Tue, 28 May 2024 10:08:36 +0100
Subject: [PATCH 20/21] Remove legal rules for nxv2p0

---
 llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp | 3 ---
 1 file changed, 3 deletions(-)

diff --git a/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp b/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp
index 9ce9ca411dfea..1df78e13cd716 100644
--- a/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp
+++ b/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp
@@ -65,7 +65,6 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
   const LLT nxv8s16 = LLT::scalable_vector(8, s16);
   const LLT nxv4s32 = LLT::scalable_vector(4, s32);
   const LLT nxv2s64 = LLT::scalable_vector(2, s64);
-  const LLT nxv2p0 = LLT::scalable_vector(2, p0);
 
   std::initializer_list<LLT> PackedVectorAllTypeList = {/* Begin 128bit types */
                                                         v16s8, v8s16, v4s32,
@@ -345,7 +344,6 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
         {nxv8s16, p0, nxv8s16, 8},
         {nxv4s32, p0, nxv4s32, 8},
         {nxv2s64, p0, nxv2s64, 8},
-        {nxv2p0, p0, nxv2p0, 8},
     });
 
     StoreActions.legalForTypesWithMemDesc({
@@ -354,7 +352,6 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
         {nxv8s16, p0, nxv8s16, 8},
         {nxv4s32, p0, nxv4s32, 8},
         {nxv2s64, p0, nxv2s64, 8},
-        {nxv2p0, p0, nxv2p0, 8},
     });
   }
 

>From e53b252fd13d41255b8ab09094c3ad8bc607927c Mon Sep 17 00:00:00 2001
From: Tianyi Guan <tguan at nvidia.com>
Date: Tue, 28 May 2024 14:20:44 +0100
Subject: [PATCH 21/21] Add TODO for nxv2p0

---
 llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp | 3 +++
 1 file changed, 3 insertions(+)

diff --git a/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp b/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp
index 1df78e13cd716..c472fd06ba373 100644
--- a/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp
+++ b/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp
@@ -346,6 +346,9 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
         {nxv2s64, p0, nxv2s64, 8},
     });
 
+    // TODO: Add nxv2p0. Consider bitcastIf.
+    //       See #92130
+    //       https://github.com/llvm/llvm-project/pull/92130#discussion_r1616888461
     StoreActions.legalForTypesWithMemDesc({
         // 128 bit base sizes
         {nxv16s8, p0, nxv16s8, 8},



More information about the llvm-commits mailing list