[llvm] [vectorcombine] Pull sext/zext through reduce.or/and/xor (PR #99548)

Philip Reames via llvm-commits llvm-commits at lists.llvm.org
Thu Jul 18 12:04:12 PDT 2024


https://github.com/preames updated https://github.com/llvm/llvm-project/pull/99548

>From 044926bef56b6999082f09295b295a0a2fb7f8fa Mon Sep 17 00:00:00 2001
From: Philip Reames <preames at rivosinc.com>
Date: Thu, 18 Jul 2024 10:49:45 -0700
Subject: [PATCH 1/2] [vectorcombine] Pull sext/zext through reduce.or/and/xor

This extends the existing foldTruncFromReductions transform to handle
sext and zext as well.  This is only legal for the bitwise reductions
(and/or/xor) and not the arithmetic ones (add, mul).  Use the same
costing decision to drive whether we do the transform.
---
 .../Transforms/Vectorize/VectorCombine.cpp    | 40 ++++++++++++-------
 .../VectorCombine/RISCV/vecreduce-of-cast.ll  | 30 ++++++++++----
 2 files changed, 48 insertions(+), 22 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
index 3a49f95d3f117..de60d80aeffa1 100644
--- a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
+++ b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
@@ -117,7 +117,7 @@ class VectorCombine {
   bool foldShuffleOfShuffles(Instruction &I);
   bool foldShuffleToIdentity(Instruction &I);
   bool foldShuffleFromReductions(Instruction &I);
-  bool foldTruncFromReductions(Instruction &I);
+  bool foldCastFromReductions(Instruction &I);
   bool foldSelectShuffle(Instruction &I, bool FromReduction = false);
 
   void replaceValue(Value &Old, Value &New) {
@@ -2113,15 +2113,20 @@ bool VectorCombine::foldShuffleFromReductions(Instruction &I) {
 
 /// Determine if its more efficient to fold:
 ///   reduce(trunc(x)) -> trunc(reduce(x)).
-bool VectorCombine::foldTruncFromReductions(Instruction &I) {
+///   reduce(sext(x))  -> sext(reduce(x)).
+///   reduce(zext(x))  -> zext(reduce(x)).
+bool VectorCombine::foldCastFromReductions(Instruction &I) {
   auto *II = dyn_cast<IntrinsicInst>(&I);
   if (!II)
     return false;
 
+  bool TruncOnly = false;
   Intrinsic::ID IID = II->getIntrinsicID();
   switch (IID) {
   case Intrinsic::vector_reduce_add:
   case Intrinsic::vector_reduce_mul:
+    TruncOnly = true;
+    break;
   case Intrinsic::vector_reduce_and:
   case Intrinsic::vector_reduce_or:
   case Intrinsic::vector_reduce_xor:
@@ -2133,25 +2138,32 @@ bool VectorCombine::foldTruncFromReductions(Instruction &I) {
   unsigned ReductionOpc = getArithmeticReductionInstruction(IID);
   Value *ReductionSrc = I.getOperand(0);
 
-  Value *TruncSrc;
-  if (!match(ReductionSrc, m_OneUse(m_Trunc(m_Value(TruncSrc)))))
+  Value *Src;
+  if (!match(ReductionSrc, m_OneUse(m_Trunc(m_Value(Src)))) &&
+      (TruncOnly ||
+       !match(ReductionSrc, m_OneUse(m_ZExtOrSExt(m_Value(Src))))))
     return false;
 
-  auto *TruncSrcTy = cast<VectorType>(TruncSrc->getType());
+  // Note: Only trunc has a constexpr, neither sext or zext do.
+  auto CastOpc = Instruction::Trunc;
+  if (auto *Cast = dyn_cast<CastInst>(ReductionSrc))
+      CastOpc = (Instruction::CastOps)cast<Instruction>(Cast)->getOpcode();
+
+  auto *SrcTy = cast<VectorType>(Src->getType());
   auto *ReductionSrcTy = cast<VectorType>(ReductionSrc->getType());
   Type *ResultTy = I.getType();
 
   TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
   InstructionCost OldCost = TTI.getArithmeticReductionCost(
       ReductionOpc, ReductionSrcTy, std::nullopt, CostKind);
-  if (auto *Trunc = dyn_cast<CastInst>(ReductionSrc))
+  if (auto *Cast = dyn_cast<CastInst>(ReductionSrc))
     OldCost +=
-        TTI.getCastInstrCost(Instruction::Trunc, ReductionSrcTy, TruncSrcTy,
-                             TTI::CastContextHint::None, CostKind, Trunc);
+        TTI.getCastInstrCost(CastOpc, ReductionSrcTy, SrcTy,
+                             TTI::CastContextHint::None, CostKind, Cast);
   InstructionCost NewCost =
-      TTI.getArithmeticReductionCost(ReductionOpc, TruncSrcTy, std::nullopt,
+      TTI.getArithmeticReductionCost(ReductionOpc, SrcTy, std::nullopt,
                                      CostKind) +
-      TTI.getCastInstrCost(Instruction::Trunc, ResultTy,
+      TTI.getCastInstrCost(CastOpc, ResultTy,
                            ReductionSrcTy->getScalarType(),
                            TTI::CastContextHint::None, CostKind);
 
@@ -2159,9 +2171,9 @@ bool VectorCombine::foldTruncFromReductions(Instruction &I) {
     return false;
 
   Value *NewReduction = Builder.CreateIntrinsic(
-      TruncSrcTy->getScalarType(), II->getIntrinsicID(), {TruncSrc});
-  Value *NewTruncation = Builder.CreateTrunc(NewReduction, ResultTy);
-  replaceValue(I, *NewTruncation);
+      SrcTy->getScalarType(), II->getIntrinsicID(), {Src});
+  Value *NewCast = Builder.CreateCast(CastOpc, NewReduction, ResultTy);
+  replaceValue(I, *NewCast);
   return true;
 }
 
@@ -2559,7 +2571,7 @@ bool VectorCombine::run() {
       switch (Opcode) {
       case Instruction::Call:
         MadeChange |= foldShuffleFromReductions(I);
-        MadeChange |= foldTruncFromReductions(I);
+        MadeChange |= foldCastFromReductions(I);
         break;
       case Instruction::ICmp:
       case Instruction::FCmp:
diff --git a/llvm/test/Transforms/VectorCombine/RISCV/vecreduce-of-cast.ll b/llvm/test/Transforms/VectorCombine/RISCV/vecreduce-of-cast.ll
index 9b1aa19f85c21..f04bcc90e5c35 100644
--- a/llvm/test/Transforms/VectorCombine/RISCV/vecreduce-of-cast.ll
+++ b/llvm/test/Transforms/VectorCombine/RISCV/vecreduce-of-cast.ll
@@ -74,8 +74,8 @@ define i16 @reduce_mul_trunc_v8i64_i16(<8 x i64> %a0)  {
 
 define i32 @reduce_or_sext_v8i8_to_v8i32(<8 x i8> %a0)  {
 ; CHECK-LABEL: @reduce_or_sext_v8i8_to_v8i32(
-; CHECK-NEXT:    [[TR:%.*]] = sext <8 x i8> [[A0:%.*]] to <8 x i32>
-; CHECK-NEXT:    [[RED:%.*]] = tail call i32 @llvm.vector.reduce.or.v8i32(<8 x i32> [[TR]])
+; CHECK-NEXT:    [[TMP1:%.*]] = call i8 @llvm.vector.reduce.or.v8i8(<8 x i8> [[A0:%.*]])
+; CHECK-NEXT:    [[RED:%.*]] = sext i8 [[TMP1]] to i32
 ; CHECK-NEXT:    ret i32 [[RED]]
 ;
   %tr = sext <8 x i8> %a0 to <8 x i32>
@@ -85,8 +85,8 @@ define i32 @reduce_or_sext_v8i8_to_v8i32(<8 x i8> %a0)  {
 
 define i32 @reduce_or_sext_v8i16_to_v8i32(<8 x i16> %a0)  {
 ; CHECK-LABEL: @reduce_or_sext_v8i16_to_v8i32(
-; CHECK-NEXT:    [[TR:%.*]] = sext <8 x i16> [[A0:%.*]] to <8 x i32>
-; CHECK-NEXT:    [[RED:%.*]] = tail call i32 @llvm.vector.reduce.or.v8i32(<8 x i32> [[TR]])
+; CHECK-NEXT:    [[TMP1:%.*]] = call i16 @llvm.vector.reduce.or.v8i16(<8 x i16> [[A0:%.*]])
+; CHECK-NEXT:    [[RED:%.*]] = sext i16 [[TMP1]] to i32
 ; CHECK-NEXT:    ret i32 [[RED]]
 ;
   %tr = sext <8 x i16> %a0 to <8 x i32>
@@ -96,8 +96,8 @@ define i32 @reduce_or_sext_v8i16_to_v8i32(<8 x i16> %a0)  {
 
 define i32 @reduce_or_zext_v8i8_to_v8i32(<8 x i8> %a0)  {
 ; CHECK-LABEL: @reduce_or_zext_v8i8_to_v8i32(
-; CHECK-NEXT:    [[TR:%.*]] = zext <8 x i8> [[A0:%.*]] to <8 x i32>
-; CHECK-NEXT:    [[RED:%.*]] = tail call i32 @llvm.vector.reduce.or.v8i32(<8 x i32> [[TR]])
+; CHECK-NEXT:    [[TMP1:%.*]] = call i8 @llvm.vector.reduce.or.v8i8(<8 x i8> [[A0:%.*]])
+; CHECK-NEXT:    [[RED:%.*]] = zext i8 [[TMP1]] to i32
 ; CHECK-NEXT:    ret i32 [[RED]]
 ;
   %tr = zext <8 x i8> %a0 to <8 x i32>
@@ -107,8 +107,8 @@ define i32 @reduce_or_zext_v8i8_to_v8i32(<8 x i8> %a0)  {
 
 define i32 @reduce_or_zext_v8i16_to_v8i32(<8 x i16> %a0)  {
 ; CHECK-LABEL: @reduce_or_zext_v8i16_to_v8i32(
-; CHECK-NEXT:    [[TR:%.*]] = zext <8 x i16> [[A0:%.*]] to <8 x i32>
-; CHECK-NEXT:    [[RED:%.*]] = tail call i32 @llvm.vector.reduce.or.v8i32(<8 x i32> [[TR]])
+; CHECK-NEXT:    [[TMP1:%.*]] = call i16 @llvm.vector.reduce.or.v8i16(<8 x i16> [[A0:%.*]])
+; CHECK-NEXT:    [[RED:%.*]] = zext i16 [[TMP1]] to i32
 ; CHECK-NEXT:    ret i32 [[RED]]
 ;
   %tr = zext <8 x i16> %a0 to <8 x i32>
@@ -116,6 +116,20 @@ define i32 @reduce_or_zext_v8i16_to_v8i32(<8 x i16> %a0)  {
   ret i32 %red
 }
 
+; Negative case - narrowing the reduce (to i8) is illegal.
+; TODO: We could narrow to i16 instead.
+define i32 @reduce_add_trunc_v8i8_to_v8i32(<8 x i8> %a0)  {
+; CHECK-LABEL: @reduce_add_trunc_v8i8_to_v8i32(
+; CHECK-NEXT:    [[TR:%.*]] = zext <8 x i8> [[A0:%.*]] to <8 x i32>
+; CHECK-NEXT:    [[RED:%.*]] = tail call i32 @llvm.vector.reduce.add.v8i32(<8 x i32> [[TR]])
+; CHECK-NEXT:    ret i32 [[RED]]
+;
+  %tr = zext <8 x i8> %a0 to <8 x i32>
+  %red = tail call i32 @llvm.vector.reduce.add.v8i32(<8 x i32> %tr)
+  ret i32 %red
+}
+
+
 declare i32 @llvm.vector.reduce.add.v8i32(<8 x i32>)
 declare i16 @llvm.vector.reduce.add.v8i16(<8 x i16>)
 declare i8 @llvm.vector.reduce.add.v8i8(<8 x i8>)

>From 8b372b677210959e4f9e5759519b759073101646 Mon Sep 17 00:00:00 2001
From: Philip Reames <preames at rivosinc.com>
Date: Thu, 18 Jul 2024 12:03:27 -0700
Subject: [PATCH 2/2] clang-format

---
 llvm/lib/Transforms/Vectorize/VectorCombine.cpp | 17 +++++++----------
 1 file changed, 7 insertions(+), 10 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
index de60d80aeffa1..854bde1491a45 100644
--- a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
+++ b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
@@ -2140,14 +2140,13 @@ bool VectorCombine::foldCastFromReductions(Instruction &I) {
 
   Value *Src;
   if (!match(ReductionSrc, m_OneUse(m_Trunc(m_Value(Src)))) &&
-      (TruncOnly ||
-       !match(ReductionSrc, m_OneUse(m_ZExtOrSExt(m_Value(Src))))))
+      (TruncOnly || !match(ReductionSrc, m_OneUse(m_ZExtOrSExt(m_Value(Src))))))
     return false;
 
   // Note: Only trunc has a constexpr, neither sext or zext do.
   auto CastOpc = Instruction::Trunc;
   if (auto *Cast = dyn_cast<CastInst>(ReductionSrc))
-      CastOpc = (Instruction::CastOps)cast<Instruction>(Cast)->getOpcode();
+    CastOpc = (Instruction::CastOps)cast<Instruction>(Cast)->getOpcode();
 
   auto *SrcTy = cast<VectorType>(Src->getType());
   auto *ReductionSrcTy = cast<VectorType>(ReductionSrc->getType());
@@ -2157,21 +2156,19 @@ bool VectorCombine::foldCastFromReductions(Instruction &I) {
   InstructionCost OldCost = TTI.getArithmeticReductionCost(
       ReductionOpc, ReductionSrcTy, std::nullopt, CostKind);
   if (auto *Cast = dyn_cast<CastInst>(ReductionSrc))
-    OldCost +=
-        TTI.getCastInstrCost(CastOpc, ReductionSrcTy, SrcTy,
-                             TTI::CastContextHint::None, CostKind, Cast);
+    OldCost += TTI.getCastInstrCost(CastOpc, ReductionSrcTy, SrcTy,
+                                    TTI::CastContextHint::None, CostKind, Cast);
   InstructionCost NewCost =
       TTI.getArithmeticReductionCost(ReductionOpc, SrcTy, std::nullopt,
                                      CostKind) +
-      TTI.getCastInstrCost(CastOpc, ResultTy,
-                           ReductionSrcTy->getScalarType(),
+      TTI.getCastInstrCost(CastOpc, ResultTy, ReductionSrcTy->getScalarType(),
                            TTI::CastContextHint::None, CostKind);
 
   if (OldCost <= NewCost || !NewCost.isValid())
     return false;
 
-  Value *NewReduction = Builder.CreateIntrinsic(
-      SrcTy->getScalarType(), II->getIntrinsicID(), {Src});
+  Value *NewReduction = Builder.CreateIntrinsic(SrcTy->getScalarType(),
+                                                II->getIntrinsicID(), {Src});
   Value *NewCast = Builder.CreateCast(CastOpc, NewReduction, ResultTy);
   replaceValue(I, *NewCast);
   return true;



More information about the llvm-commits mailing list