[llvm] [instcombine] Pull sext/zext through reduce.or/and (PR #99394)

Philip Reames via llvm-commits llvm-commits at lists.llvm.org
Wed Jul 17 14:39:24 PDT 2024


https://github.com/preames created https://github.com/llvm/llvm-project/pull/99394

This is directly analogous to the transform we perform on scalar or/and when both operands have the same extend.  (See foldCastedBitwiseLogic.) Oddly, we already had this logic, we just only used it when the source type was an <.. x i1>.

We can likely do this for other arithmetic operations as well; we do for scalar.  This patch covers the case I saw in a real workload, and working incrementally seemed reasonable.

>From 4b87b4008c03187f91ebda040e7a20fabf76df67 Mon Sep 17 00:00:00 2001
From: Philip Reames <preames at rivosinc.com>
Date: Wed, 17 Jul 2024 11:38:27 -0700
Subject: [PATCH] [instcombine] Pull sext/zext through reduce.or/and

This is directly analogous to the transform we perform on scalar or/and
when both operands have the same extend.  (See foldCastedBitwiseLogic.)
Oddly, we already had this logic, we just only used it when the source
type was an <.. x i1>.

We can likely do this for other arithmetic operations as well; we do
for scalar.  This patch covers the case I saw in a real workload, and
working incrementally seemed reasonable.
---
 .../InstCombine/InstCombineCalls.cpp          | 59 +++++++++++--------
 .../InstCombine/reduction-and-sext-zext-i1.ll | 34 +++++++++--
 .../InstCombine/reduction-or-sext-zext-i1.ll  | 22 +++++++
 .../PhaseOrdering/AArch64/quant_4x4.ll        | 10 ++--
 4 files changed, 88 insertions(+), 37 deletions(-)

diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
index 467b291f9a4c3..f656f64713625 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
@@ -3340,13 +3340,6 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
   }
   case Intrinsic::vector_reduce_or:
   case Intrinsic::vector_reduce_and: {
-    // Canonicalize logical or/and reductions:
-    // Or reduction for i1 is represented as:
-    // %val = bitcast <ReduxWidth x i1> to iReduxWidth
-    // %res = cmp ne iReduxWidth %val, 0
-    // And reduction for i1 is represented as:
-    // %val = bitcast <ReduxWidth x i1> to iReduxWidth
-    // %res = cmp eq iReduxWidth %val, 11111
     Value *Arg = II->getArgOperand(0);
     Value *Vect;
 
@@ -3356,24 +3349,40 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
       return II;
     }
 
-    if (match(Arg, m_ZExtOrSExtOrSelf(m_Value(Vect)))) {
-      if (auto *FTy = dyn_cast<FixedVectorType>(Vect->getType()))
-        if (FTy->getElementType() == Builder.getInt1Ty()) {
-          Value *Res = Builder.CreateBitCast(
-              Vect, Builder.getIntNTy(FTy->getNumElements()));
-          if (IID == Intrinsic::vector_reduce_and) {
-            Res = Builder.CreateICmpEQ(
-                Res, ConstantInt::getAllOnesValue(Res->getType()));
-          } else {
-            assert(IID == Intrinsic::vector_reduce_or &&
-                   "Expected or reduction.");
-            Res = Builder.CreateIsNotNull(Res);
-          }
-          if (Arg != Vect)
-            Res = Builder.CreateCast(cast<CastInst>(Arg)->getOpcode(), Res,
-                                     II->getType());
-          return replaceInstUsesWith(CI, Res);
-        }
+    // reduce_or (z/sext V) -> z/sext (reduce_or V)
+    // reduce_and (z/sext V) -> z/sext (reduce_and V)
+    if (match(Arg, m_ZExtOrSExt(m_Value(Vect)))) {
+      Value *Res;
+      if (IID == Intrinsic::vector_reduce_and) {
+        Res = Builder.CreateAndReduce(Vect);
+      } else {
+        assert(IID == Intrinsic::vector_reduce_or && "Expected or reduction.");
+        Res = Builder.CreateOrReduce(Vect);
+      }
+      Res = Builder.CreateCast(cast<CastInst>(Arg)->getOpcode(), Res,
+                               II->getType());
+      return replaceInstUsesWith(CI, Res);
+    }
+
+    // Canonicalize logical or/and reductions:
+    // Or reduction for i1 is represented as:
+    // %val = bitcast <ReduxWidth x i1> to iReduxWidth
+    // %res = cmp ne iReduxWidth %val, 0
+    // And reduction for i1 is represented as:
+    // %val = bitcast <ReduxWidth x i1> to iReduxWidth
+    // %res = cmp eq iReduxWidth %val, 11111
+    if (auto *FTy = dyn_cast<FixedVectorType>(Arg->getType());
+        FTy && FTy->getElementType() == Builder.getInt1Ty()) {
+      Value *Res =
+          Builder.CreateBitCast(Arg, Builder.getIntNTy(FTy->getNumElements()));
+      if (IID == Intrinsic::vector_reduce_and) {
+        Res = Builder.CreateICmpEQ(
+            Res, ConstantInt::getAllOnesValue(Res->getType()));
+      } else {
+        assert(IID == Intrinsic::vector_reduce_or && "Expected or reduction.");
+        Res = Builder.CreateIsNotNull(Res);
+      }
+      return replaceInstUsesWith(CI, Res);
     }
     [[fallthrough]];
   }
