[llvm] 2ed2a3a - [Transforms][Utils] Add helpers to map between Reduction IntrinsicID and Arithmetic Instruction Opcode and MinMax IntrinsicID / RecurKind
Simon Pilgrim via llvm-commits
llvm-commits at lists.llvm.org
Fri Feb 16 03:20:47 PST 2024
Author: Simon Pilgrim
Date: 2024-02-16T11:20:34Z
New Revision: 2ed2a3ad90934efac12cbeb01cf73afebc01d963
URL: https://github.com/llvm/llvm-project/commit/2ed2a3ad90934efac12cbeb01cf73afebc01d963
DIFF: https://github.com/llvm/llvm-project/commit/2ed2a3ad90934efac12cbeb01cf73afebc01d963.diff
LOG: [Transforms][Utils] Add helpers to map between Reduction IntrinsicID and Arithmetic Instruction Opcode and MinMax IntrinsicID / RecurKind
Noticed on #81852
Added:
Modified:
llvm/include/llvm/CodeGen/BasicTTIImpl.h
llvm/include/llvm/Transforms/Utils/LoopUtils.h
llvm/lib/CodeGen/ExpandReductions.cpp
llvm/lib/Transforms/Utils/LoopUtils.cpp
Removed:
################################################################################
diff --git a/llvm/include/llvm/CodeGen/BasicTTIImpl.h b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
index bb17298daba03a..61f6564e8cd79b 100644
--- a/llvm/include/llvm/CodeGen/BasicTTIImpl.h
+++ b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
@@ -48,6 +48,7 @@
#include "llvm/Support/MathExtras.h"
#include "llvm/Target/TargetMachine.h"
#include "llvm/Target/TargetOptions.h"
+#include "llvm/Transforms/Utils/LoopUtils.h"
#include <algorithm>
#include <cassert>
#include <cstdint>
@@ -2013,50 +2014,27 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
CostKind);
}
case Intrinsic::vector_reduce_add:
- return thisT()->getArithmeticReductionCost(Instruction::Add, VecOpTy,
- std::nullopt, CostKind);
case Intrinsic::vector_reduce_mul:
- return thisT()->getArithmeticReductionCost(Instruction::Mul, VecOpTy,
- std::nullopt, CostKind);
case Intrinsic::vector_reduce_and:
- return thisT()->getArithmeticReductionCost(Instruction::And, VecOpTy,
- std::nullopt, CostKind);
case Intrinsic::vector_reduce_or:
- return thisT()->getArithmeticReductionCost(Instruction::Or, VecOpTy,
- std::nullopt, CostKind);
case Intrinsic::vector_reduce_xor:
- return thisT()->getArithmeticReductionCost(Instruction::Xor, VecOpTy,
- std::nullopt, CostKind);
+ return thisT()->getArithmeticReductionCost(
+ getArithmeticReductionInstruction(IID), VecOpTy, std::nullopt,
+ CostKind);
case Intrinsic::vector_reduce_fadd:
- return thisT()->getArithmeticReductionCost(Instruction::FAdd, VecOpTy,
- FMF, CostKind);
case Intrinsic::vector_reduce_fmul:
- return thisT()->getArithmeticReductionCost(Instruction::FMul, VecOpTy,
- FMF, CostKind);
+ return thisT()->getArithmeticReductionCost(
+ getArithmeticReductionInstruction(IID), VecOpTy, FMF, CostKind);
case Intrinsic::vector_reduce_smax:
- return thisT()->getMinMaxReductionCost(Intrinsic::smax, VecOpTy,
- ICA.getFlags(), CostKind);
case Intrinsic::vector_reduce_smin:
- return thisT()->getMinMaxReductionCost(Intrinsic::smin, VecOpTy,
- ICA.getFlags(), CostKind);
case Intrinsic::vector_reduce_umax:
- return thisT()->getMinMaxReductionCost(Intrinsic::umax, VecOpTy,
- ICA.getFlags(), CostKind);
case Intrinsic::vector_reduce_umin:
- return thisT()->getMinMaxReductionCost(Intrinsic::umin, VecOpTy,
- ICA.getFlags(), CostKind);
case Intrinsic::vector_reduce_fmax:
- return thisT()->getMinMaxReductionCost(Intrinsic::maxnum, VecOpTy,
- ICA.getFlags(), CostKind);
case Intrinsic::vector_reduce_fmin:
- return thisT()->getMinMaxReductionCost(Intrinsic::minnum, VecOpTy,
- ICA.getFlags(), CostKind);
case Intrinsic::vector_reduce_fmaximum:
- return thisT()->getMinMaxReductionCost(Intrinsic::maximum, VecOpTy,
- ICA.getFlags(), CostKind);
case Intrinsic::vector_reduce_fminimum:
- return thisT()->getMinMaxReductionCost(Intrinsic::minimum, VecOpTy,
- ICA.getFlags(), CostKind);
+ return thisT()->getMinMaxReductionCost(getMinMaxReductionIntrinsicOp(IID),
+ VecOpTy, ICA.getFlags(), CostKind);
case Intrinsic::abs: {
// abs(X) = select(icmp(X,0),X,sub(0,X))
Type *CondTy = RetTy->getWithNewBitWidth(1);
diff --git a/llvm/include/llvm/Transforms/Utils/LoopUtils.h b/llvm/include/llvm/Transforms/Utils/LoopUtils.h
index 5a1385d01d8e44..187ace3a0cbedf 100644
--- a/llvm/include/llvm/Transforms/Utils/LoopUtils.h
+++ b/llvm/include/llvm/Transforms/Utils/LoopUtils.h
@@ -357,9 +357,18 @@ bool canSinkOrHoistInst(Instruction &I, AAResults *AA, DominatorTree *DT,
SinkAndHoistLICMFlags &LICMFlags,
OptimizationRemarkEmitter *ORE = nullptr);
+/// Returns the arithmetic instruction opcode used when expanding a reduction.
+unsigned getArithmeticReductionInstruction(Intrinsic::ID RdxID);
+
+/// Returns the min/max intrinsic used when expanding a min/max reduction.
+Intrinsic::ID getMinMaxReductionIntrinsicOp(Intrinsic::ID RdxID);
+
/// Returns the min/max intrinsic used when expanding a min/max reduction.
Intrinsic::ID getMinMaxReductionIntrinsicOp(RecurKind RK);
+/// Returns the recurence kind used when expanding a min/max reduction.
+RecurKind getMinMaxReductionRecurKind(Intrinsic::ID RdxID);
+
/// Returns the comparison predicate used when expanding a min/max reduction.
CmpInst::Predicate getMinMaxReductionPredicate(RecurKind RK);
diff --git a/llvm/lib/CodeGen/ExpandReductions.cpp b/llvm/lib/CodeGen/ExpandReductions.cpp
index 79b6dc9154b3fc..0b1504e51b1bbb 100644
--- a/llvm/lib/CodeGen/ExpandReductions.cpp
+++ b/llvm/lib/CodeGen/ExpandReductions.cpp
@@ -26,54 +26,6 @@ using namespace llvm;
namespace {
-unsigned getOpcode(Intrinsic::ID ID) {
- switch (ID) {
- case Intrinsic::vector_reduce_fadd:
- return Instruction::FAdd;
- case Intrinsic::vector_reduce_fmul:
- return Instruction::FMul;
- case Intrinsic::vector_reduce_add:
- return Instruction::Add;
- case Intrinsic::vector_reduce_mul:
- return Instruction::Mul;
- case Intrinsic::vector_reduce_and:
- return Instruction::And;
- case Intrinsic::vector_reduce_or:
- return Instruction::Or;
- case Intrinsic::vector_reduce_xor:
- return Instruction::Xor;
- case Intrinsic::vector_reduce_smax:
- case Intrinsic::vector_reduce_smin:
- case Intrinsic::vector_reduce_umax:
- case Intrinsic::vector_reduce_umin:
- return Instruction::ICmp;
- case Intrinsic::vector_reduce_fmax:
- case Intrinsic::vector_reduce_fmin:
- return Instruction::FCmp;
- default:
- llvm_unreachable("Unexpected ID");
- }
-}
-
-RecurKind getRK(Intrinsic::ID ID) {
- switch (ID) {
- case Intrinsic::vector_reduce_smax:
- return RecurKind::SMax;
- case Intrinsic::vector_reduce_smin:
- return RecurKind::SMin;
- case Intrinsic::vector_reduce_umax:
- return RecurKind::UMax;
- case Intrinsic::vector_reduce_umin:
- return RecurKind::UMin;
- case Intrinsic::vector_reduce_fmax:
- return RecurKind::FMax;
- case Intrinsic::vector_reduce_fmin:
- return RecurKind::FMin;
- default:
- return RecurKind::None;
- }
-}
-
bool expandReductions(Function &F, const TargetTransformInfo *TTI) {
bool Changed = false;
SmallVector<IntrinsicInst *, 4> Worklist;
@@ -106,7 +58,7 @@ bool expandReductions(Function &F, const TargetTransformInfo *TTI) {
FastMathFlags FMF =
isa<FPMathOperator>(II) ? II->getFastMathFlags() : FastMathFlags{};
Intrinsic::ID ID = II->getIntrinsicID();
- RecurKind RK = getRK(ID);
+ RecurKind RK = getMinMaxReductionRecurKind(ID);
Value *Rdx = nullptr;
IRBuilder<> Builder(II);
@@ -120,16 +72,16 @@ bool expandReductions(Function &F, const TargetTransformInfo *TTI) {
// and it can't be handled by generating a shuffle sequence.
Value *Acc = II->getArgOperand(0);
Value *Vec = II->getArgOperand(1);
+ unsigned RdxOpcode = getArithmeticReductionInstruction(ID);
if (!FMF.allowReassoc())
- Rdx = getOrderedReduction(Builder, Acc, Vec, getOpcode(ID), RK);
+ Rdx = getOrderedReduction(Builder, Acc, Vec, RdxOpcode, RK);
else {
if (!isPowerOf2_32(
cast<FixedVectorType>(Vec->getType())->getNumElements()))
continue;
-
- Rdx = getShuffleReduction(Builder, Vec, getOpcode(ID), RK);
- Rdx = Builder.CreateBinOp((Instruction::BinaryOps)getOpcode(ID),
- Acc, Rdx, "bin.rdx");
+ Rdx = getShuffleReduction(Builder, Vec, RdxOpcode, RK);
+ Rdx = Builder.CreateBinOp((Instruction::BinaryOps)RdxOpcode, Acc, Rdx,
+ "bin.rdx");
}
break;
}
@@ -159,8 +111,8 @@ bool expandReductions(Function &F, const TargetTransformInfo *TTI) {
}
break;
}
-
- Rdx = getShuffleReduction(Builder, Vec, getOpcode(ID), RK);
+ unsigned RdxOpcode = getArithmeticReductionInstruction(ID);
+ Rdx = getShuffleReduction(Builder, Vec, RdxOpcode, RK);
break;
}
case Intrinsic::vector_reduce_add:
@@ -174,8 +126,8 @@ bool expandReductions(Function &F, const TargetTransformInfo *TTI) {
if (!isPowerOf2_32(
cast<FixedVectorType>(Vec->getType())->getNumElements()))
continue;
-
- Rdx = getShuffleReduction(Builder, Vec, getOpcode(ID), RK);
+ unsigned RdxOpcode = getArithmeticReductionInstruction(ID);
+ Rdx = getShuffleReduction(Builder, Vec, RdxOpcode, RK);
break;
}
case Intrinsic::vector_reduce_fmax:
@@ -187,8 +139,8 @@ bool expandReductions(Function &F, const TargetTransformInfo *TTI) {
cast<FixedVectorType>(Vec->getType())->getNumElements()) ||
!FMF.noNaNs())
continue;
-
- Rdx = getShuffleReduction(Builder, Vec, getOpcode(ID), RK);
+ unsigned RdxOpcode = getArithmeticReductionInstruction(ID);
+ Rdx = getShuffleReduction(Builder, Vec, RdxOpcode, RK);
break;
}
}
diff --git a/llvm/lib/Transforms/Utils/LoopUtils.cpp b/llvm/lib/Transforms/Utils/LoopUtils.cpp
index 59485126b280ab..002bc90c9b5677 100644
--- a/llvm/lib/Transforms/Utils/LoopUtils.cpp
+++ b/llvm/lib/Transforms/Utils/LoopUtils.cpp
@@ -917,6 +917,58 @@ bool llvm::hasIterationCountInvariantInParent(Loop *InnerLoop,
return true;
}
+unsigned llvm::getArithmeticReductionInstruction(Intrinsic::ID RdxID) {
+ switch (RdxID) {
+ case Intrinsic::vector_reduce_fadd:
+ return Instruction::FAdd;
+ case Intrinsic::vector_reduce_fmul:
+ return Instruction::FMul;
+ case Intrinsic::vector_reduce_add:
+ return Instruction::Add;
+ case Intrinsic::vector_reduce_mul:
+ return Instruction::Mul;
+ case Intrinsic::vector_reduce_and:
+ return Instruction::And;
+ case Intrinsic::vector_reduce_or:
+ return Instruction::Or;
+ case Intrinsic::vector_reduce_xor:
+ return Instruction::Xor;
+ case Intrinsic::vector_reduce_smax:
+ case Intrinsic::vector_reduce_smin:
+ case Intrinsic::vector_reduce_umax:
+ case Intrinsic::vector_reduce_umin:
+ return Instruction::ICmp;
+ case Intrinsic::vector_reduce_fmax:
+ case Intrinsic::vector_reduce_fmin:
+ return Instruction::FCmp;
+ default:
+ llvm_unreachable("Unexpected ID");
+ }
+}
+
+Intrinsic::ID llvm::getMinMaxReductionIntrinsicOp(Intrinsic::ID RdxID) {
+ switch (RdxID) {
+ default:
+ llvm_unreachable("Unknown min/max recurrence kind");
+ case Intrinsic::vector_reduce_umin:
+ return Intrinsic::umin;
+ case Intrinsic::vector_reduce_umax:
+ return Intrinsic::umax;
+ case Intrinsic::vector_reduce_smin:
+ return Intrinsic::smin;
+ case Intrinsic::vector_reduce_smax:
+ return Intrinsic::smax;
+ case Intrinsic::vector_reduce_fmin:
+ return Intrinsic::minnum;
+ case Intrinsic::vector_reduce_fmax:
+ return Intrinsic::maxnum;
+ case Intrinsic::vector_reduce_fminimum:
+ return Intrinsic::minimum;
+ case Intrinsic::vector_reduce_fmaximum:
+ return Intrinsic::maximum;
+ }
+}
+
Intrinsic::ID llvm::getMinMaxReductionIntrinsicOp(RecurKind RK) {
switch (RK) {
default:
@@ -940,6 +992,25 @@ Intrinsic::ID llvm::getMinMaxReductionIntrinsicOp(RecurKind RK) {
}
}
+RecurKind llvm::getMinMaxReductionRecurKind(Intrinsic::ID RdxID) {
+ switch (RdxID) {
+ case Intrinsic::vector_reduce_smax:
+ return RecurKind::SMax;
+ case Intrinsic::vector_reduce_smin:
+ return RecurKind::SMin;
+ case Intrinsic::vector_reduce_umax:
+ return RecurKind::UMax;
+ case Intrinsic::vector_reduce_umin:
+ return RecurKind::UMin;
+ case Intrinsic::vector_reduce_fmax:
+ return RecurKind::FMax;
+ case Intrinsic::vector_reduce_fmin:
+ return RecurKind::FMin;
+ default:
+ return RecurKind::None;
+ }
+}
+
CmpInst::Predicate llvm::getMinMaxReductionPredicate(RecurKind RK) {
switch (RK) {
default:
More information about the llvm-commits
mailing list