[llvm] 4612f39 - [SVE] Add flag to specify SVE register size, using this to calculate legal vector types.
Paul Walker via llvm-commits
llvm-commits at lists.llvm.org
Thu Jun 18 05:14:51 PDT 2020
Author: Paul Walker
Date: 2020-06-18T12:11:16Z
New Revision: 4612f391200d0b4e21bc040a098227d73679de53
URL: https://github.com/llvm/llvm-project/commit/4612f391200d0b4e21bc040a098227d73679de53
DIFF: https://github.com/llvm/llvm-project/commit/4612f391200d0b4e21bc040a098227d73679de53.diff
LOG: [SVE] Add flag to specify SVE register size, using this to calculate legal vector types.
Adds aarch64-sve-vector-bits-{min,max} to allow the size of SVE
data registers (in bits) to be specified. This allows the code
generator to make assumptions it normally couldn't. As a starting
point this information is used to mark fixed length vector types
that can fit within the specified size as legal.
Reviewers: rengolin, efriedma
Subscribers: tschuett, kristof.beyls, hiraditya, rkruppe, psnobl, llvm-commits
Tags: #llvm
Differential Revision: https://reviews.llvm.org/D80384
Added:
llvm/test/Analysis/CostModel/AArch64/sve-fixed-length.ll
Modified:
llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
llvm/lib/Target/AArch64/AArch64ISelLowering.h
llvm/lib/Target/AArch64/AArch64Subtarget.cpp
llvm/lib/Target/AArch64/AArch64Subtarget.h
llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
Removed:
################################################################################
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 6177a80d0aae..8c82d88f5a6f 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -184,6 +184,16 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
addRegisterClass(MVT::nxv4f32, &AArch64::ZPRRegClass);
addRegisterClass(MVT::nxv2f64, &AArch64::ZPRRegClass);
+ if (useSVEForFixedLengthVectors()) {
+ for (MVT VT : MVT::integer_fixedlen_vector_valuetypes())
+ if (useSVEForFixedLengthVectorVT(VT))
+ addRegisterClass(VT, &AArch64::ZPRRegClass);
+
+ for (MVT VT : MVT::fp_fixedlen_vector_valuetypes())
+ if (useSVEForFixedLengthVectorVT(VT))
+ addRegisterClass(VT, &AArch64::ZPRRegClass);
+ }
+
for (auto VT : { MVT::nxv16i8, MVT::nxv8i16, MVT::nxv4i32, MVT::nxv2i64 }) {
setOperationAction(ISD::SADDSAT, VT, Legal);
setOperationAction(ISD::UADDSAT, VT, Legal);
@@ -3474,6 +3484,51 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op,
}
}
+bool AArch64TargetLowering::useSVEForFixedLengthVectors() const {
+ // Prefer NEON unless larger SVE registers are available.
+ return Subtarget->hasSVE() && Subtarget->getMinSVEVectorSizeInBits() >= 256;
+}
+
+bool AArch64TargetLowering::useSVEForFixedLengthVectorVT(MVT VT) const {
+ assert(VT.isFixedLengthVector());
+ if (!useSVEForFixedLengthVectors())
+ return false;
+
+ // Fixed length predicates should be promoted to i8.
+ // NOTE: This is consistent with how NEON (and thus 64/128bit vectors) work.
+ if (VT.getVectorElementType() == MVT::i1)
+ return false;
+
+ // Don't use SVE for vectors we cannot scalarize if required.
+ switch (VT.getVectorElementType().SimpleTy) {
+ default:
+ return false;
+ case MVT::i8:
+ case MVT::i16:
+ case MVT::i32:
+ case MVT::i64:
+ case MVT::f16:
+ case MVT::f32:
+ case MVT::f64:
+ break;
+ }
+
+ // Ensure NEON MVTs only belong to a single register class.
+ if (VT.getSizeInBits() <= 128)
+ return false;
+
+ // Don't use SVE for types that don't fit.
+ if (VT.getSizeInBits() > Subtarget->getMinSVEVectorSizeInBits())
+ return false;
+
+ // TODO: Perhaps an artificial restriction, but worth having whilst getting
+ // the base fixed length SVE support in place.
+ if (!VT.isPow2VectorType())
+ return false;
+
+ return true;
+}
+
//===----------------------------------------------------------------------===//
// Calling Convention Implementation
//===----------------------------------------------------------------------===//
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
index 6d47fc8e7483..0c0be2a3ebd2 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
@@ -907,6 +907,9 @@ class AArch64TargetLowering : public TargetLowering {
bool shouldLocalize(const MachineInstr &MI,
const TargetTransformInfo *TTI) const override;
+
+ bool useSVEForFixedLengthVectors() const;
+ bool useSVEForFixedLengthVectorVT(MVT VT) const;
};
namespace AArch64 {
diff --git a/llvm/lib/Target/AArch64/AArch64Subtarget.cpp b/llvm/lib/Target/AArch64/AArch64Subtarget.cpp
index 2eed3448558b..9bb533f8efde 100644
--- a/llvm/lib/Target/AArch64/AArch64Subtarget.cpp
+++ b/llvm/lib/Target/AArch64/AArch64Subtarget.cpp
@@ -47,6 +47,18 @@ static cl::opt<bool>
cl::desc("Call nonlazybind functions via direct GOT load"),
cl::init(false), cl::Hidden);
+static cl::opt<unsigned> SVEVectorBitsMax(
+ "aarch64-sve-vector-bits-max",
+ cl::desc("Assume SVE vector registers are at most this big, "
+ "with zero meaning no maximum size is assumed."),
+ cl::init(0), cl::Hidden);
+
+static cl::opt<unsigned> SVEVectorBitsMin(
+ "aarch64-sve-vector-bits-min",
+ cl::desc("Assume SVE vector registers are at least this big, "
+ "with zero meaning no minimum size is assumed."),
+ cl::init(0), cl::Hidden);
+
AArch64Subtarget &
AArch64Subtarget::initializeSubtargetDependencies(StringRef FS,
StringRef CPUString) {
@@ -329,3 +341,25 @@ void AArch64Subtarget::mirFileLoaded(MachineFunction &MF) const {
if (!MFI.isMaxCallFrameSizeComputed())
MFI.computeMaxCallFrameSize(MF);
}
+
+unsigned AArch64Subtarget::getMaxSVEVectorSizeInBits() const {
+ assert(HasSVE && "Tried to get SVE vector length without SVE support!");
+ assert(SVEVectorBitsMax % 128 == 0 &&
+ "SVE requires vector length in multiples of 128!");
+ assert((SVEVectorBitsMax >= SVEVectorBitsMin || SVEVectorBitsMax == 0) &&
+ "Minimum SVE vector size should not be larger than its maximum!");
+ if (SVEVectorBitsMax == 0)
+ return 0;
+ return (std::max(SVEVectorBitsMin, SVEVectorBitsMax) / 128) * 128;
+}
+
+unsigned AArch64Subtarget::getMinSVEVectorSizeInBits() const {
+ assert(HasSVE && "Tried to get SVE vector length without SVE support!");
+ assert(SVEVectorBitsMin % 128 == 0 &&
+ "SVE requires vector length in multiples of 128!");
+ assert((SVEVectorBitsMax >= SVEVectorBitsMin || SVEVectorBitsMax == 0) &&
+ "Minimum SVE vector size should not be larger than its maximum!");
+ if (SVEVectorBitsMax == 0)
+ return (SVEVectorBitsMin / 128) * 128;
+ return (std::min(SVEVectorBitsMin, SVEVectorBitsMax) / 128) * 128;
+}
diff --git a/llvm/lib/Target/AArch64/AArch64Subtarget.h b/llvm/lib/Target/AArch64/AArch64Subtarget.h
index ba0660f72477..221c103cff8f 100644
--- a/llvm/lib/Target/AArch64/AArch64Subtarget.h
+++ b/llvm/lib/Target/AArch64/AArch64Subtarget.h
@@ -534,6 +534,12 @@ class AArch64Subtarget final : public AArch64GenSubtargetInfo {
}
void mirFileLoaded(MachineFunction &MF) const override;
+
+ // Return the known range for the bit length of SVE data registers. A value
+ // of 0 means nothing is known about that particular limit beyong what's
+ // implied by the architecture.
+ unsigned getMaxSVEVectorSizeInBits() const;
+ unsigned getMinSVEVectorSizeInBits() const;
};
} // End llvm namespace
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
index f7233d364f71..8e85f9277678 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
@@ -98,6 +98,8 @@ class AArch64TTIImpl : public BasicTTIImplBase<AArch64TTIImpl> {
unsigned getRegisterBitWidth(bool Vector) const {
if (Vector) {
+ if (ST->hasSVE())
+ return std::max(ST->getMinSVEVectorSizeInBits(), 128u);
if (ST->hasNEON())
return 128;
return 0;
diff --git a/llvm/test/Analysis/CostModel/AArch64/sve-fixed-length.ll b/llvm/test/Analysis/CostModel/AArch64/sve-fixed-length.ll
new file mode 100644
index 000000000000..ba7db1c4d904
--- /dev/null
+++ b/llvm/test/Analysis/CostModel/AArch64/sve-fixed-length.ll
@@ -0,0 +1,60 @@
+; RUN: opt < %s -cost-model -analyze | FileCheck %s -D#VBITS=128
+; RUN: opt < %s -cost-model -analyze -aarch64-sve-vector-bits-min=128 | FileCheck %s -D#VBITS=128
+; RUN: opt < %s -cost-model -analyze -aarch64-sve-vector-bits-min=256 | FileCheck %s -D#VBITS=256
+; RUN: opt < %s -cost-model -analyze -aarch64-sve-vector-bits-min=384 | FileCheck %s -D#VBITS=256
+; RUN: opt < %s -cost-model -analyze -aarch64-sve-vector-bits-min=512 | FileCheck %s -D#VBITS=512
+; RUN: opt < %s -cost-model -analyze -aarch64-sve-vector-bits-min=640 | FileCheck %s -D#VBITS=512
+; RUN: opt < %s -cost-model -analyze -aarch64-sve-vector-bits-min=768 | FileCheck %s -D#VBITS=512
+; RUN: opt < %s -cost-model -analyze -aarch64-sve-vector-bits-min=896 | FileCheck %s -D#VBITS=512
+; RUN: opt < %s -cost-model -analyze -aarch64-sve-vector-bits-min=1024 | FileCheck %s -D#VBITS=1024
+; RUN: opt < %s -cost-model -analyze -aarch64-sve-vector-bits-min=1152 | FileCheck %s -D#VBITS=1024
+; RUN: opt < %s -cost-model -analyze -aarch64-sve-vector-bits-min=1280 | FileCheck %s -D#VBITS=1024
+; RUN: opt < %s -cost-model -analyze -aarch64-sve-vector-bits-min=1408 | FileCheck %s -D#VBITS=1024
+; RUN: opt < %s -cost-model -analyze -aarch64-sve-vector-bits-min=1536 | FileCheck %s -D#VBITS=1024
+; RUN: opt < %s -cost-model -analyze -aarch64-sve-vector-bits-min=1664 | FileCheck %s -D#VBITS=1024
+; RUN: opt < %s -cost-model -analyze -aarch64-sve-vector-bits-min=1792 | FileCheck %s -D#VBITS=1024
+; RUN: opt < %s -cost-model -analyze -aarch64-sve-vector-bits-min=1920 | FileCheck %s -D#VBITS=1024
+; RUN: opt < %s -cost-model -analyze -aarch64-sve-vector-bits-min=2048 | FileCheck %s -D#VBITS=2048
+
+; VBITS represents the useful bit size of a vector register from the code
+; generator's point of view. It is clamped to power-of-2 values because
+; only power-of-2 vector lengths are considered legal, regardless of the
+; user specified vector length.
+
+target triple = "aarch64-unknown-linux-gnu"
+
+; Ensure the cost of legalisation is removed as the vector length grows.
+; NOTE: Assumes BaseCost_add=1, BaseCost_fadd=2.
+define void @add() #0 {
+; CHECK-LABEL: Printing analysis 'Cost Model Analysis' for function 'add':
+; CHECK: cost of [[#div(127,VBITS)+1]] for instruction: %add128 = add <4 x i32> undef, undef
+; CHECK: cost of [[#div(255,VBITS)+1]] for instruction: %add256 = add <8 x i32> undef, undef
+; CHECK: cost of [[#div(511,VBITS)+1]] for instruction: %add512 = add <16 x i32> undef, undef
+; CHECK: cost of [[#div(1023,VBITS)+1]] for instruction: %add1024 = add <32 x i32> undef, undef
+; CHECK: cost of [[#div(2047,VBITS)+1]] for instruction: %add2048 = add <64 x i32> undef, undef
+ %add128 = add <4 x i32> undef, undef
+ %add256 = add <8 x i32> undef, undef
+ %add512 = add <16 x i32> undef, undef
+ %add1024 = add <32 x i32> undef, undef
+ %add2048 = add <64 x i32> undef, undef
+
+; Using a single vector length, ensure all element types are recognised.
+; CHECK: cost of [[#div(511,VBITS)+1]] for instruction: %add512.i8 = add <64 x i8> undef, undef
+; CHECK: cost of [[#div(511,VBITS)+1]] for instruction: %add512.i16 = add <32 x i16> undef, undef
+; CHECK: cost of [[#div(511,VBITS)+1]] for instruction: %add512.i32 = add <16 x i32> undef, undef
+; CHECK: cost of [[#div(511,VBITS)+1]] for instruction: %add512.i64 = add <8 x i64> undef, undef
+; CHECK: cost of [[#mul(div(511,VBITS)+1,2)]] for instruction: %add512.f16 = fadd <32 x half> undef, undef
+; CHECK: cost of [[#mul(div(511,VBITS)+1,2)]] for instruction: %add512.f32 = fadd <16 x float> undef, undef
+; CHECK: cost of [[#mul(div(511,VBITS)+1,2)]] for instruction: %add512.f64 = fadd <8 x double> undef, undef
+ %add512.i8 = add <64 x i8> undef, undef
+ %add512.i16 = add <32 x i16> undef, undef
+ %add512.i32 = add <16 x i32> undef, undef
+ %add512.i64 = add <8 x i64> undef, undef
+ %add512.f16 = fadd <32 x half> undef, undef
+ %add512.f32 = fadd <16 x float> undef, undef
+ %add512.f64 = fadd <8 x double> undef, undef
+
+ ret void
+}
+
+attributes #0 = { "target-features"="+sve" }
More information about the llvm-commits
mailing list