[llvm] [VectorCombine] Enable transform 'scalarizeLoadExtract' for scalable vector types (PR #65443)

Ben Shi via llvm-commits llvm-commits at lists.llvm.org
Wed Sep 6 20:03:49 PDT 2023


https://github.com/benshi001 updated https://github.com/llvm/llvm-project/pull/65443:

>From 74a625c836b060558b8589e81914937948a09983 Mon Sep 17 00:00:00 2001
From: Ben Shi <2283975856 at qq.com>
Date: Fri, 1 Sep 2023 14:43:42 +0800
Subject: [PATCH 1/2] [VectorCombine][test] Supplement tests of the
 load-extractelement sequence

The newly added tests are all about scalable vector types.
---
 .../load-extractelement-scalarization.ll      | 151 ++++++++++++++++++
 1 file changed, 151 insertions(+)

diff --git a/llvm/test/Transforms/VectorCombine/AArch64/load-extractelement-scalarization.ll b/llvm/test/Transforms/VectorCombine/AArch64/load-extractelement-scalarization.ll
index 1ed1476f6e2345..7df4f49e095c96 100644
--- a/llvm/test/Transforms/VectorCombine/AArch64/load-extractelement-scalarization.ll
+++ b/llvm/test/Transforms/VectorCombine/AArch64/load-extractelement-scalarization.ll
@@ -13,6 +13,17 @@ define i32 @load_extract_idx_0(ptr %x) {
   ret i32 %r
 }
 
+define i32 @vscale_load_extract_idx_0(ptr %x) {
+; CHECK-LABEL: @vscale_load_extract_idx_0(
+; CHECK-NEXT:    [[LV:%.*]] = load <vscale x 4 x i32>, ptr [[X:%.*]], align 16
+; CHECK-NEXT:    [[R:%.*]] = extractelement <vscale x 4 x i32> [[LV]], i32 0
+; CHECK-NEXT:    ret i32 [[R]]
+;
+  %lv = load <vscale x 4 x i32>, ptr %x
+  %r = extractelement <vscale x 4 x i32> %lv, i32 0
+  ret i32 %r
+}
+
 ; If the original load had a smaller alignment than the scalar type, the
 ; smaller alignment should be used.
 define i32 @load_extract_idx_0_small_alignment(ptr %x) {
@@ -48,6 +59,17 @@ define i32 @load_extract_idx_2(ptr %x) {
   ret i32 %r
 }
 
+define i32 @vscale_load_extract_idx_2(ptr %x) {
+; CHECK-LABEL: @vscale_load_extract_idx_2(
+; CHECK-NEXT:    [[LV:%.*]] = load <vscale x 4 x i32>, ptr [[X:%.*]], align 16
+; CHECK-NEXT:    [[R:%.*]] = extractelement <vscale x 4 x i32> [[LV]], i32 2
+; CHECK-NEXT:    ret i32 [[R]]
+;
+  %lv = load <vscale x 4 x i32>, ptr %x
+  %r = extractelement <vscale x 4 x i32> %lv, i32 2
+  ret i32 %r
+}
+
 define i32 @load_extract_idx_3(ptr %x) {
 ; CHECK-LABEL: @load_extract_idx_3(
 ; CHECK-NEXT:    [[TMP1:%.*]] = getelementptr inbounds <4 x i32>, ptr [[X:%.*]], i32 0, i32 3
@@ -72,6 +94,17 @@ define i32 @load_extract_idx_4(ptr %x) {
   ret i32 %r
 }
 
+define i32 @vscale_load_extract_idx_4(ptr %x) {
+; CHECK-LABEL: @vscale_load_extract_idx_4(
+; CHECK-NEXT:    [[LV:%.*]] = load <vscale x 4 x i32>, ptr [[X:%.*]], align 16
+; CHECK-NEXT:    [[R:%.*]] = extractelement <vscale x 4 x i32> [[LV]], i32 4
+; CHECK-NEXT:    ret i32 [[R]]
+;
+  %lv = load <vscale x 4 x i32>, ptr %x
+  %r = extractelement <vscale x 4 x i32> %lv, i32 4
+  ret i32 %r
+}
+
 define i32 @load_extract_idx_var_i64(ptr %x, i64 %idx) {
 ; CHECK-LABEL: @load_extract_idx_var_i64(
 ; CHECK-NEXT:    [[LV:%.*]] = load <4 x i32>, ptr [[X:%.*]], align 16
@@ -104,6 +137,25 @@ entry:
   ret i32 %r
 }
 
+define i32 @vscale_load_extract_idx_var_i64_known_valid_by_assume(ptr %x, i64 %idx) {
+; CHECK-LABEL: @vscale_load_extract_idx_var_i64_known_valid_by_assume(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[CMP:%.*]] = icmp ult i64 [[IDX:%.*]], 4
+; CHECK-NEXT:    call void @llvm.assume(i1 [[CMP]])
+; CHECK-NEXT:    [[LV:%.*]] = load <vscale x 4 x i32>, ptr [[X:%.*]], align 16
+; CHECK-NEXT:    call void @maythrow()
+; CHECK-NEXT:    [[R:%.*]] = extractelement <vscale x 4 x i32> [[LV]], i64 [[IDX]]
+; CHECK-NEXT:    ret i32 [[R]]
+;
+entry:
+  %cmp = icmp ult i64 %idx, 4
+  call void @llvm.assume(i1 %cmp)
+  %lv = load <vscale x 4 x i32>, ptr %x
+  call void @maythrow()
+  %r = extractelement <vscale x 4 x i32> %lv, i64 %idx
+  ret i32 %r
+}
+
 declare i1 @cond()
 
 define i32 @load_extract_idx_var_i64_known_valid_by_assume_in_dominating_block(ptr %x, i64 %idx, i1 %c.1) {
@@ -213,6 +265,45 @@ entry:
   ret i32 %r
 }
 
+define i32 @vscale_load_extract_idx_var_i64_not_known_valid_by_assume_0(ptr %x, i64 %idx) {
+; CHECK-LABEL: @vscale_load_extract_idx_var_i64_not_known_valid_by_assume_0(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[CMP:%.*]] = icmp ult i64 [[IDX:%.*]], 5
+; CHECK-NEXT:    call void @llvm.assume(i1 [[CMP]])
+; CHECK-NEXT:    [[LV:%.*]] = load <vscale x 4 x i32>, ptr [[X:%.*]], align 16
+; CHECK-NEXT:    [[R:%.*]] = extractelement <vscale x 4 x i32> [[LV]], i64 [[IDX]]
+; CHECK-NEXT:    ret i32 [[R]]
+;
+entry:
+  %cmp = icmp ult i64 %idx, 5
+  call void @llvm.assume(i1 %cmp)
+  %lv = load <vscale x 4 x i32>, ptr %x
+  %r = extractelement <vscale x 4 x i32> %lv, i64 %idx
+  ret i32 %r
+}
+
+define i32 @vscale_load_extract_idx_var_i64_not_known_valid_by_assume_1(ptr %x, i64 %idx) {
+; CHECK-LABEL: @vscale_load_extract_idx_var_i64_not_known_valid_by_assume_1(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[VS:%.*]] = call i64 @llvm.vscale.i64()
+; CHECK-NEXT:    [[VM:%.*]] = mul i64 [[VS]], 4
+; CHECK-NEXT:    [[CMP:%.*]] = icmp ult i64 [[IDX:%.*]], [[VM]]
+; CHECK-NEXT:    call void @llvm.assume(i1 [[CMP]])
+; CHECK-NEXT:    [[LV:%.*]] = load <vscale x 4 x i32>, ptr [[X:%.*]], align 16
+; CHECK-NEXT:    [[R:%.*]] = extractelement <vscale x 4 x i32> [[LV]], i64 [[IDX]]
+; CHECK-NEXT:    ret i32 [[R]]
+;
+entry:
+  %vs = call i64 @llvm.vscale.i64()
+  %vm = mul i64 %vs, 4
+  %cmp = icmp ult i64 %idx, %vm
+  call void @llvm.assume(i1 %cmp)
+  %lv = load <vscale x 4 x i32>, ptr %x
+  %r = extractelement <vscale x 4 x i32> %lv, i64 %idx
+  ret i32 %r
+}
+
+declare i64 @llvm.vscale.i64()
 declare void @llvm.assume(i1)
 
 define i32 @load_extract_idx_var_i64_known_valid_by_and(ptr %x, i64 %idx) {
@@ -230,6 +321,21 @@ entry:
   ret i32 %r
 }
 
+define i32 @vscale_load_extract_idx_var_i64_known_valid_by_and(ptr %x, i64 %idx) {
+; CHECK-LABEL: @vscale_load_extract_idx_var_i64_known_valid_by_and(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[IDX_CLAMPED:%.*]] = and i64 [[IDX:%.*]], 3
+; CHECK-NEXT:    [[LV:%.*]] = load <vscale x 4 x i32>, ptr [[X:%.*]], align 16
+; CHECK-NEXT:    [[R:%.*]] = extractelement <vscale x 4 x i32> [[LV]], i64 [[IDX_CLAMPED]]
+; CHECK-NEXT:    ret i32 [[R]]
+;
+entry:
+  %idx.clamped = and i64 %idx, 3
+  %lv = load <vscale x 4 x i32>, ptr %x
+  %r = extractelement <vscale x 4 x i32> %lv, i64 %idx.clamped
+  ret i32 %r
+}
+
 define i32 @load_extract_idx_var_i64_known_valid_by_and_noundef(ptr %x, i64 noundef %idx) {
 ; CHECK-LABEL: @load_extract_idx_var_i64_known_valid_by_and_noundef(
 ; CHECK-NEXT:  entry:
@@ -260,6 +366,21 @@ entry:
   ret i32 %r
 }
 
+define i32 @vscale_load_extract_idx_var_i64_not_known_valid_by_and(ptr %x, i64 %idx) {
+; CHECK-LABEL: @vscale_load_extract_idx_var_i64_not_known_valid_by_and(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[IDX_CLAMPED:%.*]] = and i64 [[IDX:%.*]], 4
+; CHECK-NEXT:    [[LV:%.*]] = load <vscale x 4 x i32>, ptr [[X:%.*]], align 16
+; CHECK-NEXT:    [[R:%.*]] = extractelement <vscale x 4 x i32> [[LV]], i64 [[IDX_CLAMPED]]
+; CHECK-NEXT:    ret i32 [[R]]
+;
+entry:
+  %idx.clamped = and i64 %idx, 4
+  %lv = load <vscale x 4 x i32>, ptr %x
+  %r = extractelement <vscale x 4 x i32> %lv, i64 %idx.clamped
+  ret i32 %r
+}
+
 define i32 @load_extract_idx_var_i64_known_valid_by_urem(ptr %x, i64 %idx) {
 ; CHECK-LABEL: @load_extract_idx_var_i64_known_valid_by_urem(
 ; CHECK-NEXT:  entry:
@@ -275,6 +396,21 @@ entry:
   ret i32 %r
 }
 
+define i32 @vscale_load_extract_idx_var_i64_known_valid_by_urem(ptr %x, i64 %idx) {
+; CHECK-LABEL: @vscale_load_extract_idx_var_i64_known_valid_by_urem(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[IDX_CLAMPED:%.*]] = urem i64 [[IDX:%.*]], 4
+; CHECK-NEXT:    [[LV:%.*]] = load <vscale x 4 x i32>, ptr [[X:%.*]], align 16
+; CHECK-NEXT:    [[R:%.*]] = extractelement <vscale x 4 x i32> [[LV]], i64 [[IDX_CLAMPED]]
+; CHECK-NEXT:    ret i32 [[R]]
+;
+entry:
+  %idx.clamped = urem i64 %idx, 4
+  %lv = load <vscale x 4 x i32>, ptr %x
+  %r = extractelement <vscale x 4 x i32> %lv, i64 %idx.clamped
+  ret i32 %r
+}
+
 define i32 @load_extract_idx_var_i64_known_valid_by_urem_noundef(ptr %x, i64 noundef %idx) {
 ; CHECK-LABEL: @load_extract_idx_var_i64_known_valid_by_urem_noundef(
 ; CHECK-NEXT:  entry:
@@ -305,6 +441,21 @@ entry:
   ret i32 %r
 }
 
+define i32 @vscale_load_extract_idx_var_i64_not_known_valid_by_urem(ptr %x, i64 %idx) {
+; CHECK-LABEL: @vscale_load_extract_idx_var_i64_not_known_valid_by_urem(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[IDX_CLAMPED:%.*]] = urem i64 [[IDX:%.*]], 5
+; CHECK-NEXT:    [[LV:%.*]] = load <vscale x 4 x i32>, ptr [[X:%.*]], align 16
+; CHECK-NEXT:    [[R:%.*]] = extractelement <vscale x 4 x i32> [[LV]], i64 [[IDX_CLAMPED]]
+; CHECK-NEXT:    ret i32 [[R]]
+;
+entry:
+  %idx.clamped = urem i64 %idx, 5
+  %lv = load <vscale x 4 x i32>, ptr %x
+  %r = extractelement <vscale x 4 x i32> %lv, i64 %idx.clamped
+  ret i32 %r
+}
+
 define i32 @load_extract_idx_var_i32(ptr %x, i32 %idx) {
 ; CHECK-LABEL: @load_extract_idx_var_i32(
 ; CHECK-NEXT:    [[LV:%.*]] = load <4 x i32>, ptr [[X:%.*]], align 16

>From db60c013379f33c5749d2ed6448d27b1c7487779 Mon Sep 17 00:00:00 2001
From: Ben Shi <2283975856 at qq.com>
Date: Thu, 7 Sep 2023 10:57:45 +0800
Subject: [PATCH 2/2] [VectorCombine] Enable transform 'scalarizeLoadExtract'
 for scalable vector types

The transform 'scalarizeLoadExtract' can be applied to scalable
vector types if the index is less than the minimum number of elements.

The check whether the index is less than the minimum number of elements
locates at line 1175~1180. 'scalarizeLoadExtract' will call 'canScalarizeAccess'
and check the returned result if this transform is safe.

At the beginning of the function 'canScalarizeAccess', the index will be checked
1. if it is less than the number of elements of a fixed vector type.
2. if it is less than the minimum number of elements of a scalable vector type.

Otherwise 'canScalarizeAccess' will return unsafe and this transform will be
prevented.
---
 .../Transforms/Vectorize/VectorCombine.cpp    | 25 +++++++++----------
 .../load-extractelement-scalarization.ll      | 12 ++++-----
 2 files changed, 18 insertions(+), 19 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
index 66e3bcaac0adb2..4f95eaba8de7bd 100644
--- a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
+++ b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
@@ -1134,14 +1134,14 @@ bool VectorCombine::scalarizeLoadExtract(Instruction &I) {
   if (!match(&I, m_Load(m_Value(Ptr))))
     return false;
 
-  auto *FixedVT = cast<FixedVectorType>(I.getType());
+  auto *VecTy = cast<VectorType>(I.getType());
   auto *LI = cast<LoadInst>(&I);
   const DataLayout &DL = I.getModule()->getDataLayout();
-  if (LI->isVolatile() || !DL.typeSizeEqualsStoreSize(FixedVT))
+  if (LI->isVolatile() || !DL.typeSizeEqualsStoreSize(VecTy))
     return false;
 
   InstructionCost OriginalCost =
-      TTI.getMemoryOpCost(Instruction::Load, FixedVT, LI->getAlign(),
+      TTI.getMemoryOpCost(Instruction::Load, VecTy, LI->getAlign(),
                           LI->getPointerAddressSpace());
   InstructionCost ScalarizedCost = 0;
 
@@ -1172,7 +1172,7 @@ bool VectorCombine::scalarizeLoadExtract(Instruction &I) {
       LastCheckedInst = UI;
     }
 
-    auto ScalarIdx = canScalarizeAccess(FixedVT, UI->getOperand(1), &I, AC, DT);
+    auto ScalarIdx = canScalarizeAccess(VecTy, UI->getOperand(1), &I, AC, DT);
     if (!ScalarIdx.isSafe()) {
       // TODO: Freeze index if it is safe to do so.
       ScalarIdx.discard();
@@ -1182,12 +1182,12 @@ bool VectorCombine::scalarizeLoadExtract(Instruction &I) {
     auto *Index = dyn_cast<ConstantInt>(UI->getOperand(1));
     TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
     OriginalCost +=
-        TTI.getVectorInstrCost(Instruction::ExtractElement, FixedVT, CostKind,
+        TTI.getVectorInstrCost(Instruction::ExtractElement, VecTy, CostKind,
                                Index ? Index->getZExtValue() : -1);
     ScalarizedCost +=
-        TTI.getMemoryOpCost(Instruction::Load, FixedVT->getElementType(),
+        TTI.getMemoryOpCost(Instruction::Load, VecTy->getElementType(),
                             Align(1), LI->getPointerAddressSpace());
-    ScalarizedCost += TTI.getAddressComputationCost(FixedVT->getElementType());
+    ScalarizedCost += TTI.getAddressComputationCost(VecTy->getElementType());
   }
 
   if (ScalarizedCost >= OriginalCost)
@@ -1200,12 +1200,12 @@ bool VectorCombine::scalarizeLoadExtract(Instruction &I) {
 
     Value *Idx = EI->getOperand(1);
     Value *GEP =
-        Builder.CreateInBoundsGEP(FixedVT, Ptr, {Builder.getInt32(0), Idx});
+        Builder.CreateInBoundsGEP(VecTy, Ptr, {Builder.getInt32(0), Idx});
     auto *NewLoad = cast<LoadInst>(Builder.CreateLoad(
-        FixedVT->getElementType(), GEP, EI->getName() + ".scalar"));
+        VecTy->getElementType(), GEP, EI->getName() + ".scalar"));
 
     Align ScalarOpAlignment = computeAlignmentAfterScalarization(
-        LI->getAlign(), FixedVT->getElementType(), Idx, DL);
+        LI->getAlign(), VecTy->getElementType(), Idx, DL);
     NewLoad->setAlignment(ScalarOpAlignment);
 
     replaceValue(*EI, *NewLoad);
@@ -1727,9 +1727,6 @@ bool VectorCombine::run() {
       case Instruction::ShuffleVector:
         MadeChange |= widenSubvectorLoad(I);
         break;
-      case Instruction::Load:
-        MadeChange |= scalarizeLoadExtract(I);
-        break;
       default:
         break;
       }
@@ -1743,6 +1740,8 @@ bool VectorCombine::run() {
     if (Opcode == Instruction::Store)
       MadeChange |= foldSingleElementStore(I);
 
+    if (isa<VectorType>(I.getType()) && Opcode == Instruction::Load)
+      MadeChange |= scalarizeLoadExtract(I);
 
     // If this is an early pipeline invocation of this pass, we are done.
     if (TryEarlyFoldsOnly)
diff --git a/llvm/test/Transforms/VectorCombine/AArch64/load-extractelement-scalarization.ll b/llvm/test/Transforms/VectorCombine/AArch64/load-extractelement-scalarization.ll
index 7df4f49e095c96..c7e5979aa9e7bd 100644
--- a/llvm/test/Transforms/VectorCombine/AArch64/load-extractelement-scalarization.ll
+++ b/llvm/test/Transforms/VectorCombine/AArch64/load-extractelement-scalarization.ll
@@ -15,8 +15,8 @@ define i32 @load_extract_idx_0(ptr %x) {
 
 define i32 @vscale_load_extract_idx_0(ptr %x) {
 ; CHECK-LABEL: @vscale_load_extract_idx_0(
-; CHECK-NEXT:    [[LV:%.*]] = load <vscale x 4 x i32>, ptr [[X:%.*]], align 16
-; CHECK-NEXT:    [[R:%.*]] = extractelement <vscale x 4 x i32> [[LV]], i32 0
+; CHECK-NEXT:    [[TMP1:%.*]] = getelementptr inbounds <vscale x 4 x i32>, ptr [[X:%.*]], i32 0, i32 0
+; CHECK-NEXT:    [[R:%.*]] = load i32, ptr [[TMP1]], align 16
 ; CHECK-NEXT:    ret i32 [[R]]
 ;
   %lv = load <vscale x 4 x i32>, ptr %x
@@ -61,8 +61,8 @@ define i32 @load_extract_idx_2(ptr %x) {
 
 define i32 @vscale_load_extract_idx_2(ptr %x) {
 ; CHECK-LABEL: @vscale_load_extract_idx_2(
-; CHECK-NEXT:    [[LV:%.*]] = load <vscale x 4 x i32>, ptr [[X:%.*]], align 16
-; CHECK-NEXT:    [[R:%.*]] = extractelement <vscale x 4 x i32> [[LV]], i32 2
+; CHECK-NEXT:    [[TMP1:%.*]] = getelementptr inbounds <vscale x 4 x i32>, ptr [[X:%.*]], i32 0, i32 2
+; CHECK-NEXT:    [[R:%.*]] = load i32, ptr [[TMP1]], align 8
 ; CHECK-NEXT:    ret i32 [[R]]
 ;
   %lv = load <vscale x 4 x i32>, ptr %x
@@ -142,9 +142,9 @@ define i32 @vscale_load_extract_idx_var_i64_known_valid_by_assume(ptr %x, i64 %i
 ; CHECK-NEXT:  entry:
 ; CHECK-NEXT:    [[CMP:%.*]] = icmp ult i64 [[IDX:%.*]], 4
 ; CHECK-NEXT:    call void @llvm.assume(i1 [[CMP]])
-; CHECK-NEXT:    [[LV:%.*]] = load <vscale x 4 x i32>, ptr [[X:%.*]], align 16
 ; CHECK-NEXT:    call void @maythrow()
-; CHECK-NEXT:    [[R:%.*]] = extractelement <vscale x 4 x i32> [[LV]], i64 [[IDX]]
+; CHECK-NEXT:    [[TMP0:%.*]] = getelementptr inbounds <vscale x 4 x i32>, ptr [[X:%.*]], i32 0, i64 [[IDX]]
+; CHECK-NEXT:    [[R:%.*]] = load i32, ptr [[TMP0]], align 4
 ; CHECK-NEXT:    ret i32 [[R]]
 ;
 entry:



More information about the llvm-commits mailing list