[llvm] ded35c0 - [vectorcombine] Pull sext/zext through reduce.or/and/xor (#99548)

via llvm-commits llvm-commits at lists.llvm.org
Thu Jul 18 13:56:43 PDT 2024


Author: Philip Reames
Date: 2024-07-18T13:56:40-07:00
New Revision: ded35c0c3ad371287e80872d6bd104ce3f7d2864

URL: https://github.com/llvm/llvm-project/commit/ded35c0c3ad371287e80872d6bd104ce3f7d2864
DIFF: https://github.com/llvm/llvm-project/commit/ded35c0c3ad371287e80872d6bd104ce3f7d2864.diff

LOG: [vectorcombine] Pull sext/zext through reduce.or/and/xor (#99548)

This extends the existing foldTruncFromReductions transform to handle
sext and zext as well. This is only legal for the bitwise reductions
(and/or/xor) and not the arithmetic ones (add, mul). Use the same
costing decision to drive whether we do the transform.

Added: 
    

Modified: 
    llvm/lib/Transforms/Vectorize/VectorCombine.cpp
    llvm/test/Transforms/VectorCombine/RISCV/vecreduce-of-cast.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
index 3a49f95d3f117..444598520c981 100644
--- a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
+++ b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
@@ -117,7 +117,7 @@ class VectorCombine {
   bool foldShuffleOfShuffles(Instruction &I);
   bool foldShuffleToIdentity(Instruction &I);
   bool foldShuffleFromReductions(Instruction &I);
-  bool foldTruncFromReductions(Instruction &I);
+  bool foldCastFromReductions(Instruction &I);
   bool foldSelectShuffle(Instruction &I, bool FromReduction = false);
 
   void replaceValue(Value &Old, Value &New) {
@@ -2113,15 +2113,20 @@ bool VectorCombine::foldShuffleFromReductions(Instruction &I) {
 
 /// Determine if its more efficient to fold:
 ///   reduce(trunc(x)) -> trunc(reduce(x)).
-bool VectorCombine::foldTruncFromReductions(Instruction &I) {
+///   reduce(sext(x))  -> sext(reduce(x)).
+///   reduce(zext(x))  -> zext(reduce(x)).
+bool VectorCombine::foldCastFromReductions(Instruction &I) {
   auto *II = dyn_cast<IntrinsicInst>(&I);
   if (!II)
     return false;
 
+  bool TruncOnly = false;
   Intrinsic::ID IID = II->getIntrinsicID();
   switch (IID) {
   case Intrinsic::vector_reduce_add:
   case Intrinsic::vector_reduce_mul:
+    TruncOnly = true;
+    break;
   case Intrinsic::vector_reduce_and:
   case Intrinsic::vector_reduce_or:
   case Intrinsic::vector_reduce_xor:
@@ -2133,35 +2138,37 @@ bool VectorCombine::foldTruncFromReductions(Instruction &I) {
   unsigned ReductionOpc = getArithmeticReductionInstruction(IID);
   Value *ReductionSrc = I.getOperand(0);
 
-  Value *TruncSrc;
-  if (!match(ReductionSrc, m_OneUse(m_Trunc(m_Value(TruncSrc)))))
+  Value *Src;
+  if (!match(ReductionSrc, m_OneUse(m_Trunc(m_Value(Src)))) &&
+      (TruncOnly || !match(ReductionSrc, m_OneUse(m_ZExtOrSExt(m_Value(Src))))))
     return false;
 
-  auto *TruncSrcTy = cast<VectorType>(TruncSrc->getType());
+  auto CastOpc =
+      (Instruction::CastOps)cast<Instruction>(ReductionSrc)->getOpcode();
+
+  auto *SrcTy = cast<VectorType>(Src->getType());
   auto *ReductionSrcTy = cast<VectorType>(ReductionSrc->getType());
   Type *ResultTy = I.getType();
 
   TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
   InstructionCost OldCost = TTI.getArithmeticReductionCost(
       ReductionOpc, ReductionSrcTy, std::nullopt, CostKind);
-  if (auto *Trunc = dyn_cast<CastInst>(ReductionSrc))
-    OldCost +=
-        TTI.getCastInstrCost(Instruction::Trunc, ReductionSrcTy, TruncSrcTy,
-                             TTI::CastContextHint::None, CostKind, Trunc);
+  OldCost += TTI.getCastInstrCost(CastOpc, ReductionSrcTy, SrcTy,
+                                  TTI::CastContextHint::None, CostKind,
+                                  cast<CastInst>(ReductionSrc));
   InstructionCost NewCost =
-      TTI.getArithmeticReductionCost(ReductionOpc, TruncSrcTy, std::nullopt,
+      TTI.getArithmeticReductionCost(ReductionOpc, SrcTy, std::nullopt,
                                      CostKind) +
-      TTI.getCastInstrCost(Instruction::Trunc, ResultTy,
-                           ReductionSrcTy->getScalarType(),
+      TTI.getCastInstrCost(CastOpc, ResultTy, ReductionSrcTy->getScalarType(),
                            TTI::CastContextHint::None, CostKind);
 
   if (OldCost <= NewCost || !NewCost.isValid())
     return false;
 
-  Value *NewReduction = Builder.CreateIntrinsic(
-      TruncSrcTy->getScalarType(), II->getIntrinsicID(), {TruncSrc});
-  Value *NewTruncation = Builder.CreateTrunc(NewReduction, ResultTy);
-  replaceValue(I, *NewTruncation);
+  Value *NewReduction = Builder.CreateIntrinsic(SrcTy->getScalarType(),
+                                                II->getIntrinsicID(), {Src});
+  Value *NewCast = Builder.CreateCast(CastOpc, NewReduction, ResultTy);
+  replaceValue(I, *NewCast);
   return true;
 }
 
@@ -2559,7 +2566,7 @@ bool VectorCombine::run() {
       switch (Opcode) {
       case Instruction::Call:
         MadeChange |= foldShuffleFromReductions(I);
-        MadeChange |= foldTruncFromReductions(I);
+        MadeChange |= foldCastFromReductions(I);
         break;
       case Instruction::ICmp:
       case Instruction::FCmp:

diff  --git a/llvm/test/Transforms/VectorCombine/RISCV/vecreduce-of-cast.ll b/llvm/test/Transforms/VectorCombine/RISCV/vecreduce-of-cast.ll
index 9b1aa19f85c21..f04bcc90e5c35 100644
--- a/llvm/test/Transforms/VectorCombine/RISCV/vecreduce-of-cast.ll
+++ b/llvm/test/Transforms/VectorCombine/RISCV/vecreduce-of-cast.ll
@@ -74,8 +74,8 @@ define i16 @reduce_mul_trunc_v8i64_i16(<8 x i64> %a0)  {
 
 define i32 @reduce_or_sext_v8i8_to_v8i32(<8 x i8> %a0)  {
 ; CHECK-LABEL: @reduce_or_sext_v8i8_to_v8i32(
-; CHECK-NEXT:    [[TR:%.*]] = sext <8 x i8> [[A0:%.*]] to <8 x i32>
-; CHECK-NEXT:    [[RED:%.*]] = tail call i32 @llvm.vector.reduce.or.v8i32(<8 x i32> [[TR]])
+; CHECK-NEXT:    [[TMP1:%.*]] = call i8 @llvm.vector.reduce.or.v8i8(<8 x i8> [[A0:%.*]])
+; CHECK-NEXT:    [[RED:%.*]] = sext i8 [[TMP1]] to i32
 ; CHECK-NEXT:    ret i32 [[RED]]
 ;
   %tr = sext <8 x i8> %a0 to <8 x i32>
@@ -85,8 +85,8 @@ define i32 @reduce_or_sext_v8i8_to_v8i32(<8 x i8> %a0)  {
 
 define i32 @reduce_or_sext_v8i16_to_v8i32(<8 x i16> %a0)  {
 ; CHECK-LABEL: @reduce_or_sext_v8i16_to_v8i32(
-; CHECK-NEXT:    [[TR:%.*]] = sext <8 x i16> [[A0:%.*]] to <8 x i32>
-; CHECK-NEXT:    [[RED:%.*]] = tail call i32 @llvm.vector.reduce.or.v8i32(<8 x i32> [[TR]])
+; CHECK-NEXT:    [[TMP1:%.*]] = call i16 @llvm.vector.reduce.or.v8i16(<8 x i16> [[A0:%.*]])
+; CHECK-NEXT:    [[RED:%.*]] = sext i16 [[TMP1]] to i32
 ; CHECK-NEXT:    ret i32 [[RED]]
 ;
   %tr = sext <8 x i16> %a0 to <8 x i32>
@@ -96,8 +96,8 @@ define i32 @reduce_or_sext_v8i16_to_v8i32(<8 x i16> %a0)  {
 
 define i32 @reduce_or_zext_v8i8_to_v8i32(<8 x i8> %a0)  {
 ; CHECK-LABEL: @reduce_or_zext_v8i8_to_v8i32(
-; CHECK-NEXT:    [[TR:%.*]] = zext <8 x i8> [[A0:%.*]] to <8 x i32>
-; CHECK-NEXT:    [[RED:%.*]] = tail call i32 @llvm.vector.reduce.or.v8i32(<8 x i32> [[TR]])
+; CHECK-NEXT:    [[TMP1:%.*]] = call i8 @llvm.vector.reduce.or.v8i8(<8 x i8> [[A0:%.*]])
+; CHECK-NEXT:    [[RED:%.*]] = zext i8 [[TMP1]] to i32
 ; CHECK-NEXT:    ret i32 [[RED]]
 ;
   %tr = zext <8 x i8> %a0 to <8 x i32>
@@ -107,8 +107,8 @@ define i32 @reduce_or_zext_v8i8_to_v8i32(<8 x i8> %a0)  {
 
 define i32 @reduce_or_zext_v8i16_to_v8i32(<8 x i16> %a0)  {
 ; CHECK-LABEL: @reduce_or_zext_v8i16_to_v8i32(
-; CHECK-NEXT:    [[TR:%.*]] = zext <8 x i16> [[A0:%.*]] to <8 x i32>
-; CHECK-NEXT:    [[RED:%.*]] = tail call i32 @llvm.vector.reduce.or.v8i32(<8 x i32> [[TR]])
+; CHECK-NEXT:    [[TMP1:%.*]] = call i16 @llvm.vector.reduce.or.v8i16(<8 x i16> [[A0:%.*]])
+; CHECK-NEXT:    [[RED:%.*]] = zext i16 [[TMP1]] to i32
 ; CHECK-NEXT:    ret i32 [[RED]]
 ;
   %tr = zext <8 x i16> %a0 to <8 x i32>
@@ -116,6 +116,20 @@ define i32 @reduce_or_zext_v8i16_to_v8i32(<8 x i16> %a0)  {
   ret i32 %red
 }
 
+; Negative case - narrowing the reduce (to i8) is illegal.
+; TODO: We could narrow to i16 instead.
+define i32 @reduce_add_trunc_v8i8_to_v8i32(<8 x i8> %a0)  {
+; CHECK-LABEL: @reduce_add_trunc_v8i8_to_v8i32(
+; CHECK-NEXT:    [[TR:%.*]] = zext <8 x i8> [[A0:%.*]] to <8 x i32>
+; CHECK-NEXT:    [[RED:%.*]] = tail call i32 @llvm.vector.reduce.add.v8i32(<8 x i32> [[TR]])
+; CHECK-NEXT:    ret i32 [[RED]]
+;
+  %tr = zext <8 x i8> %a0 to <8 x i32>
+  %red = tail call i32 @llvm.vector.reduce.add.v8i32(<8 x i32> %tr)
+  ret i32 %red
+}
+
+
 declare i32 @llvm.vector.reduce.add.v8i32(<8 x i32>)
 declare i16 @llvm.vector.reduce.add.v8i16(<8 x i16>)
 declare i8 @llvm.vector.reduce.add.v8i8(<8 x i8>)


        


More information about the llvm-commits mailing list