[llvm] [CGP]: Optimize mul.overflow. (PR #148343)

Alexis Engelke via llvm-commits llvm-commits at lists.llvm.org
Fri Aug 1 23:16:37 PDT 2025


================
@@ -6403,6 +6409,207 @@ bool CodeGenPrepare::optimizeGatherScatterInst(Instruction *MemoryInst,
   return true;
 }
 
+// Rewrite the umul_with_overflow intrinsic by checking if both of the
+// operands' value range is within the legal type. If so, we can optimize the
+// multiplication algorithm. This code is supposed to be written during the step
+// of type legalization, but given that we need to reconstruct the IR which is
+// not doable there, we do it here.
+bool CodeGenPrepare::optimizeUMulWithOverflow(Instruction *I) {
+  // Enable this optimization only for aarch64.
+  if (!TLI->getTargetMachine().getTargetTriple().isAArch64())
+    return false;
+  if (TLI->getTypeAction(
+          I->getContext(),
+          TLI->getValueType(*DL, I->getType()->getContainedType(0))) !=
+      TargetLowering::TypeExpandInteger)
+    return false;
+
+  Value *LHS = I->getOperand(0);
+  Value *RHS = I->getOperand(1);
+  auto *Ty = LHS->getType();
+  unsigned VTBitWidth = Ty->getScalarSizeInBits();
+  unsigned VTHalfBitWidth = VTBitWidth / 2;
+  auto *LegalTy = IntegerType::getIntNTy(I->getContext(), VTHalfBitWidth);
+
+  // Skip the optimizaiton if the type with HalfBitWidth is not legal for the
+  // target.
+  if (TLI->getTypeAction(I->getContext(), TLI->getValueType(*DL, LegalTy)) !=
+      TargetLowering::TypeLegal)
+    return false;
+
+  I->getParent()->setName("overflow.res");
+  auto *OverflowResBB = I->getParent();
+  auto *OverflowoEntryBB =
+      I->getParent()->splitBasicBlock(I, "overflow.entry", /*Before*/ true);
+  BasicBlock *NoOverflowBB = BasicBlock::Create(
+      I->getContext(), "overflow.no", I->getFunction(), OverflowResBB);
+  BasicBlock *OverflowBB = BasicBlock::Create(I->getContext(), "overflow",
+                                              I->getFunction(), OverflowResBB);
+  // new blocks should be:
+  //  entry:
+  //    (lhs_lo ne lhs_hi) || (rhs_lo ne rhs_hi) ? overflow, overflow_no
+
+  //  overflow_no:
+  //  overflow:
+  //  overflow.res:
+  //------------------------------------------------------------------------------
+  // BB overflow.entry:
+  // get Lo and Hi of RHS & LHS:
+  IRBuilder<> Builder(OverflowoEntryBB->getTerminator());
+  auto *LoRHS = Builder.CreateTrunc(RHS, LegalTy, "lo.rhs.trunc");
+  auto *ShrHiRHS = Builder.CreateLShr(RHS, VTHalfBitWidth, "rhs.lsr");
+  auto *HiRHS = Builder.CreateTrunc(ShrHiRHS, LegalTy, "hi.rhs.trunc");
+
+  auto *LoLHS = Builder.CreateTrunc(LHS, LegalTy, "lo.lhs.trunc");
+  auto *ShrHiLHS = Builder.CreateLShr(LHS, VTHalfBitWidth, "lhs.lsr");
+  auto *HiLHS = Builder.CreateTrunc(ShrHiLHS, LegalTy, "hi.lhs.trunc");
+
+  auto *CmpLHS = Builder.CreateCmp(ICmpInst::ICMP_NE, HiLHS,
+                                   ConstantInt::getNullValue(LegalTy));
+  auto *CmpRHS = Builder.CreateCmp(ICmpInst::ICMP_NE, HiRHS,
+                                   ConstantInt::getNullValue(LegalTy));
+  auto *Or = Builder.CreateOr(CmpLHS, CmpRHS, "or.lhs.rhs");
+  Builder.CreateCondBr(Or, OverflowBB, NoOverflowBB);
+  OverflowoEntryBB->getTerminator()->eraseFromParent();
+
+  //------------------------------------------------------------------------------
+  // BB overflow.no:
+  Builder.SetInsertPoint(NoOverflowBB);
+  auto *ExtLoLHS = Builder.CreateZExt(LoLHS, Ty, "lo.lhs.ext");
+  auto *ExtLoRHS = Builder.CreateZExt(LoRHS, Ty, "lo.rhs.ext");
+  auto *Mul = Builder.CreateMul(ExtLoLHS, ExtLoRHS, "mul.no.overflow");
+  StructType *STy = StructType::get(
+      I->getContext(), {Ty, IntegerType::getInt1Ty(I->getContext())});
+  Value *StructValNoOverflow = PoisonValue::get(STy);
+  StructValNoOverflow =
+      Builder.CreateInsertValue(StructValNoOverflow, Mul, {0});
+  StructValNoOverflow = Builder.CreateInsertValue(
+      StructValNoOverflow, ConstantInt::getFalse(I->getContext()), {1});
+  Builder.CreateBr(OverflowResBB);
+
+  //------------------------------------------------------------------------------
+  // BB overflow.res:
+  Builder.SetInsertPoint(OverflowResBB, OverflowResBB->getFirstInsertionPt());
+  auto *PHINode = Builder.CreatePHI(STy, 2);
+  PHINode->addIncoming(StructValNoOverflow, NoOverflowBB);
+
+  // Before moving the mul.overflow intrinsic to the overflowBB, replace all its
+  // uses by PHINode.
+  I->replaceAllUsesWith(PHINode);
+
+  // BB overflow:
+  PHINode->addIncoming(I, OverflowBB);
+  I->removeFromParent();
+  I->insertInto(OverflowBB, OverflowBB->end());
+  Builder.SetInsertPoint(OverflowBB, OverflowBB->end());
+  Builder.CreateBr(OverflowResBB);
+
+  // return false to stop reprocessing the function.
+  return false;
+}
+
+// Rewrite the smul_with_overflow intrinsic by checking if both of the
+// operands' value range is within the legal type. If so, we can optimize the
+// multiplication algorithm. This code is supposed to be written during the step
+// of type legalization, but given that we need to reconstruct the IR which is
+// not doable there, we do it here.
+bool CodeGenPrepare::optimizeSMulWithOverflow(Instruction *I) {
+  // Enable this optimization only for aarch64.
+  if (!TLI->getTargetMachine().getTargetTriple().isAArch64())
+    return false;
+  if (TLI->getTypeAction(
+          I->getContext(),
+          TLI->getValueType(*DL, I->getType()->getContainedType(0))) !=
+      TargetLowering::TypeExpandInteger)
+    return false;
+
+  Value *LHS = I->getOperand(0);
+  Value *RHS = I->getOperand(1);
+  auto *Ty = LHS->getType();
+  unsigned VTBitWidth = Ty->getScalarSizeInBits();
+  unsigned VTHalfBitWidth = VTBitWidth / 2;
+  auto *LegalTy = IntegerType::getIntNTy(I->getContext(), VTHalfBitWidth);
+
+  // Skip the optimizaiton if the type with HalfBitWidth is not legal for the
+  // target.
+  if (TLI->getTypeAction(I->getContext(), TLI->getValueType(*DL, LegalTy)) !=
+      TargetLowering::TypeLegal)
+    return false;
+
+  I->getParent()->setName("overflow.res");
+  auto *OverflowResBB = I->getParent();
+  auto *OverflowoEntryBB =
+      I->getParent()->splitBasicBlock(I, "overflow.entry", /*Before*/ true);
+  BasicBlock *NoOverflowBB = BasicBlock::Create(
+      I->getContext(), "overflow.no", I->getFunction(), OverflowResBB);
+  BasicBlock *OverflowBB = BasicBlock::Create(I->getContext(), "overflow",
+                                              I->getFunction(), OverflowResBB);
+  // new blocks should be:
+  //  entry:
+  //    (lhs_lo ne lhs_hi) || (rhs_lo ne rhs_hi) ? overflow, overflow_no
+
+  //  overflow_no:
+  //  overflow:
+  //  overflow.res:
+
+  //------------------------------------------------------------------------------
+  // BB overflow.entry:
+  // get Lo and Hi of RHS & LHS:
+  IRBuilder<> Builder(OverflowoEntryBB->getTerminator());
+  auto *LoRHS = Builder.CreateTrunc(RHS, LegalTy, "lo.rhs");
+  auto *SignLoRHS =
+      Builder.CreateAShr(LoRHS, VTHalfBitWidth - 1, "sign.lo.rhs");
+  auto *HiRHS = Builder.CreateLShr(RHS, VTHalfBitWidth, "rhs.lsr");
+  HiRHS = Builder.CreateTrunc(HiRHS, LegalTy, "hi.rhs");
+
+  auto *LoLHS = Builder.CreateTrunc(LHS, LegalTy, "lo.lhs");
+  auto *SignLoLHS =
+      Builder.CreateAShr(LoLHS, VTHalfBitWidth - 1, "sign.lo.lhs");
+  auto *HiLHS = Builder.CreateLShr(LHS, VTHalfBitWidth, "lhs.lsr");
+  HiLHS = Builder.CreateTrunc(HiLHS, LegalTy, "hi.lhs");
+
+  auto *CmpLHS = Builder.CreateCmp(ICmpInst::ICMP_NE, HiLHS, SignLoLHS);
+  auto *CmpRHS = Builder.CreateCmp(ICmpInst::ICMP_NE, HiRHS, SignLoRHS);
+  auto *Or = Builder.CreateOr(CmpLHS, CmpRHS, "or.lhs.rhs");
+  Builder.CreateCondBr(Or, OverflowBB, NoOverflowBB);
+  OverflowoEntryBB->getTerminator()->eraseFromParent();
+
+  //------------------------------------------------------------------------------
+  // BB overflow.no:
+  Builder.SetInsertPoint(NoOverflowBB);
+  auto *ExtLoLHS = Builder.CreateSExt(LoLHS, Ty, "lo.lhs.ext");
+  auto *ExtLoRHS = Builder.CreateSExt(LoRHS, Ty, "lo.rhs.ext");
+  auto *Mul = Builder.CreateMul(ExtLoLHS, ExtLoRHS, "mul.no.overflow");
+  StructType *STy = StructType::get(
+      I->getContext(), {Ty, IntegerType::getInt1Ty(I->getContext())});
+  Value *StructValNoOverflow = PoisonValue::get(STy);
+  StructValNoOverflow =
+      Builder.CreateInsertValue(StructValNoOverflow, Mul, {0});
+  StructValNoOverflow = Builder.CreateInsertValue(
+      StructValNoOverflow, ConstantInt::getFalse(I->getContext()), {1});
+  Builder.CreateBr(OverflowResBB);
+
+  //------------------------------------------------------------------------------
+  // BB overflow.res:
+  Builder.SetInsertPoint(OverflowResBB, OverflowResBB->getFirstInsertionPt());
+  auto *PHINode = Builder.CreatePHI(STy, 2);
----------------
aengelke wrote:

We shouldn't create PHIs of aggregate types, create two separate PHIs instead.

https://github.com/llvm/llvm-project/pull/148343


More information about the llvm-commits mailing list