[llvm] [llvm] Optimize usub.sat fix for #79690 (PR #151044)

Nimit Sachdeva via llvm-commits llvm-commits at lists.llvm.org
Sun Aug 24 19:48:33 PDT 2025


https://github.com/nimit25 updated https://github.com/llvm/llvm-project/pull/151044

>From 82336b681fa6e917d4fe217a227614f481aa30ac Mon Sep 17 00:00:00 2001
From: Nimit Sachdeva <nimsach at amazon.com>
Date: Mon, 28 Jul 2025 17:23:11 -0400
Subject: [PATCH 1/7] Optimize usub.sat fix for #79690

---
 .../InstCombine/InstCombineSelect.cpp         | 143 +++++++++++++++++-
 .../InstCombine/usub_sat_to_msb_mask.ll       | 126 +++++++++++++++
 2 files changed, 263 insertions(+), 6 deletions(-)
 create mode 100644 llvm/test/Transforms/InstCombine/usub_sat_to_msb_mask.ll

diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
index eb4332fbc0959..74544009e6872 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
@@ -42,6 +42,8 @@
 #include "llvm/Support/KnownBits.h"
 #include "llvm/Transforms/InstCombine/InstCombiner.h"
 #include <cassert>
+#include <cstddef>
+#include <iostream>
 #include <utility>
 
 #define DEBUG_TYPE "instcombine"
@@ -50,7 +52,6 @@
 using namespace llvm;
 using namespace PatternMatch;
 
