[llvm] [VectorCombine] Add type shrinking and zext propagation for fixed-width vector types (PR #104606)

Igor Kirillov via llvm-commits llvm-commits at lists.llvm.org
Mon Aug 19 09:01:35 PDT 2024


https://github.com/igogo-x86 updated https://github.com/llvm/llvm-project/pull/104606

>From 6b57b3941e3b732eae504b7b957cf264af54c80f Mon Sep 17 00:00:00 2001
From: Igor Kirillov <igor.kirillov at arm.com>
Date: Fri, 16 Aug 2024 15:06:00 +0000
Subject: [PATCH 1/2] [VectorCombine] Add type shrinking and zext propagation
 for fixed-width vector types

Check that binop(zext(value), other) is possible and profitable to transform
into: zext(binop(value, trunc(other))).
When CPU architecture has illegal scalar type iX, but vector type <N * iX> is
legal, scalar expressions before vectorisation may be extended to a legal
type iY. This extension could result in underutilization of vector lanes,
as more lanes could be used at one instruction with the lower type.
Vectorisers may not always recognize opportunities for type shrinking, and
this patch aims to address that limitation.
---
 .../Transforms/Vectorize/VectorCombine.cpp    | 104 ++++++++++++++++++
 .../VectorCombine/AArch64/shrink-types.ll     |  76 +++++++++++++
 2 files changed, 180 insertions(+)
 create mode 100644 llvm/test/Transforms/VectorCombine/AArch64/shrink-types.ll

diff --git a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
index 99bd383ab0dead..c2f4315928a40c 100644
--- a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
+++ b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
@@ -119,6 +119,7 @@ class VectorCombine {
   bool foldShuffleFromReductions(Instruction &I);
   bool foldCastFromReductions(Instruction &I);
   bool foldSelectShuffle(Instruction &I, bool FromReduction = false);
+  bool shrinkType(Instruction &I);
 
   void replaceValue(Value &Old, Value &New) {
     Old.replaceAllUsesWith(&New);
@@ -2493,6 +2494,106 @@ bool VectorCombine::foldSelectShuffle(Instruction &I, bool FromReduction) {
   return true;
 }
 
+/// Check if instruction depends on ZExt and this ZExt can be moved after the
+/// instruction. Move ZExt if it is profitable
+bool VectorCombine::shrinkType(llvm::Instruction &I) {
+  Value *ZExted, *OtherOperand;
+  if (match(&I, m_c_BinOp(m_ZExt(m_Value(ZExted)), m_Value(OtherOperand)))) {
+    if (I.getOpcode() != Instruction::And && I.getOpcode() != Instruction::Or &&
+        I.getOpcode() != Instruction::Xor && I.getOpcode() != Instruction::LShr)
+      return false;
+
+    // In case of LShr extraction, ZExtOperand should be applied to the first
+    // operand
+    if (I.getOpcode() == Instruction::LShr && I.getOperand(1) != OtherOperand)
+      return false;
+
+    Instruction *ZExtOperand = cast<Instruction>(
+        I.getOperand(I.getOperand(0) == OtherOperand ? 1 : 0));
+
+    auto *BigTy = cast<FixedVectorType>(I.getType());
+    auto *SmallTy = cast<FixedVectorType>(ZExted->getType());
+    auto BW = SmallTy->getElementType()->getPrimitiveSizeInBits();
+
+    // Check that the expression overall uses at most the same number of bits as
+    // ZExted
+    auto KB = computeKnownBits(&I, *DL);
+    auto IBW = KB.getBitWidth() - KB.Zero.countLeadingOnes();
+    if (IBW > BW)
+      return false;
+
+    bool HasUNZExtableUser = false;
+
+    // Calculate costs of leaving current IR as it is and moving ZExt operation
+    // later, along with adding truncates if needed
+    InstructionCost ZExtCost = TTI.getCastInstrCost(
+        Instruction::ZExt, BigTy, SmallTy,
+        TargetTransformInfo::CastContextHint::None, TTI::TCK_RecipThroughput);
+    InstructionCost CurrentCost = ZExtCost;
+    InstructionCost ShrinkCost = 0;
+
+    for (User *U : ZExtOperand->users()) {
+      auto *UI = cast<Instruction>(U);
+      if (UI == &I) {
+        CurrentCost += TTI.getArithmeticInstrCost(UI->getOpcode(), BigTy);
+        ShrinkCost += TTI.getArithmeticInstrCost(UI->getOpcode(), SmallTy);
+        ShrinkCost += ZExtCost;
+        continue;
+      }
+
+      if (!Instruction::isBinaryOp(UI->getOpcode())) {
+        HasUNZExtableUser = true;
+        continue;
+      }
+
+      // Check if we can propagate ZExt through its other users
+      auto KB = computeKnownBits(UI, *DL);
+      auto UBW = KB.getBitWidth() - KB.Zero.countLeadingOnes();
+      if (UBW <= BW) {
+        CurrentCost += TTI.getArithmeticInstrCost(UI->getOpcode(), BigTy);
+        ShrinkCost += TTI.getArithmeticInstrCost(UI->getOpcode(), SmallTy);
+        ShrinkCost += ZExtCost;
+      } else {
+        HasUNZExtableUser = true;
+      }
+    }
+
+    // ZExt can't remove, add extra cost
+    if (HasUNZExtableUser)
+      ShrinkCost += ZExtCost;
+
+    // If the other instruction operand is not a constant, we'll need to
+    // generate a truncate instruction. So we have to adjust cost
+    if (!isa<Constant>(OtherOperand))
+      ShrinkCost += TTI.getCastInstrCost(
+          Instruction::Trunc, SmallTy, BigTy,
+          TargetTransformInfo::CastContextHint::None, TTI::TCK_RecipThroughput);
+
+    // If the cost of shrinking types and leaving the IR is the same, we'll lean
+    // towards modifying the IR because shrinking opens opportunities for other
+    // shrinking optimisations.
+    if (ShrinkCost > CurrentCost)
+      return false;
+
+    auto *Op0 = ZExted;
+    if (auto *OI = dyn_cast<Instruction>(OtherOperand))
+      Builder.SetInsertPoint(OI->getNextNode());
+    auto *Op1 = Builder.CreateTrunc(OtherOperand, SmallTy);
+    Builder.SetInsertPoint(&I);
+    // Keep the order of operands the same
+    if (I.getOperand(0) == OtherOperand)
+      std::swap(Op0, Op1);
+    auto *NewBinOp =
+        Builder.CreateBinOp((Instruction::BinaryOps)I.getOpcode(), Op0, Op1);
+    cast<Instruction>(NewBinOp)->copyIRFlags(&I);
+    cast<Instruction>(NewBinOp)->copyMetadata(I);
+    auto *NewZExtr = Builder.CreateZExt(NewBinOp, BigTy);
+    replaceValue(I, *NewZExtr);
+    return true;
+  }
+  return false;
+}
+
 /// This is the entry point for all transforms. Pass manager differences are
 /// handled in the callers of this function.
 bool VectorCombine::run() {
@@ -2560,6 +2661,9 @@ bool VectorCombine::run() {
       case Instruction::BitCast:
         MadeChange |= foldBitcastShuffle(I);
         break;
+      default:
+        MadeChange |= shrinkType(I);
+        break;
       }
     } else {
       switch (Opcode) {
diff --git a/llvm/test/Transforms/VectorCombine/AArch64/shrink-types.ll b/llvm/test/Transforms/VectorCombine/AArch64/shrink-types.ll
new file mode 100644
index 00000000000000..0166656cf734f5
--- /dev/null
+++ b/llvm/test/Transforms/VectorCombine/AArch64/shrink-types.ll
@@ -0,0 +1,76 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
+; RUN: opt -passes=vector-combine -S %s | FileCheck %s
+
+target triple = "aarch64"
+
+define i32 @test_and(<16 x i32> %a, ptr %b) {
+; CHECK-LABEL: @test_and(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[WIDE_LOAD:%.*]] = load <16 x i8>, ptr [[B:%.*]], align 1
+; CHECK-NEXT:    [[TMP0:%.*]] = trunc <16 x i32> [[A:%.*]] to <16 x i8>
+; CHECK-NEXT:    [[TMP1:%.*]] = and <16 x i8> [[WIDE_LOAD]], [[TMP0]]
+; CHECK-NEXT:    [[TMP2:%.*]] = zext <16 x i8> [[TMP1]] to <16 x i32>
+; CHECK-NEXT:    [[TMP3:%.*]] = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP2]])
+; CHECK-NEXT:    ret i32 [[TMP3]]
+;
+entry:
+  %wide.load = load <16 x i8>, ptr %b, align 1
+  %0 = zext <16 x i8> %wide.load to <16 x i32>
+  %1 = and <16 x i32> %0, %a
+  %2 = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> %1)
+  ret i32 %2
+}
+
+define i32 @test_mask_or(<16 x i32> %a, ptr %b) {
+; CHECK-LABEL: @test_mask_or(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[WIDE_LOAD:%.*]] = load <16 x i8>, ptr [[B:%.*]], align 1
+; CHECK-NEXT:    [[A_MASKED:%.*]] = and <16 x i32> [[A:%.*]], <i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16>
+; CHECK-NEXT:    [[TMP0:%.*]] = trunc <16 x i32> [[A_MASKED]] to <16 x i8>
+; CHECK-NEXT:    [[TMP1:%.*]] = or <16 x i8> [[WIDE_LOAD]], [[TMP0]]
+; CHECK-NEXT:    [[TMP2:%.*]] = zext <16 x i8> [[TMP1]] to <16 x i32>
+; CHECK-NEXT:    [[TMP3:%.*]] = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP2]])
+; CHECK-NEXT:    ret i32 [[TMP3]]
+;
+entry:
+  %wide.load = load <16 x i8>, ptr %b, align 1
+  %a.masked = and <16 x i32> %a, <i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16>
+  %0 = zext <16 x i8> %wide.load to <16 x i32>
+  %1 = or <16 x i32> %0, %a.masked
+  %2 = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> %1)
+  ret i32 %2
+}
+
+define i32 @multiuse(<16 x i32> %u, <16 x i32> %v, ptr %b) {
+; CHECK-LABEL: @multiuse(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[U_MASKED:%.*]] = and <16 x i32> [[U:%.*]], <i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255>
+; CHECK-NEXT:    [[TMP0:%.*]] = trunc <16 x i32> [[U_MASKED]] to <16 x i8>
+; CHECK-NEXT:    [[V_MASKED:%.*]] = and <16 x i32> [[V:%.*]], <i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255>
+; CHECK-NEXT:    [[TMP1:%.*]] = trunc <16 x i32> [[V_MASKED]] to <16 x i8>
+; CHECK-NEXT:    [[WIDE_LOAD:%.*]] = load <16 x i8>, ptr [[B:%.*]], align 1
+; CHECK-NEXT:    [[TMP2:%.*]] = lshr <16 x i8> [[WIDE_LOAD]], <i8 4, i8 4, i8 4, i8 4, i8 4, i8 4, i8 4, i8 4, i8 4, i8 4, i8 4, i8 4, i8 4, i8 4, i8 4, i8 4>
+; CHECK-NEXT:    [[TMP3:%.*]] = or <16 x i8> [[TMP2]], [[TMP1]]
+; CHECK-NEXT:    [[TMP4:%.*]] = zext <16 x i8> [[TMP3]] to <16 x i32>
+; CHECK-NEXT:    [[TMP5:%.*]] = and <16 x i8> [[WIDE_LOAD]], <i8 15, i8 15, i8 15, i8 15, i8 15, i8 15, i8 15, i8 15, i8 15, i8 15, i8 15, i8 15, i8 15, i8 15, i8 15, i8 15>
+; CHECK-NEXT:    [[TMP6:%.*]] = or <16 x i8> [[TMP5]], [[TMP0]]
+; CHECK-NEXT:    [[TMP7:%.*]] = zext <16 x i8> [[TMP6]] to <16 x i32>
+; CHECK-NEXT:    [[TMP8:%.*]] = add nuw nsw <16 x i32> [[TMP4]], [[TMP7]]
+; CHECK-NEXT:    [[TMP9:%.*]] = call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP8]])
+; CHECK-NEXT:    ret i32 [[TMP9]]
+;
+entry:
+  %u.masked = and <16 x i32> %u, <i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255>
+  %v.masked = and <16 x i32> %v, <i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255>
+  %wide.load = load <16 x i8>, ptr %b, align 1
+  %0 = zext <16 x i8> %wide.load to <16 x i32>
+  %1 = lshr <16 x i32> %0, <i32 4, i32 4, i32 4, i32 4, i32 4, i32 4, i32 4, i32 4, i32 4, i32 4, i32 4, i32 4, i32 4, i32 4, i32 4, i32 4>
+  %2 = or <16 x i32> %1, %v.masked
+  %3 = and <16 x i32> %0, <i32 15, i32 15, i32 15, i32 15, i32 15, i32 15, i32 15, i32 15, i32 15, i32 15, i32 15, i32 15, i32 15, i32 15, i32 15, i32 15>
+  %4 = or <16 x i32> %3, %u.masked
+  %5 = add nuw nsw <16 x i32> %2, %4
+  %6 = call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> %5)
+  ret i32 %6
+}
+
+declare i32 @llvm.vector.reduce.add.v16i32(<16 x i32>)

