[llvm] 1e80143 - [InstCombine] `xor` reduction w/ i1 elt type is a parity check
Roman Lebedev via llvm-commits
llvm-commits at lists.llvm.org
Mon Aug 2 10:23:22 PDT 2021
Author: Roman Lebedev
Date: 2021-08-02T20:21:37+03:00
New Revision: 1e801439be26569c9ede6fd309a645b00adb656c
URL: https://github.com/llvm/llvm-project/commit/1e801439be26569c9ede6fd309a645b00adb656c
DIFF: https://github.com/llvm/llvm-project/commit/1e801439be26569c9ede6fd309a645b00adb656c.diff
LOG: [InstCombine] `xor` reduction w/ i1 elt type is a parity check
For i1 element type, `xor` and `add` are interchangeable
(https://alive2.llvm.org/ce/z/e77hhQ), so we should treat it just like
an `add` reduction and consistently transform them both:
https://alive2.llvm.org/ce/z/MjCm5W (self)
https://alive2.llvm.org/ce/z/kgqF4M (skipped zext)
https://alive2.llvm.org/ce/z/pgy3HP (skipped sext)
Though, let's emit the IR that is similar to the one we produce for
`vector_reduce_add(<n x i1>)`.
See https://bugs.llvm.org/show_bug.cgi?id=51259
Added:
Modified:
llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
llvm/test/Transforms/InstCombine/reduction-xor-sext-zext-i1.ll
Removed:
################################################################################
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
index 726bb545be12a..1ada23b979a0b 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
@@ -2017,8 +2017,35 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
}
LLVM_FALLTHROUGH;
}
+ case Intrinsic::vector_reduce_xor: {
+ if (IID == Intrinsic::vector_reduce_xor) {
+ // Convert vector_reduce_xor(zext(<n x i1>)) to
+ // (ZExtOrTrunc(ctpop(bitcast <n x i1> to iN) & 1)).
+ // Convert vector_reduce_xor(sext(<n x i1>)) to
+ // -(ZExtOrTrunc(ctpop(bitcast <n x i1> to iN) & 1)).
+ // Convert vector_reduce_xor(<n x i1>) to
+ // ZExtOrTrunc(ctpop(bitcast <n x i1> to iN) & 1).
+ Value *Arg = II->getArgOperand(0);
+ Value *Vect;
+ if (match(Arg, m_ZExtOrSExtOrSelf(m_Value(Vect)))) {
+ if (auto *FTy = dyn_cast<FixedVectorType>(Vect->getType()))
+ if (FTy->getElementType() == Builder.getInt1Ty()) {
+ Value *V = Builder.CreateBitCast(
+ Vect, Builder.getIntNTy(FTy->getNumElements()));
+ Value *Res = Builder.CreateUnaryIntrinsic(Intrinsic::ctpop, V);
+ Res = Builder.CreateAnd(Res, ConstantInt::get(Res->getType(), 1));
+ if (Res->getType() != II->getType())
+ Res = Builder.CreateZExtOrTrunc(Res, II->getType());
+ if (Arg != Vect &&
+ cast<Instruction>(Arg)->getOpcode() == Instruction::SExt)
+ Res = Builder.CreateNeg(Res);
+ return replaceInstUsesWith(CI, Res);
+ }
+ }
+ }
+ LLVM_FALLTHROUGH;
+ }
case Intrinsic::vector_reduce_mul:
- case Intrinsic::vector_reduce_xor:
case Intrinsic::vector_reduce_umax:
case Intrinsic::vector_reduce_umin:
case Intrinsic::vector_reduce_smax:
diff --git a/llvm/test/Transforms/InstCombine/reduction-xor-sext-zext-i1.ll b/llvm/test/Transforms/InstCombine/reduction-xor-sext-zext-i1.ll
index 67709f4e5080f..e4022355f5dac 100644
--- a/llvm/test/Transforms/InstCombine/reduction-xor-sext-zext-i1.ll
+++ b/llvm/test/Transforms/InstCombine/reduction-xor-sext-zext-i1.ll
@@ -3,8 +3,11 @@
define i1 @reduce_xor_self(<8 x i1> %x) {
; CHECK-LABEL: @reduce_xor_self(
-; CHECK-NEXT: [[RES:%.*]] = call i1 @llvm.vector.reduce.xor.v8i1(<8 x i1> [[X:%.*]])
-; CHECK-NEXT: ret i1 [[RES]]
+; CHECK-NEXT: [[TMP1:%.*]] = bitcast <8 x i1> [[X:%.*]] to i8
+; CHECK-NEXT: [[TMP2:%.*]] = call i8 @llvm.ctpop.i8(i8 [[TMP1]]), !range [[RNG0:![0-9]+]]
+; CHECK-NEXT: [[TMP3:%.*]] = and i8 [[TMP2]], 1
+; CHECK-NEXT: [[TMP4:%.*]] = icmp ne i8 [[TMP3]], 0
+; CHECK-NEXT: ret i1 [[TMP4]]
;
%res = call i1 @llvm.vector.reduce.xor.v8i32(<8 x i1> %x)
ret i1 %res
@@ -12,9 +15,12 @@ define i1 @reduce_xor_self(<8 x i1> %x) {
define i32 @reduce_xor_sext(<4 x i1> %x) {
; CHECK-LABEL: @reduce_xor_sext(
-; CHECK-NEXT: [[SEXT:%.*]] = sext <4 x i1> [[X:%.*]] to <4 x i32>
-; CHECK-NEXT: [[RES:%.*]] = call i32 @llvm.vector.reduce.xor.v4i32(<4 x i32> [[SEXT]])
-; CHECK-NEXT: ret i32 [[RES]]
+; CHECK-NEXT: [[TMP1:%.*]] = bitcast <4 x i1> [[X:%.*]] to i4
+; CHECK-NEXT: [[TMP2:%.*]] = call i4 @llvm.ctpop.i4(i4 [[TMP1]]), !range [[RNG1:![0-9]+]]
+; CHECK-NEXT: [[TMP3:%.*]] = and i4 [[TMP2]], 1
+; CHECK-NEXT: [[TMP4:%.*]] = zext i4 [[TMP3]] to i32
+; CHECK-NEXT: [[TMP5:%.*]] = sub nsw i32 0, [[TMP4]]
+; CHECK-NEXT: ret i32 [[TMP5]]
;
%sext = sext <4 x i1> %x to <4 x i32>
%res = call i32 @llvm.vector.reduce.xor.v4i32(<4 x i32> %sext)
@@ -23,9 +29,11 @@ define i32 @reduce_xor_sext(<4 x i1> %x) {
define i64 @reduce_xor_zext(<8 x i1> %x) {
; CHECK-LABEL: @reduce_xor_zext(
-; CHECK-NEXT: [[ZEXT:%.*]] = zext <8 x i1> [[X:%.*]] to <8 x i64>
-; CHECK-NEXT: [[RES:%.*]] = call i64 @llvm.vector.reduce.xor.v8i64(<8 x i64> [[ZEXT]])
-; CHECK-NEXT: ret i64 [[RES]]
+; CHECK-NEXT: [[TMP1:%.*]] = bitcast <8 x i1> [[X:%.*]] to i8
+; CHECK-NEXT: [[TMP2:%.*]] = call i8 @llvm.ctpop.i8(i8 [[TMP1]]), !range [[RNG0]]
+; CHECK-NEXT: [[TMP3:%.*]] = and i8 [[TMP2]], 1
+; CHECK-NEXT: [[TMP4:%.*]] = zext i8 [[TMP3]] to i64
+; CHECK-NEXT: ret i64 [[TMP4]]
;
%zext = zext <8 x i1> %x to <8 x i64>
%res = call i64 @llvm.vector.reduce.xor.v8i64(<8 x i64> %zext)
@@ -34,9 +42,11 @@ define i64 @reduce_xor_zext(<8 x i1> %x) {
define i16 @reduce_xor_sext_same(<16 x i1> %x) {
; CHECK-LABEL: @reduce_xor_sext_same(
-; CHECK-NEXT: [[SEXT:%.*]] = sext <16 x i1> [[X:%.*]] to <16 x i16>
-; CHECK-NEXT: [[RES:%.*]] = call i16 @llvm.vector.reduce.xor.v16i16(<16 x i16> [[SEXT]])
-; CHECK-NEXT: ret i16 [[RES]]
+; CHECK-NEXT: [[TMP1:%.*]] = bitcast <16 x i1> [[X:%.*]] to i16
+; CHECK-NEXT: [[TMP2:%.*]] = call i16 @llvm.ctpop.i16(i16 [[TMP1]]), !range [[RNG2:![0-9]+]]
+; CHECK-NEXT: [[TMP3:%.*]] = and i16 [[TMP2]], 1
+; CHECK-NEXT: [[TMP4:%.*]] = sub nsw i16 0, [[TMP3]]
+; CHECK-NEXT: ret i16 [[TMP4]]
;
%sext = sext <16 x i1> %x to <16 x i16>
%res = call i16 @llvm.vector.reduce.xor.v16i16(<16 x i16> %sext)
@@ -45,9 +55,12 @@ define i16 @reduce_xor_sext_same(<16 x i1> %x) {
define i8 @reduce_xor_zext_long(<128 x i1> %x) {
; CHECK-LABEL: @reduce_xor_zext_long(
-; CHECK-NEXT: [[SEXT:%.*]] = sext <128 x i1> [[X:%.*]] to <128 x i8>
-; CHECK-NEXT: [[RES:%.*]] = call i8 @llvm.vector.reduce.xor.v128i8(<128 x i8> [[SEXT]])
-; CHECK-NEXT: ret i8 [[RES]]
+; CHECK-NEXT: [[TMP1:%.*]] = bitcast <128 x i1> [[X:%.*]] to i128
+; CHECK-NEXT: [[TMP2:%.*]] = call i128 @llvm.ctpop.i128(i128 [[TMP1]]), !range [[RNG3:![0-9]+]]
+; CHECK-NEXT: [[TMP3:%.*]] = trunc i128 [[TMP2]] to i8
+; CHECK-NEXT: [[TMP4:%.*]] = and i8 [[TMP3]], 1
+; CHECK-NEXT: [[TMP5:%.*]] = sub nsw i8 0, [[TMP4]]
+; CHECK-NEXT: ret i8 [[TMP5]]
;
%sext = sext <128 x i1> %x to <128 x i8>
%res = call i8 @llvm.vector.reduce.xor.v128i8(<128 x i8> %sext)
@@ -57,11 +70,15 @@ define i8 @reduce_xor_zext_long(<128 x i1> %x) {
@glob = external global i8, align 1
define i8 @reduce_xor_zext_long_external_use(<128 x i1> %x) {
; CHECK-LABEL: @reduce_xor_zext_long_external_use(
-; CHECK-NEXT: [[SEXT:%.*]] = sext <128 x i1> [[X:%.*]] to <128 x i8>
-; CHECK-NEXT: [[RES:%.*]] = call i8 @llvm.vector.reduce.xor.v128i8(<128 x i8> [[SEXT]])
-; CHECK-NEXT: [[EXT:%.*]] = extractelement <128 x i8> [[SEXT]], i32 0
+; CHECK-NEXT: [[TMP1:%.*]] = bitcast <128 x i1> [[X:%.*]] to i128
+; CHECK-NEXT: [[TMP2:%.*]] = call i128 @llvm.ctpop.i128(i128 [[TMP1]]), !range [[RNG3]]
+; CHECK-NEXT: [[TMP3:%.*]] = trunc i128 [[TMP2]] to i8
+; CHECK-NEXT: [[TMP4:%.*]] = and i8 [[TMP3]], 1
+; CHECK-NEXT: [[TMP5:%.*]] = sub nsw i8 0, [[TMP4]]
+; CHECK-NEXT: [[TMP6:%.*]] = extractelement <128 x i1> [[X]], i32 0
+; CHECK-NEXT: [[EXT:%.*]] = sext i1 [[TMP6]] to i8
; CHECK-NEXT: store i8 [[EXT]], i8* @glob, align 1
-; CHECK-NEXT: ret i8 [[RES]]
+; CHECK-NEXT: ret i8 [[TMP5]]
;
%sext = sext <128 x i1> %x to <128 x i8>
%res = call i8 @llvm.vector.reduce.xor.v128i8(<128 x i8> %sext)
@@ -73,11 +90,14 @@ define i8 @reduce_xor_zext_long_external_use(<128 x i1> %x) {
@glob1 = external global i64, align 8
define i64 @reduce_xor_zext_external_use(<8 x i1> %x) {
; CHECK-LABEL: @reduce_xor_zext_external_use(
-; CHECK-NEXT: [[ZEXT:%.*]] = zext <8 x i1> [[X:%.*]] to <8 x i64>
-; CHECK-NEXT: [[RES:%.*]] = call i64 @llvm.vector.reduce.xor.v8i64(<8 x i64> [[ZEXT]])
-; CHECK-NEXT: [[EXT:%.*]] = extractelement <8 x i64> [[ZEXT]], i32 0
+; CHECK-NEXT: [[TMP1:%.*]] = bitcast <8 x i1> [[X:%.*]] to i8
+; CHECK-NEXT: [[TMP2:%.*]] = call i8 @llvm.ctpop.i8(i8 [[TMP1]]), !range [[RNG0]]
+; CHECK-NEXT: [[TMP3:%.*]] = and i8 [[TMP2]], 1
+; CHECK-NEXT: [[TMP4:%.*]] = zext i8 [[TMP3]] to i64
+; CHECK-NEXT: [[TMP5:%.*]] = extractelement <8 x i1> [[X]], i32 0
+; CHECK-NEXT: [[EXT:%.*]] = zext i1 [[TMP5]] to i64
; CHECK-NEXT: store i64 [[EXT]], i64* @glob1, align 8
-; CHECK-NEXT: ret i64 [[RES]]
+; CHECK-NEXT: ret i64 [[TMP4]]
;
%zext = zext <8 x i1> %x to <8 x i64>
%res = call i64 @llvm.vector.reduce.xor.v8i64(<8 x i64> %zext)
More information about the llvm-commits
mailing list