[llvm] [VectorCombine] Enable transform 'scalarizeLoadExtract' for non constant indexes (PR #65445)

Ben Shi via llvm-commits llvm-commits at lists.llvm.org
Wed Sep 13 04:07:35 PDT 2023


https://github.com/benshi001 updated https://github.com/llvm/llvm-project/pull/65445:

>From 4063c28369bdaa570e6fd8bb64243e5d9eed531e Mon Sep 17 00:00:00 2001
From: Ben Shi <2283975856 at qq.com>
Date: Thu, 7 Sep 2023 10:57:45 +0800
Subject: [PATCH 1/2] [VectorCombine] Enable transform 'scalarizeLoadExtract'
 for scalable vector types

The transform 'scalarizeLoadExtract' can be applied to scalable
vector types if the index is less than the minimum number of elements.

The check whether the index is less than the minimum number of elements
locates at line 1175~1180. 'scalarizeLoadExtract' will call 'canScalarizeAccess'
and check the returned result if this transform is safe.

At the beginning of the function 'canScalarizeAccess', the index will be checked
1. if it is less than the number of elements of a fixed vector type.
2. if it is less than the minimum number of elements of a scalable vector type.

Otherwise 'canScalarizeAccess' will return unsafe and this transform will be
prevented.
---
 .../Transforms/Vectorize/VectorCombine.cpp    | 25 +++++++++----------
 .../load-extractelement-scalarization.ll      | 12 ++++-----
 2 files changed, 18 insertions(+), 19 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
index 66e3bcaac0adb2e..4f95eaba8de7bd2 100644
--- a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
+++ b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
@@ -1134,14 +1134,14 @@ bool VectorCombine::scalarizeLoadExtract(Instruction &I) {
   if (!match(&I, m_Load(m_Value(Ptr))))
     return false;
 
-  auto *FixedVT = cast<FixedVectorType>(I.getType());
+  auto *VecTy = cast<VectorType>(I.getType());
   auto *LI = cast<LoadInst>(&I);
   const DataLayout &DL = I.getModule()->getDataLayout();
-  if (LI->isVolatile() || !DL.typeSizeEqualsStoreSize(FixedVT))
+  if (LI->isVolatile() || !DL.typeSizeEqualsStoreSize(VecTy))
     return false;
 
   InstructionCost OriginalCost =
-      TTI.getMemoryOpCost(Instruction::Load, FixedVT, LI->getAlign(),
+      TTI.getMemoryOpCost(Instruction::Load, VecTy, LI->getAlign(),
                           LI->getPointerAddressSpace());
   InstructionCost ScalarizedCost = 0;
 
@@ -1172,7 +1172,7 @@ bool VectorCombine::scalarizeLoadExtract(Instruction &I) {
       LastCheckedInst = UI;
     }
 
-    auto ScalarIdx = canScalarizeAccess(FixedVT, UI->getOperand(1), &I, AC, DT);
+    auto ScalarIdx = canScalarizeAccess(VecTy, UI->getOperand(1), &I, AC, DT);
     if (!ScalarIdx.isSafe()) {
       // TODO: Freeze index if it is safe to do so.
       ScalarIdx.discard();
@@ -1182,12 +1182,12 @@ bool VectorCombine::scalarizeLoadExtract(Instruction &I) {
     auto *Index = dyn_cast<ConstantInt>(UI->getOperand(1));
     TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
     OriginalCost +=
-        TTI.getVectorInstrCost(Instruction::ExtractElement, FixedVT, CostKind,
+        TTI.getVectorInstrCost(Instruction::ExtractElement, VecTy, CostKind,
                                Index ? Index->getZExtValue() : -1);
     ScalarizedCost +=
-        TTI.getMemoryOpCost(Instruction::Load, FixedVT->getElementType(),
+        TTI.getMemoryOpCost(Instruction::Load, VecTy->getElementType(),
                             Align(1), LI->getPointerAddressSpace());
-    ScalarizedCost += TTI.getAddressComputationCost(FixedVT->getElementType());
+    ScalarizedCost += TTI.getAddressComputationCost(VecTy->getElementType());
   }
 
   if (ScalarizedCost >= OriginalCost)
@@ -1200,12 +1200,12 @@ bool VectorCombine::scalarizeLoadExtract(Instruction &I) {
 
     Value *Idx = EI->getOperand(1);
     Value *GEP =
-        Builder.CreateInBoundsGEP(FixedVT, Ptr, {Builder.getInt32(0), Idx});
+        Builder.CreateInBoundsGEP(VecTy, Ptr, {Builder.getInt32(0), Idx});
     auto *NewLoad = cast<LoadInst>(Builder.CreateLoad(
-        FixedVT->getElementType(), GEP, EI->getName() + ".scalar"));
+        VecTy->getElementType(), GEP, EI->getName() + ".scalar"));
 
     Align ScalarOpAlignment = computeAlignmentAfterScalarization(
-        LI->getAlign(), FixedVT->getElementType(), Idx, DL);
+        LI->getAlign(), VecTy->getElementType(), Idx, DL);
     NewLoad->setAlignment(ScalarOpAlignment);
 
     replaceValue(*EI, *NewLoad);
@@ -1727,9 +1727,6 @@ bool VectorCombine::run() {
       case Instruction::ShuffleVector:
         MadeChange |= widenSubvectorLoad(I);
         break;
-      case Instruction::Load:
-        MadeChange |= scalarizeLoadExtract(I);
-        break;
       default:
         break;
       }
@@ -1743,6 +1740,8 @@ bool VectorCombine::run() {
     if (Opcode == Instruction::Store)
       MadeChange |= foldSingleElementStore(I);
 
+    if (isa<VectorType>(I.getType()) && Opcode == Instruction::Load)
+      MadeChange |= scalarizeLoadExtract(I);
 
     // If this is an early pipeline invocation of this pass, we are done.
     if (TryEarlyFoldsOnly)
diff --git a/llvm/test/Transforms/VectorCombine/AArch64/load-extractelement-scalarization.ll b/llvm/test/Transforms/VectorCombine/AArch64/load-extractelement-scalarization.ll
index 7df4f49e095c96c..c7e5979aa9e7bd9 100644
--- a/llvm/test/Transforms/VectorCombine/AArch64/load-extractelement-scalarization.ll
+++ b/llvm/test/Transforms/VectorCombine/AArch64/load-extractelement-scalarization.ll
@@ -15,8 +15,8 @@ define i32 @load_extract_idx_0(ptr %x) {
 
 define i32 @vscale_load_extract_idx_0(ptr %x) {
 ; CHECK-LABEL: @vscale_load_extract_idx_0(
-; CHECK-NEXT:    [[LV:%.*]] = load <vscale x 4 x i32>, ptr [[X:%.*]], align 16
-; CHECK-NEXT:    [[R:%.*]] = extractelement <vscale x 4 x i32> [[LV]], i32 0
+; CHECK-NEXT:    [[TMP1:%.*]] = getelementptr inbounds <vscale x 4 x i32>, ptr [[X:%.*]], i32 0, i32 0
+; CHECK-NEXT:    [[R:%.*]] = load i32, ptr [[TMP1]], align 16
 ; CHECK-NEXT:    ret i32 [[R]]
 ;
   %lv = load <vscale x 4 x i32>, ptr %x
@@ -61,8 +61,8 @@ define i32 @load_extract_idx_2(ptr %x) {
 
 define i32 @vscale_load_extract_idx_2(ptr %x) {
 ; CHECK-LABEL: @vscale_load_extract_idx_2(
-; CHECK-NEXT:    [[LV:%.*]] = load <vscale x 4 x i32>, ptr [[X:%.*]], align 16
-; CHECK-NEXT:    [[R:%.*]] = extractelement <vscale x 4 x i32> [[LV]], i32 2
+; CHECK-NEXT:    [[TMP1:%.*]] = getelementptr inbounds <vscale x 4 x i32>, ptr [[X:%.*]], i32 0, i32 2
+; CHECK-NEXT:    [[R:%.*]] = load i32, ptr [[TMP1]], align 8
 ; CHECK-NEXT:    ret i32 [[R]]
 ;
   %lv = load <vscale x 4 x i32>, ptr %x
@@ -142,9 +142,9 @@ define i32 @vscale_load_extract_idx_var_i64_known_valid_by_assume(ptr %x, i64 %i
 ; CHECK-NEXT:  entry:
 ; CHECK-NEXT:    [[CMP:%.*]] = icmp ult i64 [[IDX:%.*]], 4
 ; CHECK-NEXT:    call void @llvm.assume(i1 [[CMP]])
-; CHECK-NEXT:    [[LV:%.*]] = load <vscale x 4 x i32>, ptr [[X:%.*]], align 16
 ; CHECK-NEXT:    call void @maythrow()
-; CHECK-NEXT:    [[R:%.*]] = extractelement <vscale x 4 x i32> [[LV]], i64 [[IDX]]
+; CHECK-NEXT:    [[TMP0:%.*]] = getelementptr inbounds <vscale x 4 x i32>, ptr [[X:%.*]], i32 0, i64 [[IDX]]
+; CHECK-NEXT:    [[R:%.*]] = load i32, ptr [[TMP0]], align 4
 ; CHECK-NEXT:    ret i32 [[R]]
 ;
 entry:

>From 6bfabf3ebf93791f996e74cd7d5e67a67e5f1306 Mon Sep 17 00:00:00 2001
From: Ben Shi <2283975856 at qq.com>
Date: Thu, 7 Sep 2023 12:32:02 +0800
Subject: [PATCH 2/2] [VectorCombine] Enable transform 'scalarizeLoadExtract'
 for non constant indexes

Enable the transform if a non constant index is guaranteed to be safe
via a UREM/AND.
---
 .../Transforms/Vectorize/VectorCombine.cpp    | 27 ++++++----
 .../load-extractelement-scalarization.ll      | 51 +++++++++++--------
 2 files changed, 48 insertions(+), 30 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
index 4f95eaba8de7bd2..830804ddd3b8024 100644
--- a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
+++ b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
@@ -13,6 +13,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "llvm/Transforms/Vectorize/VectorCombine.h"
+#include "llvm/ADT/DenseMap.h"
 #include "llvm/ADT/Statistic.h"
 #include "llvm/Analysis/AssumptionCache.h"
 #include "llvm/Analysis/BasicAliasAnalysis.h"
@@ -969,7 +970,11 @@ class ScalarizationResult {
 public:
   ScalarizationResult(const ScalarizationResult &Other) = default;
   ~ScalarizationResult() {
-    assert(!ToFreeze && "freeze() not called with ToFreeze being set");
+    // The object may be copied to another scope if it is in state
+    // StatusTy::SafeWithFreeze.
+    if (Status != StatusTy::SafeWithFreeze)
+      assert(!ToFreeze &&
+             "freeze() or discard() not called with ToFreeze being set");
   }
 
   static ScalarizationResult unsafe() { return {StatusTy::Unsafe}; }
@@ -1147,6 +1152,7 @@ bool VectorCombine::scalarizeLoadExtract(Instruction &I) {
 
   Instruction *LastCheckedInst = LI;
   unsigned NumInstChecked = 0;
+  DenseMap<ExtractElementInst *, ScalarizationResult> NeedFreeze;
   // Check if all users of the load are extracts with no memory modifications
   // between the load and the extract. Compute the cost of both the original
   // code and the scalarized version.
@@ -1155,9 +1161,6 @@ bool VectorCombine::scalarizeLoadExtract(Instruction &I) {
     if (!UI || UI->getParent() != LI->getParent())
       return false;
 
-    if (!isGuaranteedNotToBePoison(UI->getOperand(1), &AC, LI, &DT))
-      return false;
-
     // Check if any instruction between the load and the extract may modify
     // memory.
     if (LastCheckedInst->comesBefore(UI)) {
@@ -1173,10 +1176,11 @@ bool VectorCombine::scalarizeLoadExtract(Instruction &I) {
     }
 
     auto ScalarIdx = canScalarizeAccess(VecTy, UI->getOperand(1), &I, AC, DT);
-    if (!ScalarIdx.isSafe()) {
-      // TODO: Freeze index if it is safe to do so.
-      ScalarIdx.discard();
+    if (ScalarIdx.isUnsafe()) {
       return false;
+    } else if (ScalarIdx.isSafeWithFreeze()) {
+      NeedFreeze.insert(std::make_pair(UI, ScalarIdx));
+      ScalarIdx.discard();
     }
 
     auto *Index = dyn_cast<ConstantInt>(UI->getOperand(1));
@@ -1196,9 +1200,14 @@ bool VectorCombine::scalarizeLoadExtract(Instruction &I) {
   // Replace extracts with narrow scalar loads.
   for (User *U : LI->users()) {
     auto *EI = cast<ExtractElementInst>(U);
-    Builder.SetInsertPoint(EI);
-
     Value *Idx = EI->getOperand(1);
+
+    // Insert 'freeze' for poison indexes.
+    DenseMap<ExtractElementInst *, ScalarizationResult>::iterator It;
+    if ((It = NeedFreeze.find(EI)) != NeedFreeze.end())
+      It->second.freeze(Builder, *cast<Instruction>(Idx));
+
+    Builder.SetInsertPoint(EI);
     Value *GEP =
         Builder.CreateInBoundsGEP(VecTy, Ptr, {Builder.getInt32(0), Idx});
     auto *NewLoad = cast<LoadInst>(Builder.CreateLoad(
diff --git a/llvm/test/Transforms/VectorCombine/AArch64/load-extractelement-scalarization.ll b/llvm/test/Transforms/VectorCombine/AArch64/load-extractelement-scalarization.ll
index c7e5979aa9e7bd9..42b3f9afeb56ee8 100644
--- a/llvm/test/Transforms/VectorCombine/AArch64/load-extractelement-scalarization.ll
+++ b/llvm/test/Transforms/VectorCombine/AArch64/load-extractelement-scalarization.ll
@@ -309,9 +309,10 @@ declare void @llvm.assume(i1)
 define i32 @load_extract_idx_var_i64_known_valid_by_and(ptr %x, i64 %idx) {
 ; CHECK-LABEL: @load_extract_idx_var_i64_known_valid_by_and(
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[IDX_CLAMPED:%.*]] = and i64 [[IDX:%.*]], 3
-; CHECK-NEXT:    [[LV:%.*]] = load <4 x i32>, ptr [[X:%.*]], align 16
-; CHECK-NEXT:    [[R:%.*]] = extractelement <4 x i32> [[LV]], i64 [[IDX_CLAMPED]]
+; CHECK-NEXT:    [[IDX_FROZEN:%.*]] = freeze i64 [[IDX:%.*]]
+; CHECK-NEXT:    [[IDX_CLAMPED:%.*]] = and i64 [[IDX_FROZEN]], 3
+; CHECK-NEXT:    [[TMP0:%.*]] = getelementptr inbounds <4 x i32>, ptr [[X:%.*]], i32 0, i64 [[IDX_CLAMPED]]
+; CHECK-NEXT:    [[R:%.*]] = load i32, ptr [[TMP0]], align 4
 ; CHECK-NEXT:    ret i32 [[R]]
 ;
 entry:
@@ -324,9 +325,10 @@ entry:
 define i32 @vscale_load_extract_idx_var_i64_known_valid_by_and(ptr %x, i64 %idx) {
 ; CHECK-LABEL: @vscale_load_extract_idx_var_i64_known_valid_by_and(
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[IDX_CLAMPED:%.*]] = and i64 [[IDX:%.*]], 3
-; CHECK-NEXT:    [[LV:%.*]] = load <vscale x 4 x i32>, ptr [[X:%.*]], align 16
-; CHECK-NEXT:    [[R:%.*]] = extractelement <vscale x 4 x i32> [[LV]], i64 [[IDX_CLAMPED]]
+; CHECK-NEXT:    [[IDX_FROZEN:%.*]] = freeze i64 [[IDX:%.*]]
+; CHECK-NEXT:    [[IDX_CLAMPED:%.*]] = and i64 [[IDX_FROZEN]], 3
+; CHECK-NEXT:    [[TMP0:%.*]] = getelementptr inbounds <vscale x 4 x i32>, ptr [[X:%.*]], i32 0, i64 [[IDX_CLAMPED]]
+; CHECK-NEXT:    [[R:%.*]] = load i32, ptr [[TMP0]], align 4
 ; CHECK-NEXT:    ret i32 [[R]]
 ;
 entry:
@@ -384,9 +386,10 @@ entry:
 define i32 @load_extract_idx_var_i64_known_valid_by_urem(ptr %x, i64 %idx) {
 ; CHECK-LABEL: @load_extract_idx_var_i64_known_valid_by_urem(
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[IDX_CLAMPED:%.*]] = urem i64 [[IDX:%.*]], 4
-; CHECK-NEXT:    [[LV:%.*]] = load <4 x i32>, ptr [[X:%.*]], align 16
-; CHECK-NEXT:    [[R:%.*]] = extractelement <4 x i32> [[LV]], i64 [[IDX_CLAMPED]]
+; CHECK-NEXT:    [[IDX_FROZEN:%.*]] = freeze i64 [[IDX:%.*]]
+; CHECK-NEXT:    [[IDX_CLAMPED:%.*]] = urem i64 [[IDX_FROZEN]], 4
+; CHECK-NEXT:    [[TMP0:%.*]] = getelementptr inbounds <4 x i32>, ptr [[X:%.*]], i32 0, i64 [[IDX_CLAMPED]]
+; CHECK-NEXT:    [[R:%.*]] = load i32, ptr [[TMP0]], align 4
 ; CHECK-NEXT:    ret i32 [[R]]
 ;
 entry:
@@ -399,9 +402,10 @@ entry:
 define i32 @vscale_load_extract_idx_var_i64_known_valid_by_urem(ptr %x, i64 %idx) {
 ; CHECK-LABEL: @vscale_load_extract_idx_var_i64_known_valid_by_urem(
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[IDX_CLAMPED:%.*]] = urem i64 [[IDX:%.*]], 4
-; CHECK-NEXT:    [[LV:%.*]] = load <vscale x 4 x i32>, ptr [[X:%.*]], align 16
-; CHECK-NEXT:    [[R:%.*]] = extractelement <vscale x 4 x i32> [[LV]], i64 [[IDX_CLAMPED]]
+; CHECK-NEXT:    [[IDX_FROZEN:%.*]] = freeze i64 [[IDX:%.*]]
+; CHECK-NEXT:    [[IDX_CLAMPED:%.*]] = urem i64 [[IDX_FROZEN]], 4
+; CHECK-NEXT:    [[TMP0:%.*]] = getelementptr inbounds <vscale x 4 x i32>, ptr [[X:%.*]], i32 0, i64 [[IDX_CLAMPED]]
+; CHECK-NEXT:    [[R:%.*]] = load i32, ptr [[TMP0]], align 4
 ; CHECK-NEXT:    ret i32 [[R]]
 ;
 entry:
@@ -789,11 +793,14 @@ define i32 @load_multiple_extracts_with_variable_indices_large_vector_only_first
 
 define i32 @load_multiple_extracts_with_variable_indices_large_vector_all_valid_by_and(ptr %x, i64 %idx.0, i64 %idx.1) {
 ; CHECK-LABEL: @load_multiple_extracts_with_variable_indices_large_vector_all_valid_by_and(
-; CHECK-NEXT:    [[IDX_0_CLAMPED:%.*]] = and i64 [[IDX_0:%.*]], 15
-; CHECK-NEXT:    [[IDX_1_CLAMPED:%.*]] = and i64 [[IDX_1:%.*]], 15
-; CHECK-NEXT:    [[LV:%.*]] = load <16 x i32>, ptr [[X:%.*]], align 64
-; CHECK-NEXT:    [[E_0:%.*]] = extractelement <16 x i32> [[LV]], i64 [[IDX_0_CLAMPED]]
-; CHECK-NEXT:    [[E_1:%.*]] = extractelement <16 x i32> [[LV]], i64 [[IDX_1_CLAMPED]]
+; CHECK-NEXT:    [[IDX_0_FROZEN:%.*]] = freeze i64 [[IDX_0:%.*]]
+; CHECK-NEXT:    [[IDX_0_CLAMPED:%.*]] = and i64 [[IDX_0_FROZEN]], 15
+; CHECK-NEXT:    [[IDX_1_FROZEN:%.*]] = freeze i64 [[IDX_1:%.*]]
+; CHECK-NEXT:    [[IDX_1_CLAMPED:%.*]] = and i64 [[IDX_1_FROZEN]], 15
+; CHECK-NEXT:    [[TMP1:%.*]] = getelementptr inbounds <16 x i32>, ptr [[X:%.*]], i32 0, i64 [[IDX_0_CLAMPED]]
+; CHECK-NEXT:    [[E_0:%.*]] = load i32, ptr [[TMP1]], align 4
+; CHECK-NEXT:    [[TMP2:%.*]] = getelementptr inbounds <16 x i32>, ptr [[X]], i32 0, i64 [[IDX_1_CLAMPED]]
+; CHECK-NEXT:    [[E_1:%.*]] = load i32, ptr [[TMP2]], align 4
 ; CHECK-NEXT:    [[RES:%.*]] = add i32 [[E_0]], [[E_1]]
 ; CHECK-NEXT:    ret i32 [[RES]]
 ;
@@ -809,11 +816,13 @@ define i32 @load_multiple_extracts_with_variable_indices_large_vector_all_valid_
 
 define i32 @load_multiple_extracts_with_variable_indices_large_vector_all_valid_by_and_some_noundef(ptr %x, i64 %idx.0, i64 noundef %idx.1) {
 ; CHECK-LABEL: @load_multiple_extracts_with_variable_indices_large_vector_all_valid_by_and_some_noundef(
-; CHECK-NEXT:    [[IDX_0_CLAMPED:%.*]] = and i64 [[IDX_0:%.*]], 15
+; CHECK-NEXT:    [[IDX_0_FROZEN:%.*]] = freeze i64 [[IDX_0:%.*]]
+; CHECK-NEXT:    [[IDX_0_CLAMPED:%.*]] = and i64 [[IDX_0_FROZEN]], 15
 ; CHECK-NEXT:    [[IDX_1_CLAMPED:%.*]] = and i64 [[IDX_1:%.*]], 15
-; CHECK-NEXT:    [[LV:%.*]] = load <16 x i32>, ptr [[X:%.*]], align 64
-; CHECK-NEXT:    [[E_0:%.*]] = extractelement <16 x i32> [[LV]], i64 [[IDX_0_CLAMPED]]
-; CHECK-NEXT:    [[E_1:%.*]] = extractelement <16 x i32> [[LV]], i64 [[IDX_1_CLAMPED]]
+; CHECK-NEXT:    [[TMP1:%.*]] = getelementptr inbounds <16 x i32>, ptr [[X:%.*]], i32 0, i64 [[IDX_0_CLAMPED]]
+; CHECK-NEXT:    [[E_0:%.*]] = load i32, ptr [[TMP1]], align 4
+; CHECK-NEXT:    [[TMP2:%.*]] = getelementptr inbounds <16 x i32>, ptr [[X]], i32 0, i64 [[IDX_1_CLAMPED]]
+; CHECK-NEXT:    [[E_1:%.*]] = load i32, ptr [[TMP2]], align 4
 ; CHECK-NEXT:    [[RES:%.*]] = add i32 [[E_0]], [[E_1]]
 ; CHECK-NEXT:    ret i32 [[RES]]
 ;



More information about the llvm-commits mailing list