[llvm] 2ed2a3a - [Transforms][Utils] Add helpers to map between Reduction IntrinsicID and Arithmetic Instruction Opcode and MinMax IntrinsicID / RecurKind

Simon Pilgrim via llvm-commits llvm-commits at lists.llvm.org
Fri Feb 16 03:20:47 PST 2024


Author: Simon Pilgrim
Date: 2024-02-16T11:20:34Z
New Revision: 2ed2a3ad90934efac12cbeb01cf73afebc01d963

URL: https://github.com/llvm/llvm-project/commit/2ed2a3ad90934efac12cbeb01cf73afebc01d963
DIFF: https://github.com/llvm/llvm-project/commit/2ed2a3ad90934efac12cbeb01cf73afebc01d963.diff

LOG: [Transforms][Utils] Add helpers to map between Reduction IntrinsicID and Arithmetic Instruction Opcode and MinMax IntrinsicID / RecurKind

Noticed on #81852

Added: 
    

Modified: 
    llvm/include/llvm/CodeGen/BasicTTIImpl.h
    llvm/include/llvm/Transforms/Utils/LoopUtils.h
    llvm/lib/CodeGen/ExpandReductions.cpp
    llvm/lib/Transforms/Utils/LoopUtils.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/CodeGen/BasicTTIImpl.h b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
index bb17298daba03a..61f6564e8cd79b 100644
--- a/llvm/include/llvm/CodeGen/BasicTTIImpl.h
+++ b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
@@ -48,6 +48,7 @@
 #include "llvm/Support/MathExtras.h"
 #include "llvm/Target/TargetMachine.h"
 #include "llvm/Target/TargetOptions.h"
+#include "llvm/Transforms/Utils/LoopUtils.h"
 #include <algorithm>
 #include <cassert>
 #include <cstdint>
@@ -2013,50 +2014,27 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
                                             CostKind);
     }
     case Intrinsic::vector_reduce_add:
-      return thisT()->getArithmeticReductionCost(Instruction::Add, VecOpTy,
-                                                 std::nullopt, CostKind);
     case Intrinsic::vector_reduce_mul:
-      return thisT()->getArithmeticReductionCost(Instruction::Mul, VecOpTy,
-                                                 std::nullopt, CostKind);
     case Intrinsic::vector_reduce_and:
-      return thisT()->getArithmeticReductionCost(Instruction::And, VecOpTy,
-                                                 std::nullopt, CostKind);
     case Intrinsic::vector_reduce_or:
-      return thisT()->getArithmeticReductionCost(Instruction::Or, VecOpTy,
-                                                 std::nullopt, CostKind);
     case Intrinsic::vector_reduce_xor:
-      return thisT()->getArithmeticReductionCost(Instruction::Xor, VecOpTy,
-                                                 std::nullopt, CostKind);
+      return thisT()->getArithmeticReductionCost(
+          getArithmeticReductionInstruction(IID), VecOpTy, std::nullopt,
+          CostKind);
     case Intrinsic::vector_reduce_fadd:
-      return thisT()->getArithmeticReductionCost(Instruction::FAdd, VecOpTy,
-                                                 FMF, CostKind);
     case Intrinsic::vector_reduce_fmul:
-      return thisT()->getArithmeticReductionCost(Instruction::FMul, VecOpTy,
-                                                 FMF, CostKind);
+      return thisT()->getArithmeticReductionCost(
+          getArithmeticReductionInstruction(IID), VecOpTy, FMF, CostKind);
     case Intrinsic::vector_reduce_smax:
-      return thisT()->getMinMaxReductionCost(Intrinsic::smax, VecOpTy,
-                                             ICA.getFlags(), CostKind);
     case Intrinsic::vector_reduce_smin:
-      return thisT()->getMinMaxReductionCost(Intrinsic::smin, VecOpTy,
-                                             ICA.getFlags(), CostKind);
     case Intrinsic::vector_reduce_umax:
-      return thisT()->getMinMaxReductionCost(Intrinsic::umax, VecOpTy,
-                                             ICA.getFlags(), CostKind);
     case Intrinsic::vector_reduce_umin:
