[llvm] [InstCombine] linearize complexity of findDemandedEltsByAllUsers() (PR #161436)
Princeton Ferro via llvm-commits
llvm-commits at lists.llvm.org
Tue Sep 30 13:33:15 PDT 2025
https://github.com/Prince781 created https://github.com/llvm/llvm-project/pull/161436
Don't use `APInt::operator|=()` since for large vectors with many uses this can be slow, if the total number of operations is `{# uses} x {size of vector}`. Instead, make `findDemandedEltsBySingleUser()` take constant time by using `setBit()`.
>From 426bcb1717372e180e299ff57920e0ae05842e91 Mon Sep 17 00:00:00 2001
From: Princeton Ferro <pferro at nvidia.com>
Date: Wed, 24 Sep 2025 20:16:53 -0700
Subject: [PATCH] linearize complexity of findDemandedEltsByAllUsers()
Don't use APInt::operator|=() since for large vectors with many uses
this can be slow, if the total number of operations is {# uses} x {size
of vector}. Instead, make findDemandedEltsBySingleUser() take constant
time by using setBit().
---
.../InstCombine/InstCombineVectorOps.cpp | 31 +++++++++++--------
1 file changed, 18 insertions(+), 13 deletions(-)
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp b/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp
index 6ef30663bf3ce..a099aaf9e6223 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp
@@ -320,11 +320,12 @@ Instruction *InstCombinerImpl::foldBitcastExtElt(ExtractElementInst &Ext) {
}
/// Find elements of V demanded by UserInstr.
-static APInt findDemandedEltsBySingleUser(Value *V, Instruction *UserInstr) {
- unsigned VWidth = cast<FixedVectorType>(V->getType())->getNumElements();
+static void findDemandedEltsBySingleUser(Value *V, Instruction *UserInstr,
+ APInt &UnionUsedElts) {
+ const unsigned VWidth = cast<FixedVectorType>(V->getType())->getNumElements();
- // Conservatively assume that all elements are needed.
- APInt UsedElts(APInt::getAllOnes(VWidth));
+ // Whether we can determine the elements accessed at compile time.
+ bool KnownIndices = false;
switch (UserInstr->getOpcode()) {
case Instruction::ExtractElement: {
@@ -332,32 +333,36 @@ static APInt findDemandedEltsBySingleUser(Value *V, Instruction *UserInstr) {
assert(EEI->getVectorOperand() == V);
ConstantInt *EEIIndexC = dyn_cast<ConstantInt>(EEI->getIndexOperand());
if (EEIIndexC && EEIIndexC->getValue().ult(VWidth)) {
- UsedElts = APInt::getOneBitSet(VWidth, EEIIndexC->getZExtValue());
+ UnionUsedElts.setBit(EEIIndexC->getZExtValue());
+ KnownIndices = true;
}
break;
}
case Instruction::ShuffleVector: {
ShuffleVectorInst *Shuffle = cast<ShuffleVectorInst>(UserInstr);
- unsigned MaskNumElts =
+ const unsigned MaskNumElts =
cast<FixedVectorType>(UserInstr->getType())->getNumElements();
- UsedElts = APInt(VWidth, 0);
- for (unsigned i = 0; i < MaskNumElts; i++) {
- unsigned MaskVal = Shuffle->getMaskValue(i);
+ KnownIndices = true;
+ for (auto I : llvm::seq(MaskNumElts)) {
+ unsigned MaskVal = Shuffle->getMaskValue(I);
if (MaskVal == -1u || MaskVal >= 2 * VWidth)
continue;
if (Shuffle->getOperand(0) == V && (MaskVal < VWidth))
- UsedElts.setBit(MaskVal);
+ UnionUsedElts.setBit(MaskVal);
if (Shuffle->getOperand(1) == V &&
((MaskVal >= VWidth) && (MaskVal < 2 * VWidth)))
- UsedElts.setBit(MaskVal - VWidth);
+ UnionUsedElts.setBit(MaskVal - VWidth);
}
break;
}
default:
break;
}
- return UsedElts;
+
+ // Conservatively assume all elements are accessed if indices are unknown
+ if (!KnownIndices)
+ UnionUsedElts.setAllBits();
}
/// Find union of elements of V demanded by all its users.
@@ -370,7 +375,7 @@ static APInt findDemandedEltsByAllUsers(Value *V) {
APInt UnionUsedElts(VWidth, 0);
for (const Use &U : V->uses()) {
if (Instruction *I = dyn_cast<Instruction>(U.getUser())) {
- UnionUsedElts |= findDemandedEltsBySingleUser(V, I);
+ findDemandedEltsBySingleUser(V, I, UnionUsedElts);
} else {
UnionUsedElts = APInt::getAllOnes(VWidth);
break;
More information about the llvm-commits
mailing list