[llvm] fd66195 - [VPlan] Manage compare predicates in VPRecipeWithIRFlags.

Florian Hahn via llvm-commits llvm-commits at lists.llvm.org
Sat Sep 2 13:45:52 PDT 2023


Author: Florian Hahn
Date: 2023-09-02T21:45:24+01:00
New Revision: fd6619577790f049ce72fa043351028d8c132d58

URL: https://github.com/llvm/llvm-project/commit/fd6619577790f049ce72fa043351028d8c132d58
DIFF: https://github.com/llvm/llvm-project/commit/fd6619577790f049ce72fa043351028d8c132d58.diff

LOG: [VPlan] Manage compare predicates in VPRecipeWithIRFlags.

Extend VPRecipeWithIRFlags to also manage predicates for compares. This
allows removing the custom ICmpULE opcode from VPInstruction which was a
workaround for missing proper predicate handling.

This simplifies the code a bit while also allowing compares with any
predicates. It also fixes a case where the compare predixcate wasn't
printed properly for VPReplicateRecipes.

Discussed/split off from D150398.

Reviewed By: Ayal

Differential Revision: https://reviews.llvm.org/D158992

Added: 
    

Modified: 
    llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h
    llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
    llvm/lib/Transforms/Vectorize/VPlan.h
    llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
    llvm/test/Transforms/LoopVectorize/vplan-sink-scalars-and-merge.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h b/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h
index a0d496babad3cc..2df8fea85454ea 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h
@@ -45,13 +45,17 @@ class VPBuilder {
   VPBasicBlock *BB = nullptr;
   VPBasicBlock::iterator InsertPt = VPBasicBlock::iterator();
 
+  /// Insert \p VPI in BB at InsertPt if BB is set.
+  VPInstruction *tryInsertInstruction(VPInstruction *VPI) {
+    if (BB)
+      BB->insert(VPI, InsertPt);
+    return VPI;
+  }
+
   VPInstruction *createInstruction(unsigned Opcode,
                                    ArrayRef<VPValue *> Operands, DebugLoc DL,
                                    const Twine &Name = "") {
-    VPInstruction *Instr = new VPInstruction(Opcode, Operands, DL, Name);
-    if (BB)
-      BB->insert(Instr, InsertPt);
-    return Instr;
+    return tryInsertInstruction(new VPInstruction(Opcode, Operands, DL, Name));
   }
 
   VPInstruction *createInstruction(unsigned Opcode,
@@ -152,6 +156,12 @@ class VPBuilder {
                         Name);
   }
 
+  /// Create a new ICmp VPInstruction with predicate \p Pred and operands \p A
+  /// and \p B.
+  /// TODO: add createFCmp when needed.
+  VPValue *createICmp(CmpInst::Predicate Pred, VPValue *A, VPValue *B,
+                      DebugLoc DL = {}, const Twine &Name = "");
+
   //===--------------------------------------------------------------------===//
   // RAII helpers.
   //===--------------------------------------------------------------------===//

diff  --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index 750f8ff22a22ab..03a5b08085b19c 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -7372,6 +7372,14 @@ void LoopVectorizationCostModel::collectInLoopReductions() {
   }
 }
 
+VPValue *VPBuilder::createICmp(CmpInst::Predicate Pred, VPValue *A, VPValue *B,
+                               DebugLoc DL, const Twine &Name) {
+  assert(Pred >= CmpInst::FIRST_ICMP_PREDICATE &&
+         Pred <= CmpInst::LAST_ICMP_PREDICATE && "invalid predicate");
+  return tryInsertInstruction(
+      new VPInstruction(Instruction::ICmp, Pred, A, B, DL, Name));
+}
+
 // TODO: we could return a pair of values that specify the max VF and
 // min VF, to be used in `buildVPlans(MinVF, MaxVF)` instead of
 // `buildVPlans(VF, VF)`. We cannot do it because VPLAN at the moment
@@ -8079,7 +8087,7 @@ void VPRecipeBuilder::createHeaderMask(VPlan &Plan) {
                                      nullptr, "active.lane.mask");
   } else {
     VPValue *BTC = Plan.getOrCreateBackedgeTakenCount();
-    BlockMask = Builder.createNaryOp(VPInstruction::ICmpULE, {IV, BTC});
+    BlockMask = Builder.createICmp(CmpInst::ICMP_ULE, IV, BTC);
   }
   BlockMaskCache[Header] = BlockMask;
 }

