[llvm] [SROA] Only try additional vector type candidates when needed (PR #77678)

Jeffrey Byrnes via llvm-commits llvm-commits at lists.llvm.org
Tue Jan 23 16:13:04 PST 2024


https://github.com/jrbyrnes updated https://github.com/llvm/llvm-project/pull/77678

>From 082325213f3a4ad1585d8b9548ed2810f27097ff Mon Sep 17 00:00:00 2001
From: Jeffrey Byrnes <Jeffrey.Byrnes at amd.com>
Date: Wed, 10 Jan 2024 08:11:24 -0800
Subject: [PATCH] [SROA] Only try additional vector type candidates when needed

Change-Id: I7e26c52643c98137fa4443ac95522c43f6481882
---
 llvm/lib/Transforms/Scalar/SROA.cpp           |  7 ++
 llvm/test/Transforms/SROA/vector-promotion.ll | 80 +++++++++++++++++++
 2 files changed, 87 insertions(+)

diff --git a/llvm/lib/Transforms/Scalar/SROA.cpp b/llvm/lib/Transforms/Scalar/SROA.cpp
index 10c25e2a0322971..bdbaf4f55c96d08 100644
--- a/llvm/lib/Transforms/Scalar/SROA.cpp
+++ b/llvm/lib/Transforms/Scalar/SROA.cpp
@@ -2319,6 +2319,12 @@ static VectorType *isVectorPromotionViable(Partition &P, const DataLayout &DL) {
     if (S.beginOffset() == P.beginOffset() && S.endOffset() == P.endOffset())
       CheckCandidateType(Ty);
   }
+
+  if (auto *VTy = checkVectorTypesForPromotion(
+          P, DL, CandidateTys, HaveCommonEltTy, CommonEltTy, HaveVecPtrTy,
+          HaveCommonVecPtrTy, CommonVecPtrTy))
+    return VTy;
+
   // Consider additional vector types where the element type size is a
   // multiple of load/store element size.
   for (Type *Ty : LoadStoreTys) {
@@ -2328,6 +2334,7 @@ static VectorType *isVectorPromotionViable(Partition &P, const DataLayout &DL) {
     // Make a copy of CandidateTys and iterate through it, because we might
     // append to CandidateTys in the loop.
     SmallVector<VectorType *, 4> CandidateTysCopy = CandidateTys;
+    CandidateTys.clear();
     for (VectorType *&VTy : CandidateTysCopy) {
       unsigned VectorSize = DL.getTypeSizeInBits(VTy).getFixedValue();
       unsigned ElementSize =
diff --git a/llvm/test/Transforms/SROA/vector-promotion.ll b/llvm/test/Transforms/SROA/vector-promotion.ll
index 9643a51064f049f..77bf0b1cb1cb75b 100644
--- a/llvm/test/Transforms/SROA/vector-promotion.ll
+++ b/llvm/test/Transforms/SROA/vector-promotion.ll
@@ -1227,6 +1227,86 @@ define void @swap-15bytes(ptr %x, ptr %y) {
   ret void
 }
 
+define <4 x i32> @ptrLoadStoreTys(ptr %init, i32 %val2) {
+; CHECK-LABEL: @ptrLoadStoreTys(
+; CHECK-NEXT:    [[VAL0:%.*]] = load ptr, ptr [[INIT:%.*]], align 8
+; CHECK-NEXT:    [[TMP1:%.*]] = ptrtoint ptr [[VAL0]] to i64
+; CHECK-NEXT:    [[TMP2:%.*]] = bitcast i64 [[TMP1]] to <2 x i32>
+; CHECK-NEXT:    [[OBJ_0_VEC_EXPAND:%.*]] = shufflevector <2 x i32> [[TMP2]], <2 x i32> poison, <4 x i32> <i32 0, i32 1, i32 poison, i32 poison>
+; CHECK-NEXT:    [[OBJ_0_VECBLEND:%.*]] = select <4 x i1> <i1 true, i1 true, i1 false, i1 false>, <4 x i32> [[OBJ_0_VEC_EXPAND]], <4 x i32> zeroinitializer
+; CHECK-NEXT:    [[OBJ_8_VEC_INSERT:%.*]] = insertelement <4 x i32> [[OBJ_0_VECBLEND]], i32 [[VAL2:%.*]], i32 2
+; CHECK-NEXT:    [[OBJ_12_VEC_INSERT:%.*]] = insertelement <4 x i32> [[OBJ_8_VEC_INSERT]], i32 131072, i32 3
+; CHECK-NEXT:    ret <4 x i32> [[OBJ_12_VEC_INSERT]]
+;
+; DEBUG-LABEL: @ptrLoadStoreTys(
+; DEBUG-NEXT:    [[VAL0:%.*]] = load ptr, ptr [[INIT:%.*]], align 8, !dbg [[DBG492:![0-9]+]]
+; DEBUG-NEXT:    call void @llvm.dbg.value(metadata ptr [[VAL0]], metadata [[META487:![0-9]+]], metadata !DIExpression()), !dbg [[DBG492]]
+; DEBUG-NEXT:    call void @llvm.dbg.value(metadata ptr undef, metadata [[META488:![0-9]+]], metadata !DIExpression()), !dbg [[DBG493:![0-9]+]]
+; DEBUG-NEXT:    [[TMP1:%.*]] = ptrtoint ptr [[VAL0]] to i64, !dbg [[DBG494:![0-9]+]]
+; DEBUG-NEXT:    [[TMP2:%.*]] = bitcast i64 [[TMP1]] to <2 x i32>, !dbg [[DBG494]]
+; DEBUG-NEXT:    [[OBJ_0_VEC_EXPAND:%.*]] = shufflevector <2 x i32> [[TMP2]], <2 x i32> poison, <4 x i32> <i32 0, i32 1, i32 poison, i32 poison>, !dbg [[DBG494]]
+; DEBUG-NEXT:    [[OBJ_0_VECBLEND:%.*]] = select <4 x i1> <i1 true, i1 true, i1 false, i1 false>, <4 x i32> [[OBJ_0_VEC_EXPAND]], <4 x i32> zeroinitializer, !dbg [[DBG494]]
+; DEBUG-NEXT:    call void @llvm.dbg.value(metadata ptr undef, metadata [[META489:![0-9]+]], metadata !DIExpression()), !dbg [[DBG495:![0-9]+]]
+; DEBUG-NEXT:    [[OBJ_8_VEC_INSERT:%.*]] = insertelement <4 x i32> [[OBJ_0_VECBLEND]], i32 [[VAL2:%.*]], i32 2, !dbg [[DBG496:![0-9]+]]
+; DEBUG-NEXT:    call void @llvm.dbg.value(metadata ptr undef, metadata [[META490:![0-9]+]], metadata !DIExpression()), !dbg [[DBG497:![0-9]+]]
+; DEBUG-NEXT:    [[OBJ_12_VEC_INSERT:%.*]] = insertelement <4 x i32> [[OBJ_8_VEC_INSERT]], i32 131072, i32 3, !dbg [[DBG498:![0-9]+]]
+; DEBUG-NEXT:    call void @llvm.dbg.value(metadata <4 x i32> [[OBJ_12_VEC_INSERT]], metadata [[META491:![0-9]+]], metadata !DIExpression()), !dbg [[DBG499:![0-9]+]]
+; DEBUG-NEXT:    ret <4 x i32> [[OBJ_12_VEC_INSERT]], !dbg [[DBG500:![0-9]+]]
+;
+  %val0 = load ptr, ptr %init, align 8
+  %obj = alloca <4 x i32>, align 16
+  store <4 x i32> zeroinitializer, ptr %obj, align 16
+  store ptr %val0, ptr %obj, align 8
+  %ptr2 = getelementptr inbounds i8, ptr %obj, i64 8
+  store i32 %val2, ptr %ptr2, align 4
+  %ptr3 = getelementptr inbounds i8, ptr %obj, i64 12
+  store i32 131072, ptr %ptr3, align 4
+  %sroaval = load <4 x i32>, ptr %obj, align 16
+  ret <4 x i32> %sroaval
+}
+
+define <4 x float> @ptrLoadStoreTysFloat(ptr %init, float %val2) {
+; CHECK-LABEL: @ptrLoadStoreTysFloat(
+; CHECK-NEXT:    [[VAL0:%.*]] = load ptr, ptr [[INIT:%.*]], align 8
+; CHECK-NEXT:    [[OBJ:%.*]] = alloca <4 x float>, align 16
+; CHECK-NEXT:    store <4 x float> zeroinitializer, ptr [[OBJ]], align 16
+; CHECK-NEXT:    store ptr [[VAL0]], ptr [[OBJ]], align 16
+; CHECK-NEXT:    [[OBJ_8_PTR2_SROA_IDX:%.*]] = getelementptr inbounds i8, ptr [[OBJ]], i64 8
+; CHECK-NEXT:    store float [[VAL2:%.*]], ptr [[OBJ_8_PTR2_SROA_IDX]], align 8
+; CHECK-NEXT:    [[OBJ_12_PTR3_SROA_IDX:%.*]] = getelementptr inbounds i8, ptr [[OBJ]], i64 12
+; CHECK-NEXT:    store float 1.310720e+05, ptr [[OBJ_12_PTR3_SROA_IDX]], align 4
+; CHECK-NEXT:    [[OBJ_0_SROAVAL:%.*]] = load <4 x float>, ptr [[OBJ]], align 16
+; CHECK-NEXT:    ret <4 x float> [[OBJ_0_SROAVAL]]
+;
+; DEBUG-LABEL: @ptrLoadStoreTysFloat(
+; DEBUG-NEXT:    [[VAL0:%.*]] = load ptr, ptr [[INIT:%.*]], align 8, !dbg [[DBG508:![0-9]+]]
+; DEBUG-NEXT:    call void @llvm.dbg.value(metadata ptr [[VAL0]], metadata [[META503:![0-9]+]], metadata !DIExpression()), !dbg [[DBG508]]
+; DEBUG-NEXT:    [[OBJ:%.*]] = alloca <4 x float>, align 16, !dbg [[DBG509:![0-9]+]]
+; DEBUG-NEXT:    call void @llvm.dbg.value(metadata ptr [[OBJ]], metadata [[META504:![0-9]+]], metadata !DIExpression()), !dbg [[DBG509]]
+; DEBUG-NEXT:    store <4 x float> zeroinitializer, ptr [[OBJ]], align 16, !dbg [[DBG510:![0-9]+]]
+; DEBUG-NEXT:    store ptr [[VAL0]], ptr [[OBJ]], align 16, !dbg [[DBG511:![0-9]+]]
+; DEBUG-NEXT:    call void @llvm.dbg.value(metadata ptr undef, metadata [[META505:![0-9]+]], metadata !DIExpression()), !dbg [[DBG512:![0-9]+]]
+; DEBUG-NEXT:    [[OBJ_8_PTR2_SROA_IDX:%.*]] = getelementptr inbounds i8, ptr [[OBJ]], i64 8, !dbg [[DBG513:![0-9]+]]
+; DEBUG-NEXT:    store float [[VAL2:%.*]], ptr [[OBJ_8_PTR2_SROA_IDX]], align 8, !dbg [[DBG513]]
+; DEBUG-NEXT:    call void @llvm.dbg.value(metadata ptr undef, metadata [[META506:![0-9]+]], metadata !DIExpression()), !dbg [[DBG514:![0-9]+]]
+; DEBUG-NEXT:    [[OBJ_12_PTR3_SROA_IDX:%.*]] = getelementptr inbounds i8, ptr [[OBJ]], i64 12, !dbg [[DBG515:![0-9]+]]
+; DEBUG-NEXT:    store float 1.310720e+05, ptr [[OBJ_12_PTR3_SROA_IDX]], align 4, !dbg [[DBG515]]
+; DEBUG-NEXT:    [[OBJ_0_SROAVAL:%.*]] = load <4 x float>, ptr [[OBJ]], align 16, !dbg [[DBG516:![0-9]+]]
+; DEBUG-NEXT:    call void @llvm.dbg.value(metadata <4 x float> [[OBJ_0_SROAVAL]], metadata [[META507:![0-9]+]], metadata !DIExpression()), !dbg [[DBG516]]
+; DEBUG-NEXT:    ret <4 x float> [[OBJ_0_SROAVAL]], !dbg [[DBG517:![0-9]+]]
+;
+  %val0 = load ptr, ptr %init, align 8
+  %obj = alloca <4 x float>, align 16
+  store <4 x float> zeroinitializer, ptr %obj, align 16
+  store ptr %val0, ptr %obj, align 8
+  %ptr2 = getelementptr inbounds i8, ptr %obj, i64 8
+  store float %val2, ptr %ptr2, align 4
+  %ptr3 = getelementptr inbounds i8, ptr %obj, i64 12
+  store float 131072.0, ptr %ptr3, align 4
+  %sroaval = load <4 x float>, ptr %obj, align 16
+  ret <4 x float> %sroaval
+}
+
 declare void @llvm.memcpy.p0.p0.i64(ptr, ptr, i64, i1)
 declare void @llvm.lifetime.end.p0(i64, ptr)
 ;; NOTE: These prefixes are unused and the list is autogenerated. Do not add tests below this line:



More information about the llvm-commits mailing list