[llvm] 3b194ca - Recommit "[InstCombine] Fold and-reduce idiom"

Max Kazantsev via llvm-commits llvm-commits at lists.llvm.org
Fri Jan 28 20:39:37 PST 2022


Author: Max Kazantsev
Date: 2022-01-29T11:27:48+07:00
New Revision: 3b194ca7ab37a212ab54e698d22415753dc6a197

URL: https://github.com/llvm/llvm-project/commit/3b194ca7ab37a212ab54e698d22415753dc6a197
DIFF: https://github.com/llvm/llvm-project/commit/3b194ca7ab37a212ab54e698d22415753dc6a197.diff

LOG: Recommit "[InstCombine] Fold and-reduce idiom"

Checks of original vector types made more thorough.

Differential Revision: https://reviews.llvm.org/D118317

Added: 
    

Modified: 
    llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
    llvm/test/Transforms/InstCombine/icmp-vec.ll
    llvm/test/Transforms/InstCombine/reduction-and-sext-zext-i1.ll
    llvm/test/Transforms/InstCombine/reduction-or-sext-zext-i1.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
index fd58a44504b3c..677403a55fbb5 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -5882,6 +5882,54 @@ static Instruction *foldICmpInvariantGroup(ICmpInst &I) {
   return nullptr;
 }
 
