[llvm] Re apply 130577 narrow math for and operand (PR #133896)

Matt Arsenault via llvm-commits llvm-commits at lists.llvm.org
Tue Apr 1 21:41:14 PDT 2025


================
@@ -1561,6 +1561,87 @@ void AMDGPUCodeGenPrepareImpl::expandDivRem64(BinaryOperator &I) const {
   llvm_unreachable("not a division");
 }
 
+Type *findSmallestLegalBits(Instruction *I, int OrigBit, int MaxBitsNeeded,
+                            const TargetLowering *TLI, const DataLayout &DL) {
+  if (MaxBitsNeeded >= OrigBit)
+    return nullptr;
+
+  Type *NewType = I->getType()->getWithNewBitWidth(MaxBitsNeeded);
+  while (OrigBit > MaxBitsNeeded) {
+    if (TLI->isOperationLegalOrCustom(
+            TLI->InstructionOpcodeToISD(I->getOpcode()),
+            TLI->getValueType(DL, NewType, true)))
+      return NewType;
+
+    MaxBitsNeeded *= 2;
+    NewType = I->getType()->getWithNewBitWidth(MaxBitsNeeded);
+  }
+  return nullptr;
+}
+
+static bool tryNarrowMathIfNoOverflow(Instruction *I, const TargetLowering *TLI,
+                                      const TargetTransformInfo &TTI,
+                                      const DataLayout &DL) {
+  unsigned Opc = I->getOpcode();
+  Type *OldType = I->getType();
+
+  if (Opc != Instruction::Add && Opc != Instruction::Mul)
+    return false;
+
+  unsigned OrigBit = OldType->getScalarSizeInBits();
+  unsigned MaxBitsNeeded = OrigBit;
+
+  switch (Opc) {
+  case Instruction::Add:
+    MaxBitsNeeded = KnownBits::add(computeKnownBits(I->getOperand(0), DL),
+                                   computeKnownBits(I->getOperand(1), DL))
+                        .countMaxActiveBits();
+    break;
+  case Instruction::Mul:
+    MaxBitsNeeded = KnownBits::mul(computeKnownBits(I->getOperand(0), DL),
+                                   computeKnownBits(I->getOperand(1), DL))
+                        .countMaxActiveBits();
+    break;
+  default:
+    llvm_unreachable("Unexpected opcode, only valid for Instruction::Add and "
+                     "Instruction::Mul.");
+  }
+
+  MaxBitsNeeded = std::max<unsigned>(bit_ceil(MaxBitsNeeded), 8);
+  Type *NewType = findSmallestLegalBits(I, OrigBit, MaxBitsNeeded, TLI, DL);
+
+  if (!NewType)
+    return false;
+
+  // Old cost
+  InstructionCost OldCost =
+      TTI.getArithmeticInstrCost(Opc, OldType, TTI::TCK_RecipThroughput);
+  // New cost of new op
+  InstructionCost NewCost =
+      TTI.getArithmeticInstrCost(Opc, NewType, TTI::TCK_RecipThroughput);
+  // New cost of narrowing 2 operands (use trunc)
+  NewCost += 2 * TTI.getCastInstrCost(Instruction::Trunc, NewType, OldType,
+                                      TTI.getCastContextHint(I),
+                                      TTI::TCK_RecipThroughput);
+  // New cost of zext narrowed result to original type
+  NewCost +=
+      TTI.getCastInstrCost(Instruction::ZExt, OldType, NewType,
+                           TTI.getCastContextHint(I), TTI::TCK_RecipThroughput);
----------------
arsenm wrote:

Can we move this whole back to the generic code?

https://github.com/llvm/llvm-project/pull/133896


More information about the llvm-commits mailing list