[llvm] [VectorCombine] Enable transform 'scalarizeLoadExtract' for non constant indexes (PR #65445)
via llvm-commits
llvm-commits at lists.llvm.org
Wed Sep 13 04:08:34 PDT 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-llvm-transforms
<details>
<summary>Changes</summary>
Enable the transform if a non constant index is guaranteed to be safe
via a UREM/AND.
This PR is stacked on https://github.com/llvm/llvm-project/pull/65443
--
Full diff: https://github.com/llvm/llvm-project/pull/65445.diff
2 Files Affected:
- (modified) llvm/lib/Transforms/Vectorize/VectorCombine.cpp (+30-22)
- (modified) llvm/test/Transforms/VectorCombine/AArch64/load-extractelement-scalarization.ll (+36-27)
<pre>
diff --git a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
index 66e3bcaac0adb2e..830804ddd3b8024 100644
--- a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
+++ b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
@@ -13,6 +13,7 @@
//===----------------------------------------------------------------------===//
#include "llvm/Transforms/Vectorize/VectorCombine.h"
+#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/Statistic.h"
#include "llvm/Analysis/AssumptionCache.h"
#include "llvm/Analysis/BasicAliasAnalysis.h"
@@ -969,7 +970,11 @@ class ScalarizationResult {
public:
ScalarizationResult(const ScalarizationResult &Other) = default;
~ScalarizationResult() {
- assert(!ToFreeze && "freeze() not called with ToFreeze being set");
+ // The object may be copied to another scope if it is in state
+ // StatusTy::SafeWithFreeze.
+ if (Status != StatusTy::SafeWithFreeze)
+ assert(!ToFreeze &&
+ "freeze() or discard() not called with ToFreeze being set");
}
static ScalarizationResult unsafe() { return {StatusTy::Unsafe}; }
@@ -1134,19 +1139,20 @@ bool VectorCombine::scalarizeLoadExtract(Instruction &I) {
if (!match(&I, m_Load(m_Value(Ptr))))
return false;
- auto *FixedVT = cast<FixedVectorType>(I.getType());
+ auto *VecTy = cast<VectorType>(I.getType());
auto *LI = cast<LoadInst>(&I);
const DataLayout &DL = I.getModule()->getDataLayout();
- if (LI->isVolatile() || !DL.typeSizeEqualsStoreSize(FixedVT))
+ if (LI->isVolatile() || !DL.typeSizeEqualsStoreSize(VecTy))
return false;
InstructionCost OriginalCost =
- TTI.getMemoryOpCost(Instruction::Load, FixedVT, LI->getAlign(),
+ TTI.getMemoryOpCost(Instruction::Load, VecTy, LI->getAlign(),
LI->getPointerAddressSpace());
InstructionCost ScalarizedCost = 0;
Instruction *LastCheckedInst = LI;
unsigned NumInstChecked = 0;
+ DenseMap<ExtractElementInst *, ScalarizationResult> NeedFreeze;
// Check if all users of the load are extracts with no memory modifications
// between the load and the extract. Compute the cost of both the original
// code and the scalarized version.
@@ -1155,9 +1161,6 @@ bool VectorCombine::scalarizeLoadExtract(Instruction &I) {
if (!UI || UI->getParent() != LI->getParent())
return false;
- if (!isGuaranteedNotToBePoison(UI->getOperand(1), &AC, LI, &DT))
- return false;
-
// Check if any instruction between the load and the extract may modify
// memory.
if (LastCheckedInst->comesBefore(UI)) {
@@ -1172,22 +1175,23 @@ bool VectorCombine::scalarizeLoadExtract(Instruction &I) {
LastCheckedInst = UI;
}
- auto ScalarIdx = canScalarizeAccess(FixedVT, UI->getOperand(1), &I, AC, DT);
- if (!ScalarIdx.isSafe()) {
- // TODO: Freeze index if it is safe to do so.
- ScalarIdx.discard();
+ auto ScalarIdx = canScalarizeAccess(VecTy, UI->getOperand(1), &I, AC, DT);
+ if (ScalarIdx.isUnsafe()) {
return false;
+ } else if (ScalarIdx.isSafeWithFreeze()) {
+ NeedFreeze.insert(std::make_pair(UI, ScalarIdx));
+ ScalarIdx.discard();
}
auto *Index = dyn_cast<ConstantInt>(UI->getOperand(1));
TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
OriginalCost +=
- TTI.getVectorInstrCost(Instruction::ExtractElement, FixedVT, CostKind,
+ TTI.getVectorInstrCost(Instruction::ExtractElement, VecTy, CostKind,
Index ? Index->getZExtValue() : -1);
ScalarizedCost +=
- TTI.getMemoryOpCost(Instruction::Load, FixedVT->getElementType(),
+ TTI.getMemoryOpCost(Instruction::Load, VecTy->getElementType(),
Align(1), LI->getPointerAddressSpace());
- ScalarizedCost += TTI.getAddressComputationCost(FixedVT->getElementType());
+ ScalarizedCost += TTI.getAddressComputationCost(VecTy->getElementType());
}
if (ScalarizedCost >= OriginalCost)
@@ -1196,16 +1200,21 @@ bool VectorCombine::scalarizeLoadExtract(Instruction &I) {
// Replace extracts with narrow scalar loads.
for (User *U : LI->users()) {
auto *EI = cast<ExtractElementInst>(U);
- Builder.SetInsertPoint(EI);
-
Value *Idx = EI->getOperand(1);
+
+ // Insert 'freeze' for poison indexes.
+ DenseMap<ExtractElementInst *, ScalarizationResult>::iterator It;
+ if ((It = NeedFreeze.find(EI)) != NeedFreeze.end())
+ It->second.freeze(Builder, *cast<Instruction>(Idx));
+
+ Builder.SetInsertPoint(EI);
Value *GEP =
- Builder.CreateInBoundsGEP(FixedVT, Ptr, {Builder.getInt32(0), Idx});
+ Builder.CreateInBoundsGEP(VecTy, Ptr, {Builder.getInt32(0), Idx});
auto *NewLoad = cast<LoadInst>(Builder.CreateLoad(
- FixedVT->getElementType(), GEP, EI->getName() + ".scalar"));
+ VecTy->getElementType(), GEP, EI->getName() + ".scalar"));
Align ScalarOpAlignment = computeAlignmentAfterScalarization(
- LI->getAlign(), FixedVT->getElementType(), Idx, DL);
+ LI->getAlign(), VecTy->getElementType(), Idx, DL);
NewLoad->setAlignment(ScalarOpAlignment);
replaceValue(*EI, *NewLoad);
@@ -1727,9 +1736,6 @@ bool VectorCombine::run() {
case Instruction::ShuffleVector:
MadeChange |= widenSubvectorLoad(I);
break;
- case Instruction::Load:
- MadeChange |= scalarizeLoadExtract(I);
- break;
default:
break;
}
@@ -1743,6 +1749,8 @@ bool VectorCombine::run() {
if (Opcode == Instruction::Store)
MadeChange |= foldSingleElementStore(I);
+ if (isa<VectorType>(I.getType()) && Opcode == Instruction::Load)
+ MadeChange |= scalarizeLoadExtract(I);
// If this is an early pipeline invocation of this pass, we are done.
if (TryEarlyFoldsOnly)
diff --git a/llvm/test/Transforms/VectorCombine/AArch64/load-extractelement-scalarization.ll b/llvm/test/Transforms/VectorCombine/AArch64/load-extractelement-scalarization.ll
index 7df4f49e095c96c..42b3f9afeb56ee8 100644
--- a/llvm/test/Transforms/VectorCombine/AArch64/load-extractelement-scalarization.ll
+++ b/llvm/test/Transforms/VectorCombine/AArch64/load-extractelement-scalarization.ll
@@ -15,8 +15,8 @@ define i32 @load_extract_idx_0(ptr %x) {
define i32 @vscale_load_extract_idx_0(ptr %x) {
; CHECK-LABEL: @vscale_load_extract_idx_0(
-; CHECK-NEXT: [[LV:%.*]] = load <vscale x 4 x i32>, ptr [[X:%.*]], align 16
-; CHECK-NEXT: [[R:%.*]] = extractelement <vscale x 4 x i32> [[LV]], i32 0
+; CHECK-NEXT: [[TMP1:%.*]] = getelementptr inbounds <vscale x 4 x i32>, ptr [[X:%.*]], i32 0, i32 0
+; CHECK-NEXT: [[R:%.*]] = load i32, ptr [[TMP1]], align 16
; CHECK-NEXT: ret i32 [[R]]
;
%lv = load <vscale x 4 x i32>, ptr %x
@@ -61,8 +61,8 @@ define i32 @load_extract_idx_2(ptr %x) {
define i32 @vscale_load_extract_idx_2(ptr %x) {
; CHECK-LABEL: @vscale_load_extract_idx_2(
-; CHECK-NEXT: [[LV:%.*]] = load <vscale x 4 x i32>, ptr [[X:%.*]], align 16
-; CHECK-NEXT: [[R:%.*]] = extractelement <vscale x 4 x i32> [[LV]], i32 2
+; CHECK-NEXT: [[TMP1:%.*]] = getelementptr inbounds <vscale x 4 x i32>, ptr [[X:%.*]], i32 0, i32 2
+; CHECK-NEXT: [[R:%.*]] = load i32, ptr [[TMP1]], align 8
; CHECK-NEXT: ret i32 [[R]]
;
%lv = load <vscale x 4 x i32>, ptr %x
@@ -142,9 +142,9 @@ define i32 @vscale_load_extract_idx_var_i64_known_valid_by_assume(ptr %x, i64 %i
; CHECK-NEXT: entry:
; CHECK-NEXT: [[CMP:%.*]] = icmp ult i64 [[IDX:%.*]], 4
; CHECK-NEXT: call void @llvm.assume(i1 [[CMP]])
-; CHECK-NEXT: [[LV:%.*]] = load <vscale x 4 x i32>, ptr [[X:%.*]], align 16
; CHECK-NEXT: call void @maythrow()
-; CHECK-NEXT: [[R:%.*]] = extractelement <vscale x 4 x i32> [[LV]], i64 [[IDX]]
+; CHECK-NEXT: [[TMP0:%.*]] = getelementptr inbounds <vscale x 4 x i32>, ptr [[X:%.*]], i32 0, i64 [[IDX]]
+; CHECK-NEXT: [[R:%.*]] = load i32, ptr [[TMP0]], align 4
; CHECK-NEXT: ret i32 [[R]]
;
entry:
@@ -309,9 +309,10 @@ declare void @llvm.assume(i1)
define i32 @load_extract_idx_var_i64_known_valid_by_and(ptr %x, i64 %idx) {
; CHECK-LABEL: @load_extract_idx_var_i64_known_valid_by_and(
; CHECK-NEXT: entry:
-; CHECK-NEXT: [[IDX_CLAMPED:%.*]] = and i64 [[IDX:%.*]], 3
-; CHECK-NEXT: [[LV:%.*]] = load <4 x i32>, ptr [[X:%.*]], align 16
-; CHECK-NEXT: [[R:%.*]] = extractelement <4 x i32> [[LV]], i64 [[IDX_CLAMPED]]
+; CHECK-NEXT: [[IDX_FROZEN:%.*]] = freeze i64 [[IDX:%.*]]
+; CHECK-NEXT: [[IDX_CLAMPED:%.*]] = and i64 [[IDX_FROZEN]], 3
+; CHECK-NEXT: [[TMP0:%.*]] = getelementptr inbounds <4 x i32>, ptr [[X:%.*]], i32 0, i64 [[IDX_CLAMPED]]
+; CHECK-NEXT: [[R:%.*]] = load i32, ptr [[TMP0]], align 4
; CHECK-NEXT: ret i32 [[R]]
;
entry:
@@ -324,9 +325,10 @@ entry:
define i32 @vscale_load_extract_idx_var_i64_known_valid_by_and(ptr %x, i64 %idx) {
; CHECK-LABEL: @vscale_load_extract_idx_var_i64_known_valid_by_and(
; CHECK-NEXT: entry:
-; CHECK-NEXT: [[IDX_CLAMPED:%.*]] = and i64 [[IDX:%.*]], 3
-; CHECK-NEXT: [[LV:%.*]] = load <vscale x 4 x i32>, ptr [[X:%.*]], align 16
-; CHECK-NEXT: [[R:%.*]] = extractelement <vscale x 4 x i32> [[LV]], i64 [[IDX_CLAMPED]]
+; CHECK-NEXT: [[IDX_FROZEN:%.*]] = freeze i64 [[IDX:%.*]]
+; CHECK-NEXT: [[IDX_CLAMPED:%.*]] = and i64 [[IDX_FROZEN]], 3
+; CHECK-NEXT: [[TMP0:%.*]] = getelementptr inbounds <vscale x 4 x i32>, ptr [[X:%.*]], i32 0, i64 [[IDX_CLAMPED]]
+; CHECK-NEXT: [[R:%.*]] = load i32, ptr [[TMP0]], align 4
; CHECK-NEXT: ret i32 [[R]]
;
entry:
@@ -384,9 +386,10 @@ entry:
define i32 @load_extract_idx_var_i64_known_valid_by_urem(ptr %x, i64 %idx) {
; CHECK-LABEL: @load_extract_idx_var_i64_known_valid_by_urem(
; CHECK-NEXT: entry:
-; CHECK-NEXT: [[IDX_CLAMPED:%.*]] = urem i64 [[IDX:%.*]], 4
-; CHECK-NEXT: [[LV:%.*]] = load <4 x i32>, ptr [[X:%.*]], align 16
-; CHECK-NEXT: [[R:%.*]] = extractelement <4 x i32> [[LV]], i64 [[IDX_CLAMPED]]
+; CHECK-NEXT: [[IDX_FROZEN:%.*]] = freeze i64 [[IDX:%.*]]
+; CHECK-NEXT: [[IDX_CLAMPED:%.*]] = urem i64 [[IDX_FROZEN]], 4
+; CHECK-NEXT: [[TMP0:%.*]] = getelementptr inbounds <4 x i32>, ptr [[X:%.*]], i32 0, i64 [[IDX_CLAMPED]]
+; CHECK-NEXT: [[R:%.*]] = load i32, ptr [[TMP0]], align 4
; CHECK-NEXT: ret i32 [[R]]
;
entry:
@@ -399,9 +402,10 @@ entry:
define i32 @vscale_load_extract_idx_var_i64_known_valid_by_urem(ptr %x, i64 %idx) {
; CHECK-LABEL: @vscale_load_extract_idx_var_i64_known_valid_by_urem(
; CHECK-NEXT: entry:
-; CHECK-NEXT: [[IDX_CLAMPED:%.*]] = urem i64 [[IDX:%.*]], 4
-; CHECK-NEXT: [[LV:%.*]] = load <vscale x 4 x i32>, ptr [[X:%.*]], align 16
-; CHECK-NEXT: [[R:%.*]] = extractelement <vscale x 4 x i32> [[LV]], i64 [[IDX_CLAMPED]]
+; CHECK-NEXT: [[IDX_FROZEN:%.*]] = freeze i64 [[IDX:%.*]]
+; CHECK-NEXT: [[IDX_CLAMPED:%.*]] = urem i64 [[IDX_FROZEN]], 4
+; CHECK-NEXT: [[TMP0:%.*]] = getelementptr inbounds <vscale x 4 x i32>, ptr [[X:%.*]], i32 0, i64 [[IDX_CLAMPED]]
+; CHECK-NEXT: [[R:%.*]] = load i32, ptr [[TMP0]], align 4
; CHECK-NEXT: ret i32 [[R]]
;
entry:
@@ -789,11 +793,14 @@ define i32 @load_multiple_extracts_with_variable_indices_large_vector_only_first
define i32 @load_multiple_extracts_with_variable_indices_large_vector_all_valid_by_and(ptr %x, i64 %idx.0, i64 %idx.1) {
; CHECK-LABEL: @load_multiple_extracts_with_variable_indices_large_vector_all_valid_by_and(
-; CHECK-NEXT: [[IDX_0_CLAMPED:%.*]] = and i64 [[IDX_0:%.*]], 15
-; CHECK-NEXT: [[IDX_1_CLAMPED:%.*]] = and i64 [[IDX_1:%.*]], 15
-; CHECK-NEXT: [[LV:%.*]] = load <16 x i32>, ptr [[X:%.*]], align 64
-; CHECK-NEXT: [[E_0:%.*]] = extractelement <16 x i32> [[LV]], i64 [[IDX_0_CLAMPED]]
-; CHECK-NEXT: [[E_1:%.*]] = extractelement <16 x i32> [[LV]], i64 [[IDX_1_CLAMPED]]
+; CHECK-NEXT: [[IDX_0_FROZEN:%.*]] = freeze i64 [[IDX_0:%.*]]
+; CHECK-NEXT: [[IDX_0_CLAMPED:%.*]] = and i64 [[IDX_0_FROZEN]], 15
+; CHECK-NEXT: [[IDX_1_FROZEN:%.*]] = freeze i64 [[IDX_1:%.*]]
+; CHECK-NEXT: [[IDX_1_CLAMPED:%.*]] = and i64 [[IDX_1_FROZEN]], 15
+; CHECK-NEXT: [[TMP1:%.*]] = getelementptr inbounds <16 x i32>, ptr [[X:%.*]], i32 0, i64 [[IDX_0_CLAMPED]]
+; CHECK-NEXT: [[E_0:%.*]] = load i32, ptr [[TMP1]], align 4
+; CHECK-NEXT: [[TMP2:%.*]] = getelementptr inbounds <16 x i32>, ptr [[X]], i32 0, i64 [[IDX_1_CLAMPED]]
+; CHECK-NEXT: [[E_1:%.*]] = load i32, ptr [[TMP2]], align 4
; CHECK-NEXT: [[RES:%.*]] = add i32 [[E_0]], [[E_1]]
; CHECK-NEXT: ret i32 [[RES]]
;
@@ -809,11 +816,13 @@ define i32 @load_multiple_extracts_with_variable_indices_large_vector_all_valid_
define i32 @load_multiple_extracts_with_variable_indices_large_vector_all_valid_by_and_some_noundef(ptr %x, i64 %idx.0, i64 noundef %idx.1) {
; CHECK-LABEL: @load_multiple_extracts_with_variable_indices_large_vector_all_valid_by_and_some_noundef(
-; CHECK-NEXT: [[IDX_0_CLAMPED:%.*]] = and i64 [[IDX_0:%.*]], 15
+; CHECK-NEXT: [[IDX_0_FROZEN:%.*]] = freeze i64 [[IDX_0:%.*]]
+; CHECK-NEXT: [[IDX_0_CLAMPED:%.*]] = and i64 [[IDX_0_FROZEN]], 15
; CHECK-NEXT: [[IDX_1_CLAMPED:%.*]] = and i64 [[IDX_1:%.*]], 15
-; CHECK-NEXT: [[LV:%.*]] = load <16 x i32>, ptr [[X:%.*]], align 64
-; CHECK-NEXT: [[E_0:%.*]] = extractelement <16 x i32> [[LV]], i64 [[IDX_0_CLAMPED]]
-; CHECK-NEXT: [[E_1:%.*]] = extractelement <16 x i32> [[LV]], i64 [[IDX_1_CLAMPED]]
+; CHECK-NEXT: [[TMP1:%.*]] = getelementptr inbounds <16 x i32>, ptr [[X:%.*]], i32 0, i64 [[IDX_0_CLAMPED]]
+; CHECK-NEXT: [[E_0:%.*]] = load i32, ptr [[TMP1]], align 4
+; CHECK-NEXT: [[TMP2:%.*]] = getelementptr inbounds <16 x i32>, ptr [[X]], i32 0, i64 [[IDX_1_CLAMPED]]
+; CHECK-NEXT: [[E_1:%.*]] = load i32, ptr [[TMP2]], align 4
; CHECK-NEXT: [[RES:%.*]] = add i32 [[E_0]], [[E_1]]
; CHECK-NEXT: ret i32 [[RES]]
;
</pre>
</details>
https://github.com/llvm/llvm-project/pull/65445
More information about the llvm-commits
mailing list