[llvm] [SROA] Only try additional vector type candidates when needed (PR #77678)

via llvm-commits llvm-commits at lists.llvm.org
Wed Jan 10 11:55:01 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-transforms

Author: Jeffrey Byrnes (jrbyrnes)

<details>
<summary>Changes</summary>

https://github.com/llvm/llvm-project/commit/f9c2a341b94ca71508dcefa109ece843459f7f13 causes regressions when we have a slice with integer vector type that is the same size as the partition, and a ptr load/store slice that is not the size of the element type.

Ref `vector-promotion.ll:ptrLoadStoreTys`. 

Before the patch, we would only consider `<4 x i32>` as a candidate type for vector promotion, and would find that it is a viable type for all the slices.

After the patch, we now add `<2 x ptr>` as a candidate type due to slice with user `store ptr %val0, ptr %obj, align 8` -- and flag that we `HaveVecPtrTy`. The pre-existing behavior of this flag results in removing the viable `<4 x i32>` and keeping only the unviable `<2 x ptr>`, which results in a failure to promote.

The end result is failing to promote an alloca that was previously promoted -- this does not appear to be the intent of that patch, which has the goal of increasing promotions by providing more promotion opportunities. 

This PR preserves this behavior via a simple reorganization of the implemention: try first the slice types with same size as the partition, then, if there is no promotable type, try the `LoadStoreTys.`

Sort of RFC at the moment, since the previous promotion is perhaps questionable.

---
Full diff: https://github.com/llvm/llvm-project/pull/77678.diff


2 Files Affected:

- (modified) llvm/lib/Transforms/Scalar/SROA.cpp (+106-82) 
- (modified) llvm/test/Transforms/SROA/vector-promotion.ll (+38) 


``````````diff
diff --git a/llvm/lib/Transforms/Scalar/SROA.cpp b/llvm/lib/Transforms/Scalar/SROA.cpp
index 75cddfa16d6db5..c18b6a18a00f7e 100644
--- a/llvm/lib/Transforms/Scalar/SROA.cpp
+++ b/llvm/lib/Transforms/Scalar/SROA.cpp
@@ -2108,8 +2108,9 @@ static bool isVectorPromotionViableForSlice(Partition &P, const Slice &S,
 
 /// Test whether a vector type is viable for promotion.
 ///
-/// This implements the necessary checking for \c isVectorPromotionViable over
-/// all slices of the alloca for the given VectorType.
+/// This implements the necessary checking for \c checkVectorTypesForPromotion
+/// (and thus isVectorPromotionViable) over all slices of the alloca for the
+/// given VectorType.
 static bool checkVectorTypeForPromotion(Partition &P, VectorType *VTy,
                                         const DataLayout &DL) {
   uint64_t ElementSize =
@@ -2134,6 +2135,98 @@ static bool checkVectorTypeForPromotion(Partition &P, VectorType *VTy,
   return true;
 }
 
+/// Test whether any vector type in \p CandidateTys is viable for promotion.
+///
+/// This implements the necessary checking for \c isVectorPromotionViable over
+/// all slices of the alloca for the given VectorType.
+static VectorType *
+checkVectorTypesForPromotion(Partition &P, const DataLayout &DL,
+                             SmallVectorImpl<VectorType *> &CandidateTys,
+                             bool HaveCommonEltTy, Type *CommonEltTy,
+                             bool HaveVecPtrTy, bool HaveCommonVecPtrTy,
+                             VectorType *CommonVecPtrTy) {
+  // If we didn't find a vector type, nothing to do here.
+  if (CandidateTys.empty())
+    return nullptr;
+
+  // Pointer-ness is sticky, if we had a vector-of-pointers candidate type,
+  // then we should choose it, not some other alternative.
+  // But, we can't perform a no-op pointer address space change via bitcast,
+  // so if we didn't have a common pointer element type, bail.
+  if (HaveVecPtrTy && !HaveCommonVecPtrTy)
+    return nullptr;
+
+  // Try to pick the "best" element type out of the choices.
+  if (!HaveCommonEltTy && HaveVecPtrTy) {
+    // If there was a pointer element type, there's really only one choice.
+    CandidateTys.clear();
+    CandidateTys.push_back(CommonVecPtrTy);
+  } else if (!HaveCommonEltTy && !HaveVecPtrTy) {
+    // Integer-ify vector types.
+    for (VectorType *&VTy : CandidateTys) {
+      if (!VTy->getElementType()->isIntegerTy())
+        VTy = cast<VectorType>(VTy->getWithNewType(IntegerType::getIntNTy(
+            VTy->getContext(), VTy->getScalarSizeInBits())));
+    }
+
+    // Rank the remaining candidate vector types. This is easy because we know
+    // they're all integer vectors. We sort by ascending number of elements.
+    auto RankVectorTypesComp = [&DL](VectorType *RHSTy, VectorType *LHSTy) {
+      (void)DL;
+      assert(DL.getTypeSizeInBits(RHSTy).getFixedValue() ==
+                 DL.getTypeSizeInBits(LHSTy).getFixedValue() &&
+             "Cannot have vector types of different sizes!");
+      assert(RHSTy->getElementType()->isIntegerTy() &&
+             "All non-integer types eliminated!");
+      assert(LHSTy->getElementType()->isIntegerTy() &&
+             "All non-integer types eliminated!");
+      return cast<FixedVectorType>(RHSTy)->getNumElements() <
+             cast<FixedVectorType>(LHSTy)->getNumElements();
+    };
+    auto RankVectorTypesEq = [&DL](VectorType *RHSTy, VectorType *LHSTy) {
+      (void)DL;
+      assert(DL.getTypeSizeInBits(RHSTy).getFixedValue() ==
+                 DL.getTypeSizeInBits(LHSTy).getFixedValue() &&
+             "Cannot have vector types of different sizes!");
+      assert(RHSTy->getElementType()->isIntegerTy() &&
+             "All non-integer types eliminated!");
+      assert(LHSTy->getElementType()->isIntegerTy() &&
+             "All non-integer types eliminated!");
+      return cast<FixedVectorType>(RHSTy)->getNumElements() ==
+             cast<FixedVectorType>(LHSTy)->getNumElements();
+    };
+    llvm::sort(CandidateTys, RankVectorTypesComp);
+    CandidateTys.erase(std::unique(CandidateTys.begin(), CandidateTys.end(),
+                                   RankVectorTypesEq),
+                       CandidateTys.end());
+  } else {
+// The only way to have the same element type in every vector type is to
+// have the same vector type. Check that and remove all but one.
+#ifndef NDEBUG
+    for (VectorType *VTy : CandidateTys) {
+      assert(VTy->getElementType() == CommonEltTy &&
+             "Unaccounted for element type!");
+      assert(VTy == CandidateTys[0] &&
+             "Different vector types with the same element type!");
+    }
+#endif
+    CandidateTys.resize(1);
+  }
+
+  // FIXME: hack. Do we have a named constant for this?
+  // SDAG SDNode can't have more than 65535 operands.
+  llvm::erase_if(CandidateTys, [](VectorType *VTy) {
+    return cast<FixedVectorType>(VTy)->getNumElements() >
+           std::numeric_limits<unsigned short>::max();
+  });
+
+  for (VectorType *VTy : CandidateTys)
+    if (checkVectorTypeForPromotion(P, VTy, DL))
+      return VTy;
+
+  return nullptr;
+}
+
 /// Test whether the given alloca partitioning and range of slices can be
 /// promoted to a vector.
 ///
@@ -2181,6 +2274,7 @@ static VectorType *isVectorPromotionViable(Partition &P, const DataLayout &DL) {
       }
     }
   };
+
   // Put load and store types into a set for de-duplication.
   for (const Slice &S : P) {
     Type *Ty;
@@ -2195,6 +2289,12 @@ static VectorType *isVectorPromotionViable(Partition &P, const DataLayout &DL) {
     if (S.beginOffset() == P.beginOffset() && S.endOffset() == P.endOffset())
       CheckCandidateType(Ty);
   }
+
+  if (auto *VTy = checkVectorTypesForPromotion(
+          P, DL, CandidateTys, HaveCommonEltTy, CommonEltTy, HaveVecPtrTy,
+          HaveCommonVecPtrTy, CommonVecPtrTy))
+    return VTy;
+
   // Consider additional vector types where the element type size is a
   // multiple of load/store element size.
   for (Type *Ty : LoadStoreTys) {
@@ -2204,6 +2304,7 @@ static VectorType *isVectorPromotionViable(Partition &P, const DataLayout &DL) {
     // Make a copy of CandidateTys and iterate through it, because we might
     // append to CandidateTys in the loop.
     SmallVector<VectorType *, 4> CandidateTysCopy = CandidateTys;
+    CandidateTys.clear();
     for (VectorType *&VTy : CandidateTysCopy) {
       unsigned VectorSize = DL.getTypeSizeInBits(VTy).getFixedValue();
       unsigned ElementSize =
@@ -2216,86 +2317,9 @@ static VectorType *isVectorPromotionViable(Partition &P, const DataLayout &DL) {
     }
   }
 
-  // If we didn't find a vector type, nothing to do here.
-  if (CandidateTys.empty())
-    return nullptr;
-
-  // Pointer-ness is sticky, if we had a vector-of-pointers candidate type,
-  // then we should choose it, not some other alternative.
-  // But, we can't perform a no-op pointer address space change via bitcast,
-  // so if we didn't have a common pointer element type, bail.
-  if (HaveVecPtrTy && !HaveCommonVecPtrTy)
-    return nullptr;
-
-  // Try to pick the "best" element type out of the choices.
-  if (!HaveCommonEltTy && HaveVecPtrTy) {
-    // If there was a pointer element type, there's really only one choice.
-    CandidateTys.clear();
-    CandidateTys.push_back(CommonVecPtrTy);
-  } else if (!HaveCommonEltTy && !HaveVecPtrTy) {
-    // Integer-ify vector types.
-    for (VectorType *&VTy : CandidateTys) {
-      if (!VTy->getElementType()->isIntegerTy())
-        VTy = cast<VectorType>(VTy->getWithNewType(IntegerType::getIntNTy(
-            VTy->getContext(), VTy->getScalarSizeInBits())));
-    }
-
-    // Rank the remaining candidate vector types. This is easy because we know
-    // they're all integer vectors. We sort by ascending number of elements.
-    auto RankVectorTypesComp = [&DL](VectorType *RHSTy, VectorType *LHSTy) {
-      (void)DL;
-      assert(DL.getTypeSizeInBits(RHSTy).getFixedValue() ==
-                 DL.getTypeSizeInBits(LHSTy).getFixedValue() &&
-             "Cannot have vector types of different sizes!");
-      assert(RHSTy->getElementType()->isIntegerTy() &&
-             "All non-integer types eliminated!");
-      assert(LHSTy->getElementType()->isIntegerTy() &&
-             "All non-integer types eliminated!");
-      return cast<FixedVectorType>(RHSTy)->getNumElements() <
-             cast<FixedVectorType>(LHSTy)->getNumElements();
-    };
-    auto RankVectorTypesEq = [&DL](VectorType *RHSTy, VectorType *LHSTy) {
-      (void)DL;
-      assert(DL.getTypeSizeInBits(RHSTy).getFixedValue() ==
-                 DL.getTypeSizeInBits(LHSTy).getFixedValue() &&
-             "Cannot have vector types of different sizes!");
-      assert(RHSTy->getElementType()->isIntegerTy() &&
-             "All non-integer types eliminated!");
-      assert(LHSTy->getElementType()->isIntegerTy() &&
-             "All non-integer types eliminated!");
-      return cast<FixedVectorType>(RHSTy)->getNumElements() ==
-             cast<FixedVectorType>(LHSTy)->getNumElements();
-    };
-    llvm::sort(CandidateTys, RankVectorTypesComp);
-    CandidateTys.erase(std::unique(CandidateTys.begin(), CandidateTys.end(),
-                                   RankVectorTypesEq),
-                       CandidateTys.end());
-  } else {
-// The only way to have the same element type in every vector type is to
-// have the same vector type. Check that and remove all but one.
-#ifndef NDEBUG
-    for (VectorType *VTy : CandidateTys) {
-      assert(VTy->getElementType() == CommonEltTy &&
-             "Unaccounted for element type!");
-      assert(VTy == CandidateTys[0] &&
-             "Different vector types with the same element type!");
-    }
-#endif
-    CandidateTys.resize(1);
-  }
-
-  // FIXME: hack. Do we have a named constant for this?
-  // SDAG SDNode can't have more than 65535 operands.
-  llvm::erase_if(CandidateTys, [](VectorType *VTy) {
-    return cast<FixedVectorType>(VTy)->getNumElements() >
-           std::numeric_limits<unsigned short>::max();
-  });
-
-  for (VectorType *VTy : CandidateTys)
-    if (checkVectorTypeForPromotion(P, VTy, DL))
-      return VTy;
-
-  return nullptr;
+  return checkVectorTypesForPromotion(P, DL, CandidateTys, HaveCommonEltTy,
+                                      CommonEltTy, HaveVecPtrTy,
+                                      HaveCommonVecPtrTy, CommonVecPtrTy);
 }
 
 /// Test whether a slice of an alloca is valid for integer widening.
diff --git a/llvm/test/Transforms/SROA/vector-promotion.ll b/llvm/test/Transforms/SROA/vector-promotion.ll
index 9643a51064f049..00a7beee12b66e 100644
--- a/llvm/test/Transforms/SROA/vector-promotion.ll
+++ b/llvm/test/Transforms/SROA/vector-promotion.ll
@@ -1227,6 +1227,44 @@ define void @swap-15bytes(ptr %x, ptr %y) {
   ret void
 }
 
+define <4 x i32> @ptrLoadStoreTys(ptr %init, i32 %val2) {
+; CHECK-LABEL: @ptrLoadStoreTys(
+; CHECK-NEXT:    [[VAL0:%.*]] = load ptr, ptr [[INIT:%.*]], align 8
+; CHECK-NEXT:    [[TMP1:%.*]] = ptrtoint ptr [[VAL0]] to i64
+; CHECK-NEXT:    [[TMP2:%.*]] = bitcast i64 [[TMP1]] to <2 x i32>
+; CHECK-NEXT:    [[OBJ_0_VEC_EXPAND:%.*]] = shufflevector <2 x i32> [[TMP2]], <2 x i32> poison, <4 x i32> <i32 0, i32 1, i32 poison, i32 poison>
+; CHECK-NEXT:    [[OBJ_0_VECBLEND:%.*]] = select <4 x i1> <i1 true, i1 true, i1 false, i1 false>, <4 x i32> [[OBJ_0_VEC_EXPAND]], <4 x i32> zeroinitializer
+; CHECK-NEXT:    [[OBJ_8_VEC_INSERT:%.*]] = insertelement <4 x i32> [[OBJ_0_VECBLEND]], i32 [[VAL2:%.*]], i32 2
+; CHECK-NEXT:    [[OBJ_12_VEC_INSERT:%.*]] = insertelement <4 x i32> [[OBJ_8_VEC_INSERT]], i32 131072, i32 3
+; CHECK-NEXT:    ret <4 x i32> [[OBJ_12_VEC_INSERT]]
+;
+; DEBUG-LABEL: @ptrLoadStoreTys(
+; DEBUG-NEXT:    [[VAL0:%.*]] = load ptr, ptr [[INIT:%.*]], align 8, !dbg [[DBG492:![0-9]+]]
+; DEBUG-NEXT:    call void @llvm.dbg.value(metadata ptr [[VAL0]], metadata [[META487:![0-9]+]], metadata !DIExpression()), !dbg [[DBG492]]
+; DEBUG-NEXT:    call void @llvm.dbg.value(metadata ptr undef, metadata [[META488:![0-9]+]], metadata !DIExpression()), !dbg [[DBG493:![0-9]+]]
+; DEBUG-NEXT:    [[TMP1:%.*]] = ptrtoint ptr [[VAL0]] to i64, !dbg [[DBG494:![0-9]+]]
+; DEBUG-NEXT:    [[TMP2:%.*]] = bitcast i64 [[TMP1]] to <2 x i32>, !dbg [[DBG494]]
+; DEBUG-NEXT:    [[OBJ_0_VEC_EXPAND:%.*]] = shufflevector <2 x i32> [[TMP2]], <2 x i32> poison, <4 x i32> <i32 0, i32 1, i32 poison, i32 poison>, !dbg [[DBG494]]
+; DEBUG-NEXT:    [[OBJ_0_VECBLEND:%.*]] = select <4 x i1> <i1 true, i1 true, i1 false, i1 false>, <4 x i32> [[OBJ_0_VEC_EXPAND]], <4 x i32> zeroinitializer, !dbg [[DBG494]]
+; DEBUG-NEXT:    call void @llvm.dbg.value(metadata ptr undef, metadata [[META489:![0-9]+]], metadata !DIExpression()), !dbg [[DBG495:![0-9]+]]
+; DEBUG-NEXT:    [[OBJ_8_VEC_INSERT:%.*]] = insertelement <4 x i32> [[OBJ_0_VECBLEND]], i32 [[VAL2:%.*]], i32 2, !dbg [[DBG496:![0-9]+]]
+; DEBUG-NEXT:    call void @llvm.dbg.value(metadata ptr undef, metadata [[META490:![0-9]+]], metadata !DIExpression()), !dbg [[DBG497:![0-9]+]]
+; DEBUG-NEXT:    [[OBJ_12_VEC_INSERT:%.*]] = insertelement <4 x i32> [[OBJ_8_VEC_INSERT]], i32 131072, i32 3, !dbg [[DBG498:![0-9]+]]
+; DEBUG-NEXT:    call void @llvm.dbg.value(metadata <4 x i32> [[OBJ_12_VEC_INSERT]], metadata [[META491:![0-9]+]], metadata !DIExpression()), !dbg [[DBG499:![0-9]+]]
+; DEBUG-NEXT:    ret <4 x i32> [[OBJ_12_VEC_INSERT]], !dbg [[DBG500:![0-9]+]]
+;
+  %val0 = load ptr, ptr %init, align 8
+  %obj = alloca <4 x i32>, align 16
+  store <4 x i32> zeroinitializer, ptr %obj, align 16
+  store ptr %val0, ptr %obj, align 8
+  %ptr2 = getelementptr inbounds i8, ptr %obj, i64 8
+  store i32 %val2, ptr %ptr2, align 4
+  %ptr3 = getelementptr inbounds i8, ptr %obj, i64 12
+  store i32 131072, ptr %ptr3, align 4
+  %sroaval = load <4 x i32>, ptr %obj, align 16
+  ret <4 x i32> %sroaval
+}
+
 declare void @llvm.memcpy.p0.p0.i64(ptr, ptr, i64, i1)
 declare void @llvm.lifetime.end.p0(i64, ptr)
 ;; NOTE: These prefixes are unused and the list is autogenerated. Do not add tests below this line:

``````````

</details>


https://github.com/llvm/llvm-project/pull/77678


More information about the llvm-commits mailing list