[llvm] [instcombine] Extend logical reduction canonicalization to scalable vectors (PR #99366)

Philip Reames via llvm-commits llvm-commits at lists.llvm.org
Wed Jul 17 11:25:10 PDT 2024


https://github.com/preames created https://github.com/llvm/llvm-project/pull/99366

These transformations do not depend on the type being fixed in size, so enable them for scalable vectors too.  Unlike for fixed vectors, these are only a canonicalization - the bitcast lowering for and/or/add is not legal on a scalable vector type.

>From edc125082416eba8330cdf7bd661515cc9033d4e Mon Sep 17 00:00:00 2001
From: Philip Reames <preames at rivosinc.com>
Date: Wed, 17 Jul 2024 11:16:16 -0700
Subject: [PATCH] [instcombine] Extend logical reduction canonicalization to
 scalable vectors

These transformations do not depend on the type being fixed in size, so
enable them for scalable vectors too.  Unlike for fixed vectors, these are
only a canonicalization - the bitcast lowering for and/or/add is not
legal on a scalable vector type.
---
 .../Transforms/InstCombine/InstCombineCalls.cpp  | 16 ++++++++--------
 .../InstCombine/vector-logical-reductions.ll     | 14 +++++++-------
 2 files changed, 15 insertions(+), 15 deletions(-)

diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
index 467b291f9a4c3..809be499ee0f9 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
@@ -3430,8 +3430,8 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
       }
 
       if (match(Arg, m_ZExtOrSExtOrSelf(m_Value(Vect)))) {
-        if (auto *FTy = dyn_cast<FixedVectorType>(Vect->getType()))
-          if (FTy->getElementType() == Builder.getInt1Ty()) {
+        if (auto *VTy = dyn_cast<VectorType>(Vect->getType()))
+          if (VTy->getElementType() == Builder.getInt1Ty()) {
             Value *Res = Builder.CreateAddReduce(Vect);
             if (Arg != Vect)
               Res = Builder.CreateCast(cast<CastInst>(Arg)->getOpcode(), Res,
@@ -3460,8 +3460,8 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
       }
 
       if (match(Arg, m_ZExtOrSExtOrSelf(m_Value(Vect)))) {
-        if (auto *FTy = dyn_cast<FixedVectorType>(Vect->getType()))
-          if (FTy->getElementType() == Builder.getInt1Ty()) {
+        if (auto *VTy = dyn_cast<VectorType>(Vect->getType()))
+          if (VTy->getElementType() == Builder.getInt1Ty()) {
             Value *Res = Builder.CreateAndReduce(Vect);
             if (Res->getType() != II->getType())
               Res = Builder.CreateZExt(Res, II->getType());
@@ -3491,8 +3491,8 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
       }
 
       if (match(Arg, m_ZExtOrSExtOrSelf(m_Value(Vect)))) {
-        if (auto *FTy = dyn_cast<FixedVectorType>(Vect->getType()))
-          if (FTy->getElementType() == Builder.getInt1Ty()) {
+        if (auto *VTy = dyn_cast<VectorType>(Vect->getType()))
+          if (VTy->getElementType() == Builder.getInt1Ty()) {
             Value *Res = IID == Intrinsic::vector_reduce_umin
                              ? Builder.CreateAndReduce(Vect)
                              : Builder.CreateOrReduce(Vect);
@@ -3533,8 +3533,8 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
       }
 
       if (match(Arg, m_ZExtOrSExtOrSelf(m_Value(Vect)))) {
-        if (auto *FTy = dyn_cast<FixedVectorType>(Vect->getType()))
-          if (FTy->getElementType() == Builder.getInt1Ty()) {
+        if (auto *VTy = dyn_cast<VectorType>(Vect->getType()))
+          if (VTy->getElementType() == Builder.getInt1Ty()) {
             Instruction::CastOps ExtOpc = Instruction::CastOps::CastOpsEnd;
             if (Arg != Vect)
               ExtOpc = cast<CastInst>(Arg)->getOpcode();
diff --git a/llvm/test/Transforms/InstCombine/vector-logical-reductions.ll b/llvm/test/Transforms/InstCombine/vector-logical-reductions.ll
index 74f4ed01085f8..52e6a0b009978 100644
--- a/llvm/test/Transforms/InstCombine/vector-logical-reductions.ll
+++ b/llvm/test/Transforms/InstCombine/vector-logical-reductions.ll
@@ -51,7 +51,7 @@ define i1 @reduction_logical_mul(<2 x i1> %x) {
 
 define i1 @reduction_logical_mul_nxv2i1(<vscale x 2 x i1> %x) {
 ; CHECK-LABEL: @reduction_logical_mul_nxv2i1(
-; CHECK-NEXT:    [[R:%.*]] = call i1 @llvm.vector.reduce.mul.nxv2i1(<vscale x 2 x i1> [[X:%.*]])
+; CHECK-NEXT:    [[R:%.*]] = call i1 @llvm.vector.reduce.and.nxv2i1(<vscale x 2 x i1> [[X:%.*]])
 ; CHECK-NEXT:    ret i1 [[R]]
 ;
   %r = call i1 @llvm.vector.reduce.mul.nxv2i1(<vscale x 2 x i1> %x)
@@ -71,7 +71,7 @@ define i1 @reduction_logical_xor(<2 x i1> %x) {
 
 define i1 @reduction_logical_xor_nxv2i1(<vscale x 2 x i1> %x) {
 ; CHECK-LABEL: @reduction_logical_xor_nxv2i1(
-; CHECK-NEXT:    [[R:%.*]] = call i1 @llvm.vector.reduce.xor.nxv2i1(<vscale x 2 x i1> [[X:%.*]])
+; CHECK-NEXT:    [[R:%.*]] = call i1 @llvm.vector.reduce.add.nxv2i1(<vscale x 2 x i1> [[X:%.*]])
 ; CHECK-NEXT:    ret i1 [[R]]
 ;
   %r = call i1 @llvm.vector.reduce.xor.nxv2i1(<vscale x 2 x i1> %x)
@@ -90,7 +90,7 @@ define i1 @reduction_logical_smin(<2 x i1> %x) {
 
 define i1 @reduction_logical_smin_nxv2i1(<vscale x 2 x i1> %x) {
 ; CHECK-LABEL: @reduction_logical_smin_nxv2i1(
-; CHECK-NEXT:    [[R:%.*]] = call i1 @llvm.vector.reduce.smin.nxv2i1(<vscale x 2 x i1> [[X:%.*]])
+; CHECK-NEXT:    [[R:%.*]] = call i1 @llvm.vector.reduce.or.nxv2i1(<vscale x 2 x i1> [[X:%.*]])
 ; CHECK-NEXT:    ret i1 [[R]]
 ;
   %r = call i1 @llvm.vector.reduce.smin.nxv2i1(<vscale x 2 x i1> %x)
@@ -109,7 +109,7 @@ define i1 @reduction_logical_smax(<2 x i1> %x) {
 
 define i1 @reduction_logical_smax_nxv2i1(<vscale x 2 x i1> %x) {
 ; CHECK-LABEL: @reduction_logical_smax_nxv2i1(
-; CHECK-NEXT:    [[R:%.*]] = call i1 @llvm.vector.reduce.smax.nxv2i1(<vscale x 2 x i1> [[X:%.*]])
+; CHECK-NEXT:    [[R:%.*]] = call i1 @llvm.vector.reduce.and.nxv2i1(<vscale x 2 x i1> [[X:%.*]])
 ; CHECK-NEXT:    ret i1 [[R]]
 ;
   %r = call i1 @llvm.vector.reduce.smax.nxv2i1(<vscale x 2 x i1> %x)
@@ -128,7 +128,7 @@ define i1 @reduction_logical_umin(<2 x i1> %x) {
 
 define i1 @reduction_logical_umin_nxv2i1(<vscale x 2 x i1> %x) {
 ; CHECK-LABEL: @reduction_logical_umin_nxv2i1(
-; CHECK-NEXT:    [[R:%.*]] = call i1 @llvm.vector.reduce.umin.nxv2i1(<vscale x 2 x i1> [[X:%.*]])
+; CHECK-NEXT:    [[R:%.*]] = call i1 @llvm.vector.reduce.and.nxv2i1(<vscale x 2 x i1> [[X:%.*]])
 ; CHECK-NEXT:    ret i1 [[R]]
 ;
   %r = call i1 @llvm.vector.reduce.umin.nxv2i1(<vscale x 2 x i1> %x)
@@ -147,7 +147,7 @@ define i1 @reduction_logical_umax(<2 x i1> %x) {
 
 define i1 @reduction_logical_umax_nxv2i1(<vscale x 2 x i1> %x) {
 ; CHECK-LABEL: @reduction_logical_umax_nxv2i1(
-; CHECK-NEXT:    [[R:%.*]] = call i1 @llvm.vector.reduce.umax.nxv2i1(<vscale x 2 x i1> [[X:%.*]])
+; CHECK-NEXT:    [[R:%.*]] = call i1 @llvm.vector.reduce.or.nxv2i1(<vscale x 2 x i1> [[X:%.*]])
 ; CHECK-NEXT:    ret i1 [[R]]
 ;
   %r = call i1 @llvm.vector.reduce.umax.nxv2i1(<vscale x 2 x i1> %x)
@@ -199,7 +199,7 @@ define i1 @reduction_logical_and_reverse_v2i1(<2 x i1> %p) {
 
 define i1 @reduction_logical_xor_reverse_nxv2i1(<vscale x 2 x i1> %p) {
 ; CHECK-LABEL: @reduction_logical_xor_reverse_nxv2i1(
-; CHECK-NEXT:    [[RED:%.*]] = call i1 @llvm.vector.reduce.xor.nxv2i1(<vscale x 2 x i1> [[P:%.*]])
+; CHECK-NEXT:    [[RED:%.*]] = call i1 @llvm.vector.reduce.add.nxv2i1(<vscale x 2 x i1> [[P:%.*]])
 ; CHECK-NEXT:    ret i1 [[RED]]
 ;
   %rev = call <vscale x 2 x i1> @llvm.vector.reverse.nxv2i1(<vscale x 2 x i1> %p)



More information about the llvm-commits mailing list