diff  --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index 0d45882adb3586..8a497e82e1181b 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -814,6 +814,7 @@ class VPRecipeBase : public ilist_node_with_parent<VPRecipeBase, VPBasicBlock>,
 /// Class to record LLVM IR flag for a recipe along with it.
 class VPRecipeWithIRFlags : public VPRecipeBase {
   enum class OperationType : unsigned char {
+    Cmp,
     OverflowingBinOp,
     PossiblyExactOp,
     GEPOp,
@@ -851,11 +852,12 @@ class VPRecipeWithIRFlags : public VPRecipeBase {
   OperationType OpType;
 
   union {
+    CmpInst::Predicate CmpPredicate;
     WrapFlagsTy WrapFlags;
     ExactFlagsTy ExactFlags;
     GEPFlagsTy GEPFlags;
     FastMathFlagsTy FMFs;
-    unsigned char AllFlags;
+    unsigned AllFlags;
   };
 
 public:
@@ -869,7 +871,10 @@ class VPRecipeWithIRFlags : public VPRecipeBase {
   template <typename IterT>
   VPRecipeWithIRFlags(const unsigned char SC, IterT Operands, Instruction &I)
       : VPRecipeWithIRFlags(SC, Operands) {
-    if (auto *Op = dyn_cast<OverflowingBinaryOperator>(&I)) {
+    if (auto *Op = dyn_cast<CmpInst>(&I)) {
+      OpType = OperationType::Cmp;
+      CmpPredicate = Op->getPredicate();
+    } else if (auto *Op = dyn_cast<OverflowingBinaryOperator>(&I)) {
       OpType = OperationType::OverflowingBinOp;
       WrapFlags = {Op->hasNoUnsignedWrap(), Op->hasNoSignedWrap()};
     } else if (auto *Op = dyn_cast<PossiblyExactOperator>(&I)) {
@@ -884,6 +889,12 @@ class VPRecipeWithIRFlags : public VPRecipeBase {
     }
   }
 
+  template <typename IterT>
+  VPRecipeWithIRFlags(const unsigned char SC, IterT Operands,
+                      CmpInst::Predicate Pred)
+      : VPRecipeBase(SC, Operands), OpType(OperationType::Cmp),
+        CmpPredicate(Pred) {}
+
   template <typename IterT>
   VPRecipeWithIRFlags(const unsigned char SC, IterT Operands,
                       WrapFlagsTy WrapFlags)
@@ -922,6 +933,7 @@ class VPRecipeWithIRFlags : public VPRecipeBase {
       FMFs.NoNaNs = false;
       FMFs.NoInfs = false;
       break;
+    case OperationType::Cmp:
     case OperationType::Other:
       break;
     }
@@ -949,11 +961,18 @@ class VPRecipeWithIRFlags : public VPRecipeBase {
       I->setHasAllowContract(FMFs.AllowContract);
       I->setHasApproxFunc(FMFs.ApproxFunc);
       break;
+    case OperationType::Cmp:
     case OperationType::Other:
       break;
     }
   }
 
+  CmpInst::Predicate getPredicate() const {
+    assert(OpType == OperationType::Cmp &&
+           "recipe doesn't have a compare predicate");
+    return CmpPredicate;
+  }
+
   bool isInBounds() const {
     assert(OpType == OperationType::GEPOp &&
            "recipe doesn't have inbounds flag");
@@ -996,7 +1015,6 @@ class VPInstruction : public VPRecipeWithIRFlags, public VPValue {
         Instruction::OtherOpsEnd + 1, // Combines the incoming and previous
                                       // values of a first-order recurrence.
     Not,
-    ICmpULE,
     SLPLoad,
     SLPStore,
     ActiveLaneMask,
@@ -1042,6 +1060,9 @@ class VPInstruction : public VPRecipeWithIRFlags, public VPValue {
                 DebugLoc DL = {}, const Twine &Name = "")
       : VPInstruction(Opcode, ArrayRef<VPValue *>(Operands), DL, Name) {}
 
+  VPInstruction(unsigned Opcode, CmpInst::Predicate Pred, VPValue *A,
+                VPValue *B, DebugLoc DL = {}, const Twine &Name = "");
+
   VPInstruction(unsigned Opcode, std::initializer_list<VPValue *> Operands,
                 WrapFlagsTy WrapFlags, DebugLoc DL = {}, const Twine &Name = "")
       : VPRecipeWithIRFlags(VPDef::VPInstructionSC, Operands, WrapFlags),

diff  --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index 7964d603b168a3..9302b9974b183b 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -116,8 +116,8 @@ bool VPRecipeBase::mayHaveSideEffects() const {
     return false;
   case VPInstructionSC:
     switch (cast<VPInstruction>(this)->getOpcode()) {
+    case Instruction::ICmp:
     case VPInstruction::Not:
-    case VPInstruction::ICmpULE:
     case VPInstruction::CalculateTripCountMinusVF:
     case VPInstruction::CanonicalIVIncrement:
     case VPInstruction::CanonicalIVIncrementForPart:
@@ -246,6 +246,16 @@ FastMathFlags VPRecipeWithIRFlags::getFastMathFlags() const {
   return Res;
 }
 
+VPInstruction::VPInstruction(unsigned Opcode, CmpInst::Predicate Pred,
+                             VPValue *A, VPValue *B, DebugLoc DL,
+                             const Twine &Name)
+    : VPRecipeWithIRFlags(VPDef::VPInstructionSC, ArrayRef<VPValue *>({A, B}),
+                          Pred),
+      VPValue(this), Opcode(Opcode), DL(DL), Name(Name.str()) {
+  assert(Opcode == Instruction::ICmp &&
+         "only ICmp predicates supported at the moment");
+}
+
 VPInstruction::VPInstruction(unsigned Opcode,
                              std::initializer_list<VPValue *> Operands,
                              FastMathFlags FMFs, DebugLoc DL, const Twine &Name)
@@ -271,10 +281,10 @@ Value *VPInstruction::generateInstruction(VPTransformState &State,
     Value *A = State.get(getOperand(0), Part);
     return Builder.CreateNot(A, Name);
   }
-  case VPInstruction::ICmpULE: {
+  case Instruction::ICmp: {
     Value *A = State.get(getOperand(0), Part);
     Value *B = State.get(getOperand(1), Part);
-    return Builder.CreateICmpULE(A, B, Name);
+    return Builder.CreateCmp(getPredicate(), A, B, Name);
   }
   case Instruction::Select: {
     Value *Cond = State.get(getOperand(0), Part);
@@ -444,9 +454,6 @@ void VPInstruction::print(raw_ostream &O, const Twine &Indent,
   case VPInstruction::Not:
     O << "not";
     break;
-  case VPInstruction::ICmpULE:
-    O << "icmp ule";
-    break;
   case VPInstruction::SLPLoad:
     O << "combined load";
     break;
@@ -618,6 +625,9 @@ VPRecipeWithIRFlags::FastMathFlagsTy::FastMathFlagsTy(
 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
 void VPRecipeWithIRFlags::printFlags(raw_ostream &O) const {
   switch (OpType) {
+  case OperationType::Cmp:
+    O << " " << CmpInst::getPredicateName(getPredicate());
+    break;
   case OperationType::PossiblyExactOp:
     if (ExactFlags.IsExact)
       O << " exact";
@@ -741,8 +751,6 @@ void VPWidenRecipe::print(raw_ostream &O, const Twine &Indent,
   const Instruction *UI = getUnderlyingInstr();
   O << " = " << UI->getOpcodeName();
   printFlags(O);
-  if (auto *Cmp = dyn_cast<CmpInst>(UI))
-    O << Cmp->getPredicate() << " ";
   printOperands(O, SlotTracker);
 }
 #endif

diff  --git a/llvm/test/Transforms/LoopVectorize/vplan-sink-scalars-and-merge.ll b/llvm/test/Transforms/LoopVectorize/vplan-sink-scalars-and-merge.ll
index 4d8190319b62e2..06061465939002 100644
--- a/llvm/test/Transforms/LoopVectorize/vplan-sink-scalars-and-merge.ll
+++ b/llvm/test/Transforms/LoopVectorize/vplan-sink-scalars-and-merge.ll
@@ -906,7 +906,7 @@ define void @update_multiple_users(ptr noalias %src, ptr noalias %dst, i1 %c) {
 ; CHECK-NEXT:   pred.store.if:
 ; CHECK-NEXT:     REPLICATE ir<%l1> = load ir<%src>
 ; CHECK-NEXT:     REPLICATE ir<%l2> = trunc ir<%l1>
-; CHECK-NEXT:     REPLICATE ir<%cmp> = icmp ir<%l1>, ir<0>
+; CHECK-NEXT:     REPLICATE ir<%cmp> = icmp eq ir<%l1>, ir<0>
 ; CHECK-NEXT:     REPLICATE ir<%sel> = select ir<%cmp>, ir<5>, ir<%l2>
 ; CHECK-NEXT:     REPLICATE store ir<%sel>, ir<%dst>
 ; CHECK-NEXT:   Successor(s): pred.store.continue


        


More information about the llvm-commits mailing list