[llvm] [VectorCombine] Add type shrinking and zext propagation for fixed-width vector types (PR #104606)
Igor Kirillov via llvm-commits
llvm-commits at lists.llvm.org
Wed Sep 11 06:56:12 PDT 2024
================
@@ -2493,6 +2494,96 @@ bool VectorCombine::foldSelectShuffle(Instruction &I, bool FromReduction) {
return true;
}
+/// Check if instruction depends on ZExt and this ZExt can be moved after the
+/// instruction. Move ZExt if it is profitable. For example:
+/// logic(zext(x),y) -> zext(logic(x,trunc(y)))
+/// lshr((zext(x),y) -> zext(lshr(x,trunc(y)))
+/// Cost model calculations takes into account if zext(x) has other users and
+/// whether it can be propagated through them too.
+bool VectorCombine::shrinkType(llvm::Instruction &I) {
+ Value *ZExted, *OtherOperand;
+ if (!match(&I, m_c_BitwiseLogic(m_ZExt(m_Value(ZExted)),
+ m_Value(OtherOperand))) &&
+ !match(&I, m_LShr(m_ZExt(m_Value(ZExted)), m_Value(OtherOperand))))
+ return false;
+
+ Value *ZExtOperand = I.getOperand(I.getOperand(0) == OtherOperand ? 1 : 0);
+
+ auto *BigTy = cast<FixedVectorType>(I.getType());
+ auto *SmallTy = cast<FixedVectorType>(ZExted->getType());
+ unsigned BW = SmallTy->getElementType()->getPrimitiveSizeInBits();
+
+ // Check that the expression overall uses at most the same number of bits as
+ // ZExted
+ KnownBits KB = computeKnownBits(&I, *DL);
+ if (KB.countMaxActiveBits() > BW)
+ return false;
+
+ // Calculate costs of leaving current IR as it is and moving ZExt operation
+ // later, along with adding truncates if needed
+ TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
+ InstructionCost ZExtCost = TTI.getCastInstrCost(
+ Instruction::ZExt, BigTy, SmallTy,
+ TargetTransformInfo::CastContextHint::None, CostKind);
+ InstructionCost CurrentCost = ZExtCost;
+ InstructionCost ShrinkCost = 0;
+
+ // Calculate total cost and check that we can propagate through all ZExt users
+ for (User *U : ZExtOperand->users()) {
+ auto *UI = cast<Instruction>(U);
+ if (UI == &I) {
+ CurrentCost +=
+ TTI.getArithmeticInstrCost(UI->getOpcode(), BigTy, CostKind);
+ ShrinkCost +=
+ TTI.getArithmeticInstrCost(UI->getOpcode(), SmallTy, CostKind);
+ ShrinkCost += ZExtCost;
+ continue;
+ }
+
+ if (!Instruction::isBinaryOp(UI->getOpcode()))
+ return false;
+
+ // Check if we can propagate ZExt through its other users
+ KB = computeKnownBits(UI, *DL);
+ if (KB.countMaxActiveBits() > BW)
+ return false;
+
+ CurrentCost += TTI.getArithmeticInstrCost(UI->getOpcode(), BigTy, CostKind);
+ ShrinkCost +=
+ TTI.getArithmeticInstrCost(UI->getOpcode(), SmallTy, CostKind);
+ ShrinkCost += ZExtCost;
+ }
+
+ // If the other instruction operand is not a constant, we'll need to
+ // generate a truncate instruction. So we have to adjust cost
+ if (!isa<Constant>(OtherOperand))
+ ShrinkCost += TTI.getCastInstrCost(
+ Instruction::Trunc, SmallTy, BigTy,
+ TargetTransformInfo::CastContextHint::None, CostKind);
+
+ // If the cost of shrinking types and leaving the IR is the same, we'll lean
+ // towards modifying the IR because shrinking opens opportunities for other
+ // shrinking optimisations.
+ if (ShrinkCost > CurrentCost)
+ return false;
+
+ Value *Op0 = ZExted;
+ if (auto *OI = dyn_cast<Instruction>(OtherOperand))
+ Builder.SetInsertPoint(OI->getNextNode());
----------------
igogo-x86 wrote:
Created a fix - https://github.com/llvm/llvm-project/pull/108228
https://github.com/llvm/llvm-project/pull/104606
More information about the llvm-commits
mailing list