[llvm] 0242a6a - [InstCombine] Support splat vectors in some or of icmp folds
Nikita Popov via llvm-commits
llvm-commits at lists.llvm.org
Wed Nov 10 13:59:17 PST 2021
Author: Nikita Popov
Date: 2021-11-10T22:59:09+01:00
New Revision: 0242a6adf73a7cdee55e90989d1d075b607320e1
URL: https://github.com/llvm/llvm-project/commit/0242a6adf73a7cdee55e90989d1d075b607320e1
DIFF: https://github.com/llvm/llvm-project/commit/0242a6adf73a7cdee55e90989d1d075b607320e1.diff
LOG: [InstCombine] Support splat vectors in some or of icmp folds
Replace m_ConstantInt() with m_APInt() in order to support splat
constants in addition to scalar integers.
Added:
Modified:
llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
llvm/test/Transforms/InstCombine/or.ll
Removed:
################################################################################
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
index fae19549f081..230046a20a22 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
@@ -2340,8 +2340,9 @@ Value *InstCombinerImpl::foldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS,
ICmpInst::Predicate PredL = LHS->getPredicate(), PredR = RHS->getPredicate();
Value *LHS0 = LHS->getOperand(0), *RHS0 = RHS->getOperand(0);
Value *LHS1 = LHS->getOperand(1), *RHS1 = RHS->getOperand(1);
- auto *LHSC = dyn_cast<ConstantInt>(LHS1);
- auto *RHSC = dyn_cast<ConstantInt>(RHS1);
+ const APInt *LHSC = nullptr, *RHSC = nullptr;
+ match(LHS1, m_APInt(LHSC));
+ match(RHS1, m_APInt(RHSC));
// Fold (icmp ult/ule (A + C1), C3) | (icmp ult/ule (A + C2), C3)
// --> (icmp ult/ule ((A & ~(C1 ^ C2)) + max(C1, C2)), C3)
@@ -2355,40 +2356,41 @@ Value *InstCombinerImpl::foldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS,
// This implies all values in the two ranges
diff er by exactly one bit.
if ((PredL == ICmpInst::ICMP_ULT || PredL == ICmpInst::ICMP_ULE) &&
PredL == PredR && LHSC && RHSC && LHS->hasOneUse() && RHS->hasOneUse() &&
- LHSC->getType() == RHSC->getType() &&
- LHSC->getValue() == (RHSC->getValue())) {
+ LHSC->getBitWidth() == RHSC->getBitWidth() && *LHSC == *RHSC) {
Value *AddOpnd;
- ConstantInt *LAddC, *RAddC;
- if (match(LHS0, m_Add(m_Value(AddOpnd), m_ConstantInt(LAddC))) &&
- match(RHS0, m_Add(m_Specific(AddOpnd), m_ConstantInt(RAddC))) &&
- LAddC->getValue().ugt(LHSC->getValue()) &&
- RAddC->getValue().ugt(LHSC->getValue())) {
+ const APInt *LAddC, *RAddC;
+ if (match(LHS0, m_Add(m_Value(AddOpnd), m_APInt(LAddC))) &&
+ match(RHS0, m_Add(m_Specific(AddOpnd), m_APInt(RAddC))) &&
+ LAddC->ugt(*LHSC) && RAddC->ugt(*LHSC)) {
- APInt DiffC = LAddC->getValue() ^ RAddC->getValue();
+ APInt DiffC = *LAddC ^ *RAddC;
if (DiffC.isPowerOf2()) {
- ConstantInt *MaxAddC = nullptr;
- if (LAddC->getValue().ult(RAddC->getValue()))
+ const APInt *MaxAddC = nullptr;
+ if (LAddC->ult(*RAddC))
MaxAddC = RAddC;
else
MaxAddC = LAddC;
- APInt RRangeLow = -RAddC->getValue();
- APInt RRangeHigh = RRangeLow + LHSC->getValue();
- APInt LRangeLow = -LAddC->getValue();
- APInt LRangeHigh = LRangeLow + LHSC->getValue();
+ APInt RRangeLow = -*RAddC;
+ APInt RRangeHigh = RRangeLow + *LHSC;
+ APInt LRangeLow = -*LAddC;
+ APInt LRangeHigh = LRangeLow + *LHSC;
APInt LowRangeDiff = RRangeLow ^ LRangeLow;
APInt HighRangeDiff = RRangeHigh ^ LRangeHigh;
APInt RangeDiff = LRangeLow.sgt(RRangeLow) ? LRangeLow - RRangeLow
: RRangeLow - LRangeLow;
if (LowRangeDiff.isPowerOf2() && LowRangeDiff == HighRangeDiff &&
- RangeDiff.ugt(LHSC->getValue())) {
- Value *MaskC = ConstantInt::get(LAddC->getType(), ~DiffC);
+ RangeDiff.ugt(*LHSC)) {
+ Type *Ty = AddOpnd->getType();
+ Value *MaskC = ConstantInt::get(Ty, ~DiffC);
Value *NewAnd = Builder.CreateAnd(AddOpnd, MaskC);
- Value *NewAdd = Builder.CreateAdd(NewAnd, MaxAddC);
- return Builder.CreateICmp(LHS->getPredicate(), NewAdd, LHSC);
+ Value *NewAdd = Builder.CreateAdd(NewAnd,
+ ConstantInt::get(Ty, *MaxAddC));
+ return Builder.CreateICmp(LHS->getPredicate(), NewAdd,
+ ConstantInt::get(Ty, *LHSC));
}
}
}
@@ -2480,8 +2482,7 @@ Value *InstCombinerImpl::foldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS,
if (!LHSC || !RHSC)
return nullptr;
- return foldAndOrOfICmpsUsingRanges(PredL, LHS0, LHSC->getValue(),
- PredR, RHS0, RHSC->getValue(),
+ return foldAndOrOfICmpsUsingRanges(PredL, LHS0, *LHSC, PredR, RHS0, *RHSC,
Builder, /* IsAnd */ false);
}
diff --git a/llvm/test/Transforms/InstCombine/or.ll b/llvm/test/Transforms/InstCombine/or.ll
index ce64662d191b..013183a390ac 100644
--- a/llvm/test/Transforms/InstCombine/or.ll
+++ b/llvm/test/Transforms/InstCombine/or.ll
@@ -121,13 +121,11 @@ define i1 @test18_logical(i32 %A) {
ret i1 %D
}
-; FIXME: Vectors should fold too.
define <2 x i1> @test18vec(<2 x i32> %A) {
; CHECK-LABEL: @test18vec(
-; CHECK-NEXT: [[B:%.*]] = icmp sgt <2 x i32> [[A:%.*]], <i32 99, i32 99>
-; CHECK-NEXT: [[C:%.*]] = icmp slt <2 x i32> [[A]], <i32 50, i32 50>
-; CHECK-NEXT: [[D:%.*]] = or <2 x i1> [[B]], [[C]]
-; CHECK-NEXT: ret <2 x i1> [[D]]
+; CHECK-NEXT: [[TMP1:%.*]] = add <2 x i32> [[A:%.*]], <i32 -100, i32 -100>
+; CHECK-NEXT: [[TMP2:%.*]] = icmp ult <2 x i32> [[TMP1]], <i32 -50, i32 -50>
+; CHECK-NEXT: ret <2 x i1> [[TMP2]]
;
%B = icmp sge <2 x i32> %A, <i32 100, i32 100>
%C = icmp slt <2 x i32> %A, <i32 50, i32 50>
@@ -481,9 +479,9 @@ define i1 @test36_logical(i32 %x) {
define i1 @test37(i32 %x) {
; CHECK-LABEL: @test37(
-; CHECK-NEXT: [[ADD1:%.*]] = add i32 [[X:%.*]], 7
-; CHECK-NEXT: [[TMP1:%.*]] = icmp ult i32 [[ADD1]], 31
-; CHECK-NEXT: ret i1 [[TMP1]]
+; CHECK-NEXT: [[TMP1:%.*]] = add i32 [[X:%.*]], 7
+; CHECK-NEXT: [[TMP2:%.*]] = icmp ult i32 [[TMP1]], 31
+; CHECK-NEXT: ret i1 [[TMP2]]
;
%add1 = add i32 %x, 7
%cmp1 = icmp ult i32 %add1, 30
@@ -494,9 +492,9 @@ define i1 @test37(i32 %x) {
define i1 @test37_logical(i32 %x) {
; CHECK-LABEL: @test37_logical(
-; CHECK-NEXT: [[ADD1:%.*]] = add i32 [[X:%.*]], 7
-; CHECK-NEXT: [[TMP1:%.*]] = icmp ult i32 [[ADD1]], 31
-; CHECK-NEXT: ret i1 [[TMP1]]
+; CHECK-NEXT: [[TMP1:%.*]] = add i32 [[X:%.*]], 7
+; CHECK-NEXT: [[TMP2:%.*]] = icmp ult i32 [[TMP1]], 31
+; CHECK-NEXT: ret i1 [[TMP2]]
;
%add1 = add i32 %x, 7
%cmp1 = icmp ult i32 %add1, 30
@@ -507,11 +505,9 @@ define i1 @test37_logical(i32 %x) {
define <2 x i1> @test37_uniform(<2 x i32> %x) {
; CHECK-LABEL: @test37_uniform(
-; CHECK-NEXT: [[ADD1:%.*]] = add <2 x i32> [[X:%.*]], <i32 7, i32 7>
-; CHECK-NEXT: [[CMP1:%.*]] = icmp ult <2 x i32> [[ADD1]], <i32 30, i32 30>
-; CHECK-NEXT: [[CMP2:%.*]] = icmp eq <2 x i32> [[X]], <i32 23, i32 23>
-; CHECK-NEXT: [[RET1:%.*]] = or <2 x i1> [[CMP1]], [[CMP2]]
-; CHECK-NEXT: ret <2 x i1> [[RET1]]
+; CHECK-NEXT: [[TMP1:%.*]] = add <2 x i32> [[X:%.*]], <i32 7, i32 7>
+; CHECK-NEXT: [[TMP2:%.*]] = icmp ult <2 x i32> [[TMP1]], <i32 31, i32 31>
+; CHECK-NEXT: ret <2 x i1> [[TMP2]]
;
%add1 = add <2 x i32> %x, <i32 7, i32 7>
%cmp1 = icmp ult <2 x i32> %add1, <i32 30, i32 30>
@@ -792,12 +788,10 @@ define i1 @test46_logical(i8 signext %c) {
define <2 x i1> @test46_uniform(<2 x i8> %c) {
; CHECK-LABEL: @test46_uniform(
-; CHECK-NEXT: [[C_OFF:%.*]] = add <2 x i8> [[C:%.*]], <i8 -97, i8 -97>
-; CHECK-NEXT: [[CMP1:%.*]] = icmp ult <2 x i8> [[C_OFF]], <i8 26, i8 26>
-; CHECK-NEXT: [[C_OFF17:%.*]] = add <2 x i8> [[C]], <i8 -65, i8 -65>
-; CHECK-NEXT: [[CMP2:%.*]] = icmp ult <2 x i8> [[C_OFF17]], <i8 26, i8 26>
-; CHECK-NEXT: [[OR:%.*]] = or <2 x i1> [[CMP1]], [[CMP2]]
-; CHECK-NEXT: ret <2 x i1> [[OR]]
+; CHECK-NEXT: [[TMP1:%.*]] = and <2 x i8> [[C:%.*]], <i8 -33, i8 -33>
+; CHECK-NEXT: [[TMP2:%.*]] = add <2 x i8> [[TMP1]], <i8 -65, i8 -65>
+; CHECK-NEXT: [[TMP3:%.*]] = icmp ult <2 x i8> [[TMP2]], <i8 26, i8 26>
+; CHECK-NEXT: ret <2 x i1> [[TMP3]]
;
%c.off = add <2 x i8> %c, <i8 -97, i8 -97>
%cmp1 = icmp ult <2 x i8> %c.off, <i8 26, i8 26>
More information about the llvm-commits
mailing list