[llvm] Optimize usub.sat fix for #79690 (PR #151044)
Nimit Sachdeva via llvm-commits
llvm-commits at lists.llvm.org
Mon Jul 28 15:00:38 PDT 2025
https://github.com/nimit25 updated https://github.com/llvm/llvm-project/pull/151044
>From 82336b681fa6e917d4fe217a227614f481aa30ac Mon Sep 17 00:00:00 2001
From: Nimit Sachdeva <nimsach at amazon.com>
Date: Mon, 28 Jul 2025 17:23:11 -0400
Subject: [PATCH 1/2] Optimize usub.sat fix for #79690
---
.../InstCombine/InstCombineSelect.cpp | 143 +++++++++++++++++-
.../InstCombine/usub_sat_to_msb_mask.ll | 126 +++++++++++++++
2 files changed, 263 insertions(+), 6 deletions(-)
create mode 100644 llvm/test/Transforms/InstCombine/usub_sat_to_msb_mask.ll
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
index eb4332fbc0959..74544009e6872 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
@@ -42,6 +42,8 @@
#include "llvm/Support/KnownBits.h"
#include "llvm/Transforms/InstCombine/InstCombiner.h"
#include <cassert>
+#include <cstddef>
+#include <iostream>
#include <utility>
#define DEBUG_TYPE "instcombine"
@@ -50,7 +52,6 @@
using namespace llvm;
using namespace PatternMatch;
-
/// Replace a select operand based on an equality comparison with the identity
/// constant of a binop.
static Instruction *foldSelectBinOpIdentity(SelectInst &Sel,
@@ -1713,7 +1714,6 @@ tryToReuseConstantFromSelectInComparison(SelectInst &Sel, ICmpInst &Cmp,
if (Pred == CmpInst::ICMP_ULT && match(X, m_Add(m_Value(), m_Constant())))
return nullptr;
-
Value *SelVal0, *SelVal1; // We do not care which one is from where.
match(&Sel, m_Select(m_Value(), m_Value(SelVal0), m_Value(SelVal1)));
// At least one of these values we are selecting between must be a constant
@@ -1993,6 +1993,135 @@ Value *InstCombinerImpl::foldSelectWithConstOpToBinOp(ICmpInst *Cmp,
return BinOp;
}
+/// Folds:
+/// %a_sub = call @llvm.usub.sat(x, IntConst1)
+/// %b_sub = call @llvm.usub.sat(y, IntConst2)
+/// %or = or %a_sub, %b_sub
+/// %cmp = icmp eq %or, 0
+/// %sel = select %cmp, 0, MostSignificantBit
+/// into:
+/// %a_sub' = usub.sat(x, IntConst1 - MostSignificantBit)
+/// %b_sub' = usub.sat(y, IntConst2 - MostSignificantBit)
+/// %or = or %a_sub', %b_sub'
+/// %and = and %or, MostSignificantBit
+/// If the args are vectors
+///
+static Instruction *foldICmpUSubSatWithAndForMostSignificantBitCmp(
+ SelectInst &SI, ICmpInst *ICI, InstCombiner::BuilderTy &Builder) {
+ auto *CI = dyn_cast<ICmpInst>(SI.getCondition());
+ if (!CI) {
+ return nullptr;
+ }
+
+ Value *CmpLHS = CI->getOperand(0);
+ Value *CmpRHS = CI->getOperand(1);
+ if (!match(CmpRHS, m_Zero())) {
+ return nullptr;
+ }
+ auto Pred = CI->getPredicate();
+ auto *TrueVal = SI.getTrueValue();
+ auto *FalseVal = SI.getFalseValue();
+
+ if (Pred != ICmpInst::ICMP_EQ)
+ return nullptr;
+
+ // Match: icmp eq (or (usub.sat A, IntConst1), (usub.sat B, IntConst2)), 0
+ Value *A, *B;
+ ConstantInt *IntConst1, *IntConst2, *PossibleMSBInt;
+
+ if (match(CmpLHS, m_Or(m_Intrinsic<Intrinsic::usub_sat>(
+ m_Value(A), m_ConstantInt(IntConst1)),
+ m_Intrinsic<Intrinsic::usub_sat>(
+ m_Value(B), m_ConstantInt(IntConst2)))) &&
+ match(TrueVal, m_Zero()) &&
+ match(FalseVal, m_ConstantInt(PossibleMSBInt))) {
+ auto *Ty = A->getType();
+ unsigned BW = Ty->getIntegerBitWidth();
+ APInt MostSignificantBit = APInt::getOneBitSet(BW, BW - 1);
+
+ if (PossibleMSBInt->getValue() != MostSignificantBit)
+ return nullptr;
+ // Ensure IntConst1 and IntConst2 are >= MostSignificantBit
+ if (IntConst1->getValue().ult(MostSignificantBit) ||
+ IntConst2->getValue().ult(MostSignificantBit))
+ return nullptr;
+
+ // Rewrite:
+ Value *NewA = Builder.CreateBinaryIntrinsic(
+ Intrinsic::usub_sat, A,
+ ConstantInt::get(Ty, IntConst1->getValue() - MostSignificantBit + 1));
+ Value *NewB = Builder.CreateBinaryIntrinsic(
+ Intrinsic::usub_sat, B,
+ ConstantInt::get(Ty, IntConst2->getValue() - MostSignificantBit + 1));
+ Value *Or = Builder.CreateOr(NewA, NewB);
+ Value *And =
+ Builder.CreateAnd(Or, ConstantInt::get(Ty, MostSignificantBit));
+ return cast<Instruction>(And);
+ }
+ Constant *Const1, *Const2, *PossibleMSB;
+ if (match(CmpLHS, m_Or(m_Intrinsic<Intrinsic::usub_sat>(m_Value(A),
+ m_Constant(Const1)),
+ m_Intrinsic<Intrinsic::usub_sat>(
+ m_Value(B), m_Constant(Const2)))) &&
+ match(TrueVal, m_Zero()) && match(FalseVal, m_Constant(PossibleMSB))) {
+ auto *VecTy1 = dyn_cast<FixedVectorType>(Const1->getType());
+ auto *VecTy2 = dyn_cast<FixedVectorType>(Const2->getType());
+ auto *VecTyMSB = dyn_cast<FixedVectorType>(PossibleMSB->getType());
+ if (!VecTy1 || !VecTy2 || !VecTyMSB) {
+ return nullptr;
+ }
+
+ unsigned NumElements = VecTy1->getNumElements();
+
+ if (NumElements != VecTy2->getNumElements() ||
+ NumElements != VecTyMSB->getNumElements() || NumElements == 0) {
+ return nullptr;
+ }
+ auto *SplatMSB =
+ dyn_cast<ConstantInt>(PossibleMSB->getAggregateElement(0u));
+ unsigned BW = SplatMSB->getValue().getBitWidth();
+ APInt MostSignificantBit = APInt::getOneBitSet(BW, BW - 1);
+ if (!SplatMSB || SplatMSB->getValue() != MostSignificantBit) {
+ return nullptr;
+ }
+ for (unsigned int i = 1; i < NumElements; ++i) {
+ auto *Element =
+ dyn_cast<ConstantInt>(PossibleMSB->getAggregateElement(i));
+ if (!Element || Element->getValue() != SplatMSB->getValue()) {
+ return nullptr;
+ }
+ }
+ SmallVector<Constant *, 16> Arg1, Arg2;
+ for (unsigned int i = 0; i < NumElements; ++i) {
+ auto *E1 = dyn_cast<ConstantInt>(Const1->getAggregateElement(i));
+ auto *E2 = dyn_cast<ConstantInt>(Const2->getAggregateElement(i));
+ if (!E1 || !E2) {
+ return nullptr;
+ }
+ if (E1->getValue().ult(SplatMSB->getValue()) ||
+ E2->getValue().ult(SplatMSB->getValue())) {
+ return nullptr;
+ }
+ Arg1.emplace_back(
+ ConstantInt::get(A->getType()->getScalarType(),
+ E1->getValue() - MostSignificantBit + 1));
+ Arg2.emplace_back(
+ ConstantInt::get(B->getType()->getScalarType(),
+ E2->getValue() - MostSignificantBit + 1));
+ }
+ Constant *ConstVec1 = ConstantVector::get(Arg1);
+ Constant *ConstVec2 = ConstantVector::get(Arg2);
+ Value *NewA =
+ Builder.CreateBinaryIntrinsic(Intrinsic::usub_sat, A, ConstVec1);
+ Value *NewB =
+ Builder.CreateBinaryIntrinsic(Intrinsic::usub_sat, B, ConstVec2);
+ Value *Or = Builder.CreateOr(NewA, NewB);
+ Value *And = Builder.CreateAnd(Or, PossibleMSB);
+ return cast<Instruction>(And);
+ }
+ return nullptr;
+}
+
/// Visit a SelectInst that has an ICmpInst as its first operand.
Instruction *InstCombinerImpl::foldSelectInstWithICmp(SelectInst &SI,
ICmpInst *ICI) {
@@ -2009,6 +2138,9 @@ Instruction *InstCombinerImpl::foldSelectInstWithICmp(SelectInst &SI,
if (Instruction *NewSel =
tryToReuseConstantFromSelectInComparison(SI, *ICI, *this))
return NewSel;
+ if (Instruction *Folded =
+ foldICmpUSubSatWithAndForMostSignificantBitCmp(SI, ICI, Builder))
+ return replaceInstUsesWith(SI, Folded);
// NOTE: if we wanted to, this is where to detect integer MIN/MAX
bool Changed = false;
@@ -4200,10 +4332,9 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) {
bool IsCastNeeded = LHS->getType() != SelType;
Value *CmpLHS = cast<CmpInst>(CondVal)->getOperand(0);
Value *CmpRHS = cast<CmpInst>(CondVal)->getOperand(1);
- if (IsCastNeeded ||
- (LHS->getType()->isFPOrFPVectorTy() &&
- ((CmpLHS != LHS && CmpLHS != RHS) ||
- (CmpRHS != LHS && CmpRHS != RHS)))) {
+ if (IsCastNeeded || (LHS->getType()->isFPOrFPVectorTy() &&
+ ((CmpLHS != LHS && CmpLHS != RHS) ||
+ (CmpRHS != LHS && CmpRHS != RHS)))) {
CmpInst::Predicate MinMaxPred = getMinMaxPred(SPF, SPR.Ordered);
Value *Cmp;
diff --git a/llvm/test/Transforms/InstCombine/usub_sat_to_msb_mask.ll b/llvm/test/Transforms/InstCombine/usub_sat_to_msb_mask.ll
new file mode 100644
index 0000000000000..ffa77f4b42138
--- /dev/null
+++ b/llvm/test/Transforms/InstCombine/usub_sat_to_msb_mask.ll
@@ -0,0 +1,126 @@
+
+; RUN: opt -passes=instcombine -S < %s 2>&1 | FileCheck %s
+
+declare i8 @llvm.usub.sat.i8(i8, i8)
+declare i16 @llvm.usub.sat.i16(i16, i16)
+declare i32 @llvm.usub.sat.i32(i32, i32)
+declare i64 @llvm.usub.sat.i64(i64, i64)
+
+define i8 @test_i8(i8 %a, i8 %b) {
+; CHECK-LABEL: @test_i8(
+; CHECK-NEXT: call i8 @llvm.usub.sat.i8(i8 %a, i8 96)
+; CHECK-NEXT: call i8 @llvm.usub.sat.i8(i8 %b, i8 112)
+; CHECK-NEXT: or i8
+; CHECK-NEXT: and i8
+; CHECK-NEXT: ret i8
+
+ %a_sub = call i8 @llvm.usub.sat.i8(i8 %a, i8 223)
+ %b_sub = call i8 @llvm.usub.sat.i8(i8 %b, i8 239)
+ %or = or i8 %a_sub, %b_sub
+ %cmp = icmp eq i8 %or, 0
+ %res = select i1 %cmp, i8 0, i8 128
+ ret i8 %res
+}
+
+define i16 @test_i16(i16 %a, i16 %b) {
+; CHECK-LABEL: @test_i16(
+; CHECK-NEXT: call i16 @llvm.usub.sat.i16(i16 %a, i16 32642)
+; CHECK-NEXT: call i16 @llvm.usub.sat.i16(i16 %b, i16 32656)
+; CHECK-NEXT: or i16
+; CHECK-NEXT: and i16
+; CHECK-NEXT: ret i16
+
+ %a_sub = call i16 @llvm.usub.sat.i16(i16 %a, i16 65409)
+ %b_sub = call i16 @llvm.usub.sat.i16(i16 %b, i16 65423)
+ %or = or i16 %a_sub, %b_sub
+ %cmp = icmp eq i16 %or, 0
+ %res = select i1 %cmp, i16 0, i16 32768
+ ret i16 %res
+}
+
+define i32 @test_i32(i32 %a, i32 %b) {
+; CHECK-LABEL: @test_i32(
+; CHECK-NEXT: call i32 @llvm.usub.sat.i32(i32 %a, i32 224)
+; CHECK-NEXT: call i32 @llvm.usub.sat.i32(i32 %b, i32 240)
+; CHECK-NEXT: or i32
+; CHECK-NEXT: and i32
+; CHECK-NEXT: ret i32
+
+ %a_sub = call i32 @llvm.usub.sat.i32(i32 %a, i32 2147483871)
+ %b_sub = call i32 @llvm.usub.sat.i32(i32 %b, i32 2147483887)
+ %or = or i32 %a_sub, %b_sub
+ %cmp = icmp eq i32 %or, 0
+ %res = select i1 %cmp, i32 0, i32 2147483648
+ ret i32 %res
+}
+
+define i64 @test_i64(i64 %a, i64 %b) {
+; CHECK-LABEL: @test_i64(
+; CHECK-NEXT: call i64 @llvm.usub.sat.i64(i64 %a, i64 224)
+; CHECK-NEXT: call i64 @llvm.usub.sat.i64(i64 %b, i64 240)
+; CHECK-NEXT: or i64
+; CHECK-NEXT: and i64
+; CHECK-NEXT: ret i64
+
+ %a_sub = call i64 @llvm.usub.sat.i64(i64 %a, i64 9223372036854776031)
+ %b_sub = call i64 @llvm.usub.sat.i64(i64 %b, i64 9223372036854776047)
+ %or = or i64 %a_sub, %b_sub
+ %cmp = icmp eq i64 %or, 0
+ %res = select i1 %cmp, i64 0, i64 9223372036854775808
+ ret i64 %res
+}
+
+define i32 @no_fold_due_to_small_K(i32 %a, i32 %b) {
+; CHECK-LABEL: @no_fold_due_to_small_K(
+; CHECK: call i32 @llvm.usub.sat.i32(i32 %a, i32 100)
+; CHECK: call i32 @llvm.usub.sat.i32(i32 %b, i32 239)
+; CHECK: or i32
+; CHECK: icmp eq i32
+; CHECK: select
+; CHECK: ret i32
+
+ %a_sub = call i32 @llvm.usub.sat.i32(i32 %a, i32 100)
+ %b_sub = call i32 @llvm.usub.sat.i32(i32 %b, i32 239)
+ %or = or i32 %a_sub, %b_sub
+ %cmp = icmp eq i32 %or, 0
+ %res = select i1 %cmp, i32 0, i32 2147483648
+ ret i32 %res
+}
+
+define i32 @commuted_test_neg(i32 %a, i32 %b) {
+; CHECK-LABEL: @commuted_test_neg(
+; CHECK-NEXT: call i32 @llvm.usub.sat.i32(i32 %b, i32 239)
+; CHECK-NEXT: call i32 @llvm.usub.sat.i32(i32 %a, i32 223)
+; CHECK-NEXT: or i32
+; CHECK-NEXT: icmp eq i32
+; CHECK-NEXT: select
+; CHECK-NEXT: ret i32
+
+ %b_sub = call i32 @llvm.usub.sat.i32(i32 %b, i32 239)
+ %a_sub = call i32 @llvm.usub.sat.i32(i32 %a, i32 223)
+ %or = or i32 %b_sub, %a_sub
+ %cmp = icmp eq i32 %or, 0
+ %res = select i1 %cmp, i32 0, i32 2147483648
+ ret i32 %res
+}
+define <4 x i32> @vector_test(<4 x i32> %a, <4 x i32> %b) {
+; CHECK-LABEL: @vector_test(
+; CHECK-NEXT: call <4 x i32> @llvm.usub.sat.v4i32(<4 x i32> %a, <4 x i32> splat (i32 224))
+; CHECK-NEXT: call <4 x i32> @llvm.usub.sat.v4i32(<4 x i32> %b, <4 x i32> splat (i32 240))
+; CHECK-NEXT: or <4 x i32>
+; CHECK-NEXT: and <4 x i32>
+; CHECK-NEXT: ret <4 x i32>
+
+
+ %a_sub = call <4 x i32> @llvm.usub.sat.v4i32(
+ <4 x i32> %a,
+ <4 x i32> <i32 2147483871, i32 2147483871, i32 2147483871, i32 2147483871>)
+ %b_sub = call <4 x i32> @llvm.usub.sat.v4i32(
+ <4 x i32> %b,
+ <4 x i32> <i32 2147483887, i32 2147483887, i32 2147483887, i32 2147483887>)
+ %or = or <4 x i32> %a_sub, %b_sub
+ %cmp = icmp eq <4 x i32> %or, zeroinitializer
+ %res = select <4 x i1> %cmp, <4 x i32> zeroinitializer,
+ <4 x i32> <i32 -2147483648, i32 -2147483648, i32 -2147483648, i32 -2147483648>
+ ret <4 x i32> %res
+}
>From dbfd5989029598038bb9b3bef4cbba31619aa764 Mon Sep 17 00:00:00 2001
From: Nimit Sachdeva <nimsach at amazon.com>
Date: Mon, 28 Jul 2025 17:56:10 -0400
Subject: [PATCH 2/2] refactorization
---
.../Transforms/InstCombine/InstCombineSelect.cpp | 14 +++++++-------
1 file changed, 7 insertions(+), 7 deletions(-)
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
index 74544009e6872..7e4eaa9745917 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
@@ -42,8 +42,6 @@
#include "llvm/Support/KnownBits.h"
#include "llvm/Transforms/InstCombine/InstCombiner.h"
#include <cassert>
-#include <cstddef>
-#include <iostream>
#include <utility>
#define DEBUG_TYPE "instcombine"
@@ -52,6 +50,7 @@
using namespace llvm;
using namespace PatternMatch;
+
/// Replace a select operand based on an equality comparison with the identity
/// constant of a binop.
static Instruction *foldSelectBinOpIdentity(SelectInst &Sel,
@@ -1714,6 +1713,7 @@ tryToReuseConstantFromSelectInComparison(SelectInst &Sel, ICmpInst &Cmp,
if (Pred == CmpInst::ICMP_ULT && match(X, m_Add(m_Value(), m_Constant())))
return nullptr;
+
Value *SelVal0, *SelVal1; // We do not care which one is from where.
match(&Sel, m_Select(m_Value(), m_Value(SelVal0), m_Value(SelVal1)));
// At least one of these values we are selecting between must be a constant
@@ -2004,8 +2004,7 @@ Value *InstCombinerImpl::foldSelectWithConstOpToBinOp(ICmpInst *Cmp,
/// %b_sub' = usub.sat(y, IntConst2 - MostSignificantBit)
/// %or = or %a_sub', %b_sub'
/// %and = and %or, MostSignificantBit
-/// If the args are vectors
-///
+/// Likewise, for vector arguments as well.
static Instruction *foldICmpUSubSatWithAndForMostSignificantBitCmp(
SelectInst &SI, ICmpInst *ICI, InstCombiner::BuilderTy &Builder) {
auto *CI = dyn_cast<ICmpInst>(SI.getCondition());
@@ -4332,9 +4331,10 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) {
bool IsCastNeeded = LHS->getType() != SelType;
Value *CmpLHS = cast<CmpInst>(CondVal)->getOperand(0);
Value *CmpRHS = cast<CmpInst>(CondVal)->getOperand(1);
- if (IsCastNeeded || (LHS->getType()->isFPOrFPVectorTy() &&
- ((CmpLHS != LHS && CmpLHS != RHS) ||
- (CmpRHS != LHS && CmpRHS != RHS)))) {
+ if (IsCastNeeded ||
+ (LHS->getType()->isFPOrFPVectorTy() &&
+ ((CmpLHS != LHS && CmpLHS != RHS) ||
+ (CmpRHS != LHS && CmpRHS != RHS)))) {
CmpInst::Predicate MinMaxPred = getMinMaxPred(SPF, SPR.Ordered);
Value *Cmp;
More information about the llvm-commits
mailing list