[llvm] [AMDGPU][CodeGenPrepare] Narrow 64 bit math to 32 bit if profitable (PR #130577)

via llvm-commits llvm-commits at lists.llvm.org
Thu Mar 27 12:30:31 PDT 2025


================
@@ -1560,6 +1560,87 @@ void AMDGPUCodeGenPrepareImpl::expandDivRem64(BinaryOperator &I) const {
   llvm_unreachable("not a division");
 }
 
+Type *findSmallestLegalBits(Instruction *I, int OrigBit, int MaxBitsNeeded,
+                            const TargetLowering *TLI, const DataLayout &DL) {
+  if (MaxBitsNeeded >= OrigBit)
+    return nullptr;
+
+  Type *NewType = I->getType()->getWithNewBitWidth(MaxBitsNeeded);
+  while (OrigBit > MaxBitsNeeded) {
+    if (TLI->isOperationLegalOrCustom(
+            TLI->InstructionOpcodeToISD(I->getOpcode()),
+            TLI->getValueType(DL, NewType, true)))
+      return NewType;
+
+    MaxBitsNeeded *= 2;
+    NewType = I->getType()->getWithNewBitWidth(MaxBitsNeeded);
+  }
+  return nullptr;
+}
+
+static bool tryNarrowMathIfNoOverflow(Instruction *I, const TargetLowering *TLI,
+                                      const TargetTransformInfo &TTI,
+                                      const DataLayout &DL) {
+  unsigned Opc = I->getOpcode();
+  Type *OldType = I->getType();
+
+  if (Opc != Instruction::Add && Opc != Instruction::Mul)
+    return false;
+
+  unsigned OrigBit = OldType->getScalarSizeInBits();
+  unsigned MaxBitsNeeded = OrigBit;
+
+  switch (Opc) {
+  case Instruction::Add:
+    MaxBitsNeeded = KnownBits::add(computeKnownBits(I->getOperand(0), DL),
+                                   computeKnownBits(I->getOperand(1), DL))
+                        .countMaxActiveBits();
+    break;
+  case Instruction::Mul:
+    MaxBitsNeeded = KnownBits::mul(computeKnownBits(I->getOperand(0), DL),
+                                   computeKnownBits(I->getOperand(1), DL))
+                        .countMaxActiveBits();
+    break;
+  default:
+    llvm_unreachable("Unexpected opcode, only valid for Instruction::Add and "
+                     "Instruction::Mul.");
+  }
+
+  MaxBitsNeeded = std::max<unsigned>(bit_ceil(MaxBitsNeeded), 8);
+  Type *NewType = findSmallestLegalBits(I, OrigBit, MaxBitsNeeded, TLI, DL);
+
+  if (!NewType)
+    return false;
+
+  // Old cost
+  InstructionCost OldCost =
+      TTI.getArithmeticInstrCost(Opc, OldType, TTI::TCK_RecipThroughput);
+  // New cost of new op
+  InstructionCost NewCost =
+      TTI.getArithmeticInstrCost(Opc, NewType, TTI::TCK_RecipThroughput);
+  // New cost of narrowing 2 operands (use trunc)
+  NewCost += 2 * TTI.getCastInstrCost(Instruction::Trunc, NewType, OldType,
+                                      TTI.getCastContextHint(I),
+                                      TTI::TCK_RecipThroughput);
+  // New cost of zext narrowed result to original type
+  NewCost +=
+      TTI.getCastInstrCost(Instruction::ZExt, OldType, NewType,
+                           TTI.getCastContextHint(I), TTI::TCK_RecipThroughput);
+  if (NewCost >= OldCost)
+    return false;
----------------
LU-JOHN wrote:

I think this cost makes the transformation too conservative.  Usually the truncs will be removed in the final code.  Also, it does not include the benefit of using fewer registers with the narrower operations.

https://github.com/llvm/llvm-project/pull/130577


More information about the llvm-commits mailing list