[llvm] [CGP]: Optimize mul.overflow. (PR #148343)
David Green via llvm-commits
llvm-commits at lists.llvm.org
Tue Oct 21 07:59:38 PDT 2025
================
@@ -6389,6 +6395,303 @@ bool CodeGenPrepare::optimizeGatherScatterInst(Instruction *MemoryInst,
return true;
}
+// Rewrite the mul_with_overflow intrinsic by checking if both of the
+// operands' value range is within the legal type. If so, we can optimize the
+// multiplication algorithm. This code is supposed to be written during the step
+// of type legalization, but given that we need to reconstruct the IR which is
+// not doable there, we do it here.
+bool CodeGenPrepare::optimizeMulWithOverflow(Instruction *I, bool IsSigned,
+ ModifyDT &ModifiedDT) {
+ if (!TLI->shouldOptimizeMulOverflowIntrinsic(
+ I->getContext(),
+ TLI->getValueType(*DL, I->getType()->getContainedType(0))))
+ return false;
+
+ Value *LHS = I->getOperand(0);
+ Value *RHS = I->getOperand(1);
+ Type *Ty = LHS->getType();
+ unsigned VTBitWidth = Ty->getScalarSizeInBits();
+ unsigned VTHalfBitWidth = VTBitWidth / 2;
+ IntegerType *LegalTy =
+ IntegerType::getIntNTy(I->getContext(), VTHalfBitWidth);
+
+ // Skip the optimization if the type with HalfBitWidth is not legal for the
+ // target.
+ if (TLI->getTypeAction(I->getContext(), TLI->getValueType(*DL, LegalTy)) !=
+ TargetLowering::TypeLegal)
+ return false;
+
+ // Check the pattern we are interested in where there are maximum 2 uses
+ // of the intrinsic which are the extracts instructions.
+ if (I->getNumUses() > 2)
+ return false;
+ ExtractValueInst *MulExtract = nullptr;
+ ExtractValueInst *OverflowExtract = nullptr;
+ for (User *U : I->users()) {
+ auto *Extract = dyn_cast<ExtractValueInst>(U);
+ if (!Extract)
+ return false;
+
+ unsigned Index = Extract->getIndices()[0];
+ if (Index == 0)
+ MulExtract = Extract;
+ else if (Index == 1)
+ OverflowExtract = Extract;
+ }
+
+ // Keep track of the instruction to stop reoptimizing it again.
+ InsertedInsts.insert(I);
+ // ----------------------------
+
+ // For the simple case where IR just checks the overflow flag, new blocks
+ // should be:
+ // entry:
+ // if signed:
+ // ( (lhs_lo>>BW-1) ^ lhs_hi) || ( (rhs_lo>>BW-1) ^ rhs_hi) ? overflow,
+ // overflow_no
+ // else:
+ // (lhs_hi != 0) || (rhs_hi != 0) ? overflow, overflow_no
+ // overflow_no:
+ // overflow:
+
+ // otherwise, new blocks should be:
+ // entry:
+ // if signed:
+ // ( (lhs_lo>>BW-1) ^ lhs_hi) || ( (rhs_lo>>BW-1) ^ rhs_hi) ? overflow,
+ // overflow_no
+ // else:
+ // (lhs_hi != 0) || (rhs_hi != 0) ? overflow, overflow_no
+ // overflow_no:
+ // overflow:
+ // overflow.res:
+
+ // New BBs:
+ std::string KeepBBName = I->getParent()->getName().str();
+ BasicBlock *OverflowEntryBB =
+ I->getParent()->splitBasicBlock(I, "overflow.entry", /*Before*/ true);
+ // Remove the 'br' instruction that is generated as a result of the split
+ // as we are going to append new instructions.
+ OverflowEntryBB->getTerminator()->eraseFromParent();
+ BasicBlock *NoOverflowBB =
+ BasicBlock::Create(I->getContext(), "overflow.no", I->getFunction());
+ NoOverflowBB->moveAfter(OverflowEntryBB);
+ BasicBlock *OverflowBB =
+ BasicBlock::Create(I->getContext(), "overflow", I->getFunction());
+ OverflowBB->moveAfter(NoOverflowBB);
+
+ // BB overflow.entry:
+ IRBuilder<> Builder(OverflowEntryBB);
+ // Get Lo and Hi of LHS & RHS:
+ Value *LoLHS = Builder.CreateTrunc(LHS, LegalTy, "lo.lhs");
+ Value *HiLHS = Builder.CreateLShr(LHS, VTHalfBitWidth, "lhs.lsr");
+ HiLHS = Builder.CreateTrunc(HiLHS, LegalTy, "hi.lhs");
+ Value *LoRHS = Builder.CreateTrunc(RHS, LegalTy, "lo.rhs");
+ Value *HiRHS = Builder.CreateLShr(RHS, VTHalfBitWidth, "rhs.lsr");
+ HiRHS = Builder.CreateTrunc(HiRHS, LegalTy, "hi.rhs");
+
+ Value *IsAnyBitTrue;
+ if (IsSigned) {
+ Value *SignLoLHS =
+ Builder.CreateAShr(LoLHS, VTHalfBitWidth - 1, "sign.lo.lhs");
+ Value *SignLoRHS =
+ Builder.CreateAShr(LoRHS, VTHalfBitWidth - 1, "sign.lo.rhs");
+ Value *XorLHS = Builder.CreateXor(HiLHS, SignLoLHS);
+ Value *XorRHS = Builder.CreateXor(HiRHS, SignLoRHS);
+ Value *Or = Builder.CreateOr(XorLHS, XorRHS, "or.lhs.rhs");
+ IsAnyBitTrue = Builder.CreateCmp(ICmpInst::ICMP_NE, Or,
+ ConstantInt::getNullValue(Or->getType()));
+ } else {
+ Value *CmpLHS = Builder.CreateCmp(ICmpInst::ICMP_NE, HiLHS,
+ ConstantInt::getNullValue(LegalTy));
+ Value *CmpRHS = Builder.CreateCmp(ICmpInst::ICMP_NE, HiRHS,
+ ConstantInt::getNullValue(LegalTy));
+ IsAnyBitTrue = Builder.CreateOr(CmpLHS, CmpRHS, "or.lhs.rhs");
+ }
+ Builder.CreateCondBr(IsAnyBitTrue, OverflowBB, NoOverflowBB);
+
+ // BB overflow.no:
+ Builder.SetInsertPoint(NoOverflowBB);
+ Value *ExtLoLHS, *ExtLoRHS;
+ if (IsSigned) {
+ ExtLoLHS = Builder.CreateSExt(LoLHS, Ty, "lo.lhs.ext");
+ ExtLoRHS = Builder.CreateSExt(LoRHS, Ty, "lo.rhs.ext");
+ } else {
+ ExtLoLHS = Builder.CreateZExt(LoLHS, Ty, "lo.lhs.ext");
+ ExtLoRHS = Builder.CreateZExt(LoRHS, Ty, "lo.rhs.ext");
+ }
+
+ Value *Mul = Builder.CreateMul(ExtLoLHS, ExtLoRHS, "mul.overflow.no");
+
+ // In overflow.no BB: we are sure that the overflow flag is false.
+ // So, when we find this pattern:
+ // br (extractvalue (%mul, 1)), label %if.then, label %if.end
+ // then we can jump directly to %if.end as we're sure that there is no
+ // overflow. This is checking the simple case where the exiting br of I's BB
+ // is the branch we are interested in.
+ BasicBlock *NoOverflowBrBB = nullptr;
+ if (auto *Br = dyn_cast<BranchInst>(I->getParent()->getTerminator())) {
+ // Check that the Br is testing the overflow bit:
+ if (Br->isConditional()) {
+ auto *ExtInstr = dyn_cast<ExtractValueInst>(Br->getOperand(0));
+ if (ExtInstr && ExtInstr->getIndices()[0] == 1)
+ NoOverflowBrBB = Br->getSuccessor(1) /*if.end*/;
+ }
+ }
+ if (NoOverflowBrBB) {
+ // Duplicate instructions from I's BB to the NoOverflowBB:
----------------
davemgreen wrote:
Is there a limit on the number of instructions we should duplicate here? It could be quite a few. There are certain cannotDuplicate instructions to watch out for too.
https://github.com/llvm/llvm-project/pull/148343
More information about the llvm-commits
mailing list