[llvm] [VectorCombine] Try to scalarize vector loads feeding bitcast instructions. (PR #164682)

Julian Nagele via llvm-commits llvm-commits at lists.llvm.org
Wed Oct 22 11:53:36 PDT 2025


https://github.com/juliannagele created https://github.com/llvm/llvm-project/pull/164682

This change aims to convert vector loads to scalar loads, if they are only converted to scalars after anyway.

>From 6fcd3779ac2298f9f38e66900351c2d83aa22330 Mon Sep 17 00:00:00 2001
From: Julian Nagele <j_nagele at apple.com>
Date: Wed, 22 Oct 2025 19:48:58 +0100
Subject: [PATCH] [VectorCombine] Try to scalarize vector loads feeding bitcast
 instructions.

---
 .../Transforms/Vectorize/VectorCombine.cpp    | 144 ++++++++++++++----
 .../AArch64/load-bitcast-scalarization.ll     | 136 +++++++++++++++++
 2 files changed, 252 insertions(+), 28 deletions(-)
 create mode 100644 llvm/test/Transforms/VectorCombine/AArch64/load-bitcast-scalarization.ll

diff --git a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
index d6eb00da11dc8..e045282c387fe 100644
--- a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
+++ b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
@@ -129,7 +129,9 @@ class VectorCombine {
   bool foldExtractedCmps(Instruction &I);
   bool foldBinopOfReductions(Instruction &I);
   bool foldSingleElementStore(Instruction &I);
-  bool scalarizeLoadExtract(Instruction &I);
+  bool scalarizeLoad(Instruction &I);
+  bool scalarizeLoadExtract(LoadInst *LI, VectorType *VecTy, Value *Ptr);
+  bool scalarizeLoadBitcast(LoadInst *LI, VectorType *VecTy, Value *Ptr);
   bool scalarizeExtExtract(Instruction &I);
   bool foldConcatOfBoolMasks(Instruction &I);
   bool foldPermuteOfBinops(Instruction &I);
@@ -1845,49 +1847,42 @@ bool VectorCombine::foldSingleElementStore(Instruction &I) {
   return false;
 }
 
-/// Try to scalarize vector loads feeding extractelement instructions.
-bool VectorCombine::scalarizeLoadExtract(Instruction &I) {
-  if (!TTI.allowVectorElementIndexingUsingGEP())
-    return false;
-
+/// Try to scalarize vector loads feeding extractelement or bitcast
+/// instructions.
+bool VectorCombine::scalarizeLoad(Instruction &I) {
   Value *Ptr;
   if (!match(&I, m_Load(m_Value(Ptr))))
     return false;
 
   auto *LI = cast<LoadInst>(&I);
   auto *VecTy = cast<VectorType>(LI->getType());
-  if (LI->isVolatile() || !DL->typeSizeEqualsStoreSize(VecTy->getScalarType()))
+  if (!VecTy || LI->isVolatile() ||
+      !DL->typeSizeEqualsStoreSize(VecTy->getScalarType()))
     return false;
 
-  InstructionCost OriginalCost =
-      TTI.getMemoryOpCost(Instruction::Load, VecTy, LI->getAlign(),
-                          LI->getPointerAddressSpace(), CostKind);
-  InstructionCost ScalarizedCost = 0;
-
+  // Check what type of users we have and ensure no memory modifications betwwen
+  // the load and its users.
+  bool AllExtracts = true;
+  bool AllBitcasts = true;
   Instruction *LastCheckedInst = LI;
   unsigned NumInstChecked = 0;
-  DenseMap<ExtractElementInst *, ScalarizationResult> NeedFreeze;
-  auto FailureGuard = make_scope_exit([&]() {
-    // If the transform is aborted, discard the ScalarizationResults.
-    for (auto &Pair : NeedFreeze)
-      Pair.second.discard();
-  });
 
-  // Check if all users of the load are extracts with no memory modifications
-  // between the load and the extract. Compute the cost of both the original
-  // code and the scalarized version.
   for (User *U : LI->users()) {
-    auto *UI = dyn_cast<ExtractElementInst>(U);
-    if (!UI || UI->getParent() != LI->getParent())
+    auto *UI = dyn_cast<Instruction>(U);
+    if (!UI || UI->getParent() != LI->getParent() || UI->use_empty())
       return false;
 
-    // If any extract is waiting to be erased, then bail out as this will
+    // If any user is waiting to be erased, then bail out as this will
     // distort the cost calculation and possibly lead to infinite loops.
     if (UI->use_empty())
       return false;
 
-    // Check if any instruction between the load and the extract may modify
-    // memory.
+    if (!isa<ExtractElementInst>(UI))
+      AllExtracts = false;
+    if (!isa<BitCastInst>(UI))
+      AllBitcasts = false;
+
+    // Check if any instruction between the load and the user may modify memory.
     if (LastCheckedInst->comesBefore(UI)) {
       for (Instruction &I :
            make_range(std::next(LI->getIterator()), UI->getIterator())) {
@@ -1899,6 +1894,35 @@ bool VectorCombine::scalarizeLoadExtract(Instruction &I) {
       }
       LastCheckedInst = UI;
     }
+  }
+
+  if (AllExtracts)
+    return scalarizeLoadExtract(LI, VecTy, Ptr);
+  if (AllBitcasts)
+    return scalarizeLoadBitcast(LI, VecTy, Ptr);
+  return false;
+}
+
+/// Try to scalarize vector loads feeding extractelement instructions.
+bool VectorCombine::scalarizeLoadExtract(LoadInst *LI, VectorType *VecTy,
+                                         Value *Ptr) {
+  if (!TTI.allowVectorElementIndexingUsingGEP())
+    return false;
+
+  InstructionCost OriginalCost =
+      TTI.getMemoryOpCost(Instruction::Load, VecTy, LI->getAlign(),
+                          LI->getPointerAddressSpace(), CostKind);
+  InstructionCost ScalarizedCost = 0;
+
+  DenseMap<ExtractElementInst *, ScalarizationResult> NeedFreeze;
+  auto FailureGuard = make_scope_exit([&]() {
+    // If the transform is aborted, discard the ScalarizationResults.
+    for (auto &Pair : NeedFreeze)
+      Pair.second.discard();
+  });
+
+  for (User *U : LI->users()) {
+    auto *UI = cast<ExtractElementInst>(U);
 
     auto ScalarIdx =
         canScalarizeAccess(VecTy, UI->getIndexOperand(), LI, AC, DT);
@@ -1920,7 +1944,7 @@ bool VectorCombine::scalarizeLoadExtract(Instruction &I) {
                                                     nullptr, nullptr, CostKind);
   }
 
-  LLVM_DEBUG(dbgs() << "Found all extractions of a vector load: " << I
+  LLVM_DEBUG(dbgs() << "Found all extractions of a vector load: " << *LI
                     << "\n  LoadExtractCost: " << OriginalCost
                     << " vs ScalarizedCost: " << ScalarizedCost << "\n");
 
@@ -1966,6 +1990,70 @@ bool VectorCombine::scalarizeLoadExtract(Instruction &I) {
   return true;
 }
 
+/// Try to scalarize vector loads feeding bitcast instructions.
+bool VectorCombine::scalarizeLoadBitcast(LoadInst *LI, VectorType *VecTy,
+                                         Value *Ptr) {
+  InstructionCost OriginalCost =
+      TTI.getMemoryOpCost(Instruction::Load, VecTy, LI->getAlign(),
+                          LI->getPointerAddressSpace(), CostKind);
+
+  Type *TargetScalarType = nullptr;
+  unsigned VecBitWidth = DL->getTypeSizeInBits(VecTy);
+
+  for (User *U : LI->users()) {
+    auto *BC = cast<BitCastInst>(U);
+
+    Type *DestTy = BC->getDestTy();
+    if (!DestTy->isIntegerTy() && !DestTy->isFloatingPointTy())
+      return false;
+
+    unsigned DestBitWidth = DL->getTypeSizeInBits(DestTy);
+    if (DestBitWidth != VecBitWidth)
+      return false;
+
+    // All bitcasts should target the same scalar type.
+    if (!TargetScalarType)
+      TargetScalarType = DestTy;
+    else if (TargetScalarType != DestTy)
+      return false;
+
+    OriginalCost +=
+        TTI.getCastInstrCost(Instruction::BitCast, TargetScalarType, VecTy,
+                             TTI.getCastContextHint(BC), CostKind, BC);
+  }
+
+  if (!TargetScalarType || LI->user_empty())
+    return false;
+  InstructionCost ScalarizedCost =
+      TTI.getMemoryOpCost(Instruction::Load, TargetScalarType, LI->getAlign(),
+                          LI->getPointerAddressSpace(), CostKind);
+
+  LLVM_DEBUG(dbgs() << "Found vector load feeding only bitcasts: " << *LI
+                    << "\n  OriginalCost: " << OriginalCost
+                    << " vs ScalarizedCost: " << ScalarizedCost << "\n");
+
+  if (ScalarizedCost >= OriginalCost)
+    return false;
+
+  // Ensure we add the load back to the worklist BEFORE its users so they can
+  // erased in the correct order.
+  Worklist.push(LI);
+
+  Builder.SetInsertPoint(LI);
+  auto *ScalarLoad =
+      Builder.CreateLoad(TargetScalarType, Ptr, LI->getName() + ".scalar");
+  ScalarLoad->setAlignment(LI->getAlign());
+  ScalarLoad->copyMetadata(*LI);
+
+  // Replace all bitcast users with the scalar load.
+  for (User *U : LI->users()) {
+    auto *BC = cast<BitCastInst>(U);
+    replaceValue(*BC, *ScalarLoad, false);
+  }
+
+  return true;
+}
+
 bool VectorCombine::scalarizeExtExtract(Instruction &I) {
   if (!TTI.allowVectorElementIndexingUsingGEP())
     return false;
@@ -4555,7 +4643,7 @@ bool VectorCombine::run() {
     if (IsVectorType) {
       if (scalarizeOpOrCmp(I))
         return true;
-      if (scalarizeLoadExtract(I))
+      if (scalarizeLoad(I))
         return true;
       if (scalarizeExtExtract(I))
         return true;
diff --git a/llvm/test/Transforms/VectorCombine/AArch64/load-bitcast-scalarization.ll b/llvm/test/Transforms/VectorCombine/AArch64/load-bitcast-scalarization.ll
new file mode 100644
index 0000000000000..464e5129262bc
--- /dev/null
+++ b/llvm/test/Transforms/VectorCombine/AArch64/load-bitcast-scalarization.ll
@@ -0,0 +1,136 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
+; RUN: opt -passes=vector-combine -mtriple=arm64-apple-darwinos -S %s | FileCheck %s
+
+define i32 @load_v4i8_bitcast_to_i32(ptr %x) {
+; CHECK-LABEL: define i32 @load_v4i8_bitcast_to_i32(
+; CHECK-SAME: ptr [[X:%.*]]) {
+; CHECK-NEXT:    [[R_SCALAR:%.*]] = load i32, ptr [[X]], align 4
+; CHECK-NEXT:    ret i32 [[R_SCALAR]]
+;
+  %lv = load <4 x i8>, ptr %x
+  %r = bitcast <4 x i8> %lv to i32
+  ret i32 %r
+}
+
+define i64 @load_v2i32_bitcast_to_i64(ptr %x) {
+; CHECK-LABEL: define i64 @load_v2i32_bitcast_to_i64(
+; CHECK-SAME: ptr [[X:%.*]]) {
+; CHECK-NEXT:    [[R_SCALAR:%.*]] = load i64, ptr [[X]], align 8
+; CHECK-NEXT:    ret i64 [[R_SCALAR]]
+;
+  %lv = load <2 x i32>, ptr %x
+  %r = bitcast <2 x i32> %lv to i64
+  ret i64 %r
+}
+
+define float @load_v4i8_bitcast_to_float(ptr %x) {
+; CHECK-LABEL: define float @load_v4i8_bitcast_to_float(
+; CHECK-SAME: ptr [[X:%.*]]) {
+; CHECK-NEXT:    [[R_SCALAR:%.*]] = load float, ptr [[X]], align 4
+; CHECK-NEXT:    ret float [[R_SCALAR]]
+;
+  %lv = load <4 x i8>, ptr %x
+  %r = bitcast <4 x i8> %lv to float
+  ret float %r
+}
+
+define float @load_v2i16_bitcast_to_float(ptr %x) {
+; CHECK-LABEL: define float @load_v2i16_bitcast_to_float(
+; CHECK-SAME: ptr [[X:%.*]]) {
+; CHECK-NEXT:    [[R_SCALAR:%.*]] = load float, ptr [[X]], align 4
+; CHECK-NEXT:    ret float [[R_SCALAR]]
+;
+  %lv = load <2 x i16>, ptr %x
+  %r = bitcast <2 x i16> %lv to float
+  ret float %r
+}
+
+define double @load_v4i16_bitcast_to_double(ptr %x) {
+; CHECK-LABEL: define double @load_v4i16_bitcast_to_double(
+; CHECK-SAME: ptr [[X:%.*]]) {
+; CHECK-NEXT:    [[LV:%.*]] = load <4 x i16>, ptr [[X]], align 8
+; CHECK-NEXT:    [[R_SCALAR:%.*]] = bitcast <4 x i16> [[LV]] to double
+; CHECK-NEXT:    ret double [[R_SCALAR]]
+;
+  %lv = load <4 x i16>, ptr %x
+  %r = bitcast <4 x i16> %lv to double
+  ret double %r
+}
+
+define double @load_v2i32_bitcast_to_double(ptr %x) {
+; CHECK-LABEL: define double @load_v2i32_bitcast_to_double(
+; CHECK-SAME: ptr [[X:%.*]]) {
+; CHECK-NEXT:    [[LV:%.*]] = load <2 x i32>, ptr [[X]], align 8
+; CHECK-NEXT:    [[R_SCALAR:%.*]] = bitcast <2 x i32> [[LV]] to double
+; CHECK-NEXT:    ret double [[R_SCALAR]]
+;
+  %lv = load <2 x i32>, ptr %x
+  %r = bitcast <2 x i32> %lv to double
+  ret double %r
+}
+
+; Multiple users with the same bitcast type should be scalarized.
+define i32 @load_v4i8_bitcast_multiple_users_same_type(ptr %x) {
+; CHECK-LABEL: define i32 @load_v4i8_bitcast_multiple_users_same_type(
+; CHECK-SAME: ptr [[X:%.*]]) {
+; CHECK-NEXT:    [[LV_SCALAR:%.*]] = load i32, ptr [[X]], align 4
+; CHECK-NEXT:    [[ADD:%.*]] = add i32 [[LV_SCALAR]], [[LV_SCALAR]]
+; CHECK-NEXT:    ret i32 [[ADD]]
+;
+  %lv = load <4 x i8>, ptr %x
+  %r1 = bitcast <4 x i8> %lv to i32
+  %r2 = bitcast <4 x i8> %lv to i32
+  %add = add i32 %r1, %r2
+  ret i32 %add
+}
+
+; Different bitcast types should not be scalarized.
+define i32 @load_v4i8_bitcast_multiple_users_different_types(ptr %x) {
+; CHECK-LABEL: define i32 @load_v4i8_bitcast_multiple_users_different_types(
+; CHECK-SAME: ptr [[X:%.*]]) {
+; CHECK-NEXT:    [[LV:%.*]] = load <4 x i8>, ptr [[X]], align 4
+; CHECK-NEXT:    [[R1:%.*]] = bitcast <4 x i8> [[LV]] to i32
+; CHECK-NEXT:    [[R2:%.*]] = bitcast <4 x i8> [[LV]] to float
+; CHECK-NEXT:    [[R2_INT:%.*]] = bitcast float [[R2]] to i32
+; CHECK-NEXT:    [[ADD:%.*]] = add i32 [[R1]], [[R2_INT]]
+; CHECK-NEXT:    ret i32 [[ADD]]
+;
+  %lv = load <4 x i8>, ptr %x
+  %r1 = bitcast <4 x i8> %lv to i32
+  %r2 = bitcast <4 x i8> %lv to float
+  %r2.int = bitcast float %r2 to i32
+  %add = add i32 %r1, %r2.int
+  ret i32 %add
+}
+
+; Bitcast to vector should not be scalarized.
+define <2 x i16> @load_v4i8_bitcast_to_vector(ptr %x) {
+; CHECK-LABEL: define <2 x i16> @load_v4i8_bitcast_to_vector(
+; CHECK-SAME: ptr [[X:%.*]]) {
+; CHECK-NEXT:    [[LV:%.*]] = load <4 x i8>, ptr [[X]], align 4
+; CHECK-NEXT:    [[R:%.*]] = bitcast <4 x i8> [[LV]] to <2 x i16>
+; CHECK-NEXT:    ret <2 x i16> [[R]]
+;
+  %lv = load <4 x i8>, ptr %x
+  %r = bitcast <4 x i8> %lv to <2 x i16>
+  ret <2 x i16> %r
+}
+
+; Load with both bitcast users and other users should not be scalarized.
+define i32 @load_v4i8_mixed_users(ptr %x) {
+; CHECK-LABEL: define i32 @load_v4i8_mixed_users(
+; CHECK-SAME: ptr [[X:%.*]]) {
+; CHECK-NEXT:    [[LV:%.*]] = load <4 x i8>, ptr [[X]], align 4
+; CHECK-NEXT:    [[R1:%.*]] = bitcast <4 x i8> [[LV]] to i32
+; CHECK-NEXT:    [[R2:%.*]] = extractelement <4 x i8> [[LV]], i32 0
+; CHECK-NEXT:    [[R2_EXT:%.*]] = zext i8 [[R2]] to i32
+; CHECK-NEXT:    [[ADD:%.*]] = add i32 [[R1]], [[R2_EXT]]
+; CHECK-NEXT:    ret i32 [[ADD]]
+;
+  %lv = load <4 x i8>, ptr %x
+  %r1 = bitcast <4 x i8> %lv to i32
+  %r2 = extractelement <4 x i8> %lv, i32 0
+  %r2.ext = zext i8 %r2 to i32
+  %add = add i32 %r1, %r2.ext
+  ret i32 %add
+}



More information about the llvm-commits mailing list