[llvm] [SROA] Only try additional vector type candidates when needed (PR #77678)
Jeffrey Byrnes via llvm-commits
llvm-commits at lists.llvm.org
Fri Jan 12 09:13:44 PST 2024
https://github.com/jrbyrnes updated https://github.com/llvm/llvm-project/pull/77678
>From f21673bbb55dafa4b939bd8655fcc64027a3c0ac Mon Sep 17 00:00:00 2001
From: Jeffrey Byrnes <Jeffrey.Byrnes at amd.com>
Date: Fri, 12 Jan 2024 09:00:08 -0800
Subject: [PATCH 1/2] [SROA] NFC: Extract code to checkVectorTypesForPromotion
Change-Id: Ib6f237cc791a097f8f2411bc1d6502f11d4a748e
---
llvm/lib/Transforms/Scalar/SROA.cpp | 181 +++++++++++++++-------------
1 file changed, 99 insertions(+), 82 deletions(-)
diff --git a/llvm/lib/Transforms/Scalar/SROA.cpp b/llvm/lib/Transforms/Scalar/SROA.cpp
index 75cddfa16d6db5..64b1fd198354b7 100644
--- a/llvm/lib/Transforms/Scalar/SROA.cpp
+++ b/llvm/lib/Transforms/Scalar/SROA.cpp
@@ -2108,8 +2108,9 @@ static bool isVectorPromotionViableForSlice(Partition &P, const Slice &S,
/// Test whether a vector type is viable for promotion.
///
-/// This implements the necessary checking for \c isVectorPromotionViable over
-/// all slices of the alloca for the given VectorType.
+/// This implements the necessary checking for \c checkVectorTypesForPromotion
+/// (and thus isVectorPromotionViable) over all slices of the alloca for the
+/// given VectorType.
static bool checkVectorTypeForPromotion(Partition &P, VectorType *VTy,
const DataLayout &DL) {
uint64_t ElementSize =
@@ -2134,6 +2135,98 @@ static bool checkVectorTypeForPromotion(Partition &P, VectorType *VTy,
return true;
}
+/// Test whether any vector type in \p CandidateTys is viable for promotion.
+///
+/// This implements the necessary checking for \c isVectorPromotionViable over
+/// all slices of the alloca for the given VectorType.
+static VectorType *
+checkVectorTypesForPromotion(Partition &P, const DataLayout &DL,
+ SmallVectorImpl<VectorType *> &CandidateTys,
+ bool HaveCommonEltTy, Type *CommonEltTy,
+ bool HaveVecPtrTy, bool HaveCommonVecPtrTy,
+ VectorType *CommonVecPtrTy) {
+ // If we didn't find a vector type, nothing to do here.
+ if (CandidateTys.empty())
+ return nullptr;
+
+ // Pointer-ness is sticky, if we had a vector-of-pointers candidate type,
+ // then we should choose it, not some other alternative.
+ // But, we can't perform a no-op pointer address space change via bitcast,
+ // so if we didn't have a common pointer element type, bail.
+ if (HaveVecPtrTy && !HaveCommonVecPtrTy)
+ return nullptr;
+
+ // Try to pick the "best" element type out of the choices.
+ if (!HaveCommonEltTy && HaveVecPtrTy) {
+ // If there was a pointer element type, there's really only one choice.
+ CandidateTys.clear();
+ CandidateTys.push_back(CommonVecPtrTy);
+ } else if (!HaveCommonEltTy && !HaveVecPtrTy) {
+ // Integer-ify vector types.
+ for (VectorType *&VTy : CandidateTys) {
+ if (!VTy->getElementType()->isIntegerTy())
+ VTy = cast<VectorType>(VTy->getWithNewType(IntegerType::getIntNTy(
+ VTy->getContext(), VTy->getScalarSizeInBits())));
+ }
+
+ // Rank the remaining candidate vector types. This is easy because we know
+ // they're all integer vectors. We sort by ascending number of elements.
+ auto RankVectorTypesComp = [&DL](VectorType *RHSTy, VectorType *LHSTy) {
+ (void)DL;
+ assert(DL.getTypeSizeInBits(RHSTy).getFixedValue() ==
+ DL.getTypeSizeInBits(LHSTy).getFixedValue() &&
+ "Cannot have vector types of different sizes!");
+ assert(RHSTy->getElementType()->isIntegerTy() &&
+ "All non-integer types eliminated!");
+ assert(LHSTy->getElementType()->isIntegerTy() &&
+ "All non-integer types eliminated!");
+ return cast<FixedVectorType>(RHSTy)->getNumElements() <
+ cast<FixedVectorType>(LHSTy)->getNumElements();
+ };
+ auto RankVectorTypesEq = [&DL](VectorType *RHSTy, VectorType *LHSTy) {
+ (void)DL;
+ assert(DL.getTypeSizeInBits(RHSTy).getFixedValue() ==
+ DL.getTypeSizeInBits(LHSTy).getFixedValue() &&
+ "Cannot have vector types of different sizes!");
+ assert(RHSTy->getElementType()->isIntegerTy() &&
+ "All non-integer types eliminated!");
+ assert(LHSTy->getElementType()->isIntegerTy() &&
+ "All non-integer types eliminated!");
+ return cast<FixedVectorType>(RHSTy)->getNumElements() ==
+ cast<FixedVectorType>(LHSTy)->getNumElements();
+ };
+ llvm::sort(CandidateTys, RankVectorTypesComp);
+ CandidateTys.erase(std::unique(CandidateTys.begin(), CandidateTys.end(),
+ RankVectorTypesEq),
+ CandidateTys.end());
+ } else {
+// The only way to have the same element type in every vector type is to
+// have the same vector type. Check that and remove all but one.
+#ifndef NDEBUG
+ for (VectorType *VTy : CandidateTys) {
+ assert(VTy->getElementType() == CommonEltTy &&
+ "Unaccounted for element type!");
+ assert(VTy == CandidateTys[0] &&
+ "Different vector types with the same element type!");
+ }
+#endif
+ CandidateTys.resize(1);
+ }
+
+ // FIXME: hack. Do we have a named constant for this?
+ // SDAG SDNode can't have more than 65535 operands.
+ llvm::erase_if(CandidateTys, [](VectorType *VTy) {
+ return cast<FixedVectorType>(VTy)->getNumElements() >
+ std::numeric_limits<unsigned short>::max();
+ });
+
+ for (VectorType *VTy : CandidateTys)
+ if (checkVectorTypeForPromotion(P, VTy, DL))
+ return VTy;
+
+ return nullptr;
+}
+
/// Test whether the given alloca partitioning and range of slices can be
/// promoted to a vector.
///
@@ -2181,6 +2274,7 @@ static VectorType *isVectorPromotionViable(Partition &P, const DataLayout &DL) {
}
}
};
+
// Put load and store types into a set for de-duplication.
for (const Slice &S : P) {
Type *Ty;
@@ -2216,86 +2310,9 @@ static VectorType *isVectorPromotionViable(Partition &P, const DataLayout &DL) {
}
}
- // If we didn't find a vector type, nothing to do here.
- if (CandidateTys.empty())
- return nullptr;
-
- // Pointer-ness is sticky, if we had a vector-of-pointers candidate type,
- // then we should choose it, not some other alternative.
- // But, we can't perform a no-op pointer address space change via bitcast,
- // so if we didn't have a common pointer element type, bail.
- if (HaveVecPtrTy && !HaveCommonVecPtrTy)
- return nullptr;
-
- // Try to pick the "best" element type out of the choices.
- if (!HaveCommonEltTy && HaveVecPtrTy) {
- // If there was a pointer element type, there's really only one choice.
- CandidateTys.clear();
- CandidateTys.push_back(CommonVecPtrTy);
- } else if (!HaveCommonEltTy && !HaveVecPtrTy) {
- // Integer-ify vector types.
- for (VectorType *&VTy : CandidateTys) {
- if (!VTy->getElementType()->isIntegerTy())
- VTy = cast<VectorType>(VTy->getWithNewType(IntegerType::getIntNTy(
- VTy->getContext(), VTy->getScalarSizeInBits())));
- }
-
- // Rank the remaining candidate vector types. This is easy because we know
- // they're all integer vectors. We sort by ascending number of elements.
- auto RankVectorTypesComp = [&DL](VectorType *RHSTy, VectorType *LHSTy) {
- (void)DL;
- assert(DL.getTypeSizeInBits(RHSTy).getFixedValue() ==
- DL.getTypeSizeInBits(LHSTy).getFixedValue() &&
- "Cannot have vector types of different sizes!");
- assert(RHSTy->getElementType()->isIntegerTy() &&
- "All non-integer types eliminated!");
- assert(LHSTy->getElementType()->isIntegerTy() &&
- "All non-integer types eliminated!");
- return cast<FixedVectorType>(RHSTy)->getNumElements() <
- cast<FixedVectorType>(LHSTy)->getNumElements();
- };
- auto RankVectorTypesEq = [&DL](VectorType *RHSTy, VectorType *LHSTy) {
- (void)DL;
- assert(DL.getTypeSizeInBits(RHSTy).getFixedValue() ==
- DL.getTypeSizeInBits(LHSTy).getFixedValue() &&
- "Cannot have vector types of different sizes!");
- assert(RHSTy->getElementType()->isIntegerTy() &&
- "All non-integer types eliminated!");
- assert(LHSTy->getElementType()->isIntegerTy() &&
- "All non-integer types eliminated!");
- return cast<FixedVectorType>(RHSTy)->getNumElements() ==
- cast<FixedVectorType>(LHSTy)->getNumElements();
- };
- llvm::sort(CandidateTys, RankVectorTypesComp);
- CandidateTys.erase(std::unique(CandidateTys.begin(), CandidateTys.end(),
- RankVectorTypesEq),
- CandidateTys.end());
- } else {
-// The only way to have the same element type in every vector type is to
-// have the same vector type. Check that and remove all but one.
-#ifndef NDEBUG
- for (VectorType *VTy : CandidateTys) {
- assert(VTy->getElementType() == CommonEltTy &&
- "Unaccounted for element type!");
- assert(VTy == CandidateTys[0] &&
- "Different vector types with the same element type!");
- }
-#endif
- CandidateTys.resize(1);
- }
-
- // FIXME: hack. Do we have a named constant for this?
- // SDAG SDNode can't have more than 65535 operands.
- llvm::erase_if(CandidateTys, [](VectorType *VTy) {
- return cast<FixedVectorType>(VTy)->getNumElements() >
- std::numeric_limits<unsigned short>::max();
- });
-
- for (VectorType *VTy : CandidateTys)
- if (checkVectorTypeForPromotion(P, VTy, DL))
- return VTy;
-
- return nullptr;
+ return checkVectorTypesForPromotion(P, DL, CandidateTys, HaveCommonEltTy,
+ CommonEltTy, HaveVecPtrTy,
+ HaveCommonVecPtrTy, CommonVecPtrTy);
}
/// Test whether a slice of an alloca is valid for integer widening.
>From d4f38d84cb1e9fd17d9adb22c6a3fc49514343f9 Mon Sep 17 00:00:00 2001
From: Jeffrey Byrnes <Jeffrey.Byrnes at amd.com>
Date: Wed, 10 Jan 2024 08:11:24 -0800
Subject: [PATCH 2/2] [SROA] Only try additional vector type candidates when
needed
Change-Id: I06f3026b616ddc03d09ec6c416ad4cc15d837d96
---
llvm/lib/Transforms/Scalar/SROA.cpp | 7 ++
llvm/test/Transforms/SROA/vector-promotion.ll | 80 +++++++++++++++++++
2 files changed, 87 insertions(+)
diff --git a/llvm/lib/Transforms/Scalar/SROA.cpp b/llvm/lib/Transforms/Scalar/SROA.cpp
index 64b1fd198354b7..c18b6a18a00f7e 100644
--- a/llvm/lib/Transforms/Scalar/SROA.cpp
+++ b/llvm/lib/Transforms/Scalar/SROA.cpp
@@ -2289,6 +2289,12 @@ static VectorType *isVectorPromotionViable(Partition &P, const DataLayout &DL) {
if (S.beginOffset() == P.beginOffset() && S.endOffset() == P.endOffset())
CheckCandidateType(Ty);
}
+
+ if (auto *VTy = checkVectorTypesForPromotion(
+ P, DL, CandidateTys, HaveCommonEltTy, CommonEltTy, HaveVecPtrTy,
+ HaveCommonVecPtrTy, CommonVecPtrTy))
+ return VTy;
+
// Consider additional vector types where the element type size is a
// multiple of load/store element size.
for (Type *Ty : LoadStoreTys) {
@@ -2298,6 +2304,7 @@ static VectorType *isVectorPromotionViable(Partition &P, const DataLayout &DL) {
// Make a copy of CandidateTys and iterate through it, because we might
// append to CandidateTys in the loop.
SmallVector<VectorType *, 4> CandidateTysCopy = CandidateTys;
+ CandidateTys.clear();
for (VectorType *&VTy : CandidateTysCopy) {
unsigned VectorSize = DL.getTypeSizeInBits(VTy).getFixedValue();
unsigned ElementSize =
diff --git a/llvm/test/Transforms/SROA/vector-promotion.ll b/llvm/test/Transforms/SROA/vector-promotion.ll
index 9643a51064f049..77bf0b1cb1cb75 100644
--- a/llvm/test/Transforms/SROA/vector-promotion.ll
+++ b/llvm/test/Transforms/SROA/vector-promotion.ll
@@ -1227,6 +1227,86 @@ define void @swap-15bytes(ptr %x, ptr %y) {
ret void
}
+define <4 x i32> @ptrLoadStoreTys(ptr %init, i32 %val2) {
+; CHECK-LABEL: @ptrLoadStoreTys(
+; CHECK-NEXT: [[VAL0:%.*]] = load ptr, ptr [[INIT:%.*]], align 8
+; CHECK-NEXT: [[TMP1:%.*]] = ptrtoint ptr [[VAL0]] to i64
+; CHECK-NEXT: [[TMP2:%.*]] = bitcast i64 [[TMP1]] to <2 x i32>
+; CHECK-NEXT: [[OBJ_0_VEC_EXPAND:%.*]] = shufflevector <2 x i32> [[TMP2]], <2 x i32> poison, <4 x i32> <i32 0, i32 1, i32 poison, i32 poison>
+; CHECK-NEXT: [[OBJ_0_VECBLEND:%.*]] = select <4 x i1> <i1 true, i1 true, i1 false, i1 false>, <4 x i32> [[OBJ_0_VEC_EXPAND]], <4 x i32> zeroinitializer
+; CHECK-NEXT: [[OBJ_8_VEC_INSERT:%.*]] = insertelement <4 x i32> [[OBJ_0_VECBLEND]], i32 [[VAL2:%.*]], i32 2
+; CHECK-NEXT: [[OBJ_12_VEC_INSERT:%.*]] = insertelement <4 x i32> [[OBJ_8_VEC_INSERT]], i32 131072, i32 3
+; CHECK-NEXT: ret <4 x i32> [[OBJ_12_VEC_INSERT]]
+;
+; DEBUG-LABEL: @ptrLoadStoreTys(
+; DEBUG-NEXT: [[VAL0:%.*]] = load ptr, ptr [[INIT:%.*]], align 8, !dbg [[DBG492:![0-9]+]]
+; DEBUG-NEXT: call void @llvm.dbg.value(metadata ptr [[VAL0]], metadata [[META487:![0-9]+]], metadata !DIExpression()), !dbg [[DBG492]]
+; DEBUG-NEXT: call void @llvm.dbg.value(metadata ptr undef, metadata [[META488:![0-9]+]], metadata !DIExpression()), !dbg [[DBG493:![0-9]+]]
+; DEBUG-NEXT: [[TMP1:%.*]] = ptrtoint ptr [[VAL0]] to i64, !dbg [[DBG494:![0-9]+]]
+; DEBUG-NEXT: [[TMP2:%.*]] = bitcast i64 [[TMP1]] to <2 x i32>, !dbg [[DBG494]]
+; DEBUG-NEXT: [[OBJ_0_VEC_EXPAND:%.*]] = shufflevector <2 x i32> [[TMP2]], <2 x i32> poison, <4 x i32> <i32 0, i32 1, i32 poison, i32 poison>, !dbg [[DBG494]]
+; DEBUG-NEXT: [[OBJ_0_VECBLEND:%.*]] = select <4 x i1> <i1 true, i1 true, i1 false, i1 false>, <4 x i32> [[OBJ_0_VEC_EXPAND]], <4 x i32> zeroinitializer, !dbg [[DBG494]]
+; DEBUG-NEXT: call void @llvm.dbg.value(metadata ptr undef, metadata [[META489:![0-9]+]], metadata !DIExpression()), !dbg [[DBG495:![0-9]+]]
+; DEBUG-NEXT: [[OBJ_8_VEC_INSERT:%.*]] = insertelement <4 x i32> [[OBJ_0_VECBLEND]], i32 [[VAL2:%.*]], i32 2, !dbg [[DBG496:![0-9]+]]
+; DEBUG-NEXT: call void @llvm.dbg.value(metadata ptr undef, metadata [[META490:![0-9]+]], metadata !DIExpression()), !dbg [[DBG497:![0-9]+]]
+; DEBUG-NEXT: [[OBJ_12_VEC_INSERT:%.*]] = insertelement <4 x i32> [[OBJ_8_VEC_INSERT]], i32 131072, i32 3, !dbg [[DBG498:![0-9]+]]
+; DEBUG-NEXT: call void @llvm.dbg.value(metadata <4 x i32> [[OBJ_12_VEC_INSERT]], metadata [[META491:![0-9]+]], metadata !DIExpression()), !dbg [[DBG499:![0-9]+]]
+; DEBUG-NEXT: ret <4 x i32> [[OBJ_12_VEC_INSERT]], !dbg [[DBG500:![0-9]+]]
+;
+ %val0 = load ptr, ptr %init, align 8
+ %obj = alloca <4 x i32>, align 16
+ store <4 x i32> zeroinitializer, ptr %obj, align 16
+ store ptr %val0, ptr %obj, align 8
+ %ptr2 = getelementptr inbounds i8, ptr %obj, i64 8
+ store i32 %val2, ptr %ptr2, align 4
+ %ptr3 = getelementptr inbounds i8, ptr %obj, i64 12
+ store i32 131072, ptr %ptr3, align 4
+ %sroaval = load <4 x i32>, ptr %obj, align 16
+ ret <4 x i32> %sroaval
+}
+
+define <4 x float> @ptrLoadStoreTysFloat(ptr %init, float %val2) {
+; CHECK-LABEL: @ptrLoadStoreTysFloat(
+; CHECK-NEXT: [[VAL0:%.*]] = load ptr, ptr [[INIT:%.*]], align 8
+; CHECK-NEXT: [[OBJ:%.*]] = alloca <4 x float>, align 16
+; CHECK-NEXT: store <4 x float> zeroinitializer, ptr [[OBJ]], align 16
+; CHECK-NEXT: store ptr [[VAL0]], ptr [[OBJ]], align 16
+; CHECK-NEXT: [[OBJ_8_PTR2_SROA_IDX:%.*]] = getelementptr inbounds i8, ptr [[OBJ]], i64 8
+; CHECK-NEXT: store float [[VAL2:%.*]], ptr [[OBJ_8_PTR2_SROA_IDX]], align 8
+; CHECK-NEXT: [[OBJ_12_PTR3_SROA_IDX:%.*]] = getelementptr inbounds i8, ptr [[OBJ]], i64 12
+; CHECK-NEXT: store float 1.310720e+05, ptr [[OBJ_12_PTR3_SROA_IDX]], align 4
+; CHECK-NEXT: [[OBJ_0_SROAVAL:%.*]] = load <4 x float>, ptr [[OBJ]], align 16
+; CHECK-NEXT: ret <4 x float> [[OBJ_0_SROAVAL]]
+;
+; DEBUG-LABEL: @ptrLoadStoreTysFloat(
+; DEBUG-NEXT: [[VAL0:%.*]] = load ptr, ptr [[INIT:%.*]], align 8, !dbg [[DBG508:![0-9]+]]
+; DEBUG-NEXT: call void @llvm.dbg.value(metadata ptr [[VAL0]], metadata [[META503:![0-9]+]], metadata !DIExpression()), !dbg [[DBG508]]
+; DEBUG-NEXT: [[OBJ:%.*]] = alloca <4 x float>, align 16, !dbg [[DBG509:![0-9]+]]
+; DEBUG-NEXT: call void @llvm.dbg.value(metadata ptr [[OBJ]], metadata [[META504:![0-9]+]], metadata !DIExpression()), !dbg [[DBG509]]
+; DEBUG-NEXT: store <4 x float> zeroinitializer, ptr [[OBJ]], align 16, !dbg [[DBG510:![0-9]+]]
+; DEBUG-NEXT: store ptr [[VAL0]], ptr [[OBJ]], align 16, !dbg [[DBG511:![0-9]+]]
+; DEBUG-NEXT: call void @llvm.dbg.value(metadata ptr undef, metadata [[META505:![0-9]+]], metadata !DIExpression()), !dbg [[DBG512:![0-9]+]]
+; DEBUG-NEXT: [[OBJ_8_PTR2_SROA_IDX:%.*]] = getelementptr inbounds i8, ptr [[OBJ]], i64 8, !dbg [[DBG513:![0-9]+]]
+; DEBUG-NEXT: store float [[VAL2:%.*]], ptr [[OBJ_8_PTR2_SROA_IDX]], align 8, !dbg [[DBG513]]
+; DEBUG-NEXT: call void @llvm.dbg.value(metadata ptr undef, metadata [[META506:![0-9]+]], metadata !DIExpression()), !dbg [[DBG514:![0-9]+]]
+; DEBUG-NEXT: [[OBJ_12_PTR3_SROA_IDX:%.*]] = getelementptr inbounds i8, ptr [[OBJ]], i64 12, !dbg [[DBG515:![0-9]+]]
+; DEBUG-NEXT: store float 1.310720e+05, ptr [[OBJ_12_PTR3_SROA_IDX]], align 4, !dbg [[DBG515]]
+; DEBUG-NEXT: [[OBJ_0_SROAVAL:%.*]] = load <4 x float>, ptr [[OBJ]], align 16, !dbg [[DBG516:![0-9]+]]
+; DEBUG-NEXT: call void @llvm.dbg.value(metadata <4 x float> [[OBJ_0_SROAVAL]], metadata [[META507:![0-9]+]], metadata !DIExpression()), !dbg [[DBG516]]
+; DEBUG-NEXT: ret <4 x float> [[OBJ_0_SROAVAL]], !dbg [[DBG517:![0-9]+]]
+;
+ %val0 = load ptr, ptr %init, align 8
+ %obj = alloca <4 x float>, align 16
+ store <4 x float> zeroinitializer, ptr %obj, align 16
+ store ptr %val0, ptr %obj, align 8
+ %ptr2 = getelementptr inbounds i8, ptr %obj, i64 8
+ store float %val2, ptr %ptr2, align 4
+ %ptr3 = getelementptr inbounds i8, ptr %obj, i64 12
+ store float 131072.0, ptr %ptr3, align 4
+ %sroaval = load <4 x float>, ptr %obj, align 16
+ ret <4 x float> %sroaval
+}
+
declare void @llvm.memcpy.p0.p0.i64(ptr, ptr, i64, i1)
declare void @llvm.lifetime.end.p0(i64, ptr)
;; NOTE: These prefixes are unused and the list is autogenerated. Do not add tests below this line:
More information about the llvm-commits
mailing list