[llvm] [VectorCombine] Add type shrinking and zext propagation for fixed-width vector types (PR #104606)

Simon Pilgrim via llvm-commits llvm-commits at lists.llvm.org
Fri Aug 30 11:16:33 PDT 2024


================
@@ -2493,6 +2494,91 @@ bool VectorCombine::foldSelectShuffle(Instruction &I, bool FromReduction) {
   return true;
 }
 
+/// Check if instruction depends on ZExt and this ZExt can be moved after the
+/// instruction. Move ZExt if it is profitable
+bool VectorCombine::shrinkType(llvm::Instruction &I) {
+  Value *ZExted, *OtherOperand;
+  if (!match(&I, m_c_BitwiseLogic(m_ZExt(m_Value(ZExted)),
+                                  m_Value(OtherOperand))) &&
+      !match(&I, m_LShr(m_ZExt(m_Value(ZExted)), m_Value(OtherOperand))))
+    return false;
+
+  Instruction *ZExtOperand =
+      cast<Instruction>(I.getOperand(I.getOperand(0) == OtherOperand ? 1 : 0));
+
+  auto *BigTy = cast<FixedVectorType>(I.getType());
+  auto *SmallTy = cast<FixedVectorType>(ZExted->getType());
+  unsigned BW = SmallTy->getElementType()->getPrimitiveSizeInBits();
+
+  // Check that the expression overall uses at most the same number of bits as
+  // ZExted
+  KnownBits KB = computeKnownBits(&I, *DL);
+  unsigned IBW = KB.getBitWidth() - KB.Zero.countLeadingOnes();
+  if (IBW > BW)
+    return false;
+
+  // Calculate costs of leaving current IR as it is and moving ZExt operation
+  // later, along with adding truncates if needed
+  InstructionCost ZExtCost = TTI.getCastInstrCost(
+      Instruction::ZExt, BigTy, SmallTy,
+      TargetTransformInfo::CastContextHint::None, TTI::TCK_RecipThroughput);
+  InstructionCost CurrentCost = ZExtCost;
+  InstructionCost ShrinkCost = 0;
+
+  // Calculate total cost and check that we can propagate through all ZExt users
+  for (User *U : ZExtOperand->users()) {
+    auto *UI = cast<Instruction>(U);
+    if (UI == &I) {
+      CurrentCost += TTI.getArithmeticInstrCost(UI->getOpcode(), BigTy);
+      ShrinkCost += TTI.getArithmeticInstrCost(UI->getOpcode(), SmallTy);
+      ShrinkCost += ZExtCost;
+      continue;
+    }
+
+    if (!Instruction::isBinaryOp(UI->getOpcode()))
+      return false;
+
+    // Check if we can propagate ZExt through its other users
+    KB = computeKnownBits(UI, *DL);
+    unsigned UBW = KB.getBitWidth() - KB.Zero.countLeadingOnes();
+    if (UBW > BW)
+      return false;
+
+    CurrentCost += TTI.getArithmeticInstrCost(UI->getOpcode(), BigTy);
+    ShrinkCost += TTI.getArithmeticInstrCost(UI->getOpcode(), SmallTy);
----------------
RKSimon wrote:

(style) always specify costkind

https://github.com/llvm/llvm-project/pull/104606


More information about the llvm-commits mailing list