[llvm] [VectorCombine] Fold reduce(trunc(x)) -> trunc(reduce(x)) iff cost effective (PR #81852)
Simon Pilgrim via llvm-commits
llvm-commits at lists.llvm.org
Mon Feb 19 03:16:58 PST 2024
https://github.com/RKSimon updated https://github.com/llvm/llvm-project/pull/81852
>From 7de4f825c1e91c6f87d48b4fea51da93598816b8 Mon Sep 17 00:00:00 2001
From: Simon Pilgrim <llvm-dev at redking.me.uk>
Date: Thu, 15 Feb 2024 12:43:21 +0000
Subject: [PATCH 1/2] [VectorCombine] Add test coverage for reduce(trunc(X)) ->
trunc(reduce(X)) folds
---
.../X86/reduction-of-truncations.ll | 121 ++++++++++++++++++
1 file changed, 121 insertions(+)
create mode 100644 llvm/test/Transforms/VectorCombine/X86/reduction-of-truncations.ll
diff --git a/llvm/test/Transforms/VectorCombine/X86/reduction-of-truncations.ll b/llvm/test/Transforms/VectorCombine/X86/reduction-of-truncations.ll
new file mode 100644
index 00000000000000..fdb0f5ce6f89f6
--- /dev/null
+++ b/llvm/test/Transforms/VectorCombine/X86/reduction-of-truncations.ll
@@ -0,0 +1,121 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
+; RUN: opt < %s -S --passes=vector-combine -mtriple=x86_64-- -mcpu=x86-64 | FileCheck %s --check-prefixes=CHECK,X64
+; RUN: opt < %s -S --passes=vector-combine -mtriple=x86_64-- -mcpu=x86-64-v2 | FileCheck %s --check-prefixes=CHECK,X64
+; RUN: opt < %s -S --passes=vector-combine -mtriple=x86_64-- -mcpu=x86-64-v3 | FileCheck %s --check-prefixes=CHECK,X64
+; RUN: opt < %s -S --passes=vector-combine -mtriple=x86_64-- -mcpu=x86-64-v4 | FileCheck %s --check-prefixes=CHECK,AVX512
+
+;
+; Fold reduce(trunc(X)) -> trunc(reduce(X)) if more cost efficient
+;
+
+; TODO: Cheap AVX512 v8i64 -> v8i32 truncation
+define i32 @reduce_add_trunc_v8i64_i32(<8 x i64> %a0) {
+; CHECK-LABEL: @reduce_add_trunc_v8i64_i32(
+; CHECK-NEXT: [[TR:%.*]] = trunc <8 x i64> [[A0:%.*]] to <8 x i32>
+; CHECK-NEXT: [[RED:%.*]] = tail call i32 @llvm.vector.reduce.add.v8i32(<8 x i32> [[TR]])
+; CHECK-NEXT: ret i32 [[RED]]
+;
+ %tr = trunc <8 x i64> %a0 to <8 x i32>
+ %red = tail call i32 @llvm.vector.reduce.add.v8i32(<8 x i32> %tr)
+ ret i32 %red
+}
+declare i32 @llvm.vector.reduce.add.v8i32(<8 x i32>)
+
+; TODO: No legal vXi8 multiplication so vXi16 is always cheaper
+define i8 @reduce_mul_trunc_v16i16_i8(<16 x i16> %a0) {
+; CHECK-LABEL: @reduce_mul_trunc_v16i16_i8(
+; CHECK-NEXT: [[TR:%.*]] = trunc <16 x i16> [[A0:%.*]] to <16 x i8>
+; CHECK-NEXT: [[RED:%.*]] = tail call i8 @llvm.vector.reduce.mul.v16i8(<16 x i8> [[TR]])
+; CHECK-NEXT: ret i8 [[RED]]
+;
+ %tr = trunc <16 x i16> %a0 to <16 x i8>
+ %red = tail call i8 @llvm.vector.reduce.mul.v16i8(<16 x i8> %tr)
+ ret i8 %red
+}
+declare i8 @llvm.vector.reduce.mul.v16i8(<16 x i8>)
+
+define i8 @reduce_or_trunc_v8i32_i8(<8 x i32> %a0) {
+; CHECK-LABEL: @reduce_or_trunc_v8i32_i8(
+; CHECK-NEXT: [[TR:%.*]] = trunc <8 x i32> [[A0:%.*]] to <8 x i8>
+; CHECK-NEXT: [[RED:%.*]] = tail call i8 @llvm.vector.reduce.or.v8i8(<8 x i8> [[TR]])
+; CHECK-NEXT: ret i8 [[RED]]
+;
+ %tr = trunc <8 x i32> %a0 to <8 x i8>
+ %red = tail call i8 @llvm.vector.reduce.or.v8i32(<8 x i8> %tr)
+ ret i8 %red
+}
+declare i32 @llvm.vector.reduce.or.v8i8(<8 x i8>)
+
+define i8 @reduce_xor_trunc_v16i64_i8(<16 x i64> %a0) {
+; CHECK-LABEL: @reduce_xor_trunc_v16i64_i8(
+; CHECK-NEXT: [[TR:%.*]] = trunc <16 x i64> [[A0:%.*]] to <16 x i8>
+; CHECK-NEXT: [[RED:%.*]] = tail call i8 @llvm.vector.reduce.xor.v16i8(<16 x i8> [[TR]])
+; CHECK-NEXT: ret i8 [[RED]]
+;
+ %tr = trunc <16 x i64> %a0 to <16 x i8>
+ %red = tail call i8 @llvm.vector.reduce.xor.v16i8(<16 x i8> %tr)
+ ret i8 %red
+}
+declare i8 @llvm.vector.reduce.xor.v16i8(<16 x i8>)
+
+; Truncation source has other uses - OK to truncate reduction
+define i16 @reduce_and_trunc_v16i64_i16(<16 x i64> %a0) {
+; CHECK-LABEL: @reduce_and_trunc_v16i64_i16(
+; CHECK-NEXT: [[TR:%.*]] = trunc <16 x i64> [[A0:%.*]] to <16 x i16>
+; CHECK-NEXT: [[RED:%.*]] = tail call i16 @llvm.vector.reduce.and.v16i16(<16 x i16> [[TR]])
+; CHECK-NEXT: call void @use_v16i64(<16 x i64> [[A0]])
+; CHECK-NEXT: ret i16 [[RED]]
+;
+ %tr = trunc <16 x i64> %a0 to <16 x i16>
+ %red = tail call i16 @llvm.vector.reduce.and.v16i16(<16 x i16> %tr)
+ call void @use_v16i64(<16 x i64> %a0)
+ ret i16 %red
+}
+declare i16 @llvm.vector.reduce.and.v16i16(<16 x i16>)
+
+; Negative Test: vXi16 multiply is much cheaper than vXi64
+define i16 @reduce_mul_trunc_v8i64_i16(<8 x i64> %a0) {
+; CHECK-LABEL: @reduce_mul_trunc_v8i64_i16(
+; CHECK-NEXT: [[TR:%.*]] = trunc <8 x i64> [[A0:%.*]] to <8 x i16>
+; CHECK-NEXT: [[RED:%.*]] = tail call i16 @llvm.vector.reduce.mul.v8i16(<8 x i16> [[TR]])
+; CHECK-NEXT: ret i16 [[RED]]
+;
+ %tr = trunc <8 x i64> %a0 to <8 x i16>
+ %red = tail call i16 @llvm.vector.reduce.mul.v8i16(<8 x i16> %tr)
+ ret i16 %red
+}
+declare i16 @llvm.vector.reduce.mul.v8i16(<8 x i16>)
+
+; Negative Test: min/max reductions can't use pre-truncated types.
+define i8 @reduce_smin_trunc_v16i16_i8(<16 x i16> %a0) {
+; CHECK-LABEL: @reduce_smin_trunc_v16i16_i8(
+; CHECK-NEXT: [[TR:%.*]] = trunc <16 x i16> [[A0:%.*]] to <16 x i8>
+; CHECK-NEXT: [[RED:%.*]] = tail call i8 @llvm.vector.reduce.smin.v16i8(<16 x i8> [[TR]])
+; CHECK-NEXT: ret i8 [[RED]]
+;
+ %tr = trunc <16 x i16> %a0 to <16 x i8>
+ %red = tail call i8 @llvm.vector.reduce.smin.v16i8(<16 x i8> %tr)
+ ret i8 %red
+}
+declare i8 @llvm.vector.reduce.smin.v16i8(<16 x i8>)
+
+; Negative Test: Truncation has other uses.
+define i16 @reduce_and_trunc_v16i64_i16_multiuse(<16 x i64> %a0) {
+; CHECK-LABEL: @reduce_and_trunc_v16i64_i16_multiuse(
+; CHECK-NEXT: [[TR:%.*]] = trunc <16 x i64> [[A0:%.*]] to <16 x i16>
+; CHECK-NEXT: [[RED:%.*]] = tail call i16 @llvm.vector.reduce.and.v16i16(<16 x i16> [[TR]])
+; CHECK-NEXT: call void @use_v16i16(<16 x i16> [[TR]])
+; CHECK-NEXT: ret i16 [[RED]]
+;
+ %tr = trunc <16 x i64> %a0 to <16 x i16>
+ %red = tail call i16 @llvm.vector.reduce.and.v16i16(<16 x i16> %tr)
+ call void @use_v16i16(<16 x i16> %tr)
+ ret i16 %red
+}
+
+declare void @use_v16i64(<16 x i64>)
+declare void @use_v16i16(<16 x i16>)
+
+;; NOTE: These prefixes are unused and the list is autogenerated. Do not add tests below this line:
+; AVX512: {{.*}}
+; X64: {{.*}}
>From d30236d70b6f7c0c22e0d8eff748ca53208c9e28 Mon Sep 17 00:00:00 2001
From: Simon Pilgrim <llvm-dev at redking.me.uk>
Date: Fri, 16 Feb 2024 12:11:03 +0000
Subject: [PATCH 2/2] [VectorCombine] Fold reduce(trunc(x)) -> trunc(reduce(x))
iff cost effective
Vector truncations can be pretty expensive, especially on X86, whilst scalar truncations are often free.
If the cost of performing the add/mul/and/or/xor reduction is cheap enough on the pre-truncated type, then avoid the vector truncation entirely.
Fixes #81469
---
.../Transforms/Vectorize/VectorCombine.cpp | 57 +++++++++++++++++++
.../X86/reduction-of-truncations.ll | 36 ++++++------
2 files changed, 76 insertions(+), 17 deletions(-)
diff --git a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
index f18711ba30b708..dc669e314a0d9a 100644
--- a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
+++ b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
@@ -29,6 +29,7 @@
#include "llvm/IR/PatternMatch.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Transforms/Utils/Local.h"
+#include "llvm/Transforms/Utils/LoopUtils.h"
#include <numeric>
#include <queue>
@@ -111,6 +112,7 @@ class VectorCombine {
bool scalarizeLoadExtract(Instruction &I);
bool foldShuffleOfBinops(Instruction &I);
bool foldShuffleFromReductions(Instruction &I);
+ bool foldTruncFromReductions(Instruction &I);
bool foldSelectShuffle(Instruction &I, bool FromReduction = false);
void replaceValue(Value &Old, Value &New) {
@@ -1526,6 +1528,60 @@ bool VectorCombine::foldShuffleFromReductions(Instruction &I) {
return foldSelectShuffle(*Shuffle, true);
}
+/// Determine if its more efficient to fold:
+/// reduce(trunc(x)) -> trunc(reduce(x)).
+bool VectorCombine::foldTruncFromReductions(Instruction &I) {
+ auto *II = dyn_cast<IntrinsicInst>(&I);
+ if (!II)
+ return false;
+
+ Intrinsic::ID IID = II->getIntrinsicID();
+ switch (IID) {
+ case Intrinsic::vector_reduce_add:
+ case Intrinsic::vector_reduce_mul:
+ case Intrinsic::vector_reduce_and:
+ case Intrinsic::vector_reduce_or:
+ case Intrinsic::vector_reduce_xor:
+ break;
+ default:
+ return false;
+ }
+
+ unsigned ReductionOpc = getArithmeticReductionInstruction(IID);
+ Value *ReductionSrc = I.getOperand(0);
+
+ Value *TruncSrc;
+ if (!match(ReductionSrc, m_OneUse(m_Trunc(m_Value(TruncSrc)))))
+ return false;
+
+ auto *Trunc = cast<CastInst>(ReductionSrc);
+ auto *TruncSrcTy = cast<VectorType>(TruncSrc->getType());
+ auto *ReductionSrcTy = cast<VectorType>(ReductionSrc->getType());
+ Type *ResultTy = I.getType();
+
+ TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
+ InstructionCost OldCost =
+ TTI.getCastInstrCost(Instruction::Trunc, ReductionSrcTy, TruncSrcTy,
+ TTI::CastContextHint::None, CostKind, Trunc) +
+ TTI.getArithmeticReductionCost(ReductionOpc, ReductionSrcTy, std::nullopt,
+ CostKind);
+ InstructionCost NewCost =
+ TTI.getArithmeticReductionCost(ReductionOpc, TruncSrcTy, std::nullopt,
+ CostKind) +
+ TTI.getCastInstrCost(Instruction::Trunc, ResultTy,
+ ReductionSrcTy->getScalarType(),
+ TTI::CastContextHint::None, CostKind);
+
+ if (OldCost <= NewCost || !NewCost.isValid())
+ return false;
+
+ Value *NewReduction = Builder.CreateIntrinsic(
+ TruncSrcTy->getScalarType(), II->getIntrinsicID(), {TruncSrc});
+ Value *NewTruncation = Builder.CreateTrunc(NewReduction, ResultTy);
+ replaceValue(I, *NewTruncation);
+ return true;
+}
+
/// This method looks for groups of shuffles acting on binops, of the form:
/// %x = shuffle ...
/// %y = shuffle ...
@@ -1917,6 +1973,7 @@ bool VectorCombine::run() {
switch (Opcode) {
case Instruction::Call:
MadeChange |= foldShuffleFromReductions(I);
+ MadeChange |= foldTruncFromReductions(I);
break;
case Instruction::ICmp:
case Instruction::FCmp:
diff --git a/llvm/test/Transforms/VectorCombine/X86/reduction-of-truncations.ll b/llvm/test/Transforms/VectorCombine/X86/reduction-of-truncations.ll
index fdb0f5ce6f89f6..d5dd4cf0e34f92 100644
--- a/llvm/test/Transforms/VectorCombine/X86/reduction-of-truncations.ll
+++ b/llvm/test/Transforms/VectorCombine/X86/reduction-of-truncations.ll
@@ -8,12 +8,17 @@
; Fold reduce(trunc(X)) -> trunc(reduce(X)) if more cost efficient
;
-; TODO: Cheap AVX512 v8i64 -> v8i32 truncation
+; Cheap AVX512 v8i64 -> v8i32 truncation
define i32 @reduce_add_trunc_v8i64_i32(<8 x i64> %a0) {
-; CHECK-LABEL: @reduce_add_trunc_v8i64_i32(
-; CHECK-NEXT: [[TR:%.*]] = trunc <8 x i64> [[A0:%.*]] to <8 x i32>
-; CHECK-NEXT: [[RED:%.*]] = tail call i32 @llvm.vector.reduce.add.v8i32(<8 x i32> [[TR]])
-; CHECK-NEXT: ret i32 [[RED]]
+; X64-LABEL: @reduce_add_trunc_v8i64_i32(
+; X64-NEXT: [[TMP1:%.*]] = call i64 @llvm.vector.reduce.add.v8i64(<8 x i64> [[A0:%.*]])
+; X64-NEXT: [[RED:%.*]] = trunc i64 [[TMP1]] to i32
+; X64-NEXT: ret i32 [[RED]]
+;
+; AVX512-LABEL: @reduce_add_trunc_v8i64_i32(
+; AVX512-NEXT: [[TR:%.*]] = trunc <8 x i64> [[A0:%.*]] to <8 x i32>
+; AVX512-NEXT: [[RED:%.*]] = tail call i32 @llvm.vector.reduce.add.v8i32(<8 x i32> [[TR]])
+; AVX512-NEXT: ret i32 [[RED]]
;
%tr = trunc <8 x i64> %a0 to <8 x i32>
%red = tail call i32 @llvm.vector.reduce.add.v8i32(<8 x i32> %tr)
@@ -21,11 +26,11 @@ define i32 @reduce_add_trunc_v8i64_i32(<8 x i64> %a0) {
}
declare i32 @llvm.vector.reduce.add.v8i32(<8 x i32>)
-; TODO: No legal vXi8 multiplication so vXi16 is always cheaper
+; No legal vXi8 multiplication so vXi16 is always cheaper
define i8 @reduce_mul_trunc_v16i16_i8(<16 x i16> %a0) {
; CHECK-LABEL: @reduce_mul_trunc_v16i16_i8(
-; CHECK-NEXT: [[TR:%.*]] = trunc <16 x i16> [[A0:%.*]] to <16 x i8>
-; CHECK-NEXT: [[RED:%.*]] = tail call i8 @llvm.vector.reduce.mul.v16i8(<16 x i8> [[TR]])
+; CHECK-NEXT: [[TMP1:%.*]] = call i16 @llvm.vector.reduce.mul.v16i16(<16 x i16> [[A0:%.*]])
+; CHECK-NEXT: [[RED:%.*]] = trunc i16 [[TMP1]] to i8
; CHECK-NEXT: ret i8 [[RED]]
;
%tr = trunc <16 x i16> %a0 to <16 x i8>
@@ -36,8 +41,8 @@ declare i8 @llvm.vector.reduce.mul.v16i8(<16 x i8>)
define i8 @reduce_or_trunc_v8i32_i8(<8 x i32> %a0) {
; CHECK-LABEL: @reduce_or_trunc_v8i32_i8(
-; CHECK-NEXT: [[TR:%.*]] = trunc <8 x i32> [[A0:%.*]] to <8 x i8>
-; CHECK-NEXT: [[RED:%.*]] = tail call i8 @llvm.vector.reduce.or.v8i8(<8 x i8> [[TR]])
+; CHECK-NEXT: [[TMP1:%.*]] = call i32 @llvm.vector.reduce.or.v8i32(<8 x i32> [[A0:%.*]])
+; CHECK-NEXT: [[RED:%.*]] = trunc i32 [[TMP1]] to i8
; CHECK-NEXT: ret i8 [[RED]]
;
%tr = trunc <8 x i32> %a0 to <8 x i8>
@@ -48,8 +53,8 @@ declare i32 @llvm.vector.reduce.or.v8i8(<8 x i8>)
define i8 @reduce_xor_trunc_v16i64_i8(<16 x i64> %a0) {
; CHECK-LABEL: @reduce_xor_trunc_v16i64_i8(
-; CHECK-NEXT: [[TR:%.*]] = trunc <16 x i64> [[A0:%.*]] to <16 x i8>
-; CHECK-NEXT: [[RED:%.*]] = tail call i8 @llvm.vector.reduce.xor.v16i8(<16 x i8> [[TR]])
+; CHECK-NEXT: [[TMP1:%.*]] = call i64 @llvm.vector.reduce.xor.v16i64(<16 x i64> [[A0:%.*]])
+; CHECK-NEXT: [[RED:%.*]] = trunc i64 [[TMP1]] to i8
; CHECK-NEXT: ret i8 [[RED]]
;
%tr = trunc <16 x i64> %a0 to <16 x i8>
@@ -61,8 +66,8 @@ declare i8 @llvm.vector.reduce.xor.v16i8(<16 x i8>)
; Truncation source has other uses - OK to truncate reduction
define i16 @reduce_and_trunc_v16i64_i16(<16 x i64> %a0) {
; CHECK-LABEL: @reduce_and_trunc_v16i64_i16(
-; CHECK-NEXT: [[TR:%.*]] = trunc <16 x i64> [[A0:%.*]] to <16 x i16>
-; CHECK-NEXT: [[RED:%.*]] = tail call i16 @llvm.vector.reduce.and.v16i16(<16 x i16> [[TR]])
+; CHECK-NEXT: [[TMP1:%.*]] = call i64 @llvm.vector.reduce.and.v16i64(<16 x i64> [[A0:%.*]])
+; CHECK-NEXT: [[RED:%.*]] = trunc i64 [[TMP1]] to i16
; CHECK-NEXT: call void @use_v16i64(<16 x i64> [[A0]])
; CHECK-NEXT: ret i16 [[RED]]
;
@@ -116,6 +121,3 @@ define i16 @reduce_and_trunc_v16i64_i16_multiuse(<16 x i64> %a0) {
declare void @use_v16i64(<16 x i64>)
declare void @use_v16i16(<16 x i16>)
-;; NOTE: These prefixes are unused and the list is autogenerated. Do not add tests below this line:
-; AVX512: {{.*}}
-; X64: {{.*}}
More information about the llvm-commits
mailing list