-      return thisT()->getMinMaxReductionCost(Intrinsic::umin, VecOpTy,
-                                             ICA.getFlags(), CostKind);
     case Intrinsic::vector_reduce_fmax:
-      return thisT()->getMinMaxReductionCost(Intrinsic::maxnum, VecOpTy,
-                                             ICA.getFlags(), CostKind);
     case Intrinsic::vector_reduce_fmin:
-      return thisT()->getMinMaxReductionCost(Intrinsic::minnum, VecOpTy,
-                                             ICA.getFlags(), CostKind);
     case Intrinsic::vector_reduce_fmaximum:
-      return thisT()->getMinMaxReductionCost(Intrinsic::maximum, VecOpTy,
-                                             ICA.getFlags(), CostKind);
     case Intrinsic::vector_reduce_fminimum:
-      return thisT()->getMinMaxReductionCost(Intrinsic::minimum, VecOpTy,
-                                             ICA.getFlags(), CostKind);
+      return thisT()->getMinMaxReductionCost(getMinMaxReductionIntrinsicOp(IID),
+                                             VecOpTy, ICA.getFlags(), CostKind);
     case Intrinsic::abs: {
       // abs(X) = select(icmp(X,0),X,sub(0,X))
       Type *CondTy = RetTy->getWithNewBitWidth(1);

diff  --git a/llvm/include/llvm/Transforms/Utils/LoopUtils.h b/llvm/include/llvm/Transforms/Utils/LoopUtils.h
index 5a1385d01d8e44..187ace3a0cbedf 100644
--- a/llvm/include/llvm/Transforms/Utils/LoopUtils.h
+++ b/llvm/include/llvm/Transforms/Utils/LoopUtils.h
@@ -357,9 +357,18 @@ bool canSinkOrHoistInst(Instruction &I, AAResults *AA, DominatorTree *DT,
                         SinkAndHoistLICMFlags &LICMFlags,
                         OptimizationRemarkEmitter *ORE = nullptr);
 
+/// Returns the arithmetic instruction opcode used when expanding a reduction.
+unsigned getArithmeticReductionInstruction(Intrinsic::ID RdxID);
+
+/// Returns the min/max intrinsic used when expanding a min/max reduction.
+Intrinsic::ID getMinMaxReductionIntrinsicOp(Intrinsic::ID RdxID);
+
 /// Returns the min/max intrinsic used when expanding a min/max reduction.
 Intrinsic::ID getMinMaxReductionIntrinsicOp(RecurKind RK);
 
+/// Returns the recurence kind used when expanding a min/max reduction.
+RecurKind getMinMaxReductionRecurKind(Intrinsic::ID RdxID);
+
 /// Returns the comparison predicate used when expanding a min/max reduction.
 CmpInst::Predicate getMinMaxReductionPredicate(RecurKind RK);
 

diff  --git a/llvm/lib/CodeGen/ExpandReductions.cpp b/llvm/lib/CodeGen/ExpandReductions.cpp
index 79b6dc9154b3fc..0b1504e51b1bbb 100644
--- a/llvm/lib/CodeGen/ExpandReductions.cpp
+++ b/llvm/lib/CodeGen/ExpandReductions.cpp
@@ -26,54 +26,6 @@ using namespace llvm;
 
 namespace {
 
-unsigned getOpcode(Intrinsic::ID ID) {
-  switch (ID) {
-  case Intrinsic::vector_reduce_fadd:
-    return Instruction::FAdd;
-  case Intrinsic::vector_reduce_fmul:
-    return Instruction::FMul;
-  case Intrinsic::vector_reduce_add:
-    return Instruction::Add;
-  case Intrinsic::vector_reduce_mul:
-    return Instruction::Mul;
-  case Intrinsic::vector_reduce_and:
-    return Instruction::And;
-  case Intrinsic::vector_reduce_or:
-    return Instruction::Or;
-  case Intrinsic::vector_reduce_xor:
-    return Instruction::Xor;
-  case Intrinsic::vector_reduce_smax:
-  case Intrinsic::vector_reduce_smin:
-  case Intrinsic::vector_reduce_umax:
-  case Intrinsic::vector_reduce_umin:
-    return Instruction::ICmp;
-  case Intrinsic::vector_reduce_fmax:
-  case Intrinsic::vector_reduce_fmin:
-    return Instruction::FCmp;
-  default:
-    llvm_unreachable("Unexpected ID");
-  }
-}
-
-RecurKind getRK(Intrinsic::ID ID) {
-  switch (ID) {
-  case Intrinsic::vector_reduce_smax:
-    return RecurKind::SMax;
-  case Intrinsic::vector_reduce_smin:
-    return RecurKind::SMin;
-  case Intrinsic::vector_reduce_umax:
-    return RecurKind::UMax;
-  case Intrinsic::vector_reduce_umin:
-    return RecurKind::UMin;
-  case Intrinsic::vector_reduce_fmax:
-    return RecurKind::FMax;
-  case Intrinsic::vector_reduce_fmin:
-    return RecurKind::FMin;
-  default:
-    return RecurKind::None;
-  }
-}
-
 bool expandReductions(Function &F, const TargetTransformInfo *TTI) {
   bool Changed = false;
   SmallVector<IntrinsicInst *, 4> Worklist;
@@ -106,7 +58,7 @@ bool expandReductions(Function &F, const TargetTransformInfo *TTI) {
     FastMathFlags FMF =
         isa<FPMathOperator>(II) ? II->getFastMathFlags() : FastMathFlags{};
     Intrinsic::ID ID = II->getIntrinsicID();
-    RecurKind RK = getRK(ID);
+    RecurKind RK = getMinMaxReductionRecurKind(ID);
 
     Value *Rdx = nullptr;
     IRBuilder<> Builder(II);
@@ -120,16 +72,16 @@ bool expandReductions(Function &F, const TargetTransformInfo *TTI) {
       // and it can't be handled by generating a shuffle sequence.
       Value *Acc = II->getArgOperand(0);
       Value *Vec = II->getArgOperand(1);
+      unsigned RdxOpcode = getArithmeticReductionInstruction(ID);
       if (!FMF.allowReassoc())
-        Rdx = getOrderedReduction(Builder, Acc, Vec, getOpcode(ID), RK);
+        Rdx = getOrderedReduction(Builder, Acc, Vec, RdxOpcode, RK);
       else {
         if (!isPowerOf2_32(
                 cast<FixedVectorType>(Vec->getType())->getNumElements()))
           continue;
-
-        Rdx = getShuffleReduction(Builder, Vec, getOpcode(ID), RK);
-        Rdx = Builder.CreateBinOp((Instruction::BinaryOps)getOpcode(ID),
-                                  Acc, Rdx, "bin.rdx");
+        Rdx = getShuffleReduction(Builder, Vec, RdxOpcode, RK);
+        Rdx = Builder.CreateBinOp((Instruction::BinaryOps)RdxOpcode, Acc, Rdx,
+                                  "bin.rdx");
       }
       break;
     }
@@ -159,8 +111,8 @@ bool expandReductions(Function &F, const TargetTransformInfo *TTI) {
         }
         break;
       }
-
-      Rdx = getShuffleReduction(Builder, Vec, getOpcode(ID), RK);
+      unsigned RdxOpcode = getArithmeticReductionInstruction(ID);
+      Rdx = getShuffleReduction(Builder, Vec, RdxOpcode, RK);
       break;
     }
     case Intrinsic::vector_reduce_add:
@@ -174,8 +126,8 @@ bool expandReductions(Function &F, const TargetTransformInfo *TTI) {
       if (!isPowerOf2_32(
               cast<FixedVectorType>(Vec->getType())->getNumElements()))
         continue;
-
-      Rdx = getShuffleReduction(Builder, Vec, getOpcode(ID), RK);
+      unsigned RdxOpcode = getArithmeticReductionInstruction(ID);
+      Rdx = getShuffleReduction(Builder, Vec, RdxOpcode, RK);
       break;
     }
     case Intrinsic::vector_reduce_fmax:
@@ -187,8 +139,8 @@ bool expandReductions(Function &F, const TargetTransformInfo *TTI) {
               cast<FixedVectorType>(Vec->getType())->getNumElements()) ||
           !FMF.noNaNs())
         continue;
-
-      Rdx = getShuffleReduction(Builder, Vec, getOpcode(ID), RK);
+      unsigned RdxOpcode = getArithmeticReductionInstruction(ID);
+      Rdx = getShuffleReduction(Builder, Vec, RdxOpcode, RK);
       break;
     }
     }

