[llvm] 554fc9a - [InstCombine] `vector_reduce_smax(?ext(<n x i1>))` --> `?ext(vector_reduce_{and,or}(<n x i1>))` (PR51259)

Roman Lebedev via llvm-commits llvm-commits at lists.llvm.org
Mon Aug 2 14:33:29 PDT 2021


Author: Roman Lebedev
Date: 2021-08-03T00:29:06+03:00
New Revision: 554fc9ad0a24f6689c61d080c9451edd2ddc90b1

URL: https://github.com/llvm/llvm-project/commit/554fc9ad0a24f6689c61d080c9451edd2ddc90b1
DIFF: https://github.com/llvm/llvm-project/commit/554fc9ad0a24f6689c61d080c9451edd2ddc90b1.diff

LOG: [InstCombine] `vector_reduce_smax(?ext(<n x i1>))` --> `?ext(vector_reduce_{and,or}(<n x i1>))` (PR51259)

Alive2 agrees:
https://alive2.llvm.org/ce/z/3oqir9 (self)
https://alive2.llvm.org/ce/z/6cuI5m (zext)
https://alive2.llvm.org/ce/z/4FL8rD (sext)

We already handle `vector_reduce_and(<n x i1>)`,
so let's just combine into the already-handled pattern
and let the existing fold do the rest.

Added: 
    

Modified: 
    llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
    llvm/test/Transforms/InstCombine/reduction-smax-sext-zext-i1.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
index 4ef2503b4852..6df3f27700ba 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
@@ -2094,22 +2094,24 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
     }
     LLVM_FALLTHROUGH;
   }
