[llvm] 70b3beb - [InstCombine] Generalize and-reduce pattern to handle `ne` case as well as `eq`
Max Kazantsev via llvm-commits
llvm-commits at lists.llvm.org
Sun Jan 30 21:14:19 PST 2022
Author: Max Kazantsev
Date: 2022-01-31T12:14:08+07:00
New Revision: 70b3beb0e22dd0eb33c6fcef019a24f9f1f09ef9
URL: https://github.com/llvm/llvm-project/commit/70b3beb0e22dd0eb33c6fcef019a24f9f1f09ef9
DIFF: https://github.com/llvm/llvm-project/commit/70b3beb0e22dd0eb33c6fcef019a24f9f1f09ef9.diff
LOG: [InstCombine] Generalize and-reduce pattern to handle `ne` case as well as `eq`
Following Sanjay's proposal from discussion in D118317, this patch
generalizes and-reduce handling to fold the following pattern
```
icmp ne (bitcast(icmp ne (lhs, rhs)), 0)
```
into
```
icmp ne (bitcast(lhs), bitcast(rhs))
```
https://alive2.llvm.org/ce/z/WDcuJ_
Differential Revision: https://reviews.llvm.org/D118431
Reviewed By: lebedev.ri
Added:
Modified:
llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
llvm/test/Transforms/InstCombine/icmp-vec.ll
llvm/test/Transforms/InstCombine/reduction-or-sext-zext-i1.ll
Removed:
################################################################################
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
index 677403a55fbb5..e45be5745fccd 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -5897,13 +5897,15 @@ static Instruction *foldReductionIdiom(ICmpInst &I,
// Match lowering of @llvm.vector.reduce.and. Turn
/// %vec_ne = icmp ne <8 x i8> %lhs, %rhs
/// %scalar_ne = bitcast <8 x i1> %vec_ne to i8
- /// %all_eq = icmp eq i8 %scalar_ne, 0
+ /// %res = icmp <pred> i8 %scalar_ne, 0
///
/// into
///
/// %lhs.scalar = bitcast <8 x i8> %lhs to i64
/// %rhs.scalar = bitcast <8 x i8> %rhs to i64
- /// %all_eq = icmp eq i64 %lhs.scalar, %rhs.scalar
+ /// %res = icmp <pred> i64 %lhs.scalar, %rhs.scalar
+ ///
+ /// for <pred> in {ne, eq}.
if (!match(&I, m_ICmp(OuterPred,
m_OneUse(m_BitCast(m_OneUse(
m_ICmp(InnerPred, m_Value(LHS), m_Value(RHS))))),
@@ -5918,12 +5920,11 @@ static Instruction *foldReductionIdiom(ICmpInst &I,
if (!DL.isLegalInteger(NumBits))
return nullptr;
- // TODO: Generalize to isEquality and support other patterns.
- if (OuterPred == ICmpInst::ICMP_EQ && InnerPred == ICmpInst::ICMP_NE) {
+ if (ICmpInst::isEquality(OuterPred) && InnerPred == ICmpInst::ICMP_NE) {
auto *ScalarTy = Builder.getIntNTy(NumBits);
LHS = Builder.CreateBitCast(LHS, ScalarTy, LHS->getName() + ".scalar");
RHS = Builder.CreateBitCast(RHS, ScalarTy, RHS->getName() + ".scalar");
- return ICmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_EQ, LHS, RHS,
+ return ICmpInst::Create(Instruction::ICmp, OuterPred, LHS, RHS,
I.getName());
}
diff --git a/llvm/test/Transforms/InstCombine/icmp-vec.ll b/llvm/test/Transforms/InstCombine/icmp-vec.ll
index 2888a826fdc7b..50c063220f6f6 100644
--- a/llvm/test/Transforms/InstCombine/icmp-vec.ll
+++ b/llvm/test/Transforms/InstCombine/icmp-vec.ll
@@ -443,9 +443,9 @@ define i1 @eq_cast_ne-1(<2 x i7> %x, <2 x i7> %y) {
define i1 @eq_cast_ne-1-legal-scalar(<2 x i8> %x, <2 x i8> %y) {
; CHECK-LABEL: @eq_cast_ne-1-legal-scalar(
-; CHECK-NEXT: [[IC:%.*]] = icmp ne <2 x i8> [[X:%.*]], [[Y:%.*]]
-; CHECK-NEXT: [[TMP1:%.*]] = bitcast <2 x i1> [[IC]] to i2
-; CHECK-NEXT: [[R:%.*]] = icmp ne i2 [[TMP1]], 0
+; CHECK-NEXT: [[X_SCALAR:%.*]] = bitcast <2 x i8> [[X:%.*]] to i16
+; CHECK-NEXT: [[Y_SCALAR:%.*]] = bitcast <2 x i8> [[Y:%.*]] to i16
+; CHECK-NEXT: [[R:%.*]] = icmp ne i16 [[X_SCALAR]], [[Y_SCALAR]]
; CHECK-NEXT: ret i1 [[R]]
;
%ic = icmp eq <2 x i8> %x, %y
diff --git a/llvm/test/Transforms/InstCombine/reduction-or-sext-zext-i1.ll b/llvm/test/Transforms/InstCombine/reduction-or-sext-zext-i1.ll
index 35174d4321565..cd1b10a76ade4 100644
--- a/llvm/test/Transforms/InstCombine/reduction-or-sext-zext-i1.ll
+++ b/llvm/test/Transforms/InstCombine/reduction-or-sext-zext-i1.ll
@@ -145,14 +145,12 @@ bb:
define i1 @reduce_or_pointer_cast_ne(i8* %arg, i8* %arg1) {
; CHECK-LABEL: @reduce_or_pointer_cast_ne(
; CHECK-NEXT: bb:
-; CHECK-NEXT: [[PTR1:%.*]] = bitcast i8* [[ARG1:%.*]] to <8 x i8>*
-; CHECK-NEXT: [[PTR2:%.*]] = bitcast i8* [[ARG:%.*]] to <8 x i8>*
-; CHECK-NEXT: [[LHS:%.*]] = load <8 x i8>, <8 x i8>* [[PTR1]], align 8
-; CHECK-NEXT: [[RHS:%.*]] = load <8 x i8>, <8 x i8>* [[PTR2]], align 8
-; CHECK-NEXT: [[CMP:%.*]] = icmp ne <8 x i8> [[LHS]], [[RHS]]
-; CHECK-NEXT: [[TMP0:%.*]] = bitcast <8 x i1> [[CMP]] to i8
-; CHECK-NEXT: [[TMP1:%.*]] = icmp ne i8 [[TMP0]], 0
-; CHECK-NEXT: ret i1 [[TMP1]]
+; CHECK-NEXT: [[TMP0:%.*]] = bitcast i8* [[ARG1:%.*]] to i64*
+; CHECK-NEXT: [[LHS1:%.*]] = load i64, i64* [[TMP0]], align 8
+; CHECK-NEXT: [[TMP1:%.*]] = bitcast i8* [[ARG:%.*]] to i64*
+; CHECK-NEXT: [[RHS2:%.*]] = load i64, i64* [[TMP1]], align 8
+; CHECK-NEXT: [[TMP2:%.*]] = icmp ne i64 [[LHS1]], [[RHS2]]
+; CHECK-NEXT: ret i1 [[TMP2]]
;
bb:
%ptr1 = bitcast i8* %arg1 to <8 x i8>*
More information about the llvm-commits
mailing list