[llvm] [VectorCombine] Add type shrinking and zext propagation for fixed-width vector types (PR #104606)
Mikael Holmén via llvm-commits
llvm-commits at lists.llvm.org
Wed Sep 11 03:41:05 PDT 2024
================
@@ -2493,6 +2494,96 @@ bool VectorCombine::foldSelectShuffle(Instruction &I, bool FromReduction) {
return true;
}
+/// Check if instruction depends on ZExt and this ZExt can be moved after the
+/// instruction. Move ZExt if it is profitable. For example:
+/// logic(zext(x),y) -> zext(logic(x,trunc(y)))
+/// lshr((zext(x),y) -> zext(lshr(x,trunc(y)))
+/// Cost model calculations takes into account if zext(x) has other users and
+/// whether it can be propagated through them too.
+bool VectorCombine::shrinkType(llvm::Instruction &I) {
+ Value *ZExted, *OtherOperand;
+ if (!match(&I, m_c_BitwiseLogic(m_ZExt(m_Value(ZExted)),
+ m_Value(OtherOperand))) &&
+ !match(&I, m_LShr(m_ZExt(m_Value(ZExted)), m_Value(OtherOperand))))
+ return false;
+
+ Value *ZExtOperand = I.getOperand(I.getOperand(0) == OtherOperand ? 1 : 0);
+
+ auto *BigTy = cast<FixedVectorType>(I.getType());
+ auto *SmallTy = cast<FixedVectorType>(ZExted->getType());
+ unsigned BW = SmallTy->getElementType()->getPrimitiveSizeInBits();
+
+ // Check that the expression overall uses at most the same number of bits as
+ // ZExted
+ KnownBits KB = computeKnownBits(&I, *DL);
+ if (KB.countMaxActiveBits() > BW)
+ return false;
+
+ // Calculate costs of leaving current IR as it is and moving ZExt operation
+ // later, along with adding truncates if needed
+ TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
+ InstructionCost ZExtCost = TTI.getCastInstrCost(
+ Instruction::ZExt, BigTy, SmallTy,
+ TargetTransformInfo::CastContextHint::None, CostKind);
+ InstructionCost CurrentCost = ZExtCost;
+ InstructionCost ShrinkCost = 0;
+
+ // Calculate total cost and check that we can propagate through all ZExt users
+ for (User *U : ZExtOperand->users()) {
+ auto *UI = cast<Instruction>(U);
+ if (UI == &I) {
+ CurrentCost +=
+ TTI.getArithmeticInstrCost(UI->getOpcode(), BigTy, CostKind);
+ ShrinkCost +=
+ TTI.getArithmeticInstrCost(UI->getOpcode(), SmallTy, CostKind);
+ ShrinkCost += ZExtCost;
+ continue;
+ }
+
+ if (!Instruction::isBinaryOp(UI->getOpcode()))
+ return false;
+
+ // Check if we can propagate ZExt through its other users
+ KB = computeKnownBits(UI, *DL);
+ if (KB.countMaxActiveBits() > BW)
+ return false;
+
+ CurrentCost += TTI.getArithmeticInstrCost(UI->getOpcode(), BigTy, CostKind);
+ ShrinkCost +=
+ TTI.getArithmeticInstrCost(UI->getOpcode(), SmallTy, CostKind);
+ ShrinkCost += ZExtCost;
+ }
+
+ // If the other instruction operand is not a constant, we'll need to
+ // generate a truncate instruction. So we have to adjust cost
+ if (!isa<Constant>(OtherOperand))
+ ShrinkCost += TTI.getCastInstrCost(
+ Instruction::Trunc, SmallTy, BigTy,
+ TargetTransformInfo::CastContextHint::None, CostKind);
+
+ // If the cost of shrinking types and leaving the IR is the same, we'll lean
+ // towards modifying the IR because shrinking opens opportunities for other
+ // shrinking optimisations.
+ if (ShrinkCost > CurrentCost)
+ return false;
+
+ Value *Op0 = ZExted;
+ if (auto *OI = dyn_cast<Instruction>(OtherOperand))
+ Builder.SetInsertPoint(OI->getNextNode());
----------------
mikaelholmen wrote:
I think this is currently broken.
What if OI->getNextNode() is a PHI?
For my out of tree target the following fails
```
opt -passes="vector-combine" bbi-99058.ll -o /dev/null
```
with
```
PHI nodes not grouped at top of basic block!
%vec.ind = phi <4 x i16> [ zeroinitializer, %entry ], [ zeroinitializer, %vector.body ]
label %vector.body
LLVM ERROR: Broken module found, compilation aborted!
```
and I think it's because the trunc created on the next line is inserted before the second PHI in the bb.
If you simply comment out the
```
if (ShrinkCost > CurrentCost)
return false;
```
code at line 2567 above it happens in tree as well. (I'm sure the testcase can be modified in some way so it happens even with the cost comparison at 2567 for some target but I didn't manage right now.)
bbi-99058.ll in my example is
```
define i64 @func_1() {
entry:
br label %vector.body
vector.body: ; preds = %vector.body, %entry
%vec.phi = phi <4 x i32> [ zeroinitializer, %entry ], [ %1, %vector.body ]
%vec.ind = phi <4 x i16> [ zeroinitializer, %entry ], [ zeroinitializer, %vector.body ]
%0 = zext <4 x i16> zeroinitializer to <4 x i32>
%1 = and <4 x i32> %vec.phi, %0
br label %vector.body
}
```
https://github.com/llvm/llvm-project/pull/104606
More information about the llvm-commits
mailing list