[llvm] [VectorCombine] Add type shrinking and zext propagation for fixed-width vector types (PR #104606)

via llvm-commits llvm-commits at lists.llvm.org
Fri Aug 16 08:29:38 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-transforms

Author: Igor Kirillov (igogo-x86)

<details>
<summary>Changes</summary>

Check that `binop(zext(value)`, other) is possible and profitable to transform into: `zext(binop(value, trunc(other)))`.
When CPU architecture has illegal scalar type iX, but vector type <N * iX> is legal, scalar expressions before vectorisation may be extended to a legal type iY. This extension could result in underutilization of vector lanes, as more lanes could be used at one instruction with the lower type. Vectorisers may not always recognize opportunities for type shrinking, and this patch aims to address that limitation.

---
Full diff: https://github.com/llvm/llvm-project/pull/104606.diff


2 Files Affected:

- (modified) llvm/lib/Transforms/Vectorize/VectorCombine.cpp (+104) 
- (added) llvm/test/Transforms/VectorCombine/AArch64/shrink-types.ll (+76) 


``````````diff
diff --git a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
index 99bd383ab0dead..c2f4315928a40c 100644
--- a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
+++ b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
@@ -119,6 +119,7 @@ class VectorCombine {
   bool foldShuffleFromReductions(Instruction &I);
   bool foldCastFromReductions(Instruction &I);
   bool foldSelectShuffle(Instruction &I, bool FromReduction = false);
+  bool shrinkType(Instruction &I);
 
   void replaceValue(Value &Old, Value &New) {
     Old.replaceAllUsesWith(&New);
@@ -2493,6 +2494,106 @@ bool VectorCombine::foldSelectShuffle(Instruction &I, bool FromReduction) {
   return true;
 }
 
+/// Check if instruction depends on ZExt and this ZExt can be moved after the
+/// instruction. Move ZExt if it is profitable
+bool VectorCombine::shrinkType(llvm::Instruction &I) {
+  Value *ZExted, *OtherOperand;
+  if (match(&I, m_c_BinOp(m_ZExt(m_Value(ZExted)), m_Value(OtherOperand)))) {
+    if (I.getOpcode() != Instruction::And && I.getOpcode() != Instruction::Or &&
+        I.getOpcode() != Instruction::Xor && I.getOpcode() != Instruction::LShr)
+      return false;
+
+    // In case of LShr extraction, ZExtOperand should be applied to the first
+    // operand
+    if (I.getOpcode() == Instruction::LShr && I.getOperand(1) != OtherOperand)
+      return false;
+
+    Instruction *ZExtOperand = cast<Instruction>(
+        I.getOperand(I.getOperand(0) == OtherOperand ? 1 : 0));
+
+    auto *BigTy = cast<FixedVectorType>(I.getType());
+    auto *SmallTy = cast<FixedVectorType>(ZExted->getType());
+    auto BW = SmallTy->getElementType()->getPrimitiveSizeInBits();
+
+    // Check that the expression overall uses at most the same number of bits as
+    // ZExted
+    auto KB = computeKnownBits(&I, *DL);
+    auto IBW = KB.getBitWidth() - KB.Zero.countLeadingOnes();
+    if (IBW > BW)
+      return false;
+
+    bool HasUNZExtableUser = false;
+
+    // Calculate costs of leaving current IR as it is and moving ZExt operation
+    // later, along with adding truncates if needed
+    InstructionCost ZExtCost = TTI.getCastInstrCost(
+        Instruction::ZExt, BigTy, SmallTy,
+        TargetTransformInfo::CastContextHint::None, TTI::TCK_RecipThroughput);
+    InstructionCost CurrentCost = ZExtCost;
+    InstructionCost ShrinkCost = 0;
+
+    for (User *U : ZExtOperand->users()) {
+      auto *UI = cast<Instruction>(U);
+      if (UI == &I) {
+        CurrentCost += TTI.getArithmeticInstrCost(UI->getOpcode(), BigTy);
+        ShrinkCost += TTI.getArithmeticInstrCost(UI->getOpcode(), SmallTy);
+        ShrinkCost += ZExtCost;
+        continue;
+      }
+
+      if (!Instruction::isBinaryOp(UI->getOpcode())) {
+        HasUNZExtableUser = true;
+        continue;
+      }
+
+      // Check if we can propagate ZExt through its other users
+      auto KB = computeKnownBits(UI, *DL);
+      auto UBW = KB.getBitWidth() - KB.Zero.countLeadingOnes();
+      if (UBW <= BW) {
+        CurrentCost += TTI.getArithmeticInstrCost(UI->getOpcode(), BigTy);
+        ShrinkCost += TTI.getArithmeticInstrCost(UI->getOpcode(), SmallTy);
+        ShrinkCost += ZExtCost;
+      } else {
+        HasUNZExtableUser = true;
+      }
+    }
+
+    // ZExt can't remove, add extra cost
+    if (HasUNZExtableUser)
+      ShrinkCost += ZExtCost;
+
+    // If the other instruction operand is not a constant, we'll need to
+    // generate a truncate instruction. So we have to adjust cost
+    if (!isa<Constant>(OtherOperand))
+      ShrinkCost += TTI.getCastInstrCost(
+          Instruction::Trunc, SmallTy, BigTy,
+          TargetTransformInfo::CastContextHint::None, TTI::TCK_RecipThroughput);
+
+    // If the cost of shrinking types and leaving the IR is the same, we'll lean
+    // towards modifying the IR because shrinking opens opportunities for other
+    // shrinking optimisations.
+    if (ShrinkCost > CurrentCost)
+      return false;
+
+    auto *Op0 = ZExted;
+    if (auto *OI = dyn_cast<Instruction>(OtherOperand))
+      Builder.SetInsertPoint(OI->getNextNode());
+    auto *Op1 = Builder.CreateTrunc(OtherOperand, SmallTy);
+    Builder.SetInsertPoint(&I);
+    // Keep the order of operands the same
+    if (I.getOperand(0) == OtherOperand)
+      std::swap(Op0, Op1);
+    auto *NewBinOp =
+        Builder.CreateBinOp((Instruction::BinaryOps)I.getOpcode(), Op0, Op1);
+    cast<Instruction>(NewBinOp)->copyIRFlags(&I);
+    cast<Instruction>(NewBinOp)->copyMetadata(I);
+    auto *NewZExtr = Builder.CreateZExt(NewBinOp, BigTy);
+    replaceValue(I, *NewZExtr);
+    return true;
+  }
+  return false;
+}
+
 /// This is the entry point for all transforms. Pass manager differences are
 /// handled in the callers of this function.
 bool VectorCombine::run() {
@@ -2560,6 +2661,9 @@ bool VectorCombine::run() {
       case Instruction::BitCast:
         MadeChange |= foldBitcastShuffle(I);
         break;
+      default:
+        MadeChange |= shrinkType(I);
+        break;
       }
     } else {
       switch (Opcode) {
diff --git a/llvm/test/Transforms/VectorCombine/AArch64/shrink-types.ll b/llvm/test/Transforms/VectorCombine/AArch64/shrink-types.ll
new file mode 100644
index 00000000000000..0166656cf734f5
--- /dev/null
+++ b/llvm/test/Transforms/VectorCombine/AArch64/shrink-types.ll
@@ -0,0 +1,76 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
+; RUN: opt -passes=vector-combine -S %s | FileCheck %s
+
+target triple = "aarch64"
+
+define i32 @test_and(<16 x i32> %a, ptr %b) {
+; CHECK-LABEL: @test_and(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[WIDE_LOAD:%.*]] = load <16 x i8>, ptr [[B:%.*]], align 1
+; CHECK-NEXT:    [[TMP0:%.*]] = trunc <16 x i32> [[A:%.*]] to <16 x i8>
+; CHECK-NEXT:    [[TMP1:%.*]] = and <16 x i8> [[WIDE_LOAD]], [[TMP0]]
+; CHECK-NEXT:    [[TMP2:%.*]] = zext <16 x i8> [[TMP1]] to <16 x i32>
+; CHECK-NEXT:    [[TMP3:%.*]] = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP2]])
+; CHECK-NEXT:    ret i32 [[TMP3]]
+;
+entry:
+  %wide.load = load <16 x i8>, ptr %b, align 1
+  %0 = zext <16 x i8> %wide.load to <16 x i32>
+  %1 = and <16 x i32> %0, %a
+  %2 = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> %1)
+  ret i32 %2
+}
+
+define i32 @test_mask_or(<16 x i32> %a, ptr %b) {
+; CHECK-LABEL: @test_mask_or(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[WIDE_LOAD:%.*]] = load <16 x i8>, ptr [[B:%.*]], align 1
+; CHECK-NEXT:    [[A_MASKED:%.*]] = and <16 x i32> [[A:%.*]], <i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16>
+; CHECK-NEXT:    [[TMP0:%.*]] = trunc <16 x i32> [[A_MASKED]] to <16 x i8>
+; CHECK-NEXT:    [[TMP1:%.*]] = or <16 x i8> [[WIDE_LOAD]], [[TMP0]]
+; CHECK-NEXT:    [[TMP2:%.*]] = zext <16 x i8> [[TMP1]] to <16 x i32>
+; CHECK-NEXT:    [[TMP3:%.*]] = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP2]])
+; CHECK-NEXT:    ret i32 [[TMP3]]
+;
+entry:
+  %wide.load = load <16 x i8>, ptr %b, align 1
+  %a.masked = and <16 x i32> %a, <i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16>
+  %0 = zext <16 x i8> %wide.load to <16 x i32>
+  %1 = or <16 x i32> %0, %a.masked
+  %2 = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> %1)
+  ret i32 %2
+}
+
+define i32 @multiuse(<16 x i32> %u, <16 x i32> %v, ptr %b) {
+; CHECK-LABEL: @multiuse(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[U_MASKED:%.*]] = and <16 x i32> [[U:%.*]], <i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255>
+; CHECK-NEXT:    [[TMP0:%.*]] = trunc <16 x i32> [[U_MASKED]] to <16 x i8>
+; CHECK-NEXT:    [[V_MASKED:%.*]] = and <16 x i32> [[V:%.*]], <i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255>
+; CHECK-NEXT:    [[TMP1:%.*]] = trunc <16 x i32> [[V_MASKED]] to <16 x i8>
+; CHECK-NEXT:    [[WIDE_LOAD:%.*]] = load <16 x i8>, ptr [[B:%.*]], align 1
+; CHECK-NEXT:    [[TMP2:%.*]] = lshr <16 x i8> [[WIDE_LOAD]], <i8 4, i8 4, i8 4, i8 4, i8 4, i8 4, i8 4, i8 4, i8 4, i8 4, i8 4, i8 4, i8 4, i8 4, i8 4, i8 4>
+; CHECK-NEXT:    [[TMP3:%.*]] = or <16 x i8> [[TMP2]], [[TMP1]]
+; CHECK-NEXT:    [[TMP4:%.*]] = zext <16 x i8> [[TMP3]] to <16 x i32>
+; CHECK-NEXT:    [[TMP5:%.*]] = and <16 x i8> [[WIDE_LOAD]], <i8 15, i8 15, i8 15, i8 15, i8 15, i8 15, i8 15, i8 15, i8 15, i8 15, i8 15, i8 15, i8 15, i8 15, i8 15, i8 15>
+; CHECK-NEXT:    [[TMP6:%.*]] = or <16 x i8> [[TMP5]], [[TMP0]]
+; CHECK-NEXT:    [[TMP7:%.*]] = zext <16 x i8> [[TMP6]] to <16 x i32>
+; CHECK-NEXT:    [[TMP8:%.*]] = add nuw nsw <16 x i32> [[TMP4]], [[TMP7]]
+; CHECK-NEXT:    [[TMP9:%.*]] = call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP8]])
+; CHECK-NEXT:    ret i32 [[TMP9]]
+;
+entry:
+  %u.masked = and <16 x i32> %u, <i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255>
+  %v.masked = and <16 x i32> %v, <i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255>
+  %wide.load = load <16 x i8>, ptr %b, align 1
+  %0 = zext <16 x i8> %wide.load to <16 x i32>
+  %1 = lshr <16 x i32> %0, <i32 4, i32 4, i32 4, i32 4, i32 4, i32 4, i32 4, i32 4, i32 4, i32 4, i32 4, i32 4, i32 4, i32 4, i32 4, i32 4>
+  %2 = or <16 x i32> %1, %v.masked
+  %3 = and <16 x i32> %0, <i32 15, i32 15, i32 15, i32 15, i32 15, i32 15, i32 15, i32 15, i32 15, i32 15, i32 15, i32 15, i32 15, i32 15, i32 15, i32 15>
+  %4 = or <16 x i32> %3, %u.masked
+  %5 = add nuw nsw <16 x i32> %2, %4
+  %6 = call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> %5)
+  ret i32 %6
+}
+
+declare i32 @llvm.vector.reduce.add.v16i32(<16 x i32>)

``````````

</details>


https://github.com/llvm/llvm-project/pull/104606


More information about the llvm-commits mailing list