[llvm] [llvm] Optimize usub.sat fix for #79690 (PR #151044)
Nimit Sachdeva via llvm-commits
llvm-commits at lists.llvm.org
Sun Aug 24 16:19:15 PDT 2025
https://github.com/nimit25 updated https://github.com/llvm/llvm-project/pull/151044
>From 82336b681fa6e917d4fe217a227614f481aa30ac Mon Sep 17 00:00:00 2001
From: Nimit Sachdeva <nimsach at amazon.com>
Date: Mon, 28 Jul 2025 17:23:11 -0400
Subject: [PATCH 1/5] Optimize usub.sat fix for #79690
---
.../InstCombine/InstCombineSelect.cpp | 143 +++++++++++++++++-
.../InstCombine/usub_sat_to_msb_mask.ll | 126 +++++++++++++++
2 files changed, 263 insertions(+), 6 deletions(-)
create mode 100644 llvm/test/Transforms/InstCombine/usub_sat_to_msb_mask.ll
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
index eb4332fbc0959..74544009e6872 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
@@ -42,6 +42,8 @@
#include "llvm/Support/KnownBits.h"
#include "llvm/Transforms/InstCombine/InstCombiner.h"
#include <cassert>
+#include <cstddef>
+#include <iostream>
#include <utility>
#define DEBUG_TYPE "instcombine"
@@ -50,7 +52,6 @@
using namespace llvm;
using namespace PatternMatch;
-
/// Replace a select operand based on an equality comparison with the identity
/// constant of a binop.
static Instruction *foldSelectBinOpIdentity(SelectInst &Sel,
@@ -1713,7 +1714,6 @@ tryToReuseConstantFromSelectInComparison(SelectInst &Sel, ICmpInst &Cmp,
if (Pred == CmpInst::ICMP_ULT && match(X, m_Add(m_Value(), m_Constant())))
return nullptr;
-
Value *SelVal0, *SelVal1; // We do not care which one is from where.
match(&Sel, m_Select(m_Value(), m_Value(SelVal0), m_Value(SelVal1)));
// At least one of these values we are selecting between must be a constant
@@ -1993,6 +1993,135 @@ Value *InstCombinerImpl::foldSelectWithConstOpToBinOp(ICmpInst *Cmp,
return BinOp;
}
+/// Folds:
+/// %a_sub = call @llvm.usub.sat(x, IntConst1)
+/// %b_sub = call @llvm.usub.sat(y, IntConst2)
+/// %or = or %a_sub, %b_sub
+/// %cmp = icmp eq %or, 0
+/// %sel = select %cmp, 0, MostSignificantBit
+/// into:
+/// %a_sub' = usub.sat(x, IntConst1 - MostSignificantBit)
+/// %b_sub' = usub.sat(y, IntConst2 - MostSignificantBit)
+/// %or = or %a_sub', %b_sub'
+/// %and = and %or, MostSignificantBit
+/// If the args are vectors
+///
+static Instruction *foldICmpUSubSatWithAndForMostSignificantBitCmp(
+ SelectInst &SI, ICmpInst *ICI, InstCombiner::BuilderTy &Builder) {
+ auto *CI = dyn_cast<ICmpInst>(SI.getCondition());
+ if (!CI) {
+ return nullptr;
+ }
+
+ Value *CmpLHS = CI->getOperand(0);
+ Value *CmpRHS = CI->getOperand(1);
+ if (!match(CmpRHS, m_Zero())) {
+ return nullptr;
+ }
+ auto Pred = CI->getPredicate();
+ auto *TrueVal = SI.getTrueValue();
+ auto *FalseVal = SI.getFalseValue();
+
+ if (Pred != ICmpInst::ICMP_EQ)
+ return nullptr;
+
+ // Match: icmp eq (or (usub.sat A, IntConst1), (usub.sat B, IntConst2)), 0
+ Value *A, *B;
+ ConstantInt *IntConst1, *IntConst2, *PossibleMSBInt;
+
+ if (match(CmpLHS, m_Or(m_Intrinsic<Intrinsic::usub_sat>(
+ m_Value(A), m_ConstantInt(IntConst1)),
+ m_Intrinsic<Intrinsic::usub_sat>(
+ m_Value(B), m_ConstantInt(IntConst2)))) &&
+ match(TrueVal, m_Zero()) &&
+ match(FalseVal, m_ConstantInt(PossibleMSBInt))) {
+ auto *Ty = A->getType();
+ unsigned BW = Ty->getIntegerBitWidth();
+ APInt MostSignificantBit = APInt::getOneBitSet(BW, BW - 1);
+
+ if (PossibleMSBInt->getValue() != MostSignificantBit)
+ return nullptr;
+ // Ensure IntConst1 and IntConst2 are >= MostSignificantBit
+ if (IntConst1->getValue().ult(MostSignificantBit) ||
+ IntConst2->getValue().ult(MostSignificantBit))
+ return nullptr;
+
+ // Rewrite:
+ Value *NewA = Builder.CreateBinaryIntrinsic(
+ Intrinsic::usub_sat, A,
+ ConstantInt::get(Ty, IntConst1->getValue() - MostSignificantBit + 1));
+ Value *NewB = Builder.CreateBinaryIntrinsic(
+ Intrinsic::usub_sat, B,
+ ConstantInt::get(Ty, IntConst2->getValue() - MostSignificantBit + 1));
+ Value *Or = Builder.CreateOr(NewA, NewB);
+ Value *And =
+ Builder.CreateAnd(Or, ConstantInt::get(Ty, MostSignificantBit));
+ return cast<Instruction>(And);
+ }
+ Constant *Const1, *Const2, *PossibleMSB;
+ if (match(CmpLHS, m_Or(m_Intrinsic<Intrinsic::usub_sat>(m_Value(A),
+ m_Constant(Const1)),
+ m_Intrinsic<Intrinsic::usub_sat>(
+ m_Value(B), m_Constant(Const2)))) &&
+ match(TrueVal, m_Zero()) && match(FalseVal, m_Constant(PossibleMSB))) {
+ auto *VecTy1 = dyn_cast<FixedVectorType>(Const1->getType());
+ auto *VecTy2 = dyn_cast<FixedVectorType>(Const2->getType());
+ auto *VecTyMSB = dyn_cast<FixedVectorType>(PossibleMSB->getType());
+ if (!VecTy1 || !VecTy2 || !VecTyMSB) {
+ return nullptr;
+ }
+
+ unsigned NumElements = VecTy1->getNumElements();
+
+ if (NumElements != VecTy2->getNumElements() ||
+ NumElements != VecTyMSB->getNumElements() || NumElements == 0) {
+ return nullptr;
+ }
+ auto *SplatMSB =
+ dyn_cast<ConstantInt>(PossibleMSB->getAggregateElement(0u));
+ unsigned BW = SplatMSB->getValue().getBitWidth();
+ APInt MostSignificantBit = APInt::getOneBitSet(BW, BW - 1);
+ if (!SplatMSB || SplatMSB->getValue() != MostSignificantBit) {
+ return nullptr;
+ }
+ for (unsigned int i = 1; i < NumElements; ++i) {
+ auto *Element =
+ dyn_cast<ConstantInt>(PossibleMSB->getAggregateElement(i));
+ if (!Element || Element->getValue() != SplatMSB->getValue()) {
+ return nullptr;
+ }
+ }
+ SmallVector<Constant *, 16> Arg1, Arg2;
+ for (unsigned int i = 0; i < NumElements; ++i) {
+ auto *E1 = dyn_cast<ConstantInt>(Const1->getAggregateElement(i));
+ auto *E2 = dyn_cast<ConstantInt>(Const2->getAggregateElement(i));
+ if (!E1 || !E2) {
+ return nullptr;
+ }
+ if (E1->getValue().ult(SplatMSB->getValue()) ||
+ E2->getValue().ult(SplatMSB->getValue())) {
+ return nullptr;
+ }
+ Arg1.emplace_back(
+ ConstantInt::get(A->getType()->getScalarType(),
+ E1->getValue() - MostSignificantBit + 1));
+ Arg2.emplace_back(
+ ConstantInt::get(B->getType()->getScalarType(),
+ E2->getValue() - MostSignificantBit + 1));
+ }
+ Constant *ConstVec1 = ConstantVector::get(Arg1);
+ Constant *ConstVec2 = ConstantVector::get(Arg2);
+ Value *NewA =
+ Builder.CreateBinaryIntrinsic(Intrinsic::usub_sat, A, ConstVec1);
+ Value *NewB =
+ Builder.CreateBinaryIntrinsic(Intrinsic::usub_sat, B, ConstVec2);
+ Value *Or = Builder.CreateOr(NewA, NewB);
+ Value *And = Builder.CreateAnd(Or, PossibleMSB);
+ return cast<Instruction>(And);
+ }
+ return nullptr;
+}
+
/// Visit a SelectInst that has an ICmpInst as its first operand.
Instruction *InstCombinerImpl::foldSelectInstWithICmp(SelectInst &SI,
ICmpInst *ICI) {
@@ -2009,6 +2138,9 @@ Instruction *InstCombinerImpl::foldSelectInstWithICmp(SelectInst &SI,
if (Instruction *NewSel =
tryToReuseConstantFromSelectInComparison(SI, *ICI, *this))
return NewSel;
+ if (Instruction *Folded =
+ foldICmpUSubSatWithAndForMostSignificantBitCmp(SI, ICI, Builder))
+ return replaceInstUsesWith(SI, Folded);
// NOTE: if we wanted to, this is where to detect integer MIN/MAX
bool Changed = false;
@@ -4200,10 +4332,9 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) {
bool IsCastNeeded = LHS->getType() != SelType;
Value *CmpLHS = cast<CmpInst>(CondVal)->getOperand(0);
Value *CmpRHS = cast<CmpInst>(CondVal)->getOperand(1);
- if (IsCastNeeded ||
- (LHS->getType()->isFPOrFPVectorTy() &&
- ((CmpLHS != LHS && CmpLHS != RHS) ||
- (CmpRHS != LHS && CmpRHS != RHS)))) {
+ if (IsCastNeeded || (LHS->getType()->isFPOrFPVectorTy() &&
+ ((CmpLHS != LHS && CmpLHS != RHS) ||
+ (CmpRHS != LHS && CmpRHS != RHS)))) {
CmpInst::Predicate MinMaxPred = getMinMaxPred(SPF, SPR.Ordered);
Value *Cmp;
diff --git a/llvm/test/Transforms/InstCombine/usub_sat_to_msb_mask.ll b/llvm/test/Transforms/InstCombine/usub_sat_to_msb_mask.ll
new file mode 100644
index 0000000000000..ffa77f4b42138
--- /dev/null
+++ b/llvm/test/Transforms/InstCombine/usub_sat_to_msb_mask.ll
@@ -0,0 +1,126 @@
+
+; RUN: opt -passes=instcombine -S < %s 2>&1 | FileCheck %s
+
+declare i8 @llvm.usub.sat.i8(i8, i8)
+declare i16 @llvm.usub.sat.i16(i16, i16)
+declare i32 @llvm.usub.sat.i32(i32, i32)
+declare i64 @llvm.usub.sat.i64(i64, i64)
+
+define i8 @test_i8(i8 %a, i8 %b) {
+; CHECK-LABEL: @test_i8(
+; CHECK-NEXT: call i8 @llvm.usub.sat.i8(i8 %a, i8 96)
+; CHECK-NEXT: call i8 @llvm.usub.sat.i8(i8 %b, i8 112)
+; CHECK-NEXT: or i8
+; CHECK-NEXT: and i8
+; CHECK-NEXT: ret i8
+
+ %a_sub = call i8 @llvm.usub.sat.i8(i8 %a, i8 223)
+ %b_sub = call i8 @llvm.usub.sat.i8(i8 %b, i8 239)
+ %or = or i8 %a_sub, %b_sub
+ %cmp = icmp eq i8 %or, 0
+ %res = select i1 %cmp, i8 0, i8 128
+ ret i8 %res
+}
+
+define i16 @test_i16(i16 %a, i16 %b) {
+; CHECK-LABEL: @test_i16(
+; CHECK-NEXT: call i16 @llvm.usub.sat.i16(i16 %a, i16 32642)
+; CHECK-NEXT: call i16 @llvm.usub.sat.i16(i16 %b, i16 32656)
+; CHECK-NEXT: or i16
+; CHECK-NEXT: and i16
+; CHECK-NEXT: ret i16
+
+ %a_sub = call i16 @llvm.usub.sat.i16(i16 %a, i16 65409)
+ %b_sub = call i16 @llvm.usub.sat.i16(i16 %b, i16 65423)
+ %or = or i16 %a_sub, %b_sub
+ %cmp = icmp eq i16 %or, 0
+ %res = select i1 %cmp, i16 0, i16 32768
+ ret i16 %res
+}
+
+define i32 @test_i32(i32 %a, i32 %b) {
+; CHECK-LABEL: @test_i32(
+; CHECK-NEXT: call i32 @llvm.usub.sat.i32(i32 %a, i32 224)
+; CHECK-NEXT: call i32 @llvm.usub.sat.i32(i32 %b, i32 240)
+; CHECK-NEXT: or i32
+; CHECK-NEXT: and i32
+; CHECK-NEXT: ret i32
+
+ %a_sub = call i32 @llvm.usub.sat.i32(i32 %a, i32 2147483871)
+ %b_sub = call i32 @llvm.usub.sat.i32(i32 %b, i32 2147483887)
+ %or = or i32 %a_sub, %b_sub
+ %cmp = icmp eq i32 %or, 0
+ %res = select i1 %cmp, i32 0, i32 2147483648
+ ret i32 %res
+}
+
+define i64 @test_i64(i64 %a, i64 %b) {
+; CHECK-LABEL: @test_i64(
+; CHECK-NEXT: call i64 @llvm.usub.sat.i64(i64 %a, i64 224)
+; CHECK-NEXT: call i64 @llvm.usub.sat.i64(i64 %b, i64 240)
+; CHECK-NEXT: or i64
+; CHECK-NEXT: and i64
+; CHECK-NEXT: ret i64
+
+ %a_sub = call i64 @llvm.usub.sat.i64(i64 %a, i64 9223372036854776031)
+ %b_sub = call i64 @llvm.usub.sat.i64(i64 %b, i64 9223372036854776047)
+ %or = or i64 %a_sub, %b_sub
+ %cmp = icmp eq i64 %or, 0
+ %res = select i1 %cmp, i64 0, i64 9223372036854775808
+ ret i64 %res
+}
+
+define i32 @no_fold_due_to_small_K(i32 %a, i32 %b) {
+; CHECK-LABEL: @no_fold_due_to_small_K(
+; CHECK: call i32 @llvm.usub.sat.i32(i32 %a, i32 100)
+; CHECK: call i32 @llvm.usub.sat.i32(i32 %b, i32 239)
+; CHECK: or i32
+; CHECK: icmp eq i32
+; CHECK: select
+; CHECK: ret i32
+
+ %a_sub = call i32 @llvm.usub.sat.i32(i32 %a, i32 100)
+ %b_sub = call i32 @llvm.usub.sat.i32(i32 %b, i32 239)
+ %or = or i32 %a_sub, %b_sub
+ %cmp = icmp eq i32 %or, 0
+ %res = select i1 %cmp, i32 0, i32 2147483648
+ ret i32 %res
+}
+
+define i32 @commuted_test_neg(i32 %a, i32 %b) {
+; CHECK-LABEL: @commuted_test_neg(
+; CHECK-NEXT: call i32 @llvm.usub.sat.i32(i32 %b, i32 239)
+; CHECK-NEXT: call i32 @llvm.usub.sat.i32(i32 %a, i32 223)
+; CHECK-NEXT: or i32
+; CHECK-NEXT: icmp eq i32
+; CHECK-NEXT: select
+; CHECK-NEXT: ret i32
+
+ %b_sub = call i32 @llvm.usub.sat.i32(i32 %b, i32 239)
+ %a_sub = call i32 @llvm.usub.sat.i32(i32 %a, i32 223)
+ %or = or i32 %b_sub, %a_sub
+ %cmp = icmp eq i32 %or, 0
+ %res = select i1 %cmp, i32 0, i32 2147483648
+ ret i32 %res
+}
+define <4 x i32> @vector_test(<4 x i32> %a, <4 x i32> %b) {
+; CHECK-LABEL: @vector_test(
+; CHECK-NEXT: call <4 x i32> @llvm.usub.sat.v4i32(<4 x i32> %a, <4 x i32> splat (i32 224))
+; CHECK-NEXT: call <4 x i32> @llvm.usub.sat.v4i32(<4 x i32> %b, <4 x i32> splat (i32 240))
+; CHECK-NEXT: or <4 x i32>
+; CHECK-NEXT: and <4 x i32>
+; CHECK-NEXT: ret <4 x i32>
+
+
+ %a_sub = call <4 x i32> @llvm.usub.sat.v4i32(
+ <4 x i32> %a,
+ <4 x i32> <i32 2147483871, i32 2147483871, i32 2147483871, i32 2147483871>)
+ %b_sub = call <4 x i32> @llvm.usub.sat.v4i32(
+ <4 x i32> %b,
+ <4 x i32> <i32 2147483887, i32 2147483887, i32 2147483887, i32 2147483887>)
+ %or = or <4 x i32> %a_sub, %b_sub
+ %cmp = icmp eq <4 x i32> %or, zeroinitializer
+ %res = select <4 x i1> %cmp, <4 x i32> zeroinitializer,
+ <4 x i32> <i32 -2147483648, i32 -2147483648, i32 -2147483648, i32 -2147483648>
+ ret <4 x i32> %res
+}
>From dbfd5989029598038bb9b3bef4cbba31619aa764 Mon Sep 17 00:00:00 2001
From: Nimit Sachdeva <nimsach at amazon.com>
Date: Mon, 28 Jul 2025 17:56:10 -0400
Subject: [PATCH 2/5] refactorization
---
.../Transforms/InstCombine/InstCombineSelect.cpp | 14 +++++++-------
1 file changed, 7 insertions(+), 7 deletions(-)
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
index 74544009e6872..7e4eaa9745917 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
@@ -42,8 +42,6 @@
#include "llvm/Support/KnownBits.h"
#include "llvm/Transforms/InstCombine/InstCombiner.h"
#include <cassert>
-#include <cstddef>
-#include <iostream>
#include <utility>
#define DEBUG_TYPE "instcombine"
@@ -52,6 +50,7 @@
using namespace llvm;
using namespace PatternMatch;
+
/// Replace a select operand based on an equality comparison with the identity
/// constant of a binop.
static Instruction *foldSelectBinOpIdentity(SelectInst &Sel,
@@ -1714,6 +1713,7 @@ tryToReuseConstantFromSelectInComparison(SelectInst &Sel, ICmpInst &Cmp,
if (Pred == CmpInst::ICMP_ULT && match(X, m_Add(m_Value(), m_Constant())))
return nullptr;
+
Value *SelVal0, *SelVal1; // We do not care which one is from where.
match(&Sel, m_Select(m_Value(), m_Value(SelVal0), m_Value(SelVal1)));
// At least one of these values we are selecting between must be a constant
@@ -2004,8 +2004,7 @@ Value *InstCombinerImpl::foldSelectWithConstOpToBinOp(ICmpInst *Cmp,
/// %b_sub' = usub.sat(y, IntConst2 - MostSignificantBit)
/// %or = or %a_sub', %b_sub'
/// %and = and %or, MostSignificantBit
-/// If the args are vectors
-///
+/// Likewise, for vector arguments as well.
static Instruction *foldICmpUSubSatWithAndForMostSignificantBitCmp(
SelectInst &SI, ICmpInst *ICI, InstCombiner::BuilderTy &Builder) {
auto *CI = dyn_cast<ICmpInst>(SI.getCondition());
@@ -4332,9 +4331,10 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) {
bool IsCastNeeded = LHS->getType() != SelType;
Value *CmpLHS = cast<CmpInst>(CondVal)->getOperand(0);
Value *CmpRHS = cast<CmpInst>(CondVal)->getOperand(1);
- if (IsCastNeeded || (LHS->getType()->isFPOrFPVectorTy() &&
- ((CmpLHS != LHS && CmpLHS != RHS) ||
- (CmpRHS != LHS && CmpRHS != RHS)))) {
+ if (IsCastNeeded ||
+ (LHS->getType()->isFPOrFPVectorTy() &&
+ ((CmpLHS != LHS && CmpLHS != RHS) ||
+ (CmpRHS != LHS && CmpRHS != RHS)))) {
CmpInst::Predicate MinMaxPred = getMinMaxPred(SPF, SPR.Ordered);
Value *Cmp;
>From c572adbf917ce9a07e060eb78b50f7f00e79bd77 Mon Sep 17 00:00:00 2001
From: Nimit Sachdeva <nimsach at amazon.com>
Date: Sun, 17 Aug 2025 22:37:54 -0400
Subject: [PATCH 3/5] Add more tests and change the condition for negation
---
.../InstCombine/InstCombineSelect.cpp | 13 +-
.../InstCombine/usub_sat_to_msb_mask.ll | 178 +++++++++++++-----
2 files changed, 136 insertions(+), 55 deletions(-)
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
index 7e4eaa9745917..cd0eb59007b3b 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
@@ -50,7 +50,6 @@
using namespace llvm;
using namespace PatternMatch;
-
/// Replace a select operand based on an equality comparison with the identity
/// constant of a binop.
static Instruction *foldSelectBinOpIdentity(SelectInst &Sel,
@@ -1713,7 +1712,6 @@ tryToReuseConstantFromSelectInComparison(SelectInst &Sel, ICmpInst &Cmp,
if (Pred == CmpInst::ICMP_ULT && match(X, m_Add(m_Value(), m_Constant())))
return nullptr;
-
Value *SelVal0, *SelVal1; // We do not care which one is from where.
match(&Sel, m_Select(m_Value(), m_Value(SelVal0), m_Value(SelVal1)));
// At least one of these values we are selecting between must be a constant
@@ -2021,7 +2019,7 @@ static Instruction *foldICmpUSubSatWithAndForMostSignificantBitCmp(
auto *TrueVal = SI.getTrueValue();
auto *FalseVal = SI.getFalseValue();
- if (Pred != ICmpInst::ICMP_EQ)
+ if (Pred != ICmpInst::ICMP_EQ && Pred != llvm::ICmpInst::ICMP_NE)
return nullptr;
// Match: icmp eq (or (usub.sat A, IntConst1), (usub.sat B, IntConst2)), 0
@@ -2032,8 +2030,10 @@ static Instruction *foldICmpUSubSatWithAndForMostSignificantBitCmp(
m_Value(A), m_ConstantInt(IntConst1)),
m_Intrinsic<Intrinsic::usub_sat>(
m_Value(B), m_ConstantInt(IntConst2)))) &&
- match(TrueVal, m_Zero()) &&
- match(FalseVal, m_ConstantInt(PossibleMSBInt))) {
+ (match(TrueVal, m_Zero()) &&
+ match(FalseVal, m_ConstantInt(PossibleMSBInt)) ||
+ match(TrueVal, m_ConstantInt(PossibleMSBInt)) &&
+ match(FalseVal, m_Zero()))) {
auto *Ty = A->getType();
unsigned BW = Ty->getIntegerBitWidth();
APInt MostSignificantBit = APInt::getOneBitSet(BW, BW - 1);
@@ -2062,7 +2062,8 @@ static Instruction *foldICmpUSubSatWithAndForMostSignificantBitCmp(
m_Constant(Const1)),
m_Intrinsic<Intrinsic::usub_sat>(
m_Value(B), m_Constant(Const2)))) &&
- match(TrueVal, m_Zero()) && match(FalseVal, m_Constant(PossibleMSB))) {
+ (match(TrueVal, m_Zero()) && match(FalseVal, m_Constant(PossibleMSB))
+ || match(TrueVal, m_Constant(PossibleMSB) ) && match(FalseVal, m_Zero()))) {
auto *VecTy1 = dyn_cast<FixedVectorType>(Const1->getType());
auto *VecTy2 = dyn_cast<FixedVectorType>(Const2->getType());
auto *VecTyMSB = dyn_cast<FixedVectorType>(PossibleMSB->getType());
diff --git a/llvm/test/Transforms/InstCombine/usub_sat_to_msb_mask.ll b/llvm/test/Transforms/InstCombine/usub_sat_to_msb_mask.ll
index ffa77f4b42138..f6402c24315c5 100644
--- a/llvm/test/Transforms/InstCombine/usub_sat_to_msb_mask.ll
+++ b/llvm/test/Transforms/InstCombine/usub_sat_to_msb_mask.ll
@@ -1,3 +1,4 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
; RUN: opt -passes=instcombine -S < %s 2>&1 | FileCheck %s
@@ -7,12 +8,14 @@ declare i32 @llvm.usub.sat.i32(i32, i32)
declare i64 @llvm.usub.sat.i64(i64, i64)
define i8 @test_i8(i8 %a, i8 %b) {
-; CHECK-LABEL: @test_i8(
-; CHECK-NEXT: call i8 @llvm.usub.sat.i8(i8 %a, i8 96)
-; CHECK-NEXT: call i8 @llvm.usub.sat.i8(i8 %b, i8 112)
-; CHECK-NEXT: or i8
-; CHECK-NEXT: and i8
-; CHECK-NEXT: ret i8
+; CHECK-LABEL: define i8 @test_i8(
+; CHECK-SAME: i8 [[A:%.*]], i8 [[B:%.*]]) {
+; CHECK-NEXT: [[TMP1:%.*]] = call i8 @llvm.usub.sat.i8(i8 [[A]], i8 96)
+; CHECK-NEXT: [[TMP2:%.*]] = call i8 @llvm.usub.sat.i8(i8 [[B]], i8 112)
+; CHECK-NEXT: [[TMP3:%.*]] = or i8 [[TMP1]], [[TMP2]]
+; CHECK-NEXT: [[RES:%.*]] = and i8 [[TMP3]], -128
+; CHECK-NEXT: ret i8 [[RES]]
+;
%a_sub = call i8 @llvm.usub.sat.i8(i8 %a, i8 223)
%b_sub = call i8 @llvm.usub.sat.i8(i8 %b, i8 239)
@@ -22,13 +25,33 @@ define i8 @test_i8(i8 %a, i8 %b) {
ret i8 %res
}
+define i8 @test_i8_ne(i8 %a, i8 %b) {
+; CHECK-LABEL: define i8 @test_i8_ne(
+; CHECK-SAME: i8 [[A:%.*]], i8 [[B:%.*]]) {
+; CHECK-NEXT: [[TMP1:%.*]] = call i8 @llvm.usub.sat.i8(i8 [[A]], i8 96)
+; CHECK-NEXT: [[TMP2:%.*]] = call i8 @llvm.usub.sat.i8(i8 [[B]], i8 112)
+; CHECK-NEXT: [[TMP3:%.*]] = or i8 [[TMP1]], [[TMP2]]
+; CHECK-NEXT: [[RES:%.*]] = and i8 [[TMP3]], -128
+; CHECK-NEXT: ret i8 [[RES]]
+;
+
+ %a_sub = call i8 @llvm.usub.sat.i8(i8 %a, i8 223)
+ %b_sub = call i8 @llvm.usub.sat.i8(i8 %b, i8 239)
+ %or = or i8 %a_sub, %b_sub
+ %cmp = icmp ne i8 %or, 0
+ %res = select i1 %cmp, i8 128, i8 0
+ ret i8 %res
+}
+
define i16 @test_i16(i16 %a, i16 %b) {
-; CHECK-LABEL: @test_i16(
-; CHECK-NEXT: call i16 @llvm.usub.sat.i16(i16 %a, i16 32642)
-; CHECK-NEXT: call i16 @llvm.usub.sat.i16(i16 %b, i16 32656)
-; CHECK-NEXT: or i16
-; CHECK-NEXT: and i16
-; CHECK-NEXT: ret i16
+; CHECK-LABEL: define i16 @test_i16(
+; CHECK-SAME: i16 [[A:%.*]], i16 [[B:%.*]]) {
+; CHECK-NEXT: [[TMP1:%.*]] = call i16 @llvm.usub.sat.i16(i16 [[A]], i16 32642)
+; CHECK-NEXT: [[TMP2:%.*]] = call i16 @llvm.usub.sat.i16(i16 [[B]], i16 32656)
+; CHECK-NEXT: [[TMP3:%.*]] = or i16 [[TMP1]], [[TMP2]]
+; CHECK-NEXT: [[RES:%.*]] = and i16 [[TMP3]], -32768
+; CHECK-NEXT: ret i16 [[RES]]
+;
%a_sub = call i16 @llvm.usub.sat.i16(i16 %a, i16 65409)
%b_sub = call i16 @llvm.usub.sat.i16(i16 %b, i16 65423)
@@ -39,12 +62,14 @@ define i16 @test_i16(i16 %a, i16 %b) {
}
define i32 @test_i32(i32 %a, i32 %b) {
-; CHECK-LABEL: @test_i32(
-; CHECK-NEXT: call i32 @llvm.usub.sat.i32(i32 %a, i32 224)
-; CHECK-NEXT: call i32 @llvm.usub.sat.i32(i32 %b, i32 240)
-; CHECK-NEXT: or i32
-; CHECK-NEXT: and i32
-; CHECK-NEXT: ret i32
+; CHECK-LABEL: define i32 @test_i32(
+; CHECK-SAME: i32 [[A:%.*]], i32 [[B:%.*]]) {
+; CHECK-NEXT: [[TMP1:%.*]] = call i32 @llvm.usub.sat.i32(i32 [[A]], i32 224)
+; CHECK-NEXT: [[TMP2:%.*]] = call i32 @llvm.usub.sat.i32(i32 [[B]], i32 240)
+; CHECK-NEXT: [[TMP3:%.*]] = or i32 [[TMP1]], [[TMP2]]
+; CHECK-NEXT: [[RES:%.*]] = and i32 [[TMP3]], -2147483648
+; CHECK-NEXT: ret i32 [[RES]]
+;
%a_sub = call i32 @llvm.usub.sat.i32(i32 %a, i32 2147483871)
%b_sub = call i32 @llvm.usub.sat.i32(i32 %b, i32 2147483887)
@@ -55,12 +80,14 @@ define i32 @test_i32(i32 %a, i32 %b) {
}
define i64 @test_i64(i64 %a, i64 %b) {
-; CHECK-LABEL: @test_i64(
-; CHECK-NEXT: call i64 @llvm.usub.sat.i64(i64 %a, i64 224)
-; CHECK-NEXT: call i64 @llvm.usub.sat.i64(i64 %b, i64 240)
-; CHECK-NEXT: or i64
-; CHECK-NEXT: and i64
-; CHECK-NEXT: ret i64
+; CHECK-LABEL: define i64 @test_i64(
+; CHECK-SAME: i64 [[A:%.*]], i64 [[B:%.*]]) {
+; CHECK-NEXT: [[TMP1:%.*]] = call i64 @llvm.usub.sat.i64(i64 [[A]], i64 224)
+; CHECK-NEXT: [[TMP2:%.*]] = call i64 @llvm.usub.sat.i64(i64 [[B]], i64 240)
+; CHECK-NEXT: [[TMP3:%.*]] = or i64 [[TMP1]], [[TMP2]]
+; CHECK-NEXT: [[RES:%.*]] = and i64 [[TMP3]], -9223372036854775808
+; CHECK-NEXT: ret i64 [[RES]]
+;
%a_sub = call i64 @llvm.usub.sat.i64(i64 %a, i64 9223372036854776031)
%b_sub = call i64 @llvm.usub.sat.i64(i64 %b, i64 9223372036854776047)
@@ -71,13 +98,15 @@ define i64 @test_i64(i64 %a, i64 %b) {
}
define i32 @no_fold_due_to_small_K(i32 %a, i32 %b) {
-; CHECK-LABEL: @no_fold_due_to_small_K(
-; CHECK: call i32 @llvm.usub.sat.i32(i32 %a, i32 100)
-; CHECK: call i32 @llvm.usub.sat.i32(i32 %b, i32 239)
-; CHECK: or i32
-; CHECK: icmp eq i32
-; CHECK: select
-; CHECK: ret i32
+; CHECK-LABEL: define i32 @no_fold_due_to_small_K(
+; CHECK-SAME: i32 [[A:%.*]], i32 [[B:%.*]]) {
+; CHECK-NEXT: [[A_SUB:%.*]] = call i32 @llvm.usub.sat.i32(i32 [[A]], i32 100)
+; CHECK-NEXT: [[B_SUB:%.*]] = call i32 @llvm.usub.sat.i32(i32 [[B]], i32 239)
+; CHECK-NEXT: [[OR:%.*]] = or i32 [[A_SUB]], [[B_SUB]]
+; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[OR]], 0
+; CHECK-NEXT: [[RES:%.*]] = select i1 [[CMP]], i32 0, i32 -2147483648
+; CHECK-NEXT: ret i32 [[RES]]
+;
%a_sub = call i32 @llvm.usub.sat.i32(i32 %a, i32 100)
%b_sub = call i32 @llvm.usub.sat.i32(i32 %b, i32 239)
@@ -88,13 +117,15 @@ define i32 @no_fold_due_to_small_K(i32 %a, i32 %b) {
}
define i32 @commuted_test_neg(i32 %a, i32 %b) {
-; CHECK-LABEL: @commuted_test_neg(
-; CHECK-NEXT: call i32 @llvm.usub.sat.i32(i32 %b, i32 239)
-; CHECK-NEXT: call i32 @llvm.usub.sat.i32(i32 %a, i32 223)
-; CHECK-NEXT: or i32
-; CHECK-NEXT: icmp eq i32
-; CHECK-NEXT: select
-; CHECK-NEXT: ret i32
+; CHECK-LABEL: define i32 @commuted_test_neg(
+; CHECK-SAME: i32 [[A:%.*]], i32 [[B:%.*]]) {
+; CHECK-NEXT: [[B_SUB:%.*]] = call i32 @llvm.usub.sat.i32(i32 [[B]], i32 239)
+; CHECK-NEXT: [[A_SUB:%.*]] = call i32 @llvm.usub.sat.i32(i32 [[A]], i32 223)
+; CHECK-NEXT: [[OR:%.*]] = or i32 [[B_SUB]], [[A_SUB]]
+; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[OR]], 0
+; CHECK-NEXT: [[RES:%.*]] = select i1 [[CMP]], i32 0, i32 -2147483648
+; CHECK-NEXT: ret i32 [[RES]]
+;
%b_sub = call i32 @llvm.usub.sat.i32(i32 %b, i32 239)
%a_sub = call i32 @llvm.usub.sat.i32(i32 %a, i32 223)
@@ -104,23 +135,72 @@ define i32 @commuted_test_neg(i32 %a, i32 %b) {
ret i32 %res
}
define <4 x i32> @vector_test(<4 x i32> %a, <4 x i32> %b) {
-; CHECK-LABEL: @vector_test(
-; CHECK-NEXT: call <4 x i32> @llvm.usub.sat.v4i32(<4 x i32> %a, <4 x i32> splat (i32 224))
-; CHECK-NEXT: call <4 x i32> @llvm.usub.sat.v4i32(<4 x i32> %b, <4 x i32> splat (i32 240))
-; CHECK-NEXT: or <4 x i32>
-; CHECK-NEXT: and <4 x i32>
-; CHECK-NEXT: ret <4 x i32>
+; CHECK-LABEL: define <4 x i32> @vector_test(
+; CHECK-SAME: <4 x i32> [[A:%.*]], <4 x i32> [[B:%.*]]) {
+; CHECK-NEXT: [[TMP1:%.*]] = call <4 x i32> @llvm.usub.sat.v4i32(<4 x i32> [[A]], <4 x i32> splat (i32 224))
+; CHECK-NEXT: [[TMP2:%.*]] = call <4 x i32> @llvm.usub.sat.v4i32(<4 x i32> [[B]], <4 x i32> splat (i32 240))
+; CHECK-NEXT: [[TMP3:%.*]] = or <4 x i32> [[TMP1]], [[TMP2]]
+; CHECK-NEXT: [[RES:%.*]] = and <4 x i32> [[TMP3]], splat (i32 -2147483648)
+; CHECK-NEXT: ret <4 x i32> [[RES]]
+;
+
+
+ %a_sub = call <4 x i32> @llvm.usub.sat.v4i32(
+ <4 x i32> %a,
+ <4 x i32> <i32 2147483871, i32 2147483871, i32 2147483871, i32 2147483871>)
+ %b_sub = call <4 x i32> @llvm.usub.sat.v4i32(
+ <4 x i32> %b,
+ <4 x i32> <i32 2147483887, i32 2147483887, i32 2147483887, i32 2147483887>)
+ %or = or <4 x i32> %a_sub, %b_sub
+ %cmp = icmp eq <4 x i32> %or, zeroinitializer
+ %res = select <4 x i1> %cmp, <4 x i32> zeroinitializer,
+ <4 x i32> <i32 -2147483648, i32 -2147483648, i32 -2147483648, i32 -2147483648>
+ ret <4 x i32> %res
+}
+
+define <4 x i32> @vector_negative_test(<4 x i32> %a, <4 x i32> %b) {
+; CHECK-LABEL: define <4 x i32> @vector_negative_test(
+; CHECK-SAME: <4 x i32> [[A:%.*]], <4 x i32> [[B:%.*]]) {
+; CHECK-NEXT: [[A_SUB:%.*]] = call <4 x i32> @llvm.usub.sat.v4i32(<4 x i32> [[A]], <4 x i32> <i32 -2147483425, i32 0, i32 -2147483425, i32 -2147483425>)
+; CHECK-NEXT: [[B_SUB:%.*]] = call <4 x i32> @llvm.usub.sat.v4i32(<4 x i32> [[B]], <4 x i32> splat (i32 -2147483409))
+; CHECK-NEXT: [[OR:%.*]] = or <4 x i32> [[A_SUB]], [[B_SUB]]
+; CHECK-NEXT: [[CMP:%.*]] = icmp eq <4 x i32> [[OR]], zeroinitializer
+; CHECK-NEXT: [[RES:%.*]] = select <4 x i1> [[CMP]], <4 x i32> zeroinitializer, <4 x i32> splat (i32 -2147483648)
+; CHECK-NEXT: ret <4 x i32> [[RES]]
+;
+ %a_sub = call <4 x i32> @llvm.usub.sat.v4i32(
+ <4 x i32> %a,
+ <4 x i32> <i32 2147483871, i32 0, i32 2147483871, i32 2147483871>)
+ %b_sub = call <4 x i32> @llvm.usub.sat.v4i32(
+ <4 x i32> %b,
+ <4 x i32> <i32 2147483887, i32 2147483887, i32 2147483887, i32 2147483887>)
+ %or = or <4 x i32> %a_sub, %b_sub
+ %cmp = icmp eq <4 x i32> %or, zeroinitializer
+ %res = select <4 x i1> %cmp, <4 x i32> zeroinitializer,
+ <4 x i32> <i32 -2147483648, i32 -2147483648, i32 -2147483648, i32 -2147483648>
+ ret <4 x i32> %res
+}
+
+define <4 x i32> @vector_ne_test(<4 x i32> %a, <4 x i32> %b) {
+; CHECK-LABEL: define <4 x i32> @vector_ne_test(
+; CHECK-SAME: <4 x i32> [[A:%.*]], <4 x i32> [[B:%.*]]) {
+; CHECK-NEXT: [[TMP1:%.*]] = call <4 x i32> @llvm.usub.sat.v4i32(<4 x i32> [[A]], <4 x i32> splat (i32 224))
+; CHECK-NEXT: [[TMP2:%.*]] = call <4 x i32> @llvm.usub.sat.v4i32(<4 x i32> [[B]], <4 x i32> splat (i32 240))
+; CHECK-NEXT: [[TMP3:%.*]] = or <4 x i32> [[TMP1]], [[TMP2]]
+; CHECK-NEXT: [[RES:%.*]] = and <4 x i32> [[TMP3]], splat (i32 -2147483648)
+; CHECK-NEXT: ret <4 x i32> [[RES]]
+;
%a_sub = call <4 x i32> @llvm.usub.sat.v4i32(
- <4 x i32> %a,
- <4 x i32> <i32 2147483871, i32 2147483871, i32 2147483871, i32 2147483871>)
+ <4 x i32> %a,
+ <4 x i32> <i32 2147483871, i32 2147483871, i32 2147483871, i32 2147483871>)
%b_sub = call <4 x i32> @llvm.usub.sat.v4i32(
- <4 x i32> %b,
- <4 x i32> <i32 2147483887, i32 2147483887, i32 2147483887, i32 2147483887>)
+ <4 x i32> %b,
+ <4 x i32> <i32 2147483887, i32 2147483887, i32 2147483887, i32 2147483887>)
%or = or <4 x i32> %a_sub, %b_sub
%cmp = icmp eq <4 x i32> %or, zeroinitializer
%res = select <4 x i1> %cmp, <4 x i32> zeroinitializer,
- <4 x i32> <i32 -2147483648, i32 -2147483648, i32 -2147483648, i32 -2147483648>
+ <4 x i32> <i32 -2147483648, i32 -2147483648, i32 -2147483648, i32 -2147483648>
ret <4 x i32> %res
}
>From 14f028966b280ad8f62a0c0009e71c8f8fee9d8f Mon Sep 17 00:00:00 2001
From: Nimit Sachdeva <nimsach at amazon.com>
Date: Sun, 24 Aug 2025 03:36:04 -0400
Subject: [PATCH 4/5] Change to APInt for scalar and splat vector
---
.../InstCombine/InstCombineSelect.cpp | 150 ++++++------------
1 file changed, 49 insertions(+), 101 deletions(-)
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
index cd0eb59007b3b..d2525e2facef3 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
@@ -2006,120 +2006,68 @@ Value *InstCombinerImpl::foldSelectWithConstOpToBinOp(ICmpInst *Cmp,
static Instruction *foldICmpUSubSatWithAndForMostSignificantBitCmp(
SelectInst &SI, ICmpInst *ICI, InstCombiner::BuilderTy &Builder) {
auto *CI = dyn_cast<ICmpInst>(SI.getCondition());
- if (!CI) {
+ if (!CI)
return nullptr;
- }
Value *CmpLHS = CI->getOperand(0);
Value *CmpRHS = CI->getOperand(1);
- if (!match(CmpRHS, m_Zero())) {
+ if (!match(CmpRHS, m_Zero()))
return nullptr;
- }
- auto Pred = CI->getPredicate();
- auto *TrueVal = SI.getTrueValue();
- auto *FalseVal = SI.getFalseValue();
- if (Pred != ICmpInst::ICMP_EQ && Pred != llvm::ICmpInst::ICMP_NE)
+ auto Pred = CI->getPredicate();
+ if (Pred != ICmpInst::ICMP_EQ && Pred != ICmpInst::ICMP_NE)
return nullptr;
- // Match: icmp eq (or (usub.sat A, IntConst1), (usub.sat B, IntConst2)), 0
Value *A, *B;
- ConstantInt *IntConst1, *IntConst2, *PossibleMSBInt;
-
- if (match(CmpLHS, m_Or(m_Intrinsic<Intrinsic::usub_sat>(
- m_Value(A), m_ConstantInt(IntConst1)),
- m_Intrinsic<Intrinsic::usub_sat>(
- m_Value(B), m_ConstantInt(IntConst2)))) &&
- (match(TrueVal, m_Zero()) &&
- match(FalseVal, m_ConstantInt(PossibleMSBInt)) ||
- match(TrueVal, m_ConstantInt(PossibleMSBInt)) &&
- match(FalseVal, m_Zero()))) {
- auto *Ty = A->getType();
- unsigned BW = Ty->getIntegerBitWidth();
- APInt MostSignificantBit = APInt::getOneBitSet(BW, BW - 1);
-
- if (PossibleMSBInt->getValue() != MostSignificantBit)
- return nullptr;
- // Ensure IntConst1 and IntConst2 are >= MostSignificantBit
- if (IntConst1->getValue().ult(MostSignificantBit) ||
- IntConst2->getValue().ult(MostSignificantBit))
- return nullptr;
+ const APInt *Constant1, *Constant2, *PossibleMSB;
+ if (!match(CmpLHS, m_Or(m_Intrinsic<Intrinsic::usub_sat>(m_Value(A),
+ m_APInt(Constant1)),
+ m_Intrinsic<Intrinsic::usub_sat>(
+ m_Value(B), m_APInt(Constant2)))))
+ return nullptr;
- // Rewrite:
- Value *NewA = Builder.CreateBinaryIntrinsic(
- Intrinsic::usub_sat, A,
- ConstantInt::get(Ty, IntConst1->getValue() - MostSignificantBit + 1));
- Value *NewB = Builder.CreateBinaryIntrinsic(
- Intrinsic::usub_sat, B,
- ConstantInt::get(Ty, IntConst2->getValue() - MostSignificantBit + 1));
- Value *Or = Builder.CreateOr(NewA, NewB);
- Value *And =
- Builder.CreateAnd(Or, ConstantInt::get(Ty, MostSignificantBit));
- return cast<Instruction>(And);
- }
- Constant *Const1, *Const2, *PossibleMSB;
- if (match(CmpLHS, m_Or(m_Intrinsic<Intrinsic::usub_sat>(m_Value(A),
- m_Constant(Const1)),
- m_Intrinsic<Intrinsic::usub_sat>(
- m_Value(B), m_Constant(Const2)))) &&
- (match(TrueVal, m_Zero()) && match(FalseVal, m_Constant(PossibleMSB))
- || match(TrueVal, m_Constant(PossibleMSB) ) && match(FalseVal, m_Zero()))) {
- auto *VecTy1 = dyn_cast<FixedVectorType>(Const1->getType());
- auto *VecTy2 = dyn_cast<FixedVectorType>(Const2->getType());
- auto *VecTyMSB = dyn_cast<FixedVectorType>(PossibleMSB->getType());
- if (!VecTy1 || !VecTy2 || !VecTyMSB) {
- return nullptr;
- }
+ Value *TrueVal = SI.getTrueValue();
+ Value *FalseVal = SI.getFalseValue();
+ if (!((match(TrueVal, m_Zero()) && match(FalseVal, m_APInt(PossibleMSB))) ||
+ (match(TrueVal, m_APInt(PossibleMSB)) && match(FalseVal, m_Zero()))))
+ return nullptr;
- unsigned NumElements = VecTy1->getNumElements();
+ auto *Ty = A->getType();
+ auto *VecTy = dyn_cast<VectorType>(Ty);
+ unsigned BW = PossibleMSB->getBitWidth();
+ APInt MostSignificantBit = APInt::getOneBitSet(BW, BW - 1);
- if (NumElements != VecTy2->getNumElements() ||
- NumElements != VecTyMSB->getNumElements() || NumElements == 0) {
- return nullptr;
- }
- auto *SplatMSB =
- dyn_cast<ConstantInt>(PossibleMSB->getAggregateElement(0u));
- unsigned BW = SplatMSB->getValue().getBitWidth();
- APInt MostSignificantBit = APInt::getOneBitSet(BW, BW - 1);
- if (!SplatMSB || SplatMSB->getValue() != MostSignificantBit) {
- return nullptr;
- }
- for (unsigned int i = 1; i < NumElements; ++i) {
- auto *Element =
- dyn_cast<ConstantInt>(PossibleMSB->getAggregateElement(i));
- if (!Element || Element->getValue() != SplatMSB->getValue()) {
- return nullptr;
- }
- }
- SmallVector<Constant *, 16> Arg1, Arg2;
- for (unsigned int i = 0; i < NumElements; ++i) {
- auto *E1 = dyn_cast<ConstantInt>(Const1->getAggregateElement(i));
- auto *E2 = dyn_cast<ConstantInt>(Const2->getAggregateElement(i));
- if (!E1 || !E2) {
- return nullptr;
- }
- if (E1->getValue().ult(SplatMSB->getValue()) ||
- E2->getValue().ult(SplatMSB->getValue())) {
- return nullptr;
- }
- Arg1.emplace_back(
- ConstantInt::get(A->getType()->getScalarType(),
- E1->getValue() - MostSignificantBit + 1));
- Arg2.emplace_back(
- ConstantInt::get(B->getType()->getScalarType(),
- E2->getValue() - MostSignificantBit + 1));
- }
- Constant *ConstVec1 = ConstantVector::get(Arg1);
- Constant *ConstVec2 = ConstantVector::get(Arg2);
- Value *NewA =
- Builder.CreateBinaryIntrinsic(Intrinsic::usub_sat, A, ConstVec1);
- Value *NewB =
- Builder.CreateBinaryIntrinsic(Intrinsic::usub_sat, B, ConstVec2);
- Value *Or = Builder.CreateOr(NewA, NewB);
- Value *And = Builder.CreateAnd(Or, PossibleMSB);
- return cast<Instruction>(And);
+ if (*PossibleMSB != MostSignificantBit ||
+ Constant1->ult(MostSignificantBit) || Constant2->ult(MostSignificantBit))
+ return nullptr;
+
+ APInt AdjAP1 = *Constant1 - MostSignificantBit + 1;
+ APInt AdjAP2 = *Constant2 - MostSignificantBit + 1;
+
+ Constant *Adj1, *Adj2;
+ if (VecTy) {
+ Constant *Elt1 = ConstantInt::get(VecTy->getElementType(), AdjAP1);
+ Constant *Elt2 = ConstantInt::get(VecTy->getElementType(), AdjAP2);
+ Adj1 = ConstantVector::getSplat(VecTy->getElementCount(), Elt1);
+ Adj2 = ConstantVector::getSplat(VecTy->getElementCount(), Elt2);
+ } else {
+ Adj1 = ConstantInt::get(Ty, AdjAP1);
+ Adj2 = ConstantInt::get(Ty, AdjAP2);
}
- return nullptr;
+
+ Value *NewA = Builder.CreateBinaryIntrinsic(Intrinsic::usub_sat, A, Adj1);
+ Value *NewB = Builder.CreateBinaryIntrinsic(Intrinsic::usub_sat, B, Adj2);
+ Value *Or = Builder.CreateOr(NewA, NewB);
+ Constant *MSBConst;
+ if (VecTy) {
+ MSBConst = ConstantVector::getSplat(
+ VecTy->getElementCount(),
+ ConstantInt::get(VecTy->getScalarType(), *PossibleMSB));
+ } else {
+ MSBConst = ConstantInt::get(Ty->getScalarType(), *PossibleMSB);
+ }
+ Value *And = Builder.CreateAnd(Or, MSBConst);
+ return cast<Instruction>(And);
}
/// Visit a SelectInst that has an ICmpInst as its first operand.
>From fb6c7369c24e97ae0d2ac63bf0c74bf04b5590c7 Mon Sep 17 00:00:00 2001
From: Nimit Sachdeva <nimsach at amazon.com>
Date: Sun, 24 Aug 2025 18:39:08 -0400
Subject: [PATCH 5/5] apply suggestions from code review
---
.../InstCombine/InstCombineSelect.cpp | 69 +++++++------------
1 file changed, 23 insertions(+), 46 deletions(-)
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
index d2525e2facef3..b2720fc113cad 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
@@ -2005,69 +2005,45 @@ Value *InstCombinerImpl::foldSelectWithConstOpToBinOp(ICmpInst *Cmp,
/// Likewise, for vector arguments as well.
static Instruction *foldICmpUSubSatWithAndForMostSignificantBitCmp(
SelectInst &SI, ICmpInst *ICI, InstCombiner::BuilderTy &Builder) {
- auto *CI = dyn_cast<ICmpInst>(SI.getCondition());
- if (!CI)
- return nullptr;
-
- Value *CmpLHS = CI->getOperand(0);
- Value *CmpRHS = CI->getOperand(1);
- if (!match(CmpRHS, m_Zero()))
- return nullptr;
-
- auto Pred = CI->getPredicate();
- if (Pred != ICmpInst::ICMP_EQ && Pred != ICmpInst::ICMP_NE)
- return nullptr;
-
+ CmpPredicate Pred;
Value *A, *B;
- const APInt *Constant1, *Constant2, *PossibleMSB;
- if (!match(CmpLHS, m_Or(m_Intrinsic<Intrinsic::usub_sat>(m_Value(A),
- m_APInt(Constant1)),
- m_Intrinsic<Intrinsic::usub_sat>(
- m_Value(B), m_APInt(Constant2)))))
+ const APInt *Constant1, *Constant2;
+ if (!match(SI.getCondition(),
+ m_ICmp(Pred,
+ m_Or(m_Intrinsic<Intrinsic::usub_sat>(m_Value(A),
+ m_APInt(Constant1)),
+ m_Intrinsic<Intrinsic::usub_sat>(m_Value(B),
+ m_APInt(Constant2))),
+ m_Zero())))
return nullptr;
Value *TrueVal = SI.getTrueValue();
Value *FalseVal = SI.getFalseValue();
- if (!((match(TrueVal, m_Zero()) && match(FalseVal, m_APInt(PossibleMSB))) ||
- (match(TrueVal, m_APInt(PossibleMSB)) && match(FalseVal, m_Zero()))))
+ if (!(Pred == ICmpInst::ICMP_EQ &&
+ (match(TrueVal, m_Zero()) && match(FalseVal, m_SignMask()))) ||
+ (Pred == ICmpInst::ICMP_NE &&
+ (match(TrueVal, m_SignMask()) && match(FalseVal, m_Zero()))))
return nullptr;
auto *Ty = A->getType();
- auto *VecTy = dyn_cast<VectorType>(Ty);
- unsigned BW = PossibleMSB->getBitWidth();
- APInt MostSignificantBit = APInt::getOneBitSet(BW, BW - 1);
+ unsigned BW = Constant1->getBitWidth();
+ APInt MostSignificantBit = APInt::getSignMask(BW);
- if (*PossibleMSB != MostSignificantBit ||
- Constant1->ult(MostSignificantBit) || Constant2->ult(MostSignificantBit))
+ // Anything over MSB is negative
+ if (Constant1->isNonNegative() || Constant2->isNonNegative())
return nullptr;
APInt AdjAP1 = *Constant1 - MostSignificantBit + 1;
APInt AdjAP2 = *Constant2 - MostSignificantBit + 1;
- Constant *Adj1, *Adj2;
- if (VecTy) {
- Constant *Elt1 = ConstantInt::get(VecTy->getElementType(), AdjAP1);
- Constant *Elt2 = ConstantInt::get(VecTy->getElementType(), AdjAP2);
- Adj1 = ConstantVector::getSplat(VecTy->getElementCount(), Elt1);
- Adj2 = ConstantVector::getSplat(VecTy->getElementCount(), Elt2);
- } else {
- Adj1 = ConstantInt::get(Ty, AdjAP1);
- Adj2 = ConstantInt::get(Ty, AdjAP2);
- }
+ auto *Adj1 = ConstantInt::get(Ty, AdjAP1);
+ auto *Adj2 = ConstantInt::get(Ty, AdjAP2);
Value *NewA = Builder.CreateBinaryIntrinsic(Intrinsic::usub_sat, A, Adj1);
Value *NewB = Builder.CreateBinaryIntrinsic(Intrinsic::usub_sat, B, Adj2);
Value *Or = Builder.CreateOr(NewA, NewB);
- Constant *MSBConst;
- if (VecTy) {
- MSBConst = ConstantVector::getSplat(
- VecTy->getElementCount(),
- ConstantInt::get(VecTy->getScalarType(), *PossibleMSB));
- } else {
- MSBConst = ConstantInt::get(Ty->getScalarType(), *PossibleMSB);
- }
- Value *And = Builder.CreateAnd(Or, MSBConst);
- return cast<Instruction>(And);
+ Constant *MSBConst = ConstantInt::get(Ty, MostSignificantBit);
+ return BinaryOperator::CreateAnd(Or, MSBConst);
}
/// Visit a SelectInst that has an ICmpInst as its first operand.
@@ -2088,7 +2064,8 @@ Instruction *InstCombinerImpl::foldSelectInstWithICmp(SelectInst &SI,
return NewSel;
if (Instruction *Folded =
foldICmpUSubSatWithAndForMostSignificantBitCmp(SI, ICI, Builder))
- return replaceInstUsesWith(SI, Folded);
+ return Folded;
+ ;
// NOTE: if we wanted to, this is where to detect integer MIN/MAX
bool Changed = false;
More information about the llvm-commits
mailing list