-
 /// Replace a select operand based on an equality comparison with the identity
 /// constant of a binop.
 static Instruction *foldSelectBinOpIdentity(SelectInst &Sel,
@@ -1713,7 +1714,6 @@ tryToReuseConstantFromSelectInComparison(SelectInst &Sel, ICmpInst &Cmp,
   if (Pred == CmpInst::ICMP_ULT && match(X, m_Add(m_Value(), m_Constant())))
     return nullptr;
 
-
   Value *SelVal0, *SelVal1; // We do not care which one is from where.
   match(&Sel, m_Select(m_Value(), m_Value(SelVal0), m_Value(SelVal1)));
   // At least one of these values we are selecting between must be a constant
@@ -1993,6 +1993,135 @@ Value *InstCombinerImpl::foldSelectWithConstOpToBinOp(ICmpInst *Cmp,
   return BinOp;
 }
 
+/// Folds:
+///   %a_sub = call @llvm.usub.sat(x, IntConst1)
+///   %b_sub = call @llvm.usub.sat(y, IntConst2)
+///   %or = or %a_sub, %b_sub
+///   %cmp = icmp eq %or, 0
+///   %sel = select %cmp, 0, MostSignificantBit
+/// into:
+///   %a_sub' = usub.sat(x, IntConst1 - MostSignificantBit)
+///   %b_sub' = usub.sat(y, IntConst2 - MostSignificantBit)
+///   %or = or %a_sub', %b_sub'
+///   %and = and %or, MostSignificantBit
+/// If the args are vectors
+///
+static Instruction *foldICmpUSubSatWithAndForMostSignificantBitCmp(
+    SelectInst &SI, ICmpInst *ICI, InstCombiner::BuilderTy &Builder) {
+  auto *CI = dyn_cast<ICmpInst>(SI.getCondition());
+  if (!CI) {
+    return nullptr;
+  }
+
+  Value *CmpLHS = CI->getOperand(0);
+  Value *CmpRHS = CI->getOperand(1);
+  if (!match(CmpRHS, m_Zero())) {
+    return nullptr;
+  }
+  auto Pred = CI->getPredicate();
+  auto *TrueVal = SI.getTrueValue();
+  auto *FalseVal = SI.getFalseValue();
+
+  if (Pred != ICmpInst::ICMP_EQ)
+    return nullptr;
+
+  // Match: icmp eq (or (usub.sat A, IntConst1), (usub.sat B, IntConst2)), 0
+  Value *A, *B;
+  ConstantInt *IntConst1, *IntConst2, *PossibleMSBInt;
+
+  if (match(CmpLHS, m_Or(m_Intrinsic<Intrinsic::usub_sat>(
+                             m_Value(A), m_ConstantInt(IntConst1)),
+                         m_Intrinsic<Intrinsic::usub_sat>(
+                             m_Value(B), m_ConstantInt(IntConst2)))) &&
+      match(TrueVal, m_Zero()) &&
+      match(FalseVal, m_ConstantInt(PossibleMSBInt))) {
+    auto *Ty = A->getType();
+    unsigned BW = Ty->getIntegerBitWidth();
+    APInt MostSignificantBit = APInt::getOneBitSet(BW, BW - 1);
+
+    if (PossibleMSBInt->getValue() != MostSignificantBit)
+      return nullptr;
+    // Ensure IntConst1 and IntConst2 are >= MostSignificantBit
+    if (IntConst1->getValue().ult(MostSignificantBit) ||
+        IntConst2->getValue().ult(MostSignificantBit))
+      return nullptr;
+
+    // Rewrite:
+    Value *NewA = Builder.CreateBinaryIntrinsic(
+        Intrinsic::usub_sat, A,
+        ConstantInt::get(Ty, IntConst1->getValue() - MostSignificantBit + 1));
+    Value *NewB = Builder.CreateBinaryIntrinsic(
+        Intrinsic::usub_sat, B,
+        ConstantInt::get(Ty, IntConst2->getValue() - MostSignificantBit + 1));
+    Value *Or = Builder.CreateOr(NewA, NewB);
+    Value *And =
+        Builder.CreateAnd(Or, ConstantInt::get(Ty, MostSignificantBit));
+    return cast<Instruction>(And);
+  }
+  Constant *Const1, *Const2, *PossibleMSB;
+  if (match(CmpLHS, m_Or(m_Intrinsic<Intrinsic::usub_sat>(m_Value(A),
+                                                          m_Constant(Const1)),
+                         m_Intrinsic<Intrinsic::usub_sat>(
+                             m_Value(B), m_Constant(Const2)))) &&
+      match(TrueVal, m_Zero()) && match(FalseVal, m_Constant(PossibleMSB))) {
+    auto *VecTy1 = dyn_cast<FixedVectorType>(Const1->getType());
+    auto *VecTy2 = dyn_cast<FixedVectorType>(Const2->getType());
+    auto *VecTyMSB = dyn_cast<FixedVectorType>(PossibleMSB->getType());
+    if (!VecTy1 || !VecTy2 || !VecTyMSB) {
+      return nullptr;
+    }
+
+    unsigned NumElements = VecTy1->getNumElements();
+
+    if (NumElements != VecTy2->getNumElements() ||
+        NumElements != VecTyMSB->getNumElements() || NumElements == 0) {
+      return nullptr;
+    }
+    auto *SplatMSB =
+        dyn_cast<ConstantInt>(PossibleMSB->getAggregateElement(0u));
+    unsigned BW = SplatMSB->getValue().getBitWidth();
+    APInt MostSignificantBit = APInt::getOneBitSet(BW, BW - 1);
+    if (!SplatMSB || SplatMSB->getValue() != MostSignificantBit) {
+      return nullptr;
+    }
+    for (unsigned int i = 1; i < NumElements; ++i) {
+      auto *Element =
+          dyn_cast<ConstantInt>(PossibleMSB->getAggregateElement(i));
+      if (!Element || Element->getValue() != SplatMSB->getValue()) {
+        return nullptr;
+      }
+    }
+    SmallVector<Constant *, 16> Arg1, Arg2;
+    for (unsigned int i = 0; i < NumElements; ++i) {
+      auto *E1 = dyn_cast<ConstantInt>(Const1->getAggregateElement(i));
+      auto *E2 = dyn_cast<ConstantInt>(Const2->getAggregateElement(i));
+      if (!E1 || !E2) {
+        return nullptr;
+      }
+      if (E1->getValue().ult(SplatMSB->getValue()) ||
+          E2->getValue().ult(SplatMSB->getValue())) {
+        return nullptr;
+      }
+      Arg1.emplace_back(
+          ConstantInt::get(A->getType()->getScalarType(),
+                           E1->getValue() - MostSignificantBit + 1));
+      Arg2.emplace_back(
+          ConstantInt::get(B->getType()->getScalarType(),
+                           E2->getValue() - MostSignificantBit + 1));
+    }
+    Constant *ConstVec1 = ConstantVector::get(Arg1);
+    Constant *ConstVec2 = ConstantVector::get(Arg2);
+    Value *NewA =
+        Builder.CreateBinaryIntrinsic(Intrinsic::usub_sat, A, ConstVec1);
+    Value *NewB =
+        Builder.CreateBinaryIntrinsic(Intrinsic::usub_sat, B, ConstVec2);
+    Value *Or = Builder.CreateOr(NewA, NewB);
+    Value *And = Builder.CreateAnd(Or, PossibleMSB);
+    return cast<Instruction>(And);
+  }
+  return nullptr;
+}
+
 /// Visit a SelectInst that has an ICmpInst as its first operand.
 Instruction *InstCombinerImpl::foldSelectInstWithICmp(SelectInst &SI,
                                                       ICmpInst *ICI) {
@@ -2009,6 +2138,9 @@ Instruction *InstCombinerImpl::foldSelectInstWithICmp(SelectInst &SI,
   if (Instruction *NewSel =
           tryToReuseConstantFromSelectInComparison(SI, *ICI, *this))
     return NewSel;
+  if (Instruction *Folded =
+          foldICmpUSubSatWithAndForMostSignificantBitCmp(SI, ICI, Builder))
+    return replaceInstUsesWith(SI, Folded);
 
   // NOTE: if we wanted to, this is where to detect integer MIN/MAX
   bool Changed = false;
@@ -4200,10 +4332,9 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) {
       bool IsCastNeeded = LHS->getType() != SelType;
       Value *CmpLHS = cast<CmpInst>(CondVal)->getOperand(0);
       Value *CmpRHS = cast<CmpInst>(CondVal)->getOperand(1);
-      if (IsCastNeeded ||
-          (LHS->getType()->isFPOrFPVectorTy() &&
-           ((CmpLHS != LHS && CmpLHS != RHS) ||
-            (CmpRHS != LHS && CmpRHS != RHS)))) {
+      if (IsCastNeeded || (LHS->getType()->isFPOrFPVectorTy() &&
+                           ((CmpLHS != LHS && CmpLHS != RHS) ||
+                            (CmpRHS != LHS && CmpRHS != RHS)))) {
         CmpInst::Predicate MinMaxPred = getMinMaxPred(SPF, SPR.Ordered);
 
         Value *Cmp;
diff --git a/llvm/test/Transforms/InstCombine/usub_sat_to_msb_mask.ll b/llvm/test/Transforms/InstCombine/usub_sat_to_msb_mask.ll
new file mode 100644
index 0000000000000..ffa77f4b42138
--- /dev/null
+++ b/llvm/test/Transforms/InstCombine/usub_sat_to_msb_mask.ll
@@ -0,0 +1,126 @@
+
+; RUN: opt -passes=instcombine -S < %s 2>&1 | FileCheck %s
+
+declare i8 @llvm.usub.sat.i8(i8, i8)
+declare i16 @llvm.usub.sat.i16(i16, i16)
+declare i32 @llvm.usub.sat.i32(i32, i32)
+declare i64 @llvm.usub.sat.i64(i64, i64)
+
+define i8 @test_i8(i8 %a, i8 %b) {
+; CHECK-LABEL: @test_i8(
+; CHECK-NEXT: call i8 @llvm.usub.sat.i8(i8 %a, i8 96)
+; CHECK-NEXT: call i8 @llvm.usub.sat.i8(i8 %b, i8 112)
+; CHECK-NEXT: or i8
+; CHECK-NEXT: and i8
+; CHECK-NEXT: ret i8
+
+  %a_sub = call i8 @llvm.usub.sat.i8(i8 %a, i8 223)
+  %b_sub = call i8 @llvm.usub.sat.i8(i8 %b, i8 239)
+  %or = or i8 %a_sub, %b_sub
+  %cmp = icmp eq i8 %or, 0
+  %res = select i1 %cmp, i8 0, i8 128
+  ret i8 %res
+}
+
+define i16 @test_i16(i16 %a, i16 %b) {
+; CHECK-LABEL: @test_i16(
+; CHECK-NEXT: call i16 @llvm.usub.sat.i16(i16 %a, i16 32642)
+; CHECK-NEXT: call i16 @llvm.usub.sat.i16(i16 %b, i16 32656)
+; CHECK-NEXT: or i16
+; CHECK-NEXT: and i16
+; CHECK-NEXT: ret i16
+
+  %a_sub = call i16 @llvm.usub.sat.i16(i16 %a, i16 65409)
+  %b_sub = call i16 @llvm.usub.sat.i16(i16 %b, i16 65423)
+  %or = or i16 %a_sub, %b_sub
+  %cmp = icmp eq i16 %or, 0
+  %res = select i1 %cmp, i16 0, i16 32768
+  ret i16 %res
+}
+
+define i32 @test_i32(i32 %a, i32 %b) {
+; CHECK-LABEL: @test_i32(
+; CHECK-NEXT: call i32 @llvm.usub.sat.i32(i32 %a, i32 224)
+; CHECK-NEXT: call i32 @llvm.usub.sat.i32(i32 %b, i32 240)
+; CHECK-NEXT: or i32
+; CHECK-NEXT: and i32
+; CHECK-NEXT: ret i32
+
+  %a_sub = call i32 @llvm.usub.sat.i32(i32 %a, i32 2147483871)
+  %b_sub = call i32 @llvm.usub.sat.i32(i32 %b, i32 2147483887)
+  %or = or i32 %a_sub, %b_sub
+  %cmp = icmp eq i32 %or, 0
+  %res = select i1 %cmp, i32 0, i32 2147483648
+  ret i32 %res
+}
+
+define i64 @test_i64(i64 %a, i64 %b) {
+; CHECK-LABEL: @test_i64(
+; CHECK-NEXT: call i64 @llvm.usub.sat.i64(i64 %a, i64 224)
+; CHECK-NEXT: call i64 @llvm.usub.sat.i64(i64 %b, i64 240)
+; CHECK-NEXT: or i64
+; CHECK-NEXT: and i64
+; CHECK-NEXT: ret i64
+
+  %a_sub = call i64 @llvm.usub.sat.i64(i64 %a, i64 9223372036854776031)
+  %b_sub = call i64 @llvm.usub.sat.i64(i64 %b, i64 9223372036854776047)
+  %or = or i64 %a_sub, %b_sub
+  %cmp = icmp eq i64 %or, 0
+  %res = select i1 %cmp, i64 0, i64 9223372036854775808
+  ret i64 %res
+}
+
+define i32 @no_fold_due_to_small_K(i32 %a, i32 %b) {
+; CHECK-LABEL: @no_fold_due_to_small_K(
+; CHECK: call i32 @llvm.usub.sat.i32(i32 %a, i32 100)
+; CHECK: call i32 @llvm.usub.sat.i32(i32 %b, i32 239)
+; CHECK: or i32
+; CHECK: icmp eq i32
+; CHECK: select
+; CHECK: ret i32
+
+  %a_sub = call i32 @llvm.usub.sat.i32(i32 %a, i32 100)
+  %b_sub = call i32 @llvm.usub.sat.i32(i32 %b, i32 239)
+  %or = or i32 %a_sub, %b_sub
+  %cmp = icmp eq i32 %or, 0
+  %res = select i1 %cmp, i32 0, i32 2147483648
+  ret i32 %res
+}
+
+define i32 @commuted_test_neg(i32 %a, i32 %b) {
+; CHECK-LABEL: @commuted_test_neg(
+; CHECK-NEXT: call i32 @llvm.usub.sat.i32(i32 %b, i32 239)
+; CHECK-NEXT: call i32 @llvm.usub.sat.i32(i32 %a, i32 223)
+; CHECK-NEXT: or i32
+; CHECK-NEXT: icmp eq i32
+; CHECK-NEXT: select
+; CHECK-NEXT: ret i32
+
+  %b_sub = call i32 @llvm.usub.sat.i32(i32 %b, i32 239)
+  %a_sub = call i32 @llvm.usub.sat.i32(i32 %a, i32 223)
+  %or = or i32 %b_sub, %a_sub
+  %cmp = icmp eq i32 %or, 0
+  %res = select i1 %cmp, i32 0, i32 2147483648
+  ret i32 %res
+}
+define <4 x i32> @vector_test(<4 x i32> %a, <4 x i32> %b) {
+; CHECK-LABEL: @vector_test(
+; CHECK-NEXT: call <4 x i32> @llvm.usub.sat.v4i32(<4 x i32> %a, <4 x i32> splat (i32 224))
+; CHECK-NEXT: call <4 x i32> @llvm.usub.sat.v4i32(<4 x i32> %b, <4 x i32> splat (i32 240))
+; CHECK-NEXT: or <4 x i32>
+; CHECK-NEXT: and <4 x i32>
+; CHECK-NEXT: ret <4 x i32>
+
+
+  %a_sub = call <4 x i32> @llvm.usub.sat.v4i32(
+              <4 x i32> %a,
+              <4 x i32> <i32 2147483871, i32 2147483871, i32 2147483871, i32 2147483871>)
+  %b_sub = call <4 x i32> @llvm.usub.sat.v4i32(
+              <4 x i32> %b,
+              <4 x i32> <i32 2147483887, i32 2147483887, i32 2147483887, i32 2147483887>)
+  %or = or <4 x i32> %a_sub, %b_sub
+  %cmp = icmp eq <4 x i32> %or, zeroinitializer
+  %res = select <4 x i1> %cmp, <4 x i32> zeroinitializer,
+                         <4 x i32> <i32 -2147483648, i32 -2147483648, i32 -2147483648, i32 -2147483648>
+  ret <4 x i32> %res
+}

>From dbfd5989029598038bb9b3bef4cbba31619aa764 Mon Sep 17 00:00:00 2001
From: Nimit Sachdeva <nimsach at amazon.com>
Date: Mon, 28 Jul 2025 17:56:10 -0400
Subject: [PATCH 2/7] refactorization

---
 .../Transforms/InstCombine/InstCombineSelect.cpp   | 14 +++++++-------
 1 file changed, 7 insertions(+), 7 deletions(-)

diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
index 74544009e6872..7e4eaa9745917 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
@@ -42,8 +42,6 @@
 #include "llvm/Support/KnownBits.h"
 #include "llvm/Transforms/InstCombine/InstCombiner.h"
 #include <cassert>
-#include <cstddef>
-#include <iostream>
 #include <utility>
 
 #define DEBUG_TYPE "instcombine"
@@ -52,6 +50,7 @@
 using namespace llvm;
 using namespace PatternMatch;
 
+
 /// Replace a select operand based on an equality comparison with the identity
 /// constant of a binop.
 static Instruction *foldSelectBinOpIdentity(SelectInst &Sel,
@@ -1714,6 +1713,7 @@ tryToReuseConstantFromSelectInComparison(SelectInst &Sel, ICmpInst &Cmp,
   if (Pred == CmpInst::ICMP_ULT && match(X, m_Add(m_Value(), m_Constant())))
     return nullptr;
 
+
   Value *SelVal0, *SelVal1; // We do not care which one is from where.
   match(&Sel, m_Select(m_Value(), m_Value(SelVal0), m_Value(SelVal1)));
   // At least one of these values we are selecting between must be a constant
@@ -2004,8 +2004,7 @@ Value *InstCombinerImpl::foldSelectWithConstOpToBinOp(ICmpInst *Cmp,
 ///   %b_sub' = usub.sat(y, IntConst2 - MostSignificantBit)
 ///   %or = or %a_sub', %b_sub'
 ///   %and = and %or, MostSignificantBit
-/// If the args are vectors
-///
+/// Likewise, for vector arguments as well.
 static Instruction *foldICmpUSubSatWithAndForMostSignificantBitCmp(
     SelectInst &SI, ICmpInst *ICI, InstCombiner::BuilderTy &Builder) {
   auto *CI = dyn_cast<ICmpInst>(SI.getCondition());
@@ -4332,9 +4331,10 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) {
       bool IsCastNeeded = LHS->getType() != SelType;
       Value *CmpLHS = cast<CmpInst>(CondVal)->getOperand(0);
       Value *CmpRHS = cast<CmpInst>(CondVal)->getOperand(1);
-      if (IsCastNeeded || (LHS->getType()->isFPOrFPVectorTy() &&
-                           ((CmpLHS != LHS && CmpLHS != RHS) ||
-                            (CmpRHS != LHS && CmpRHS != RHS)))) {
+      if (IsCastNeeded ||
+          (LHS->getType()->isFPOrFPVectorTy() &&
+           ((CmpLHS != LHS && CmpLHS != RHS) ||
+            (CmpRHS != LHS && CmpRHS != RHS)))) {
         CmpInst::Predicate MinMaxPred = getMinMaxPred(SPF, SPR.Ordered);
 
         Value *Cmp;

>From c572adbf917ce9a07e060eb78b50f7f00e79bd77 Mon Sep 17 00:00:00 2001
From: Nimit Sachdeva <nimsach at amazon.com>
Date: Sun, 17 Aug 2025 22:37:54 -0400
Subject: [PATCH 3/7] Add more tests and change the condition for negation

---
 .../InstCombine/InstCombineSelect.cpp         |  13 +-
 .../InstCombine/usub_sat_to_msb_mask.ll       | 178 +++++++++++++-----
 2 files changed, 136 insertions(+), 55 deletions(-)

diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
index 7e4eaa9745917..cd0eb59007b3b 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
@@ -50,7 +50,6 @@
 using namespace llvm;
 using namespace PatternMatch;
 
-
 /// Replace a select operand based on an equality comparison with the identity
 /// constant of a binop.
 static Instruction *foldSelectBinOpIdentity(SelectInst &Sel,
@@ -1713,7 +1712,6 @@ tryToReuseConstantFromSelectInComparison(SelectInst &Sel, ICmpInst &Cmp,
   if (Pred == CmpInst::ICMP_ULT && match(X, m_Add(m_Value(), m_Constant())))
     return nullptr;
 
-
   Value *SelVal0, *SelVal1; // We do not care which one is from where.
   match(&Sel, m_Select(m_Value(), m_Value(SelVal0), m_Value(SelVal1)));
   // At least one of these values we are selecting between must be a constant
@@ -2021,7 +2019,7 @@ static Instruction *foldICmpUSubSatWithAndForMostSignificantBitCmp(
   auto *TrueVal = SI.getTrueValue();
   auto *FalseVal = SI.getFalseValue();
 
-  if (Pred != ICmpInst::ICMP_EQ)
+  if (Pred != ICmpInst::ICMP_EQ && Pred != llvm::ICmpInst::ICMP_NE)
     return nullptr;
 
   // Match: icmp eq (or (usub.sat A, IntConst1), (usub.sat B, IntConst2)), 0
@@ -2032,8 +2030,10 @@ static Instruction *foldICmpUSubSatWithAndForMostSignificantBitCmp(
                              m_Value(A), m_ConstantInt(IntConst1)),
                          m_Intrinsic<Intrinsic::usub_sat>(
                              m_Value(B), m_ConstantInt(IntConst2)))) &&
-      match(TrueVal, m_Zero()) &&
-      match(FalseVal, m_ConstantInt(PossibleMSBInt))) {
+      (match(TrueVal, m_Zero()) &&
+           match(FalseVal, m_ConstantInt(PossibleMSBInt)) ||
+       match(TrueVal, m_ConstantInt(PossibleMSBInt)) &&
+           match(FalseVal, m_Zero()))) {
     auto *Ty = A->getType();
     unsigned BW = Ty->getIntegerBitWidth();
     APInt MostSignificantBit = APInt::getOneBitSet(BW, BW - 1);
@@ -2062,7 +2062,8 @@ static Instruction *foldICmpUSubSatWithAndForMostSignificantBitCmp(
                                                           m_Constant(Const1)),
                          m_Intrinsic<Intrinsic::usub_sat>(
                              m_Value(B), m_Constant(Const2)))) &&
-      match(TrueVal, m_Zero()) && match(FalseVal, m_Constant(PossibleMSB))) {
+      (match(TrueVal, m_Zero()) && match(FalseVal, m_Constant(PossibleMSB))
+    || match(TrueVal, m_Constant(PossibleMSB) ) && match(FalseVal, m_Zero()))) {
     auto *VecTy1 = dyn_cast<FixedVectorType>(Const1->getType());
     auto *VecTy2 = dyn_cast<FixedVectorType>(Const2->getType());
     auto *VecTyMSB = dyn_cast<FixedVectorType>(PossibleMSB->getType());
diff --git a/llvm/test/Transforms/InstCombine/usub_sat_to_msb_mask.ll b/llvm/test/Transforms/InstCombine/usub_sat_to_msb_mask.ll
index ffa77f4b42138..f6402c24315c5 100644
--- a/llvm/test/Transforms/InstCombine/usub_sat_to_msb_mask.ll
+++ b/llvm/test/Transforms/InstCombine/usub_sat_to_msb_mask.ll
@@ -1,3 +1,4 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
 
 ; RUN: opt -passes=instcombine -S < %s 2>&1 | FileCheck %s
 
@@ -7,12 +8,14 @@ declare i32 @llvm.usub.sat.i32(i32, i32)
 declare i64 @llvm.usub.sat.i64(i64, i64)
 
 define i8 @test_i8(i8 %a, i8 %b) {
-; CHECK-LABEL: @test_i8(
-; CHECK-NEXT: call i8 @llvm.usub.sat.i8(i8 %a, i8 96)
-; CHECK-NEXT: call i8 @llvm.usub.sat.i8(i8 %b, i8 112)
-; CHECK-NEXT: or i8
-; CHECK-NEXT: and i8
-; CHECK-NEXT: ret i8
+; CHECK-LABEL: define i8 @test_i8(
+; CHECK-SAME: i8 [[A:%.*]], i8 [[B:%.*]]) {
+; CHECK-NEXT:    [[TMP1:%.*]] = call i8 @llvm.usub.sat.i8(i8 [[A]], i8 96)
+; CHECK-NEXT:    [[TMP2:%.*]] = call i8 @llvm.usub.sat.i8(i8 [[B]], i8 112)
+; CHECK-NEXT:    [[TMP3:%.*]] = or i8 [[TMP1]], [[TMP2]]
+; CHECK-NEXT:    [[RES:%.*]] = and i8 [[TMP3]], -128
+; CHECK-NEXT:    ret i8 [[RES]]
+;
 
   %a_sub = call i8 @llvm.usub.sat.i8(i8 %a, i8 223)
   %b_sub = call i8 @llvm.usub.sat.i8(i8 %b, i8 239)
@@ -22,13 +25,33 @@ define i8 @test_i8(i8 %a, i8 %b) {
   ret i8 %res
 }
 
+define i8 @test_i8_ne(i8 %a, i8 %b) {
+; CHECK-LABEL: define i8 @test_i8_ne(
+; CHECK-SAME: i8 [[A:%.*]], i8 [[B:%.*]]) {
+; CHECK-NEXT:    [[TMP1:%.*]] = call i8 @llvm.usub.sat.i8(i8 [[A]], i8 96)
+; CHECK-NEXT:    [[TMP2:%.*]] = call i8 @llvm.usub.sat.i8(i8 [[B]], i8 112)
+; CHECK-NEXT:    [[TMP3:%.*]] = or i8 [[TMP1]], [[TMP2]]
+; CHECK-NEXT:    [[RES:%.*]] = and i8 [[TMP3]], -128
+; CHECK-NEXT:    ret i8 [[RES]]
+;
+
+  %a_sub = call i8 @llvm.usub.sat.i8(i8 %a, i8 223)
+  %b_sub = call i8 @llvm.usub.sat.i8(i8 %b, i8 239)
+  %or = or i8 %a_sub, %b_sub
+  %cmp = icmp ne i8 %or, 0
+  %res = select i1 %cmp, i8 128, i8 0
+  ret i8 %res
+}
+
 define i16 @test_i16(i16 %a, i16 %b) {
-; CHECK-LABEL: @test_i16(
-; CHECK-NEXT: call i16 @llvm.usub.sat.i16(i16 %a, i16 32642)
-; CHECK-NEXT: call i16 @llvm.usub.sat.i16(i16 %b, i16 32656)
-; CHECK-NEXT: or i16
-; CHECK-NEXT: and i16
-; CHECK-NEXT: ret i16
+; CHECK-LABEL: define i16 @test_i16(
+; CHECK-SAME: i16 [[A:%.*]], i16 [[B:%.*]]) {
+; CHECK-NEXT:    [[TMP1:%.*]] = call i16 @llvm.usub.sat.i16(i16 [[A]], i16 32642)
+; CHECK-NEXT:    [[TMP2:%.*]] = call i16 @llvm.usub.sat.i16(i16 [[B]], i16 32656)
+; CHECK-NEXT:    [[TMP3:%.*]] = or i16 [[TMP1]], [[TMP2]]
+; CHECK-NEXT:    [[RES:%.*]] = and i16 [[TMP3]], -32768
+; CHECK-NEXT:    ret i16 [[RES]]
+;
 
   %a_sub = call i16 @llvm.usub.sat.i16(i16 %a, i16 65409)
   %b_sub = call i16 @llvm.usub.sat.i16(i16 %b, i16 65423)
@@ -39,12 +62,14 @@ define i16 @test_i16(i16 %a, i16 %b) {
 }
 
 define i32 @test_i32(i32 %a, i32 %b) {
-; CHECK-LABEL: @test_i32(
-; CHECK-NEXT: call i32 @llvm.usub.sat.i32(i32 %a, i32 224)
-; CHECK-NEXT: call i32 @llvm.usub.sat.i32(i32 %b, i32 240)
-; CHECK-NEXT: or i32
-; CHECK-NEXT: and i32
-; CHECK-NEXT: ret i32
+; CHECK-LABEL: define i32 @test_i32(
+; CHECK-SAME: i32 [[A:%.*]], i32 [[B:%.*]]) {
+; CHECK-NEXT:    [[TMP1:%.*]] = call i32 @llvm.usub.sat.i32(i32 [[A]], i32 224)
+; CHECK-NEXT:    [[TMP2:%.*]] = call i32 @llvm.usub.sat.i32(i32 [[B]], i32 240)
+; CHECK-NEXT:    [[TMP3:%.*]] = or i32 [[TMP1]], [[TMP2]]
+; CHECK-NEXT:    [[RES:%.*]] = and i32 [[TMP3]], -2147483648
+; CHECK-NEXT:    ret i32 [[RES]]
+;
 
   %a_sub = call i32 @llvm.usub.sat.i32(i32 %a, i32 2147483871)
   %b_sub = call i32 @llvm.usub.sat.i32(i32 %b, i32 2147483887)
@@ -55,12 +80,14 @@ define i32 @test_i32(i32 %a, i32 %b) {
 }
 
 define i64 @test_i64(i64 %a, i64 %b) {
-; CHECK-LABEL: @test_i64(
-; CHECK-NEXT: call i64 @llvm.usub.sat.i64(i64 %a, i64 224)
-; CHECK-NEXT: call i64 @llvm.usub.sat.i64(i64 %b, i64 240)
-; CHECK-NEXT: or i64
-; CHECK-NEXT: and i64
-; CHECK-NEXT: ret i64
+; CHECK-LABEL: define i64 @test_i64(
+; CHECK-SAME: i64 [[A:%.*]], i64 [[B:%.*]]) {
+; CHECK-NEXT:    [[TMP1:%.*]] = call i64 @llvm.usub.sat.i64(i64 [[A]], i64 224)
+; CHECK-NEXT:    [[TMP2:%.*]] = call i64 @llvm.usub.sat.i64(i64 [[B]], i64 240)
+; CHECK-NEXT:    [[TMP3:%.*]] = or i64 [[TMP1]], [[TMP2]]
+; CHECK-NEXT:    [[RES:%.*]] = and i64 [[TMP3]], -9223372036854775808
+; CHECK-NEXT:    ret i64 [[RES]]
+;
 
   %a_sub = call i64 @llvm.usub.sat.i64(i64 %a, i64 9223372036854776031)
   %b_sub = call i64 @llvm.usub.sat.i64(i64 %b, i64 9223372036854776047)
@@ -71,13 +98,15 @@ define i64 @test_i64(i64 %a, i64 %b) {
 }
 
 define i32 @no_fold_due_to_small_K(i32 %a, i32 %b) {
-; CHECK-LABEL: @no_fold_due_to_small_K(
-; CHECK: call i32 @llvm.usub.sat.i32(i32 %a, i32 100)
-; CHECK: call i32 @llvm.usub.sat.i32(i32 %b, i32 239)
-; CHECK: or i32
-; CHECK: icmp eq i32
-; CHECK: select
-; CHECK: ret i32
+; CHECK-LABEL: define i32 @no_fold_due_to_small_K(
+; CHECK-SAME: i32 [[A:%.*]], i32 [[B:%.*]]) {
+; CHECK-NEXT:    [[A_SUB:%.*]] = call i32 @llvm.usub.sat.i32(i32 [[A]], i32 100)
+; CHECK-NEXT:    [[B_SUB:%.*]] = call i32 @llvm.usub.sat.i32(i32 [[B]], i32 239)
+; CHECK-NEXT:    [[OR:%.*]] = or i32 [[A_SUB]], [[B_SUB]]
+; CHECK-NEXT:    [[CMP:%.*]] = icmp eq i32 [[OR]], 0
+; CHECK-NEXT:    [[RES:%.*]] = select i1 [[CMP]], i32 0, i32 -2147483648
+; CHECK-NEXT:    ret i32 [[RES]]
+;
 
   %a_sub = call i32 @llvm.usub.sat.i32(i32 %a, i32 100)
   %b_sub = call i32 @llvm.usub.sat.i32(i32 %b, i32 239)
@@ -88,13 +117,15 @@ define i32 @no_fold_due_to_small_K(i32 %a, i32 %b) {
 }
 
 define i32 @commuted_test_neg(i32 %a, i32 %b) {
-; CHECK-LABEL: @commuted_test_neg(
-; CHECK-NEXT: call i32 @llvm.usub.sat.i32(i32 %b, i32 239)
-; CHECK-NEXT: call i32 @llvm.usub.sat.i32(i32 %a, i32 223)
-; CHECK-NEXT: or i32
-; CHECK-NEXT: icmp eq i32
-; CHECK-NEXT: select
-; CHECK-NEXT: ret i32
+; CHECK-LABEL: define i32 @commuted_test_neg(
+; CHECK-SAME: i32 [[A:%.*]], i32 [[B:%.*]]) {
+; CHECK-NEXT:    [[B_SUB:%.*]] = call i32 @llvm.usub.sat.i32(i32 [[B]], i32 239)
+; CHECK-NEXT:    [[A_SUB:%.*]] = call i32 @llvm.usub.sat.i32(i32 [[A]], i32 223)
+; CHECK-NEXT:    [[OR:%.*]] = or i32 [[B_SUB]], [[A_SUB]]
+; CHECK-NEXT:    [[CMP:%.*]] = icmp eq i32 [[OR]], 0
+; CHECK-NEXT:    [[RES:%.*]] = select i1 [[CMP]], i32 0, i32 -2147483648
+; CHECK-NEXT:    ret i32 [[RES]]
+;
 
   %b_sub = call i32 @llvm.usub.sat.i32(i32 %b, i32 239)
   %a_sub = call i32 @llvm.usub.sat.i32(i32 %a, i32 223)
@@ -104,23 +135,72 @@ define i32 @commuted_test_neg(i32 %a, i32 %b) {
   ret i32 %res
 }
 define <4 x i32> @vector_test(<4 x i32> %a, <4 x i32> %b) {
-; CHECK-LABEL: @vector_test(
-; CHECK-NEXT: call <4 x i32> @llvm.usub.sat.v4i32(<4 x i32> %a, <4 x i32> splat (i32 224))
-; CHECK-NEXT: call <4 x i32> @llvm.usub.sat.v4i32(<4 x i32> %b, <4 x i32> splat (i32 240))
-; CHECK-NEXT: or <4 x i32>
-; CHECK-NEXT: and <4 x i32>
-; CHECK-NEXT: ret <4 x i32>
+; CHECK-LABEL: define <4 x i32> @vector_test(
+; CHECK-SAME: <4 x i32> [[A:%.*]], <4 x i32> [[B:%.*]]) {
+; CHECK-NEXT:    [[TMP1:%.*]] = call <4 x i32> @llvm.usub.sat.v4i32(<4 x i32> [[A]], <4 x i32> splat (i32 224))
+; CHECK-NEXT:    [[TMP2:%.*]] = call <4 x i32> @llvm.usub.sat.v4i32(<4 x i32> [[B]], <4 x i32> splat (i32 240))
+; CHECK-NEXT:    [[TMP3:%.*]] = or <4 x i32> [[TMP1]], [[TMP2]]
+; CHECK-NEXT:    [[RES:%.*]] = and <4 x i32> [[TMP3]], splat (i32 -2147483648)
+; CHECK-NEXT:    ret <4 x i32> [[RES]]
+;
+
+
+  %a_sub = call <4 x i32> @llvm.usub.sat.v4i32(
+  <4 x i32> %a,
+  <4 x i32> <i32 2147483871, i32 2147483871, i32 2147483871, i32 2147483871>)
+  %b_sub = call <4 x i32> @llvm.usub.sat.v4i32(
+  <4 x i32> %b,
+  <4 x i32> <i32 2147483887, i32 2147483887, i32 2147483887, i32 2147483887>)
+  %or = or <4 x i32> %a_sub, %b_sub
+  %cmp = icmp eq <4 x i32> %or, zeroinitializer
+  %res = select <4 x i1> %cmp, <4 x i32> zeroinitializer,
+  <4 x i32> <i32 -2147483648, i32 -2147483648, i32 -2147483648, i32 -2147483648>
+  ret <4 x i32> %res
+}
+
+define <4 x i32> @vector_negative_test(<4 x i32> %a, <4 x i32> %b) {
+; CHECK-LABEL: define <4 x i32> @vector_negative_test(
+; CHECK-SAME: <4 x i32> [[A:%.*]], <4 x i32> [[B:%.*]]) {
+; CHECK-NEXT:    [[A_SUB:%.*]] = call <4 x i32> @llvm.usub.sat.v4i32(<4 x i32> [[A]], <4 x i32> <i32 -2147483425, i32 0, i32 -2147483425, i32 -2147483425>)
+; CHECK-NEXT:    [[B_SUB:%.*]] = call <4 x i32> @llvm.usub.sat.v4i32(<4 x i32> [[B]], <4 x i32> splat (i32 -2147483409))
+; CHECK-NEXT:    [[OR:%.*]] = or <4 x i32> [[A_SUB]], [[B_SUB]]
+; CHECK-NEXT:    [[CMP:%.*]] = icmp eq <4 x i32> [[OR]], zeroinitializer
+; CHECK-NEXT:    [[RES:%.*]] = select <4 x i1> [[CMP]], <4 x i32> zeroinitializer, <4 x i32> splat (i32 -2147483648)
+; CHECK-NEXT:    ret <4 x i32> [[RES]]
+;
+  %a_sub = call <4 x i32> @llvm.usub.sat.v4i32(
+  <4 x i32> %a,
+  <4 x i32> <i32 2147483871, i32 0, i32 2147483871, i32 2147483871>)
+  %b_sub = call <4 x i32> @llvm.usub.sat.v4i32(
+  <4 x i32> %b,
+  <4 x i32> <i32 2147483887, i32 2147483887, i32 2147483887, i32 2147483887>)
+  %or = or <4 x i32> %a_sub, %b_sub
+  %cmp = icmp eq <4 x i32> %or, zeroinitializer
+  %res = select <4 x i1> %cmp, <4 x i32> zeroinitializer,
+  <4 x i32> <i32 -2147483648, i32 -2147483648, i32 -2147483648, i32 -2147483648>
+  ret <4 x i32> %res
+}
+
+define <4 x i32> @vector_ne_test(<4 x i32> %a, <4 x i32> %b) {
+; CHECK-LABEL: define <4 x i32> @vector_ne_test(
+; CHECK-SAME: <4 x i32> [[A:%.*]], <4 x i32> [[B:%.*]]) {
+; CHECK-NEXT:    [[TMP1:%.*]] = call <4 x i32> @llvm.usub.sat.v4i32(<4 x i32> [[A]], <4 x i32> splat (i32 224))
+; CHECK-NEXT:    [[TMP2:%.*]] = call <4 x i32> @llvm.usub.sat.v4i32(<4 x i32> [[B]], <4 x i32> splat (i32 240))
+; CHECK-NEXT:    [[TMP3:%.*]] = or <4 x i32> [[TMP1]], [[TMP2]]
+; CHECK-NEXT:    [[RES:%.*]] = and <4 x i32> [[TMP3]], splat (i32 -2147483648)
+; CHECK-NEXT:    ret <4 x i32> [[RES]]
+;
 
 
   %a_sub = call <4 x i32> @llvm.usub.sat.v4i32(
-              <4 x i32> %a,
-              <4 x i32> <i32 2147483871, i32 2147483871, i32 2147483871, i32 2147483871>)
+  <4 x i32> %a,
+  <4 x i32> <i32 2147483871, i32 2147483871, i32 2147483871, i32 2147483871>)
   %b_sub = call <4 x i32> @llvm.usub.sat.v4i32(
-              <4 x i32> %b,
-              <4 x i32> <i32 2147483887, i32 2147483887, i32 2147483887, i32 2147483887>)
+  <4 x i32> %b,
+  <4 x i32> <i32 2147483887, i32 2147483887, i32 2147483887, i32 2147483887>)
   %or = or <4 x i32> %a_sub, %b_sub
   %cmp = icmp eq <4 x i32> %or, zeroinitializer
   %res = select <4 x i1> %cmp, <4 x i32> zeroinitializer,
-                         <4 x i32> <i32 -2147483648, i32 -2147483648, i32 -2147483648, i32 -2147483648>
+  <4 x i32> <i32 -2147483648, i32 -2147483648, i32 -2147483648, i32 -2147483648>
   ret <4 x i32> %res
 }

>From 14f028966b280ad8f62a0c0009e71c8f8fee9d8f Mon Sep 17 00:00:00 2001
From: Nimit Sachdeva <nimsach at amazon.com>
Date: Sun, 24 Aug 2025 03:36:04 -0400
Subject: [PATCH 4/7] Change to APInt for scalar and splat vector

---
 .../InstCombine/InstCombineSelect.cpp         | 150 ++++++------------
 1 file changed, 49 insertions(+), 101 deletions(-)

diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
index cd0eb59007b3b..d2525e2facef3 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
@@ -2006,120 +2006,68 @@ Value *InstCombinerImpl::foldSelectWithConstOpToBinOp(ICmpInst *Cmp,
 static Instruction *foldICmpUSubSatWithAndForMostSignificantBitCmp(
     SelectInst &SI, ICmpInst *ICI, InstCombiner::BuilderTy &Builder) {
   auto *CI = dyn_cast<ICmpInst>(SI.getCondition());
-  if (!CI) {
+  if (!CI)
     return nullptr;
-  }
 
   Value *CmpLHS = CI->getOperand(0);
   Value *CmpRHS = CI->getOperand(1);
-  if (!match(CmpRHS, m_Zero())) {
+  if (!match(CmpRHS, m_Zero()))
     return nullptr;
-  }
-  auto Pred = CI->getPredicate();
-  auto *TrueVal = SI.getTrueValue();
-  auto *FalseVal = SI.getFalseValue();
 
-  if (Pred != ICmpInst::ICMP_EQ && Pred != llvm::ICmpInst::ICMP_NE)
+  auto Pred = CI->getPredicate();
+  if (Pred != ICmpInst::ICMP_EQ && Pred != ICmpInst::ICMP_NE)
     return nullptr;
 
-  // Match: icmp eq (or (usub.sat A, IntConst1), (usub.sat B, IntConst2)), 0
   Value *A, *B;
-  ConstantInt *IntConst1, *IntConst2, *PossibleMSBInt;
-
-  if (match(CmpLHS, m_Or(m_Intrinsic<Intrinsic::usub_sat>(
-                             m_Value(A), m_ConstantInt(IntConst1)),
-                         m_Intrinsic<Intrinsic::usub_sat>(
-                             m_Value(B), m_ConstantInt(IntConst2)))) &&
-      (match(TrueVal, m_Zero()) &&
-           match(FalseVal, m_ConstantInt(PossibleMSBInt)) ||
-       match(TrueVal, m_ConstantInt(PossibleMSBInt)) &&
-           match(FalseVal, m_Zero()))) {
-    auto *Ty = A->getType();
-    unsigned BW = Ty->getIntegerBitWidth();
-    APInt MostSignificantBit = APInt::getOneBitSet(BW, BW - 1);
-
-    if (PossibleMSBInt->getValue() != MostSignificantBit)
-      return nullptr;
-    // Ensure IntConst1 and IntConst2 are >= MostSignificantBit
-    if (IntConst1->getValue().ult(MostSignificantBit) ||
-        IntConst2->getValue().ult(MostSignificantBit))
-      return nullptr;
+  const APInt *Constant1, *Constant2, *PossibleMSB;
+  if (!match(CmpLHS, m_Or(m_Intrinsic<Intrinsic::usub_sat>(m_Value(A),
+                                                           m_APInt(Constant1)),
+                          m_Intrinsic<Intrinsic::usub_sat>(
+                              m_Value(B), m_APInt(Constant2)))))
+    return nullptr;
 
-    // Rewrite:
-    Value *NewA = Builder.CreateBinaryIntrinsic(
-        Intrinsic::usub_sat, A,
-        ConstantInt::get(Ty, IntConst1->getValue() - MostSignificantBit + 1));
-    Value *NewB = Builder.CreateBinaryIntrinsic(
-        Intrinsic::usub_sat, B,
-        ConstantInt::get(Ty, IntConst2->getValue() - MostSignificantBit + 1));
-    Value *Or = Builder.CreateOr(NewA, NewB);
-    Value *And =
-        Builder.CreateAnd(Or, ConstantInt::get(Ty, MostSignificantBit));
-    return cast<Instruction>(And);
-  }
-  Constant *Const1, *Const2, *PossibleMSB;
-  if (match(CmpLHS, m_Or(m_Intrinsic<Intrinsic::usub_sat>(m_Value(A),
-                                                          m_Constant(Const1)),
-                         m_Intrinsic<Intrinsic::usub_sat>(
-                             m_Value(B), m_Constant(Const2)))) &&
-      (match(TrueVal, m_Zero()) && match(FalseVal, m_Constant(PossibleMSB))
-    || match(TrueVal, m_Constant(PossibleMSB) ) && match(FalseVal, m_Zero()))) {
-    auto *VecTy1 = dyn_cast<FixedVectorType>(Const1->getType());
-    auto *VecTy2 = dyn_cast<FixedVectorType>(Const2->getType());
-    auto *VecTyMSB = dyn_cast<FixedVectorType>(PossibleMSB->getType());
-    if (!VecTy1 || !VecTy2 || !VecTyMSB) {
-      return nullptr;
-    }
+  Value *TrueVal = SI.getTrueValue();
+  Value *FalseVal = SI.getFalseValue();
+  if (!((match(TrueVal, m_Zero()) && match(FalseVal, m_APInt(PossibleMSB))) ||
+        (match(TrueVal, m_APInt(PossibleMSB)) && match(FalseVal, m_Zero()))))
+    return nullptr;
 
-    unsigned NumElements = VecTy1->getNumElements();
+  auto *Ty = A->getType();
+  auto *VecTy = dyn_cast<VectorType>(Ty);
+  unsigned BW = PossibleMSB->getBitWidth();
+  APInt MostSignificantBit = APInt::getOneBitSet(BW, BW - 1);
 
-    if (NumElements != VecTy2->getNumElements() ||
-        NumElements != VecTyMSB->getNumElements() || NumElements == 0) {
-      return nullptr;
-    }
-    auto *SplatMSB =
-        dyn_cast<ConstantInt>(PossibleMSB->getAggregateElement(0u));
-    unsigned BW = SplatMSB->getValue().getBitWidth();
-    APInt MostSignificantBit = APInt::getOneBitSet(BW, BW - 1);
-    if (!SplatMSB || SplatMSB->getValue() != MostSignificantBit) {
-      return nullptr;
-    }
-    for (unsigned int i = 1; i < NumElements; ++i) {
-      auto *Element =
-          dyn_cast<ConstantInt>(PossibleMSB->getAggregateElement(i));
-      if (!Element || Element->getValue() != SplatMSB->getValue()) {
-        return nullptr;
-      }
-    }
-    SmallVector<Constant *, 16> Arg1, Arg2;
-    for (unsigned int i = 0; i < NumElements; ++i) {
-      auto *E1 = dyn_cast<ConstantInt>(Const1->getAggregateElement(i));
-      auto *E2 = dyn_cast<ConstantInt>(Const2->getAggregateElement(i));
-      if (!E1 || !E2) {
-        return nullptr;
-      }
-      if (E1->getValue().ult(SplatMSB->getValue()) ||
-          E2->getValue().ult(SplatMSB->getValue())) {
-        return nullptr;
-      }
-      Arg1.emplace_back(
-          ConstantInt::get(A->getType()->getScalarType(),
-                           E1->getValue() - MostSignificantBit + 1));
-      Arg2.emplace_back(
-          ConstantInt::get(B->getType()->getScalarType(),
-                           E2->getValue() - MostSignificantBit + 1));
-    }
-    Constant *ConstVec1 = ConstantVector::get(Arg1);
-    Constant *ConstVec2 = ConstantVector::get(Arg2);
-    Value *NewA =
-        Builder.CreateBinaryIntrinsic(Intrinsic::usub_sat, A, ConstVec1);
-    Value *NewB =
-        Builder.CreateBinaryIntrinsic(Intrinsic::usub_sat, B, ConstVec2);
-    Value *Or = Builder.CreateOr(NewA, NewB);
-    Value *And = Builder.CreateAnd(Or, PossibleMSB);
-    return cast<Instruction>(And);
+  if (*PossibleMSB != MostSignificantBit ||
+      Constant1->ult(MostSignificantBit) || Constant2->ult(MostSignificantBit))
+    return nullptr;
+
+  APInt AdjAP1 = *Constant1 - MostSignificantBit + 1;
+  APInt AdjAP2 = *Constant2 - MostSignificantBit + 1;
+
+  Constant *Adj1, *Adj2;
+  if (VecTy) {
+    Constant *Elt1 = ConstantInt::get(VecTy->getElementType(), AdjAP1);
+    Constant *Elt2 = ConstantInt::get(VecTy->getElementType(), AdjAP2);
+    Adj1 = ConstantVector::getSplat(VecTy->getElementCount(), Elt1);
+    Adj2 = ConstantVector::getSplat(VecTy->getElementCount(), Elt2);
+  } else {
+    Adj1 = ConstantInt::get(Ty, AdjAP1);
+    Adj2 = ConstantInt::get(Ty, AdjAP2);
   }
-  return nullptr;
+
+  Value *NewA = Builder.CreateBinaryIntrinsic(Intrinsic::usub_sat, A, Adj1);
+  Value *NewB = Builder.CreateBinaryIntrinsic(Intrinsic::usub_sat, B, Adj2);
+  Value *Or = Builder.CreateOr(NewA, NewB);
+  Constant *MSBConst;
+  if (VecTy) {
+    MSBConst = ConstantVector::getSplat(
+        VecTy->getElementCount(),
+        ConstantInt::get(VecTy->getScalarType(), *PossibleMSB));
+  } else {
+    MSBConst = ConstantInt::get(Ty->getScalarType(), *PossibleMSB);
+  }
+  Value *And = Builder.CreateAnd(Or, MSBConst);
+  return cast<Instruction>(And);
 }
 
 /// Visit a SelectInst that has an ICmpInst as its first operand.

>From fb6c7369c24e97ae0d2ac63bf0c74bf04b5590c7 Mon Sep 17 00:00:00 2001
From: Nimit Sachdeva <nimsach at amazon.com>
Date: Sun, 24 Aug 2025 18:39:08 -0400
Subject: [PATCH 5/7] apply suggestions from code review

---
 .../InstCombine/InstCombineSelect.cpp         | 69 +++++++------------
 1 file changed, 23 insertions(+), 46 deletions(-)

diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
index d2525e2facef3..b2720fc113cad 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
@@ -2005,69 +2005,45 @@ Value *InstCombinerImpl::foldSelectWithConstOpToBinOp(ICmpInst *Cmp,
 /// Likewise, for vector arguments as well.
 static Instruction *foldICmpUSubSatWithAndForMostSignificantBitCmp(
     SelectInst &SI, ICmpInst *ICI, InstCombiner::BuilderTy &Builder) {
-  auto *CI = dyn_cast<ICmpInst>(SI.getCondition());
-  if (!CI)
-    return nullptr;
-
-  Value *CmpLHS = CI->getOperand(0);
-  Value *CmpRHS = CI->getOperand(1);
-  if (!match(CmpRHS, m_Zero()))
-    return nullptr;
-
-  auto Pred = CI->getPredicate();
-  if (Pred != ICmpInst::ICMP_EQ && Pred != ICmpInst::ICMP_NE)
-    return nullptr;
-
+  CmpPredicate Pred;
   Value *A, *B;
-  const APInt *Constant1, *Constant2, *PossibleMSB;
-  if (!match(CmpLHS, m_Or(m_Intrinsic<Intrinsic::usub_sat>(m_Value(A),
-                                                           m_APInt(Constant1)),
-                          m_Intrinsic<Intrinsic::usub_sat>(
-                              m_Value(B), m_APInt(Constant2)))))
+  const APInt *Constant1, *Constant2;
+  if (!match(SI.getCondition(),
+             m_ICmp(Pred,
+                    m_Or(m_Intrinsic<Intrinsic::usub_sat>(m_Value(A),
+                                                          m_APInt(Constant1)),
+                         m_Intrinsic<Intrinsic::usub_sat>(m_Value(B),
+                                                          m_APInt(Constant2))),
+                    m_Zero())))
     return nullptr;
 
   Value *TrueVal = SI.getTrueValue();
   Value *FalseVal = SI.getFalseValue();
-  if (!((match(TrueVal, m_Zero()) && match(FalseVal, m_APInt(PossibleMSB))) ||
-        (match(TrueVal, m_APInt(PossibleMSB)) && match(FalseVal, m_Zero()))))
+  if (!(Pred == ICmpInst::ICMP_EQ &&
+        (match(TrueVal, m_Zero()) && match(FalseVal, m_SignMask()))) ||
+      (Pred == ICmpInst::ICMP_NE &&
+       (match(TrueVal, m_SignMask()) && match(FalseVal, m_Zero()))))
     return nullptr;
 
   auto *Ty = A->getType();
-  auto *VecTy = dyn_cast<VectorType>(Ty);
-  unsigned BW = PossibleMSB->getBitWidth();
-  APInt MostSignificantBit = APInt::getOneBitSet(BW, BW - 1);
+  unsigned BW = Constant1->getBitWidth();
+  APInt MostSignificantBit = APInt::getSignMask(BW);
 
-  if (*PossibleMSB != MostSignificantBit ||
-      Constant1->ult(MostSignificantBit) || Constant2->ult(MostSignificantBit))
+  // Anything over MSB is negative
+  if (Constant1->isNonNegative() || Constant2->isNonNegative())
     return nullptr;
 
   APInt AdjAP1 = *Constant1 - MostSignificantBit + 1;
   APInt AdjAP2 = *Constant2 - MostSignificantBit + 1;
 
-  Constant *Adj1, *Adj2;
-  if (VecTy) {
-    Constant *Elt1 = ConstantInt::get(VecTy->getElementType(), AdjAP1);
-    Constant *Elt2 = ConstantInt::get(VecTy->getElementType(), AdjAP2);
-    Adj1 = ConstantVector::getSplat(VecTy->getElementCount(), Elt1);
-    Adj2 = ConstantVector::getSplat(VecTy->getElementCount(), Elt2);
-  } else {
-    Adj1 = ConstantInt::get(Ty, AdjAP1);
-    Adj2 = ConstantInt::get(Ty, AdjAP2);
-  }
+  auto *Adj1 = ConstantInt::get(Ty, AdjAP1);
+  auto *Adj2 = ConstantInt::get(Ty, AdjAP2);
 
   Value *NewA = Builder.CreateBinaryIntrinsic(Intrinsic::usub_sat, A, Adj1);
   Value *NewB = Builder.CreateBinaryIntrinsic(Intrinsic::usub_sat, B, Adj2);
   Value *Or = Builder.CreateOr(NewA, NewB);
-  Constant *MSBConst;
-  if (VecTy) {
-    MSBConst = ConstantVector::getSplat(
-        VecTy->getElementCount(),
-        ConstantInt::get(VecTy->getScalarType(), *PossibleMSB));
-  } else {
-    MSBConst = ConstantInt::get(Ty->getScalarType(), *PossibleMSB);
-  }
-  Value *And = Builder.CreateAnd(Or, MSBConst);
-  return cast<Instruction>(And);
+  Constant *MSBConst = ConstantInt::get(Ty, MostSignificantBit);
+  return BinaryOperator::CreateAnd(Or, MSBConst);
 }
 
 /// Visit a SelectInst that has an ICmpInst as its first operand.
@@ -2088,7 +2064,8 @@ Instruction *InstCombinerImpl::foldSelectInstWithICmp(SelectInst &SI,
     return NewSel;
   if (Instruction *Folded =
           foldICmpUSubSatWithAndForMostSignificantBitCmp(SI, ICI, Builder))
-    return replaceInstUsesWith(SI, Folded);
+    return Folded;
+  ;
 
   // NOTE: if we wanted to, this is where to detect integer MIN/MAX
   bool Changed = false;

>From b3bad04cc8ce3d0439710d2c6cafc6d3aff02828 Mon Sep 17 00:00:00 2001
From: Nimit Sachdeva <nimsach at amazon.com>
Date: Sun, 24 Aug 2025 19:21:59 -0400
Subject: [PATCH 6/7] test change

---
 .../InstCombine/usub_sat_to_msb_mask.ll       | 27 +++++++------------
 1 file changed, 10 insertions(+), 17 deletions(-)

diff --git a/llvm/test/Transforms/InstCombine/usub_sat_to_msb_mask.ll b/llvm/test/Transforms/InstCombine/usub_sat_to_msb_mask.ll
index f6402c24315c5..d89e7e5f2beba 100644
--- a/llvm/test/Transforms/InstCombine/usub_sat_to_msb_mask.ll
+++ b/llvm/test/Transforms/InstCombine/usub_sat_to_msb_mask.ll
@@ -2,11 +2,6 @@
 
 ; RUN: opt -passes=instcombine -S < %s 2>&1 | FileCheck %s
 
-declare i8 @llvm.usub.sat.i8(i8, i8)
-declare i16 @llvm.usub.sat.i16(i16, i16)
-declare i32 @llvm.usub.sat.i32(i32, i32)
-declare i64 @llvm.usub.sat.i64(i64, i64)
-
 define i8 @test_i8(i8 %a, i8 %b) {
 ; CHECK-LABEL: define i8 @test_i8(
 ; CHECK-SAME: i8 [[A:%.*]], i8 [[B:%.*]]) {
@@ -146,15 +141,14 @@ define <4 x i32> @vector_test(<4 x i32> %a, <4 x i32> %b) {
 
 
   %a_sub = call <4 x i32> @llvm.usub.sat.v4i32(
-  <4 x i32> %a,
-  <4 x i32> <i32 2147483871, i32 2147483871, i32 2147483871, i32 2147483871>)
+  <4 x i32> %a, <4 x i32> splat (i32 2147483871))
   %b_sub = call <4 x i32> @llvm.usub.sat.v4i32(
-  <4 x i32> %b,
-  <4 x i32> <i32 2147483887, i32 2147483887, i32 2147483887, i32 2147483887>)
+  <4 x i32> %b, <4 x i32> splat (i32 2147483887))
   %or = or <4 x i32> %a_sub, %b_sub
   %cmp = icmp eq <4 x i32> %or, zeroinitializer
-  %res = select <4 x i1> %cmp, <4 x i32> zeroinitializer,
-  <4 x i32> <i32 -2147483648, i32 -2147483648, i32 -2147483648, i32 -2147483648>
+  %res = select <4 x i1> %cmp,
+  <4 x i32> zeroinitializer,
+  <4 x i32> splat (i32 -2147483648)
   ret <4 x i32> %res
 }
 
@@ -193,14 +187,13 @@ define <4 x i32> @vector_ne_test(<4 x i32> %a, <4 x i32> %b) {
 
 
   %a_sub = call <4 x i32> @llvm.usub.sat.v4i32(
-  <4 x i32> %a,
-  <4 x i32> <i32 2147483871, i32 2147483871, i32 2147483871, i32 2147483871>)
+  <4 x i32> %a, <4 x i32> splat (i32 2147483871))
   %b_sub = call <4 x i32> @llvm.usub.sat.v4i32(
-  <4 x i32> %b,
-  <4 x i32> <i32 2147483887, i32 2147483887, i32 2147483887, i32 2147483887>)
+  <4 x i32> %b, <4 x i32> splat (i32 2147483887))
   %or = or <4 x i32> %a_sub, %b_sub
   %cmp = icmp eq <4 x i32> %or, zeroinitializer
-  %res = select <4 x i1> %cmp, <4 x i32> zeroinitializer,
-  <4 x i32> <i32 -2147483648, i32 -2147483648, i32 -2147483648, i32 -2147483648>
+  %res = select <4 x i1> %cmp,
+  <4 x i32> zeroinitializer,
+  <4 x i32> splat (i32 -2147483648)
   ret <4 x i32> %res
 }

>From d246173b13ac64dced36c2327bb2dbe0a4a011ba Mon Sep 17 00:00:00 2001
From: Nimit Sachdeva <nimsach at amazon.com>
Date: Sun, 24 Aug 2025 22:48:08 -0400
Subject: [PATCH 7/7] formatting

---
 llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp | 3 ++-
 1 file changed, 2 insertions(+), 1 deletion(-)

diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
index b2720fc113cad..c8bde80f132d2 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
@@ -50,6 +50,7 @@
 using namespace llvm;
 using namespace PatternMatch;
 
+
 /// Replace a select operand based on an equality comparison with the identity
 /// constant of a binop.
 static Instruction *foldSelectBinOpIdentity(SelectInst &Sel,
@@ -1712,6 +1713,7 @@ tryToReuseConstantFromSelectInComparison(SelectInst &Sel, ICmpInst &Cmp,
   if (Pred == CmpInst::ICMP_ULT && match(X, m_Add(m_Value(), m_Constant())))
     return nullptr;
 
+
   Value *SelVal0, *SelVal1; // We do not care which one is from where.
   match(&Sel, m_Select(m_Value(), m_Value(SelVal0), m_Value(SelVal1)));
   // At least one of these values we are selecting between must be a constant
@@ -2065,7 +2067,6 @@ Instruction *InstCombinerImpl::foldSelectInstWithICmp(SelectInst &SI,
   if (Instruction *Folded =
           foldICmpUSubSatWithAndForMostSignificantBitCmp(SI, ICI, Builder))
     return Folded;
-  ;
 
   // NOTE: if we wanted to, this is where to detect integer MIN/MAX
   bool Changed = false;



More information about the llvm-commits mailing list