[llvm] 4ba3326 - [InstCombine] `vector_reduce_{or,and}(?ext(<n x i1>))` --> `?ext(vector_reduce_{or,and}(<n x i1>))` (PR51259)
Roman Lebedev via llvm-commits
llvm-commits at lists.llvm.org
Mon Aug 2 14:54:59 PDT 2021
Author: Roman Lebedev
Date: 2021-08-03T00:54:35+03:00
New Revision: 4ba3326f17ddabc1f427508a927a987d812ac543
URL: https://github.com/llvm/llvm-project/commit/4ba3326f17ddabc1f427508a927a987d812ac543
DIFF: https://github.com/llvm/llvm-project/commit/4ba3326f17ddabc1f427508a927a987d812ac543.diff
LOG: [InstCombine] `vector_reduce_{or,and}(?ext(<n x i1>))` --> `?ext(vector_reduce_{or,and}(<n x i1>))` (PR51259)
This allows the expansion logic to actually trigger if the argument
was extended from i1 element type, like the rest of the reductions expect.
Alive2 agrees:
https://alive2.llvm.org/ce/z/wcfews (or zext)
https://alive2.llvm.org/ce/z/FCXNFx (or sext)
https://alive2.llvm.org/ce/z/f26zUY (and zext)
https://alive2.llvm.org/ce/z/jprViN (and sext)
Added:
Modified:
llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
llvm/test/Transforms/InstCombine/reduction-and-sext-zext-i1.ll
llvm/test/Transforms/InstCombine/reduction-or-sext-zext-i1.ll
Removed:
################################################################################
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
index 6df3f27700ba..b3ade5593e9e 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
@@ -1973,21 +1973,26 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
// %val = bitcast <ReduxWidth x i1> to iReduxWidth
// %res = cmp eq iReduxWidth %val, 11111
Value *Arg = II->getArgOperand(0);
- Type *RetTy = II->getType();
- if (RetTy == Builder.getInt1Ty())
- if (auto *FVTy = dyn_cast<FixedVectorType>(Arg->getType())) {
- Value *Res = Builder.CreateBitCast(
- Arg, Builder.getIntNTy(FVTy->getNumElements()));
- if (IID == Intrinsic::vector_reduce_and) {
- Res = Builder.CreateICmpEQ(
- Res, ConstantInt::getAllOnesValue(Res->getType()));
- } else {
- assert(IID == Intrinsic::vector_reduce_or &&
- "Expected or reduction.");
- Res = Builder.CreateIsNotNull(Res);
+ Value *Vect;
+ if (match(Arg, m_ZExtOrSExtOrSelf(m_Value(Vect)))) {
+ if (auto *FTy = dyn_cast<FixedVectorType>(Vect->getType()))
+ if (FTy->getElementType() == Builder.getInt1Ty()) {
+ Value *Res = Builder.CreateBitCast(
+ Vect, Builder.getIntNTy(FTy->getNumElements()));
+ if (IID == Intrinsic::vector_reduce_and) {
+ Res = Builder.CreateICmpEQ(
+ Res, ConstantInt::getAllOnesValue(Res->getType()));
+ } else {
+ assert(IID == Intrinsic::vector_reduce_or &&
+ "Expected or reduction.");
+ Res = Builder.CreateIsNotNull(Res);
+ }
+ if (Arg != Vect)
+ Res = Builder.CreateCast(cast<CastInst>(Arg)->getOpcode(), Res,
+ II->getType());
+ return replaceInstUsesWith(CI, Res);
}
- return replaceInstUsesWith(CI, Res);
- }
+ }
LLVM_FALLTHROUGH;
}
case Intrinsic::vector_reduce_add: {
diff --git a/llvm/test/Transforms/InstCombine/reduction-and-sext-zext-i1.ll b/llvm/test/Transforms/InstCombine/reduction-and-sext-zext-i1.ll
index c8a77990a43e..cad992e262d8 100644
--- a/llvm/test/Transforms/InstCombine/reduction-and-sext-zext-i1.ll
+++ b/llvm/test/Transforms/InstCombine/reduction-and-sext-zext-i1.ll
@@ -13,9 +13,10 @@ define i1 @reduce_and_self(<8 x i1> %x) {
define i32 @reduce_and_sext(<4 x i1> %x) {
; CHECK-LABEL: @reduce_and_sext(
-; CHECK-NEXT: [[SEXT:%.*]] = sext <4 x i1> [[X:%.*]] to <4 x i32>
-; CHECK-NEXT: [[RES:%.*]] = call i32 @llvm.vector.reduce.and.v4i32(<4 x i32> [[SEXT]])
-; CHECK-NEXT: ret i32 [[RES]]
+; CHECK-NEXT: [[TMP1:%.*]] = bitcast <4 x i1> [[X:%.*]] to i4
+; CHECK-NEXT: [[TMP2:%.*]] = icmp eq i4 [[TMP1]], -1
+; CHECK-NEXT: [[TMP3:%.*]] = sext i1 [[TMP2]] to i32
+; CHECK-NEXT: ret i32 [[TMP3]]
;
%sext = sext <4 x i1> %x to <4 x i32>
%res = call i32 @llvm.vector.reduce.and.v4i32(<4 x i32> %sext)
@@ -24,9 +25,10 @@ define i32 @reduce_and_sext(<4 x i1> %x) {
define i64 @reduce_and_zext(<8 x i1> %x) {
; CHECK-LABEL: @reduce_and_zext(
-; CHECK-NEXT: [[ZEXT:%.*]] = zext <8 x i1> [[X:%.*]] to <8 x i64>
-; CHECK-NEXT: [[RES:%.*]] = call i64 @llvm.vector.reduce.and.v8i64(<8 x i64> [[ZEXT]])
-; CHECK-NEXT: ret i64 [[RES]]
+; CHECK-NEXT: [[TMP1:%.*]] = bitcast <8 x i1> [[X:%.*]] to i8
+; CHECK-NEXT: [[TMP2:%.*]] = icmp eq i8 [[TMP1]], -1
+; CHECK-NEXT: [[TMP3:%.*]] = zext i1 [[TMP2]] to i64
+; CHECK-NEXT: ret i64 [[TMP3]]
;
%zext = zext <8 x i1> %x to <8 x i64>
%res = call i64 @llvm.vector.reduce.and.v8i64(<8 x i64> %zext)
@@ -35,9 +37,10 @@ define i64 @reduce_and_zext(<8 x i1> %x) {
define i16 @reduce_and_sext_same(<16 x i1> %x) {
; CHECK-LABEL: @reduce_and_sext_same(
-; CHECK-NEXT: [[SEXT:%.*]] = sext <16 x i1> [[X:%.*]] to <16 x i16>
-; CHECK-NEXT: [[RES:%.*]] = call i16 @llvm.vector.reduce.and.v16i16(<16 x i16> [[SEXT]])
-; CHECK-NEXT: ret i16 [[RES]]
+; CHECK-NEXT: [[TMP1:%.*]] = bitcast <16 x i1> [[X:%.*]] to i16
+; CHECK-NEXT: [[TMP2:%.*]] = icmp eq i16 [[TMP1]], -1
+; CHECK-NEXT: [[TMP3:%.*]] = sext i1 [[TMP2]] to i16
+; CHECK-NEXT: ret i16 [[TMP3]]
;
%sext = sext <16 x i1> %x to <16 x i16>
%res = call i16 @llvm.vector.reduce.and.v16i16(<16 x i16> %sext)
@@ -46,9 +49,10 @@ define i16 @reduce_and_sext_same(<16 x i1> %x) {
define i8 @reduce_and_zext_long(<128 x i1> %x) {
; CHECK-LABEL: @reduce_and_zext_long(
-; CHECK-NEXT: [[SEXT:%.*]] = sext <128 x i1> [[X:%.*]] to <128 x i8>
-; CHECK-NEXT: [[RES:%.*]] = call i8 @llvm.vector.reduce.and.v128i8(<128 x i8> [[SEXT]])
-; CHECK-NEXT: ret i8 [[RES]]
+; CHECK-NEXT: [[TMP1:%.*]] = bitcast <128 x i1> [[X:%.*]] to i128
+; CHECK-NEXT: [[TMP2:%.*]] = icmp eq i128 [[TMP1]], -1
+; CHECK-NEXT: [[TMP3:%.*]] = sext i1 [[TMP2]] to i8
+; CHECK-NEXT: ret i8 [[TMP3]]
;
%sext = sext <128 x i1> %x to <128 x i8>
%res = call i8 @llvm.vector.reduce.and.v128i8(<128 x i8> %sext)
@@ -58,11 +62,13 @@ define i8 @reduce_and_zext_long(<128 x i1> %x) {
@glob = external global i8, align 1
define i8 @reduce_and_zext_long_external_use(<128 x i1> %x) {
; CHECK-LABEL: @reduce_and_zext_long_external_use(
-; CHECK-NEXT: [[SEXT:%.*]] = sext <128 x i1> [[X:%.*]] to <128 x i8>
-; CHECK-NEXT: [[RES:%.*]] = call i8 @llvm.vector.reduce.and.v128i8(<128 x i8> [[SEXT]])
-; CHECK-NEXT: [[EXT:%.*]] = extractelement <128 x i8> [[SEXT]], i32 0
+; CHECK-NEXT: [[TMP1:%.*]] = bitcast <128 x i1> [[X:%.*]] to i128
+; CHECK-NEXT: [[TMP2:%.*]] = icmp eq i128 [[TMP1]], -1
+; CHECK-NEXT: [[TMP3:%.*]] = sext i1 [[TMP2]] to i8
+; CHECK-NEXT: [[TMP4:%.*]] = extractelement <128 x i1> [[X]], i32 0
+; CHECK-NEXT: [[EXT:%.*]] = sext i1 [[TMP4]] to i8
; CHECK-NEXT: store i8 [[EXT]], i8* @glob, align 1
-; CHECK-NEXT: ret i8 [[RES]]
+; CHECK-NEXT: ret i8 [[TMP3]]
;
%sext = sext <128 x i1> %x to <128 x i8>
%res = call i8 @llvm.vector.reduce.and.v128i8(<128 x i8> %sext)
@@ -74,11 +80,13 @@ define i8 @reduce_and_zext_long_external_use(<128 x i1> %x) {
@glob1 = external global i64, align 8
define i64 @reduce_and_zext_external_use(<8 x i1> %x) {
; CHECK-LABEL: @reduce_and_zext_external_use(
-; CHECK-NEXT: [[ZEXT:%.*]] = zext <8 x i1> [[X:%.*]] to <8 x i64>
-; CHECK-NEXT: [[RES:%.*]] = call i64 @llvm.vector.reduce.and.v8i64(<8 x i64> [[ZEXT]])
-; CHECK-NEXT: [[EXT:%.*]] = extractelement <8 x i64> [[ZEXT]], i32 0
+; CHECK-NEXT: [[TMP1:%.*]] = bitcast <8 x i1> [[X:%.*]] to i8
+; CHECK-NEXT: [[TMP2:%.*]] = icmp eq i8 [[TMP1]], -1
+; CHECK-NEXT: [[TMP3:%.*]] = zext i1 [[TMP2]] to i64
+; CHECK-NEXT: [[TMP4:%.*]] = extractelement <8 x i1> [[X]], i32 0
+; CHECK-NEXT: [[EXT:%.*]] = zext i1 [[TMP4]] to i64
; CHECK-NEXT: store i64 [[EXT]], i64* @glob1, align 8
-; CHECK-NEXT: ret i64 [[RES]]
+; CHECK-NEXT: ret i64 [[TMP3]]
;
%zext = zext <8 x i1> %x to <8 x i64>
%res = call i64 @llvm.vector.reduce.and.v8i64(<8 x i64> %zext)
diff --git a/llvm/test/Transforms/InstCombine/reduction-or-sext-zext-i1.ll b/llvm/test/Transforms/InstCombine/reduction-or-sext-zext-i1.ll
index fd7c1723e701..a52bab9f25aa 100644
--- a/llvm/test/Transforms/InstCombine/reduction-or-sext-zext-i1.ll
+++ b/llvm/test/Transforms/InstCombine/reduction-or-sext-zext-i1.ll
@@ -13,9 +13,10 @@ define i1 @reduce_or_self(<8 x i1> %x) {
define i32 @reduce_or_sext(<4 x i1> %x) {
; CHECK-LABEL: @reduce_or_sext(
-; CHECK-NEXT: [[SEXT:%.*]] = sext <4 x i1> [[X:%.*]] to <4 x i32>
-; CHECK-NEXT: [[RES:%.*]] = call i32 @llvm.vector.reduce.or.v4i32(<4 x i32> [[SEXT]])
-; CHECK-NEXT: ret i32 [[RES]]
+; CHECK-NEXT: [[TMP1:%.*]] = bitcast <4 x i1> [[X:%.*]] to i4
+; CHECK-NEXT: [[TMP2:%.*]] = icmp ne i4 [[TMP1]], 0
+; CHECK-NEXT: [[TMP3:%.*]] = sext i1 [[TMP2]] to i32
+; CHECK-NEXT: ret i32 [[TMP3]]
;
%sext = sext <4 x i1> %x to <4 x i32>
%res = call i32 @llvm.vector.reduce.or.v4i32(<4 x i32> %sext)
@@ -24,9 +25,10 @@ define i32 @reduce_or_sext(<4 x i1> %x) {
define i64 @reduce_or_zext(<8 x i1> %x) {
; CHECK-LABEL: @reduce_or_zext(
-; CHECK-NEXT: [[ZEXT:%.*]] = zext <8 x i1> [[X:%.*]] to <8 x i64>
-; CHECK-NEXT: [[RES:%.*]] = call i64 @llvm.vector.reduce.or.v8i64(<8 x i64> [[ZEXT]])
-; CHECK-NEXT: ret i64 [[RES]]
+; CHECK-NEXT: [[TMP1:%.*]] = bitcast <8 x i1> [[X:%.*]] to i8
+; CHECK-NEXT: [[TMP2:%.*]] = icmp ne i8 [[TMP1]], 0
+; CHECK-NEXT: [[TMP3:%.*]] = zext i1 [[TMP2]] to i64
+; CHECK-NEXT: ret i64 [[TMP3]]
;
%zext = zext <8 x i1> %x to <8 x i64>
%res = call i64 @llvm.vector.reduce.or.v8i64(<8 x i64> %zext)
@@ -35,9 +37,10 @@ define i64 @reduce_or_zext(<8 x i1> %x) {
define i16 @reduce_or_sext_same(<16 x i1> %x) {
; CHECK-LABEL: @reduce_or_sext_same(
-; CHECK-NEXT: [[SEXT:%.*]] = sext <16 x i1> [[X:%.*]] to <16 x i16>
-; CHECK-NEXT: [[RES:%.*]] = call i16 @llvm.vector.reduce.or.v16i16(<16 x i16> [[SEXT]])
-; CHECK-NEXT: ret i16 [[RES]]
+; CHECK-NEXT: [[TMP1:%.*]] = bitcast <16 x i1> [[X:%.*]] to i16
+; CHECK-NEXT: [[TMP2:%.*]] = icmp ne i16 [[TMP1]], 0
+; CHECK-NEXT: [[TMP3:%.*]] = sext i1 [[TMP2]] to i16
+; CHECK-NEXT: ret i16 [[TMP3]]
;
%sext = sext <16 x i1> %x to <16 x i16>
%res = call i16 @llvm.vector.reduce.or.v16i16(<16 x i16> %sext)
@@ -46,9 +49,10 @@ define i16 @reduce_or_sext_same(<16 x i1> %x) {
define i8 @reduce_or_zext_long(<128 x i1> %x) {
; CHECK-LABEL: @reduce_or_zext_long(
-; CHECK-NEXT: [[SEXT:%.*]] = sext <128 x i1> [[X:%.*]] to <128 x i8>
-; CHECK-NEXT: [[RES:%.*]] = call i8 @llvm.vector.reduce.or.v128i8(<128 x i8> [[SEXT]])
-; CHECK-NEXT: ret i8 [[RES]]
+; CHECK-NEXT: [[TMP1:%.*]] = bitcast <128 x i1> [[X:%.*]] to i128
+; CHECK-NEXT: [[TMP2:%.*]] = icmp ne i128 [[TMP1]], 0
+; CHECK-NEXT: [[TMP3:%.*]] = sext i1 [[TMP2]] to i8
+; CHECK-NEXT: ret i8 [[TMP3]]
;
%sext = sext <128 x i1> %x to <128 x i8>
%res = call i8 @llvm.vector.reduce.or.v128i8(<128 x i8> %sext)
@@ -58,11 +62,13 @@ define i8 @reduce_or_zext_long(<128 x i1> %x) {
@glob = external global i8, align 1
define i8 @reduce_or_zext_long_external_use(<128 x i1> %x) {
; CHECK-LABEL: @reduce_or_zext_long_external_use(
-; CHECK-NEXT: [[SEXT:%.*]] = sext <128 x i1> [[X:%.*]] to <128 x i8>
-; CHECK-NEXT: [[RES:%.*]] = call i8 @llvm.vector.reduce.or.v128i8(<128 x i8> [[SEXT]])
-; CHECK-NEXT: [[EXT:%.*]] = extractelement <128 x i8> [[SEXT]], i32 0
+; CHECK-NEXT: [[TMP1:%.*]] = bitcast <128 x i1> [[X:%.*]] to i128
+; CHECK-NEXT: [[TMP2:%.*]] = icmp ne i128 [[TMP1]], 0
+; CHECK-NEXT: [[TMP3:%.*]] = sext i1 [[TMP2]] to i8
+; CHECK-NEXT: [[TMP4:%.*]] = extractelement <128 x i1> [[X]], i32 0
+; CHECK-NEXT: [[EXT:%.*]] = sext i1 [[TMP4]] to i8
; CHECK-NEXT: store i8 [[EXT]], i8* @glob, align 1
-; CHECK-NEXT: ret i8 [[RES]]
+; CHECK-NEXT: ret i8 [[TMP3]]
;
%sext = sext <128 x i1> %x to <128 x i8>
%res = call i8 @llvm.vector.reduce.or.v128i8(<128 x i8> %sext)
@@ -74,11 +80,13 @@ define i8 @reduce_or_zext_long_external_use(<128 x i1> %x) {
@glob1 = external global i64, align 8
define i64 @reduce_or_zext_external_use(<8 x i1> %x) {
; CHECK-LABEL: @reduce_or_zext_external_use(
-; CHECK-NEXT: [[ZEXT:%.*]] = zext <8 x i1> [[X:%.*]] to <8 x i64>
-; CHECK-NEXT: [[RES:%.*]] = call i64 @llvm.vector.reduce.or.v8i64(<8 x i64> [[ZEXT]])
-; CHECK-NEXT: [[EXT:%.*]] = extractelement <8 x i64> [[ZEXT]], i32 0
+; CHECK-NEXT: [[TMP1:%.*]] = bitcast <8 x i1> [[X:%.*]] to i8
+; CHECK-NEXT: [[TMP2:%.*]] = icmp ne i8 [[TMP1]], 0
+; CHECK-NEXT: [[TMP3:%.*]] = zext i1 [[TMP2]] to i64
+; CHECK-NEXT: [[TMP4:%.*]] = extractelement <8 x i1> [[X]], i32 0
+; CHECK-NEXT: [[EXT:%.*]] = zext i1 [[TMP4]] to i64
; CHECK-NEXT: store i64 [[EXT]], i64* @glob1, align 8
-; CHECK-NEXT: ret i64 [[RES]]
+; CHECK-NEXT: ret i64 [[TMP3]]
;
%zext = zext <8 x i1> %x to <8 x i64>
%res = call i64 @llvm.vector.reduce.or.v8i64(<8 x i64> %zext)
More information about the llvm-commits
mailing list