[llvm] Optimize usub.sat fix for #79690 (PR #151044)

Nimit Sachdeva via llvm-commits llvm-commits at lists.llvm.org
Mon Jul 28 15:00:38 PDT 2025


https://github.com/nimit25 updated https://github.com/llvm/llvm-project/pull/151044

>From 82336b681fa6e917d4fe217a227614f481aa30ac Mon Sep 17 00:00:00 2001
From: Nimit Sachdeva <nimsach at amazon.com>
Date: Mon, 28 Jul 2025 17:23:11 -0400
Subject: [PATCH 1/2] Optimize usub.sat fix for #79690

---
 .../InstCombine/InstCombineSelect.cpp         | 143 +++++++++++++++++-
 .../InstCombine/usub_sat_to_msb_mask.ll       | 126 +++++++++++++++
 2 files changed, 263 insertions(+), 6 deletions(-)
 create mode 100644 llvm/test/Transforms/InstCombine/usub_sat_to_msb_mask.ll

diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
index eb4332fbc0959..74544009e6872 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
@@ -42,6 +42,8 @@
 #include "llvm/Support/KnownBits.h"
 #include "llvm/Transforms/InstCombine/InstCombiner.h"
 #include <cassert>
+#include <cstddef>
+#include <iostream>
 #include <utility>
 
 #define DEBUG_TYPE "instcombine"
@@ -50,7 +52,6 @@
 using namespace llvm;
 using namespace PatternMatch;
 
