[llvm] [CGP]: Optimize mul.overflow. (PR #148343)
David Green via llvm-commits
llvm-commits at lists.llvm.org
Mon Aug 4 04:31:20 PDT 2025
================
@@ -6403,6 +6409,229 @@ bool CodeGenPrepare::optimizeGatherScatterInst(Instruction *MemoryInst,
return true;
}
+// Rewrite the umul_with_overflow intrinsic by checking if both of the
+// operands' value range is within the legal type. If so, we can optimize the
+// multiplication algorithm. This code is supposed to be written during the step
+// of type legalization, but given that we need to reconstruct the IR which is
+// not doable there, we do it here.
+bool CodeGenPrepare::optimizeUMulWithOverflow(Instruction *I) {
+ // Enable this optimization only for aarch64.
+ if (!TLI->getTargetMachine().getTargetTriple().isAArch64())
+ return false;
+ if (TLI->getTypeAction(
+ I->getContext(),
+ TLI->getValueType(*DL, I->getType()->getContainedType(0))) !=
+ TargetLowering::TypeExpandInteger)
+ return false;
+
+ Value *LHS = I->getOperand(0);
+ Value *RHS = I->getOperand(1);
+ auto *Ty = LHS->getType();
+ unsigned VTBitWidth = Ty->getScalarSizeInBits();
+ unsigned VTHalfBitWidth = VTBitWidth / 2;
+ auto *LegalTy = IntegerType::getIntNTy(I->getContext(), VTHalfBitWidth);
+
+ // Skip the optimization if the type with HalfBitWidth is not legal for the
+ // target.
+ if (TLI->getTypeAction(I->getContext(), TLI->getValueType(*DL, LegalTy)) !=
+ TargetLowering::TypeLegal)
+ return false;
+
+ I->getParent()->setName("overflow.res");
+ auto *OverflowResBB = I->getParent();
+ auto *OverflowoEntryBB =
+ I->getParent()->splitBasicBlock(I, "overflow.entry", /*Before*/ true);
+ BasicBlock *NoOverflowBB = BasicBlock::Create(
+ I->getContext(), "overflow.no", I->getFunction(), OverflowResBB);
+ BasicBlock *OverflowBB = BasicBlock::Create(I->getContext(), "overflow",
+ I->getFunction(), OverflowResBB);
+ // new blocks should be:
+ // entry:
+ // (lhs_lo ne lhs_hi) || (rhs_lo ne rhs_hi) ? overflow, overflow_no
+
+ // overflow_no:
+ // overflow:
+ // overflow.res:
+ //------------------------------------------------------------------------------
+ // BB overflow.entry:
+ // get Lo and Hi of RHS & LHS:
+ IRBuilder<> Builder(OverflowoEntryBB->getTerminator());
+ auto *LoRHS = Builder.CreateTrunc(RHS, LegalTy, "lo.rhs.trunc");
+ auto *ShrHiRHS = Builder.CreateLShr(RHS, VTHalfBitWidth, "rhs.lsr");
+ auto *HiRHS = Builder.CreateTrunc(ShrHiRHS, LegalTy, "hi.rhs.trunc");
+
+ auto *LoLHS = Builder.CreateTrunc(LHS, LegalTy, "lo.lhs.trunc");
+ auto *ShrHiLHS = Builder.CreateLShr(LHS, VTHalfBitWidth, "lhs.lsr");
+ auto *HiLHS = Builder.CreateTrunc(ShrHiLHS, LegalTy, "hi.lhs.trunc");
+
+ auto *CmpLHS = Builder.CreateCmp(ICmpInst::ICMP_NE, HiLHS,
+ ConstantInt::getNullValue(LegalTy));
+ auto *CmpRHS = Builder.CreateCmp(ICmpInst::ICMP_NE, HiRHS,
+ ConstantInt::getNullValue(LegalTy));
+ auto *Or = Builder.CreateOr(CmpLHS, CmpRHS, "or.lhs.rhs");
+ Builder.CreateCondBr(Or, OverflowBB, NoOverflowBB);
+ OverflowoEntryBB->getTerminator()->eraseFromParent();
+
+ //------------------------------------------------------------------------------
+ // BB overflow.no:
+ Builder.SetInsertPoint(NoOverflowBB);
+ auto *ExtLoLHS = Builder.CreateZExt(LoLHS, Ty, "lo.lhs.ext");
+ auto *ExtLoRHS = Builder.CreateZExt(LoRHS, Ty, "lo.rhs.ext");
+ auto *Mul = Builder.CreateMul(ExtLoLHS, ExtLoRHS, "mul.no.overflow");
+ Builder.CreateBr(OverflowResBB);
+
+ //------------------------------------------------------------------------------
+ // BB overflow.res:
+ Builder.SetInsertPoint(OverflowResBB, OverflowResBB->getFirstInsertionPt());
+ auto *PHINode1 = Builder.CreatePHI(Ty, 2);
+ PHINode1->addIncoming(Mul, NoOverflowBB);
+ auto *PHINode2 =
+ Builder.CreatePHI(IntegerType::getInt1Ty(I->getContext()), 2);
+ PHINode2->addIncoming(ConstantInt::getFalse(I->getContext()), NoOverflowBB);
+
+ StructType *STy = StructType::get(
+ I->getContext(), {Ty, IntegerType::getInt1Ty(I->getContext())});
+ Value *StructValOverflowRes = PoisonValue::get(STy);
+ StructValOverflowRes =
+ Builder.CreateInsertValue(StructValOverflowRes, PHINode1, {0});
+ StructValOverflowRes =
+ Builder.CreateInsertValue(StructValOverflowRes, PHINode2, {1});
+ // Before moving the mul.overflow intrinsic to the overflowBB, replace all its
+ // uses by StructValOverflowRes.
+ I->replaceAllUsesWith(StructValOverflowRes);
+ I->removeFromParent();
+
+ // BB overflow:
+ I->insertInto(OverflowBB, OverflowBB->end());
+ Builder.SetInsertPoint(OverflowBB, OverflowBB->end());
+ auto *MulOverflow = Builder.CreateExtractValue(I, {0}, "mul.overflow");
+ auto *OverflowFlag = Builder.CreateExtractValue(I, {1}, "overflow.flag");
+ Builder.CreateBr(OverflowResBB);
+
+ // Add The Extracted values to the PHINodes in the overflow.res block.
+ PHINode1->addIncoming(MulOverflow, OverflowBB);
+ PHINode2->addIncoming(OverflowFlag, OverflowBB);
+
+ // return false to stop reprocessing the function.
+ return false;
----------------
davemgreen wrote:
I think this might need to return true to set MadeChange, (and maybe set ModifiedDT too). That needs a way to prevent re-expanding mulo again and again though, maybe by checking if the block is guarded by a check that the top bits are not zero?
https://github.com/llvm/llvm-project/pull/148343
More information about the llvm-commits
mailing list