-  case Intrinsic::vector_reduce_smin: {
-    if (IID == Intrinsic::vector_reduce_smin) {
-      // SMin reduction over the vector with (potentially-extended)
+  case Intrinsic::vector_reduce_smin:
+  case Intrinsic::vector_reduce_smax: {
+    if (IID == Intrinsic::vector_reduce_smin ||
+        IID == Intrinsic::vector_reduce_smax) {
+      // SMin/SMax reduction over the vector with (potentially-extended)
       // i1 element type is actually a (potentially-extended)
       // logical `and`/`or` reduction over the original non-extended value:
-      //   vector_reduce_smin(<n x i1>)
+      //   vector_reduce_s{min,max}(<n x i1>)
       //     -->
-      //   vector_reduce_or(<n x i1>)
+      //   vector_reduce_{or,and}(<n x i1>)
       // and
-      //   vector_reduce_smin(sext(<n x i1>))
+      //   vector_reduce_s{min,max}(sext(<n x i1>))
       //     -->
-      //   sext(vector_reduce_or(<n x i1>))
+      //   sext(vector_reduce_{or,and}(<n x i1>))
       // and
-      //   vector_reduce_smin(zext(<n x i1>))
+      //   vector_reduce_s{min,max}(zext(<n x i1>))
       //     -->
-      //   zext(vector_reduce_and(<n x i1>))
+      //   zext(vector_reduce_{and,or}(<n x i1>))
       Value *Arg = II->getArgOperand(0);
       Value *Vect;
       if (match(Arg, m_ZExtOrSExtOrSelf(m_Value(Vect)))) {
@@ -2118,7 +2120,8 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
             Instruction::CastOps ExtOpc = Instruction::CastOps::CastOpsEnd;
             if (Arg != Vect)
               ExtOpc = cast<CastInst>(Arg)->getOpcode();
-            Value *Res = ExtOpc == Instruction::CastOps::ZExt
+            Value *Res = ((IID == Intrinsic::vector_reduce_smin) ==
+                          (ExtOpc == Instruction::CastOps::ZExt))
                              ? Builder.CreateAndReduce(Vect)
                              : Builder.CreateOrReduce(Vect);
             if (Arg != Vect)
@@ -2129,7 +2132,6 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
     }
     LLVM_FALLTHROUGH;
   }
-  case Intrinsic::vector_reduce_smax:
   case Intrinsic::vector_reduce_fmax:
   case Intrinsic::vector_reduce_fmin:
   case Intrinsic::vector_reduce_fadd:

diff  --git a/llvm/test/Transforms/InstCombine/reduction-smax-sext-zext-i1.ll b/llvm/test/Transforms/InstCombine/reduction-smax-sext-zext-i1.ll
index bf26a3276d75..da027edf5876 100644
--- a/llvm/test/Transforms/InstCombine/reduction-smax-sext-zext-i1.ll
+++ b/llvm/test/Transforms/InstCombine/reduction-smax-sext-zext-i1.ll
@@ -3,8 +3,9 @@
 
 define i1 @reduce_smax_self(<8 x i1> %x) {
 ; CHECK-LABEL: @reduce_smax_self(
-; CHECK-NEXT:    [[RES:%.*]] = call i1 @llvm.vector.reduce.smax.v8i1(<8 x i1> [[X:%.*]])
-; CHECK-NEXT:    ret i1 [[RES]]
+; CHECK-NEXT:    [[TMP1:%.*]] = bitcast <8 x i1> [[X:%.*]] to i8
+; CHECK-NEXT:    [[TMP2:%.*]] = icmp eq i8 [[TMP1]], -1
+; CHECK-NEXT:    ret i1 [[TMP2]]
 ;
   %res = call i1 @llvm.vector.reduce.smax.v8i32(<8 x i1> %x)
   ret i1 %res
@@ -12,9 +13,10 @@ define i1 @reduce_smax_self(<8 x i1> %x) {
 
 define i32 @reduce_smax_sext(<4 x i1> %x) {
 ; CHECK-LABEL: @reduce_smax_sext(
-; CHECK-NEXT:    [[SEXT:%.*]] = sext <4 x i1> [[X:%.*]] to <4 x i32>
-; CHECK-NEXT:    [[RES:%.*]] = call i32 @llvm.vector.reduce.smax.v4i32(<4 x i32> [[SEXT]])
-; CHECK-NEXT:    ret i32 [[RES]]
+; CHECK-NEXT:    [[TMP1:%.*]] = bitcast <4 x i1> [[X:%.*]] to i4
+; CHECK-NEXT:    [[TMP2:%.*]] = icmp eq i4 [[TMP1]], -1
+; CHECK-NEXT:    [[TMP3:%.*]] = sext i1 [[TMP2]] to i32
+; CHECK-NEXT:    ret i32 [[TMP3]]
 ;
   %sext = sext <4 x i1> %x to <4 x i32>
   %res = call i32 @llvm.vector.reduce.smax.v4i32(<4 x i32> %sext)
@@ -23,9 +25,10 @@ define i32 @reduce_smax_sext(<4 x i1> %x) {
 
 define i64 @reduce_smax_zext(<8 x i1> %x) {
 ; CHECK-LABEL: @reduce_smax_zext(
-; CHECK-NEXT:    [[ZEXT:%.*]] = zext <8 x i1> [[X:%.*]] to <8 x i64>
-; CHECK-NEXT:    [[RES:%.*]] = call i64 @llvm.vector.reduce.smax.v8i64(<8 x i64> [[ZEXT]])
-; CHECK-NEXT:    ret i64 [[RES]]
+; CHECK-NEXT:    [[TMP1:%.*]] = bitcast <8 x i1> [[X:%.*]] to i8
+; CHECK-NEXT:    [[TMP2:%.*]] = icmp ne i8 [[TMP1]], 0
+; CHECK-NEXT:    [[TMP3:%.*]] = zext i1 [[TMP2]] to i64
+; CHECK-NEXT:    ret i64 [[TMP3]]
 ;
   %zext = zext <8 x i1> %x to <8 x i64>
   %res = call i64 @llvm.vector.reduce.smax.v8i64(<8 x i64> %zext)
@@ -34,9 +37,10 @@ define i64 @reduce_smax_zext(<8 x i1> %x) {
 
 define i16 @reduce_smax_sext_same(<16 x i1> %x) {
 ; CHECK-LABEL: @reduce_smax_sext_same(
-; CHECK-NEXT:    [[SEXT:%.*]] = sext <16 x i1> [[X:%.*]] to <16 x i16>
-; CHECK-NEXT:    [[RES:%.*]] = call i16 @llvm.vector.reduce.smax.v16i16(<16 x i16> [[SEXT]])
-; CHECK-NEXT:    ret i16 [[RES]]
+; CHECK-NEXT:    [[TMP1:%.*]] = bitcast <16 x i1> [[X:%.*]] to i16
+; CHECK-NEXT:    [[TMP2:%.*]] = icmp eq i16 [[TMP1]], -1
+; CHECK-NEXT:    [[TMP3:%.*]] = sext i1 [[TMP2]] to i16
+; CHECK-NEXT:    ret i16 [[TMP3]]
 ;
   %sext = sext <16 x i1> %x to <16 x i16>
   %res = call i16 @llvm.vector.reduce.smax.v16i16(<16 x i16> %sext)
@@ -45,9 +49,10 @@ define i16 @reduce_smax_sext_same(<16 x i1> %x) {
 
 define i8 @reduce_smax_zext_long(<128 x i1> %x) {
 ; CHECK-LABEL: @reduce_smax_zext_long(
-; CHECK-NEXT:    [[SEXT:%.*]] = sext <128 x i1> [[X:%.*]] to <128 x i8>
-; CHECK-NEXT:    [[RES:%.*]] = call i8 @llvm.vector.reduce.smax.v128i8(<128 x i8> [[SEXT]])
-; CHECK-NEXT:    ret i8 [[RES]]
+; CHECK-NEXT:    [[TMP1:%.*]] = bitcast <128 x i1> [[X:%.*]] to i128
+; CHECK-NEXT:    [[TMP2:%.*]] = icmp eq i128 [[TMP1]], -1
+; CHECK-NEXT:    [[TMP3:%.*]] = sext i1 [[TMP2]] to i8
+; CHECK-NEXT:    ret i8 [[TMP3]]
 ;
   %sext = sext <128 x i1> %x to <128 x i8>
   %res = call i8 @llvm.vector.reduce.smax.v128i8(<128 x i8> %sext)
@@ -57,11 +62,13 @@ define i8 @reduce_smax_zext_long(<128 x i1> %x) {
 @glob = external global i8, align 1
 define i8 @reduce_smax_zext_long_external_use(<128 x i1> %x) {
 ; CHECK-LABEL: @reduce_smax_zext_long_external_use(
-; CHECK-NEXT:    [[SEXT:%.*]] = sext <128 x i1> [[X:%.*]] to <128 x i8>
-; CHECK-NEXT:    [[RES:%.*]] = call i8 @llvm.vector.reduce.smax.v128i8(<128 x i8> [[SEXT]])
-; CHECK-NEXT:    [[EXT:%.*]] = extractelement <128 x i8> [[SEXT]], i32 0
+; CHECK-NEXT:    [[TMP1:%.*]] = bitcast <128 x i1> [[X:%.*]] to i128
+; CHECK-NEXT:    [[TMP2:%.*]] = icmp eq i128 [[TMP1]], -1
+; CHECK-NEXT:    [[TMP3:%.*]] = sext i1 [[TMP2]] to i8
+; CHECK-NEXT:    [[TMP4:%.*]] = extractelement <128 x i1> [[X]], i32 0
+; CHECK-NEXT:    [[EXT:%.*]] = sext i1 [[TMP4]] to i8
 ; CHECK-NEXT:    store i8 [[EXT]], i8* @glob, align 1
-; CHECK-NEXT:    ret i8 [[RES]]
+; CHECK-NEXT:    ret i8 [[TMP3]]
 ;
   %sext = sext <128 x i1> %x to <128 x i8>
   %res = call i8 @llvm.vector.reduce.smax.v128i8(<128 x i8> %sext)
@@ -73,11 +80,13 @@ define i8 @reduce_smax_zext_long_external_use(<128 x i1> %x) {
 @glob1 = external global i64, align 8
 define i64 @reduce_smax_zext_external_use(<8 x i1> %x) {
 ; CHECK-LABEL: @reduce_smax_zext_external_use(
-; CHECK-NEXT:    [[ZEXT:%.*]] = zext <8 x i1> [[X:%.*]] to <8 x i64>
-; CHECK-NEXT:    [[RES:%.*]] = call i64 @llvm.vector.reduce.smax.v8i64(<8 x i64> [[ZEXT]])
-; CHECK-NEXT:    [[EXT:%.*]] = extractelement <8 x i64> [[ZEXT]], i32 0
+; CHECK-NEXT:    [[TMP1:%.*]] = bitcast <8 x i1> [[X:%.*]] to i8
+; CHECK-NEXT:    [[TMP2:%.*]] = icmp ne i8 [[TMP1]], 0
+; CHECK-NEXT:    [[TMP3:%.*]] = zext i1 [[TMP2]] to i64
+; CHECK-NEXT:    [[TMP4:%.*]] = extractelement <8 x i1> [[X]], i32 0
+; CHECK-NEXT:    [[EXT:%.*]] = zext i1 [[TMP4]] to i64
 ; CHECK-NEXT:    store i64 [[EXT]], i64* @glob1, align 8
-; CHECK-NEXT:    ret i64 [[RES]]
+; CHECK-NEXT:    ret i64 [[TMP3]]
 ;
   %zext = zext <8 x i1> %x to <8 x i64>
   %res = call i64 @llvm.vector.reduce.smax.v8i64(<8 x i64> %zext)


        


More information about the llvm-commits mailing list