[llvm] [VPlan] Use VPInstructionWithType for uniform casts. (PR #140623)

via llvm-commits llvm-commits at lists.llvm.org
Mon May 19 14:07:49 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-llvm-transforms

@llvm/pr-subscribers-backend-risc-v

Author: Florian Hahn (fhahn)

<details>
<summary>Changes</summary>

Use VPInstructionWithType instead of VPReplicate recipe for uniform
casts. This is a first step towards breaking up VPReplicateRecipe. Using
the general VPInstructionWithType has the additional benefit that we can
now apply a number of simplifications directly.

This patch also adds a new IsSingleScalar field to VPInstruction, to
encode the fact we know a recipe always produces a single scalar.

Depends on https://github.com/llvm/llvm-project/pull/140621 (included in PR)

---

Patch is 60.45 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/140623.diff


18 Files Affected:

- (modified) llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h (+14-18) 
- (modified) llvm/lib/Transforms/Vectorize/LoopVectorize.cpp (+8-1) 
- (modified) llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h (+1-1) 
- (modified) llvm/lib/Transforms/Vectorize/VPlan.h (+86-126) 
- (modified) llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp (+63-44) 
- (modified) llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp (+26-17) 
- (modified) llvm/lib/Transforms/Vectorize/VPlanUnroll.cpp (+4-4) 
- (modified) llvm/test/Transforms/LoopVectorize/AArch64/sve2-histcnt-vplan.ll (+1-1) 
- (modified) llvm/test/Transforms/LoopVectorize/RISCV/riscv-vector-reverse.ll (+4-4) 
- (modified) llvm/test/Transforms/LoopVectorize/RISCV/vplan-vp-call-intrinsics.ll (+9-9) 
- (modified) llvm/test/Transforms/LoopVectorize/RISCV/vplan-vp-cast-intrinsics.ll (+10-10) 
- (modified) llvm/test/Transforms/LoopVectorize/RISCV/vplan-vp-intrinsics-fixed-order-recurrence.ll (+2-2) 
- (modified) llvm/test/Transforms/LoopVectorize/RISCV/vplan-vp-intrinsics-reduction.ll (+2-2) 
- (modified) llvm/test/Transforms/LoopVectorize/RISCV/vplan-vp-intrinsics.ll (+1-1) 
- (modified) llvm/test/Transforms/LoopVectorize/RISCV/vplan-vp-select-intrinsics.ll (+1-1) 
- (modified) llvm/test/Transforms/LoopVectorize/X86/constant-fold.ll (+1-2) 
- (modified) llvm/test/Transforms/LoopVectorize/as_cast.ll (+6-5) 
- (modified) llvm/test/Transforms/LoopVectorize/interleave-and-scalarize-only.ll (+6-8) 


``````````diff
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h b/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h
index bae53c600c18c..c751f053cb65a 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h
@@ -164,25 +164,19 @@ class VPBuilder {
                               DebugLoc DL, const Twine &Name = "") {
     return createInstruction(Opcode, Operands, DL, Name);
   }
-  VPInstruction *createNaryOp(unsigned Opcode,
-                              std::initializer_list<VPValue *> Operands,
-                              std::optional<FastMathFlags> FMFs = {},
-                              DebugLoc DL = {}, const Twine &Name = "") {
-    if (FMFs)
-      return tryInsertInstruction(
-          new VPInstruction(Opcode, Operands, *FMFs, DL, Name));
-    return createInstruction(Opcode, Operands, DL, Name);
+  VPInstruction *createNaryOp(unsigned Opcode, ArrayRef<VPValue *> Operands,
+                              const VPIRFlags &Flags, DebugLoc DL = {},
+                              const Twine &Name = "") {
+    return tryInsertInstruction(
+        new VPInstruction(Opcode, Operands, Flags, DL, Name));
   }
+
   VPInstruction *createNaryOp(unsigned Opcode,
                               std::initializer_list<VPValue *> Operands,
-                              Type *ResultTy,
-                              std::optional<FastMathFlags> FMFs = {},
+                              Type *ResultTy, const VPIRFlags &Flags = {},
                               DebugLoc DL = {}, const Twine &Name = "") {
-    if (FMFs)
-      return tryInsertInstruction(new VPInstructionWithType(
-          Opcode, Operands, ResultTy, *FMFs, DL, Name));
     return tryInsertInstruction(
-        new VPInstructionWithType(Opcode, Operands, ResultTy, DL, Name));
+        new VPInstructionWithType(Opcode, Operands, ResultTy, Flags, DL, Name));
   }
 
   VPInstruction *createOverflowingOp(unsigned Opcode,
@@ -236,18 +230,20 @@ class VPBuilder {
     assert(Pred >= CmpInst::FIRST_ICMP_PREDICATE &&
            Pred <= CmpInst::LAST_ICMP_PREDICATE && "invalid predicate");
     return tryInsertInstruction(
-        new VPInstruction(Instruction::ICmp, Pred, A, B, DL, Name));
+        new VPInstruction(Instruction::ICmp, {A, B}, Pred, DL, Name));
   }
 
   VPInstruction *createPtrAdd(VPValue *Ptr, VPValue *Offset, DebugLoc DL = {},
                               const Twine &Name = "") {
     return tryInsertInstruction(
-        new VPInstruction(Ptr, Offset, GEPNoWrapFlags::none(), DL, Name));
+        new VPInstruction(VPInstruction::PtrAdd, {Ptr, Offset},
+                          GEPNoWrapFlags::none(), DL, Name));
   }
   VPValue *createInBoundsPtrAdd(VPValue *Ptr, VPValue *Offset, DebugLoc DL = {},
                                 const Twine &Name = "") {
     return tryInsertInstruction(
-        new VPInstruction(Ptr, Offset, GEPNoWrapFlags::inBounds(), DL, Name));
+        new VPInstruction(VPInstruction::PtrAdd, {Ptr, Offset},
+                          GEPNoWrapFlags::inBounds(), DL, Name));
   }
 
   VPInstruction *createScalarPhi(ArrayRef<VPValue *> IncomingValues,
@@ -269,7 +265,7 @@ class VPBuilder {
   VPInstruction *createScalarCast(Instruction::CastOps Opcode, VPValue *Op,
                                   Type *ResultTy, DebugLoc DL) {
     return tryInsertInstruction(
-        new VPInstructionWithType(Opcode, Op, ResultTy, DL));
+        new VPInstructionWithType(Opcode, Op, ResultTy, {}, DL));
   }
 
   VPWidenCastRecipe *createWidenCast(Instruction::CastOps Opcode, VPValue *Op,
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index b2d7c44761f6d..58618c50573d3 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -8576,7 +8576,7 @@ VPRecipeBuilder::tryToWidenHistogram(const HistogramInfo *HI,
   return new VPHistogramRecipe(Opcode, HGramOps, HI->Store->getDebugLoc());
 }
 
-VPReplicateRecipe *
+VPSingleDefRecipe *
 VPRecipeBuilder::handleReplication(Instruction *I, ArrayRef<VPValue *> Operands,
                                    VFRange &Range) {
   bool IsUniform = LoopVectorizationPlanner::getDecisionAndClampRange(
@@ -8634,6 +8634,13 @@ VPRecipeBuilder::handleReplication(Instruction *I, ArrayRef<VPValue *> Operands,
   assert((Range.Start.isScalar() || !IsUniform || !IsPredicated ||
           (Range.Start.isScalable() && isa<IntrinsicInst>(I))) &&
          "Should not predicate a uniform recipe");
+  if (IsUniform && Instruction::isCast(I->getOpcode())) {
+    auto *Recipe = new VPInstructionWithType(I->getOpcode(), Operands,
+                                             I->getType(), VPIRFlags(*I),
+                                             I->getDebugLoc(), I->getName());
+    Recipe->setUnderlyingValue(I);
+    return Recipe;
+  }
   auto *Recipe = new VPReplicateRecipe(I, Operands, IsUniform, BlockInMask,
                                        VPIRMetadata(*I, LVer));
   return Recipe;
diff --git a/llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h b/llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h
index ae86181487261..959cb61889c7d 100644
--- a/llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h
+++ b/llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h
@@ -241,7 +241,7 @@ class VPRecipeBuilder {
   /// Build a VPReplicationRecipe for \p I using \p Operands. If it is
   /// predicated, add the mask as last operand. Range.End may be decreased to
   /// ensure same recipe behavior from \p Range.Start to \p Range.End.
-  VPReplicateRecipe *handleReplication(Instruction *I,
+  VPSingleDefRecipe *handleReplication(Instruction *I,
                                        ArrayRef<VPValue *> Operands,
                                        VFRange &Range);
 
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index e634de1e17c69..9e3cec123d28a 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -577,8 +577,8 @@ class VPSingleDefRecipe : public VPRecipeBase, public VPValue {
 #endif
 };
 
-/// Class to record LLVM IR flag for a recipe along with it.
-class VPRecipeWithIRFlags : public VPSingleDefRecipe {
+/// Class to record LLVM IR flags.
+class VPIRFlags {
   enum class OperationType : unsigned char {
     Cmp,
     OverflowingBinOp,
@@ -637,23 +637,10 @@ class VPRecipeWithIRFlags : public VPSingleDefRecipe {
     unsigned AllFlags;
   };
 
-protected:
-  void transferFlags(VPRecipeWithIRFlags &Other) {
-    OpType = Other.OpType;
-    AllFlags = Other.AllFlags;
-  }
-
 public:
-  VPRecipeWithIRFlags(const unsigned char SC, ArrayRef<VPValue *> Operands,
-                      DebugLoc DL = {})
-      : VPSingleDefRecipe(SC, Operands, DL) {
-    OpType = OperationType::Other;
-    AllFlags = 0;
-  }
+  VPIRFlags() : OpType(OperationType::Other), AllFlags(0) {}
 
-  VPRecipeWithIRFlags(const unsigned char SC, ArrayRef<VPValue *> Operands,
-                      Instruction &I)
-      : VPSingleDefRecipe(SC, Operands, &I, I.getDebugLoc()) {
+  VPIRFlags(Instruction &I) {
     if (auto *Op = dyn_cast<CmpInst>(&I)) {
       OpType = OperationType::Cmp;
       CmpPredicate = Op->getPredicate();
@@ -681,63 +668,27 @@ class VPRecipeWithIRFlags : public VPSingleDefRecipe {
     }
   }
 
-  VPRecipeWithIRFlags(const unsigned char SC, ArrayRef<VPValue *> Operands,
-                      CmpInst::Predicate Pred, DebugLoc DL = {})
-      : VPSingleDefRecipe(SC, Operands, DL), OpType(OperationType::Cmp),
-        CmpPredicate(Pred) {}
+  VPIRFlags(CmpInst::Predicate Pred)
+      : OpType(OperationType::Cmp), CmpPredicate(Pred) {}
 
-  VPRecipeWithIRFlags(const unsigned char SC, ArrayRef<VPValue *> Operands,
-                      WrapFlagsTy WrapFlags, DebugLoc DL = {})
-      : VPSingleDefRecipe(SC, Operands, DL),
-        OpType(OperationType::OverflowingBinOp), WrapFlags(WrapFlags) {}
+  VPIRFlags(WrapFlagsTy WrapFlags)
+      : OpType(OperationType::OverflowingBinOp), WrapFlags(WrapFlags) {}
 
-  VPRecipeWithIRFlags(const unsigned char SC, ArrayRef<VPValue *> Operands,
-                      FastMathFlags FMFs, DebugLoc DL = {})
-      : VPSingleDefRecipe(SC, Operands, DL), OpType(OperationType::FPMathOp),
-        FMFs(FMFs) {}
+  VPIRFlags(FastMathFlags FMFs) : OpType(OperationType::FPMathOp), FMFs(FMFs) {}
 
-  VPRecipeWithIRFlags(const unsigned char SC, ArrayRef<VPValue *> Operands,
-                      DisjointFlagsTy DisjointFlags, DebugLoc DL = {})
-      : VPSingleDefRecipe(SC, Operands, DL), OpType(OperationType::DisjointOp),
-        DisjointFlags(DisjointFlags) {}
+  VPIRFlags(DisjointFlagsTy DisjointFlags)
+      : OpType(OperationType::DisjointOp), DisjointFlags(DisjointFlags) {}
 
-  template <typename IterT>
-  VPRecipeWithIRFlags(const unsigned char SC, IterT Operands,
-                      NonNegFlagsTy NonNegFlags, DebugLoc DL = {})
-      : VPSingleDefRecipe(SC, Operands, DL), OpType(OperationType::NonNegOp),
-        NonNegFlags(NonNegFlags) {}
+  VPIRFlags(NonNegFlagsTy NonNegFlags)
+      : OpType(OperationType::NonNegOp), NonNegFlags(NonNegFlags) {}
 
-protected:
-  VPRecipeWithIRFlags(const unsigned char SC, ArrayRef<VPValue *> Operands,
-                      GEPNoWrapFlags GEPFlags, DebugLoc DL = {})
-      : VPSingleDefRecipe(SC, Operands, DL), OpType(OperationType::GEPOp),
-        GEPFlags(GEPFlags) {}
+  VPIRFlags(GEPNoWrapFlags GEPFlags)
+      : OpType(OperationType::GEPOp), GEPFlags(GEPFlags) {}
 
 public:
-  static inline bool classof(const VPRecipeBase *R) {
-    return R->getVPDefID() == VPRecipeBase::VPInstructionSC ||
-           R->getVPDefID() == VPRecipeBase::VPWidenSC ||
-           R->getVPDefID() == VPRecipeBase::VPWidenGEPSC ||
-           R->getVPDefID() == VPRecipeBase::VPWidenCallSC ||
-           R->getVPDefID() == VPRecipeBase::VPWidenCastSC ||
-           R->getVPDefID() == VPRecipeBase::VPWidenIntrinsicSC ||
-           R->getVPDefID() == VPRecipeBase::VPReductionSC ||
-           R->getVPDefID() == VPRecipeBase::VPReductionEVLSC ||
-           R->getVPDefID() == VPRecipeBase::VPReplicateSC ||
-           R->getVPDefID() == VPRecipeBase::VPVectorEndPointerSC ||
-           R->getVPDefID() == VPRecipeBase::VPVectorPointerSC ||
-           R->getVPDefID() == VPRecipeBase::VPExtendedReductionSC ||
-           R->getVPDefID() == VPRecipeBase::VPMulAccumulateReductionSC;
-  }
-
-  static inline bool classof(const VPUser *U) {
-    auto *R = dyn_cast<VPRecipeBase>(U);
-    return R && classof(R);
-  }
-
-  static inline bool classof(const VPValue *V) {
-    auto *R = dyn_cast_or_null<VPRecipeBase>(V->getDefiningRecipe());
-    return R && classof(R);
+  void transferFlags(VPIRFlags &Other) {
+    OpType = Other.OpType;
+    AllFlags = Other.AllFlags;
   }
 
   /// Drop all poison-generating flags.
@@ -851,11 +802,58 @@ class VPRecipeWithIRFlags : public VPSingleDefRecipe {
     return DisjointFlags.IsDisjoint;
   }
 
+#if !defined(NDEBUG)
+  /// Returns true if the set flags are valid for \p Opcode.
+  bool flagsValidForOpcode(unsigned Opcode) const;
+#endif
+
 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
   void printFlags(raw_ostream &O) const;
 #endif
 };
 
+class VPRecipeWithIRFlags : public VPSingleDefRecipe, public VPIRFlags {
+public:
+  VPRecipeWithIRFlags(const unsigned char SC, ArrayRef<VPValue *> Operands,
+                      DebugLoc DL = {})
+      : VPSingleDefRecipe(SC, Operands, DL), VPIRFlags() {}
+
+  VPRecipeWithIRFlags(const unsigned char SC, ArrayRef<VPValue *> Operands,
+                      Instruction &I)
+      : VPSingleDefRecipe(SC, Operands, &I, I.getDebugLoc()), VPIRFlags(I) {}
+
+  VPRecipeWithIRFlags(const unsigned char SC, ArrayRef<VPValue *> Operands,
+                      const VPIRFlags &Flags, DebugLoc DL = {})
+      : VPSingleDefRecipe(SC, Operands, DL), VPIRFlags(Flags) {}
+
+public:
+  static inline bool classof(const VPRecipeBase *R) {
+    return R->getVPDefID() == VPRecipeBase::VPInstructionSC ||
+           R->getVPDefID() == VPRecipeBase::VPWidenSC ||
+           R->getVPDefID() == VPRecipeBase::VPWidenGEPSC ||
+           R->getVPDefID() == VPRecipeBase::VPWidenCallSC ||
+           R->getVPDefID() == VPRecipeBase::VPWidenCastSC ||
+           R->getVPDefID() == VPRecipeBase::VPWidenIntrinsicSC ||
+           R->getVPDefID() == VPRecipeBase::VPReductionSC ||
+           R->getVPDefID() == VPRecipeBase::VPReductionEVLSC ||
+           R->getVPDefID() == VPRecipeBase::VPReplicateSC ||
+           R->getVPDefID() == VPRecipeBase::VPVectorEndPointerSC ||
+           R->getVPDefID() == VPRecipeBase::VPVectorPointerSC ||
+           R->getVPDefID() == VPRecipeBase::VPExtendedReductionSC ||
+           R->getVPDefID() == VPRecipeBase::VPMulAccumulateReductionSC;
+  }
+
+  static inline bool classof(const VPUser *U) {
+    auto *R = dyn_cast<VPRecipeBase>(U);
+    return R && classof(R);
+  }
+
+  static inline bool classof(const VPValue *V) {
+    auto *R = dyn_cast_or_null<VPRecipeBase>(V->getDefiningRecipe());
+    return R && classof(R);
+  }
+};
+
 /// Helper to access the operand that contains the unroll part for this recipe
 /// after unrolling.
 template <unsigned PartOpIdx> class VPUnrollPartAccessor {
@@ -876,6 +874,9 @@ class VPInstruction : public VPRecipeWithIRFlags,
                       public VPUnrollPartAccessor<1> {
   friend class VPlanSlp;
 
+  /// True if the VPInstruction produces a single scalar value.
+  bool IsSingleScalar;
+
 public:
   /// VPlan opcodes, extending LLVM IR with idiomatics instructions.
   enum {
@@ -958,54 +959,21 @@ class VPInstruction : public VPRecipeWithIRFlags,
   /// value for lane \p Lane.
   Value *generatePerLane(VPTransformState &State, const VPLane &Lane);
 
-#if !defined(NDEBUG)
-  /// Return true if the VPInstruction is a floating point math operation, i.e.
-  /// has fast-math flags.
-  bool isFPMathOp() const;
-#endif
-
 public:
-  VPInstruction(unsigned Opcode, ArrayRef<VPValue *> Operands, DebugLoc DL,
+  VPInstruction(unsigned Opcode, ArrayRef<VPValue *> Operands, DebugLoc DL = {},
                 const Twine &Name = "")
       : VPRecipeWithIRFlags(VPDef::VPInstructionSC, Operands, DL),
         Opcode(Opcode), Name(Name.str()) {}
 
-  VPInstruction(unsigned Opcode, std::initializer_list<VPValue *> Operands,
-                DebugLoc DL = {}, const Twine &Name = "")
-      : VPInstruction(Opcode, ArrayRef<VPValue *>(Operands), DL, Name) {}
-
-  VPInstruction(unsigned Opcode, CmpInst::Predicate Pred, VPValue *A,
-                VPValue *B, DebugLoc DL = {}, const Twine &Name = "");
-
-  VPInstruction(unsigned Opcode, std::initializer_list<VPValue *> Operands,
-                WrapFlagsTy WrapFlags, DebugLoc DL = {}, const Twine &Name = "")
-      : VPRecipeWithIRFlags(VPDef::VPInstructionSC, Operands, WrapFlags, DL),
-        Opcode(Opcode), Name(Name.str()) {}
-
-  VPInstruction(unsigned Opcode, std::initializer_list<VPValue *> Operands,
-                DisjointFlagsTy DisjointFlag, DebugLoc DL = {},
-                const Twine &Name = "")
-      : VPRecipeWithIRFlags(VPDef::VPInstructionSC, Operands, DisjointFlag, DL),
-        Opcode(Opcode), Name(Name.str()) {
-    assert(Opcode == Instruction::Or && "only OR opcodes can be disjoint");
-  }
-
-  VPInstruction(VPValue *Ptr, VPValue *Offset, GEPNoWrapFlags Flags,
-                DebugLoc DL = {}, const Twine &Name = "")
-      : VPRecipeWithIRFlags(VPDef::VPInstructionSC,
-                            ArrayRef<VPValue *>({Ptr, Offset}), Flags, DL),
-        Opcode(VPInstruction::PtrAdd), Name(Name.str()) {}
-
-  VPInstruction(unsigned Opcode, std::initializer_list<VPValue *> Operands,
-                FastMathFlags FMFs, DebugLoc DL = {}, const Twine &Name = "");
+  VPInstruction(unsigned Opcode, ArrayRef<VPValue *> Operands,
+                const VPIRFlags &Flags, DebugLoc DL = {},
+                const Twine &Name = "", bool IsSingleScalar = false);
 
   VP_CLASSOF_IMPL(VPDef::VPInstructionSC)
 
   VPInstruction *clone() override {
     SmallVector<VPValue *, 2> Operands(operands());
-    auto *New = new VPInstruction(Opcode, Operands, getDebugLoc(), Name);
-    New->transferFlags(*this);
-    return New;
+    return new VPInstruction(Opcode, Operands, *this, getDebugLoc(), Name);
   }
 
   unsigned getOpcode() const { return Opcode; }
@@ -1082,13 +1050,10 @@ class VPInstructionWithType : public VPInstruction {
 
 public:
   VPInstructionWithType(unsigned Opcode, ArrayRef<VPValue *> Operands,
-                        Type *ResultTy, DebugLoc DL, const Twine &Name = "")
-      : VPInstruction(Opcode, Operands, DL, Name), ResultTy(ResultTy) {}
-  VPInstructionWithType(unsigned Opcode,
-                        std::initializer_list<VPValue *> Operands,
-                        Type *ResultTy, FastMathFlags FMFs, DebugLoc DL = {},
+                        Type *ResultTy, const VPIRFlags &Flags, DebugLoc DL,
                         const Twine &Name = "")
-      : VPInstruction(Opcode, Operands, FMFs, DL, Name), ResultTy(ResultTy) {}
+      : VPInstruction(Opcode, Operands, Flags, DL, Name, true),
+        ResultTy(ResultTy) {}
 
   static inline bool classof(const VPRecipeBase *R) {
     // VPInstructionWithType are VPInstructions with specific opcodes requiring
@@ -1113,8 +1078,9 @@ class VPInstructionWithType : public VPInstruction {
 
   VPInstruction *clone() override {
     SmallVector<VPValue *, 2> Operands(operands());
-    auto *New = new VPInstructionWithType(
-        getOpcode(), Operands, getResultType(), getDebugLoc(), getName());
+    auto *New =
+        new VPInstructionWithType(getOpcode(), Operands, getResultType(), *this,
+                                  getDebugLoc(), getName());
     New->setUnderlyingValue(getUnderlyingValue());
     return New;
   }
@@ -1123,10 +1089,7 @@ class VPInstructionWithType : public VPInstruction {
 
   /// Return the cost of this VPInstruction.
   InstructionCost computeCost(ElementCount VF,
-                              VPCostContext &Ctx) const override {
-    // TODO: Compute accurate cost after retiring the legacy cost model.
-    return 0;
-  }
+                              VPCostContext &Ctx) const override;
 
   Type *getResultType() const { return ResultTy; }
 
@@ -1373,15 +1336,12 @@ class VPWidenCastRecipe : public VPRecipeWithIRFlags, public VPIRMetadata {
   }
 
   VPWidenCastRecipe(Instruction::CastOps Opcode, VPValue *Op, Type *ResultTy,
-                    DebugLoc DL = {})
-      : VPRecipeWithIRFlags(VPDef::VPWidenCastSC, Op, DL), VPIRMetadata(),
-        Opcode(Opcode), ResultTy(ResultTy) {}
-
-  VPWidenCastRecipe(Instruction::CastOps Opcode, VPValue *Op, Type *ResultTy,
-                    bool IsNonNeg, DebugLoc DL = {})
-      : VPRecipeWithIRFlags(VPDef::VPWidenCastSC, Op, NonNegFlagsTy(IsNonNeg),
-                            DL),
-        Opcode(Opcode), ResultTy(ResultTy) {}
+                    const VPIRFlags &Flags = {}, DebugLoc DL = {})
+      : VPRecipeWithIRFlags(VPDef::VPWidenCastSC, Op, Flags, DL),
+        VPIRMetadata(), Opcode(Opcode), ResultTy(ResultTy) {
+    assert(flagsValidForOpcode(Opcode) &&
+           "Set flags not supported for the provided opcode");
+  }
 
   ~VPWidenCastRecipe() override = default;
 
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index 14ed40f16683a..abdf4f80390f2 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -368,7 +368,7 @@ void VPPartialReductionRecipe::print(raw_ostream &O, const Twine &Indent,
 }
 #endif
 
-FastMathFlags VPRecipeWithIRFlags::getFastMathFlags() const {
+FastMathFlags VPIRFlags::getFastMathFlags() const {
   assert(OpType == OperationType::FPMathOp &&
          "recipe doesn't have fast math flags");
   FastMathFlags Res;
@@ -406,23 +406,13 @@ template class VPUnrollPartAccessor<2>;
 template class VPUnrollPartAccessor<3>;
 }
 
-VPInstruction::VPInstruction(unsigned Opcode, CmpInst::Predicate Pred,
-                             VPValue *A, VPValue *B, DebugLoc DL,
-                             const Twine &Name)
-    : VPRecipeWithIRFlags(VPDef::VPInstructionSC, ArrayRef<VPValue *>({A, B}),
-                          Pred, DL),
-      Opcode(Opcode), Name(Name.str()) {
-  assert(Opcode == Instruction::ICmp &&
-         "...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/140623


More information about the llvm-commits mailing list