[llvm] Re apply 130577 narrow math for and operand (PR #133896)
Matt Arsenault via llvm-commits
llvm-commits at lists.llvm.org
Tue Apr 1 21:41:14 PDT 2025
================
@@ -1561,6 +1561,87 @@ void AMDGPUCodeGenPrepareImpl::expandDivRem64(BinaryOperator &I) const {
llvm_unreachable("not a division");
}
+Type *findSmallestLegalBits(Instruction *I, int OrigBit, int MaxBitsNeeded,
+ const TargetLowering *TLI, const DataLayout &DL) {
+ if (MaxBitsNeeded >= OrigBit)
+ return nullptr;
+
+ Type *NewType = I->getType()->getWithNewBitWidth(MaxBitsNeeded);
+ while (OrigBit > MaxBitsNeeded) {
+ if (TLI->isOperationLegalOrCustom(
+ TLI->InstructionOpcodeToISD(I->getOpcode()),
+ TLI->getValueType(DL, NewType, true)))
+ return NewType;
+
+ MaxBitsNeeded *= 2;
+ NewType = I->getType()->getWithNewBitWidth(MaxBitsNeeded);
+ }
+ return nullptr;
+}
+
+static bool tryNarrowMathIfNoOverflow(Instruction *I, const TargetLowering *TLI,
+ const TargetTransformInfo &TTI,
+ const DataLayout &DL) {
+ unsigned Opc = I->getOpcode();
+ Type *OldType = I->getType();
+
+ if (Opc != Instruction::Add && Opc != Instruction::Mul)
+ return false;
+
+ unsigned OrigBit = OldType->getScalarSizeInBits();
+ unsigned MaxBitsNeeded = OrigBit;
+
+ switch (Opc) {
+ case Instruction::Add:
+ MaxBitsNeeded = KnownBits::add(computeKnownBits(I->getOperand(0), DL),
+ computeKnownBits(I->getOperand(1), DL))
+ .countMaxActiveBits();
+ break;
+ case Instruction::Mul:
+ MaxBitsNeeded = KnownBits::mul(computeKnownBits(I->getOperand(0), DL),
+ computeKnownBits(I->getOperand(1), DL))
+ .countMaxActiveBits();
+ break;
+ default:
+ llvm_unreachable("Unexpected opcode, only valid for Instruction::Add and "
+ "Instruction::Mul.");
+ }
+
+ MaxBitsNeeded = std::max<unsigned>(bit_ceil(MaxBitsNeeded), 8);
+ Type *NewType = findSmallestLegalBits(I, OrigBit, MaxBitsNeeded, TLI, DL);
+
+ if (!NewType)
+ return false;
+
+ // Old cost
+ InstructionCost OldCost =
+ TTI.getArithmeticInstrCost(Opc, OldType, TTI::TCK_RecipThroughput);
+ // New cost of new op
+ InstructionCost NewCost =
+ TTI.getArithmeticInstrCost(Opc, NewType, TTI::TCK_RecipThroughput);
+ // New cost of narrowing 2 operands (use trunc)
+ NewCost += 2 * TTI.getCastInstrCost(Instruction::Trunc, NewType, OldType,
+ TTI.getCastContextHint(I),
+ TTI::TCK_RecipThroughput);
+ // New cost of zext narrowed result to original type
+ NewCost +=
+ TTI.getCastInstrCost(Instruction::ZExt, OldType, NewType,
+ TTI.getCastContextHint(I), TTI::TCK_RecipThroughput);
----------------
arsenm wrote:
Can we move this whole back to the generic code?
https://github.com/llvm/llvm-project/pull/133896
More information about the llvm-commits
mailing list