[llvm] [CodeGenPrepare] Create USubWithOverflow_match (PR #160327)
via llvm-commits
llvm-commits at lists.llvm.org
Tue Sep 23 11:29:28 PDT 2025
https://github.com/AZero13 updated https://github.com/llvm/llvm-project/pull/160327
>From 22352ccf1c80d8bd6e074d7f434b7a1672f97aa3 Mon Sep 17 00:00:00 2001
From: AZero13 <gfunni234 at gmail.com>
Date: Tue, 23 Sep 2025 11:14:42 -0400
Subject: [PATCH] [CodeGenPrepare] Create USubWithOverflow_match (NFC)
To make it consistent with m_UAddWithOverflow_match.
---
llvm/include/llvm/IR/PatternMatch.h | 75 +++++++++++++++++++
llvm/lib/CodeGen/CodeGenPrepare.cpp | 60 +++++++++++----
.../InstCombine/InstCombineCompares.cpp | 17 +++++
3 files changed, 137 insertions(+), 15 deletions(-)
diff --git a/llvm/include/llvm/IR/PatternMatch.h b/llvm/include/llvm/IR/PatternMatch.h
index 6168e24569f99..6da6eca8677f8 100644
--- a/llvm/include/llvm/IR/PatternMatch.h
+++ b/llvm/include/llvm/IR/PatternMatch.h
@@ -2685,6 +2685,81 @@ m_UAddWithOverflow(const LHS_t &L, const RHS_t &R, const Sum_t &S) {
return UAddWithOverflow_match<LHS_t, RHS_t, Sum_t>(L, R, S);
}
+template <typename LHS_t, typename RHS_t, typename Diff_t>
+struct USubWithOverflow_match {
+ LHS_t L;
+ RHS_t R;
+ Diff_t S;
+
+ USubWithOverflow_match(const LHS_t &L, const RHS_t &R, const Diff_t &S)
+ : L(L), R(R), S(S) {}
+
+ template <typename OpTy> bool match(OpTy *V) const {
+ Value *ICmpLHS = nullptr, *ICmpRHS = nullptr;
+ CmpPredicate Pred;
+ if (!m_ICmp(Pred, m_Value(ICmpLHS), m_Value(ICmpRHS)).match(V))
+ return false;
+
+ Value *SubLHS = nullptr, *SubRHS = nullptr;
+ auto SubExpr = m_Sub(m_Value(SubLHS), m_Value(SubRHS));
+
+ Value *AddLHS = nullptr, *AddRHS = nullptr;
+ auto AddExpr = m_Add(m_Value(AddLHS), m_Value(AddRHS));
+
+ // (a - b) >u a OR (a + (-c)) >u a (allow add-canonicalized forms
+ // but only where the RHS is a constant APInt that is negative)
+ if (Pred == ICmpInst::ICMP_UGT) {
+ if (SubExpr.match(ICmpLHS) && ICmpRHS == SubLHS)
+ return L.match(SubLHS) && R.match(SubRHS) && S.match(ICmpLHS);
+
+ if (AddExpr.match(ICmpLHS)) {
+ const APInt *AddC = nullptr;
+ if (m_APInt(AddC).match(AddRHS) && ICmpRHS == AddLHS) {
+ APInt NegC = -(*AddC);
+ Constant *NegConst = ConstantInt::get(AddRHS->getType(), NegC);
+ return L.match(AddLHS) && R.match(NegConst) && S.match(ICmpLHS);
+ }
+ }
+ }
+
+ // a <u (a - b) OR a <u (a + (-c))
+ if (Pred == ICmpInst::ICMP_ULT) {
+ if (SubExpr.match(ICmpRHS) && ICmpLHS == SubLHS)
+ return L.match(SubLHS) && R.match(SubRHS) && S.match(ICmpRHS);
+
+ if (AddExpr.match(ICmpRHS)) {
+ const APInt *AddC = nullptr;
+ if (m_APInt(AddC).match(AddRHS) && ICmpLHS == AddLHS) {
+ APInt NegC = -(*AddC);
+ Constant *NegConst = ConstantInt::get(AddRHS->getType(), NegC);
+ return L.match(AddLHS) && R.match(NegConst) && S.match(ICmpRHS);
+ }
+ }
+ }
+
+ // Special-case for 0 - a != 0 (common canonicalization)
+ if (Pred == ICmpInst::ICMP_NE) {
+ // (0 - a) != 0
+ if (SubExpr.match(ICmpLHS) && m_Zero().match(ICmpRHS) &&
+ m_Zero().match(SubLHS))
+ return L.match(SubLHS) && R.match(SubRHS) && S.match(ICmpLHS);
+
+ // 0 != (0 - a)
+ if (m_Zero().match(ICmpLHS) && SubExpr.match(ICmpRHS) &&
+ m_Zero().match(SubLHS))
+ return L.match(SubLHS) && R.match(SubRHS) && S.match(ICmpRHS);
+ }
+
+ return false;
+ }
+};
+
+template <typename LHS_t, typename RHS_t, typename Diff_t>
+USubWithOverflow_match<LHS_t, RHS_t, Diff_t>
+m_USubWithOverflow(const LHS_t &L, const RHS_t &R, const Diff_t &S) {
+ return USubWithOverflow_match<LHS_t, RHS_t, Diff_t>(L, R, S);
+}
+
template <typename Opnd_t> struct Argument_match {
unsigned OpI;
Opnd_t Val;
diff --git a/llvm/lib/CodeGen/CodeGenPrepare.cpp b/llvm/lib/CodeGen/CodeGenPrepare.cpp
index d290f202f3cca..cc596aed4cc85 100644
--- a/llvm/lib/CodeGen/CodeGenPrepare.cpp
+++ b/llvm/lib/CodeGen/CodeGenPrepare.cpp
@@ -1695,19 +1695,23 @@ bool CodeGenPrepare::combineToUAddWithOverflow(CmpInst *Cmp,
return true;
}
-bool CodeGenPrepare::combineToUSubWithOverflow(CmpInst *Cmp,
- ModifyDT &ModifiedDT) {
- // We are not expecting non-canonical/degenerate code. Just bail out.
+static bool matchUSubWithOverflowConstantEdgeCases(CmpInst *Cmp,
+ BinaryOperator *&Sub) {
+ // A - B, A u> B --> usubo(A, B)
Value *A = Cmp->getOperand(0), *B = Cmp->getOperand(1);
+
+ // We are not expecting non-canonical/degenerate code. Just bail out.
if (isa<Constant>(A) && isa<Constant>(B))
return false;
- // Convert (A u> B) to (A u< B) to simplify pattern matching.
ICmpInst::Predicate Pred = Cmp->getPredicate();
+
+ // Normalize: convert (A u> B) -> (B u< A)
if (Pred == ICmpInst::ICMP_UGT) {
std::swap(A, B);
Pred = ICmpInst::ICMP_ULT;
}
+
// Convert special-case: (A == 0) is the same as (A u< 1).
if (Pred == ICmpInst::ICMP_EQ && match(B, m_ZeroInt())) {
B = ConstantInt::get(B->getType(), 1);
@@ -1718,19 +1722,22 @@ bool CodeGenPrepare::combineToUSubWithOverflow(CmpInst *Cmp,
std::swap(A, B);
Pred = ICmpInst::ICMP_ULT;
}
+
if (Pred != ICmpInst::ICMP_ULT)
return false;
- // Walk the users of a variable operand of a compare looking for a subtract or
- // add with that same operand. Also match the 2nd operand of the compare to
- // the add/sub, but that may be a negated constant operand of an add.
+ // Walk the users of the variable operand of the compare looking for a
+ // subtract or add with that same operand. Also match the 2nd operand of the
+ // compare to the add/sub, but that may be a negated constant operand of an
+ // add.
Value *CmpVariableOperand = isa<Constant>(A) ? B : A;
- BinaryOperator *Sub = nullptr;
+ Sub = nullptr;
+
for (User *U : CmpVariableOperand->users()) {
// A - B, A u< B --> usubo(A, B)
if (match(U, m_Sub(m_Specific(A), m_Specific(B)))) {
Sub = cast<BinaryOperator>(U);
- break;
+ return true;
}
// A + (-C), A u< C (canonicalized form of (sub A, C))
@@ -1738,19 +1745,42 @@ bool CodeGenPrepare::combineToUSubWithOverflow(CmpInst *Cmp,
if (match(U, m_Add(m_Specific(A), m_APInt(AddC))) &&
match(B, m_APInt(CmpC)) && *AddC == -(*CmpC)) {
Sub = cast<BinaryOperator>(U);
- break;
+ return true;
}
}
- if (!Sub)
- return false;
+ return false;
+}
+
+bool CodeGenPrepare::combineToUSubWithOverflow(CmpInst *Cmp,
+ ModifyDT &ModifiedDT) {
+ bool EdgeCase = false;
+ Value *A = nullptr, *B = nullptr;
+ BinaryOperator *Sub = nullptr;
+
+ // If the compare already matches the (sub, icmp) pattern use it directly.
+ if (!match(Cmp, m_USubWithOverflow(m_Value(A), m_Value(B), m_BinOp(Sub)))) {
+ // Otherwise try to recognize constant-edge-case forms like
+ // icmp ne (sub 0, B), 0 or
+ // icmp eq (sub A, 1), 0
+ if (!matchUSubWithOverflowConstantEdgeCases(Cmp, Sub))
+ return false;
+ // Set A/B from the discovered Sub and record that this was an edge-case
+ // match.
+ A = Sub->getOperand(0);
+ B = Sub->getOperand(1);
+ EdgeCase = true;
+ }
+
+ // Check target wants the overflow intrinsic formed. When matching an
+ // edge-case we allow forming the intrinsic with fewer uses.
if (!TLI->shouldFormOverflowOp(ISD::USUBO,
TLI->getValueType(*DL, Sub->getType()),
- Sub->hasNUsesOrMore(1)))
+ Sub->hasNUsesOrMore(EdgeCase ? 1 : 2)))
return false;
- if (!replaceMathCmpWithIntrinsic(Sub, Sub->getOperand(0), Sub->getOperand(1),
- Cmp, Intrinsic::usub_with_overflow))
+ if (!replaceMathCmpWithIntrinsic(Sub, A, B, Cmp,
+ Intrinsic::usub_with_overflow))
return false;
// Reset callers - do not crash by iterating over a dead instruction.
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
index e4cb457499ef5..5c7aae5f91fab 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -7829,6 +7829,23 @@ Instruction *InstCombinerImpl::visitICmpInst(ICmpInst &I) {
}
}
+ Instruction *SubI = nullptr;
+ if (match(&I, m_USubWithOverflow(m_Value(X), m_Value(Y),
+ m_Instruction(SubI))) &&
+ isa<IntegerType>(X->getType())) {
+ Value *Result;
+ Constant *Overflow;
+ // m_UAddWithOverflow can match patterns that do not include an explicit
+ // "add" instruction, so check the opcode of the matched op.
+ if (SubI->getOpcode() == Instruction::Sub &&
+ OptimizeOverflowCheck(Instruction::Sub, /*Signed*/ false, X, Y, *SubI,
+ Result, Overflow)) {
+ replaceInstUsesWith(*SubI, Result);
+ eraseInstFromFunction(*SubI);
+ return replaceInstUsesWith(I, Overflow);
+ }
+ }
+
// (zext X) * (zext Y) --> llvm.umul.with.overflow.
if (match(Op0, m_NUWMul(m_ZExt(m_Value(X)), m_ZExt(m_Value(Y)))) &&
match(Op1, m_APInt(C))) {
More information about the llvm-commits
mailing list