[llvm] 0b2f253 - [LV] Separate AnyOf recurrence from getRecurrenceIdentity [NFC]

Philip Reames via llvm-commits llvm-commits at lists.llvm.org
Tue Sep 3 09:46:48 PDT 2024


Author: Philip Reames
Date: 2024-09-03T09:46:30-07:00
New Revision: 0b2f2537a5b717539b200bd7fa31cbc24679e96f

URL: https://github.com/llvm/llvm-project/commit/0b2f2537a5b717539b200bd7fa31cbc24679e96f
DIFF: https://github.com/llvm/llvm-project/commit/0b2f2537a5b717539b200bd7fa31cbc24679e96f.diff

LOG: [LV] Separate AnyOf recurrence from getRecurrenceIdentity [NFC]

These recurrence types don't have a meaningful identity, and the
routine was abused to return the start value instead.  Out of the
three callers to this routine, only one actually wants this
behavior.  This is a prep change for removing the routine entirely
and commoning it with other copies of the same logic.

Added: 
    

Modified: 
    llvm/include/llvm/Analysis/IVDescriptors.h
    llvm/lib/Analysis/IVDescriptors.cpp
    llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Analysis/IVDescriptors.h b/llvm/include/llvm/Analysis/IVDescriptors.h
index 379b114b79cdfa..e7e6c5c01ad4db 100644
--- a/llvm/include/llvm/Analysis/IVDescriptors.h
+++ b/llvm/include/llvm/Analysis/IVDescriptors.h
@@ -156,7 +156,7 @@ class RecurrenceDescriptor {
   static InstDesc isConditionalRdxPattern(RecurKind Kind, Instruction *I);
 
   /// Returns identity corresponding to the RecurrenceKind.
-  Value *getRecurrenceIdentity(RecurKind K, Type *Tp, FastMathFlags FMF) const;
+  static Value *getRecurrenceIdentity(RecurKind K, Type *Tp, FastMathFlags FMF);
 
   /// Returns the opcode corresponding to the RecurrenceKind.
   static unsigned getOpcode(RecurKind Kind);

diff  --git a/llvm/lib/Analysis/IVDescriptors.cpp b/llvm/lib/Analysis/IVDescriptors.cpp
index ba3619417114c7..917c5f53e0d08d 100644
--- a/llvm/lib/Analysis/IVDescriptors.cpp
+++ b/llvm/lib/Analysis/IVDescriptors.cpp
@@ -1035,7 +1035,7 @@ bool RecurrenceDescriptor::isFixedOrderRecurrence(PHINode *Phi, Loop *TheLoop,
 /// This function returns the identity element (or neutral element) for
 /// the operation K.
 Value *RecurrenceDescriptor::getRecurrenceIdentity(RecurKind K, Type *Tp,
-                                                   FastMathFlags FMF) const {
+                                                   FastMathFlags FMF) {
   switch (K) {
   case RecurKind::Xor:
   case RecurKind::Add:
@@ -1071,8 +1071,7 @@ Value *RecurrenceDescriptor::getRecurrenceIdentity(RecurKind K, Type *Tp,
     return ConstantFP::getInfinity(Tp, true /*Negative*/);
   case RecurKind::IAnyOf:
   case RecurKind::FAnyOf:
-    return getRecurrenceStartValue();
-    break;
+    llvm_unreachable("No meaningful identity for recurrence kind");
   default:
     llvm_unreachable("Unknown recurrence kind");
   }

diff  --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index c91ea2d1663c87..0d3d0febfea1ba 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -1829,13 +1829,18 @@ void VPReductionRecipe::execute(VPTransformState &State) {
       Value *NewCond = State.get(Cond, Part, State.VF.isScalar());
       VectorType *VecTy = dyn_cast<VectorType>(NewVecOp->getType());
       Type *ElementTy = VecTy ? VecTy->getElementType() : NewVecOp->getType();
-      Value *Iden = RdxDesc.getRecurrenceIdentity(Kind, ElementTy,
-                                                  RdxDesc.getFastMathFlags());
-      if (State.VF.isVector()) {
-        Iden = State.Builder.CreateVectorSplat(VecTy->getElementCount(), Iden);
-      }
 
-      Value *Select = State.Builder.CreateSelect(NewCond, NewVecOp, Iden);
+      Value *Start;
+      if (RecurrenceDescriptor::isAnyOfRecurrenceKind(Kind))
+        Start = RdxDesc.getRecurrenceStartValue();
+      else
+        Start = RdxDesc.getRecurrenceIdentity(Kind, ElementTy,
+                                              RdxDesc.getFastMathFlags());
+      if (State.VF.isVector())
+        Start = State.Builder.CreateVectorSplat(VecTy->getElementCount(),
+                                                Start);
+
+      Value *Select = State.Builder.CreateSelect(NewCond, NewVecOp, Start);
       NewVecOp = Select;
     }
     Value *NewRed;


        


More information about the llvm-commits mailing list