[llvm] [VectorCombine] Try to scalarize vector loads feeding bitcast instructions. (PR #164682)
via llvm-commits
llvm-commits at lists.llvm.org
Wed Oct 22 11:54:21 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-vectorizers
Author: Julian Nagele (juliannagele)
<details>
<summary>Changes</summary>
This change aims to convert vector loads to scalar loads, if they are only converted to scalars after anyway.
---
Full diff: https://github.com/llvm/llvm-project/pull/164682.diff
2 Files Affected:
- (modified) llvm/lib/Transforms/Vectorize/VectorCombine.cpp (+116-28)
- (added) llvm/test/Transforms/VectorCombine/AArch64/load-bitcast-scalarization.ll (+136)
``````````diff
diff --git a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
index d6eb00da11dc8..e045282c387fe 100644
--- a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
+++ b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
@@ -129,7 +129,9 @@ class VectorCombine {
bool foldExtractedCmps(Instruction &I);
bool foldBinopOfReductions(Instruction &I);
bool foldSingleElementStore(Instruction &I);
- bool scalarizeLoadExtract(Instruction &I);
+ bool scalarizeLoad(Instruction &I);
+ bool scalarizeLoadExtract(LoadInst *LI, VectorType *VecTy, Value *Ptr);
+ bool scalarizeLoadBitcast(LoadInst *LI, VectorType *VecTy, Value *Ptr);
bool scalarizeExtExtract(Instruction &I);
bool foldConcatOfBoolMasks(Instruction &I);
bool foldPermuteOfBinops(Instruction &I);
@@ -1845,49 +1847,42 @@ bool VectorCombine::foldSingleElementStore(Instruction &I) {
return false;
}
-/// Try to scalarize vector loads feeding extractelement instructions.
-bool VectorCombine::scalarizeLoadExtract(Instruction &I) {
- if (!TTI.allowVectorElementIndexingUsingGEP())
- return false;
-
+/// Try to scalarize vector loads feeding extractelement or bitcast
+/// instructions.
+bool VectorCombine::scalarizeLoad(Instruction &I) {
Value *Ptr;
if (!match(&I, m_Load(m_Value(Ptr))))
return false;
auto *LI = cast<LoadInst>(&I);
auto *VecTy = cast<VectorType>(LI->getType());
- if (LI->isVolatile() || !DL->typeSizeEqualsStoreSize(VecTy->getScalarType()))
+ if (!VecTy || LI->isVolatile() ||
+ !DL->typeSizeEqualsStoreSize(VecTy->getScalarType()))
return false;
- InstructionCost OriginalCost =
- TTI.getMemoryOpCost(Instruction::Load, VecTy, LI->getAlign(),
- LI->getPointerAddressSpace(), CostKind);
- InstructionCost ScalarizedCost = 0;
-
+ // Check what type of users we have and ensure no memory modifications betwwen
+ // the load and its users.
+ bool AllExtracts = true;
+ bool AllBitcasts = true;
Instruction *LastCheckedInst = LI;
unsigned NumInstChecked = 0;
- DenseMap<ExtractElementInst *, ScalarizationResult> NeedFreeze;
- auto FailureGuard = make_scope_exit([&]() {
- // If the transform is aborted, discard the ScalarizationResults.
- for (auto &Pair : NeedFreeze)
- Pair.second.discard();
- });
- // Check if all users of the load are extracts with no memory modifications
- // between the load and the extract. Compute the cost of both the original
- // code and the scalarized version.
for (User *U : LI->users()) {
- auto *UI = dyn_cast<ExtractElementInst>(U);
- if (!UI || UI->getParent() != LI->getParent())
+ auto *UI = dyn_cast<Instruction>(U);
+ if (!UI || UI->getParent() != LI->getParent() || UI->use_empty())
return false;
- // If any extract is waiting to be erased, then bail out as this will
+ // If any user is waiting to be erased, then bail out as this will
// distort the cost calculation and possibly lead to infinite loops.
if (UI->use_empty())
return false;
- // Check if any instruction between the load and the extract may modify
- // memory.
+ if (!isa<ExtractElementInst>(UI))
+ AllExtracts = false;
+ if (!isa<BitCastInst>(UI))
+ AllBitcasts = false;
+
+ // Check if any instruction between the load and the user may modify memory.
if (LastCheckedInst->comesBefore(UI)) {
for (Instruction &I :
make_range(std::next(LI->getIterator()), UI->getIterator())) {
@@ -1899,6 +1894,35 @@ bool VectorCombine::scalarizeLoadExtract(Instruction &I) {
}
LastCheckedInst = UI;
}
+ }
+
+ if (AllExtracts)
+ return scalarizeLoadExtract(LI, VecTy, Ptr);
+ if (AllBitcasts)
+ return scalarizeLoadBitcast(LI, VecTy, Ptr);
+ return false;
+}
+
+/// Try to scalarize vector loads feeding extractelement instructions.
+bool VectorCombine::scalarizeLoadExtract(LoadInst *LI, VectorType *VecTy,
+ Value *Ptr) {
+ if (!TTI.allowVectorElementIndexingUsingGEP())
+ return false;
+
+ InstructionCost OriginalCost =
+ TTI.getMemoryOpCost(Instruction::Load, VecTy, LI->getAlign(),
+ LI->getPointerAddressSpace(), CostKind);
+ InstructionCost ScalarizedCost = 0;
+
+ DenseMap<ExtractElementInst *, ScalarizationResult> NeedFreeze;
+ auto FailureGuard = make_scope_exit([&]() {
+ // If the transform is aborted, discard the ScalarizationResults.
+ for (auto &Pair : NeedFreeze)
+ Pair.second.discard();
+ });
+
+ for (User *U : LI->users()) {
+ auto *UI = cast<ExtractElementInst>(U);
auto ScalarIdx =
canScalarizeAccess(VecTy, UI->getIndexOperand(), LI, AC, DT);
@@ -1920,7 +1944,7 @@ bool VectorCombine::scalarizeLoadExtract(Instruction &I) {
nullptr, nullptr, CostKind);
}
- LLVM_DEBUG(dbgs() << "Found all extractions of a vector load: " << I
+ LLVM_DEBUG(dbgs() << "Found all extractions of a vector load: " << *LI
<< "\n LoadExtractCost: " << OriginalCost
<< " vs ScalarizedCost: " << ScalarizedCost << "\n");
@@ -1966,6 +1990,70 @@ bool VectorCombine::scalarizeLoadExtract(Instruction &I) {
return true;
}
+/// Try to scalarize vector loads feeding bitcast instructions.
+bool VectorCombine::scalarizeLoadBitcast(LoadInst *LI, VectorType *VecTy,
+ Value *Ptr) {
+ InstructionCost OriginalCost =
+ TTI.getMemoryOpCost(Instruction::Load, VecTy, LI->getAlign(),
+ LI->getPointerAddressSpace(), CostKind);
+
+ Type *TargetScalarType = nullptr;
+ unsigned VecBitWidth = DL->getTypeSizeInBits(VecTy);
+
+ for (User *U : LI->users()) {
+ auto *BC = cast<BitCastInst>(U);
+
+ Type *DestTy = BC->getDestTy();
+ if (!DestTy->isIntegerTy() && !DestTy->isFloatingPointTy())
+ return false;
+
+ unsigned DestBitWidth = DL->getTypeSizeInBits(DestTy);
+ if (DestBitWidth != VecBitWidth)
+ return false;
+
+ // All bitcasts should target the same scalar type.
+ if (!TargetScalarType)
+ TargetScalarType = DestTy;
+ else if (TargetScalarType != DestTy)
+ return false;
+
+ OriginalCost +=
+ TTI.getCastInstrCost(Instruction::BitCast, TargetScalarType, VecTy,
+ TTI.getCastContextHint(BC), CostKind, BC);
+ }
+
+ if (!TargetScalarType || LI->user_empty())
+ return false;
+ InstructionCost ScalarizedCost =
+ TTI.getMemoryOpCost(Instruction::Load, TargetScalarType, LI->getAlign(),
+ LI->getPointerAddressSpace(), CostKind);
+
+ LLVM_DEBUG(dbgs() << "Found vector load feeding only bitcasts: " << *LI
+ << "\n OriginalCost: " << OriginalCost
+ << " vs ScalarizedCost: " << ScalarizedCost << "\n");
+
+ if (ScalarizedCost >= OriginalCost)
+ return false;
+
+ // Ensure we add the load back to the worklist BEFORE its users so they can
+ // erased in the correct order.
+ Worklist.push(LI);
+
+ Builder.SetInsertPoint(LI);
+ auto *ScalarLoad =
+ Builder.CreateLoad(TargetScalarType, Ptr, LI->getName() + ".scalar");
+ ScalarLoad->setAlignment(LI->getAlign());
+ ScalarLoad->copyMetadata(*LI);
+
+ // Replace all bitcast users with the scalar load.
+ for (User *U : LI->users()) {
+ auto *BC = cast<BitCastInst>(U);
+ replaceValue(*BC, *ScalarLoad, false);
+ }
+
+ return true;
+}
+
bool VectorCombine::scalarizeExtExtract(Instruction &I) {
if (!TTI.allowVectorElementIndexingUsingGEP())
return false;
@@ -4555,7 +4643,7 @@ bool VectorCombine::run() {
if (IsVectorType) {
if (scalarizeOpOrCmp(I))
return true;
- if (scalarizeLoadExtract(I))
+ if (scalarizeLoad(I))
return true;
if (scalarizeExtExtract(I))
return true;
diff --git a/llvm/test/Transforms/VectorCombine/AArch64/load-bitcast-scalarization.ll b/llvm/test/Transforms/VectorCombine/AArch64/load-bitcast-scalarization.ll
new file mode 100644
index 0000000000000..464e5129262bc
--- /dev/null
+++ b/llvm/test/Transforms/VectorCombine/AArch64/load-bitcast-scalarization.ll
@@ -0,0 +1,136 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
+; RUN: opt -passes=vector-combine -mtriple=arm64-apple-darwinos -S %s | FileCheck %s
+
+define i32 @load_v4i8_bitcast_to_i32(ptr %x) {
+; CHECK-LABEL: define i32 @load_v4i8_bitcast_to_i32(
+; CHECK-SAME: ptr [[X:%.*]]) {
+; CHECK-NEXT: [[R_SCALAR:%.*]] = load i32, ptr [[X]], align 4
+; CHECK-NEXT: ret i32 [[R_SCALAR]]
+;
+ %lv = load <4 x i8>, ptr %x
+ %r = bitcast <4 x i8> %lv to i32
+ ret i32 %r
+}
+
+define i64 @load_v2i32_bitcast_to_i64(ptr %x) {
+; CHECK-LABEL: define i64 @load_v2i32_bitcast_to_i64(
+; CHECK-SAME: ptr [[X:%.*]]) {
+; CHECK-NEXT: [[R_SCALAR:%.*]] = load i64, ptr [[X]], align 8
+; CHECK-NEXT: ret i64 [[R_SCALAR]]
+;
+ %lv = load <2 x i32>, ptr %x
+ %r = bitcast <2 x i32> %lv to i64
+ ret i64 %r
+}
+
+define float @load_v4i8_bitcast_to_float(ptr %x) {
+; CHECK-LABEL: define float @load_v4i8_bitcast_to_float(
+; CHECK-SAME: ptr [[X:%.*]]) {
+; CHECK-NEXT: [[R_SCALAR:%.*]] = load float, ptr [[X]], align 4
+; CHECK-NEXT: ret float [[R_SCALAR]]
+;
+ %lv = load <4 x i8>, ptr %x
+ %r = bitcast <4 x i8> %lv to float
+ ret float %r
+}
+
+define float @load_v2i16_bitcast_to_float(ptr %x) {
+; CHECK-LABEL: define float @load_v2i16_bitcast_to_float(
+; CHECK-SAME: ptr [[X:%.*]]) {
+; CHECK-NEXT: [[R_SCALAR:%.*]] = load float, ptr [[X]], align 4
+; CHECK-NEXT: ret float [[R_SCALAR]]
+;
+ %lv = load <2 x i16>, ptr %x
+ %r = bitcast <2 x i16> %lv to float
+ ret float %r
+}
+
+define double @load_v4i16_bitcast_to_double(ptr %x) {
+; CHECK-LABEL: define double @load_v4i16_bitcast_to_double(
+; CHECK-SAME: ptr [[X:%.*]]) {
+; CHECK-NEXT: [[LV:%.*]] = load <4 x i16>, ptr [[X]], align 8
+; CHECK-NEXT: [[R_SCALAR:%.*]] = bitcast <4 x i16> [[LV]] to double
+; CHECK-NEXT: ret double [[R_SCALAR]]
+;
+ %lv = load <4 x i16>, ptr %x
+ %r = bitcast <4 x i16> %lv to double
+ ret double %r
+}
+
+define double @load_v2i32_bitcast_to_double(ptr %x) {
+; CHECK-LABEL: define double @load_v2i32_bitcast_to_double(
+; CHECK-SAME: ptr [[X:%.*]]) {
+; CHECK-NEXT: [[LV:%.*]] = load <2 x i32>, ptr [[X]], align 8
+; CHECK-NEXT: [[R_SCALAR:%.*]] = bitcast <2 x i32> [[LV]] to double
+; CHECK-NEXT: ret double [[R_SCALAR]]
+;
+ %lv = load <2 x i32>, ptr %x
+ %r = bitcast <2 x i32> %lv to double
+ ret double %r
+}
+
+; Multiple users with the same bitcast type should be scalarized.
+define i32 @load_v4i8_bitcast_multiple_users_same_type(ptr %x) {
+; CHECK-LABEL: define i32 @load_v4i8_bitcast_multiple_users_same_type(
+; CHECK-SAME: ptr [[X:%.*]]) {
+; CHECK-NEXT: [[LV_SCALAR:%.*]] = load i32, ptr [[X]], align 4
+; CHECK-NEXT: [[ADD:%.*]] = add i32 [[LV_SCALAR]], [[LV_SCALAR]]
+; CHECK-NEXT: ret i32 [[ADD]]
+;
+ %lv = load <4 x i8>, ptr %x
+ %r1 = bitcast <4 x i8> %lv to i32
+ %r2 = bitcast <4 x i8> %lv to i32
+ %add = add i32 %r1, %r2
+ ret i32 %add
+}
+
+; Different bitcast types should not be scalarized.
+define i32 @load_v4i8_bitcast_multiple_users_different_types(ptr %x) {
+; CHECK-LABEL: define i32 @load_v4i8_bitcast_multiple_users_different_types(
+; CHECK-SAME: ptr [[X:%.*]]) {
+; CHECK-NEXT: [[LV:%.*]] = load <4 x i8>, ptr [[X]], align 4
+; CHECK-NEXT: [[R1:%.*]] = bitcast <4 x i8> [[LV]] to i32
+; CHECK-NEXT: [[R2:%.*]] = bitcast <4 x i8> [[LV]] to float
+; CHECK-NEXT: [[R2_INT:%.*]] = bitcast float [[R2]] to i32
+; CHECK-NEXT: [[ADD:%.*]] = add i32 [[R1]], [[R2_INT]]
+; CHECK-NEXT: ret i32 [[ADD]]
+;
+ %lv = load <4 x i8>, ptr %x
+ %r1 = bitcast <4 x i8> %lv to i32
+ %r2 = bitcast <4 x i8> %lv to float
+ %r2.int = bitcast float %r2 to i32
+ %add = add i32 %r1, %r2.int
+ ret i32 %add
+}
+
+; Bitcast to vector should not be scalarized.
+define <2 x i16> @load_v4i8_bitcast_to_vector(ptr %x) {
+; CHECK-LABEL: define <2 x i16> @load_v4i8_bitcast_to_vector(
+; CHECK-SAME: ptr [[X:%.*]]) {
+; CHECK-NEXT: [[LV:%.*]] = load <4 x i8>, ptr [[X]], align 4
+; CHECK-NEXT: [[R:%.*]] = bitcast <4 x i8> [[LV]] to <2 x i16>
+; CHECK-NEXT: ret <2 x i16> [[R]]
+;
+ %lv = load <4 x i8>, ptr %x
+ %r = bitcast <4 x i8> %lv to <2 x i16>
+ ret <2 x i16> %r
+}
+
+; Load with both bitcast users and other users should not be scalarized.
+define i32 @load_v4i8_mixed_users(ptr %x) {
+; CHECK-LABEL: define i32 @load_v4i8_mixed_users(
+; CHECK-SAME: ptr [[X:%.*]]) {
+; CHECK-NEXT: [[LV:%.*]] = load <4 x i8>, ptr [[X]], align 4
+; CHECK-NEXT: [[R1:%.*]] = bitcast <4 x i8> [[LV]] to i32
+; CHECK-NEXT: [[R2:%.*]] = extractelement <4 x i8> [[LV]], i32 0
+; CHECK-NEXT: [[R2_EXT:%.*]] = zext i8 [[R2]] to i32
+; CHECK-NEXT: [[ADD:%.*]] = add i32 [[R1]], [[R2_EXT]]
+; CHECK-NEXT: ret i32 [[ADD]]
+;
+ %lv = load <4 x i8>, ptr %x
+ %r1 = bitcast <4 x i8> %lv to i32
+ %r2 = extractelement <4 x i8> %lv, i32 0
+ %r2.ext = zext i8 %r2 to i32
+ %add = add i32 %r1, %r2.ext
+ ret i32 %add
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/164682
More information about the llvm-commits
mailing list