[llvm] [CGP]: Optimize mul.overflow. (PR #148343)

David Green via llvm-commits llvm-commits at lists.llvm.org
Tue Oct 21 07:59:38 PDT 2025


================
@@ -6389,6 +6395,303 @@ bool CodeGenPrepare::optimizeGatherScatterInst(Instruction *MemoryInst,
   return true;
 }
 
+// Rewrite the mul_with_overflow intrinsic by checking if both of the
+// operands' value range is within the legal type. If so, we can optimize the
+// multiplication algorithm. This code is supposed to be written during the step
+// of type legalization, but given that we need to reconstruct the IR which is
+// not doable there, we do it here.
+bool CodeGenPrepare::optimizeMulWithOverflow(Instruction *I, bool IsSigned,
+                                             ModifyDT &ModifiedDT) {
+  if (!TLI->shouldOptimizeMulOverflowIntrinsic(
+          I->getContext(),
+          TLI->getValueType(*DL, I->getType()->getContainedType(0))))
+    return false;
+
+  Value *LHS = I->getOperand(0);
+  Value *RHS = I->getOperand(1);
+  Type *Ty = LHS->getType();
+  unsigned VTBitWidth = Ty->getScalarSizeInBits();
+  unsigned VTHalfBitWidth = VTBitWidth / 2;
+  IntegerType *LegalTy =
+      IntegerType::getIntNTy(I->getContext(), VTHalfBitWidth);
+
+  // Skip the optimization if the type with HalfBitWidth is not legal for the
+  // target.
+  if (TLI->getTypeAction(I->getContext(), TLI->getValueType(*DL, LegalTy)) !=
+      TargetLowering::TypeLegal)
+    return false;
+
+  // Check the pattern we are interested in where there are maximum 2 uses
+  // of the intrinsic which are the extracts instructions.
+  if (I->getNumUses() > 2)
+    return false;
+  ExtractValueInst *MulExtract = nullptr;
+  ExtractValueInst *OverflowExtract = nullptr;
+  for (User *U : I->users()) {
+    auto *Extract = dyn_cast<ExtractValueInst>(U);
+    if (!Extract)
+      return false;
+
+    unsigned Index = Extract->getIndices()[0];
+    if (Index == 0)
+      MulExtract = Extract;
+    else if (Index == 1)
+      OverflowExtract = Extract;
+  }
+
+  // Keep track of the instruction to stop reoptimizing it again.
+  InsertedInsts.insert(I);
+  // ----------------------------
+
+  // For the simple case where IR just checks the overflow flag, new blocks
+  // should be:
+  //  entry:
+  //    if signed:
+  //      ( (lhs_lo>>BW-1) ^ lhs_hi) || ( (rhs_lo>>BW-1) ^ rhs_hi) ? overflow,
+  //      overflow_no
+  //    else:
+  //      (lhs_hi != 0) || (rhs_hi != 0) ? overflow, overflow_no
+  //  overflow_no:
+  //  overflow:
+
+  // otherwise, new blocks should be:
+  //  entry:
+  //    if signed:
+  //      ( (lhs_lo>>BW-1) ^ lhs_hi) || ( (rhs_lo>>BW-1) ^ rhs_hi) ? overflow,
+  //      overflow_no
+  //    else:
+  //      (lhs_hi != 0) || (rhs_hi != 0) ? overflow, overflow_no
+  //  overflow_no:
+  //  overflow:
+  //  overflow.res:
+
+  // New BBs:
+  std::string KeepBBName = I->getParent()->getName().str();
+  BasicBlock *OverflowEntryBB =
+      I->getParent()->splitBasicBlock(I, "overflow.entry", /*Before*/ true);
+  // Remove the 'br' instruction that is generated as a result of the split
+  // as we are going to append new instructions.
+  OverflowEntryBB->getTerminator()->eraseFromParent();
+  BasicBlock *NoOverflowBB =
+      BasicBlock::Create(I->getContext(), "overflow.no", I->getFunction());
+  NoOverflowBB->moveAfter(OverflowEntryBB);
+  BasicBlock *OverflowBB =
+      BasicBlock::Create(I->getContext(), "overflow", I->getFunction());
+  OverflowBB->moveAfter(NoOverflowBB);
+
+  // BB overflow.entry:
+  IRBuilder<> Builder(OverflowEntryBB);
+  // Get Lo and Hi of LHS & RHS:
+  Value *LoLHS = Builder.CreateTrunc(LHS, LegalTy, "lo.lhs");
+  Value *HiLHS = Builder.CreateLShr(LHS, VTHalfBitWidth, "lhs.lsr");
+  HiLHS = Builder.CreateTrunc(HiLHS, LegalTy, "hi.lhs");
+  Value *LoRHS = Builder.CreateTrunc(RHS, LegalTy, "lo.rhs");
+  Value *HiRHS = Builder.CreateLShr(RHS, VTHalfBitWidth, "rhs.lsr");
+  HiRHS = Builder.CreateTrunc(HiRHS, LegalTy, "hi.rhs");
+
+  Value *IsAnyBitTrue;
+  if (IsSigned) {
+    Value *SignLoLHS =
+        Builder.CreateAShr(LoLHS, VTHalfBitWidth - 1, "sign.lo.lhs");
+    Value *SignLoRHS =
+        Builder.CreateAShr(LoRHS, VTHalfBitWidth - 1, "sign.lo.rhs");
+    Value *XorLHS = Builder.CreateXor(HiLHS, SignLoLHS);
+    Value *XorRHS = Builder.CreateXor(HiRHS, SignLoRHS);
+    Value *Or = Builder.CreateOr(XorLHS, XorRHS, "or.lhs.rhs");
+    IsAnyBitTrue = Builder.CreateCmp(ICmpInst::ICMP_NE, Or,
+                                     ConstantInt::getNullValue(Or->getType()));
+  } else {
+    Value *CmpLHS = Builder.CreateCmp(ICmpInst::ICMP_NE, HiLHS,
+                                      ConstantInt::getNullValue(LegalTy));
+    Value *CmpRHS = Builder.CreateCmp(ICmpInst::ICMP_NE, HiRHS,
+                                      ConstantInt::getNullValue(LegalTy));
+    IsAnyBitTrue = Builder.CreateOr(CmpLHS, CmpRHS, "or.lhs.rhs");
+  }
+  Builder.CreateCondBr(IsAnyBitTrue, OverflowBB, NoOverflowBB);
+
+  // BB overflow.no:
+  Builder.SetInsertPoint(NoOverflowBB);
+  Value *ExtLoLHS, *ExtLoRHS;
+  if (IsSigned) {
+    ExtLoLHS = Builder.CreateSExt(LoLHS, Ty, "lo.lhs.ext");
+    ExtLoRHS = Builder.CreateSExt(LoRHS, Ty, "lo.rhs.ext");
+  } else {
+    ExtLoLHS = Builder.CreateZExt(LoLHS, Ty, "lo.lhs.ext");
+    ExtLoRHS = Builder.CreateZExt(LoRHS, Ty, "lo.rhs.ext");
+  }
+
+  Value *Mul = Builder.CreateMul(ExtLoLHS, ExtLoRHS, "mul.overflow.no");
+
+  // In overflow.no BB: we are sure that the overflow flag is false.
+  // So, when we find this pattern:
+  // br (extractvalue (%mul, 1)), label %if.then, label %if.end
+  // then we can jump directly to %if.end as we're sure that there is no
+  // overflow. This is checking the simple case where the exiting br of I's BB
+  // is the branch we are interested in.
+  BasicBlock *NoOverflowBrBB = nullptr;
+  if (auto *Br = dyn_cast<BranchInst>(I->getParent()->getTerminator())) {
+    // Check that the Br is testing the overflow bit:
+    if (Br->isConditional()) {
+      auto *ExtInstr = dyn_cast<ExtractValueInst>(Br->getOperand(0));
+      if (ExtInstr && ExtInstr->getIndices()[0] == 1)
+        NoOverflowBrBB = Br->getSuccessor(1) /*if.end*/;
+    }
+  }
+  if (NoOverflowBrBB) {
+    // Duplicate instructions from I's BB to the NoOverflowBB:
+    ValueToValueMapTy VMap;
+    for (auto It = std::next(BasicBlock::iterator(I));
+         &*It != I->getParent()->getTerminator(); ++It) {
+      Instruction *OrigInst = &*It;
+      if (isa<DbgInfoIntrinsic>(OrigInst) || OrigInst == MulExtract ||
+          OrigInst == OverflowExtract)
+        continue;
+      Instruction *NewInst = nullptr;
+      NewInst = OrigInst->clone();
+      Builder.Insert(NewInst);
+      VMap[OrigInst] = NewInst;
+      RemapInstruction(NewInst, VMap, RF_IgnoreMissingLocals);
+    }
+    // Replace uses of MulExtract at the 'overflow.no' BB
+    if (MulExtract)
+      MulExtract->replaceUsesWithIf(Mul, [&](Use &U) {
+        return cast<Instruction>(U.getUser())->getParent() == NoOverflowBB;
+      });
+    if (OverflowExtract)
+      // Overflow flag is always false as we are sure it's not overflow.
+      OverflowExtract->replaceUsesWithIf(
+          ConstantInt::getFalse(I->getContext()), [&](Use &U) {
+            return cast<Instruction>(U.getUser())->getParent() == NoOverflowBB;
+          });
+    // BB overflow.no: jump directly to if.end BB
+    Builder.CreateBr(NoOverflowBrBB);
+
+    // Remove the original BB as it's divided into 'overflow.entry' and
+    // another BB where I exists.
+    BasicBlock *ToBeRemoveBB = I->getParent();
+    // BB overflow:
+    // Merge the original BB of I into the 'overflow' BB:
+    OverflowBB->splice(OverflowBB->end(), ToBeRemoveBB);
+
+    // Check if the Br BB has a PHI node and I->getParent() is one of
+    // its incoming BBs:
+    PHINode *PN = nullptr;
+    for (auto It = NoOverflowBrBB->begin(); It != NoOverflowBrBB->end(); ++It) {
+      if (!isa<PHINode>(&*It))
+        break;
+      PN = cast<PHINode>(&*It);
+      for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) {
+        if (PN->getIncomingBlock(i) == ToBeRemoveBB) {
+          // Replace the old block by the new 'overflow' BB:
+          PN->setIncomingBlock(i, OverflowBB);
+          Value *IncomingValue = PN->getIncomingValue(i);
+          // Check if the incoming value is a constant, duplicate it.
+          if (isa<Constant>(IncomingValue)) {
+            PN->addIncoming(IncomingValue, NoOverflowBB);
+            continue;
+          }
+          // Check if this instruction was cloned to the 'overflow.no' BB:
+          Instruction *ClonedInstr =
+              cast_or_null<Instruction>(VMap.lookup(IncomingValue));
+          if (ClonedInstr) {
+            PN->addIncoming(ClonedInstr, NoOverflowBB);
+            continue;
+          } else if (isa<Instruction>(IncomingValue)) {
+            if (cast<Instruction>(IncomingValue) == MulExtract) {
+              PN->addIncoming(Mul, NoOverflowBB);
+              continue;
+            }
+            if (cast<Instruction>(IncomingValue) == OverflowExtract) {
+              PN->addIncoming(ConstantInt::getFalse(I->getContext()),
+                              NoOverflowBB);
+              continue;
+            }
+          }
+          llvm_unreachable("Unexpected incoming value to PHI node");
+        }
+      }
+    }
+    if (!PN) {
----------------
davemgreen wrote:

Do we need to check that PN was the MulExtract or something? We could be checking multiple unrelated phis above.
And could OverflowExtract have other uses?

https://github.com/llvm/llvm-project/pull/148343


More information about the llvm-commits mailing list