[llvm] [AArch64][CostModel] Alter sdiv/srem cost where the divisor is constant (PR #123552)
David Green via llvm-commits
llvm-commits at lists.llvm.org
Sat Mar 1 07:45:35 PST 2025
================
@@ -3526,23 +3526,103 @@ InstructionCost AArch64TTIImpl::getArithmeticInstrCost(
default:
return BaseT::getArithmeticInstrCost(Opcode, Ty, CostKind, Op1Info,
Op2Info);
+ case ISD::SREM:
case ISD::SDIV:
- if (Op2Info.isConstant() && Op2Info.isUniform() && Op2Info.isPowerOf2()) {
- // On AArch64, scalar signed division by constants power-of-two are
- // normally expanded to the sequence ADD + CMP + SELECT + SRA.
- // The OperandValue properties many not be same as that of previous
- // operation; conservatively assume OP_None.
- InstructionCost Cost = getArithmeticInstrCost(
- Instruction::Add, Ty, CostKind,
- Op1Info.getNoProps(), Op2Info.getNoProps());
- Cost += getArithmeticInstrCost(Instruction::Sub, Ty, CostKind,
- Op1Info.getNoProps(), Op2Info.getNoProps());
- Cost += getArithmeticInstrCost(
- Instruction::Select, Ty, CostKind,
- Op1Info.getNoProps(), Op2Info.getNoProps());
- Cost += getArithmeticInstrCost(Instruction::AShr, Ty, CostKind,
- Op1Info.getNoProps(), Op2Info.getNoProps());
- return Cost;
+ /*
+ Notes for sdiv/srem specific costs:
+ 1. This only considers the cases where the divisor is constant, uniform and
+ (pow-of-2/non-pow-of-2). Other cases are not important since they either
+ result in some form of (ldr + adrp), corresponding to constant vectors, or
+ scalarization of the division operation.
+ 2. Constant divisors, either negative in whole or partially, don't result in
+ significantly different codegen as compared to positive constant divisors.
+ So, we don't consider negative divisors seperately.
+ 3. If the codegen is significantly different with SVE, it has been indicated
+ using comments at appropriate places.
+
+ sdiv specific cases:
+ -----------------------------------------------------------------------
+ codegen | pow-of-2 | Type
+ -----------------------------------------------------------------------
+ add + cmp + csel + asr | Y | i64
+ add + cmp + csel + asr | Y | i32
+ -----------------------------------------------------------------------
+
+ srem specific cases:
+ -----------------------------------------------------------------------
+ codegen | pow-of-2 | Type
+ -----------------------------------------------------------------------
+ negs + and + and + csneg | Y | i64
+ negs + and + and + csneg | Y | i32
+ -----------------------------------------------------------------------
+
+ other sdiv/srem cases:
+ -------------------------------------------------------------------------
+ commom codegen | + srem | + sdiv | pow-of-2 | Type
+ -------------------------------------------------------------------------
+ smulh + asr + add + add | - | - | N | i64
+ smull + lsr + add + add | - | - | N | i32
+ usra | and + sub | sshr | Y | <2 x i64>
+ 2 * (scalar code) | - | - | N | <2 x i64>
+ usra | bic + sub | sshr + neg | Y | <4 x i32>
+ smull2 + smull + uzp2 | mls | - | N | <4 x i32>
+ + sshr + usra | | | |
+ -------------------------------------------------------------------------
+ */
+ if (Op2Info.isConstant() && Op2Info.isUniform()) {
+ InstructionCost AddCost =
+ getArithmeticInstrCost(Instruction::Add, Ty, CostKind,
+ Op1Info.getNoProps(), Op2Info.getNoProps());
+ InstructionCost AsrCost =
+ getArithmeticInstrCost(Instruction::AShr, Ty, CostKind,
+ Op1Info.getNoProps(), Op2Info.getNoProps());
+ InstructionCost MulCost =
+ getArithmeticInstrCost(Instruction::Mul, Ty, CostKind,
+ Op1Info.getNoProps(), Op2Info.getNoProps());
+ // add/cmp/csel/csneg should have similar cost while asr/negs/and should
+ // have similar cost.
+ if (LT.second.isScalarInteger()) {
+ if (Op2Info.isPowerOf2()) {
+ return ISD == ISD::SDIV ? (3 * AddCost + AsrCost)
+ : (3 * AsrCost + AddCost);
+ } else {
+ return MulCost + AsrCost + 2 * AddCost;
+ }
+ } else {
+ InstructionCost UsraCost = 2 * AsrCost;
+ if (Op2Info.isPowerOf2()) {
+ // Division with scalable types corresponds to native 'asrd'
+ // instruction when SVE is available.
+ // e.g. %1 = sdiv <vscale x 4 x i32> %a, splat (i32 8)
+ if (Ty->isScalableTy() && ST->hasSVE())
+ return 2 * AsrCost;
+ return UsraCost +
+ (ISD == ISD::SDIV
+ ? (LT.second.getScalarType() == MVT::i64 ? 1 : 2) *
+ AsrCost
+ : 2 * AddCost);
+ } else if (LT.second.is128BitVector() &&
+ LT.second.getScalarType() == MVT::i64) {
----------------
davemgreen wrote:
I believe they should be equivalent and v2i64 is a bit shorter / simpler.
https://github.com/llvm/llvm-project/pull/123552
More information about the llvm-commits
mailing list