[llvm] [AArch64] Improve urem by constant costs (PR #122236)
Sushant Gokhale via llvm-commits
llvm-commits at lists.llvm.org
Tue Feb 25 21:08:09 PST 2025
================
@@ -3545,20 +3545,58 @@ InstructionCost AArch64TTIImpl::getArithmeticInstrCost(
return Cost;
}
[[fallthrough]];
- case ISD::UDIV: {
+ case ISD::UDIV:
+ case ISD::UREM: {
auto VT = TLI->getValueType(DL, Ty);
- if (Op2Info.isConstant() && Op2Info.isUniform()) {
+ if (Op2Info.isConstant()) {
+ // If the operand is a power of 2 we can use the shift or and cost.
+ if (ISD == ISD::UDIV && Op2Info.isPowerOf2())
+ return getArithmeticInstrCost(Instruction::LShr, Ty, CostKind,
+ Op1Info.getNoProps(),
+ Op2Info.getNoProps());
+ if (ISD == ISD::UREM && Op2Info.isPowerOf2())
+ return getArithmeticInstrCost(Instruction::And, Ty, CostKind,
+ Op1Info.getNoProps(),
+ Op2Info.getNoProps());
+
+ if (ISD == ISD::UDIV || ISD == ISD::UREM) {
+ // Divides by a constant are expanded to MULHU + SUB + SRL + ADD + SRL.
+ // The MULHU will be expanded to UMULL for the types not listed below,
+ // and will become a pair of UMULL+MULL2 for 128bit vectors.
+ bool HasMULH = VT == MVT::i64 || LT.second == MVT::nxv2i64 ||
+ LT.second == MVT::nxv4i32 || LT.second == MVT::nxv8i16 ||
+ LT.second == MVT::nxv16i8;
+ bool Is128bit = LT.second.is128BitVector();
+
+ InstructionCost MulCost =
+ getArithmeticInstrCost(Instruction::Mul, Ty, CostKind,
+ Op1Info.getNoProps(), Op2Info.getNoProps());
+ InstructionCost AddCost =
+ getArithmeticInstrCost(Instruction::Add, Ty, CostKind,
+ Op1Info.getNoProps(), Op2Info.getNoProps());
+ InstructionCost ShrCost =
+ getArithmeticInstrCost(Instruction::AShr, Ty, CostKind,
+ Op1Info.getNoProps(), Op2Info.getNoProps());
+ InstructionCost DivCost = MulCost * (Is128bit ? 2 : 1) + // UMULL/UMULH
+ (HasMULH ? 0 : ShrCost) + // UMULL shift
+ AddCost * 2 + ShrCost;
+ return DivCost + (ISD == ISD::UREM ? MulCost + AddCost : 0);
+ }
+
+ // TODOD: Fix SDIV and SREM costs, similar to the above.
if (TLI->isOperationLegalOrCustom(ISD::MULHU, VT) &&
- !VT.isScalableVector()) {
+ Op2Info.isUniform()) {
// Vector signed division by constant are expanded to the
- // sequence MULHS + ADD/SUB + SRA + SRL + ADD, and unsigned division
- // to MULHS + SUB + SRL + ADD + SRL.
- InstructionCost MulCost = getArithmeticInstrCost(
- Instruction::Mul, Ty, CostKind, Op1Info.getNoProps(), Op2Info.getNoProps());
- InstructionCost AddCost = getArithmeticInstrCost(
- Instruction::Add, Ty, CostKind, Op1Info.getNoProps(), Op2Info.getNoProps());
- InstructionCost ShrCost = getArithmeticInstrCost(
- Instruction::AShr, Ty, CostKind, Op1Info.getNoProps(), Op2Info.getNoProps());
+ // sequence MULHS + ADD/SUB + SRA + SRL + ADD.
+ InstructionCost MulCost =
+ getArithmeticInstrCost(Instruction::Mul, Ty, CostKind,
+ Op1Info.getNoProps(), Op2Info.getNoProps());
+ InstructionCost AddCost =
+ getArithmeticInstrCost(Instruction::Add, Ty, CostKind,
+ Op1Info.getNoProps(), Op2Info.getNoProps());
+ InstructionCost ShrCost =
+ getArithmeticInstrCost(Instruction::AShr, Ty, CostKind,
+ Op1Info.getNoProps(), Op2Info.getNoProps());
return MulCost * 2 + AddCost * 2 + ShrCost * 2 + 1;
}
----------------
sushgokh wrote:
ok. Thanks. Will remove this if(){...} when I update the sdiv/srem patch.
https://github.com/llvm/llvm-project/pull/122236
More information about the llvm-commits
mailing list