diff --git a/llvm/test/Transforms/InstCombine/reduction-and-sext-zext-i1.ll b/llvm/test/Transforms/InstCombine/reduction-and-sext-zext-i1.ll
index 4af9177bbfaaa..7a09f62e8c28d 100644
--- a/llvm/test/Transforms/InstCombine/reduction-and-sext-zext-i1.ll
+++ b/llvm/test/Transforms/InstCombine/reduction-and-sext-zext-i1.ll
@@ -37,6 +37,28 @@ define i64 @reduce_and_zext(<8 x i1> %x) {
   ret i64 %res
 }
 
+define i32 @reduce_and_sext_i8(<4 x i8> %x) {
+; CHECK-LABEL: @reduce_and_sext_i8(
+; CHECK-NEXT:    [[TMP1:%.*]] = call i8 @llvm.vector.reduce.and.v4i8(<4 x i8> [[X:%.*]])
+; CHECK-NEXT:    [[RES:%.*]] = sext i8 [[TMP1]] to i32
+; CHECK-NEXT:    ret i32 [[RES]]
+;
+  %sext = sext <4 x i8> %x to <4 x i32>
+  %res = call i32 @llvm.vector.reduce.and.v4i32(<4 x i32> %sext)
+  ret i32 %res
+}
+
+define i64 @reduce_and_zext_i8(<8 x i8> %x) {
+; CHECK-LABEL: @reduce_and_zext_i8(
+; CHECK-NEXT:    [[TMP1:%.*]] = call i8 @llvm.vector.reduce.and.v8i8(<8 x i8> [[X:%.*]])
+; CHECK-NEXT:    [[RES:%.*]] = zext i8 [[TMP1]] to i64
+; CHECK-NEXT:    ret i64 [[RES]]
+;
+  %zext = zext <8 x i8> %x to <8 x i64>
+  %res = call i64 @llvm.vector.reduce.and.v8i64(<8 x i64> %zext)
+  ret i64 %res
+}
+
 define i16 @reduce_and_sext_same(<16 x i1> %x) {
 ; CHECK-LABEL: @reduce_and_sext_same(
 ; CHECK-NEXT:    [[TMP1:%.*]] = bitcast <16 x i1> [[X:%.*]] to i16
@@ -118,9 +140,9 @@ define i1 @reduce_and_pointer_cast_wide(ptr %arg, ptr %arg1) {
 ; CHECK-NEXT:  bb:
 ; CHECK-NEXT:    [[LHS:%.*]] = load <8 x i16>, ptr [[ARG1:%.*]], align 16
 ; CHECK-NEXT:    [[RHS:%.*]] = load <8 x i16>, ptr [[ARG:%.*]], align 16
-; CHECK-NEXT:    [[CMP:%.*]] = icmp ne <8 x i16> [[LHS]], [[RHS]]
-; CHECK-NEXT:    [[TMP0:%.*]] = bitcast <8 x i1> [[CMP]] to i8
-; CHECK-NEXT:    [[ALL_EQ:%.*]] = icmp eq i8 [[TMP0]], 0
+; CHECK-NEXT:    [[TMP0:%.*]] = icmp ne <8 x i16> [[LHS]], [[RHS]]
+; CHECK-NEXT:    [[TMP1:%.*]] = bitcast <8 x i1> [[TMP0]] to i8
+; CHECK-NEXT:    [[ALL_EQ:%.*]] = icmp eq i8 [[TMP1]], 0
 ; CHECK-NEXT:    ret i1 [[ALL_EQ]]
 ;
 bb:
@@ -153,9 +175,9 @@ define i1 @reduce_and_pointer_cast_ne_wide(ptr %arg, ptr %arg1) {
 ; CHECK-NEXT:  bb:
 ; CHECK-NEXT:    [[LHS:%.*]] = load <8 x i16>, ptr [[ARG1:%.*]], align 16
 ; CHECK-NEXT:    [[RHS:%.*]] = load <8 x i16>, ptr [[ARG:%.*]], align 16
-; CHECK-NEXT:    [[CMP:%.*]] = icmp ne <8 x i16> [[LHS]], [[RHS]]
-; CHECK-NEXT:    [[TMP0:%.*]] = bitcast <8 x i1> [[CMP]] to i8
-; CHECK-NEXT:    [[ALL_EQ:%.*]] = icmp ne i8 [[TMP0]], 0
+; CHECK-NEXT:    [[TMP0:%.*]] = icmp ne <8 x i16> [[LHS]], [[RHS]]
+; CHECK-NEXT:    [[TMP1:%.*]] = bitcast <8 x i1> [[TMP0]] to i8
+; CHECK-NEXT:    [[ALL_EQ:%.*]] = icmp ne i8 [[TMP1]], 0
 ; CHECK-NEXT:    ret i1 [[ALL_EQ]]
 ;
 bb:
diff --git a/llvm/test/Transforms/InstCombine/reduction-or-sext-zext-i1.ll b/llvm/test/Transforms/InstCombine/reduction-or-sext-zext-i1.ll
index 48c139663b62d..c8fc9210a269f 100644
--- a/llvm/test/Transforms/InstCombine/reduction-or-sext-zext-i1.ll
+++ b/llvm/test/Transforms/InstCombine/reduction-or-sext-zext-i1.ll
@@ -37,6 +37,28 @@ define i64 @reduce_or_zext(<8 x i1> %x) {
   ret i64 %res
 }
 
+define i32 @reduce_or_sext_i8(<4 x i8> %x) {
+; CHECK-LABEL: @reduce_or_sext_i8(
+; CHECK-NEXT:    [[TMP1:%.*]] = call i8 @llvm.vector.reduce.or.v4i8(<4 x i8> [[X:%.*]])
+; CHECK-NEXT:    [[RES:%.*]] = sext i8 [[TMP1]] to i32
+; CHECK-NEXT:    ret i32 [[RES]]
+;
+  %sext = sext <4 x i8> %x to <4 x i32>
+  %res = call i32 @llvm.vector.reduce.or.v4i32(<4 x i32> %sext)
+  ret i32 %res
+}
+
+define i64 @reduce_or_zext_i8(<8 x i8> %x) {
+; CHECK-LABEL: @reduce_or_zext_i8(
+; CHECK-NEXT:    [[TMP1:%.*]] = call i8 @llvm.vector.reduce.or.v8i8(<8 x i8> [[X:%.*]])
+; CHECK-NEXT:    [[RES:%.*]] = zext i8 [[TMP1]] to i64
+; CHECK-NEXT:    ret i64 [[RES]]
+;
+  %zext = zext <8 x i8> %x to <8 x i64>
+  %res = call i64 @llvm.vector.reduce.or.v8i64(<8 x i64> %zext)
+  ret i64 %res
+}
+
 define i16 @reduce_or_sext_same(<16 x i1> %x) {
 ; CHECK-LABEL: @reduce_or_sext_same(
 ; CHECK-NEXT:    [[TMP1:%.*]] = bitcast <16 x i1> [[X:%.*]] to i16
diff --git a/llvm/test/Transforms/PhaseOrdering/AArch64/quant_4x4.ll b/llvm/test/Transforms/PhaseOrdering/AArch64/quant_4x4.ll
index c133852f66937..d8adfe274e8cf 100644
--- a/llvm/test/Transforms/PhaseOrdering/AArch64/quant_4x4.ll
+++ b/llvm/test/Transforms/PhaseOrdering/AArch64/quant_4x4.ll
@@ -62,12 +62,11 @@ define i32 @quant_4x4(ptr noundef %dct, ptr noundef %mf, ptr noundef %bias) {
 ; CHECK-NEXT:    store <8 x i16> [[PREDPHI]], ptr [[DCT]], align 2, !alias.scope [[META0]], !noalias [[META3]]
 ; CHECK-NEXT:    store <8 x i16> [[PREDPHI34]], ptr [[TMP0]], align 2, !alias.scope [[META0]], !noalias [[META3]]
 ; CHECK-NEXT:    [[BIN_RDX35:%.*]] = or <8 x i16> [[PREDPHI34]], [[PREDPHI]]
-; CHECK-NEXT:    [[BIN_RDX:%.*]] = sext <8 x i16> [[BIN_RDX35]] to <8 x i32>
-; CHECK-NEXT:    [[TMP29:%.*]] = tail call i32 @llvm.vector.reduce.or.v8i32(<8 x i32> [[BIN_RDX]])
+; CHECK-NEXT:    [[TMP29:%.*]] = tail call i16 @llvm.vector.reduce.or.v8i16(<8 x i16> [[BIN_RDX35]])
 ; CHECK-NEXT:    br label [[FOR_COND_CLEANUP:%.*]]
 ; CHECK:       for.cond.cleanup:
-; CHECK-NEXT:    [[OR_LCSSA:%.*]] = phi i32 [ [[TMP29]], [[VECTOR_BODY]] ], [ [[OR_15:%.*]], [[IF_END_15:%.*]] ]
-; CHECK-NEXT:    [[TOBOOL:%.*]] = icmp ne i32 [[OR_LCSSA]], 0
+; CHECK-NEXT:    [[OR_LCSSA_IN:%.*]] = phi i16 [ [[TMP29]], [[VECTOR_BODY]] ], [ [[OR_1551:%.*]], [[IF_END_15:%.*]] ]
+; CHECK-NEXT:    [[TOBOOL:%.*]] = icmp ne i16 [[OR_LCSSA_IN]], 0
 ; CHECK-NEXT:    [[LNOT_EXT:%.*]] = zext i1 [[TOBOOL]] to i32
 ; CHECK-NEXT:    ret i32 [[LNOT_EXT]]
 ; CHECK:       for.body:
@@ -514,8 +513,7 @@ define i32 @quant_4x4(ptr noundef %dct, ptr noundef %mf, ptr noundef %bias) {
 ; CHECK:       if.end.15:
 ; CHECK-NEXT:    [[STOREMERGE_15:%.*]] = phi i16 [ [[CONV28_15]], [[IF_ELSE_15]] ], [ [[CONV12_15]], [[IF_THEN_15]] ]
 ; CHECK-NEXT:    store i16 [[STOREMERGE_15]], ptr [[ARRAYIDX_15]], align 2
-; CHECK-NEXT:    [[OR_1551:%.*]] = or i16 [[OR_1450]], [[STOREMERGE_15]]
-; CHECK-NEXT:    [[OR_15]] = sext i16 [[OR_1551]] to i32
+; CHECK-NEXT:    [[OR_1551]] = or i16 [[OR_1450]], [[STOREMERGE_15]]
 ; CHECK-NEXT:    br label [[FOR_COND_CLEANUP]]
 ;
 entry:



More information about the llvm-commits mailing list