diff  --git a/llvm/lib/Transforms/Utils/LoopUtils.cpp b/llvm/lib/Transforms/Utils/LoopUtils.cpp
index 59485126b280ab..002bc90c9b5677 100644
--- a/llvm/lib/Transforms/Utils/LoopUtils.cpp
+++ b/llvm/lib/Transforms/Utils/LoopUtils.cpp
@@ -917,6 +917,58 @@ bool llvm::hasIterationCountInvariantInParent(Loop *InnerLoop,
   return true;
 }
 
+unsigned llvm::getArithmeticReductionInstruction(Intrinsic::ID RdxID) {
+  switch (RdxID) {
+  case Intrinsic::vector_reduce_fadd:
+    return Instruction::FAdd;
+  case Intrinsic::vector_reduce_fmul:
+    return Instruction::FMul;
+  case Intrinsic::vector_reduce_add:
+    return Instruction::Add;
+  case Intrinsic::vector_reduce_mul:
+    return Instruction::Mul;
+  case Intrinsic::vector_reduce_and:
+    return Instruction::And;
+  case Intrinsic::vector_reduce_or:
+    return Instruction::Or;
+  case Intrinsic::vector_reduce_xor:
+    return Instruction::Xor;
+  case Intrinsic::vector_reduce_smax:
+  case Intrinsic::vector_reduce_smin:
+  case Intrinsic::vector_reduce_umax:
+  case Intrinsic::vector_reduce_umin:
+    return Instruction::ICmp;
+  case Intrinsic::vector_reduce_fmax:
+  case Intrinsic::vector_reduce_fmin:
+    return Instruction::FCmp;
+  default:
+    llvm_unreachable("Unexpected ID");
+  }
+}
+
+Intrinsic::ID llvm::getMinMaxReductionIntrinsicOp(Intrinsic::ID RdxID) {
+  switch (RdxID) {
+  default:
+    llvm_unreachable("Unknown min/max recurrence kind");
+  case Intrinsic::vector_reduce_umin:
+    return Intrinsic::umin;
+  case Intrinsic::vector_reduce_umax:
+    return Intrinsic::umax;
+  case Intrinsic::vector_reduce_smin:
+    return Intrinsic::smin;
+  case Intrinsic::vector_reduce_smax:
+    return Intrinsic::smax;
+  case Intrinsic::vector_reduce_fmin:
+    return Intrinsic::minnum;
+  case Intrinsic::vector_reduce_fmax:
+    return Intrinsic::maxnum;
+  case Intrinsic::vector_reduce_fminimum:
+    return Intrinsic::minimum;
+  case Intrinsic::vector_reduce_fmaximum:
+    return Intrinsic::maximum;
+  }
+}
+
 Intrinsic::ID llvm::getMinMaxReductionIntrinsicOp(RecurKind RK) {
   switch (RK) {
   default:
@@ -940,6 +992,25 @@ Intrinsic::ID llvm::getMinMaxReductionIntrinsicOp(RecurKind RK) {
   }
 }
 
+RecurKind llvm::getMinMaxReductionRecurKind(Intrinsic::ID RdxID) {
+  switch (RdxID) {
+  case Intrinsic::vector_reduce_smax:
+    return RecurKind::SMax;
+  case Intrinsic::vector_reduce_smin:
+    return RecurKind::SMin;
+  case Intrinsic::vector_reduce_umax:
+    return RecurKind::UMax;
+  case Intrinsic::vector_reduce_umin:
+    return RecurKind::UMin;
+  case Intrinsic::vector_reduce_fmax:
+    return RecurKind::FMax;
+  case Intrinsic::vector_reduce_fmin:
+    return RecurKind::FMin;
+  default:
+    return RecurKind::None;
+  }
+}
+
 CmpInst::Predicate llvm::getMinMaxReductionPredicate(RecurKind RK) {
   switch (RK) {
   default:


        


More information about the llvm-commits mailing list