[llvm] [VectorCombine] Add type shrinking and zext propagation for fixed-width vector types (PR #104606)
Igor Kirillov via llvm-commits
llvm-commits at lists.llvm.org
Fri Aug 16 08:29:04 PDT 2024
https://github.com/igogo-x86 created https://github.com/llvm/llvm-project/pull/104606
Check that `binop(zext(value)`, other) is possible and profitable to transform into: `zext(binop(value, trunc(other)))`.
When CPU architecture has illegal scalar type iX, but vector type <N * iX> is legal, scalar expressions before vectorisation may be extended to a legal type iY. This extension could result in underutilization of vector lanes, as more lanes could be used at one instruction with the lower type. Vectorisers may not always recognize opportunities for type shrinking, and this patch aims to address that limitation.
>From 0afa95ef8d2e2fcf06229e732b1b600c6ae6b219 Mon Sep 17 00:00:00 2001
From: Igor Kirillov <igor.kirillov at arm.com>
Date: Fri, 16 Aug 2024 15:06:00 +0000
Subject: [PATCH] [VectorCombine] Add type shrinking and zext propagation for
fixed-width vector types
Check that binop(zext(value), other) is possible and profitable to transform
into: zext(binop(value, trunc(other))).
When CPU architecture has illegal scalar type iX, but vector type <N * iX> is
legal, scalar expressions before vectorisation may be extended to a legal
type iY. This extension could result in underutilization of vector lanes,
as more lanes could be used at one instruction with the lower type.
Vectorisers may not always recognize opportunities for type shrinking, and
this patch aims to address that limitation.
---
.../Transforms/Vectorize/VectorCombine.cpp | 104 ++++++++++++++++++
.../VectorCombine/AArch64/shrink-types.ll | 76 +++++++++++++
2 files changed, 180 insertions(+)
create mode 100644 llvm/test/Transforms/VectorCombine/AArch64/shrink-types.ll
diff --git a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
index 99bd383ab0dead..c2f4315928a40c 100644
--- a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
+++ b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
@@ -119,6 +119,7 @@ class VectorCombine {
bool foldShuffleFromReductions(Instruction &I);
bool foldCastFromReductions(Instruction &I);
bool foldSelectShuffle(Instruction &I, bool FromReduction = false);
+ bool shrinkType(Instruction &I);
void replaceValue(Value &Old, Value &New) {
Old.replaceAllUsesWith(&New);
@@ -2493,6 +2494,106 @@ bool VectorCombine::foldSelectShuffle(Instruction &I, bool FromReduction) {
return true;
}
+/// Check if instruction depends on ZExt and this ZExt can be moved after the
+/// instruction. Move ZExt if it is profitable
+bool VectorCombine::shrinkType(llvm::Instruction &I) {
+ Value *ZExted, *OtherOperand;
+ if (match(&I, m_c_BinOp(m_ZExt(m_Value(ZExted)), m_Value(OtherOperand)))) {
+ if (I.getOpcode() != Instruction::And && I.getOpcode() != Instruction::Or &&
+ I.getOpcode() != Instruction::Xor && I.getOpcode() != Instruction::LShr)
+ return false;
+
+ // In case of LShr extraction, ZExtOperand should be applied to the first
+ // operand
+ if (I.getOpcode() == Instruction::LShr && I.getOperand(1) != OtherOperand)
+ return false;
+
+ Instruction *ZExtOperand = cast<Instruction>(
+ I.getOperand(I.getOperand(0) == OtherOperand ? 1 : 0));
+
+ auto *BigTy = cast<FixedVectorType>(I.getType());
+ auto *SmallTy = cast<FixedVectorType>(ZExted->getType());
+ auto BW = SmallTy->getElementType()->getPrimitiveSizeInBits();
+
+ // Check that the expression overall uses at most the same number of bits as
+ // ZExted
+ auto KB = computeKnownBits(&I, *DL);
+ auto IBW = KB.getBitWidth() - KB.Zero.countLeadingOnes();
+ if (IBW > BW)
+ return false;
+
+ bool HasUNZExtableUser = false;
+
+ // Calculate costs of leaving current IR as it is and moving ZExt operation
+ // later, along with adding truncates if needed
+ InstructionCost ZExtCost = TTI.getCastInstrCost(
+ Instruction::ZExt, BigTy, SmallTy,
+ TargetTransformInfo::CastContextHint::None, TTI::TCK_RecipThroughput);
+ InstructionCost CurrentCost = ZExtCost;
+ InstructionCost ShrinkCost = 0;
+
+ for (User *U : ZExtOperand->users()) {
+ auto *UI = cast<Instruction>(U);
+ if (UI == &I) {
+ CurrentCost += TTI.getArithmeticInstrCost(UI->getOpcode(), BigTy);
+ ShrinkCost += TTI.getArithmeticInstrCost(UI->getOpcode(), SmallTy);
+ ShrinkCost += ZExtCost;
+ continue;
+ }
+
+ if (!Instruction::isBinaryOp(UI->getOpcode())) {
+ HasUNZExtableUser = true;
+ continue;
+ }
+
+ // Check if we can propagate ZExt through its other users
+ auto KB = computeKnownBits(UI, *DL);
+ auto UBW = KB.getBitWidth() - KB.Zero.countLeadingOnes();
+ if (UBW <= BW) {
+ CurrentCost += TTI.getArithmeticInstrCost(UI->getOpcode(), BigTy);
+ ShrinkCost += TTI.getArithmeticInstrCost(UI->getOpcode(), SmallTy);
+ ShrinkCost += ZExtCost;
+ } else {
+ HasUNZExtableUser = true;
+ }
+ }
+
+ // ZExt can't remove, add extra cost
+ if (HasUNZExtableUser)
+ ShrinkCost += ZExtCost;
+
+ // If the other instruction operand is not a constant, we'll need to
+ // generate a truncate instruction. So we have to adjust cost
+ if (!isa<Constant>(OtherOperand))
+ ShrinkCost += TTI.getCastInstrCost(
+ Instruction::Trunc, SmallTy, BigTy,
+ TargetTransformInfo::CastContextHint::None, TTI::TCK_RecipThroughput);
+
+ // If the cost of shrinking types and leaving the IR is the same, we'll lean
+ // towards modifying the IR because shrinking opens opportunities for other
+ // shrinking optimisations.
+ if (ShrinkCost > CurrentCost)
+ return false;
+
+ auto *Op0 = ZExted;
+ if (auto *OI = dyn_cast<Instruction>(OtherOperand))
+ Builder.SetInsertPoint(OI->getNextNode());
+ auto *Op1 = Builder.CreateTrunc(OtherOperand, SmallTy);
+ Builder.SetInsertPoint(&I);
+ // Keep the order of operands the same
+ if (I.getOperand(0) == OtherOperand)
+ std::swap(Op0, Op1);
+ auto *NewBinOp =
+ Builder.CreateBinOp((Instruction::BinaryOps)I.getOpcode(), Op0, Op1);
+ cast<Instruction>(NewBinOp)->copyIRFlags(&I);
+ cast<Instruction>(NewBinOp)->copyMetadata(I);
+ auto *NewZExtr = Builder.CreateZExt(NewBinOp, BigTy);
+ replaceValue(I, *NewZExtr);
+ return true;
+ }
+ return false;
+}
+
/// This is the entry point for all transforms. Pass manager differences are
/// handled in the callers of this function.
bool VectorCombine::run() {
@@ -2560,6 +2661,9 @@ bool VectorCombine::run() {
case Instruction::BitCast:
MadeChange |= foldBitcastShuffle(I);
break;
+ default:
+ MadeChange |= shrinkType(I);
+ break;
}
} else {
switch (Opcode) {
diff --git a/llvm/test/Transforms/VectorCombine/AArch64/shrink-types.ll b/llvm/test/Transforms/VectorCombine/AArch64/shrink-types.ll
new file mode 100644
index 00000000000000..0166656cf734f5
--- /dev/null
+++ b/llvm/test/Transforms/VectorCombine/AArch64/shrink-types.ll
@@ -0,0 +1,76 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
+; RUN: opt -passes=vector-combine -S %s | FileCheck %s
+
+target triple = "aarch64"
+
+define i32 @test_and(<16 x i32> %a, ptr %b) {
+; CHECK-LABEL: @test_and(
+; CHECK-NEXT: entry:
+; CHECK-NEXT: [[WIDE_LOAD:%.*]] = load <16 x i8>, ptr [[B:%.*]], align 1
+; CHECK-NEXT: [[TMP0:%.*]] = trunc <16 x i32> [[A:%.*]] to <16 x i8>
+; CHECK-NEXT: [[TMP1:%.*]] = and <16 x i8> [[WIDE_LOAD]], [[TMP0]]
+; CHECK-NEXT: [[TMP2:%.*]] = zext <16 x i8> [[TMP1]] to <16 x i32>
+; CHECK-NEXT: [[TMP3:%.*]] = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP2]])
+; CHECK-NEXT: ret i32 [[TMP3]]
+;
+entry:
+ %wide.load = load <16 x i8>, ptr %b, align 1
+ %0 = zext <16 x i8> %wide.load to <16 x i32>
+ %1 = and <16 x i32> %0, %a
+ %2 = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> %1)
+ ret i32 %2
+}
+
+define i32 @test_mask_or(<16 x i32> %a, ptr %b) {
+; CHECK-LABEL: @test_mask_or(
+; CHECK-NEXT: entry:
+; CHECK-NEXT: [[WIDE_LOAD:%.*]] = load <16 x i8>, ptr [[B:%.*]], align 1
+; CHECK-NEXT: [[A_MASKED:%.*]] = and <16 x i32> [[A:%.*]], <i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16>
+; CHECK-NEXT: [[TMP0:%.*]] = trunc <16 x i32> [[A_MASKED]] to <16 x i8>
+; CHECK-NEXT: [[TMP1:%.*]] = or <16 x i8> [[WIDE_LOAD]], [[TMP0]]
+; CHECK-NEXT: [[TMP2:%.*]] = zext <16 x i8> [[TMP1]] to <16 x i32>
+; CHECK-NEXT: [[TMP3:%.*]] = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP2]])
+; CHECK-NEXT: ret i32 [[TMP3]]
+;
+entry:
+ %wide.load = load <16 x i8>, ptr %b, align 1
+ %a.masked = and <16 x i32> %a, <i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16>
+ %0 = zext <16 x i8> %wide.load to <16 x i32>
+ %1 = or <16 x i32> %0, %a.masked
+ %2 = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> %1)
+ ret i32 %2
+}
+
+define i32 @multiuse(<16 x i32> %u, <16 x i32> %v, ptr %b) {
+; CHECK-LABEL: @multiuse(
+; CHECK-NEXT: entry:
+; CHECK-NEXT: [[U_MASKED:%.*]] = and <16 x i32> [[U:%.*]], <i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255>
+; CHECK-NEXT: [[TMP0:%.*]] = trunc <16 x i32> [[U_MASKED]] to <16 x i8>
+; CHECK-NEXT: [[V_MASKED:%.*]] = and <16 x i32> [[V:%.*]], <i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255>
+; CHECK-NEXT: [[TMP1:%.*]] = trunc <16 x i32> [[V_MASKED]] to <16 x i8>
+; CHECK-NEXT: [[WIDE_LOAD:%.*]] = load <16 x i8>, ptr [[B:%.*]], align 1
+; CHECK-NEXT: [[TMP2:%.*]] = lshr <16 x i8> [[WIDE_LOAD]], <i8 4, i8 4, i8 4, i8 4, i8 4, i8 4, i8 4, i8 4, i8 4, i8 4, i8 4, i8 4, i8 4, i8 4, i8 4, i8 4>
+; CHECK-NEXT: [[TMP3:%.*]] = or <16 x i8> [[TMP2]], [[TMP1]]
+; CHECK-NEXT: [[TMP4:%.*]] = zext <16 x i8> [[TMP3]] to <16 x i32>
+; CHECK-NEXT: [[TMP5:%.*]] = and <16 x i8> [[WIDE_LOAD]], <i8 15, i8 15, i8 15, i8 15, i8 15, i8 15, i8 15, i8 15, i8 15, i8 15, i8 15, i8 15, i8 15, i8 15, i8 15, i8 15>
+; CHECK-NEXT: [[TMP6:%.*]] = or <16 x i8> [[TMP5]], [[TMP0]]
+; CHECK-NEXT: [[TMP7:%.*]] = zext <16 x i8> [[TMP6]] to <16 x i32>
+; CHECK-NEXT: [[TMP8:%.*]] = add nuw nsw <16 x i32> [[TMP4]], [[TMP7]]
+; CHECK-NEXT: [[TMP9:%.*]] = call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP8]])
+; CHECK-NEXT: ret i32 [[TMP9]]
+;
+entry:
+ %u.masked = and <16 x i32> %u, <i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255>
+ %v.masked = and <16 x i32> %v, <i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255>
+ %wide.load = load <16 x i8>, ptr %b, align 1
+ %0 = zext <16 x i8> %wide.load to <16 x i32>
+ %1 = lshr <16 x i32> %0, <i32 4, i32 4, i32 4, i32 4, i32 4, i32 4, i32 4, i32 4, i32 4, i32 4, i32 4, i32 4, i32 4, i32 4, i32 4, i32 4>
+ %2 = or <16 x i32> %1, %v.masked
+ %3 = and <16 x i32> %0, <i32 15, i32 15, i32 15, i32 15, i32 15, i32 15, i32 15, i32 15, i32 15, i32 15, i32 15, i32 15, i32 15, i32 15, i32 15, i32 15>
+ %4 = or <16 x i32> %3, %u.masked
+ %5 = add nuw nsw <16 x i32> %2, %4
+ %6 = call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> %5)
+ ret i32 %6
+}
+
+declare i32 @llvm.vector.reduce.add.v16i32(<16 x i32>)
More information about the llvm-commits
mailing list