[llvm] 21bca79 - [RISCV] Use switch in RISCVTargetTransformInfo::getShuffleCost [nfc]
Philip Reames via llvm-commits
llvm-commits at lists.llvm.org
Mon Mar 13 08:41:08 PDT 2023
Author: Philip Reames
Date: 2023-03-13T08:40:47-07:00
New Revision: 21bca796d7f06006037bc57a578131751474adea
URL: https://github.com/llvm/llvm-project/commit/21bca796d7f06006037bc57a578131751474adea
DIFF: https://github.com/llvm/llvm-project/commit/21bca796d7f06006037bc57a578131751474adea.diff
LOG: [RISCV] Use switch in RISCVTargetTransformInfo::getShuffleCost [nfc]
Refactoring in advance of a semantic change.
Added:
Modified:
llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
Removed:
################################################################################
diff --git a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
index b68080dc4b18e..ebf80a3e13e9e 100644
--- a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
+++ b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
@@ -257,8 +257,9 @@ InstructionCost RISCVTTIImpl::getShuffleCost(TTI::ShuffleKind Kind,
TTI::TargetCostKind CostKind,
int Index, VectorType *SubTp,
ArrayRef<const Value *> Args) {
+ std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost(Tp);
+
if (isa<ScalableVectorType>(Tp)) {
- std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost(Tp);
switch (Kind) {
default:
// Fallthrough to generic handling.
@@ -287,71 +288,73 @@ InstructionCost RISCVTTIImpl::getShuffleCost(TTI::ShuffleKind Kind,
}
}
- if (isa<FixedVectorType>(Tp) && Kind == TargetTransformInfo::SK_Broadcast) {
- std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost(Tp);
- bool HasScalar = (Args.size() > 0) && (Operator::getOpcode(Args[0]) ==
- Instruction::InsertElement);
- if (LT.second.getScalarSizeInBits() == 1) {
+ if (isa<FixedVectorType>(Tp)) {
+ switch (Kind) {
+ default:
+ break;
+ case TargetTransformInfo::SK_Broadcast: {
+ bool HasScalar = (Args.size() > 0) && (Operator::getOpcode(Args[0]) ==
+ Instruction::InsertElement);
+ if (LT.second.getScalarSizeInBits() == 1) {
+ if (HasScalar) {
+ // Example sequence:
+ // andi a0, a0, 1
+ // vsetivli zero, 2, e8, mf8, ta, ma (ignored)
+ // vmv.v.x v8, a0
+ // vmsne.vi v0, v8, 0
+ return LT.first * getLMULCost(LT.second) * 3;
+ }
+ // Example sequence:
+ // vsetivli zero, 2, e8, mf8, ta, mu (ignored)
+ // vmv.v.i v8, 0
+ // vmerge.vim v8, v8, 1, v0
+ // vmv.x.s a0, v8
+ // andi a0, a0, 1
+ // vmv.v.x v8, a0
+ // vmsne.vi v0, v8, 0
+
+ return LT.first * getLMULCost(LT.second) * 6;
+ }
+
if (HasScalar) {
// Example sequence:
- // andi a0, a0, 1
- // vsetivli zero, 2, e8, mf8, ta, ma (ignored)
// vmv.v.x v8, a0
- // vmsne.vi v0, v8, 0
- return LT.first * getLMULCost(LT.second) * 3;
+ return LT.first * getLMULCost(LT.second);
}
- // Example sequence:
- // vsetivli zero, 2, e8, mf8, ta, mu (ignored)
- // vmv.v.i v8, 0
- // vmerge.vim v8, v8, 1, v0
- // vmv.x.s a0, v8
- // andi a0, a0, 1
- // vmv.v.x v8, a0
- // vmsne.vi v0, v8, 0
-
- return LT.first * getLMULCost(LT.second) * 6;
- }
- if (HasScalar) {
// Example sequence:
- // vmv.v.x v8, a0
+ // vrgather.vi v9, v8, 0
+ // TODO: vrgather could be slower than vmv.v.x. It is
+ // implementation-dependent.
return LT.first * getLMULCost(LT.second);
}
-
- // Example sequence:
- // vrgather.vi v9, v8, 0
- // TODO: vrgather could be slower than vmv.v.x. It is
- // implementation-dependent.
- return LT.first * getLMULCost(LT.second);
- }
-
- if (isa<FixedVectorType>(Tp) && Kind == TTI::SK_PermuteSingleSrc &&
- Mask.size() >= 2) {
- std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost(Tp);
- if (LT.second.isFixedLengthVector()) {
- MVT EltTp = LT.second.getVectorElementType();
- // If the size of the element is < ELEN then shuffles of interleaves and
- // deinterleaves of 2 vectors can be lowered into the following sequences
- if (EltTp.getScalarSizeInBits() < ST->getELEN()) {
- auto InterleaveMask = createInterleaveMask(Mask.size() / 2, 2);
- // Example sequence:
- // vsetivli zero, 4, e8, mf4, ta, ma (ignored)
- // vwaddu.vv v10, v8, v9
- // li a0, -1 (ignored)
- // vwmaccu.vx v10, a0, v9
- if (equal(InterleaveMask, Mask))
- return 2 * LT.first * getLMULCost(LT.second);
-
- if (Mask[0] == 0 || Mask[0] == 1) {
- auto DeinterleaveMask = createStrideMask(Mask[0], 2, Mask.size());
+ case TTI::SK_PermuteSingleSrc: {
+ if (Mask.size() >= 2 && LT.second.isFixedLengthVector()) {
+ MVT EltTp = LT.second.getVectorElementType();
+ // If the size of the element is < ELEN then shuffles of interleaves and
+ // deinterleaves of 2 vectors can be lowered into the following sequences
+ if (EltTp.getScalarSizeInBits() < ST->getELEN()) {
+ auto InterleaveMask = createInterleaveMask(Mask.size() / 2, 2);
// Example sequence:
- // vnsrl.wi v10, v8, 0
- if (equal(DeinterleaveMask, Mask))
- return LT.first * getLMULCost(LT.second);
+ // vsetivli zero, 4, e8, mf4, ta, ma (ignored)
+ // vwaddu.vv v10, v8, v9
+ // li a0, -1 (ignored)
+ // vwmaccu.vx v10, a0, v9
+ if (equal(InterleaveMask, Mask))
+ return 2 * LT.first * getLMULCost(LT.second);
+
+ if (Mask[0] == 0 || Mask[0] == 1) {
+ auto DeinterleaveMask = createStrideMask(Mask[0], 2, Mask.size());
+ // Example sequence:
+ // vnsrl.wi v10, v8, 0
+ if (equal(DeinterleaveMask, Mask))
+ return LT.first * getLMULCost(LT.second);
+ }
}
}
}
- }
+ }
+ };
return BaseT::getShuffleCost(Kind, Tp, Mask, CostKind, Index, SubTp);
}
More information about the llvm-commits
mailing list