[llvm] 4612f39 - [SVE] Add flag to specify SVE register size, using this to calculate legal vector types.

Paul Walker via llvm-commits llvm-commits at lists.llvm.org
Thu Jun 18 05:14:51 PDT 2020


Author: Paul Walker
Date: 2020-06-18T12:11:16Z
New Revision: 4612f391200d0b4e21bc040a098227d73679de53

URL: https://github.com/llvm/llvm-project/commit/4612f391200d0b4e21bc040a098227d73679de53
DIFF: https://github.com/llvm/llvm-project/commit/4612f391200d0b4e21bc040a098227d73679de53.diff

LOG: [SVE] Add flag to specify SVE register size, using this to calculate legal vector types.

Adds aarch64-sve-vector-bits-{min,max} to allow the size of SVE
data registers (in bits) to be specified. This allows the code
generator to make assumptions it normally couldn't. As a starting
point this information is used to mark fixed length vector types
that can fit within the specified size as legal.

Reviewers: rengolin, efriedma

Subscribers: tschuett, kristof.beyls, hiraditya, rkruppe, psnobl, llvm-commits

Tags: #llvm

Differential Revision: https://reviews.llvm.org/D80384

Added: 
    llvm/test/Analysis/CostModel/AArch64/sve-fixed-length.ll

Modified: 
    llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
    llvm/lib/Target/AArch64/AArch64ISelLowering.h
    llvm/lib/Target/AArch64/AArch64Subtarget.cpp
    llvm/lib/Target/AArch64/AArch64Subtarget.h
    llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 6177a80d0aae..8c82d88f5a6f 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -184,6 +184,16 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
     addRegisterClass(MVT::nxv4f32, &AArch64::ZPRRegClass);
     addRegisterClass(MVT::nxv2f64, &AArch64::ZPRRegClass);
 
+    if (useSVEForFixedLengthVectors()) {
+      for (MVT VT : MVT::integer_fixedlen_vector_valuetypes())
+        if (useSVEForFixedLengthVectorVT(VT))
+          addRegisterClass(VT, &AArch64::ZPRRegClass);
+
+      for (MVT VT : MVT::fp_fixedlen_vector_valuetypes())
+        if (useSVEForFixedLengthVectorVT(VT))
+          addRegisterClass(VT, &AArch64::ZPRRegClass);
+    }
+
     for (auto VT : { MVT::nxv16i8, MVT::nxv8i16, MVT::nxv4i32, MVT::nxv2i64 }) {
       setOperationAction(ISD::SADDSAT, VT, Legal);
       setOperationAction(ISD::UADDSAT, VT, Legal);
@@ -3474,6 +3484,51 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op,
   }
 }
 
+bool AArch64TargetLowering::useSVEForFixedLengthVectors() const {
+  // Prefer NEON unless larger SVE registers are available.
+  return Subtarget->hasSVE() && Subtarget->getMinSVEVectorSizeInBits() >= 256;
+}
+
+bool AArch64TargetLowering::useSVEForFixedLengthVectorVT(MVT VT) const {
+  assert(VT.isFixedLengthVector());
+  if (!useSVEForFixedLengthVectors())
+    return false;
+
+  // Fixed length predicates should be promoted to i8.
+  // NOTE: This is consistent with how NEON (and thus 64/128bit vectors) work.
+  if (VT.getVectorElementType() == MVT::i1)
+    return false;
+
+  // Don't use SVE for vectors we cannot scalarize if required.
+  switch (VT.getVectorElementType().SimpleTy) {
+  default:
+    return false;
+  case MVT::i8:
+  case MVT::i16:
+  case MVT::i32:
+  case MVT::i64:
+  case MVT::f16:
+  case MVT::f32:
+  case MVT::f64:
+    break;
+  }
+
+  // Ensure NEON MVTs only belong to a single register class.
+  if (VT.getSizeInBits() <= 128)
+    return false;
+
+  // Don't use SVE for types that don't fit.
+  if (VT.getSizeInBits() > Subtarget->getMinSVEVectorSizeInBits())
+    return false;
+
+  // TODO: Perhaps an artificial restriction, but worth having whilst getting
+  // the base fixed length SVE support in place.
+  if (!VT.isPow2VectorType())
+    return false;
+
+  return true;
+}
+
 //===----------------------------------------------------------------------===//
 //                      Calling Convention Implementation
 //===----------------------------------------------------------------------===//

diff  --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
index 6d47fc8e7483..0c0be2a3ebd2 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
@@ -907,6 +907,9 @@ class AArch64TargetLowering : public TargetLowering {
 
   bool shouldLocalize(const MachineInstr &MI,
                       const TargetTransformInfo *TTI) const override;
+
+  bool useSVEForFixedLengthVectors() const;
+  bool useSVEForFixedLengthVectorVT(MVT VT) const;
 };
 
 namespace AArch64 {

diff  --git a/llvm/lib/Target/AArch64/AArch64Subtarget.cpp b/llvm/lib/Target/AArch64/AArch64Subtarget.cpp
index 2eed3448558b..9bb533f8efde 100644
--- a/llvm/lib/Target/AArch64/AArch64Subtarget.cpp
+++ b/llvm/lib/Target/AArch64/AArch64Subtarget.cpp
@@ -47,6 +47,18 @@ static cl::opt<bool>
                    cl::desc("Call nonlazybind functions via direct GOT load"),
                    cl::init(false), cl::Hidden);
 
