[llvm] [vectorcombine] Pull sext/zext through reduce.or/and/xor (PR #99548)
Philip Reames via llvm-commits
llvm-commits at lists.llvm.org
Thu Jul 18 12:04:12 PDT 2024
https://github.com/preames updated https://github.com/llvm/llvm-project/pull/99548
>From 044926bef56b6999082f09295b295a0a2fb7f8fa Mon Sep 17 00:00:00 2001
From: Philip Reames <preames at rivosinc.com>
Date: Thu, 18 Jul 2024 10:49:45 -0700
Subject: [PATCH 1/2] [vectorcombine] Pull sext/zext through reduce.or/and/xor
This extends the existing foldTruncFromReductions transform to handle
sext and zext as well. This is only legal for the bitwise reductions
(and/or/xor) and not the arithmetic ones (add, mul). Use the same
costing decision to drive whether we do the transform.
---
.../Transforms/Vectorize/VectorCombine.cpp | 40 ++++++++++++-------
.../VectorCombine/RISCV/vecreduce-of-cast.ll | 30 ++++++++++----
2 files changed, 48 insertions(+), 22 deletions(-)
diff --git a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
index 3a49f95d3f117..de60d80aeffa1 100644
--- a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
+++ b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
@@ -117,7 +117,7 @@ class VectorCombine {
bool foldShuffleOfShuffles(Instruction &I);
bool foldShuffleToIdentity(Instruction &I);
bool foldShuffleFromReductions(Instruction &I);
- bool foldTruncFromReductions(Instruction &I);
+ bool foldCastFromReductions(Instruction &I);
bool foldSelectShuffle(Instruction &I, bool FromReduction = false);
void replaceValue(Value &Old, Value &New) {
@@ -2113,15 +2113,20 @@ bool VectorCombine::foldShuffleFromReductions(Instruction &I) {
/// Determine if its more efficient to fold:
/// reduce(trunc(x)) -> trunc(reduce(x)).
-bool VectorCombine::foldTruncFromReductions(Instruction &I) {
+/// reduce(sext(x)) -> sext(reduce(x)).
+/// reduce(zext(x)) -> zext(reduce(x)).
+bool VectorCombine::foldCastFromReductions(Instruction &I) {
auto *II = dyn_cast<IntrinsicInst>(&I);
if (!II)
return false;
+ bool TruncOnly = false;
Intrinsic::ID IID = II->getIntrinsicID();
switch (IID) {
case Intrinsic::vector_reduce_add:
case Intrinsic::vector_reduce_mul:
+ TruncOnly = true;
+ break;
case Intrinsic::vector_reduce_and:
case Intrinsic::vector_reduce_or:
case Intrinsic::vector_reduce_xor:
@@ -2133,25 +2138,32 @@ bool VectorCombine::foldTruncFromReductions(Instruction &I) {
unsigned ReductionOpc = getArithmeticReductionInstruction(IID);
Value *ReductionSrc = I.getOperand(0);
- Value *TruncSrc;
- if (!match(ReductionSrc, m_OneUse(m_Trunc(m_Value(TruncSrc)))))
+ Value *Src;
+ if (!match(ReductionSrc, m_OneUse(m_Trunc(m_Value(Src)))) &&
+ (TruncOnly ||
+ !match(ReductionSrc, m_OneUse(m_ZExtOrSExt(m_Value(Src))))))
return false;
- auto *TruncSrcTy = cast<VectorType>(TruncSrc->getType());
+ // Note: Only trunc has a constexpr, neither sext or zext do.
+ auto CastOpc = Instruction::Trunc;
+ if (auto *Cast = dyn_cast<CastInst>(ReductionSrc))
+ CastOpc = (Instruction::CastOps)cast<Instruction>(Cast)->getOpcode();
+
+ auto *SrcTy = cast<VectorType>(Src->getType());
auto *ReductionSrcTy = cast<VectorType>(ReductionSrc->getType());
Type *ResultTy = I.getType();
TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
InstructionCost OldCost = TTI.getArithmeticReductionCost(
ReductionOpc, ReductionSrcTy, std::nullopt, CostKind);
- if (auto *Trunc = dyn_cast<CastInst>(ReductionSrc))
+ if (auto *Cast = dyn_cast<CastInst>(ReductionSrc))
OldCost +=
- TTI.getCastInstrCost(Instruction::Trunc, ReductionSrcTy, TruncSrcTy,
- TTI::CastContextHint::None, CostKind, Trunc);
+ TTI.getCastInstrCost(CastOpc, ReductionSrcTy, SrcTy,
+ TTI::CastContextHint::None, CostKind, Cast);
InstructionCost NewCost =
- TTI.getArithmeticReductionCost(ReductionOpc, TruncSrcTy, std::nullopt,
+ TTI.getArithmeticReductionCost(ReductionOpc, SrcTy, std::nullopt,
CostKind) +
- TTI.getCastInstrCost(Instruction::Trunc, ResultTy,
+ TTI.getCastInstrCost(CastOpc, ResultTy,
ReductionSrcTy->getScalarType(),
TTI::CastContextHint::None, CostKind);
@@ -2159,9 +2171,9 @@ bool VectorCombine::foldTruncFromReductions(Instruction &I) {
return false;
Value *NewReduction = Builder.CreateIntrinsic(
- TruncSrcTy->getScalarType(), II->getIntrinsicID(), {TruncSrc});
- Value *NewTruncation = Builder.CreateTrunc(NewReduction, ResultTy);
- replaceValue(I, *NewTruncation);
+ SrcTy->getScalarType(), II->getIntrinsicID(), {Src});
+ Value *NewCast = Builder.CreateCast(CastOpc, NewReduction, ResultTy);
+ replaceValue(I, *NewCast);
return true;
}
@@ -2559,7 +2571,7 @@ bool VectorCombine::run() {
switch (Opcode) {
case Instruction::Call:
MadeChange |= foldShuffleFromReductions(I);
- MadeChange |= foldTruncFromReductions(I);
+ MadeChange |= foldCastFromReductions(I);
break;
case Instruction::ICmp:
case Instruction::FCmp:
diff --git a/llvm/test/Transforms/VectorCombine/RISCV/vecreduce-of-cast.ll b/llvm/test/Transforms/VectorCombine/RISCV/vecreduce-of-cast.ll
index 9b1aa19f85c21..f04bcc90e5c35 100644
--- a/llvm/test/Transforms/VectorCombine/RISCV/vecreduce-of-cast.ll
+++ b/llvm/test/Transforms/VectorCombine/RISCV/vecreduce-of-cast.ll
@@ -74,8 +74,8 @@ define i16 @reduce_mul_trunc_v8i64_i16(<8 x i64> %a0) {
define i32 @reduce_or_sext_v8i8_to_v8i32(<8 x i8> %a0) {
; CHECK-LABEL: @reduce_or_sext_v8i8_to_v8i32(
-; CHECK-NEXT: [[TR:%.*]] = sext <8 x i8> [[A0:%.*]] to <8 x i32>
-; CHECK-NEXT: [[RED:%.*]] = tail call i32 @llvm.vector.reduce.or.v8i32(<8 x i32> [[TR]])
+; CHECK-NEXT: [[TMP1:%.*]] = call i8 @llvm.vector.reduce.or.v8i8(<8 x i8> [[A0:%.*]])
+; CHECK-NEXT: [[RED:%.*]] = sext i8 [[TMP1]] to i32
; CHECK-NEXT: ret i32 [[RED]]
;
%tr = sext <8 x i8> %a0 to <8 x i32>
@@ -85,8 +85,8 @@ define i32 @reduce_or_sext_v8i8_to_v8i32(<8 x i8> %a0) {
define i32 @reduce_or_sext_v8i16_to_v8i32(<8 x i16> %a0) {
; CHECK-LABEL: @reduce_or_sext_v8i16_to_v8i32(
-; CHECK-NEXT: [[TR:%.*]] = sext <8 x i16> [[A0:%.*]] to <8 x i32>
-; CHECK-NEXT: [[RED:%.*]] = tail call i32 @llvm.vector.reduce.or.v8i32(<8 x i32> [[TR]])
+; CHECK-NEXT: [[TMP1:%.*]] = call i16 @llvm.vector.reduce.or.v8i16(<8 x i16> [[A0:%.*]])
+; CHECK-NEXT: [[RED:%.*]] = sext i16 [[TMP1]] to i32
; CHECK-NEXT: ret i32 [[RED]]
;
%tr = sext <8 x i16> %a0 to <8 x i32>
@@ -96,8 +96,8 @@ define i32 @reduce_or_sext_v8i16_to_v8i32(<8 x i16> %a0) {
define i32 @reduce_or_zext_v8i8_to_v8i32(<8 x i8> %a0) {
; CHECK-LABEL: @reduce_or_zext_v8i8_to_v8i32(
-; CHECK-NEXT: [[TR:%.*]] = zext <8 x i8> [[A0:%.*]] to <8 x i32>
-; CHECK-NEXT: [[RED:%.*]] = tail call i32 @llvm.vector.reduce.or.v8i32(<8 x i32> [[TR]])
+; CHECK-NEXT: [[TMP1:%.*]] = call i8 @llvm.vector.reduce.or.v8i8(<8 x i8> [[A0:%.*]])
+; CHECK-NEXT: [[RED:%.*]] = zext i8 [[TMP1]] to i32
; CHECK-NEXT: ret i32 [[RED]]
;
%tr = zext <8 x i8> %a0 to <8 x i32>
@@ -107,8 +107,8 @@ define i32 @reduce_or_zext_v8i8_to_v8i32(<8 x i8> %a0) {
define i32 @reduce_or_zext_v8i16_to_v8i32(<8 x i16> %a0) {
; CHECK-LABEL: @reduce_or_zext_v8i16_to_v8i32(
-; CHECK-NEXT: [[TR:%.*]] = zext <8 x i16> [[A0:%.*]] to <8 x i32>
-; CHECK-NEXT: [[RED:%.*]] = tail call i32 @llvm.vector.reduce.or.v8i32(<8 x i32> [[TR]])
+; CHECK-NEXT: [[TMP1:%.*]] = call i16 @llvm.vector.reduce.or.v8i16(<8 x i16> [[A0:%.*]])
+; CHECK-NEXT: [[RED:%.*]] = zext i16 [[TMP1]] to i32
; CHECK-NEXT: ret i32 [[RED]]
;
%tr = zext <8 x i16> %a0 to <8 x i32>
@@ -116,6 +116,20 @@ define i32 @reduce_or_zext_v8i16_to_v8i32(<8 x i16> %a0) {
ret i32 %red
}
+; Negative case - narrowing the reduce (to i8) is illegal.
+; TODO: We could narrow to i16 instead.
+define i32 @reduce_add_trunc_v8i8_to_v8i32(<8 x i8> %a0) {
+; CHECK-LABEL: @reduce_add_trunc_v8i8_to_v8i32(
+; CHECK-NEXT: [[TR:%.*]] = zext <8 x i8> [[A0:%.*]] to <8 x i32>
+; CHECK-NEXT: [[RED:%.*]] = tail call i32 @llvm.vector.reduce.add.v8i32(<8 x i32> [[TR]])
+; CHECK-NEXT: ret i32 [[RED]]
+;
+ %tr = zext <8 x i8> %a0 to <8 x i32>
+ %red = tail call i32 @llvm.vector.reduce.add.v8i32(<8 x i32> %tr)
+ ret i32 %red
+}
+
+
declare i32 @llvm.vector.reduce.add.v8i32(<8 x i32>)
declare i16 @llvm.vector.reduce.add.v8i16(<8 x i16>)
declare i8 @llvm.vector.reduce.add.v8i8(<8 x i8>)
>From 8b372b677210959e4f9e5759519b759073101646 Mon Sep 17 00:00:00 2001
From: Philip Reames <preames at rivosinc.com>
Date: Thu, 18 Jul 2024 12:03:27 -0700
Subject: [PATCH 2/2] clang-format
---
llvm/lib/Transforms/Vectorize/VectorCombine.cpp | 17 +++++++----------
1 file changed, 7 insertions(+), 10 deletions(-)
diff --git a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
index de60d80aeffa1..854bde1491a45 100644
--- a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
+++ b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
@@ -2140,14 +2140,13 @@ bool VectorCombine::foldCastFromReductions(Instruction &I) {
Value *Src;
if (!match(ReductionSrc, m_OneUse(m_Trunc(m_Value(Src)))) &&
- (TruncOnly ||
- !match(ReductionSrc, m_OneUse(m_ZExtOrSExt(m_Value(Src))))))
+ (TruncOnly || !match(ReductionSrc, m_OneUse(m_ZExtOrSExt(m_Value(Src))))))
return false;
// Note: Only trunc has a constexpr, neither sext or zext do.
auto CastOpc = Instruction::Trunc;
if (auto *Cast = dyn_cast<CastInst>(ReductionSrc))
- CastOpc = (Instruction::CastOps)cast<Instruction>(Cast)->getOpcode();
+ CastOpc = (Instruction::CastOps)cast<Instruction>(Cast)->getOpcode();
auto *SrcTy = cast<VectorType>(Src->getType());
auto *ReductionSrcTy = cast<VectorType>(ReductionSrc->getType());
@@ -2157,21 +2156,19 @@ bool VectorCombine::foldCastFromReductions(Instruction &I) {
InstructionCost OldCost = TTI.getArithmeticReductionCost(
ReductionOpc, ReductionSrcTy, std::nullopt, CostKind);
if (auto *Cast = dyn_cast<CastInst>(ReductionSrc))
- OldCost +=
- TTI.getCastInstrCost(CastOpc, ReductionSrcTy, SrcTy,
- TTI::CastContextHint::None, CostKind, Cast);
+ OldCost += TTI.getCastInstrCost(CastOpc, ReductionSrcTy, SrcTy,
+ TTI::CastContextHint::None, CostKind, Cast);
InstructionCost NewCost =
TTI.getArithmeticReductionCost(ReductionOpc, SrcTy, std::nullopt,
CostKind) +
- TTI.getCastInstrCost(CastOpc, ResultTy,
- ReductionSrcTy->getScalarType(),
+ TTI.getCastInstrCost(CastOpc, ResultTy, ReductionSrcTy->getScalarType(),
TTI::CastContextHint::None, CostKind);
if (OldCost <= NewCost || !NewCost.isValid())
return false;
- Value *NewReduction = Builder.CreateIntrinsic(
- SrcTy->getScalarType(), II->getIntrinsicID(), {Src});
+ Value *NewReduction = Builder.CreateIntrinsic(SrcTy->getScalarType(),
+ II->getIntrinsicID(), {Src});
Value *NewCast = Builder.CreateCast(CastOpc, NewReduction, ResultTy);
replaceValue(I, *NewCast);
return true;
More information about the llvm-commits
mailing list