[llvm] [VectorCombine] Enable transform 'scalarizeLoadExtract' for non constant indexes (PR #65445)

Ben Shi via llvm-commits llvm-commits at lists.llvm.org
Wed Sep 6 21:35:02 PDT 2023


https://github.com/benshi001 updated https://github.com/llvm/llvm-project/pull/65445:

>From 74a625c836b060558b8589e81914937948a09983 Mon Sep 17 00:00:00 2001
From: Ben Shi <2283975856 at qq.com>
Date: Fri, 1 Sep 2023 14:43:42 +0800
Subject: [PATCH 1/3] [VectorCombine][test] Supplement tests of the
 load-extractelement sequence

The newly added tests are all about scalable vector types.
---
 .../load-extractelement-scalarization.ll      | 151 ++++++++++++++++++
 1 file changed, 151 insertions(+)

diff --git a/llvm/test/Transforms/VectorCombine/AArch64/load-extractelement-scalarization.ll b/llvm/test/Transforms/VectorCombine/AArch64/load-extractelement-scalarization.ll
index 1ed1476f6e2345e..7df4f49e095c96c 100644
--- a/llvm/test/Transforms/VectorCombine/AArch64/load-extractelement-scalarization.ll
+++ b/llvm/test/Transforms/VectorCombine/AArch64/load-extractelement-scalarization.ll
@@ -13,6 +13,17 @@ define i32 @load_extract_idx_0(ptr %x) {
   ret i32 %r
 }
 
+define i32 @vscale_load_extract_idx_0(ptr %x) {
+; CHECK-LABEL: @vscale_load_extract_idx_0(
+; CHECK-NEXT:    [[LV:%.*]] = load <vscale x 4 x i32>, ptr [[X:%.*]], align 16
+; CHECK-NEXT:    [[R:%.*]] = extractelement <vscale x 4 x i32> [[LV]], i32 0
+; CHECK-NEXT:    ret i32 [[R]]
+;
+  %lv = load <vscale x 4 x i32>, ptr %x
+  %r = extractelement <vscale x 4 x i32> %lv, i32 0
+  ret i32 %r
+}
+
 ; If the original load had a smaller alignment than the scalar type, the
 ; smaller alignment should be used.
 define i32 @load_extract_idx_0_small_alignment(ptr %x) {
@@ -48,6 +59,17 @@ define i32 @load_extract_idx_2(ptr %x) {
   ret i32 %r
 }
 
+define i32 @vscale_load_extract_idx_2(ptr %x) {
+; CHECK-LABEL: @vscale_load_extract_idx_2(
+; CHECK-NEXT:    [[LV:%.*]] = load <vscale x 4 x i32>, ptr [[X:%.*]], align 16
+; CHECK-NEXT:    [[R:%.*]] = extractelement <vscale x 4 x i32> [[LV]], i32 2
+; CHECK-NEXT:    ret i32 [[R]]
+;
+  %lv = load <vscale x 4 x i32>, ptr %x
+  %r = extractelement <vscale x 4 x i32> %lv, i32 2
+  ret i32 %r
+}
+
 define i32 @load_extract_idx_3(ptr %x) {
 ; CHECK-LABEL: @load_extract_idx_3(
 ; CHECK-NEXT:    [[TMP1:%.*]] = getelementptr inbounds <4 x i32>, ptr [[X:%.*]], i32 0, i32 3
@@ -72,6 +94,17 @@ define i32 @load_extract_idx_4(ptr %x) {
   ret i32 %r
 }
 
+define i32 @vscale_load_extract_idx_4(ptr %x) {
+; CHECK-LABEL: @vscale_load_extract_idx_4(
+; CHECK-NEXT:    [[LV:%.*]] = load <vscale x 4 x i32>, ptr [[X:%.*]], align 16
+; CHECK-NEXT:    [[R:%.*]] = extractelement <vscale x 4 x i32> [[LV]], i32 4
+; CHECK-NEXT:    ret i32 [[R]]
+;
+  %lv = load <vscale x 4 x i32>, ptr %x
+  %r = extractelement <vscale x 4 x i32> %lv, i32 4
+  ret i32 %r
+}
+
 define i32 @load_extract_idx_var_i64(ptr %x, i64 %idx) {
 ; CHECK-LABEL: @load_extract_idx_var_i64(
 ; CHECK-NEXT:    [[LV:%.*]] = load <4 x i32>, ptr [[X:%.*]], align 16
@@ -104,6 +137,25 @@ entry:
   ret i32 %r
 }
 
+define i32 @vscale_load_extract_idx_var_i64_known_valid_by_assume(ptr %x, i64 %idx) {
+; CHECK-LABEL: @vscale_load_extract_idx_var_i64_known_valid_by_assume(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[CMP:%.*]] = icmp ult i64 [[IDX:%.*]], 4
+; CHECK-NEXT:    call void @llvm.assume(i1 [[CMP]])
+; CHECK-NEXT:    [[LV:%.*]] = load <vscale x 4 x i32>, ptr [[X:%.*]], align 16
+; CHECK-NEXT:    call void @maythrow()
+; CHECK-NEXT:    [[R:%.*]] = extractelement <vscale x 4 x i32> [[LV]], i64 [[IDX]]
+; CHECK-NEXT:    ret i32 [[R]]
+;
+entry:
+  %cmp = icmp ult i64 %idx, 4
+  call void @llvm.assume(i1 %cmp)
+  %lv = load <vscale x 4 x i32>, ptr %x
+  call void @maythrow()
+  %r = extractelement <vscale x 4 x i32> %lv, i64 %idx
+  ret i32 %r
+}
+
 declare i1 @cond()
 
 define i32 @load_extract_idx_var_i64_known_valid_by_assume_in_dominating_block(ptr %x, i64 %idx, i1 %c.1) {
@@ -213,6 +265,45 @@ entry:
   ret i32 %r
 }
 
+define i32 @vscale_load_extract_idx_var_i64_not_known_valid_by_assume_0(ptr %x, i64 %idx) {
+; CHECK-LABEL: @vscale_load_extract_idx_var_i64_not_known_valid_by_assume_0(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[CMP:%.*]] = icmp ult i64 [[IDX:%.*]], 5
+; CHECK-NEXT:    call void @llvm.assume(i1 [[CMP]])
+; CHECK-NEXT:    [[LV:%.*]] = load <vscale x 4 x i32>, ptr [[X:%.*]], align 16
+; CHECK-NEXT:    [[R:%.*]] = extractelement <vscale x 4 x i32> [[LV]], i64 [[IDX]]
+; CHECK-NEXT:    ret i32 [[R]]
+;
+entry:
+  %cmp = icmp ult i64 %idx, 5
+  call void @llvm.assume(i1 %cmp)
+  %lv = load <vscale x 4 x i32>, ptr %x
+  %r = extractelement <vscale x 4 x i32> %lv, i64 %idx
+  ret i32 %r
+}
+
+define i32 @vscale_load_extract_idx_var_i64_not_known_valid_by_assume_1(ptr %x, i64 %idx) {
+; CHECK-LABEL: @vscale_load_extract_idx_var_i64_not_known_valid_by_assume_1(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[VS:%.*]] = call i64 @llvm.vscale.i64()
+; CHECK-NEXT:    [[VM:%.*]] = mul i64 [[VS]], 4
+; CHECK-NEXT:    [[CMP:%.*]] = icmp ult i64 [[IDX:%.*]], [[VM]]
+; CHECK-NEXT:    call void @llvm.assume(i1 [[CMP]])
+; CHECK-NEXT:    [[LV:%.*]] = load <vscale x 4 x i32>, ptr [[X:%.*]], align 16
+; CHECK-NEXT:    [[R:%.*]] = extractelement <vscale x 4 x i32> [[LV]], i64 [[IDX]]
+; CHECK-NEXT:    ret i32 [[R]]
+;
+entry:
+  %vs = call i64 @llvm.vscale.i64()
+  %vm = mul i64 %vs, 4
+  %cmp = icmp ult i64 %idx, %vm
+  call void @llvm.assume(i1 %cmp)
+  %lv = load <vscale x 4 x i32>, ptr %x
+  %r = extractelement <vscale x 4 x i32> %lv, i64 %idx
+  ret i32 %r
+}
+
+declare i64 @llvm.vscale.i64()
 declare void @llvm.assume(i1)
 
 define i32 @load_extract_idx_var_i64_known_valid_by_and(ptr %x, i64 %idx) {
@@ -230,6 +321,21 @@ entry:
   ret i32 %r
 }
 
+define i32 @vscale_load_extract_idx_var_i64_known_valid_by_and(ptr %x, i64 %idx) {
+; CHECK-LABEL: @vscale_load_extract_idx_var_i64_known_valid_by_and(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[IDX_CLAMPED:%.*]] = and i64 [[IDX:%.*]], 3
+; CHECK-NEXT:    [[LV:%.*]] = load <vscale x 4 x i32>, ptr [[X:%.*]], align 16
+; CHECK-NEXT:    [[R:%.*]] = extractelement <vscale x 4 x i32> [[LV]], i64 [[IDX_CLAMPED]]
+; CHECK-NEXT:    ret i32 [[R]]
+;
+entry:
+  %idx.clamped = and i64 %idx, 3
+  %lv = load <vscale x 4 x i32>, ptr %x
+  %r = extractelement <vscale x 4 x i32> %lv, i64 %idx.clamped
+  ret i32 %r
+}
+
 define i32 @load_extract_idx_var_i64_known_valid_by_and_noundef(ptr %x, i64 noundef %idx) {
 ; CHECK-LABEL: @load_extract_idx_var_i64_known_valid_by_and_noundef(
 ; CHECK-NEXT:  entry:
@@ -260,6 +366,21 @@ entry:
   ret i32 %r
 }
 
+define i32 @vscale_load_extract_idx_var_i64_not_known_valid_by_and(ptr %x, i64 %idx) {
+; CHECK-LABEL: @vscale_load_extract_idx_var_i64_not_known_valid_by_and(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[IDX_CLAMPED:%.*]] = and i64 [[IDX:%.*]], 4
+; CHECK-NEXT:    [[LV:%.*]] = load <vscale x 4 x i32>, ptr [[X:%.*]], align 16
+; CHECK-NEXT:    [[R:%.*]] = extractelement <vscale x 4 x i32> [[LV]], i64 [[IDX_CLAMPED]]
+; CHECK-NEXT:    ret i32 [[R]]
+;
+entry:
+  %idx.clamped = and i64 %idx, 4
+  %lv = load <vscale x 4 x i32>, ptr %x
+  %r = extractelement <vscale x 4 x i32> %lv, i64 %idx.clamped
+  ret i32 %r
+}
+
 define i32 @load_extract_idx_var_i64_known_valid_by_urem(ptr %x, i64 %idx) {
 ; CHECK-LABEL: @load_extract_idx_var_i64_known_valid_by_urem(
 ; CHECK-NEXT:  entry:
@@ -275,6 +396,21 @@ entry:
   ret i32 %r
 }
 
+define i32 @vscale_load_extract_idx_var_i64_known_valid_by_urem(ptr %x, i64 %idx) {
+; CHECK-LABEL: @vscale_load_extract_idx_var_i64_known_valid_by_urem(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[IDX_CLAMPED:%.*]] = urem i64 [[IDX:%.*]], 4
+; CHECK-NEXT:    [[LV:%.*]] = load <vscale x 4 x i32>, ptr [[X:%.*]], align 16
+; CHECK-NEXT:    [[R:%.*]] = extractelement <vscale x 4 x i32> [[LV]], i64 [[IDX_CLAMPED]]
+; CHECK-NEXT:    ret i32 [[R]]
+;
+entry:
+  %idx.clamped = urem i64 %idx, 4
+  %lv = load <vscale x 4 x i32>, ptr %x
+  %r = extractelement <vscale x 4 x i32> %lv, i64 %idx.clamped
+  ret i32 %r
+}
+
 define i32 @load_extract_idx_var_i64_known_valid_by_urem_noundef(ptr %x, i64 noundef %idx) {
 ; CHECK-LABEL: @load_extract_idx_var_i64_known_valid_by_urem_noundef(
 ; CHECK-NEXT:  entry:
@@ -305,6 +441,21 @@ entry:
   ret i32 %r
 }
 
+define i32 @vscale_load_extract_idx_var_i64_not_known_valid_by_urem(ptr %x, i64 %idx) {
+; CHECK-LABEL: @vscale_load_extract_idx_var_i64_not_known_valid_by_urem(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[IDX_CLAMPED:%.*]] = urem i64 [[IDX:%.*]], 5
+; CHECK-NEXT:    [[LV:%.*]] = load <vscale x 4 x i32>, ptr [[X:%.*]], align 16
+; CHECK-NEXT:    [[R:%.*]] = extractelement <vscale x 4 x i32> [[LV]], i64 [[IDX_CLAMPED]]
+; CHECK-NEXT:    ret i32 [[R]]
+;
+entry:
+  %idx.clamped = urem i64 %idx, 5
+  %lv = load <vscale x 4 x i32>, ptr %x
+  %r = extractelement <vscale x 4 x i32> %lv, i64 %idx.clamped
+  ret i32 %r
+}
+
 define i32 @load_extract_idx_var_i32(ptr %x, i32 %idx) {
 ; CHECK-LABEL: @load_extract_idx_var_i32(
 ; CHECK-NEXT:    [[LV:%.*]] = load <4 x i32>, ptr [[X:%.*]], align 16

>From db60c013379f33c5749d2ed6448d27b1c7487779 Mon Sep 17 00:00:00 2001
From: Ben Shi <2283975856 at qq.com>
Date: Thu, 7 Sep 2023 10:57:45 +0800
Subject: [PATCH 2/3] [VectorCombine] Enable transform 'scalarizeLoadExtract'
 for scalable vector types

The transform 'scalarizeLoadExtract' can be applied to scalable
vector types if the index is less than the minimum number of elements.

The check whether the index is less than the minimum number of elements
locates at line 1175~1180. 'scalarizeLoadExtract' will call 'canScalarizeAccess'
and check the returned result if this transform is safe.

At the beginning of the function 'canScalarizeAccess', the index will be checked
1. if it is less than the number of elements of a fixed vector type.
2. if it is less than the minimum number of elements of a scalable vector type.

Otherwise 'canScalarizeAccess' will return unsafe and this transform will be
prevented.
---
 .../Transforms/Vectorize/VectorCombine.cpp    | 25 +++++++++----------
 .../load-extractelement-scalarization.ll      | 12 ++++-----
 2 files changed, 18 insertions(+), 19 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
index 66e3bcaac0adb2e..4f95eaba8de7bd2 100644
--- a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
+++ b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
@@ -1134,14 +1134,14 @@ bool VectorCombine::scalarizeLoadExtract(Instruction &I) {
   if (!match(&I, m_Load(m_Value(Ptr))))
     return false;
 
-  auto *FixedVT = cast<FixedVectorType>(I.getType());
+  auto *VecTy = cast<VectorType>(I.getType());
   auto *LI = cast<LoadInst>(&I);
   const DataLayout &DL = I.getModule()->getDataLayout();
-  if (LI->isVolatile() || !DL.typeSizeEqualsStoreSize(FixedVT))
+  if (LI->isVolatile() || !DL.typeSizeEqualsStoreSize(VecTy))
     return false;
 
   InstructionCost OriginalCost =
-      TTI.getMemoryOpCost(Instruction::Load, FixedVT, LI->getAlign(),
+      TTI.getMemoryOpCost(Instruction::Load, VecTy, LI->getAlign(),
                           LI->getPointerAddressSpace());
   InstructionCost ScalarizedCost = 0;
 
@@ -1172,7 +1172,7 @@ bool VectorCombine::scalarizeLoadExtract(Instruction &I) {
       LastCheckedInst = UI;
     }
 
-    auto ScalarIdx = canScalarizeAccess(FixedVT, UI->getOperand(1), &I, AC, DT);
+    auto ScalarIdx = canScalarizeAccess(VecTy, UI->getOperand(1), &I, AC, DT);
     if (!ScalarIdx.isSafe()) {
       // TODO: Freeze index if it is safe to do so.
       ScalarIdx.discard();
@@ -1182,12 +1182,12 @@ bool VectorCombine::scalarizeLoadExtract(Instruction &I) {
     auto *Index = dyn_cast<ConstantInt>(UI->getOperand(1));
     TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
     OriginalCost +=
-        TTI.getVectorInstrCost(Instruction::ExtractElement, FixedVT, CostKind,
+        TTI.getVectorInstrCost(Instruction::ExtractElement, VecTy, CostKind,
                                Index ? Index->getZExtValue() : -1);
     ScalarizedCost +=
-        TTI.getMemoryOpCost(Instruction::Load, FixedVT->getElementType(),
+        TTI.getMemoryOpCost(Instruction::Load, VecTy->getElementType(),
                             Align(1), LI->getPointerAddressSpace());
-    ScalarizedCost += TTI.getAddressComputationCost(FixedVT->getElementType());
+    ScalarizedCost += TTI.getAddressComputationCost(VecTy->getElementType());
   }
 
   if (ScalarizedCost >= OriginalCost)
@@ -1200,12 +1200,12 @@ bool VectorCombine::scalarizeLoadExtract(Instruction &I) {
 
     Value *Idx = EI->getOperand(1);
     Value *GEP =
-        Builder.CreateInBoundsGEP(FixedVT, Ptr, {Builder.getInt32(0), Idx});
+        Builder.CreateInBoundsGEP(VecTy, Ptr, {Builder.getInt32(0), Idx});
     auto *NewLoad = cast<LoadInst>(Builder.CreateLoad(
-        FixedVT->getElementType(), GEP, EI->getName() + ".scalar"));
+        VecTy->getElementType(), GEP, EI->getName() + ".scalar"));
 
     Align ScalarOpAlignment = computeAlignmentAfterScalarization(
-        LI->getAlign(), FixedVT->getElementType(), Idx, DL);
+        LI->getAlign(), VecTy->getElementType(), Idx, DL);
     NewLoad->setAlignment(ScalarOpAlignment);
 
     replaceValue(*EI, *NewLoad);
@@ -1727,9 +1727,6 @@ bool VectorCombine::run() {
       case Instruction::ShuffleVector:
         MadeChange |= widenSubvectorLoad(I);
         break;
-      case Instruction::Load:
-        MadeChange |= scalarizeLoadExtract(I);
-        break;
       default:
         break;
       }
@@ -1743,6 +1740,8 @@ bool VectorCombine::run() {
     if (Opcode == Instruction::Store)
       MadeChange |= foldSingleElementStore(I);
 
+    if (isa<VectorType>(I.getType()) && Opcode == Instruction::Load)
+      MadeChange |= scalarizeLoadExtract(I);
 
     // If this is an early pipeline invocation of this pass, we are done.
     if (TryEarlyFoldsOnly)
diff --git a/llvm/test/Transforms/VectorCombine/AArch64/load-extractelement-scalarization.ll b/llvm/test/Transforms/VectorCombine/AArch64/load-extractelement-scalarization.ll
index 7df4f49e095c96c..c7e5979aa9e7bd9 100644
--- a/llvm/test/Transforms/VectorCombine/AArch64/load-extractelement-scalarization.ll
+++ b/llvm/test/Transforms/VectorCombine/AArch64/load-extractelement-scalarization.ll
@@ -15,8 +15,8 @@ define i32 @load_extract_idx_0(ptr %x) {
 
 define i32 @vscale_load_extract_idx_0(ptr %x) {
 ; CHECK-LABEL: @vscale_load_extract_idx_0(
-; CHECK-NEXT:    [[LV:%.*]] = load <vscale x 4 x i32>, ptr [[X:%.*]], align 16
-; CHECK-NEXT:    [[R:%.*]] = extractelement <vscale x 4 x i32> [[LV]], i32 0
+; CHECK-NEXT:    [[TMP1:%.*]] = getelementptr inbounds <vscale x 4 x i32>, ptr [[X:%.*]], i32 0, i32 0
+; CHECK-NEXT:    [[R:%.*]] = load i32, ptr [[TMP1]], align 16
 ; CHECK-NEXT:    ret i32 [[R]]
 ;
   %lv = load <vscale x 4 x i32>, ptr %x
@@ -61,8 +61,8 @@ define i32 @load_extract_idx_2(ptr %x) {
 
 define i32 @vscale_load_extract_idx_2(ptr %x) {
 ; CHECK-LABEL: @vscale_load_extract_idx_2(
-; CHECK-NEXT:    [[LV:%.*]] = load <vscale x 4 x i32>, ptr [[X:%.*]], align 16
-; CHECK-NEXT:    [[R:%.*]] = extractelement <vscale x 4 x i32> [[LV]], i32 2
+; CHECK-NEXT:    [[TMP1:%.*]] = getelementptr inbounds <vscale x 4 x i32>, ptr [[X:%.*]], i32 0, i32 2
+; CHECK-NEXT:    [[R:%.*]] = load i32, ptr [[TMP1]], align 8
 ; CHECK-NEXT:    ret i32 [[R]]
 ;
   %lv = load <vscale x 4 x i32>, ptr %x
@@ -142,9 +142,9 @@ define i32 @vscale_load_extract_idx_var_i64_known_valid_by_assume(ptr %x, i64 %i
 ; CHECK-NEXT:  entry:
 ; CHECK-NEXT:    [[CMP:%.*]] = icmp ult i64 [[IDX:%.*]], 4
 ; CHECK-NEXT:    call void @llvm.assume(i1 [[CMP]])
-; CHECK-NEXT:    [[LV:%.*]] = load <vscale x 4 x i32>, ptr [[X:%.*]], align 16
 ; CHECK-NEXT:    call void @maythrow()
-; CHECK-NEXT:    [[R:%.*]] = extractelement <vscale x 4 x i32> [[LV]], i64 [[IDX]]
+; CHECK-NEXT:    [[TMP0:%.*]] = getelementptr inbounds <vscale x 4 x i32>, ptr [[X:%.*]], i32 0, i64 [[IDX]]
+; CHECK-NEXT:    [[R:%.*]] = load i32, ptr [[TMP0]], align 4
 ; CHECK-NEXT:    ret i32 [[R]]
 ;
 entry:

>From 75802ac068855cc9904c66c555a2a29390f7ee59 Mon Sep 17 00:00:00 2001
From: Ben Shi <2283975856 at qq.com>
Date: Thu, 7 Sep 2023 12:32:02 +0800
Subject: [PATCH 3/3] [VectorCombine] Enable transform 'scalarizeLoadExtract'
 for non constant indexes

Enable the transform if a non constant index is guaranteed to be safe
via a UREM/AND.
---
 .../Transforms/Vectorize/VectorCombine.cpp    | 27 ++++++----
 .../load-extractelement-scalarization.ll      | 51 +++++++++++--------
 2 files changed, 48 insertions(+), 30 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
index 4f95eaba8de7bd2..1398f9e380e4813 100644
--- a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
+++ b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
@@ -13,6 +13,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "llvm/Transforms/Vectorize/VectorCombine.h"
+#include "llvm/ADT/DenseMap.h"
 #include "llvm/ADT/Statistic.h"
 #include "llvm/Analysis/AssumptionCache.h"
 #include "llvm/Analysis/BasicAliasAnalysis.h"
@@ -969,7 +970,11 @@ class ScalarizationResult {
 public:
   ScalarizationResult(const ScalarizationResult &Other) = default;
   ~ScalarizationResult() {
-    assert(!ToFreeze && "freeze() not called with ToFreeze being set");
+    // The object may be copied to another scope if it is in state
+    // StatusTy::SafeWithFreeze.
+    if (Status != StatusTy::SafeWithFreeze)
+      assert(!ToFreeze &&
+             "freeze() or discard() not called with ToFreeze being set");
   }
 
   static ScalarizationResult unsafe() { return {StatusTy::Unsafe}; }
@@ -1147,6 +1152,7 @@ bool VectorCombine::scalarizeLoadExtract(Instruction &I) {
 
   Instruction *LastCheckedInst = LI;
   unsigned NumInstChecked = 0;
+  DenseMap<ExtractElementInst *, ScalarizationResult> NeedFreeze;
   // Check if all users of the load are extracts with no memory modifications
   // between the load and the extract. Compute the cost of both the original
   // code and the scalarized version.
@@ -1155,9 +1161,6 @@ bool VectorCombine::scalarizeLoadExtract(Instruction &I) {
     if (!UI || UI->getParent() != LI->getParent())
       return false;
 
-    if (!isGuaranteedNotToBePoison(UI->getOperand(1), &AC, LI, &DT))
-      return false;
-
     // Check if any instruction between the load and the extract may modify
     // memory.
     if (LastCheckedInst->comesBefore(UI)) {
@@ -1173,10 +1176,11 @@ bool VectorCombine::scalarizeLoadExtract(Instruction &I) {
     }
 
     auto ScalarIdx = canScalarizeAccess(VecTy, UI->getOperand(1), &I, AC, DT);
-    if (!ScalarIdx.isSafe()) {
-      // TODO: Freeze index if it is safe to do so.
-      ScalarIdx.discard();
+    if (ScalarIdx.isUnsafe()) {
       return false;
+    } else if (ScalarIdx.isSafeWithFreeze()) {
+      NeedFreeze.insert(std::make_pair(UI, ScalarIdx));
+      ScalarIdx.discard();
     }
 
     auto *Index = dyn_cast<ConstantInt>(UI->getOperand(1));
@@ -1196,9 +1200,14 @@ bool VectorCombine::scalarizeLoadExtract(Instruction &I) {
   // Replace extracts with narrow scalar loads.
   for (User *U : LI->users()) {
     auto *EI = cast<ExtractElementInst>(U);
-    Builder.SetInsertPoint(EI);
-
     Value *Idx = EI->getOperand(1);
+
+    // Insert 'freeze' if needed.
+    DenseMap<ExtractElementInst *, ScalarizationResult>::iterator It;
+    if ((It = NeedFreeze.find(EI)) != NeedFreeze.end())
+      It->second.freeze(Builder, *cast<Instruction>(Idx));
+
+    Builder.SetInsertPoint(EI);
     Value *GEP =
         Builder.CreateInBoundsGEP(VecTy, Ptr, {Builder.getInt32(0), Idx});
     auto *NewLoad = cast<LoadInst>(Builder.CreateLoad(
diff --git a/llvm/test/Transforms/VectorCombine/AArch64/load-extractelement-scalarization.ll b/llvm/test/Transforms/VectorCombine/AArch64/load-extractelement-scalarization.ll
index c7e5979aa9e7bd9..42b3f9afeb56ee8 100644
--- a/llvm/test/Transforms/VectorCombine/AArch64/load-extractelement-scalarization.ll
+++ b/llvm/test/Transforms/VectorCombine/AArch64/load-extractelement-scalarization.ll
@@ -309,9 +309,10 @@ declare void @llvm.assume(i1)
 define i32 @load_extract_idx_var_i64_known_valid_by_and(ptr %x, i64 %idx) {
 ; CHECK-LABEL: @load_extract_idx_var_i64_known_valid_by_and(
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[IDX_CLAMPED:%.*]] = and i64 [[IDX:%.*]], 3
-; CHECK-NEXT:    [[LV:%.*]] = load <4 x i32>, ptr [[X:%.*]], align 16
-; CHECK-NEXT:    [[R:%.*]] = extractelement <4 x i32> [[LV]], i64 [[IDX_CLAMPED]]
+; CHECK-NEXT:    [[IDX_FROZEN:%.*]] = freeze i64 [[IDX:%.*]]
+; CHECK-NEXT:    [[IDX_CLAMPED:%.*]] = and i64 [[IDX_FROZEN]], 3
+; CHECK-NEXT:    [[TMP0:%.*]] = getelementptr inbounds <4 x i32>, ptr [[X:%.*]], i32 0, i64 [[IDX_CLAMPED]]
+; CHECK-NEXT:    [[R:%.*]] = load i32, ptr [[TMP0]], align 4
 ; CHECK-NEXT:    ret i32 [[R]]
 ;
 entry:
@@ -324,9 +325,10 @@ entry:
 define i32 @vscale_load_extract_idx_var_i64_known_valid_by_and(ptr %x, i64 %idx) {
 ; CHECK-LABEL: @vscale_load_extract_idx_var_i64_known_valid_by_and(
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[IDX_CLAMPED:%.*]] = and i64 [[IDX:%.*]], 3
-; CHECK-NEXT:    [[LV:%.*]] = load <vscale x 4 x i32>, ptr [[X:%.*]], align 16
-; CHECK-NEXT:    [[R:%.*]] = extractelement <vscale x 4 x i32> [[LV]], i64 [[IDX_CLAMPED]]
+; CHECK-NEXT:    [[IDX_FROZEN:%.*]] = freeze i64 [[IDX:%.*]]
+; CHECK-NEXT:    [[IDX_CLAMPED:%.*]] = and i64 [[IDX_FROZEN]], 3
+; CHECK-NEXT:    [[TMP0:%.*]] = getelementptr inbounds <vscale x 4 x i32>, ptr [[X:%.*]], i32 0, i64 [[IDX_CLAMPED]]
+; CHECK-NEXT:    [[R:%.*]] = load i32, ptr [[TMP0]], align 4
 ; CHECK-NEXT:    ret i32 [[R]]
 ;
 entry:
@@ -384,9 +386,10 @@ entry:
 define i32 @load_extract_idx_var_i64_known_valid_by_urem(ptr %x, i64 %idx) {
 ; CHECK-LABEL: @load_extract_idx_var_i64_known_valid_by_urem(
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[IDX_CLAMPED:%.*]] = urem i64 [[IDX:%.*]], 4
-; CHECK-NEXT:    [[LV:%.*]] = load <4 x i32>, ptr [[X:%.*]], align 16
-; CHECK-NEXT:    [[R:%.*]] = extractelement <4 x i32> [[LV]], i64 [[IDX_CLAMPED]]
+; CHECK-NEXT:    [[IDX_FROZEN:%.*]] = freeze i64 [[IDX:%.*]]
+; CHECK-NEXT:    [[IDX_CLAMPED:%.*]] = urem i64 [[IDX_FROZEN]], 4
+; CHECK-NEXT:    [[TMP0:%.*]] = getelementptr inbounds <4 x i32>, ptr [[X:%.*]], i32 0, i64 [[IDX_CLAMPED]]
+; CHECK-NEXT:    [[R:%.*]] = load i32, ptr [[TMP0]], align 4
 ; CHECK-NEXT:    ret i32 [[R]]
 ;
 entry:
@@ -399,9 +402,10 @@ entry:
 define i32 @vscale_load_extract_idx_var_i64_known_valid_by_urem(ptr %x, i64 %idx) {
 ; CHECK-LABEL: @vscale_load_extract_idx_var_i64_known_valid_by_urem(
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[IDX_CLAMPED:%.*]] = urem i64 [[IDX:%.*]], 4
-; CHECK-NEXT:    [[LV:%.*]] = load <vscale x 4 x i32>, ptr [[X:%.*]], align 16
-; CHECK-NEXT:    [[R:%.*]] = extractelement <vscale x 4 x i32> [[LV]], i64 [[IDX_CLAMPED]]
+; CHECK-NEXT:    [[IDX_FROZEN:%.*]] = freeze i64 [[IDX:%.*]]
+; CHECK-NEXT:    [[IDX_CLAMPED:%.*]] = urem i64 [[IDX_FROZEN]], 4
+; CHECK-NEXT:    [[TMP0:%.*]] = getelementptr inbounds <vscale x 4 x i32>, ptr [[X:%.*]], i32 0, i64 [[IDX_CLAMPED]]
+; CHECK-NEXT:    [[R:%.*]] = load i32, ptr [[TMP0]], align 4
 ; CHECK-NEXT:    ret i32 [[R]]
 ;
 entry:
@@ -789,11 +793,14 @@ define i32 @load_multiple_extracts_with_variable_indices_large_vector_only_first
 
 define i32 @load_multiple_extracts_with_variable_indices_large_vector_all_valid_by_and(ptr %x, i64 %idx.0, i64 %idx.1) {
 ; CHECK-LABEL: @load_multiple_extracts_with_variable_indices_large_vector_all_valid_by_and(
-; CHECK-NEXT:    [[IDX_0_CLAMPED:%.*]] = and i64 [[IDX_0:%.*]], 15
-; CHECK-NEXT:    [[IDX_1_CLAMPED:%.*]] = and i64 [[IDX_1:%.*]], 15
-; CHECK-NEXT:    [[LV:%.*]] = load <16 x i32>, ptr [[X:%.*]], align 64
-; CHECK-NEXT:    [[E_0:%.*]] = extractelement <16 x i32> [[LV]], i64 [[IDX_0_CLAMPED]]
-; CHECK-NEXT:    [[E_1:%.*]] = extractelement <16 x i32> [[LV]], i64 [[IDX_1_CLAMPED]]
+; CHECK-NEXT:    [[IDX_0_FROZEN:%.*]] = freeze i64 [[IDX_0:%.*]]
+; CHECK-NEXT:    [[IDX_0_CLAMPED:%.*]] = and i64 [[IDX_0_FROZEN]], 15
+; CHECK-NEXT:    [[IDX_1_FROZEN:%.*]] = freeze i64 [[IDX_1:%.*]]
+; CHECK-NEXT:    [[IDX_1_CLAMPED:%.*]] = and i64 [[IDX_1_FROZEN]], 15
+; CHECK-NEXT:    [[TMP1:%.*]] = getelementptr inbounds <16 x i32>, ptr [[X:%.*]], i32 0, i64 [[IDX_0_CLAMPED]]
+; CHECK-NEXT:    [[E_0:%.*]] = load i32, ptr [[TMP1]], align 4
+; CHECK-NEXT:    [[TMP2:%.*]] = getelementptr inbounds <16 x i32>, ptr [[X]], i32 0, i64 [[IDX_1_CLAMPED]]
+; CHECK-NEXT:    [[E_1:%.*]] = load i32, ptr [[TMP2]], align 4
 ; CHECK-NEXT:    [[RES:%.*]] = add i32 [[E_0]], [[E_1]]
 ; CHECK-NEXT:    ret i32 [[RES]]
 ;
@@ -809,11 +816,13 @@ define i32 @load_multiple_extracts_with_variable_indices_large_vector_all_valid_
 
 define i32 @load_multiple_extracts_with_variable_indices_large_vector_all_valid_by_and_some_noundef(ptr %x, i64 %idx.0, i64 noundef %idx.1) {
 ; CHECK-LABEL: @load_multiple_extracts_with_variable_indices_large_vector_all_valid_by_and_some_noundef(
-; CHECK-NEXT:    [[IDX_0_CLAMPED:%.*]] = and i64 [[IDX_0:%.*]], 15
+; CHECK-NEXT:    [[IDX_0_FROZEN:%.*]] = freeze i64 [[IDX_0:%.*]]
+; CHECK-NEXT:    [[IDX_0_CLAMPED:%.*]] = and i64 [[IDX_0_FROZEN]], 15
 ; CHECK-NEXT:    [[IDX_1_CLAMPED:%.*]] = and i64 [[IDX_1:%.*]], 15
-; CHECK-NEXT:    [[LV:%.*]] = load <16 x i32>, ptr [[X:%.*]], align 64
-; CHECK-NEXT:    [[E_0:%.*]] = extractelement <16 x i32> [[LV]], i64 [[IDX_0_CLAMPED]]
-; CHECK-NEXT:    [[E_1:%.*]] = extractelement <16 x i32> [[LV]], i64 [[IDX_1_CLAMPED]]
+; CHECK-NEXT:    [[TMP1:%.*]] = getelementptr inbounds <16 x i32>, ptr [[X:%.*]], i32 0, i64 [[IDX_0_CLAMPED]]
+; CHECK-NEXT:    [[E_0:%.*]] = load i32, ptr [[TMP1]], align 4
+; CHECK-NEXT:    [[TMP2:%.*]] = getelementptr inbounds <16 x i32>, ptr [[X]], i32 0, i64 [[IDX_1_CLAMPED]]
+; CHECK-NEXT:    [[E_1:%.*]] = load i32, ptr [[TMP2]], align 4
 ; CHECK-NEXT:    [[RES:%.*]] = add i32 [[E_0]], [[E_1]]
 ; CHECK-NEXT:    ret i32 [[RES]]
 ;



More information about the llvm-commits mailing list