+static cl::opt<unsigned> SVEVectorBitsMax(
+    "aarch64-sve-vector-bits-max",
+    cl::desc("Assume SVE vector registers are at most this big, "
+             "with zero meaning no maximum size is assumed."),
+    cl::init(0), cl::Hidden);
+
+static cl::opt<unsigned> SVEVectorBitsMin(
+    "aarch64-sve-vector-bits-min",
+    cl::desc("Assume SVE vector registers are at least this big, "
+             "with zero meaning no minimum size is assumed."),
+    cl::init(0), cl::Hidden);
+
 AArch64Subtarget &
 AArch64Subtarget::initializeSubtargetDependencies(StringRef FS,
                                                   StringRef CPUString) {
@@ -329,3 +341,25 @@ void AArch64Subtarget::mirFileLoaded(MachineFunction &MF) const {
   if (!MFI.isMaxCallFrameSizeComputed())
     MFI.computeMaxCallFrameSize(MF);
 }
+
+unsigned AArch64Subtarget::getMaxSVEVectorSizeInBits() const {
+  assert(HasSVE && "Tried to get SVE vector length without SVE support!");
+  assert(SVEVectorBitsMax % 128 == 0 &&
+         "SVE requires vector length in multiples of 128!");
+  assert((SVEVectorBitsMax >= SVEVectorBitsMin || SVEVectorBitsMax == 0) &&
+         "Minimum SVE vector size should not be larger than its maximum!");
+  if (SVEVectorBitsMax == 0)
+    return 0;
+  return (std::max(SVEVectorBitsMin, SVEVectorBitsMax) / 128) * 128;
+}
+
+unsigned AArch64Subtarget::getMinSVEVectorSizeInBits() const {
+  assert(HasSVE && "Tried to get SVE vector length without SVE support!");
+  assert(SVEVectorBitsMin % 128 == 0 &&
+         "SVE requires vector length in multiples of 128!");
+  assert((SVEVectorBitsMax >= SVEVectorBitsMin || SVEVectorBitsMax == 0) &&
+         "Minimum SVE vector size should not be larger than its maximum!");
+  if (SVEVectorBitsMax == 0)
+    return (SVEVectorBitsMin / 128) * 128;
+  return (std::min(SVEVectorBitsMin, SVEVectorBitsMax) / 128) * 128;
+}

diff  --git a/llvm/lib/Target/AArch64/AArch64Subtarget.h b/llvm/lib/Target/AArch64/AArch64Subtarget.h
index ba0660f72477..221c103cff8f 100644
--- a/llvm/lib/Target/AArch64/AArch64Subtarget.h
+++ b/llvm/lib/Target/AArch64/AArch64Subtarget.h
@@ -534,6 +534,12 @@ class AArch64Subtarget final : public AArch64GenSubtargetInfo {
   }
 
   void mirFileLoaded(MachineFunction &MF) const override;
+
+  // Return the known range for the bit length of SVE data registers. A value
+  // of 0 means nothing is known about that particular limit beyong what's
+  // implied by the architecture.
+  unsigned getMaxSVEVectorSizeInBits() const;
+  unsigned getMinSVEVectorSizeInBits() const;
 };
 } // End llvm namespace
 

