[llvm] [VectorCombine] Enable transform 'scalarizeLoadExtract' for scalable vector types (PR #65443)
Ben Shi via llvm-commits
llvm-commits at lists.llvm.org
Wed Sep 6 00:03:23 PDT 2023
https://github.com/benshi001 created https://github.com/llvm/llvm-project/pull/65443:
This PR is stacked https://github.com/llvm/llvm-project/pull/65442
>From 8e7e9cece7748508cce5af2116fcf7b331216d9f Mon Sep 17 00:00:00 2001
From: Ben Shi <2283975856 at qq.com>
Date: Fri, 1 Sep 2023 14:43:42 +0800
Subject: [PATCH 1/2] [VectorCombine][test] Supplement tests of the
load-extractelement sequence
The newly added tests are all about scalable vector types.
---
.../load-extractelement-scalarization.ll | 129 ++++++++++++++++++
1 file changed, 129 insertions(+)
diff --git a/llvm/test/Transforms/VectorCombine/AArch64/load-extractelement-scalarization.ll b/llvm/test/Transforms/VectorCombine/AArch64/load-extractelement-scalarization.ll
index 1ed1476f6e2345e..946a9196064fde4 100644
--- a/llvm/test/Transforms/VectorCombine/AArch64/load-extractelement-scalarization.ll
+++ b/llvm/test/Transforms/VectorCombine/AArch64/load-extractelement-scalarization.ll
@@ -13,6 +13,17 @@ define i32 @load_extract_idx_0(ptr %x) {
ret i32 %r
}
+define i32 @vscale_load_extract_idx_0(ptr %x) {
+; CHECK-LABEL: @vscale_load_extract_idx_0(
+; CHECK-NEXT: [[LV:%.*]] = load <vscale x 4 x i32>, ptr [[X:%.*]], align 16
+; CHECK-NEXT: [[R:%.*]] = extractelement <vscale x 4 x i32> [[LV]], i32 0
+; CHECK-NEXT: ret i32 [[R]]
+;
+ %lv = load <vscale x 4 x i32>, ptr %x
+ %r = extractelement <vscale x 4 x i32> %lv, i32 0
+ ret i32 %r
+}
+
; If the original load had a smaller alignment than the scalar type, the
; smaller alignment should be used.
define i32 @load_extract_idx_0_small_alignment(ptr %x) {
@@ -48,6 +59,17 @@ define i32 @load_extract_idx_2(ptr %x) {
ret i32 %r
}
+define i32 @vscale_load_extract_idx_2(ptr %x) {
+; CHECK-LABEL: @vscale_load_extract_idx_2(
+; CHECK-NEXT: [[LV:%.*]] = load <vscale x 4 x i32>, ptr [[X:%.*]], align 16
+; CHECK-NEXT: [[R:%.*]] = extractelement <vscale x 4 x i32> [[LV]], i32 2
+; CHECK-NEXT: ret i32 [[R]]
+;
+ %lv = load <vscale x 4 x i32>, ptr %x
+ %r = extractelement <vscale x 4 x i32> %lv, i32 2
+ ret i32 %r
+}
+
define i32 @load_extract_idx_3(ptr %x) {
; CHECK-LABEL: @load_extract_idx_3(
; CHECK-NEXT: [[TMP1:%.*]] = getelementptr inbounds <4 x i32>, ptr [[X:%.*]], i32 0, i32 3
@@ -72,6 +94,17 @@ define i32 @load_extract_idx_4(ptr %x) {
ret i32 %r
}
+define i32 @vscale_load_extract_idx_4(ptr %x) {
+; CHECK-LABEL: @vscale_load_extract_idx_4(
+; CHECK-NEXT: [[LV:%.*]] = load <vscale x 4 x i32>, ptr [[X:%.*]], align 16
+; CHECK-NEXT: [[R:%.*]] = extractelement <vscale x 4 x i32> [[LV]], i32 4
+; CHECK-NEXT: ret i32 [[R]]
+;
+ %lv = load <vscale x 4 x i32>, ptr %x
+ %r = extractelement <vscale x 4 x i32> %lv, i32 4
+ ret i32 %r
+}
+
define i32 @load_extract_idx_var_i64(ptr %x, i64 %idx) {
; CHECK-LABEL: @load_extract_idx_var_i64(
; CHECK-NEXT: [[LV:%.*]] = load <4 x i32>, ptr [[X:%.*]], align 16
@@ -104,6 +137,25 @@ entry:
ret i32 %r
}
+define i32 @vscale_load_extract_idx_var_i64_known_valid_by_assume(ptr %x, i64 %idx) {
+; CHECK-LABEL: @vscale_load_extract_idx_var_i64_known_valid_by_assume(
+; CHECK-NEXT: entry:
+; CHECK-NEXT: [[CMP:%.*]] = icmp ult i64 [[IDX:%.*]], 4
+; CHECK-NEXT: call void @llvm.assume(i1 [[CMP]])
+; CHECK-NEXT: [[LV:%.*]] = load <vscale x 4 x i32>, ptr [[X:%.*]], align 16
+; CHECK-NEXT: call void @maythrow()
+; CHECK-NEXT: [[R:%.*]] = extractelement <vscale x 4 x i32> [[LV]], i64 [[IDX]]
+; CHECK-NEXT: ret i32 [[R]]
+;
+entry:
+ %cmp = icmp ult i64 %idx, 4
+ call void @llvm.assume(i1 %cmp)
+ %lv = load <vscale x 4 x i32>, ptr %x
+ call void @maythrow()
+ %r = extractelement <vscale x 4 x i32> %lv, i64 %idx
+ ret i32 %r
+}
+
declare i1 @cond()
define i32 @load_extract_idx_var_i64_known_valid_by_assume_in_dominating_block(ptr %x, i64 %idx, i1 %c.1) {
@@ -213,6 +265,23 @@ entry:
ret i32 %r
}
+define i32 @vscale_load_extract_idx_var_i64_not_known_valid_by_assume(ptr %x, i64 %idx) {
+; CHECK-LABEL: @vscale_load_extract_idx_var_i64_not_known_valid_by_assume(
+; CHECK-NEXT: entry:
+; CHECK-NEXT: [[CMP:%.*]] = icmp ult i64 [[IDX:%.*]], 5
+; CHECK-NEXT: call void @llvm.assume(i1 [[CMP]])
+; CHECK-NEXT: [[LV:%.*]] = load <vscale x 4 x i32>, ptr [[X:%.*]], align 16
+; CHECK-NEXT: [[R:%.*]] = extractelement <vscale x 4 x i32> [[LV]], i64 [[IDX]]
+; CHECK-NEXT: ret i32 [[R]]
+;
+entry:
+ %cmp = icmp ult i64 %idx, 5
+ call void @llvm.assume(i1 %cmp)
+ %lv = load <vscale x 4 x i32>, ptr %x
+ %r = extractelement <vscale x 4 x i32> %lv, i64 %idx
+ ret i32 %r
+}
+
declare void @llvm.assume(i1)
define i32 @load_extract_idx_var_i64_known_valid_by_and(ptr %x, i64 %idx) {
@@ -230,6 +299,21 @@ entry:
ret i32 %r
}
+define i32 @vscale_load_extract_idx_var_i64_known_valid_by_and(ptr %x, i64 %idx) {
+; CHECK-LABEL: @vscale_load_extract_idx_var_i64_known_valid_by_and(
+; CHECK-NEXT: entry:
+; CHECK-NEXT: [[IDX_CLAMPED:%.*]] = and i64 [[IDX:%.*]], 3
+; CHECK-NEXT: [[LV:%.*]] = load <vscale x 4 x i32>, ptr [[X:%.*]], align 16
+; CHECK-NEXT: [[R:%.*]] = extractelement <vscale x 4 x i32> [[LV]], i64 [[IDX_CLAMPED]]
+; CHECK-NEXT: ret i32 [[R]]
+;
+entry:
+ %idx.clamped = and i64 %idx, 3
+ %lv = load <vscale x 4 x i32>, ptr %x
+ %r = extractelement <vscale x 4 x i32> %lv, i64 %idx.clamped
+ ret i32 %r
+}
+
define i32 @load_extract_idx_var_i64_known_valid_by_and_noundef(ptr %x, i64 noundef %idx) {
; CHECK-LABEL: @load_extract_idx_var_i64_known_valid_by_and_noundef(
; CHECK-NEXT: entry:
@@ -260,6 +344,21 @@ entry:
ret i32 %r
}
+define i32 @vscale_load_extract_idx_var_i64_not_known_valid_by_and(ptr %x, i64 %idx) {
+; CHECK-LABEL: @vscale_load_extract_idx_var_i64_not_known_valid_by_and(
+; CHECK-NEXT: entry:
+; CHECK-NEXT: [[IDX_CLAMPED:%.*]] = and i64 [[IDX:%.*]], 4
+; CHECK-NEXT: [[LV:%.*]] = load <vscale x 4 x i32>, ptr [[X:%.*]], align 16
+; CHECK-NEXT: [[R:%.*]] = extractelement <vscale x 4 x i32> [[LV]], i64 [[IDX_CLAMPED]]
+; CHECK-NEXT: ret i32 [[R]]
+;
+entry:
+ %idx.clamped = and i64 %idx, 4
+ %lv = load <vscale x 4 x i32>, ptr %x
+ %r = extractelement <vscale x 4 x i32> %lv, i64 %idx.clamped
+ ret i32 %r
+}
+
define i32 @load_extract_idx_var_i64_known_valid_by_urem(ptr %x, i64 %idx) {
; CHECK-LABEL: @load_extract_idx_var_i64_known_valid_by_urem(
; CHECK-NEXT: entry:
@@ -275,6 +374,21 @@ entry:
ret i32 %r
}
+define i32 @vscale_load_extract_idx_var_i64_known_valid_by_urem(ptr %x, i64 %idx) {
+; CHECK-LABEL: @vscale_load_extract_idx_var_i64_known_valid_by_urem(
+; CHECK-NEXT: entry:
+; CHECK-NEXT: [[IDX_CLAMPED:%.*]] = urem i64 [[IDX:%.*]], 4
+; CHECK-NEXT: [[LV:%.*]] = load <vscale x 4 x i32>, ptr [[X:%.*]], align 16
+; CHECK-NEXT: [[R:%.*]] = extractelement <vscale x 4 x i32> [[LV]], i64 [[IDX_CLAMPED]]
+; CHECK-NEXT: ret i32 [[R]]
+;
+entry:
+ %idx.clamped = urem i64 %idx, 4
+ %lv = load <vscale x 4 x i32>, ptr %x
+ %r = extractelement <vscale x 4 x i32> %lv, i64 %idx.clamped
+ ret i32 %r
+}
+
define i32 @load_extract_idx_var_i64_known_valid_by_urem_noundef(ptr %x, i64 noundef %idx) {
; CHECK-LABEL: @load_extract_idx_var_i64_known_valid_by_urem_noundef(
; CHECK-NEXT: entry:
@@ -305,6 +419,21 @@ entry:
ret i32 %r
}
+define i32 @vscale_load_extract_idx_var_i64_not_known_valid_by_urem(ptr %x, i64 %idx) {
+; CHECK-LABEL: @vscale_load_extract_idx_var_i64_not_known_valid_by_urem(
+; CHECK-NEXT: entry:
+; CHECK-NEXT: [[IDX_CLAMPED:%.*]] = urem i64 [[IDX:%.*]], 5
+; CHECK-NEXT: [[LV:%.*]] = load <vscale x 4 x i32>, ptr [[X:%.*]], align 16
+; CHECK-NEXT: [[R:%.*]] = extractelement <vscale x 4 x i32> [[LV]], i64 [[IDX_CLAMPED]]
+; CHECK-NEXT: ret i32 [[R]]
+;
+entry:
+ %idx.clamped = urem i64 %idx, 5
+ %lv = load <vscale x 4 x i32>, ptr %x
+ %r = extractelement <vscale x 4 x i32> %lv, i64 %idx.clamped
+ ret i32 %r
+}
+
define i32 @load_extract_idx_var_i32(ptr %x, i32 %idx) {
; CHECK-LABEL: @load_extract_idx_var_i32(
; CHECK-NEXT: [[LV:%.*]] = load <4 x i32>, ptr [[X:%.*]], align 16
>From c71b28fdacdaac94b719e209621a1ed0d7c3c284 Mon Sep 17 00:00:00 2001
From: Ben Shi <2283975856 at qq.com>
Date: Sun, 3 Sep 2023 15:11:53 +0800
Subject: [PATCH 2/2] [VectorCombine] Enable transform 'scalarizeLoadExtract'
for scalable vector types
The transform 'scalarizeLoadExtract' can be applied to scalable
vector types if the index is less than the minimum number of elements.
---
.../Transforms/Vectorize/VectorCombine.cpp | 25 +++++++++----------
.../load-extractelement-scalarization.ll | 12 ++++-----
2 files changed, 18 insertions(+), 19 deletions(-)
diff --git a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
index 66e3bcaac0adb2e..4f95eaba8de7bd2 100644
--- a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
+++ b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
@@ -1134,14 +1134,14 @@ bool VectorCombine::scalarizeLoadExtract(Instruction &I) {
if (!match(&I, m_Load(m_Value(Ptr))))
return false;
- auto *FixedVT = cast<FixedVectorType>(I.getType());
+ auto *VecTy = cast<VectorType>(I.getType());
auto *LI = cast<LoadInst>(&I);
const DataLayout &DL = I.getModule()->getDataLayout();
- if (LI->isVolatile() || !DL.typeSizeEqualsStoreSize(FixedVT))
+ if (LI->isVolatile() || !DL.typeSizeEqualsStoreSize(VecTy))
return false;
InstructionCost OriginalCost =
- TTI.getMemoryOpCost(Instruction::Load, FixedVT, LI->getAlign(),
+ TTI.getMemoryOpCost(Instruction::Load, VecTy, LI->getAlign(),
LI->getPointerAddressSpace());
InstructionCost ScalarizedCost = 0;
@@ -1172,7 +1172,7 @@ bool VectorCombine::scalarizeLoadExtract(Instruction &I) {
LastCheckedInst = UI;
}
- auto ScalarIdx = canScalarizeAccess(FixedVT, UI->getOperand(1), &I, AC, DT);
+ auto ScalarIdx = canScalarizeAccess(VecTy, UI->getOperand(1), &I, AC, DT);
if (!ScalarIdx.isSafe()) {
// TODO: Freeze index if it is safe to do so.
ScalarIdx.discard();
@@ -1182,12 +1182,12 @@ bool VectorCombine::scalarizeLoadExtract(Instruction &I) {
auto *Index = dyn_cast<ConstantInt>(UI->getOperand(1));
TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
OriginalCost +=
- TTI.getVectorInstrCost(Instruction::ExtractElement, FixedVT, CostKind,
+ TTI.getVectorInstrCost(Instruction::ExtractElement, VecTy, CostKind,
Index ? Index->getZExtValue() : -1);
ScalarizedCost +=
- TTI.getMemoryOpCost(Instruction::Load, FixedVT->getElementType(),
+ TTI.getMemoryOpCost(Instruction::Load, VecTy->getElementType(),
Align(1), LI->getPointerAddressSpace());
- ScalarizedCost += TTI.getAddressComputationCost(FixedVT->getElementType());
+ ScalarizedCost += TTI.getAddressComputationCost(VecTy->getElementType());
}
if (ScalarizedCost >= OriginalCost)
@@ -1200,12 +1200,12 @@ bool VectorCombine::scalarizeLoadExtract(Instruction &I) {
Value *Idx = EI->getOperand(1);
Value *GEP =
- Builder.CreateInBoundsGEP(FixedVT, Ptr, {Builder.getInt32(0), Idx});
+ Builder.CreateInBoundsGEP(VecTy, Ptr, {Builder.getInt32(0), Idx});
auto *NewLoad = cast<LoadInst>(Builder.CreateLoad(
- FixedVT->getElementType(), GEP, EI->getName() + ".scalar"));
+ VecTy->getElementType(), GEP, EI->getName() + ".scalar"));
Align ScalarOpAlignment = computeAlignmentAfterScalarization(
- LI->getAlign(), FixedVT->getElementType(), Idx, DL);
+ LI->getAlign(), VecTy->getElementType(), Idx, DL);
NewLoad->setAlignment(ScalarOpAlignment);
replaceValue(*EI, *NewLoad);
@@ -1727,9 +1727,6 @@ bool VectorCombine::run() {
case Instruction::ShuffleVector:
MadeChange |= widenSubvectorLoad(I);
break;
- case Instruction::Load:
- MadeChange |= scalarizeLoadExtract(I);
- break;
default:
break;
}
@@ -1743,6 +1740,8 @@ bool VectorCombine::run() {
if (Opcode == Instruction::Store)
MadeChange |= foldSingleElementStore(I);
+ if (isa<VectorType>(I.getType()) && Opcode == Instruction::Load)
+ MadeChange |= scalarizeLoadExtract(I);
// If this is an early pipeline invocation of this pass, we are done.
if (TryEarlyFoldsOnly)
diff --git a/llvm/test/Transforms/VectorCombine/AArch64/load-extractelement-scalarization.ll b/llvm/test/Transforms/VectorCombine/AArch64/load-extractelement-scalarization.ll
index 946a9196064fde4..c9f9ef0661ace26 100644
--- a/llvm/test/Transforms/VectorCombine/AArch64/load-extractelement-scalarization.ll
+++ b/llvm/test/Transforms/VectorCombine/AArch64/load-extractelement-scalarization.ll
@@ -15,8 +15,8 @@ define i32 @load_extract_idx_0(ptr %x) {
define i32 @vscale_load_extract_idx_0(ptr %x) {
; CHECK-LABEL: @vscale_load_extract_idx_0(
-; CHECK-NEXT: [[LV:%.*]] = load <vscale x 4 x i32>, ptr [[X:%.*]], align 16
-; CHECK-NEXT: [[R:%.*]] = extractelement <vscale x 4 x i32> [[LV]], i32 0
+; CHECK-NEXT: [[TMP1:%.*]] = getelementptr inbounds <vscale x 4 x i32>, ptr [[X:%.*]], i32 0, i32 0
+; CHECK-NEXT: [[R:%.*]] = load i32, ptr [[TMP1]], align 16
; CHECK-NEXT: ret i32 [[R]]
;
%lv = load <vscale x 4 x i32>, ptr %x
@@ -61,8 +61,8 @@ define i32 @load_extract_idx_2(ptr %x) {
define i32 @vscale_load_extract_idx_2(ptr %x) {
; CHECK-LABEL: @vscale_load_extract_idx_2(
-; CHECK-NEXT: [[LV:%.*]] = load <vscale x 4 x i32>, ptr [[X:%.*]], align 16
-; CHECK-NEXT: [[R:%.*]] = extractelement <vscale x 4 x i32> [[LV]], i32 2
+; CHECK-NEXT: [[TMP1:%.*]] = getelementptr inbounds <vscale x 4 x i32>, ptr [[X:%.*]], i32 0, i32 2
+; CHECK-NEXT: [[R:%.*]] = load i32, ptr [[TMP1]], align 8
; CHECK-NEXT: ret i32 [[R]]
;
%lv = load <vscale x 4 x i32>, ptr %x
@@ -142,9 +142,9 @@ define i32 @vscale_load_extract_idx_var_i64_known_valid_by_assume(ptr %x, i64 %i
; CHECK-NEXT: entry:
; CHECK-NEXT: [[CMP:%.*]] = icmp ult i64 [[IDX:%.*]], 4
; CHECK-NEXT: call void @llvm.assume(i1 [[CMP]])
-; CHECK-NEXT: [[LV:%.*]] = load <vscale x 4 x i32>, ptr [[X:%.*]], align 16
; CHECK-NEXT: call void @maythrow()
-; CHECK-NEXT: [[R:%.*]] = extractelement <vscale x 4 x i32> [[LV]], i64 [[IDX]]
+; CHECK-NEXT: [[TMP0:%.*]] = getelementptr inbounds <vscale x 4 x i32>, ptr [[X:%.*]], i32 0, i64 [[IDX]]
+; CHECK-NEXT: [[R:%.*]] = load i32, ptr [[TMP0]], align 4
; CHECK-NEXT: ret i32 [[R]]
;
entry:
More information about the llvm-commits
mailing list