[llvm] [AArch64] Improve urem by constant costs (PR #122236)

David Green via llvm-commits llvm-commits at lists.llvm.org
Tue Feb 25 10:12:52 PST 2025


================
@@ -3545,20 +3545,58 @@ InstructionCost AArch64TTIImpl::getArithmeticInstrCost(
       return Cost;
     }
     [[fallthrough]];
-  case ISD::UDIV: {
+  case ISD::UDIV:
+  case ISD::UREM: {
     auto VT = TLI->getValueType(DL, Ty);
-    if (Op2Info.isConstant() && Op2Info.isUniform()) {
+    if (Op2Info.isConstant()) {
+      // If the operand is a power of 2 we can use the shift or and cost.
+      if (ISD == ISD::UDIV && Op2Info.isPowerOf2())
+        return getArithmeticInstrCost(Instruction::LShr, Ty, CostKind,
+                                      Op1Info.getNoProps(),
+                                      Op2Info.getNoProps());
+      if (ISD == ISD::UREM && Op2Info.isPowerOf2())
+        return getArithmeticInstrCost(Instruction::And, Ty, CostKind,
+                                      Op1Info.getNoProps(),
+                                      Op2Info.getNoProps());
+
+      if (ISD == ISD::UDIV || ISD == ISD::UREM) {
+        // Divides by a constant are expanded to MULHU + SUB + SRL + ADD + SRL.
+        // The MULHU will be expanded to UMULL for the types not listed below,
+        // and will become a pair of UMULL+MULL2 for 128bit vectors.
+        bool HasMULH = VT == MVT::i64 || LT.second == MVT::nxv2i64 ||
+                       LT.second == MVT::nxv4i32 || LT.second == MVT::nxv8i16 ||
+                       LT.second == MVT::nxv16i8;
+        bool Is128bit = LT.second.is128BitVector();
+
+        InstructionCost MulCost =
+            getArithmeticInstrCost(Instruction::Mul, Ty, CostKind,
+                                   Op1Info.getNoProps(), Op2Info.getNoProps());
+        InstructionCost AddCost =
+            getArithmeticInstrCost(Instruction::Add, Ty, CostKind,
+                                   Op1Info.getNoProps(), Op2Info.getNoProps());
+        InstructionCost ShrCost =
+            getArithmeticInstrCost(Instruction::AShr, Ty, CostKind,
+                                   Op1Info.getNoProps(), Op2Info.getNoProps());
+        InstructionCost DivCost = MulCost * (Is128bit ? 2 : 1) + // UMULL/UMULH
+                                  (HasMULH ? 0 : ShrCost) +      // UMULL shift
+                                  AddCost * 2 + ShrCost;
+        return DivCost + (ISD == ISD::UREM ? MulCost + AddCost : 0);
+      }
+
+      // TODOD: Fix SDIV and SREM costs, similar to the above.
       if (TLI->isOperationLegalOrCustom(ISD::MULHU, VT) &&
-          !VT.isScalableVector()) {
+          Op2Info.isUniform()) {
         // Vector signed division by constant are expanded to the
-        // sequence MULHS + ADD/SUB + SRA + SRL + ADD, and unsigned division
-        // to MULHS + SUB + SRL + ADD + SRL.
-        InstructionCost MulCost = getArithmeticInstrCost(
-            Instruction::Mul, Ty, CostKind, Op1Info.getNoProps(), Op2Info.getNoProps());
-        InstructionCost AddCost = getArithmeticInstrCost(
-            Instruction::Add, Ty, CostKind, Op1Info.getNoProps(), Op2Info.getNoProps());
-        InstructionCost ShrCost = getArithmeticInstrCost(
-            Instruction::AShr, Ty, CostKind, Op1Info.getNoProps(), Op2Info.getNoProps());
+        // sequence MULHS + ADD/SUB + SRA + SRL + ADD.
+        InstructionCost MulCost =
+            getArithmeticInstrCost(Instruction::Mul, Ty, CostKind,
+                                   Op1Info.getNoProps(), Op2Info.getNoProps());
+        InstructionCost AddCost =
+            getArithmeticInstrCost(Instruction::Add, Ty, CostKind,
+                                   Op1Info.getNoProps(), Op2Info.getNoProps());
+        InstructionCost ShrCost =
+            getArithmeticInstrCost(Instruction::AShr, Ty, CostKind,
+                                   Op1Info.getNoProps(), Op2Info.getNoProps());
         return MulCost * 2 + AddCost * 2 + ShrCost * 2 + 1;
       }
----------------
davemgreen wrote:

Hi - This is inside the top-level `if (Op2Info.isConstant()) {`, so I don't think that should be an issue. We are only trying to update constant costs in this patch, to keep it simpler. I've added some uniform tests in f08824b935434b91f7352904a25f6309f2b3e6bd to check.

I think this bit of code can probably be removed when sdiv gets added. For now I will re-add the isScalableVector check to make sure the sdiv scores don't change yet.

https://github.com/llvm/llvm-project/pull/122236


More information about the llvm-commits mailing list