[PATCH] D118431: [InstCombine] Generalize and-reduce pattern to handle `ne` case as well as `eq`

Max Kazantsev via Phabricator via llvm-commits llvm-commits at lists.llvm.org
Thu Jan 27 21:38:18 PST 2022


mkazantsev created this revision.
mkazantsev added reviewers: spatel, lebedev.ri, RKSimon, dmakogon.
Herald added a subscriber: hiraditya.
mkazantsev requested review of this revision.
Herald added a project: LLVM.
Herald added a subscriber: llvm-commits.

Following Sanjay's proposal from discussion in D118317 <https://reviews.llvm.org/D118317>, this patch
generalizes and-reduce handling to fold the following pattern

  icmp eq (bitcast(icmp ne (lhs, rhs)), 0)

into

  icmp eq (bitcast(lhs), bitcast(rhs))


https://reviews.llvm.org/D118431

Files:
  llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
  llvm/test/Transforms/InstCombine/icmp-vec.ll
  llvm/test/Transforms/InstCombine/reduction-or-sext-zext-i1.ll


Index: llvm/test/Transforms/InstCombine/reduction-or-sext-zext-i1.ll
===================================================================
--- llvm/test/Transforms/InstCombine/reduction-or-sext-zext-i1.ll
+++ llvm/test/Transforms/InstCombine/reduction-or-sext-zext-i1.ll
@@ -145,14 +145,12 @@
 define i1 @reduce_or_pointer_cast_ne(i8* %arg, i8* %arg1) {
 ; CHECK-LABEL: @reduce_or_pointer_cast_ne(
 ; CHECK-NEXT:  bb:
-; CHECK-NEXT:    [[PTR1:%.*]] = bitcast i8* [[ARG1:%.*]] to <8 x i8>*
-; CHECK-NEXT:    [[PTR2:%.*]] = bitcast i8* [[ARG:%.*]] to <8 x i8>*
-; CHECK-NEXT:    [[LHS:%.*]] = load <8 x i8>, <8 x i8>* [[PTR1]], align 8
-; CHECK-NEXT:    [[RHS:%.*]] = load <8 x i8>, <8 x i8>* [[PTR2]], align 8
-; CHECK-NEXT:    [[CMP:%.*]] = icmp ne <8 x i8> [[LHS]], [[RHS]]
-; CHECK-NEXT:    [[TMP0:%.*]] = bitcast <8 x i1> [[CMP]] to i8
-; CHECK-NEXT:    [[TMP1:%.*]] = icmp ne i8 [[TMP0]], 0
-; CHECK-NEXT:    ret i1 [[TMP1]]
+; CHECK-NEXT:    [[TMP0:%.*]] = bitcast i8* [[ARG1:%.*]] to i64*
+; CHECK-NEXT:    [[LHS1:%.*]] = load i64, i64* [[TMP0]], align 8
+; CHECK-NEXT:    [[TMP1:%.*]] = bitcast i8* [[ARG:%.*]] to i64*
+; CHECK-NEXT:    [[RHS2:%.*]] = load i64, i64* [[TMP1]], align 8
+; CHECK-NEXT:    [[TMP2:%.*]] = icmp ne i64 [[LHS1]], [[RHS2]]
+; CHECK-NEXT:    ret i1 [[TMP2]]
 ;
 bb:
   %ptr1 = bitcast i8* %arg1 to <8 x i8>*
Index: llvm/test/Transforms/InstCombine/icmp-vec.ll
===================================================================
--- llvm/test/Transforms/InstCombine/icmp-vec.ll
+++ llvm/test/Transforms/InstCombine/icmp-vec.ll
@@ -443,9 +443,9 @@
 
 define i1 @eq_cast_ne-1-legal-scalar(<2 x i8> %x, <2 x i8> %y) {
 ; CHECK-LABEL: @eq_cast_ne-1-legal-scalar(
-; CHECK-NEXT:    [[IC:%.*]] = icmp ne <2 x i8> [[X:%.*]], [[Y:%.*]]
-; CHECK-NEXT:    [[TMP1:%.*]] = bitcast <2 x i1> [[IC]] to i2
-; CHECK-NEXT:    [[R:%.*]] = icmp ne i2 [[TMP1]], 0
+; CHECK-NEXT:    [[X_SCALAR:%.*]] = bitcast <2 x i8> [[X:%.*]] to i16
+; CHECK-NEXT:    [[Y_SCALAR:%.*]] = bitcast <2 x i8> [[Y:%.*]] to i16
+; CHECK-NEXT:    [[R:%.*]] = icmp ne i16 [[X_SCALAR]], [[Y_SCALAR]]
 ; CHECK-NEXT:    ret i1 [[R]]
 ;
   %ic = icmp eq <2 x i8> %x, %y
Index: llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
===================================================================
--- llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -5897,13 +5897,15 @@
   // Match lowering of @llvm.vector.reduce.and. Turn
   ///   %vec_ne = icmp ne <8 x i8> %lhs, %rhs
   ///   %scalar_ne = bitcast <8 x i1> %vec_ne to i8
-  ///   %all_eq = icmp eq i8 %scalar_ne, 0
+  ///   %res = icmp <pred> i8 %scalar_ne, 0
   ///
   /// into
   ///
   ///   %lhs.scalar = bitcast <8 x i8> %lhs to i64
   ///   %rhs.scalar = bitcast <8 x i8> %rhs to i64
-  ///   %all_eq = icmp eq i64 %lhs.scalar, %rhs.scalar
+  ///   %res = icmp <pred> i64 %lhs.scalar, %rhs.scalar
+  ///
+  /// for <pred> in {ne, eq}.
   if (!match(&I, m_ICmp(OuterPred,
                         m_OneUse(m_BitCast(m_OneUse(
                             m_ICmp(InnerPred, m_Value(LHS), m_Value(RHS))))),
@@ -5918,12 +5920,11 @@
   if (!DL.isLegalInteger(NumBits))
     return nullptr;
 
-  // TODO: Generalize to isEquality and support other patterns.
-  if (OuterPred == ICmpInst::ICMP_EQ && InnerPred == ICmpInst::ICMP_NE) {
+  if (ICmpInst::isEquality(OuterPred) && InnerPred == ICmpInst::ICMP_NE) {
     auto *ScalarTy = Builder.getIntNTy(NumBits);
     LHS = Builder.CreateBitCast(LHS, ScalarTy, LHS->getName() + ".scalar");
     RHS = Builder.CreateBitCast(RHS, ScalarTy, RHS->getName() + ".scalar");
-    return ICmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_EQ, LHS, RHS,
+    return ICmpInst::Create(Instruction::ICmp, OuterPred, LHS, RHS,
                             I.getName());
   }
 


-------------- next part --------------
A non-text attachment was scrubbed...
Name: D118431.403881.patch
Type: text/x-patch
Size: 3818 bytes
Desc: not available
URL: <http://lists.llvm.org/pipermail/llvm-commits/attachments/20220128/30506e0c/attachment.bin>


More information about the llvm-commits mailing list