[llvm] [instcombine] Pull sext/zext through reduce.or/and (PR #99394)
Philip Reames via llvm-commits
llvm-commits at lists.llvm.org
Wed Jul 17 14:39:24 PDT 2024
https://github.com/preames created https://github.com/llvm/llvm-project/pull/99394
This is directly analogous to the transform we perform on scalar or/and when both operands have the same extend. (See foldCastedBitwiseLogic.) Oddly, we already had this logic, we just only used it when the source type was an <.. x i1>.
We can likely do this for other arithmetic operations as well; we do for scalar. This patch covers the case I saw in a real workload, and working incrementally seemed reasonable.
>From 4b87b4008c03187f91ebda040e7a20fabf76df67 Mon Sep 17 00:00:00 2001
From: Philip Reames <preames at rivosinc.com>
Date: Wed, 17 Jul 2024 11:38:27 -0700
Subject: [PATCH] [instcombine] Pull sext/zext through reduce.or/and
This is directly analogous to the transform we perform on scalar or/and
when both operands have the same extend. (See foldCastedBitwiseLogic.)
Oddly, we already had this logic, we just only used it when the source
type was an <.. x i1>.
We can likely do this for other arithmetic operations as well; we do
for scalar. This patch covers the case I saw in a real workload, and
working incrementally seemed reasonable.
---
.../InstCombine/InstCombineCalls.cpp | 59 +++++++++++--------
.../InstCombine/reduction-and-sext-zext-i1.ll | 34 +++++++++--
.../InstCombine/reduction-or-sext-zext-i1.ll | 22 +++++++
.../PhaseOrdering/AArch64/quant_4x4.ll | 10 ++--
4 files changed, 88 insertions(+), 37 deletions(-)
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
index 467b291f9a4c3..f656f64713625 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
@@ -3340,13 +3340,6 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
}
case Intrinsic::vector_reduce_or:
case Intrinsic::vector_reduce_and: {
- // Canonicalize logical or/and reductions:
- // Or reduction for i1 is represented as:
- // %val = bitcast <ReduxWidth x i1> to iReduxWidth
- // %res = cmp ne iReduxWidth %val, 0
- // And reduction for i1 is represented as:
- // %val = bitcast <ReduxWidth x i1> to iReduxWidth
- // %res = cmp eq iReduxWidth %val, 11111
Value *Arg = II->getArgOperand(0);
Value *Vect;
@@ -3356,24 +3349,40 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
return II;
}
- if (match(Arg, m_ZExtOrSExtOrSelf(m_Value(Vect)))) {
- if (auto *FTy = dyn_cast<FixedVectorType>(Vect->getType()))
- if (FTy->getElementType() == Builder.getInt1Ty()) {
- Value *Res = Builder.CreateBitCast(
- Vect, Builder.getIntNTy(FTy->getNumElements()));
- if (IID == Intrinsic::vector_reduce_and) {
- Res = Builder.CreateICmpEQ(
- Res, ConstantInt::getAllOnesValue(Res->getType()));
- } else {
- assert(IID == Intrinsic::vector_reduce_or &&
- "Expected or reduction.");
- Res = Builder.CreateIsNotNull(Res);
- }
- if (Arg != Vect)
- Res = Builder.CreateCast(cast<CastInst>(Arg)->getOpcode(), Res,
- II->getType());
- return replaceInstUsesWith(CI, Res);
- }
+ // reduce_or (z/sext V) -> z/sext (reduce_or V)
+ // reduce_and (z/sext V) -> z/sext (reduce_and V)
+ if (match(Arg, m_ZExtOrSExt(m_Value(Vect)))) {
+ Value *Res;
+ if (IID == Intrinsic::vector_reduce_and) {
+ Res = Builder.CreateAndReduce(Vect);
+ } else {
+ assert(IID == Intrinsic::vector_reduce_or && "Expected or reduction.");
+ Res = Builder.CreateOrReduce(Vect);
+ }
+ Res = Builder.CreateCast(cast<CastInst>(Arg)->getOpcode(), Res,
+ II->getType());
+ return replaceInstUsesWith(CI, Res);
+ }
+
+ // Canonicalize logical or/and reductions:
+ // Or reduction for i1 is represented as:
+ // %val = bitcast <ReduxWidth x i1> to iReduxWidth
+ // %res = cmp ne iReduxWidth %val, 0
+ // And reduction for i1 is represented as:
+ // %val = bitcast <ReduxWidth x i1> to iReduxWidth
+ // %res = cmp eq iReduxWidth %val, 11111
+ if (auto *FTy = dyn_cast<FixedVectorType>(Arg->getType());
+ FTy && FTy->getElementType() == Builder.getInt1Ty()) {
+ Value *Res =
+ Builder.CreateBitCast(Arg, Builder.getIntNTy(FTy->getNumElements()));
+ if (IID == Intrinsic::vector_reduce_and) {
+ Res = Builder.CreateICmpEQ(
+ Res, ConstantInt::getAllOnesValue(Res->getType()));
+ } else {
+ assert(IID == Intrinsic::vector_reduce_or && "Expected or reduction.");
+ Res = Builder.CreateIsNotNull(Res);
+ }
+ return replaceInstUsesWith(CI, Res);
}
[[fallthrough]];
}
diff --git a/llvm/test/Transforms/InstCombine/reduction-and-sext-zext-i1.ll b/llvm/test/Transforms/InstCombine/reduction-and-sext-zext-i1.ll
index 4af9177bbfaaa..7a09f62e8c28d 100644
--- a/llvm/test/Transforms/InstCombine/reduction-and-sext-zext-i1.ll
+++ b/llvm/test/Transforms/InstCombine/reduction-and-sext-zext-i1.ll
@@ -37,6 +37,28 @@ define i64 @reduce_and_zext(<8 x i1> %x) {
ret i64 %res
}
+define i32 @reduce_and_sext_i8(<4 x i8> %x) {
+; CHECK-LABEL: @reduce_and_sext_i8(
+; CHECK-NEXT: [[TMP1:%.*]] = call i8 @llvm.vector.reduce.and.v4i8(<4 x i8> [[X:%.*]])
+; CHECK-NEXT: [[RES:%.*]] = sext i8 [[TMP1]] to i32
+; CHECK-NEXT: ret i32 [[RES]]
+;
+ %sext = sext <4 x i8> %x to <4 x i32>
+ %res = call i32 @llvm.vector.reduce.and.v4i32(<4 x i32> %sext)
+ ret i32 %res
+}
+
+define i64 @reduce_and_zext_i8(<8 x i8> %x) {
+; CHECK-LABEL: @reduce_and_zext_i8(
+; CHECK-NEXT: [[TMP1:%.*]] = call i8 @llvm.vector.reduce.and.v8i8(<8 x i8> [[X:%.*]])
+; CHECK-NEXT: [[RES:%.*]] = zext i8 [[TMP1]] to i64
+; CHECK-NEXT: ret i64 [[RES]]
+;
+ %zext = zext <8 x i8> %x to <8 x i64>
+ %res = call i64 @llvm.vector.reduce.and.v8i64(<8 x i64> %zext)
+ ret i64 %res
+}
+
define i16 @reduce_and_sext_same(<16 x i1> %x) {
; CHECK-LABEL: @reduce_and_sext_same(
; CHECK-NEXT: [[TMP1:%.*]] = bitcast <16 x i1> [[X:%.*]] to i16
@@ -118,9 +140,9 @@ define i1 @reduce_and_pointer_cast_wide(ptr %arg, ptr %arg1) {
; CHECK-NEXT: bb:
; CHECK-NEXT: [[LHS:%.*]] = load <8 x i16>, ptr [[ARG1:%.*]], align 16
; CHECK-NEXT: [[RHS:%.*]] = load <8 x i16>, ptr [[ARG:%.*]], align 16
-; CHECK-NEXT: [[CMP:%.*]] = icmp ne <8 x i16> [[LHS]], [[RHS]]
-; CHECK-NEXT: [[TMP0:%.*]] = bitcast <8 x i1> [[CMP]] to i8
-; CHECK-NEXT: [[ALL_EQ:%.*]] = icmp eq i8 [[TMP0]], 0
+; CHECK-NEXT: [[TMP0:%.*]] = icmp ne <8 x i16> [[LHS]], [[RHS]]
+; CHECK-NEXT: [[TMP1:%.*]] = bitcast <8 x i1> [[TMP0]] to i8
+; CHECK-NEXT: [[ALL_EQ:%.*]] = icmp eq i8 [[TMP1]], 0
; CHECK-NEXT: ret i1 [[ALL_EQ]]
;
bb:
@@ -153,9 +175,9 @@ define i1 @reduce_and_pointer_cast_ne_wide(ptr %arg, ptr %arg1) {
; CHECK-NEXT: bb:
; CHECK-NEXT: [[LHS:%.*]] = load <8 x i16>, ptr [[ARG1:%.*]], align 16
; CHECK-NEXT: [[RHS:%.*]] = load <8 x i16>, ptr [[ARG:%.*]], align 16
-; CHECK-NEXT: [[CMP:%.*]] = icmp ne <8 x i16> [[LHS]], [[RHS]]
-; CHECK-NEXT: [[TMP0:%.*]] = bitcast <8 x i1> [[CMP]] to i8
-; CHECK-NEXT: [[ALL_EQ:%.*]] = icmp ne i8 [[TMP0]], 0
+; CHECK-NEXT: [[TMP0:%.*]] = icmp ne <8 x i16> [[LHS]], [[RHS]]
+; CHECK-NEXT: [[TMP1:%.*]] = bitcast <8 x i1> [[TMP0]] to i8
+; CHECK-NEXT: [[ALL_EQ:%.*]] = icmp ne i8 [[TMP1]], 0
; CHECK-NEXT: ret i1 [[ALL_EQ]]
;
bb:
diff --git a/llvm/test/Transforms/InstCombine/reduction-or-sext-zext-i1.ll b/llvm/test/Transforms/InstCombine/reduction-or-sext-zext-i1.ll
index 48c139663b62d..c8fc9210a269f 100644
--- a/llvm/test/Transforms/InstCombine/reduction-or-sext-zext-i1.ll
+++ b/llvm/test/Transforms/InstCombine/reduction-or-sext-zext-i1.ll
@@ -37,6 +37,28 @@ define i64 @reduce_or_zext(<8 x i1> %x) {
ret i64 %res
}
+define i32 @reduce_or_sext_i8(<4 x i8> %x) {
+; CHECK-LABEL: @reduce_or_sext_i8(
+; CHECK-NEXT: [[TMP1:%.*]] = call i8 @llvm.vector.reduce.or.v4i8(<4 x i8> [[X:%.*]])
+; CHECK-NEXT: [[RES:%.*]] = sext i8 [[TMP1]] to i32
+; CHECK-NEXT: ret i32 [[RES]]
+;
+ %sext = sext <4 x i8> %x to <4 x i32>
+ %res = call i32 @llvm.vector.reduce.or.v4i32(<4 x i32> %sext)
+ ret i32 %res
+}
+
+define i64 @reduce_or_zext_i8(<8 x i8> %x) {
+; CHECK-LABEL: @reduce_or_zext_i8(
+; CHECK-NEXT: [[TMP1:%.*]] = call i8 @llvm.vector.reduce.or.v8i8(<8 x i8> [[X:%.*]])
+; CHECK-NEXT: [[RES:%.*]] = zext i8 [[TMP1]] to i64
+; CHECK-NEXT: ret i64 [[RES]]
+;
+ %zext = zext <8 x i8> %x to <8 x i64>
+ %res = call i64 @llvm.vector.reduce.or.v8i64(<8 x i64> %zext)
+ ret i64 %res
+}
+
define i16 @reduce_or_sext_same(<16 x i1> %x) {
; CHECK-LABEL: @reduce_or_sext_same(
; CHECK-NEXT: [[TMP1:%.*]] = bitcast <16 x i1> [[X:%.*]] to i16
diff --git a/llvm/test/Transforms/PhaseOrdering/AArch64/quant_4x4.ll b/llvm/test/Transforms/PhaseOrdering/AArch64/quant_4x4.ll
index c133852f66937..d8adfe274e8cf 100644
--- a/llvm/test/Transforms/PhaseOrdering/AArch64/quant_4x4.ll
+++ b/llvm/test/Transforms/PhaseOrdering/AArch64/quant_4x4.ll
@@ -62,12 +62,11 @@ define i32 @quant_4x4(ptr noundef %dct, ptr noundef %mf, ptr noundef %bias) {
; CHECK-NEXT: store <8 x i16> [[PREDPHI]], ptr [[DCT]], align 2, !alias.scope [[META0]], !noalias [[META3]]
; CHECK-NEXT: store <8 x i16> [[PREDPHI34]], ptr [[TMP0]], align 2, !alias.scope [[META0]], !noalias [[META3]]
; CHECK-NEXT: [[BIN_RDX35:%.*]] = or <8 x i16> [[PREDPHI34]], [[PREDPHI]]
-; CHECK-NEXT: [[BIN_RDX:%.*]] = sext <8 x i16> [[BIN_RDX35]] to <8 x i32>
-; CHECK-NEXT: [[TMP29:%.*]] = tail call i32 @llvm.vector.reduce.or.v8i32(<8 x i32> [[BIN_RDX]])
+; CHECK-NEXT: [[TMP29:%.*]] = tail call i16 @llvm.vector.reduce.or.v8i16(<8 x i16> [[BIN_RDX35]])
; CHECK-NEXT: br label [[FOR_COND_CLEANUP:%.*]]
; CHECK: for.cond.cleanup:
-; CHECK-NEXT: [[OR_LCSSA:%.*]] = phi i32 [ [[TMP29]], [[VECTOR_BODY]] ], [ [[OR_15:%.*]], [[IF_END_15:%.*]] ]
-; CHECK-NEXT: [[TOBOOL:%.*]] = icmp ne i32 [[OR_LCSSA]], 0
+; CHECK-NEXT: [[OR_LCSSA_IN:%.*]] = phi i16 [ [[TMP29]], [[VECTOR_BODY]] ], [ [[OR_1551:%.*]], [[IF_END_15:%.*]] ]
+; CHECK-NEXT: [[TOBOOL:%.*]] = icmp ne i16 [[OR_LCSSA_IN]], 0
; CHECK-NEXT: [[LNOT_EXT:%.*]] = zext i1 [[TOBOOL]] to i32
; CHECK-NEXT: ret i32 [[LNOT_EXT]]
; CHECK: for.body:
@@ -514,8 +513,7 @@ define i32 @quant_4x4(ptr noundef %dct, ptr noundef %mf, ptr noundef %bias) {
; CHECK: if.end.15:
; CHECK-NEXT: [[STOREMERGE_15:%.*]] = phi i16 [ [[CONV28_15]], [[IF_ELSE_15]] ], [ [[CONV12_15]], [[IF_THEN_15]] ]
; CHECK-NEXT: store i16 [[STOREMERGE_15]], ptr [[ARRAYIDX_15]], align 2
-; CHECK-NEXT: [[OR_1551:%.*]] = or i16 [[OR_1450]], [[STOREMERGE_15]]
-; CHECK-NEXT: [[OR_15]] = sext i16 [[OR_1551]] to i32
+; CHECK-NEXT: [[OR_1551]] = or i16 [[OR_1450]], [[STOREMERGE_15]]
; CHECK-NEXT: br label [[FOR_COND_CLEANUP]]
;
entry:
More information about the llvm-commits
mailing list