+/// This function folds patterns produced by lowering of reduce idioms, such as
+/// llvm.vector.reduce.and which are lowered into instruction chains. This code
+/// attempts to generate fewer number of scalar comparisons instead of vector
+/// comparisons when possible.
+static Instruction *foldReductionIdiom(ICmpInst &I,
+                                       InstCombiner::BuilderTy &Builder,
+                                       const DataLayout &DL) {
+  if (I.getType()->isVectorTy())
+    return nullptr;
+  ICmpInst::Predicate OuterPred, InnerPred;
+  Value *LHS, *RHS;
+
+  // Match lowering of @llvm.vector.reduce.and. Turn
+  ///   %vec_ne = icmp ne <8 x i8> %lhs, %rhs
+  ///   %scalar_ne = bitcast <8 x i1> %vec_ne to i8
+  ///   %all_eq = icmp eq i8 %scalar_ne, 0
+  ///
+  /// into
+  ///
+  ///   %lhs.scalar = bitcast <8 x i8> %lhs to i64
+  ///   %rhs.scalar = bitcast <8 x i8> %rhs to i64
+  ///   %all_eq = icmp eq i64 %lhs.scalar, %rhs.scalar
+  if (!match(&I, m_ICmp(OuterPred,
+                        m_OneUse(m_BitCast(m_OneUse(
+                            m_ICmp(InnerPred, m_Value(LHS), m_Value(RHS))))),
+                        m_Zero())))
+    return nullptr;
+  auto *LHSTy = dyn_cast<FixedVectorType>(LHS->getType());
+  if (!LHSTy || !LHSTy->getElementType()->isIntegerTy())
+    return nullptr;
+  unsigned NumBits =
+      LHSTy->getNumElements() * LHSTy->getElementType()->getIntegerBitWidth();
+  // TODO: Relax this to "not wider than max legal integer type"?
+  if (!DL.isLegalInteger(NumBits))
+    return nullptr;
+
+  // TODO: Generalize to isEquality and support other patterns.
+  if (OuterPred == ICmpInst::ICMP_EQ && InnerPred == ICmpInst::ICMP_NE) {
+    auto *ScalarTy = Builder.getIntNTy(NumBits);
+    LHS = Builder.CreateBitCast(LHS, ScalarTy, LHS->getName() + ".scalar");
+    RHS = Builder.CreateBitCast(RHS, ScalarTy, RHS->getName() + ".scalar");
+    return ICmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_EQ, LHS, RHS,
+                            I.getName());
+  }
+
+  return nullptr;
+}
+
 Instruction *InstCombinerImpl::visitICmpInst(ICmpInst &I) {
   bool Changed = false;
   const SimplifyQuery Q = SQ.getWithInstruction(&I);
@@ -6124,6 +6172,9 @@ Instruction *InstCombinerImpl::visitICmpInst(ICmpInst &I) {
   if (Instruction *Res = foldICmpInvariantGroup(I))
     return Res;
 
+  if (Instruction *Res = foldReductionIdiom(I, Builder, DL))
+    return Res;
+
   return Changed ? &I : nullptr;
 }
 

diff  --git a/llvm/test/Transforms/InstCombine/icmp-vec.ll b/llvm/test/Transforms/InstCombine/icmp-vec.ll
index 3068c1b5dcffd..2888a826fdc7b 100644
--- a/llvm/test/Transforms/InstCombine/icmp-vec.ll
+++ b/llvm/test/Transforms/InstCombine/icmp-vec.ll
@@ -404,9 +404,9 @@ define <vscale x 2 x i1> @icmp_logical_or_scalablevec(<vscale x 2 x i64> %x, <vs
 
 define i1 @eq_cast_eq-1(<2 x i4> %x, <2 x i4> %y) {
 ; CHECK-LABEL: @eq_cast_eq-1(
-; CHECK-NEXT:    [[IC:%.*]] = icmp ne <2 x i4> [[X:%.*]], [[Y:%.*]]
-; CHECK-NEXT:    [[TMP1:%.*]] = bitcast <2 x i1> [[IC]] to i2
-; CHECK-NEXT:    [[R:%.*]] = icmp eq i2 [[TMP1]], 0
+; CHECK-NEXT:    [[X_SCALAR:%.*]] = bitcast <2 x i4> [[X:%.*]] to i8
+; CHECK-NEXT:    [[Y_SCALAR:%.*]] = bitcast <2 x i4> [[Y:%.*]] to i8
+; CHECK-NEXT:    [[R:%.*]] = icmp eq i8 [[X_SCALAR]], [[Y_SCALAR]]
 ; CHECK-NEXT:    ret i1 [[R]]
 ;
   %ic = icmp eq <2 x i4> %x, %y

diff  --git a/llvm/test/Transforms/InstCombine/reduction-and-sext-zext-i1.ll b/llvm/test/Transforms/InstCombine/reduction-and-sext-zext-i1.ll
index bed201d30773b..a38d26b1f8bf5 100644
--- a/llvm/test/Transforms/InstCombine/reduction-and-sext-zext-i1.ll
+++ b/llvm/test/Transforms/InstCombine/reduction-and-sext-zext-i1.ll
@@ -100,14 +100,12 @@ define i64 @reduce_and_zext_external_use(<8 x i1> %x) {
 define i1 @reduce_and_pointer_cast(i8* %arg, i8* %arg1) {
 ; CHECK-LABEL: @reduce_and_pointer_cast(
 ; CHECK-NEXT:  bb:
-; CHECK-NEXT:    [[PTR1:%.*]] = bitcast i8* [[ARG1:%.*]] to <8 x i8>*
-; CHECK-NEXT:    [[PTR2:%.*]] = bitcast i8* [[ARG:%.*]] to <8 x i8>*
-; CHECK-NEXT:    [[LHS:%.*]] = load <8 x i8>, <8 x i8>* [[PTR1]], align 8
-; CHECK-NEXT:    [[RHS:%.*]] = load <8 x i8>, <8 x i8>* [[PTR2]], align 8
-; CHECK-NEXT:    [[CMP:%.*]] = icmp ne <8 x i8> [[LHS]], [[RHS]]
-; CHECK-NEXT:    [[TMP0:%.*]] = bitcast <8 x i1> [[CMP]] to i8
-; CHECK-NEXT:    [[TMP1:%.*]] = icmp eq i8 [[TMP0]], 0
-; CHECK-NEXT:    ret i1 [[TMP1]]
+; CHECK-NEXT:    [[TMP0:%.*]] = bitcast i8* [[ARG1:%.*]] to i64*
+; CHECK-NEXT:    [[LHS1:%.*]] = load i64, i64* [[TMP0]], align 8
+; CHECK-NEXT:    [[TMP1:%.*]] = bitcast i8* [[ARG:%.*]] to i64*
+; CHECK-NEXT:    [[RHS2:%.*]] = load i64, i64* [[TMP1]], align 8
+; CHECK-NEXT:    [[TMP2:%.*]] = icmp eq i64 [[LHS1]], [[RHS2]]
+; CHECK-NEXT:    ret i1 [[TMP2]]
 ;
 bb:
   %ptr1 = bitcast i8* %arg1 to <8 x i8>*
@@ -144,14 +142,12 @@ bb:
 define i1 @reduce_and_pointer_cast_ne(i8* %arg, i8* %arg1) {
 ; CHECK-LABEL: @reduce_and_pointer_cast_ne(
 ; CHECK-NEXT:  bb:
-; CHECK-NEXT:    [[PTR1:%.*]] = bitcast i8* [[ARG1:%.*]] to <8 x i8>*
-; CHECK-NEXT:    [[PTR2:%.*]] = bitcast i8* [[ARG:%.*]] to <8 x i8>*
-; CHECK-NEXT:    [[LHS:%.*]] = load <8 x i8>, <8 x i8>* [[PTR1]], align 8
-; CHECK-NEXT:    [[RHS:%.*]] = load <8 x i8>, <8 x i8>* [[PTR2]], align 8
-; CHECK-NEXT:    [[CMP:%.*]] = icmp ne <8 x i8> [[LHS]], [[RHS]]
-; CHECK-NEXT:    [[TMP0:%.*]] = bitcast <8 x i1> [[CMP]] to i8
-; CHECK-NEXT:    [[TMP1:%.*]] = icmp ne i8 [[TMP0]], 0
-; CHECK-NEXT:    ret i1 [[TMP1]]
+; CHECK-NEXT:    [[TMP0:%.*]] = bitcast i8* [[ARG1:%.*]] to i64*
+; CHECK-NEXT:    [[LHS1:%.*]] = load i64, i64* [[TMP0]], align 8
+; CHECK-NEXT:    [[TMP1:%.*]] = bitcast i8* [[ARG:%.*]] to i64*
+; CHECK-NEXT:    [[RHS2:%.*]] = load i64, i64* [[TMP1]], align 8
+; CHECK-NEXT:    [[TMP2:%.*]] = icmp ne i64 [[LHS1]], [[RHS2]]
+; CHECK-NEXT:    ret i1 [[TMP2]]
 ;
 bb:
   %ptr1 = bitcast i8* %arg1 to <8 x i8>*

diff  --git a/llvm/test/Transforms/InstCombine/reduction-or-sext-zext-i1.ll b/llvm/test/Transforms/InstCombine/reduction-or-sext-zext-i1.ll
index e6cfddb424888..35174d4321565 100644
--- a/llvm/test/Transforms/InstCombine/reduction-or-sext-zext-i1.ll
+++ b/llvm/test/Transforms/InstCombine/reduction-or-sext-zext-i1.ll
@@ -100,13 +100,11 @@ define i64 @reduce_or_zext_external_use(<8 x i1> %x) {
 define i1 @reduce_or_pointer_cast(i8* %arg, i8* %arg1) {
 ; CHECK-LABEL: @reduce_or_pointer_cast(
 ; CHECK-NEXT:  bb:
-; CHECK-NEXT:    [[PTR1:%.*]] = bitcast i8* [[ARG1:%.*]] to <8 x i8>*
-; CHECK-NEXT:    [[PTR2:%.*]] = bitcast i8* [[ARG:%.*]] to <8 x i8>*
-; CHECK-NEXT:    [[LHS:%.*]] = load <8 x i8>, <8 x i8>* [[PTR1]], align 8
-; CHECK-NEXT:    [[RHS:%.*]] = load <8 x i8>, <8 x i8>* [[PTR2]], align 8
-; CHECK-NEXT:    [[CMP:%.*]] = icmp ne <8 x i8> [[LHS]], [[RHS]]
-; CHECK-NEXT:    [[TMP0:%.*]] = bitcast <8 x i1> [[CMP]] to i8
-; CHECK-NEXT:    [[DOTNOT:%.*]] = icmp eq i8 [[TMP0]], 0
+; CHECK-NEXT:    [[TMP0:%.*]] = bitcast i8* [[ARG1:%.*]] to i64*
+; CHECK-NEXT:    [[LHS1:%.*]] = load i64, i64* [[TMP0]], align 8
+; CHECK-NEXT:    [[TMP1:%.*]] = bitcast i8* [[ARG:%.*]] to i64*
+; CHECK-NEXT:    [[RHS2:%.*]] = load i64, i64* [[TMP1]], align 8
+; CHECK-NEXT:    [[DOTNOT:%.*]] = icmp eq i64 [[LHS1]], [[RHS2]]
 ; CHECK-NEXT:    ret i1 [[DOTNOT]]
 ;
 bb:


        


More information about the llvm-commits mailing list