diff  --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
index f7233d364f71..8e85f9277678 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
@@ -98,6 +98,8 @@ class AArch64TTIImpl : public BasicTTIImplBase<AArch64TTIImpl> {
 
   unsigned getRegisterBitWidth(bool Vector) const {
     if (Vector) {
+      if (ST->hasSVE())
+        return std::max(ST->getMinSVEVectorSizeInBits(), 128u);
       if (ST->hasNEON())
         return 128;
       return 0;

diff  --git a/llvm/test/Analysis/CostModel/AArch64/sve-fixed-length.ll b/llvm/test/Analysis/CostModel/AArch64/sve-fixed-length.ll
new file mode 100644
index 000000000000..ba7db1c4d904
--- /dev/null
+++ b/llvm/test/Analysis/CostModel/AArch64/sve-fixed-length.ll
@@ -0,0 +1,60 @@
+; RUN: opt < %s -cost-model -analyze | FileCheck %s -D#VBITS=128
+; RUN: opt < %s -cost-model -analyze -aarch64-sve-vector-bits-min=128 | FileCheck %s -D#VBITS=128
+; RUN: opt < %s -cost-model -analyze -aarch64-sve-vector-bits-min=256 | FileCheck %s -D#VBITS=256
+; RUN: opt < %s -cost-model -analyze -aarch64-sve-vector-bits-min=384 | FileCheck %s -D#VBITS=256
+; RUN: opt < %s -cost-model -analyze -aarch64-sve-vector-bits-min=512 | FileCheck %s -D#VBITS=512
+; RUN: opt < %s -cost-model -analyze -aarch64-sve-vector-bits-min=640 | FileCheck %s -D#VBITS=512
+; RUN: opt < %s -cost-model -analyze -aarch64-sve-vector-bits-min=768 | FileCheck %s -D#VBITS=512
+; RUN: opt < %s -cost-model -analyze -aarch64-sve-vector-bits-min=896 | FileCheck %s -D#VBITS=512
+; RUN: opt < %s -cost-model -analyze -aarch64-sve-vector-bits-min=1024 | FileCheck %s -D#VBITS=1024
+; RUN: opt < %s -cost-model -analyze -aarch64-sve-vector-bits-min=1152 | FileCheck %s -D#VBITS=1024
+; RUN: opt < %s -cost-model -analyze -aarch64-sve-vector-bits-min=1280 | FileCheck %s -D#VBITS=1024
+; RUN: opt < %s -cost-model -analyze -aarch64-sve-vector-bits-min=1408 | FileCheck %s -D#VBITS=1024
+; RUN: opt < %s -cost-model -analyze -aarch64-sve-vector-bits-min=1536 | FileCheck %s -D#VBITS=1024
+; RUN: opt < %s -cost-model -analyze -aarch64-sve-vector-bits-min=1664 | FileCheck %s -D#VBITS=1024
+; RUN: opt < %s -cost-model -analyze -aarch64-sve-vector-bits-min=1792 | FileCheck %s -D#VBITS=1024
+; RUN: opt < %s -cost-model -analyze -aarch64-sve-vector-bits-min=1920 | FileCheck %s -D#VBITS=1024
+; RUN: opt < %s -cost-model -analyze -aarch64-sve-vector-bits-min=2048 | FileCheck %s -D#VBITS=2048
+
+; VBITS represents the useful bit size of a vector register from the code
+; generator's point of view. It is clamped to power-of-2 values because
+; only power-of-2 vector lengths are considered legal, regardless of the
+; user specified vector length.
+
+target triple = "aarch64-unknown-linux-gnu"
+
+; Ensure the cost of legalisation is removed as the vector length grows.
+; NOTE: Assumes BaseCost_add=1, BaseCost_fadd=2.
+define void @add() #0 {
+; CHECK-LABEL: Printing analysis 'Cost Model Analysis' for function 'add':
+; CHECK: cost of [[#div(127,VBITS)+1]] for instruction:   %add128 = add <4 x i32> undef, undef
+; CHECK: cost of [[#div(255,VBITS)+1]] for instruction:   %add256 = add <8 x i32> undef, undef
+; CHECK: cost of [[#div(511,VBITS)+1]] for instruction:   %add512 = add <16 x i32> undef, undef
+; CHECK: cost of [[#div(1023,VBITS)+1]] for instruction:   %add1024 = add <32 x i32> undef, undef
+; CHECK: cost of [[#div(2047,VBITS)+1]] for instruction:   %add2048 = add <64 x i32> undef, undef
+  %add128 = add <4 x i32> undef, undef
+  %add256 = add <8 x i32> undef, undef
+  %add512 = add <16 x i32> undef, undef
+  %add1024 = add <32 x i32> undef, undef
+  %add2048 = add <64 x i32> undef, undef
+
+; Using a single vector length, ensure all element types are recognised.
+; CHECK: cost of [[#div(511,VBITS)+1]] for instruction:   %add512.i8 = add <64 x i8> undef, undef
+; CHECK: cost of [[#div(511,VBITS)+1]] for instruction:   %add512.i16 = add <32 x i16> undef, undef
+; CHECK: cost of [[#div(511,VBITS)+1]] for instruction:   %add512.i32 = add <16 x i32> undef, undef
+; CHECK: cost of [[#div(511,VBITS)+1]] for instruction:   %add512.i64 = add <8 x i64> undef, undef
+; CHECK: cost of [[#mul(div(511,VBITS)+1,2)]] for instruction:   %add512.f16 = fadd <32 x half> undef, undef
+; CHECK: cost of [[#mul(div(511,VBITS)+1,2)]] for instruction:   %add512.f32 = fadd <16 x float> undef, undef
+; CHECK: cost of [[#mul(div(511,VBITS)+1,2)]] for instruction:   %add512.f64 = fadd <8 x double> undef, undef
+  %add512.i8 = add <64 x i8> undef, undef
+  %add512.i16 = add <32 x i16> undef, undef
+  %add512.i32 = add <16 x i32> undef, undef
+  %add512.i64 = add <8 x i64> undef, undef
+  %add512.f16 = fadd <32 x half> undef, undef
+  %add512.f32 = fadd <16 x float> undef, undef
+  %add512.f64 = fadd <8 x double> undef, undef
+
+  ret void
+}
+
+attributes #0 = { "target-features"="+sve" }


        


More information about the llvm-commits mailing list