-
 /// Replace a select operand based on an equality comparison with the identity
 /// constant of a binop.
 static Instruction *foldSelectBinOpIdentity(SelectInst &Sel,
@@ -1713,7 +1714,6 @@ tryToReuseConstantFromSelectInComparison(SelectInst &Sel, ICmpInst &Cmp,
   if (Pred == CmpInst::ICMP_ULT && match(X, m_Add(m_Value(), m_Constant())))
     return nullptr;
 
-
   Value *SelVal0, *SelVal1; // We do not care which one is from where.
   match(&Sel, m_Select(m_Value(), m_Value(SelVal0), m_Value(SelVal1)));
   // At least one of these values we are selecting between must be a constant
@@ -1993,6 +1993,135 @@ Value *InstCombinerImpl::foldSelectWithConstOpToBinOp(ICmpInst *Cmp,
   return BinOp;
 }
 
+/// Folds:
+///   %a_sub = call @llvm.usub.sat(x, IntConst1)
+///   %b_sub = call @llvm.usub.sat(y, IntConst2)
+///   %or = or %a_sub, %b_sub
+///   %cmp = icmp eq %or, 0
+///   %sel = select %cmp, 0, MostSignificantBit
+/// into:
+///   %a_sub' = usub.sat(x, IntConst1 - MostSignificantBit)
+///   %b_sub' = usub.sat(y, IntConst2 - MostSignificantBit)
+///   %or = or %a_sub', %b_sub'
+///   %and = and %or, MostSignificantBit
+/// If the args are vectors
+///
+static Instruction *foldICmpUSubSatWithAndForMostSignificantBitCmp(
+    SelectInst &SI, ICmpInst *ICI, InstCombiner::BuilderTy &Builder) {
+  auto *CI = dyn_cast<ICmpInst>(SI.getCondition());
+  if (!CI) {
+    return nullptr;
+  }
+
+  Value *CmpLHS = CI->getOperand(0);
+  Value *CmpRHS = CI->getOperand(1);
+  if (!match(CmpRHS, m_Zero())) {
+    return nullptr;
+  }
+  auto Pred = CI->getPredicate();
+  auto *TrueVal = SI.getTrueValue();
+  auto *FalseVal = SI.getFalseValue();
+
+  if (Pred != ICmpInst::ICMP_EQ)
+    return nullptr;
+
+  // Match: icmp eq (or (usub.sat A, IntConst1), (usub.sat B, IntConst2)), 0
+  Value *A, *B;
+  ConstantInt *IntConst1, *IntConst2, *PossibleMSBInt;
+
+  if (match(CmpLHS, m_Or(m_Intrinsic<Intrinsic::usub_sat>(
+                             m_Value(A), m_ConstantInt(IntConst1)),
+                         m_Intrinsic<Intrinsic::usub_sat>(
+                             m_Value(B), m_ConstantInt(IntConst2)))) &&
+      match(TrueVal, m_Zero()) &&
+      match(FalseVal, m_ConstantInt(PossibleMSBInt))) {
+    auto *Ty = A->getType();
+    unsigned BW = Ty->getIntegerBitWidth();
+    APInt MostSignificantBit = APInt::getOneBitSet(BW, BW - 1);
+
+    if (PossibleMSBInt->getValue() != MostSignificantBit)
+      return nullptr;
+    // Ensure IntConst1 and IntConst2 are >= MostSignificantBit
+    if (IntConst1->getValue().ult(MostSignificantBit) ||
+        IntConst2->getValue().ult(MostSignificantBit))
+      return nullptr;
+
+    // Rewrite:
+    Value *NewA = Builder.CreateBinaryIntrinsic(
+        Intrinsic::usub_sat, A,
+        ConstantInt::get(Ty, IntConst1->getValue() - MostSignificantBit + 1));
+    Value *NewB = Builder.CreateBinaryIntrinsic(
+        Intrinsic::usub_sat, B,
+        ConstantInt::get(Ty, IntConst2->getValue() - MostSignificantBit + 1));
+    Value *Or = Builder.CreateOr(NewA, NewB);
+    Value *And =
+        Builder.CreateAnd(Or, ConstantInt::get(Ty, MostSignificantBit));
+    return cast<Instruction>(And);
+  }
+  Constant *Const1, *Const2, *PossibleMSB;
+  if (match(CmpLHS, m_Or(m_Intrinsic<Intrinsic::usub_sat>(m_Value(A),
+                                                          m_Constant(Const1)),
+                         m_Intrinsic<Intrinsic::usub_sat>(
+                             m_Value(B), m_Constant(Const2)))) &&
+      match(TrueVal, m_Zero()) && match(FalseVal, m_Constant(PossibleMSB))) {
+    auto *VecTy1 = dyn_cast<FixedVectorType>(Const1->getType());
+    auto *VecTy2 = dyn_cast<FixedVectorType>(Const2->getType());
+    auto *VecTyMSB = dyn_cast<FixedVectorType>(PossibleMSB->getType());
+    if (!VecTy1 || !VecTy2 || !VecTyMSB) {
+      return nullptr;
+    }
+
+    unsigned NumElements = VecTy1->getNumElements();
+
+    if (NumElements != VecTy2->getNumElements() ||
+        NumElements != VecTyMSB->getNumElements() || NumElements == 0) {
+      return nullptr;
+    }
+    auto *SplatMSB =
+        dyn_cast<ConstantInt>(PossibleMSB->getAggregateElement(0u));
+    unsigned BW = SplatMSB->getValue().getBitWidth();
+    APInt MostSignificantBit = APInt::getOneBitSet(BW, BW - 1);
+    if (!SplatMSB || SplatMSB->getValue() != MostSignificantBit) {
+      return nullptr;
+    }
+    for (unsigned int i = 1; i < NumElements; ++i) {
+      auto *Element =
+          dyn_cast<ConstantInt>(PossibleMSB->getAggregateElement(i));
+      if (!Element || Element->getValue() != SplatMSB->getValue()) {
+        return nullptr;
+      }
+    }
+    SmallVector<Constant *, 16> Arg1, Arg2;
+    for (unsigned int i = 0; i < NumElements; ++i) {
+      auto *E1 = dyn_cast<ConstantInt>(Const1->getAggregateElement(i));
+      auto *E2 = dyn_cast<ConstantInt>(Const2->getAggregateElement(i));
+      if (!E1 || !E2) {
+        return nullptr;
+      }
+      if (E1->getValue().ult(SplatMSB->getValue()) ||
+          E2->getValue().ult(SplatMSB->getValue())) {
+        return nullptr;
+      }
+      Arg1.emplace_back(
+          ConstantInt::get(A->getType()->getScalarType(),
+                           E1->getValue() - MostSignificantBit + 1));
+      Arg2.emplace_back(
+          ConstantInt::get(B->getType()->getScalarType(),
+                           E2->getValue() - MostSignificantBit + 1));
+    }
+    Constant *ConstVec1 = ConstantVector::get(Arg1);
+    Constant *ConstVec2 = ConstantVector::get(Arg2);
+    Value *NewA =
+        Builder.CreateBinaryIntrinsic(Intrinsic::usub_sat, A, ConstVec1);
+    Value *NewB =
+        Builder.CreateBinaryIntrinsic(Intrinsic::usub_sat, B, ConstVec2);
+    Value *Or = Builder.CreateOr(NewA, NewB);
+    Value *And = Builder.CreateAnd(Or, PossibleMSB);
+    return cast<Instruction>(And);
+  }
+  return nullptr;
+}
+
 /// Visit a SelectInst that has an ICmpInst as its first operand.
 Instruction *InstCombinerImpl::foldSelectInstWithICmp(SelectInst &SI,
                                                       ICmpInst *ICI) {
@@ -2009,6 +2138,9 @@ Instruction *InstCombinerImpl::foldSelectInstWithICmp(SelectInst &SI,
   if (Instruction *NewSel =
           tryToReuseConstantFromSelectInComparison(SI, *ICI, *this))
     return NewSel;
+  if (Instruction *Folded =
+          foldICmpUSubSatWithAndForMostSignificantBitCmp(SI, ICI, Builder))
+    return replaceInstUsesWith(SI, Folded);
 
   // NOTE: if we wanted to, this is where to detect integer MIN/MAX
   bool Changed = false;
@@ -4200,10 +4332,9 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) {
       bool IsCastNeeded = LHS->getType() != SelType;
       Value *CmpLHS = cast<CmpInst>(CondVal)->getOperand(0);
       Value *CmpRHS = cast<CmpInst>(CondVal)->getOperand(1);
-      if (IsCastNeeded ||
-          (LHS->getType()->isFPOrFPVectorTy() &&
-           ((CmpLHS != LHS && CmpLHS != RHS) ||
-            (CmpRHS != LHS && CmpRHS != RHS)))) {
+      if (IsCastNeeded || (LHS->getType()->isFPOrFPVectorTy() &&
+                           ((CmpLHS != LHS && CmpLHS != RHS) ||
+                            (CmpRHS != LHS && CmpRHS != RHS)))) {
         CmpInst::Predicate MinMaxPred = getMinMaxPred(SPF, SPR.Ordered);
 
         Value *Cmp;
diff --git a/llvm/test/Transforms/InstCombine/usub_sat_to_msb_mask.ll b/llvm/test/Transforms/InstCombine/usub_sat_to_msb_mask.ll
new file mode 100644
index 0000000000000..ffa77f4b42138
--- /dev/null
+++ b/llvm/test/Transforms/InstCombine/usub_sat_to_msb_mask.ll
@@ -0,0 +1,126 @@
+
+; RUN: opt -passes=instcombine -S < %s 2>&1 | FileCheck %s
+
+declare i8 @llvm.usub.sat.i8(i8, i8)
+declare i16 @llvm.usub.sat.i16(i16, i16)
+declare i32 @llvm.usub.sat.i32(i32, i32)
+declare i64 @llvm.usub.sat.i64(i64, i64)
+
+define i8 @test_i8(i8 %a, i8 %b) {
+; CHECK-LABEL: @test_i8(
+; CHECK-NEXT: call i8 @llvm.usub.sat.i8(i8 %a, i8 96)
+; CHECK-NEXT: call i8 @llvm.usub.sat.i8(i8 %b, i8 112)
+; CHECK-NEXT: or i8
+; CHECK-NEXT: and i8
+; CHECK-NEXT: ret i8
+
+  %a_sub = call i8 @llvm.usub.sat.i8(i8 %a, i8 223)
+  %b_sub = call i8 @llvm.usub.sat.i8(i8 %b, i8 239)
+  %or = or i8 %a_sub, %b_sub
+  %cmp = icmp eq i8 %or, 0
+  %res = select i1 %cmp, i8 0, i8 128
+  ret i8 %res
+}
+
+define i16 @test_i16(i16 %a, i16 %b) {
+; CHECK-LABEL: @test_i16(
+; CHECK-NEXT: call i16 @llvm.usub.sat.i16(i16 %a, i16 32642)
+; CHECK-NEXT: call i16 @llvm.usub.sat.i16(i16 %b, i16 32656)
+; CHECK-NEXT: or i16
+; CHECK-NEXT: and i16
+; CHECK-NEXT: ret i16
+
+  %a_sub = call i16 @llvm.usub.sat.i16(i16 %a, i16 65409)
+  %b_sub = call i16 @llvm.usub.sat.i16(i16 %b, i16 65423)
+  %or = or i16 %a_sub, %b_sub
+  %cmp = icmp eq i16 %or, 0
+  %res = select i1 %cmp, i16 0, i16 32768
+  ret i16 %res
+}
+
+define i32 @test_i32(i32 %a, i32 %b) {
+; CHECK-LABEL: @test_i32(
+; CHECK-NEXT: call i32 @llvm.usub.sat.i32(i32 %a, i32 224)
+; CHECK-NEXT: call i32 @llvm.usub.sat.i32(i32 %b, i32 240)
+; CHECK-NEXT: or i32
+; CHECK-NEXT: and i32
+; CHECK-NEXT: ret i32
+
+  %a_sub = call i32 @llvm.usub.sat.i32(i32 %a, i32 2147483871)
+  %b_sub = call i32 @llvm.usub.sat.i32(i32 %b, i32 2147483887)
+  %or = or i32 %a_sub, %b_sub
+  %cmp = icmp eq i32 %or, 0
+  %res = select i1 %cmp, i32 0, i32 2147483648
+  ret i32 %res
+}
+
+define i64 @test_i64(i64 %a, i64 %b) {
+; CHECK-LABEL: @test_i64(
+; CHECK-NEXT: call i64 @llvm.usub.sat.i64(i64 %a, i64 224)
+; CHECK-NEXT: call i64 @llvm.usub.sat.i64(i64 %b, i64 240)
+; CHECK-NEXT: or i64
+; CHECK-NEXT: and i64
+; CHECK-NEXT: ret i64
+
+  %a_sub = call i64 @llvm.usub.sat.i64(i64 %a, i64 9223372036854776031)
+  %b_sub = call i64 @llvm.usub.sat.i64(i64 %b, i64 9223372036854776047)
+  %or = or i64 %a_sub, %b_sub
+  %cmp = icmp eq i64 %or, 0
+  %res = select i1 %cmp, i64 0, i64 9223372036854775808
+  ret i64 %res
+}
+
+define i32 @no_fold_due_to_small_K(i32 %a, i32 %b) {
+; CHECK-LABEL: @no_fold_due_to_small_K(
+; CHECK: call i32 @llvm.usub.sat.i32(i32 %a, i32 100)
+; CHECK: call i32 @llvm.usub.sat.i32(i32 %b, i32 239)
+; CHECK: or i32
+; CHECK: icmp eq i32
+; CHECK: select
+; CHECK: ret i32
+
+  %a_sub = call i32 @llvm.usub.sat.i32(i32 %a, i32 100)
+  %b_sub = call i32 @llvm.usub.sat.i32(i32 %b, i32 239)
+  %or = or i32 %a_sub, %b_sub
+  %cmp = icmp eq i32 %or, 0
+  %res = select i1 %cmp, i32 0, i32 2147483648
+  ret i32 %res
+}
+
+define i32 @commuted_test_neg(i32 %a, i32 %b) {
+; CHECK-LABEL: @commuted_test_neg(
+; CHECK-NEXT: call i32 @llvm.usub.sat.i32(i32 %b, i32 239)
+; CHECK-NEXT: call i32 @llvm.usub.sat.i32(i32 %a, i32 223)
+; CHECK-NEXT: or i32
+; CHECK-NEXT: icmp eq i32
+; CHECK-NEXT: select
+; CHECK-NEXT: ret i32
+
+  %b_sub = call i32 @llvm.usub.sat.i32(i32 %b, i32 239)
+  %a_sub = call i32 @llvm.usub.sat.i32(i32 %a, i32 223)
+  %or = or i32 %b_sub, %a_sub
+  %cmp = icmp eq i32 %or, 0
+  %res = select i1 %cmp, i32 0, i32 2147483648
+  ret i32 %res
+}
+define <4 x i32> @vector_test(<4 x i32> %a, <4 x i32> %b) {
+; CHECK-LABEL: @vector_test(
+; CHECK-NEXT: call <4 x i32> @llvm.usub.sat.v4i32(<4 x i32> %a, <4 x i32> splat (i32 224))
+; CHECK-NEXT: call <4 x i32> @llvm.usub.sat.v4i32(<4 x i32> %b, <4 x i32> splat (i32 240))
+; CHECK-NEXT: or <4 x i32>
+; CHECK-NEXT: and <4 x i32>
+; CHECK-NEXT: ret <4 x i32>
+
+
+  %a_sub = call <4 x i32> @llvm.usub.sat.v4i32(
+              <4 x i32> %a,
+              <4 x i32> <i32 2147483871, i32 2147483871, i32 2147483871, i32 2147483871>)
+  %b_sub = call <4 x i32> @llvm.usub.sat.v4i32(
+              <4 x i32> %b,
+              <4 x i32> <i32 2147483887, i32 2147483887, i32 2147483887, i32 2147483887>)
+  %or = or <4 x i32> %a_sub, %b_sub
+  %cmp = icmp eq <4 x i32> %or, zeroinitializer
+  %res = select <4 x i1> %cmp, <4 x i32> zeroinitializer,
+                         <4 x i32> <i32 -2147483648, i32 -2147483648, i32 -2147483648, i32 -2147483648>
+  ret <4 x i32> %res
+}

>From dbfd5989029598038bb9b3bef4cbba31619aa764 Mon Sep 17 00:00:00 2001
From: Nimit Sachdeva <nimsach at amazon.com>
Date: Mon, 28 Jul 2025 17:56:10 -0400
Subject: [PATCH 2/2] refactorization

---
 .../Transforms/InstCombine/InstCombineSelect.cpp   | 14 +++++++-------
 1 file changed, 7 insertions(+), 7 deletions(-)

diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
index 74544009e6872..7e4eaa9745917 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
@@ -42,8 +42,6 @@
 #include "llvm/Support/KnownBits.h"
 #include "llvm/Transforms/InstCombine/InstCombiner.h"
 #include <cassert>
-#include <cstddef>
-#include <iostream>
 #include <utility>
 
 #define DEBUG_TYPE "instcombine"
@@ -52,6 +50,7 @@
 using namespace llvm;
 using namespace PatternMatch;
 
+
 /// Replace a select operand based on an equality comparison with the identity
 /// constant of a binop.
 static Instruction *foldSelectBinOpIdentity(SelectInst &Sel,
@@ -1714,6 +1713,7 @@ tryToReuseConstantFromSelectInComparison(SelectInst &Sel, ICmpInst &Cmp,
   if (Pred == CmpInst::ICMP_ULT && match(X, m_Add(m_Value(), m_Constant())))
     return nullptr;
 
+
   Value *SelVal0, *SelVal1; // We do not care which one is from where.
   match(&Sel, m_Select(m_Value(), m_Value(SelVal0), m_Value(SelVal1)));
   // At least one of these values we are selecting between must be a constant
@@ -2004,8 +2004,7 @@ Value *InstCombinerImpl::foldSelectWithConstOpToBinOp(ICmpInst *Cmp,
 ///   %b_sub' = usub.sat(y, IntConst2 - MostSignificantBit)
 ///   %or = or %a_sub', %b_sub'
 ///   %and = and %or, MostSignificantBit
-/// If the args are vectors
-///
+/// Likewise, for vector arguments as well.
 static Instruction *foldICmpUSubSatWithAndForMostSignificantBitCmp(
     SelectInst &SI, ICmpInst *ICI, InstCombiner::BuilderTy &Builder) {
   auto *CI = dyn_cast<ICmpInst>(SI.getCondition());
@@ -4332,9 +4331,10 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) {
       bool IsCastNeeded = LHS->getType() != SelType;
       Value *CmpLHS = cast<CmpInst>(CondVal)->getOperand(0);
       Value *CmpRHS = cast<CmpInst>(CondVal)->getOperand(1);
-      if (IsCastNeeded || (LHS->getType()->isFPOrFPVectorTy() &&
-                           ((CmpLHS != LHS && CmpLHS != RHS) ||
-                            (CmpRHS != LHS && CmpRHS != RHS)))) {
+      if (IsCastNeeded ||
+          (LHS->getType()->isFPOrFPVectorTy() &&
+           ((CmpLHS != LHS && CmpLHS != RHS) ||
+            (CmpRHS != LHS && CmpRHS != RHS)))) {
         CmpInst::Predicate MinMaxPred = getMinMaxPred(SPF, SPR.Ordered);
 
         Value *Cmp;



More information about the llvm-commits mailing list