>From 049d02affc06f5fd6174ac357b63c13f92c0b38a Mon Sep 17 00:00:00 2001
From: Igor Kirillov <igor.kirillov at arm.com>
Date: Mon, 19 Aug 2024 15:51:54 +0000
Subject: [PATCH 2/2] Refactor code and fix failing
 PhaseOrdering/X86/pr50555.ll

---
 .../Transforms/Vectorize/VectorCombine.cpp    | 155 ++++++++----------
 1 file changed, 70 insertions(+), 85 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
index c2f4315928a40c..9aab88d8a173d5 100644
--- a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
+++ b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
@@ -2498,100 +2498,85 @@ bool VectorCombine::foldSelectShuffle(Instruction &I, bool FromReduction) {
 /// instruction. Move ZExt if it is profitable
 bool VectorCombine::shrinkType(llvm::Instruction &I) {
   Value *ZExted, *OtherOperand;
-  if (match(&I, m_c_BinOp(m_ZExt(m_Value(ZExted)), m_Value(OtherOperand)))) {
-    if (I.getOpcode() != Instruction::And && I.getOpcode() != Instruction::Or &&
-        I.getOpcode() != Instruction::Xor && I.getOpcode() != Instruction::LShr)
-      return false;
+  if (!match(&I, m_c_BitwiseLogic(m_ZExt(m_Value(ZExted)),
+                                  m_Value(OtherOperand))) &&
+      !match(&I, m_LShr(m_ZExt(m_Value(ZExted)), m_Value(OtherOperand))))
+    return false;
+
+  Instruction *ZExtOperand =
+      cast<Instruction>(I.getOperand(I.getOperand(0) == OtherOperand ? 1 : 0));
+
+  auto *BigTy = cast<FixedVectorType>(I.getType());
+  auto *SmallTy = cast<FixedVectorType>(ZExted->getType());
+  unsigned BW = SmallTy->getElementType()->getPrimitiveSizeInBits();
+
+  // Check that the expression overall uses at most the same number of bits as
+  // ZExted
+  KnownBits KB = computeKnownBits(&I, *DL);
+  unsigned IBW = KB.getBitWidth() - KB.Zero.countLeadingOnes();
+  if (IBW > BW)
+    return false;
+
+  // Calculate costs of leaving current IR as it is and moving ZExt operation
+  // later, along with adding truncates if needed
+  InstructionCost ZExtCost = TTI.getCastInstrCost(
+      Instruction::ZExt, BigTy, SmallTy,
+      TargetTransformInfo::CastContextHint::None, TTI::TCK_RecipThroughput);
+  InstructionCost CurrentCost = ZExtCost;
+  InstructionCost ShrinkCost = 0;
+
+  // Calculate total cost and check that we can propagate through all ZExt users
+  for (User *U : ZExtOperand->users()) {
+    auto *UI = cast<Instruction>(U);
+    if (UI == &I) {
+      CurrentCost += TTI.getArithmeticInstrCost(UI->getOpcode(), BigTy);
+      ShrinkCost += TTI.getArithmeticInstrCost(UI->getOpcode(), SmallTy);
+      ShrinkCost += ZExtCost;
+      continue;
+    }
 
-    // In case of LShr extraction, ZExtOperand should be applied to the first
-    // operand
-    if (I.getOpcode() == Instruction::LShr && I.getOperand(1) != OtherOperand)
+    if (!Instruction::isBinaryOp(UI->getOpcode()))
       return false;
 
-    Instruction *ZExtOperand = cast<Instruction>(
-        I.getOperand(I.getOperand(0) == OtherOperand ? 1 : 0));
-
-    auto *BigTy = cast<FixedVectorType>(I.getType());
-    auto *SmallTy = cast<FixedVectorType>(ZExted->getType());
-    auto BW = SmallTy->getElementType()->getPrimitiveSizeInBits();
-
-    // Check that the expression overall uses at most the same number of bits as
-    // ZExted
-    auto KB = computeKnownBits(&I, *DL);
-    auto IBW = KB.getBitWidth() - KB.Zero.countLeadingOnes();
-    if (IBW > BW)
+    // Check if we can propagate ZExt through its other users
+    KB = computeKnownBits(UI, *DL);
+    unsigned UBW = KB.getBitWidth() - KB.Zero.countLeadingOnes();
+    if (UBW > BW)
       return false;
 
-    bool HasUNZExtableUser = false;
+    CurrentCost += TTI.getArithmeticInstrCost(UI->getOpcode(), BigTy);
+    ShrinkCost += TTI.getArithmeticInstrCost(UI->getOpcode(), SmallTy);
+    ShrinkCost += ZExtCost;
+  }
 
-    // Calculate costs of leaving current IR as it is and moving ZExt operation
-    // later, along with adding truncates if needed
-    InstructionCost ZExtCost = TTI.getCastInstrCost(
-        Instruction::ZExt, BigTy, SmallTy,
+  // If the other instruction operand is not a constant, we'll need to
+  // generate a truncate instruction. So we have to adjust cost
+  if (!isa<Constant>(OtherOperand))
+    ShrinkCost += TTI.getCastInstrCost(
+        Instruction::Trunc, SmallTy, BigTy,
         TargetTransformInfo::CastContextHint::None, TTI::TCK_RecipThroughput);
-    InstructionCost CurrentCost = ZExtCost;
-    InstructionCost ShrinkCost = 0;
-
-    for (User *U : ZExtOperand->users()) {
-      auto *UI = cast<Instruction>(U);
-      if (UI == &I) {
-        CurrentCost += TTI.getArithmeticInstrCost(UI->getOpcode(), BigTy);
-        ShrinkCost += TTI.getArithmeticInstrCost(UI->getOpcode(), SmallTy);
-        ShrinkCost += ZExtCost;
-        continue;
-      }
-
-      if (!Instruction::isBinaryOp(UI->getOpcode())) {
-        HasUNZExtableUser = true;
-        continue;
-      }
-
-      // Check if we can propagate ZExt through its other users
-      auto KB = computeKnownBits(UI, *DL);
-      auto UBW = KB.getBitWidth() - KB.Zero.countLeadingOnes();
-      if (UBW <= BW) {
-        CurrentCost += TTI.getArithmeticInstrCost(UI->getOpcode(), BigTy);
-        ShrinkCost += TTI.getArithmeticInstrCost(UI->getOpcode(), SmallTy);
-        ShrinkCost += ZExtCost;
-      } else {
-        HasUNZExtableUser = true;
-      }
-    }
 
-    // ZExt can't remove, add extra cost
-    if (HasUNZExtableUser)
-      ShrinkCost += ZExtCost;
-
-    // If the other instruction operand is not a constant, we'll need to
-    // generate a truncate instruction. So we have to adjust cost
-    if (!isa<Constant>(OtherOperand))
-      ShrinkCost += TTI.getCastInstrCost(
-          Instruction::Trunc, SmallTy, BigTy,
-          TargetTransformInfo::CastContextHint::None, TTI::TCK_RecipThroughput);
-
-    // If the cost of shrinking types and leaving the IR is the same, we'll lean
-    // towards modifying the IR because shrinking opens opportunities for other
-    // shrinking optimisations.
-    if (ShrinkCost > CurrentCost)
-      return false;
+  // If the cost of shrinking types and leaving the IR is the same, we'll lean
+  // towards modifying the IR because shrinking opens opportunities for other
+  // shrinking optimisations.
+  if (ShrinkCost > CurrentCost)
+    return false;
 
-    auto *Op0 = ZExted;
-    if (auto *OI = dyn_cast<Instruction>(OtherOperand))
-      Builder.SetInsertPoint(OI->getNextNode());
-    auto *Op1 = Builder.CreateTrunc(OtherOperand, SmallTy);
-    Builder.SetInsertPoint(&I);
-    // Keep the order of operands the same
-    if (I.getOperand(0) == OtherOperand)
-      std::swap(Op0, Op1);
-    auto *NewBinOp =
-        Builder.CreateBinOp((Instruction::BinaryOps)I.getOpcode(), Op0, Op1);
-    cast<Instruction>(NewBinOp)->copyIRFlags(&I);
-    cast<Instruction>(NewBinOp)->copyMetadata(I);
-    auto *NewZExtr = Builder.CreateZExt(NewBinOp, BigTy);
-    replaceValue(I, *NewZExtr);
-    return true;
-  }
-  return false;
+  Value *Op0 = ZExted;
+  if (auto *OI = dyn_cast<Instruction>(OtherOperand))
+    Builder.SetInsertPoint(OI->getNextNode());
+  Value *Op1 = Builder.CreateTrunc(OtherOperand, SmallTy);
+  Builder.SetInsertPoint(&I);
+  // Keep the order of operands the same
+  if (I.getOperand(0) == OtherOperand)
+    std::swap(Op0, Op1);
+  Value *NewBinOp =
+      Builder.CreateBinOp((Instruction::BinaryOps)I.getOpcode(), Op0, Op1);
+  cast<Instruction>(NewBinOp)->copyIRFlags(&I);
+  cast<Instruction>(NewBinOp)->copyMetadata(I);
+  Value *NewZExtr = Builder.CreateZExt(NewBinOp, BigTy);
+  replaceValue(I, *NewZExtr);
+  return true;
 }
 
 /// This is the entry point for all transforms. Pass manager differences are



More information about the llvm-commits mailing list