[llvm] bf69484 - [VectorCombine] Add type shrinking and zext propagation for fixed-width vector types (#104606)

via llvm-commits llvm-commits at lists.llvm.org
Tue Sep 10 02:09:08 PDT 2024


Author: Igor Kirillov
Date: 2024-09-10T10:09:03+01:00
New Revision: bf694841f5b986f677e4fbe2a7ee93c77690d765

URL: https://github.com/llvm/llvm-project/commit/bf694841f5b986f677e4fbe2a7ee93c77690d765
DIFF: https://github.com/llvm/llvm-project/commit/bf694841f5b986f677e4fbe2a7ee93c77690d765.diff

LOG: [VectorCombine] Add type shrinking and zext propagation for fixed-width vector types (#104606)

Check that `binop(zext(value)`, other) is possible and profitable to transform
into: `zext(binop(value, trunc(other)))`.
When CPU architecture has illegal scalar type iX, but vector type <N * iX> is
legal, scalar expressions before vectorisation may be extended to a legal
type iY. This extension could result in underutilization of vector lanes,
as more lanes could be used at one instruction with the lower type.
Vectorisers may not always recognize opportunities for type shrinking, and
this patch aims to address that limitation.

Added: 
    llvm/test/Transforms/VectorCombine/AArch64/shrink-types.ll

Modified: 
    llvm/lib/Transforms/Vectorize/VectorCombine.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
index 99bd383ab0dead..54f6de34a76c93 100644
--- a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
+++ b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
@@ -119,6 +119,7 @@ class VectorCombine {
   bool foldShuffleFromReductions(Instruction &I);
   bool foldCastFromReductions(Instruction &I);
   bool foldSelectShuffle(Instruction &I, bool FromReduction = false);
+  bool shrinkType(Instruction &I);
 
   void replaceValue(Value &Old, Value &New) {
     Old.replaceAllUsesWith(&New);
@@ -2493,6 +2494,96 @@ bool VectorCombine::foldSelectShuffle(Instruction &I, bool FromReduction) {
   return true;
 }
 
+/// Check if instruction depends on ZExt and this ZExt can be moved after the
+/// instruction. Move ZExt if it is profitable. For example:
+///     logic(zext(x),y) -> zext(logic(x,trunc(y)))
+///     lshr((zext(x),y) -> zext(lshr(x,trunc(y)))
+/// Cost model calculations takes into account if zext(x) has other users and
+/// whether it can be propagated through them too.
+bool VectorCombine::shrinkType(llvm::Instruction &I) {
+  Value *ZExted, *OtherOperand;
+  if (!match(&I, m_c_BitwiseLogic(m_ZExt(m_Value(ZExted)),
+                                  m_Value(OtherOperand))) &&
+      !match(&I, m_LShr(m_ZExt(m_Value(ZExted)), m_Value(OtherOperand))))
+    return false;
+
+  Value *ZExtOperand = I.getOperand(I.getOperand(0) == OtherOperand ? 1 : 0);
+
+  auto *BigTy = cast<FixedVectorType>(I.getType());
+  auto *SmallTy = cast<FixedVectorType>(ZExted->getType());
+  unsigned BW = SmallTy->getElementType()->getPrimitiveSizeInBits();
+
+  // Check that the expression overall uses at most the same number of bits as
+  // ZExted
+  KnownBits KB = computeKnownBits(&I, *DL);
+  if (KB.countMaxActiveBits() > BW)
+    return false;
+
+  // Calculate costs of leaving current IR as it is and moving ZExt operation
+  // later, along with adding truncates if needed
+  TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
+  InstructionCost ZExtCost = TTI.getCastInstrCost(
+      Instruction::ZExt, BigTy, SmallTy,
+      TargetTransformInfo::CastContextHint::None, CostKind);
+  InstructionCost CurrentCost = ZExtCost;
+  InstructionCost ShrinkCost = 0;
+
+  // Calculate total cost and check that we can propagate through all ZExt users
+  for (User *U : ZExtOperand->users()) {
+    auto *UI = cast<Instruction>(U);
+    if (UI == &I) {
+      CurrentCost +=
+          TTI.getArithmeticInstrCost(UI->getOpcode(), BigTy, CostKind);
+      ShrinkCost +=
+          TTI.getArithmeticInstrCost(UI->getOpcode(), SmallTy, CostKind);
+      ShrinkCost += ZExtCost;
+      continue;
+    }
+
+    if (!Instruction::isBinaryOp(UI->getOpcode()))
+      return false;
+
+    // Check if we can propagate ZExt through its other users
+    KB = computeKnownBits(UI, *DL);
+    if (KB.countMaxActiveBits() > BW)
+      return false;
+
+    CurrentCost += TTI.getArithmeticInstrCost(UI->getOpcode(), BigTy, CostKind);
+    ShrinkCost +=
+        TTI.getArithmeticInstrCost(UI->getOpcode(), SmallTy, CostKind);
+    ShrinkCost += ZExtCost;
+  }
+
+  // If the other instruction operand is not a constant, we'll need to
+  // generate a truncate instruction. So we have to adjust cost
+  if (!isa<Constant>(OtherOperand))
+    ShrinkCost += TTI.getCastInstrCost(
+        Instruction::Trunc, SmallTy, BigTy,
+        TargetTransformInfo::CastContextHint::None, CostKind);
+
+  // If the cost of shrinking types and leaving the IR is the same, we'll lean
+  // towards modifying the IR because shrinking opens opportunities for other
+  // shrinking optimisations.
+  if (ShrinkCost > CurrentCost)
+    return false;
+
+  Value *Op0 = ZExted;
+  if (auto *OI = dyn_cast<Instruction>(OtherOperand))
+    Builder.SetInsertPoint(OI->getNextNode());
+  Value *Op1 = Builder.CreateTrunc(OtherOperand, SmallTy);
+  Builder.SetInsertPoint(&I);
+  // Keep the order of operands the same
+  if (I.getOperand(0) == OtherOperand)
+    std::swap(Op0, Op1);
+  Value *NewBinOp =
+      Builder.CreateBinOp((Instruction::BinaryOps)I.getOpcode(), Op0, Op1);
+  cast<Instruction>(NewBinOp)->copyIRFlags(&I);
+  cast<Instruction>(NewBinOp)->copyMetadata(I);
+  Value *NewZExtr = Builder.CreateZExt(NewBinOp, BigTy);
+  replaceValue(I, *NewZExtr);
+  return true;
+}
+
 /// This is the entry point for all transforms. Pass manager 
diff erences are
 /// handled in the callers of this function.
 bool VectorCombine::run() {
@@ -2560,6 +2651,9 @@ bool VectorCombine::run() {
       case Instruction::BitCast:
         MadeChange |= foldBitcastShuffle(I);
         break;
+      default:
+        MadeChange |= shrinkType(I);
+        break;
       }
     } else {
       switch (Opcode) {

diff  --git a/llvm/test/Transforms/VectorCombine/AArch64/shrink-types.ll b/llvm/test/Transforms/VectorCombine/AArch64/shrink-types.ll
new file mode 100644
index 00000000000000..0166656cf734f5
--- /dev/null
+++ b/llvm/test/Transforms/VectorCombine/AArch64/shrink-types.ll
@@ -0,0 +1,76 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
+; RUN: opt -passes=vector-combine -S %s | FileCheck %s
+
+target triple = "aarch64"
+
+define i32 @test_and(<16 x i32> %a, ptr %b) {
+; CHECK-LABEL: @test_and(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[WIDE_LOAD:%.*]] = load <16 x i8>, ptr [[B:%.*]], align 1
+; CHECK-NEXT:    [[TMP0:%.*]] = trunc <16 x i32> [[A:%.*]] to <16 x i8>
+; CHECK-NEXT:    [[TMP1:%.*]] = and <16 x i8> [[WIDE_LOAD]], [[TMP0]]
+; CHECK-NEXT:    [[TMP2:%.*]] = zext <16 x i8> [[TMP1]] to <16 x i32>
+; CHECK-NEXT:    [[TMP3:%.*]] = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP2]])
+; CHECK-NEXT:    ret i32 [[TMP3]]
+;
+entry:
+  %wide.load = load <16 x i8>, ptr %b, align 1
+  %0 = zext <16 x i8> %wide.load to <16 x i32>
+  %1 = and <16 x i32> %0, %a
+  %2 = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> %1)
+  ret i32 %2
+}
+
+define i32 @test_mask_or(<16 x i32> %a, ptr %b) {
+; CHECK-LABEL: @test_mask_or(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[WIDE_LOAD:%.*]] = load <16 x i8>, ptr [[B:%.*]], align 1
+; CHECK-NEXT:    [[A_MASKED:%.*]] = and <16 x i32> [[A:%.*]], <i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16>
+; CHECK-NEXT:    [[TMP0:%.*]] = trunc <16 x i32> [[A_MASKED]] to <16 x i8>
+; CHECK-NEXT:    [[TMP1:%.*]] = or <16 x i8> [[WIDE_LOAD]], [[TMP0]]
+; CHECK-NEXT:    [[TMP2:%.*]] = zext <16 x i8> [[TMP1]] to <16 x i32>
+; CHECK-NEXT:    [[TMP3:%.*]] = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP2]])
+; CHECK-NEXT:    ret i32 [[TMP3]]
+;
+entry:
+  %wide.load = load <16 x i8>, ptr %b, align 1
+  %a.masked = and <16 x i32> %a, <i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16>
+  %0 = zext <16 x i8> %wide.load to <16 x i32>
+  %1 = or <16 x i32> %0, %a.masked
+  %2 = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> %1)
+  ret i32 %2
+}
+
+define i32 @multiuse(<16 x i32> %u, <16 x i32> %v, ptr %b) {
+; CHECK-LABEL: @multiuse(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[U_MASKED:%.*]] = and <16 x i32> [[U:%.*]], <i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255>
+; CHECK-NEXT:    [[TMP0:%.*]] = trunc <16 x i32> [[U_MASKED]] to <16 x i8>
+; CHECK-NEXT:    [[V_MASKED:%.*]] = and <16 x i32> [[V:%.*]], <i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255>
+; CHECK-NEXT:    [[TMP1:%.*]] = trunc <16 x i32> [[V_MASKED]] to <16 x i8>
+; CHECK-NEXT:    [[WIDE_LOAD:%.*]] = load <16 x i8>, ptr [[B:%.*]], align 1
+; CHECK-NEXT:    [[TMP2:%.*]] = lshr <16 x i8> [[WIDE_LOAD]], <i8 4, i8 4, i8 4, i8 4, i8 4, i8 4, i8 4, i8 4, i8 4, i8 4, i8 4, i8 4, i8 4, i8 4, i8 4, i8 4>
+; CHECK-NEXT:    [[TMP3:%.*]] = or <16 x i8> [[TMP2]], [[TMP1]]
+; CHECK-NEXT:    [[TMP4:%.*]] = zext <16 x i8> [[TMP3]] to <16 x i32>
+; CHECK-NEXT:    [[TMP5:%.*]] = and <16 x i8> [[WIDE_LOAD]], <i8 15, i8 15, i8 15, i8 15, i8 15, i8 15, i8 15, i8 15, i8 15, i8 15, i8 15, i8 15, i8 15, i8 15, i8 15, i8 15>
+; CHECK-NEXT:    [[TMP6:%.*]] = or <16 x i8> [[TMP5]], [[TMP0]]
+; CHECK-NEXT:    [[TMP7:%.*]] = zext <16 x i8> [[TMP6]] to <16 x i32>
+; CHECK-NEXT:    [[TMP8:%.*]] = add nuw nsw <16 x i32> [[TMP4]], [[TMP7]]
+; CHECK-NEXT:    [[TMP9:%.*]] = call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP8]])
+; CHECK-NEXT:    ret i32 [[TMP9]]
+;
+entry:
+  %u.masked = and <16 x i32> %u, <i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255>
+  %v.masked = and <16 x i32> %v, <i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255>
+  %wide.load = load <16 x i8>, ptr %b, align 1
+  %0 = zext <16 x i8> %wide.load to <16 x i32>
+  %1 = lshr <16 x i32> %0, <i32 4, i32 4, i32 4, i32 4, i32 4, i32 4, i32 4, i32 4, i32 4, i32 4, i32 4, i32 4, i32 4, i32 4, i32 4, i32 4>
+  %2 = or <16 x i32> %1, %v.masked
+  %3 = and <16 x i32> %0, <i32 15, i32 15, i32 15, i32 15, i32 15, i32 15, i32 15, i32 15, i32 15, i32 15, i32 15, i32 15, i32 15, i32 15, i32 15, i32 15>
+  %4 = or <16 x i32> %3, %u.masked
+  %5 = add nuw nsw <16 x i32> %2, %4
+  %6 = call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> %5)
+  ret i32 %6
+}
+
+declare i32 @llvm.vector.reduce.add.v16i32(<16 x i32>)


        


More information about the llvm-commits mailing list