[llvm] Prototype vectorizing structs via multiple result VPWidenCallRecipe (PR #112402)

Benjamin Maxwell via llvm-commits llvm-commits at lists.llvm.org
Tue Oct 15 09:53:20 PDT 2024


https://github.com/MacDue created https://github.com/llvm/llvm-project/pull/112402

Prototype for https://github.com/llvm/llvm-project/pull/109833#issuecomment-2413611812

>From 75affb4c5ea35d68a1dcada4cebed1fe5b8ef04e Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Tue, 15 Oct 2024 15:50:03 +0000
Subject: [PATCH 1/3] Initial changes picked over from #109833

---
 llvm/include/llvm/Analysis/VectorUtils.h      | 15 +---
 llvm/include/llvm/IR/VectorUtils.h            | 53 ++++++++++++++
 llvm/lib/Analysis/VectorUtils.cpp             | 14 ++++
 llvm/lib/IR/CMakeLists.txt                    |  1 +
 llvm/lib/IR/VFABIDemangler.cpp                | 18 +++--
 llvm/lib/IR/VectorUtils.cpp                   | 69 +++++++++++++++++++
 .../Vectorize/LoopVectorizationLegality.cpp   |  4 +-
 .../Transforms/Vectorize/LoopVectorize.cpp    | 46 +++++++------
 .../lib/Transforms/Vectorize/VPlanRecipes.cpp |  2 +-
 9 files changed, 180 insertions(+), 42 deletions(-)
 create mode 100644 llvm/include/llvm/IR/VectorUtils.h
 create mode 100644 llvm/lib/IR/VectorUtils.cpp

diff --git a/llvm/include/llvm/Analysis/VectorUtils.h b/llvm/include/llvm/Analysis/VectorUtils.h
index e2dd4976f39065..2a419560be3030 100644
--- a/llvm/include/llvm/Analysis/VectorUtils.h
+++ b/llvm/include/llvm/Analysis/VectorUtils.h
@@ -18,6 +18,7 @@
 #include "llvm/Analysis/LoopAccessAnalysis.h"
 #include "llvm/IR/Module.h"
 #include "llvm/IR/VFABIDemangler.h"
+#include "llvm/IR/VectorUtils.h"
 #include "llvm/Support/CheckedArithmetic.h"
 
 namespace llvm {
@@ -127,18 +128,8 @@ namespace Intrinsic {
 typedef unsigned ID;
 }
 
-/// A helper function for converting Scalar types to vector types. If
-/// the incoming type is void, we return void. If the EC represents a
-/// scalar, we return the scalar type.
-inline Type *ToVectorTy(Type *Scalar, ElementCount EC) {
-  if (Scalar->isVoidTy() || Scalar->isMetadataTy() || EC.isScalar())
-    return Scalar;
-  return VectorType::get(Scalar, EC);
-}
-
-inline Type *ToVectorTy(Type *Scalar, unsigned VF) {
-  return ToVectorTy(Scalar, ElementCount::getFixed(VF));
-}
+/// Returns true if `Ty` can be widened by the loop vectorizer.
+bool canWidenType(Type *Ty);
 
 /// Identify if the intrinsic is trivially vectorizable.
 /// This method returns true if the intrinsic's argument types are all scalars
diff --git a/llvm/include/llvm/IR/VectorUtils.h b/llvm/include/llvm/IR/VectorUtils.h
new file mode 100644
index 00000000000000..e8e838d8287c42
--- /dev/null
+++ b/llvm/include/llvm/IR/VectorUtils.h
@@ -0,0 +1,53 @@
+//===----------- VectorUtils.h -  Vector type utility functions -*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/IR/DerivedTypes.h"
+
+namespace llvm {
+
+/// A helper function for converting Scalar types to vector types. If
+/// the incoming type is void, we return void. If the EC represents a
+/// scalar, we return the scalar type.
+inline Type *ToVectorTy(Type *Scalar, ElementCount EC) {
+  if (Scalar->isVoidTy() || Scalar->isMetadataTy() || EC.isScalar())
+    return Scalar;
+  return VectorType::get(Scalar, EC);
+}
+
+inline Type *ToVectorTy(Type *Scalar, unsigned VF) {
+  return ToVectorTy(Scalar, ElementCount::getFixed(VF));
+}
+
+/// A helper for converting to wider (vector) types. For scalar types, this is
+/// equivalent to calling `ToVectorTy`. For struct types, this returns a new
+/// struct where each element type has been widened to a vector type. Note: Only
+/// unpacked literal struct types are supported.
+Type *ToWideTy(Type *Ty, ElementCount EC);
+
+/// A helper for converting wide types to narrow (non-vector) types. For vector
+/// types, this is equivalent to calling .getScalarType(). For struct types,
+/// this returns a new struct where each element type has been converted to a
+/// scalar type. Note: Only unpacked literal struct types are supported.
+Type *ToNarrowTy(Type *Ty);
+
+/// Returns the types contained in `Ty`. For struct types, it returns the
+/// elements, all other types are returned directly.
+SmallVector<Type *, 2> getContainedTypes(Type *Ty);
+
+/// Returns true if `Ty` is a vector type or a struct of vector types where all
+/// vector types share the same VF.
+bool isWideTy(Type *Ty);
+
+/// Returns the vectorization factor for a widened type.
+inline ElementCount getWideTypeVF(Type *Ty) {
+  assert(isWideTy(Ty) && "expected widened type!");
+  return cast<VectorType>(getContainedTypes(Ty).front())->getElementCount();
+}
+
+} // namespace llvm
diff --git a/llvm/lib/Analysis/VectorUtils.cpp b/llvm/lib/Analysis/VectorUtils.cpp
index dbffbb8a5f81d9..38b9da69ae2b76 100644
--- a/llvm/lib/Analysis/VectorUtils.cpp
+++ b/llvm/lib/Analysis/VectorUtils.cpp
@@ -39,6 +39,20 @@ static cl::opt<unsigned> MaxInterleaveGroupFactor(
     cl::desc("Maximum factor for an interleaved access group (default = 8)"),
     cl::init(8));
 
+/// Returns true if `Ty` can be widened by the loop vectorizer.
+bool llvm::canWidenType(Type *Ty) {
+  Type *ElTy = Ty;
+  // For now, only allow widening non-packed literal structs where all
+  // element types are the same. This simplifies the cost model and
+  // conversion between scalar and wide types.
+  if (auto *StructTy = dyn_cast<StructType>(Ty);
+      StructTy && !StructTy->isPacked() && StructTy->isLiteral() &&
+      StructTy->containsHomogeneousTypes()) {
+    ElTy = StructTy->elements().front();
+  }
+  return VectorType::isValidElementType(ElTy);
+}
+
 /// Return true if all of the intrinsic's arguments and return type are scalars
 /// for the scalar form of the intrinsic, and vectors for the vector form of the
 /// intrinsic (except operands that are marked as always being scalar by
diff --git a/llvm/lib/IR/CMakeLists.txt b/llvm/lib/IR/CMakeLists.txt
index 544f4ea9223d0e..7eaf35e10ebc67 100644
--- a/llvm/lib/IR/CMakeLists.txt
+++ b/llvm/lib/IR/CMakeLists.txt
@@ -73,6 +73,7 @@ add_llvm_component_library(LLVMCore
   Value.cpp
   ValueSymbolTable.cpp
   VectorBuilder.cpp
+  VectorUtils.cpp
   Verifier.cpp
   VFABIDemangler.cpp
   RuntimeLibcalls.cpp
diff --git a/llvm/lib/IR/VFABIDemangler.cpp b/llvm/lib/IR/VFABIDemangler.cpp
index cdfb9fbfaa084d..6ccd77fd23793a 100644
--- a/llvm/lib/IR/VFABIDemangler.cpp
+++ b/llvm/lib/IR/VFABIDemangler.cpp
@@ -11,6 +11,7 @@
 #include "llvm/ADT/SmallString.h"
 #include "llvm/ADT/StringSwitch.h"
 #include "llvm/IR/Module.h"
+#include "llvm/IR/VectorUtils.h"
 #include "llvm/Support/Debug.h"
 #include "llvm/Support/raw_ostream.h"
 #include <limits>
@@ -346,12 +347,15 @@ getScalableECFromSignature(const FunctionType *Signature, const VFISAKind ISA,
   // Also check the return type if not void.
   Type *RetTy = Signature->getReturnType();
   if (!RetTy->isVoidTy()) {
-    std::optional<ElementCount> ReturnEC = getElementCountForTy(ISA, RetTy);
-    // If we have an unknown scalar element type we can't find a reasonable VF.
-    if (!ReturnEC)
-      return std::nullopt;
-    if (ElementCount::isKnownLT(*ReturnEC, MinEC))
-      MinEC = *ReturnEC;
+    for (Type *RetTy : getContainedTypes(RetTy)) {
+      std::optional<ElementCount> ReturnEC = getElementCountForTy(ISA, RetTy);
+      // If we have an unknown scalar element type we can't find a reasonable
+      // VF.
+      if (!ReturnEC)
+        return std::nullopt;
+      if (ElementCount::isKnownLT(*ReturnEC, MinEC))
+        MinEC = *ReturnEC;
+    }
   }
 
   // The SVE Vector function call ABI bases the VF on the widest element types
@@ -566,7 +570,7 @@ FunctionType *VFABI::createFunctionType(const VFInfo &Info,
 
   auto *RetTy = ScalarFTy->getReturnType();
   if (!RetTy->isVoidTy())
-    RetTy = VectorType::get(RetTy, VF);
+    RetTy = ToWideTy(RetTy, VF);
   return FunctionType::get(RetTy, VecTypes, false);
 }
 
diff --git a/llvm/lib/IR/VectorUtils.cpp b/llvm/lib/IR/VectorUtils.cpp
new file mode 100644
index 00000000000000..c89a8eaf2ad1e0
--- /dev/null
+++ b/llvm/lib/IR/VectorUtils.cpp
@@ -0,0 +1,69 @@
+//===----------- VectorUtils.cpp - Vector type utility functions ----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/IR/VectorUtils.h"
+#include "llvm/ADT/SmallVectorExtras.h"
+
+using namespace llvm;
+
+/// A helper for converting to wider (vector) types. For scalar types, this is
+/// equivalent to calling `ToVectorTy`. For struct types, this returns a new
+/// struct where each element type has been widened to a vector type. Note: Only
+/// unpacked literal struct types are supported.
+Type *llvm::ToWideTy(Type *Ty, ElementCount EC) {
+  if (EC.isScalar())
+    return Ty;
+  auto *StructTy = dyn_cast<StructType>(Ty);
+  if (!StructTy)
+    return ToVectorTy(Ty, EC);
+  assert(StructTy->isLiteral() && !StructTy->isPacked() &&
+         "expected unpacked struct literal");
+  return StructType::get(
+      Ty->getContext(),
+      map_to_vector(StructTy->elements(), [&](Type *ElTy) -> Type * {
+        return VectorType::get(ElTy, EC);
+      }));
+}
+
+/// A helper for converting wide types to narrow (non-vector) types. For vector
+/// types, this is equivalent to calling .getScalarType(). For struct types,
+/// this returns a new struct where each element type has been converted to a
+/// scalar type. Note: Only unpacked literal struct types are supported.
+Type *llvm::ToNarrowTy(Type *Ty) {
+  auto *StructTy = dyn_cast<StructType>(Ty);
+  if (!StructTy)
+    return Ty->getScalarType();
+  assert(StructTy->isLiteral() && !StructTy->isPacked() &&
+         "expected unpacked struct literal");
+  return StructType::get(
+      Ty->getContext(),
+      map_to_vector(StructTy->elements(), [](Type *ElTy) -> Type * {
+        return ElTy->getScalarType();
+      }));
+}
+
+/// Returns the types contained in `Ty`. For struct types, it returns the
+/// elements, all other types are returned directly.
+SmallVector<Type *, 2> llvm::getContainedTypes(Type *Ty) {
+  auto *StructTy = dyn_cast<StructType>(Ty);
+  if (StructTy)
+    return to_vector<2>(StructTy->elements());
+  return {Ty};
+}
+
+/// Returns true if `Ty` is a vector type or a struct of vector types where all
+/// vector types share the same VF.
+bool llvm::isWideTy(Type *Ty) {
+  auto ContainedTys = getContainedTypes(Ty);
+  if (ContainedTys.empty() || !ContainedTys.front()->isVectorTy())
+    return false;
+  ElementCount VF = cast<VectorType>(ContainedTys.front())->getElementCount();
+  return all_of(ContainedTys, [&](Type *Ty) {
+    return Ty->isVectorTy() && cast<VectorType>(Ty)->getElementCount() == VF;
+  });
+}
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp
index 43be72f0f34d45..cb6327640dbdbb 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp
@@ -949,8 +949,8 @@ bool LoopVectorizationLegality::canVectorizeInstrs() {
       // Check that the instruction return type is vectorizable.
       // We can't vectorize casts from vector type to scalar type.
       // Also, we can't vectorize extractelement instructions.
-      if ((!VectorType::isValidElementType(I.getType()) &&
-           !I.getType()->isVoidTy()) ||
+      Type *InstTy = I.getType();
+      if (!(InstTy->isVoidTy() || canWidenType(InstTy)) ||
           (isa<CastInst>(I) &&
            !VectorType::isValidElementType(I.getOperand(0)->getType())) ||
           isa<ExtractElementInst>(I)) {
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index 027ee21527d228..87566387e5d1d2 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -2862,10 +2862,10 @@ LoopVectorizationCostModel::getVectorCallCost(CallInst *CI,
   return ScalarCallCost;
 }
 
-static Type *maybeVectorizeType(Type *Elt, ElementCount VF) {
-  if (VF.isScalar() || (!Elt->isIntOrPtrTy() && !Elt->isFloatingPointTy()))
-    return Elt;
-  return VectorType::get(Elt, VF);
+static Type *maybeVectorizeType(Type *Ty, ElementCount VF) {
+  if (VF.isScalar() || !canWidenType(Ty))
+    return Ty;
+  return ToWideTy(Ty, VF);
 }
 
 InstructionCost
@@ -3630,9 +3630,8 @@ void LoopVectorizationCostModel::collectLoopUniforms(ElementCount VF) {
 
       // ExtractValue instructions must be uniform, because the operands are
       // known to be loop-invariant.
-      if (auto *EVI = dyn_cast<ExtractValueInst>(&I)) {
-        assert(IsOutOfScope(EVI->getAggregateOperand()) &&
-               "Expected aggregate value to be loop invariant");
+      if (auto *EVI = dyn_cast<ExtractValueInst>(&I);
+          EVI && IsOutOfScope(EVI->getAggregateOperand())) {
         AddToWorklistIfAllowed(EVI);
         continue;
       }
@@ -5456,10 +5455,13 @@ InstructionCost LoopVectorizationCostModel::computePredInstDiscount(
     // and phi nodes.
     TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
     if (isScalarWithPredication(I, VF) && !I->getType()->isVoidTy()) {
-      ScalarCost += TTI.getScalarizationOverhead(
-          cast<VectorType>(ToVectorTy(I->getType(), VF)),
-          APInt::getAllOnes(VF.getFixedValue()), /*Insert*/ true,
-          /*Extract*/ false, CostKind);
+      Type *WideTy = ToWideTy(I->getType(), VF);
+      for (Type *VectorTy : getContainedTypes(WideTy)) {
+        ScalarCost += TTI.getScalarizationOverhead(
+            cast<VectorType>(VectorTy), APInt::getAllOnes(VF.getFixedValue()),
+            /*Insert*/ true,
+            /*Extract*/ false, CostKind);
+      }
       ScalarCost +=
           VF.getFixedValue() * TTI.getCFInstrCost(Instruction::PHI, CostKind);
     }
@@ -5948,13 +5950,17 @@ InstructionCost LoopVectorizationCostModel::getScalarizationOverhead(
     return 0;
 
   InstructionCost Cost = 0;
-  Type *RetTy = ToVectorTy(I->getType(), VF);
+  Type *RetTy = ToWideTy(I->getType(), VF);
   if (!RetTy->isVoidTy() &&
-      (!isa<LoadInst>(I) || !TTI.supportsEfficientVectorElementLoadStore()))
-    Cost += TTI.getScalarizationOverhead(
-        cast<VectorType>(RetTy), APInt::getAllOnes(VF.getKnownMinValue()),
-        /*Insert*/ true,
-        /*Extract*/ false, CostKind);
+      (!isa<LoadInst>(I) || !TTI.supportsEfficientVectorElementLoadStore())) {
+
+    for (Type *VectorTy : getContainedTypes(RetTy)) {
+      Cost += TTI.getScalarizationOverhead(
+          cast<VectorType>(VectorTy), APInt::getAllOnes(VF.getKnownMinValue()),
+          /*Insert*/ true,
+          /*Extract*/ false, CostKind);
+    }
+  }
 
   // Some targets keep addresses scalar.
   if (isa<LoadInst>(I) && !TTI.prefersVectorizedAddressing())
@@ -6214,9 +6220,9 @@ void LoopVectorizationCostModel::setVectorizedCallDecision(ElementCount VF) {
 
       bool MaskRequired = Legal->isMaskRequired(CI);
       // Compute corresponding vector type for return value and arguments.
-      Type *RetTy = ToVectorTy(ScalarRetTy, VF);
+      Type *RetTy = ToWideTy(ScalarRetTy, VF);
       for (Type *ScalarTy : ScalarTys)
-        Tys.push_back(ToVectorTy(ScalarTy, VF));
+        Tys.push_back(ToWideTy(ScalarTy, VF));
 
       // An in-loop reduction using an fmuladd intrinsic is a special case;
       // we don't want the normal cost for that intrinsic.
@@ -6393,7 +6399,7 @@ LoopVectorizationCostModel::getInstructionCost(Instruction *I,
            HasSingleCopyAfterVectorization(I, VF));
     VectorTy = RetTy;
   } else
-    VectorTy = ToVectorTy(RetTy, VF);
+    VectorTy = ToWideTy(RetTy, VF);
 
   if (VF.isVector() && VectorTy->isVectorTy() &&
       !TTI.getNumberOfParts(VectorTy))
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index 368d6e58a5578e..49322941f5f069 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -1024,7 +1024,7 @@ InstructionCost VPWidenIntrinsicRecipe::computeCost(ElementCount VF,
     Arguments.push_back(V);
   }
 
-  Type *RetTy = ToVectorTy(Ctx.Types.inferScalarType(this), VF);
+  Type *RetTy = ToWideTy(Ctx.Types.inferScalarType(this), VF);
   SmallVector<Type *> ParamTys;
   for (unsigned I = 0; I != getNumOperands(); ++I)
     ParamTys.push_back(

>From dbedb67e957574100b7a368526a064299cbf14ed Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Tue, 15 Oct 2024 14:07:36 +0000
Subject: [PATCH 2/3] Make having flags (FMFs etc) not require inheriting
 VPSingleDefRecipe

In the next patch `VPWidenCallRecipe` is made to return multiple
results. It still needs to support flags, but can't be a
`VPSingleDefRecipe`.

This is fixed by replacing `VPRecipeWithIRFlags` with `VPRecipeIRFlags`,
and adding the virtual method `VPRecipeIRFlags* VPRecipeBase::getIRFlags()`.
So now a recipe having flags does not require any special inheritance
(or imply a number of results).
---
 .../Vectorize/LoopVectorizationPlanner.h      |  12 +-
 llvm/lib/Transforms/Vectorize/VPlan.h         | 204 ++++++++++--------
 .../lib/Transforms/Vectorize/VPlanRecipes.cpp |  47 ++--
 .../Transforms/Vectorize/VPlanTransforms.cpp  |  31 +--
 .../AArch64/sve2-histcnt-vplan.ll             |   2 +-
 5 files changed, 162 insertions(+), 134 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h b/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h
index 1c8d541ef2c51f..e382e73013b8a3 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h
@@ -168,7 +168,7 @@ class VPBuilder {
 
   VPInstruction *createOverflowingOp(unsigned Opcode,
                                      std::initializer_list<VPValue *> Operands,
-                                     VPRecipeWithIRFlags::WrapFlagsTy WrapFlags,
+                                     VPRecipeIRFlags::WrapFlagsTy WrapFlags,
                                      DebugLoc DL = {}, const Twine &Name = "") {
     return tryInsertInstruction(
         new VPInstruction(Opcode, Operands, WrapFlags, DL, Name));
@@ -187,9 +187,9 @@ class VPBuilder {
   VPValue *createOr(VPValue *LHS, VPValue *RHS, DebugLoc DL = {},
                     const Twine &Name = "") {
 
-    return tryInsertInstruction(new VPInstruction(
-        Instruction::BinaryOps::Or, {LHS, RHS},
-        VPRecipeWithIRFlags::DisjointFlagsTy(false), DL, Name));
+    return tryInsertInstruction(
+        new VPInstruction(Instruction::BinaryOps::Or, {LHS, RHS},
+                          VPRecipeIRFlags::DisjointFlagsTy(false), DL, Name));
   }
 
   VPValue *createLogicalAnd(VPValue *LHS, VPValue *RHS, DebugLoc DL = {},
@@ -223,12 +223,12 @@ class VPBuilder {
   VPInstruction *createPtrAdd(VPValue *Ptr, VPValue *Offset, DebugLoc DL = {},
                               const Twine &Name = "") {
     return tryInsertInstruction(new VPInstruction(
-        Ptr, Offset, VPRecipeWithIRFlags::GEPFlagsTy(false), DL, Name));
+        Ptr, Offset, VPRecipeIRFlags::GEPFlagsTy(false), DL, Name));
   }
   VPValue *createInBoundsPtrAdd(VPValue *Ptr, VPValue *Offset, DebugLoc DL = {},
                                 const Twine &Name = "") {
     return tryInsertInstruction(new VPInstruction(
-        Ptr, Offset, VPRecipeWithIRFlags::GEPFlagsTy(true), DL, Name));
+        Ptr, Offset, VPRecipeIRFlags::GEPFlagsTy(true), DL, Name));
   }
 
   VPDerivedIVRecipe *createDerivedIV(InductionDescriptor::InductionKind Kind,
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index 6a61ef63c2a054..c0bd5833b21a58 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -720,6 +720,8 @@ struct VPCostContext {
   bool skipCostComputation(Instruction *UI, bool IsVector) const;
 };
 
+class VPRecipeIRFlags;
+
 /// VPRecipeBase is a base class modeling a sequence of one or more output IR
 /// instructions. VPRecipeBase owns the VPValues it defines through VPDef
 /// and is responsible for deleting its defined values. Single-value
@@ -825,6 +827,9 @@ class VPRecipeBase : public ilist_node_with_parent<VPRecipeBase, VPBasicBlock>,
   /// Returns the debug location of the recipe.
   DebugLoc getDebugLoc() const { return DL; }
 
+  /// Returns the IR flags for the recipe.
+  virtual VPRecipeIRFlags *getIRFlags() { return nullptr; }
+
 protected:
   /// Compute the cost of this recipe either using a recipe's specialized
   /// implementation or using the legacy cost model and the underlying
@@ -936,8 +941,8 @@ class VPSingleDefRecipe : public VPRecipeBase, public VPValue {
                               VPCostContext &Ctx) const override;
 };
 
-/// Class to record LLVM IR flag for a recipe along with it.
-class VPRecipeWithIRFlags : public VPSingleDefRecipe {
+/// Class to record LLVM IR flag for a recipe.
+class VPRecipeIRFlags {
   enum class OperationType : unsigned char {
     Cmp,
     OverflowingBinOp,
@@ -999,23 +1004,10 @@ class VPRecipeWithIRFlags : public VPSingleDefRecipe {
     unsigned AllFlags;
   };
 
-protected:
-  void transferFlags(VPRecipeWithIRFlags &Other) {
-    OpType = Other.OpType;
-    AllFlags = Other.AllFlags;
-  }
-
 public:
-  template <typename IterT>
-  VPRecipeWithIRFlags(const unsigned char SC, IterT Operands, DebugLoc DL = {})
-      : VPSingleDefRecipe(SC, Operands, DL) {
-    OpType = OperationType::Other;
-    AllFlags = 0;
-  }
+  VPRecipeIRFlags() : OpType(OperationType::Other), AllFlags(0) {}
 
-  template <typename IterT>
-  VPRecipeWithIRFlags(const unsigned char SC, IterT Operands, Instruction &I)
-      : VPSingleDefRecipe(SC, Operands, &I, I.getDebugLoc()) {
+  VPRecipeIRFlags(Instruction &I) {
     if (auto *Op = dyn_cast<CmpInst>(&I)) {
       OpType = OperationType::Cmp;
       CmpPredicate = Op->getPredicate();
@@ -1043,53 +1035,22 @@ class VPRecipeWithIRFlags : public VPSingleDefRecipe {
     }
   }
 
-  template <typename IterT>
-  VPRecipeWithIRFlags(const unsigned char SC, IterT Operands,
-                      CmpInst::Predicate Pred, DebugLoc DL = {})
-      : VPSingleDefRecipe(SC, Operands, DL), OpType(OperationType::Cmp),
-        CmpPredicate(Pred) {}
+  VPRecipeIRFlags(CmpInst::Predicate Pred)
+      : OpType(OperationType::Cmp), CmpPredicate(Pred) {}
 
-  template <typename IterT>
-  VPRecipeWithIRFlags(const unsigned char SC, IterT Operands,
-                      WrapFlagsTy WrapFlags, DebugLoc DL = {})
-      : VPSingleDefRecipe(SC, Operands, DL),
-        OpType(OperationType::OverflowingBinOp), WrapFlags(WrapFlags) {}
+  VPRecipeIRFlags(WrapFlagsTy WrapFlags)
+      : OpType(OperationType::OverflowingBinOp), WrapFlags(WrapFlags) {}
 
-  template <typename IterT>
-  VPRecipeWithIRFlags(const unsigned char SC, IterT Operands,
-                      FastMathFlags FMFs, DebugLoc DL = {})
-      : VPSingleDefRecipe(SC, Operands, DL), OpType(OperationType::FPMathOp),
-        FMFs(FMFs) {}
+  VPRecipeIRFlags(FastMathFlags FMFs)
+      : OpType(OperationType::FPMathOp), FMFs(FMFs) {}
 
-  template <typename IterT>
-  VPRecipeWithIRFlags(const unsigned char SC, IterT Operands,
-                      DisjointFlagsTy DisjointFlags, DebugLoc DL = {})
-      : VPSingleDefRecipe(SC, Operands, DL), OpType(OperationType::DisjointOp),
-        DisjointFlags(DisjointFlags) {}
+  VPRecipeIRFlags(DisjointFlagsTy DisjointFlags)
+      : OpType(OperationType::DisjointOp), DisjointFlags(DisjointFlags) {}
 
-protected:
-  template <typename IterT>
-  VPRecipeWithIRFlags(const unsigned char SC, IterT Operands,
-                      GEPFlagsTy GEPFlags, DebugLoc DL = {})
-      : VPSingleDefRecipe(SC, Operands, DL), OpType(OperationType::GEPOp),
-        GEPFlags(GEPFlags) {}
+  VPRecipeIRFlags(GEPFlagsTy GEPFlags)
+      : OpType(OperationType::GEPOp), GEPFlags(GEPFlags) {}
 
 public:
-  static inline bool classof(const VPRecipeBase *R) {
-    return R->getVPDefID() == VPRecipeBase::VPInstructionSC ||
-           R->getVPDefID() == VPRecipeBase::VPWidenSC ||
-           R->getVPDefID() == VPRecipeBase::VPWidenEVLSC ||
-           R->getVPDefID() == VPRecipeBase::VPWidenGEPSC ||
-           R->getVPDefID() == VPRecipeBase::VPWidenCastSC ||
-           R->getVPDefID() == VPRecipeBase::VPReplicateSC ||
-           R->getVPDefID() == VPRecipeBase::VPVectorPointerSC;
-  }
-
-  static inline bool classof(const VPUser *U) {
-    auto *R = dyn_cast<VPRecipeBase>(U);
-    return R && classof(R);
-  }
-
   /// Drop all poison-generating flags.
   void dropPoisonGeneratingFlags() {
     // NOTE: This needs to be kept in-sync with
@@ -1198,6 +1159,53 @@ class VPRecipeWithIRFlags : public VPSingleDefRecipe {
 #endif
 };
 
+class VPSingleDefRecipeWithIRFlags : public VPSingleDefRecipe,
+                                     public VPRecipeIRFlags {
+public:
+  template <typename IterT>
+  VPSingleDefRecipeWithIRFlags(const unsigned char SC, IterT Operands,
+                               DebugLoc DL = {})
+      : VPSingleDefRecipe(SC, Operands, DL), VPRecipeIRFlags() {}
+
+  template <typename IterT>
+  VPSingleDefRecipeWithIRFlags(const unsigned char SC, IterT Operands,
+                               Instruction &I)
+      : VPSingleDefRecipe(SC, Operands, &I, I.getDebugLoc()),
+        VPRecipeIRFlags(I) {}
+
+  template <typename IterT>
+  VPSingleDefRecipeWithIRFlags(const unsigned char SC, IterT Operands,
+                               CmpInst::Predicate Pred, DebugLoc DL = {})
+      : VPSingleDefRecipe(SC, Operands, DL), VPRecipeIRFlags(Pred) {}
+  template <typename IterT>
+  VPSingleDefRecipeWithIRFlags(const unsigned char SC, IterT Operands,
+                               VPRecipeIRFlags::WrapFlagsTy WrapFlags,
+                               DebugLoc DL = {})
+      : VPSingleDefRecipe(SC, Operands, DL), VPRecipeIRFlags(WrapFlags) {}
+
+  template <typename IterT>
+  VPSingleDefRecipeWithIRFlags(const unsigned char SC, IterT Operands,
+                               FastMathFlags FMFs, DebugLoc DL = {})
+      : VPSingleDefRecipe(SC, Operands, DL), VPRecipeIRFlags(FMFs) {}
+
+  template <typename IterT>
+  VPSingleDefRecipeWithIRFlags(const unsigned char SC, IterT Operands,
+                               VPRecipeIRFlags::DisjointFlagsTy DisjointFlags,
+                               DebugLoc DL = {})
+      : VPSingleDefRecipe(SC, Operands, DL), VPRecipeIRFlags(DisjointFlags) {}
+
+  virtual VPRecipeIRFlags *getIRFlags() override {
+    return static_cast<VPRecipeIRFlags *>(this);
+  }
+
+protected:
+  template <typename IterT>
+  VPSingleDefRecipeWithIRFlags(const unsigned char SC, IterT Operands,
+                               VPRecipeIRFlags::GEPFlagsTy GEPFlags,
+                               DebugLoc DL = {})
+      : VPSingleDefRecipe(SC, Operands, DL), VPRecipeIRFlags(GEPFlags) {}
+};
+
 /// Helper to access the operand that contains the unroll part for this recipe
 /// after unrolling.
 template <unsigned PartOpIdx> class VPUnrollPartAccessor {
@@ -1214,7 +1222,7 @@ template <unsigned PartOpIdx> class VPUnrollPartAccessor {
 /// While as any Recipe it may generate a sequence of IR instructions when
 /// executed, these instructions would always form a single-def expression as
 /// the VPInstruction is also a single def-use vertex.
-class VPInstruction : public VPRecipeWithIRFlags,
+class VPInstruction : public VPSingleDefRecipeWithIRFlags,
                       public VPUnrollPartAccessor<1> {
   friend class VPlanSlp;
 
@@ -1289,7 +1297,7 @@ class VPInstruction : public VPRecipeWithIRFlags,
 public:
   VPInstruction(unsigned Opcode, ArrayRef<VPValue *> Operands, DebugLoc DL,
                 const Twine &Name = "")
-      : VPRecipeWithIRFlags(VPDef::VPInstructionSC, Operands, DL),
+      : VPSingleDefRecipeWithIRFlags(VPDef::VPInstructionSC, Operands, DL),
         Opcode(Opcode), Name(Name.str()) {}
 
   VPInstruction(unsigned Opcode, std::initializer_list<VPValue *> Operands,
@@ -1300,22 +1308,27 @@ class VPInstruction : public VPRecipeWithIRFlags,
                 VPValue *B, DebugLoc DL = {}, const Twine &Name = "");
 
   VPInstruction(unsigned Opcode, std::initializer_list<VPValue *> Operands,
-                WrapFlagsTy WrapFlags, DebugLoc DL = {}, const Twine &Name = "")
-      : VPRecipeWithIRFlags(VPDef::VPInstructionSC, Operands, WrapFlags, DL),
+                VPRecipeIRFlags::WrapFlagsTy WrapFlags, DebugLoc DL = {},
+                const Twine &Name = "")
+      : VPSingleDefRecipeWithIRFlags(VPDef::VPInstructionSC, Operands,
+                                     WrapFlags, DL),
         Opcode(Opcode), Name(Name.str()) {}
 
   VPInstruction(unsigned Opcode, std::initializer_list<VPValue *> Operands,
-                DisjointFlagsTy DisjointFlag, DebugLoc DL = {},
+                VPRecipeIRFlags::DisjointFlagsTy DisjointFlag, DebugLoc DL = {},
                 const Twine &Name = "")
-      : VPRecipeWithIRFlags(VPDef::VPInstructionSC, Operands, DisjointFlag, DL),
+      : VPSingleDefRecipeWithIRFlags(VPDef::VPInstructionSC, Operands,
+                                     DisjointFlag, DL),
         Opcode(Opcode), Name(Name.str()) {
     assert(Opcode == Instruction::Or && "only OR opcodes can be disjoint");
   }
 
-  VPInstruction(VPValue *Ptr, VPValue *Offset, GEPFlagsTy Flags,
-                DebugLoc DL = {}, const Twine &Name = "")
-      : VPRecipeWithIRFlags(VPDef::VPInstructionSC,
-                            ArrayRef<VPValue *>({Ptr, Offset}), Flags, DL),
+  VPInstruction(VPValue *Ptr, VPValue *Offset,
+                VPRecipeIRFlags::GEPFlagsTy Flags, DebugLoc DL = {},
+                const Twine &Name = "")
+      : VPSingleDefRecipeWithIRFlags(VPDef::VPInstructionSC,
+                                     ArrayRef<VPValue *>({Ptr, Offset}), Flags,
+                                     DL),
         Opcode(VPInstruction::PtrAdd), Name(Name.str()) {}
 
   VPInstruction(unsigned Opcode, std::initializer_list<VPValue *> Operands,
@@ -1326,7 +1339,7 @@ class VPInstruction : public VPRecipeWithIRFlags,
   VPInstruction *clone() override {
     SmallVector<VPValue *, 2> Operands(operands());
     auto *New = new VPInstruction(Opcode, Operands, getDebugLoc(), Name);
-    New->transferFlags(*this);
+    *New->getIRFlags() = *getIRFlags();
     return New;
   }
 
@@ -1438,14 +1451,15 @@ class VPIRInstruction : public VPRecipeBase {
 /// opcode and operands of the recipe. This recipe covers most of the
 /// traditional vectorization cases where each recipe transforms into a
 /// vectorized version of itself.
-class VPWidenRecipe : public VPRecipeWithIRFlags {
+class VPWidenRecipe : public VPSingleDefRecipeWithIRFlags {
   unsigned Opcode;
 
 protected:
   template <typename IterT>
   VPWidenRecipe(unsigned VPDefOpcode, Instruction &I,
                 iterator_range<IterT> Operands)
-      : VPRecipeWithIRFlags(VPDefOpcode, Operands, I), Opcode(I.getOpcode()) {}
+      : VPSingleDefRecipeWithIRFlags(VPDefOpcode, Operands, I),
+        Opcode(I.getOpcode()) {}
 
 public:
   template <typename IterT>
@@ -1456,7 +1470,7 @@ class VPWidenRecipe : public VPRecipeWithIRFlags {
 
   VPWidenRecipe *clone() override {
     auto *R = new VPWidenRecipe(*getUnderlyingInstr(), operands());
-    R->transferFlags(*this);
+    *R->getIRFlags() = *getIRFlags();
     return R;
   }
 
@@ -1490,8 +1504,6 @@ class VPWidenRecipe : public VPRecipeWithIRFlags {
 /// A recipe for widening operations with vector-predication intrinsics with
 /// explicit vector length (EVL).
 class VPWidenEVLRecipe : public VPWidenRecipe {
-  using VPRecipeWithIRFlags::transferFlags;
-
 public:
   template <typename IterT>
   VPWidenEVLRecipe(Instruction &I, iterator_range<IterT> Operands, VPValue &EVL)
@@ -1500,7 +1512,7 @@ class VPWidenEVLRecipe : public VPWidenRecipe {
   }
   VPWidenEVLRecipe(VPWidenRecipe &W, VPValue &EVL)
       : VPWidenEVLRecipe(*W.getUnderlyingInstr(), W.operands(), EVL) {
-    transferFlags(W);
+    *getIRFlags() = *W.getIRFlags();
   }
 
   ~VPWidenEVLRecipe() override = default;
@@ -1536,7 +1548,7 @@ class VPWidenEVLRecipe : public VPWidenRecipe {
 };
 
 /// VPWidenCastRecipe is a recipe to create vector cast instructions.
-class VPWidenCastRecipe : public VPRecipeWithIRFlags {
+class VPWidenCastRecipe : public VPSingleDefRecipeWithIRFlags {
   /// Cast instruction opcode.
   Instruction::CastOps Opcode;
 
@@ -1546,14 +1558,14 @@ class VPWidenCastRecipe : public VPRecipeWithIRFlags {
 public:
   VPWidenCastRecipe(Instruction::CastOps Opcode, VPValue *Op, Type *ResultTy,
                     CastInst &UI)
-      : VPRecipeWithIRFlags(VPDef::VPWidenCastSC, Op, UI), Opcode(Opcode),
-        ResultTy(ResultTy) {
+      : VPSingleDefRecipeWithIRFlags(VPDef::VPWidenCastSC, Op, UI),
+        Opcode(Opcode), ResultTy(ResultTy) {
     assert(UI.getOpcode() == Opcode &&
            "opcode of underlying cast doesn't match");
   }
 
   VPWidenCastRecipe(Instruction::CastOps Opcode, VPValue *Op, Type *ResultTy)
-      : VPRecipeWithIRFlags(VPDef::VPWidenCastSC, Op), Opcode(Opcode),
+      : VPSingleDefRecipeWithIRFlags(VPDef::VPWidenCastSC, Op), Opcode(Opcode),
         ResultTy(ResultTy) {}
 
   ~VPWidenCastRecipe() override = default;
@@ -1623,7 +1635,7 @@ class VPScalarCastRecipe : public VPSingleDefRecipe {
 };
 
 /// A recipe for widening vector intrinsics.
-class VPWidenIntrinsicRecipe : public VPRecipeWithIRFlags {
+class VPWidenIntrinsicRecipe : public VPSingleDefRecipeWithIRFlags {
   /// ID of the vector intrinsic to widen.
   Intrinsic::ID VectorIntrinsicID;
 
@@ -1643,7 +1655,8 @@ class VPWidenIntrinsicRecipe : public VPRecipeWithIRFlags {
   VPWidenIntrinsicRecipe(CallInst &CI, Intrinsic::ID VectorIntrinsicID,
                          ArrayRef<VPValue *> CallArguments, Type *Ty,
                          DebugLoc DL = {})
-      : VPRecipeWithIRFlags(VPDef::VPWidenIntrinsicSC, CallArguments, CI),
+      : VPSingleDefRecipeWithIRFlags(VPDef::VPWidenIntrinsicSC, CallArguments,
+                                     CI),
         VectorIntrinsicID(VectorIntrinsicID), ResultTy(Ty),
         MayReadFromMemory(CI.mayReadFromMemory()),
         MayWriteToMemory(CI.mayWriteToMemory()),
@@ -1689,7 +1702,7 @@ class VPWidenIntrinsicRecipe : public VPRecipeWithIRFlags {
 };
 
 /// A recipe for widening Call instructions using library calls.
-class VPWidenCallRecipe : public VPRecipeWithIRFlags {
+class VPWidenCallRecipe : public VPSingleDefRecipeWithIRFlags {
   /// Variant stores a pointer to the chosen function. There is a 1:1 mapping
   /// between a given VF and the chosen vectorized variant, so there will be a
   /// different VPlan for each VF with a valid variant.
@@ -1698,8 +1711,8 @@ class VPWidenCallRecipe : public VPRecipeWithIRFlags {
 public:
   VPWidenCallRecipe(Value *UV, Function *Variant,
                     ArrayRef<VPValue *> CallArguments, DebugLoc DL = {})
-      : VPRecipeWithIRFlags(VPDef::VPWidenCallSC, CallArguments,
-                            *cast<Instruction>(UV)),
+      : VPSingleDefRecipeWithIRFlags(VPDef::VPWidenCallSC, CallArguments,
+                                     *cast<Instruction>(UV)),
         Variant(Variant) {
     assert(
         isa<Function>(getOperand(getNumOperands() - 1)->getLiveInIRValue()) &&
@@ -1820,7 +1833,7 @@ struct VPWidenSelectRecipe : public VPSingleDefRecipe {
 };
 
 /// A recipe for handling GEP instructions.
-class VPWidenGEPRecipe : public VPRecipeWithIRFlags {
+class VPWidenGEPRecipe : public VPSingleDefRecipeWithIRFlags {
   bool isPointerLoopInvariant() const {
     return getOperand(0)->isDefinedOutsideLoopRegions();
   }
@@ -1838,7 +1851,7 @@ class VPWidenGEPRecipe : public VPRecipeWithIRFlags {
 public:
   template <typename IterT>
   VPWidenGEPRecipe(GetElementPtrInst *GEP, iterator_range<IterT> Operands)
-      : VPRecipeWithIRFlags(VPDef::VPWidenGEPSC, Operands, *GEP) {}
+      : VPSingleDefRecipeWithIRFlags(VPDef::VPWidenGEPSC, Operands, *GEP) {}
 
   ~VPWidenGEPRecipe() override = default;
 
@@ -1862,7 +1875,7 @@ class VPWidenGEPRecipe : public VPRecipeWithIRFlags {
 /// A recipe to compute the pointers for widened memory accesses of IndexTy for
 /// all parts. If IsReverse is true, compute pointers for accessing the input in
 /// reverse order per part.
-class VPVectorPointerRecipe : public VPRecipeWithIRFlags,
+class VPVectorPointerRecipe : public VPSingleDefRecipeWithIRFlags,
                               public VPUnrollPartAccessor<1> {
   Type *IndexedTy;
   bool IsReverse;
@@ -1870,8 +1883,9 @@ class VPVectorPointerRecipe : public VPRecipeWithIRFlags,
 public:
   VPVectorPointerRecipe(VPValue *Ptr, Type *IndexedTy, bool IsReverse,
                         bool IsInBounds, DebugLoc DL)
-      : VPRecipeWithIRFlags(VPDef::VPVectorPointerSC, ArrayRef<VPValue *>(Ptr),
-                            GEPFlagsTy(IsInBounds), DL),
+      : VPSingleDefRecipeWithIRFlags(
+            VPDef::VPVectorPointerSC, ArrayRef<VPValue *>(Ptr),
+            VPRecipeIRFlags::GEPFlagsTy(IsInBounds), DL),
         IndexedTy(IndexedTy), IsReverse(IsReverse) {}
 
   VP_CLASSOF_IMPL(VPDef::VPVectorPointerSC)
@@ -2549,7 +2563,7 @@ class VPReductionEVLRecipe : public VPReductionRecipe {
 /// copies of the original scalar type, one per lane, instead of producing a
 /// single copy of widened type for all lanes. If the instruction is known to be
 /// uniform only one copy, per lane zero, will be generated.
-class VPReplicateRecipe : public VPRecipeWithIRFlags {
+class VPReplicateRecipe : public VPSingleDefRecipeWithIRFlags {
   /// Indicator if only a single replica per lane is needed.
   bool IsUniform;
 
@@ -2560,7 +2574,7 @@ class VPReplicateRecipe : public VPRecipeWithIRFlags {
   template <typename IterT>
   VPReplicateRecipe(Instruction *I, iterator_range<IterT> Operands,
                     bool IsUniform, VPValue *Mask = nullptr)
-      : VPRecipeWithIRFlags(VPDef::VPReplicateSC, Operands, *I),
+      : VPSingleDefRecipeWithIRFlags(VPDef::VPReplicateSC, Operands, *I),
         IsUniform(IsUniform), IsPredicated(Mask) {
     if (Mask)
       addOperand(Mask);
@@ -2572,7 +2586,7 @@ class VPReplicateRecipe : public VPRecipeWithIRFlags {
     auto *Copy =
         new VPReplicateRecipe(getUnderlyingInstr(), operands(), IsUniform,
                               isPredicated() ? getMask() : nullptr);
-    Copy->transferFlags(*this);
+    *Copy->getIRFlags() = *getIRFlags();
     return Copy;
   }
 
@@ -3201,15 +3215,15 @@ class VPDerivedIVRecipe : public VPSingleDefRecipe {
 
 /// A recipe for handling phi nodes of integer and floating-point inductions,
 /// producing their scalar values.
-class VPScalarIVStepsRecipe : public VPRecipeWithIRFlags,
+class VPScalarIVStepsRecipe : public VPSingleDefRecipeWithIRFlags,
                               public VPUnrollPartAccessor<2> {
   Instruction::BinaryOps InductionOpcode;
 
 public:
   VPScalarIVStepsRecipe(VPValue *IV, VPValue *Step,
                         Instruction::BinaryOps Opcode, FastMathFlags FMFs)
-      : VPRecipeWithIRFlags(VPDef::VPScalarIVStepsSC,
-                            ArrayRef<VPValue *>({IV, Step}), FMFs),
+      : VPSingleDefRecipeWithIRFlags(VPDef::VPScalarIVStepsSC,
+                                     ArrayRef<VPValue *>({IV, Step}), FMFs),
         InductionOpcode(Opcode) {}
 
   VPScalarIVStepsRecipe(const InductionDescriptor &IndDesc, VPValue *IV,
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index 49322941f5f069..07933e67469be1 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -327,7 +327,7 @@ InstructionCost VPSingleDefRecipe::computeCost(ElementCount VF,
   return UI ? Ctx.getLegacyCost(UI, VF) : 0;
 }
 
-FastMathFlags VPRecipeWithIRFlags::getFastMathFlags() const {
+FastMathFlags VPRecipeIRFlags::getFastMathFlags() const {
   assert(OpType == OperationType::FPMathOp &&
          "recipe doesn't have fast math flags");
   FastMathFlags Res;
@@ -359,8 +359,8 @@ unsigned VPUnrollPartAccessor<PartOpIdx>::getUnrollPart(VPUser &U) const {
 VPInstruction::VPInstruction(unsigned Opcode, CmpInst::Predicate Pred,
                              VPValue *A, VPValue *B, DebugLoc DL,
                              const Twine &Name)
-    : VPRecipeWithIRFlags(VPDef::VPInstructionSC, ArrayRef<VPValue *>({A, B}),
-                          Pred, DL),
+    : VPSingleDefRecipeWithIRFlags(VPDef::VPInstructionSC,
+                                   ArrayRef<VPValue *>({A, B}), Pred, DL),
       Opcode(Opcode), Name(Name.str()) {
   assert(Opcode == Instruction::ICmp &&
          "only ICmp predicates supported at the moment");
@@ -369,7 +369,7 @@ VPInstruction::VPInstruction(unsigned Opcode, CmpInst::Predicate Pred,
 VPInstruction::VPInstruction(unsigned Opcode,
                              std::initializer_list<VPValue *> Operands,
                              FastMathFlags FMFs, DebugLoc DL, const Twine &Name)
-    : VPRecipeWithIRFlags(VPDef::VPInstructionSC, Operands, FMFs, DL),
+    : VPSingleDefRecipeWithIRFlags(VPDef::VPInstructionSC, Operands, FMFs, DL),
       Opcode(Opcode), Name(Name.str()) {
   // Make sure the VPInstruction is a floating-point operation.
   assert(isFPMathOp() && "this op can't take fast-math flags");
@@ -836,7 +836,10 @@ void VPInstruction::print(raw_ostream &O, const Twine &Indent,
   }
 
   printFlags(O);
-  printOperands(O, SlotTracker);
+  if (getNumOperands() > 0) {
+    O << " ";
+    printOperands(O, SlotTracker);
+  }
 
   if (auto DL = getDebugLoc()) {
     O << ", !dbg ";
@@ -1055,7 +1058,7 @@ void VPWidenIntrinsicRecipe::print(raw_ostream &O, const Twine &Indent,
 
   O << "call";
   printFlags(O);
-  O << getIntrinsicName() << "(";
+  O << " " << getIntrinsicName() << "(";
 
   interleaveComma(operands(), O, [&O, &SlotTracker](VPValue *Op) {
     Op->printAsOperand(O, SlotTracker);
@@ -1183,8 +1186,7 @@ void VPWidenSelectRecipe::execute(VPTransformState &State) {
   State.addMetadata(Sel, dyn_cast_or_null<Instruction>(getUnderlyingValue()));
 }
 
-VPRecipeWithIRFlags::FastMathFlagsTy::FastMathFlagsTy(
-    const FastMathFlags &FMF) {
+VPRecipeIRFlags::FastMathFlagsTy::FastMathFlagsTy(const FastMathFlags &FMF) {
   AllowReassoc = FMF.allowReassoc();
   NoNaNs = FMF.noNaNs();
   NoInfs = FMF.noInfs();
@@ -1195,7 +1197,7 @@ VPRecipeWithIRFlags::FastMathFlagsTy::FastMathFlagsTy(
 }
 
 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
-void VPRecipeWithIRFlags::printFlags(raw_ostream &O) const {
+void VPRecipeIRFlags::printFlags(raw_ostream &O) const {
   switch (OpType) {
   case OperationType::Cmp:
     O << " " << CmpInst::getPredicateName(getPredicate());
@@ -1228,8 +1230,6 @@ void VPRecipeWithIRFlags::printFlags(raw_ostream &O) const {
   case OperationType::Other:
     break;
   }
-  if (getNumOperands() > 0)
-    O << " ";
 }
 #endif
 
@@ -1436,7 +1436,10 @@ void VPWidenRecipe::print(raw_ostream &O, const Twine &Indent,
   printAsOperand(O, SlotTracker);
   O << " = " << Instruction::getOpcodeName(Opcode);
   printFlags(O);
-  printOperands(O, SlotTracker);
+  if (getNumOperands() > 0) {
+    O << " ";
+    printOperands(O, SlotTracker);
+  }
 }
 
 void VPWidenEVLRecipe::print(raw_ostream &O, const Twine &Indent,
@@ -1445,7 +1448,10 @@ void VPWidenEVLRecipe::print(raw_ostream &O, const Twine &Indent,
   printAsOperand(O, SlotTracker);
   O << " = vp." << Instruction::getOpcodeName(getOpcode());
   printFlags(O);
-  printOperands(O, SlotTracker);
+  if (getNumOperands() > 0) {
+    O << " ";
+    printOperands(O, SlotTracker);
+  }
 }
 #endif
 
@@ -1467,9 +1473,12 @@ void VPWidenCastRecipe::print(raw_ostream &O, const Twine &Indent,
                               VPSlotTracker &SlotTracker) const {
   O << Indent << "WIDEN-CAST ";
   printAsOperand(O, SlotTracker);
-  O << " = " << Instruction::getOpcodeName(Opcode) << " ";
+  O << " = " << Instruction::getOpcodeName(Opcode);
   printFlags(O);
-  printOperands(O, SlotTracker);
+  if (getNumOperands() > 0) {
+    O << " ";
+    printOperands(O, SlotTracker);
+  }
   O << " to " << *getResultType();
 }
 #endif
@@ -1853,6 +1862,7 @@ void VPWidenGEPRecipe::print(raw_ostream &O, const Twine &Indent,
   printAsOperand(O, SlotTracker);
   O << " = getelementptr";
   printFlags(O);
+  O << " ";
   printOperands(O, SlotTracker);
 }
 #endif
@@ -2176,7 +2186,7 @@ void VPReplicateRecipe::print(raw_ostream &O, const Twine &Indent,
   if (auto *CB = dyn_cast<CallBase>(getUnderlyingInstr())) {
     O << "call";
     printFlags(O);
-    O << "@" << CB->getCalledFunction()->getName() << "(";
+    O << " @" << CB->getCalledFunction()->getName() << "(";
     interleaveComma(make_range(op_begin(), op_begin() + (getNumOperands() - 1)),
                     O, [&O, &SlotTracker](VPValue *Op) {
                       Op->printAsOperand(O, SlotTracker);
@@ -2185,7 +2195,10 @@ void VPReplicateRecipe::print(raw_ostream &O, const Twine &Indent,
   } else {
     O << Instruction::getOpcodeName(getUnderlyingInstr()->getOpcode());
     printFlags(O);
-    printOperands(O, SlotTracker);
+    if (getNumOperands() > 0) {
+      O << " ";
+      printOperands(O, SlotTracker);
+    }
   }
 
   if (shouldPack())
diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
index 379bfc0a4394bf..9fa6ad0a7bd1dc 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
@@ -843,8 +843,9 @@ void VPlanTransforms::clearReductionWrapFlags(VPlan &Plan) {
       continue;
 
     for (VPUser *U : collectUsersRecursively(PhiR))
-      if (auto *RecWithFlags = dyn_cast<VPRecipeWithIRFlags>(U)) {
-        RecWithFlags->dropPoisonGeneratingFlags();
+      if (auto *R = dyn_cast<VPRecipeBase>(U)) {
+        if (auto *IRFlags = R->getIRFlags())
+          IRFlags->dropPoisonGeneratingFlags();
       }
   }
 }
@@ -1092,8 +1093,8 @@ void VPlanTransforms::truncateToMinimalBitwidths(
       // Any wrapping introduced by shrinking this operation shouldn't be
       // considered undefined behavior. So, we can't unconditionally copy
       // arithmetic wrapping flags to VPW.
-      if (auto *VPW = dyn_cast<VPRecipeWithIRFlags>(&R))
-        VPW->dropPoisonGeneratingFlags();
+      if (auto *Flags = R.getIRFlags())
+        Flags->dropPoisonGeneratingFlags();
 
       using namespace llvm::VPlanPatternMatch;
       if (OldResSizeInBits != NewResSizeInBits &&
@@ -1520,7 +1521,7 @@ void VPlanTransforms::dropPoisonGeneratingRecipes(
       // This recipe contributes to the address computation of a widen
       // load/store. If the underlying instruction has poison-generating flags,
       // drop them directly.
-      if (auto *RecWithFlags = dyn_cast<VPRecipeWithIRFlags>(CurRec)) {
+      if (auto *Flags = CurRec->getIRFlags()) {
         VPValue *A, *B;
         using namespace llvm::VPlanPatternMatch;
         // Dropping disjoint from an OR may yield incorrect results, as some
@@ -1528,25 +1529,25 @@ void VPlanTransforms::dropPoisonGeneratingRecipes(
         // for dependence analysis). Instead, replace it with an equivalent Add.
         // This is possible as all users of the disjoint OR only access lanes
         // where the operands are disjoint or poison otherwise.
-        if (match(RecWithFlags, m_BinaryOr(m_VPValue(A), m_VPValue(B))) &&
-            RecWithFlags->isDisjoint()) {
-          VPBuilder Builder(RecWithFlags);
+        if (match(CurRec, m_BinaryOr(m_VPValue(A), m_VPValue(B))) &&
+            Flags->isDisjoint()) {
+          VPValue *OldValue = CurRec->getVPSingleValue();
+          VPBuilder Builder(CurRec);
           VPInstruction *New = Builder.createOverflowingOp(
-              Instruction::Add, {A, B}, {false, false},
-              RecWithFlags->getDebugLoc());
-          New->setUnderlyingValue(RecWithFlags->getUnderlyingValue());
-          RecWithFlags->replaceAllUsesWith(New);
-          RecWithFlags->eraseFromParent();
+              Instruction::Add, {A, B}, {false, false}, CurRec->getDebugLoc());
+          New->setUnderlyingValue(OldValue->getUnderlyingValue());
+          OldValue->replaceAllUsesWith(New);
+          CurRec->eraseFromParent();
           CurRec = New;
         } else
-          RecWithFlags->dropPoisonGeneratingFlags();
+          Flags->dropPoisonGeneratingFlags();
       } else {
         Instruction *Instr = dyn_cast_or_null<Instruction>(
             CurRec->getVPSingleValue()->getUnderlyingValue());
         (void)Instr;
         assert((!Instr || !Instr->hasPoisonGeneratingFlags()) &&
                "found instruction with poison generating flags not covered by "
-               "VPRecipeWithIRFlags");
+               "without VPRecipeIRFlags");
       }
 
       // Add new definitions to the worklist.
diff --git a/llvm/test/Transforms/LoopVectorize/AArch64/sve2-histcnt-vplan.ll b/llvm/test/Transforms/LoopVectorize/AArch64/sve2-histcnt-vplan.ll
index 9be068ce880ea8..f6f42338959e3f 100644
--- a/llvm/test/Transforms/LoopVectorize/AArch64/sve2-histcnt-vplan.ll
+++ b/llvm/test/Transforms/LoopVectorize/AArch64/sve2-histcnt-vplan.ll
@@ -64,7 +64,7 @@ target triple = "aarch64-unknown-linux-gnu"
 ; CHECK-NEXT:     CLONE [[GEP_IDX:.*]] = getelementptr inbounds ir<%indices>, [[STEPS]]
 ; CHECK-NEXT:     [[VECP_IDX:vp.*]] = vector-pointer [[GEP_IDX]]
 ; CHECK-NEXT:     WIDEN [[IDX:.*]] = load [[VECP_IDX]]
-; CHECK-NEXT:     WIDEN-CAST [[EXT_IDX:.*]] = zext  [[IDX]] to i64
+; CHECK-NEXT:     WIDEN-CAST [[EXT_IDX:.*]] = zext [[IDX]] to i64
 ; CHECK-NEXT:     WIDEN-GEP Inv[Var] [[GEP_BUCKET:.*]] = getelementptr inbounds ir<%buckets>, [[EXT_IDX]]
 ; CHECK-NEXT:     WIDEN-HISTOGRAM buckets: [[GEP_BUCKET]], inc: ir<1>
 ; CHECK-NEXT:     EMIT [[IV_NEXT]] = add nuw [[IV]], [[VFxUF]]

>From 2aaf07c4fb69d21ded3109f5cb408196064b5f0e Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Tue, 15 Oct 2024 15:33:00 +0000
Subject: [PATCH 3/3] Prototype vectorizing structs via multiple result
 VPWidenCallRecipe

---
 .../Vectorize/LoopVectorizationLegality.cpp   |  19 +-
 .../Transforms/Vectorize/LoopVectorize.cpp    |  21 +-
 .../Transforms/Vectorize/VPRecipeBuilder.h    |  14 +-
 llvm/lib/Transforms/Vectorize/VPlan.h         |  23 +-
 .../Transforms/Vectorize/VPlanAnalysis.cpp    |  16 +-
 llvm/lib/Transforms/Vectorize/VPlanAnalysis.h |   2 +-
 .../lib/Transforms/Vectorize/VPlanRecipes.cpp |  20 +-
 .../Transforms/LoopVectorize/struct-return.ll | 202 ++++++++++++++++++
 .../vplan-widen-struct-return.ll              |  58 +++++
 9 files changed, 350 insertions(+), 25 deletions(-)
 create mode 100644 llvm/test/Transforms/LoopVectorize/struct-return.ll
 create mode 100644 llvm/test/Transforms/LoopVectorize/vplan-widen-struct-return.ll

diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp
index cb6327640dbdbb..05d1de9032a6ef 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp
@@ -946,11 +946,26 @@ bool LoopVectorizationLegality::canVectorizeInstrs() {
       if (CI && !VFDatabase::getMappings(*CI).empty())
         VecCallVariantsFound = true;
 
+      // TODO: Tidy up these checks.
+      auto canWidenInst = [](Instruction &I) {
+        Type *InstTy = I.getType();
+        if (isa<CallInst>(I) && isa<StructType>(InstTy) &&
+            canWidenType(InstTy)) {
+          // We can only widen struct calls where the users are extractvalues.
+          for (auto &U : I.uses()) {
+            if (!isa<ExtractValueInst>(U.getUser()))
+              return false;
+          }
+          return true;
+        }
+        return VectorType::isValidElementType(InstTy) || InstTy->isVoidTy();
+      };
+
       // Check that the instruction return type is vectorizable.
       // We can't vectorize casts from vector type to scalar type.
       // Also, we can't vectorize extractelement instructions.
-      Type *InstTy = I.getType();
-      if (!(InstTy->isVoidTy() || canWidenType(InstTy)) ||
+      // TODO: Tidy up these checks.
+      if (!canWidenInst(I) ||
           (isa<CastInst>(I) &&
            !VectorType::isValidElementType(I.getOperand(0)->getType())) ||
           isa<ExtractElementInst>(I)) {
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index 87566387e5d1d2..c3ec4631721b43 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -7357,6 +7357,8 @@ static bool planContainsAdditionalSimplifications(VPlan &Plan,
       return dyn_cast_or_null<Instruction>(S->getUnderlyingValue());
     if (auto *WidenMem = dyn_cast<VPWidenMemoryRecipe>(R))
       return &WidenMem->getIngredient();
+    if (auto *WidenCall = dyn_cast<VPWidenCallRecipe>(R))
+      return WidenCall->getUnderlyingCallInstruction();
     return nullptr;
   };
 
@@ -8332,9 +8334,9 @@ VPBlendRecipe *VPRecipeBuilder::tryToBlend(PHINode *Phi,
   return new VPBlendRecipe(Phi, OperandsWithMask);
 }
 
-VPSingleDefRecipe *VPRecipeBuilder::tryToWidenCall(CallInst *CI,
-                                                   ArrayRef<VPValue *> Operands,
-                                                   VFRange &Range) {
+VPRecipeBase *VPRecipeBuilder::tryToWidenCall(CallInst *CI,
+                                              ArrayRef<VPValue *> Operands,
+                                              VFRange &Range) {
   bool IsPredicated = LoopVectorizationPlanner::getDecisionAndClampRange(
       [this, CI](ElementCount VF) {
         return CM.isScalarWithPredication(CI, VF);
@@ -9044,6 +9046,19 @@ LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(VFRange &Range) {
     // TODO: Model and preserve debug intrinsics in VPlan.
     for (Instruction &I : drop_end(BB->instructionsWithoutDebug(false))) {
       Instruction *Instr = &I;
+
+      // A special case. Mapping handled in
+      // VPRecipeBuilder::getVPValueOrAddLiveIn().
+      if (auto *ExtractValue = dyn_cast<ExtractValueInst>(Instr)) {
+        bool IsUniform = LoopVectorizationPlanner::getDecisionAndClampRange(
+            [&](ElementCount VF) {
+              return CM.isUniformAfterVectorization(ExtractValue, VF);
+            },
+            Range);
+        if (!IsUniform)
+          continue;
+      }
+
       SmallVector<VPValue *, 4> Operands;
       auto *Phi = dyn_cast<PHINode>(Instr);
       if (Phi && Phi->getParent() == HeaderBB) {
diff --git a/llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h b/llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h
index 5d4a3b555981ce..c5b2147f571333 100644
--- a/llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h
+++ b/llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h
@@ -95,8 +95,8 @@ class VPRecipeBuilder {
   /// Handle call instructions. If \p CI can be widened for \p Range.Start,
   /// return a new VPWidenCallRecipe or VPWidenIntrinsicRecipe. Range.End may be
   /// decreased to ensure same decision from \p Range.Start to \p Range.End.
-  VPSingleDefRecipe *tryToWidenCall(CallInst *CI, ArrayRef<VPValue *> Operands,
-                                    VFRange &Range);
+  VPRecipeBase *tryToWidenCall(CallInst *CI, ArrayRef<VPValue *> Operands,
+                               VFRange &Range);
 
   /// Check if \p I has an opcode that can be widened and return a VPWidenRecipe
   /// if it can. The function should only be called if the cost-model indicates
@@ -182,6 +182,16 @@ class VPRecipeBuilder {
       if (auto *R = Ingredient2Recipe.lookup(I))
         return R->getVPSingleValue();
     }
+    // Ugh: Not sure where to handle this :(
+    if (auto *EVI = dyn_cast<ExtractValueInst>(V)) {
+      Value *AggOp = EVI->getAggregateOperand();
+      if (auto *R = getRecipe(cast<Instruction>(AggOp))) {
+        assert(R->getNumDefinedValues() ==
+               cast<StructType>(AggOp->getType())->getNumElements());
+        unsigned Idx = EVI->getIndices()[0];
+        return R->getVPValue(Idx);
+      }
+    }
     return Plan.getOrAddLiveIn(V);
   }
 };
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index c0bd5833b21a58..d1689b81f2e970 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -887,7 +887,6 @@ class VPSingleDefRecipe : public VPRecipeBase, public VPValue {
     case VPRecipeBase::VPReplicateSC:
     case VPRecipeBase::VPScalarIVStepsSC:
     case VPRecipeBase::VPVectorPointerSC:
-    case VPRecipeBase::VPWidenCallSC:
     case VPRecipeBase::VPWidenCanonicalIVSC:
     case VPRecipeBase::VPWidenCastSC:
     case VPRecipeBase::VPWidenGEPSC:
@@ -909,6 +908,7 @@ class VPSingleDefRecipe : public VPRecipeBase, public VPValue {
     case VPRecipeBase::VPBranchOnMaskSC:
     case VPRecipeBase::VPInterleaveSC:
     case VPRecipeBase::VPIRInstructionSC:
+    case VPRecipeBase::VPWidenCallSC:
     case VPRecipeBase::VPWidenLoadEVLSC:
     case VPRecipeBase::VPWidenLoadSC:
     case VPRecipeBase::VPWidenStoreEVLSC:
@@ -1702,28 +1702,35 @@ class VPWidenIntrinsicRecipe : public VPSingleDefRecipeWithIRFlags {
 };
 
 /// A recipe for widening Call instructions using library calls.
-class VPWidenCallRecipe : public VPSingleDefRecipeWithIRFlags {
+class VPWidenCallRecipe : public VPRecipeBase, public VPRecipeIRFlags {
   /// Variant stores a pointer to the chosen function. There is a 1:1 mapping
   /// between a given VF and the chosen vectorized variant, so there will be a
   /// different VPlan for each VF with a valid variant.
   Function *Variant;
 
+  CallInst *CI;
+
 public:
-  VPWidenCallRecipe(Value *UV, Function *Variant,
+  VPWidenCallRecipe(CallInst *CI, Function *Variant,
                     ArrayRef<VPValue *> CallArguments, DebugLoc DL = {})
-      : VPSingleDefRecipeWithIRFlags(VPDef::VPWidenCallSC, CallArguments,
-                                     *cast<Instruction>(UV)),
-        Variant(Variant) {
+      : VPRecipeBase(VPDef::VPWidenCallSC, CallArguments, DL),
+        VPRecipeIRFlags(*CI), Variant(Variant), CI(CI) {
     assert(
         isa<Function>(getOperand(getNumOperands() - 1)->getLiveInIRValue()) &&
         "last operand must be the called function");
+    for (Type *Ty : getContainedTypes(CI->getType())) {
+      (void)Ty;
+      new VPValue(CI, this);
+    }
   }
 
+  CallInst *getUnderlyingCallInstruction() const { return CI; }
+
   ~VPWidenCallRecipe() override = default;
 
   VPWidenCallRecipe *clone() override {
-    return new VPWidenCallRecipe(getUnderlyingValue(), Variant,
-                                 {op_begin(), op_end()}, getDebugLoc());
+    return new VPWidenCallRecipe(CI, Variant, {op_begin(), op_end()},
+                                 getDebugLoc());
   }
 
   VP_CLASSOF_IMPL(VPDef::VPWidenCallSC)
diff --git a/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp b/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp
index 5a5b3ac19c46ad..664e8daf6a102d 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp
@@ -134,9 +134,14 @@ Type *VPTypeAnalysis::inferScalarTypeForRecipe(const VPWidenRecipe *R) {
   llvm_unreachable("Unhandled opcode!");
 }
 
-Type *VPTypeAnalysis::inferScalarTypeForRecipe(const VPWidenCallRecipe *R) {
-  auto &CI = *cast<CallInst>(R->getUnderlyingInstr());
-  return CI.getType();
+Type *VPTypeAnalysis::inferScalarTypeForRecipe(const VPWidenCallRecipe *R,
+                                               const VPValue *V) {
+  auto &CI = *cast<CallInst>(R->getUnderlyingCallInstruction());
+  for (auto [I, Ty] : enumerate(getContainedTypes(CI.getType()))) {
+    if (R->getVPValue(I) == V)
+      return Ty;
+  }
+  llvm_unreachable("Unexpected call value!");
 }
 
 Type *VPTypeAnalysis::inferScalarTypeForRecipe(const VPWidenMemoryRecipe *R) {
@@ -265,12 +270,13 @@ Type *VPTypeAnalysis::inferScalarType(const VPValue *V) {
             return inferScalarType(R->getOperand(0));
           })
           .Case<VPBlendRecipe, VPInstruction, VPWidenRecipe, VPWidenEVLRecipe,
-                VPReplicateRecipe, VPWidenCallRecipe, VPWidenMemoryRecipe,
-                VPWidenSelectRecipe>(
+                VPReplicateRecipe, VPWidenMemoryRecipe, VPWidenSelectRecipe>(
               [this](const auto *R) { return inferScalarTypeForRecipe(R); })
           .Case<VPWidenIntrinsicRecipe>([](const VPWidenIntrinsicRecipe *R) {
             return R->getResultType();
           })
+          .Case<VPWidenCallRecipe>(
+              [&](const auto *R) { return inferScalarTypeForRecipe(R, V); })
           .Case<VPInterleaveRecipe>([V](const VPInterleaveRecipe *R) {
             // TODO: Use info from interleave group.
             return V->getUnderlyingValue()->getType();
diff --git a/llvm/lib/Transforms/Vectorize/VPlanAnalysis.h b/llvm/lib/Transforms/Vectorize/VPlanAnalysis.h
index cc21870bee2e3b..140e5ac3359b66 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanAnalysis.h
+++ b/llvm/lib/Transforms/Vectorize/VPlanAnalysis.h
@@ -47,7 +47,7 @@ class VPTypeAnalysis {
 
   Type *inferScalarTypeForRecipe(const VPBlendRecipe *R);
   Type *inferScalarTypeForRecipe(const VPInstruction *R);
-  Type *inferScalarTypeForRecipe(const VPWidenCallRecipe *R);
+  Type *inferScalarTypeForRecipe(const VPWidenCallRecipe *R, const VPValue *V);
   Type *inferScalarTypeForRecipe(const VPWidenRecipe *R);
   Type *inferScalarTypeForRecipe(const VPWidenIntOrFpInductionRecipe *R);
   Type *inferScalarTypeForRecipe(const VPWidenMemoryRecipe *R);
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index 07933e67469be1..de8bde60b0a7bd 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -292,6 +292,8 @@ InstructionCost VPRecipeBase::cost(ElementCount VF, VPCostContext &Ctx) {
     UI = IG->getInsertPos();
   else if (auto *WidenMem = dyn_cast<VPWidenMemoryRecipe>(this))
     UI = &WidenMem->getIngredient();
+  else if (auto *WidenCall = dyn_cast<VPWidenCallRecipe>(this))
+    UI = WidenCall->getUnderlyingCallInstruction();
 
   InstructionCost RecipeCost;
   if (UI && Ctx.skipCostComputation(UI, VF.isVector())) {
@@ -913,7 +915,7 @@ void VPWidenCallRecipe::execute(VPTransformState &State) {
 
   assert(Variant != nullptr && "Can't create vector function.");
 
-  auto *CI = cast_or_null<CallInst>(getUnderlyingValue());
+  auto *CI = getUnderlyingCallInstruction();
   SmallVector<OperandBundleDef, 1> OpBundles;
   if (CI)
     CI->getOperandBundlesAsDefs(OpBundles);
@@ -921,8 +923,16 @@ void VPWidenCallRecipe::execute(VPTransformState &State) {
   CallInst *V = State.Builder.CreateCall(Variant, Args, OpBundles);
   setFlags(V);
 
-  if (!V->getType()->isVoidTy())
-    State.set(this, V);
+  if (!V->getType()->isVoidTy()) {
+    if (getNumDefinedValues() > 1) {
+      for (auto [I, Def] : enumerate(definedValues())) {
+        Value *AggV = State.Builder.CreateExtractValue(V, I);
+        State.set(Def, AggV);
+      }
+    } else {
+      State.set(getVPSingleValue(), V);
+    }
+  }
   State.addMetadata(V, CI);
 }
 
@@ -943,7 +953,9 @@ void VPWidenCallRecipe::print(raw_ostream &O, const Twine &Indent,
   if (CalledFn->getReturnType()->isVoidTy())
     O << "void ";
   else {
-    printAsOperand(O, SlotTracker);
+    interleaveComma(definedValues(), O, [&O, &SlotTracker](VPValue *Def) {
+      Def->printAsOperand(O, SlotTracker);
+    });
     O << " = ";
   }
 
diff --git a/llvm/test/Transforms/LoopVectorize/struct-return.ll b/llvm/test/Transforms/LoopVectorize/struct-return.ll
new file mode 100644
index 00000000000000..cf87ad01fcfc84
--- /dev/null
+++ b/llvm/test/Transforms/LoopVectorize/struct-return.ll
@@ -0,0 +1,202 @@
+; RUN: opt < %s -passes=loop-vectorize,dce,instcombine -force-vector-width=2 -force-vector-interleave=1 -S | FileCheck %s
+; RUN: opt < %s -passes=loop-vectorize,dce,instcombine -force-vector-width=2 -force-vector-interleave=1 -pass-remarks='loop-vectorize' -disable-output -S 2>&1 | FileCheck %s --check-prefix=CHECK-REMARKS
+
+target datalayout = "e-m:e-p:32:32-Fi8-i64:64-v128:64:128-a:0:32-n32-S64"
+
+; Tests basic vectorization of homogeneous struct literal returns.
+
+; CHECK-REMARKS-COUNT-3: remark: {{.*}} vectorized loop
+; CHECK-REMARKS-COUNT-2: remark: {{.*}} loop not vectorized: instruction return type cannot be vectorized
+; CHECK-REMARKS:         remark: {{.*}} loop not vectorized: call instruction cannot be vectorized
+
+define void @struct_return_f32_widen(ptr noalias %in, ptr noalias writeonly %out_a, ptr noalias writeonly %out_b) {
+; CHECK-LABEL: define void @struct_return_f32_widen
+; CHECK-SAME:  (ptr noalias [[IN:%.*]], ptr noalias writeonly [[OUT_A:%.*]], ptr noalias writeonly [[OUT_B:%.*]])
+; CHECK:       vector.body:
+; CHECK:         [[WIDE_CALL:%.*]] = call { <2 x float>, <2 x float> } @fixed_vec_foo(<2 x float> [[WIDE_LOAD:%.*]])
+; CHECK:         [[WIDE_A:%.*]] = extractvalue { <2 x float>, <2 x float> } [[WIDE_CALL]], 0
+; CHECK:         [[WIDE_B:%.*]] = extractvalue { <2 x float>, <2 x float> } [[WIDE_CALL]], 1
+; CHECK:         store <2 x float> [[WIDE_A]], ptr {{%.*}}, align 4
+; CHECK:         store <2 x float> [[WIDE_B]], ptr {{%.*}}, align 4
+entry:
+  br label %for.body
+
+for.body:
+  %iv = phi i64 [ 0, %entry ], [ %iv.next, %for.body ]
+  %arrayidx = getelementptr inbounds float, ptr %in, i64 %iv
+  %in_val = load float, ptr %arrayidx, align 4
+  %call = tail call { float, float } @foo(float %in_val) #0
+  %extract_a = extractvalue { float, float } %call, 0
+  %extract_b = extractvalue { float, float } %call, 1
+  %arrayidx2 = getelementptr inbounds float, ptr %out_a, i64 %iv
+  store float %extract_a, ptr %arrayidx2, align 4
+  %arrayidx4 = getelementptr inbounds float, ptr %out_b, i64 %iv
+  store float %extract_b, ptr %arrayidx4, align 4
+  %iv.next = add nuw nsw i64 %iv, 1
+  %exitcond.not = icmp eq i64 %iv.next, 1024
+  br i1 %exitcond.not, label %exit, label %for.body
+
+exit:
+  ret void
+}
+
+define void @struct_return_f64_widen(ptr noalias %in, ptr noalias writeonly %out_a, ptr noalias writeonly %out_b) {
+; CHECK-LABEL: define void @struct_return_f64_widen
+; CHECK-SAME:  (ptr noalias [[IN:%.*]], ptr noalias writeonly [[OUT_A:%.*]], ptr noalias writeonly [[OUT_B:%.*]])
+; CHECK:        vector.body:
+; CHECK:          [[WIDE_CALL:%.*]] = call { <2 x double>, <2 x double> } @fixed_vec_bar(<2 x double> [[WIDE_LOAD:%.*]])
+; CHECK:          [[WIDE_A:%.*]] = extractvalue { <2 x double>, <2 x double> } [[WIDE_CALL]], 0
+; CHECK:          [[WIDE_B:%.*]] = extractvalue { <2 x double>, <2 x double> } [[WIDE_CALL]], 1
+; CHECK:          store <2 x double> [[WIDE_A]], ptr {{%.*}}, align 8
+; CHECK:          store <2 x double> [[WIDE_B]], ptr {{%.*}}, align 8
+entry:
+  br label %for.body
+
+for.body:
+  %iv = phi i64 [ 0, %entry ], [ %iv.next, %for.body ]
+  %arrayidx = getelementptr inbounds double, ptr %in, i64 %iv
+  %in_val = load double, ptr %arrayidx, align 8
+  %call = tail call { double, double } @bar(double %in_val) #1
+  %extract_a = extractvalue { double, double } %call, 0
+  %extract_b = extractvalue { double, double } %call, 1
+  %arrayidx2 = getelementptr inbounds double, ptr %out_a, i64 %iv
+  store double %extract_a, ptr %arrayidx2, align 8
+  %arrayidx4 = getelementptr inbounds double, ptr %out_b, i64 %iv
+  store double %extract_b, ptr %arrayidx4, align 8
+  %iv.next = add nuw nsw i64 %iv, 1
+  %exitcond.not = icmp eq i64 %iv.next, 1024
+  br i1 %exitcond.not, label %exit, label %for.body
+
+exit:
+  ret void
+}
+
+define void @struct_return_f32_widen_rt_checks(ptr %in, ptr writeonly %out_a, ptr writeonly %out_b) {
+; CHECK-LABEL: define void @struct_return_f32_widen_rt_checks
+; CHECK-SAME:  (ptr [[IN:%.*]], ptr writeonly [[OUT_A:%.*]], ptr writeonly [[OUT_B:%.*]])
+; CHECK:       entry:
+; CHECK:         br i1 false, label %scalar.ph, label %vector.memcheck
+; CHECK:       vector.memcheck:
+; CHECK:       vector.body:
+; CHECK:         call { <2 x float>, <2 x float> } @fixed_vec_foo(<2 x float> [[WIDE_LOAD:%.*]])
+; CHECK:       for.body:
+; CHECK          call { float, float } @foo(float [[LOAD:%.*]])
+entry:
+  br label %for.body
+
+for.body:
+  %iv = phi i64 [ 0, %entry ], [ %iv.next, %for.body ]
+  %arrayidx = getelementptr inbounds float, ptr %in, i64 %iv
+  %in_val = load float, ptr %arrayidx, align 4
+  %call = tail call { float, float } @foo(float %in_val) #0
+  %extract_a = extractvalue { float, float } %call, 0
+  %extract_b = extractvalue { float, float } %call, 1
+  %arrayidx2 = getelementptr inbounds float, ptr %out_a, i64 %iv
+  store float %extract_a, ptr %arrayidx2, align 4
+  %arrayidx4 = getelementptr inbounds float, ptr %out_b, i64 %iv
+  store float %extract_b, ptr %arrayidx4, align 4
+  %iv.next = add nuw nsw i64 %iv, 1
+  %exitcond.not = icmp eq i64 %iv.next, 1024
+  br i1 %exitcond.not, label %exit, label %for.body
+
+exit:
+  ret void
+}
+
+; Negative test. Widening structs with mixed element types is not supported.
+define void @negative_mixed_element_type_struct_return(ptr noalias %in, ptr noalias writeonly %out_a, ptr noalias writeonly %out_b) {
+; CHECK-LABEL: define void @negative_mixed_element_type_struct_return
+; CHECK-NOT:   vector.body:
+; CHECK-NOT:   call {{.*}} @fixed_vec_baz
+entry:
+  br label %for.body
+
+for.body:
+  %iv = phi i64 [ 0, %entry ], [ %iv.next, %for.body ]
+  %arrayidx = getelementptr inbounds float, ptr %in, i64 %iv
+  %in_val = load float, ptr %arrayidx, align 4
+  %call = tail call { float, i32 } @baz(float %in_val) #2
+  %extract_a = extractvalue { float, i32 } %call, 0
+  %extract_b = extractvalue { float, i32 } %call, 1
+  %arrayidx2 = getelementptr inbounds float, ptr %out_a, i64 %iv
+  store float %extract_a, ptr %arrayidx2, align 4
+  %arrayidx4 = getelementptr inbounds i32, ptr %out_b, i64 %iv
+  store i32 %extract_b, ptr %arrayidx4, align 4
+  %iv.next = add nuw nsw i64 %iv, 1
+  %exitcond.not = icmp eq i64 %iv.next, 1024
+  br i1 %exitcond.not, label %exit, label %for.body
+
+exit:
+  ret void
+}
+
+%named_struct = type { double, double }
+
+; Negative test. Widening non-literal structs is not supported.
+define void @test_named_struct_return(ptr noalias readonly %in, ptr noalias writeonly %out_a, ptr noalias writeonly %out_b) {
+; CHECK-LABEL: define void @test_named_struct_return
+; CHECK-NOT:   vector.body:
+; CHECK-NOT:   call {{.*}} @fixed_vec_bar
+entry:
+  br label %for.body
+
+for.body:
+  %iv = phi i64 [ 0, %entry ], [ %iv.next, %for.body ]
+  %arrayidx = getelementptr inbounds double, ptr %in, i64 %iv
+  %in_val = load double, ptr %arrayidx, align 8
+  %call = tail call %named_struct @bar_named(double %in_val) #3
+  %extract_a = extractvalue %named_struct %call, 0
+  %extract_b = extractvalue %named_struct %call, 1
+  %arrayidx2 = getelementptr inbounds double, ptr %out_a, i64 %iv
+  store double %extract_a, ptr %arrayidx2, align 8
+  %arrayidx4 = getelementptr inbounds double, ptr %out_b, i64 %iv
+  store double %extract_b, ptr %arrayidx4, align 8
+  %iv.next = add nuw nsw i64 %iv, 1
+  %exitcond.not = icmp eq i64 %iv.next, 1024
+  br i1 %exitcond.not, label %exit, label %for.body
+
+exit:
+  ret void
+}
+
+; TODO: Allow mixed-struct type vectorization and mark overflow intrinsics as trivially vectorizable.
+define void @test_overflow_intrinsic(ptr noalias readonly %in, ptr noalias writeonly %out_a, ptr noalias writeonly %out_b) {
+; CHECK-LABEL: define void @test_overflow_intrinsic
+; CHECK-NOT:   vector.body:
+; CHECK-NOT:   @llvm.sadd.with.overflow.v{{.+}}i32
+entry:
+  br label %for.body
+
+for.body:
+  %iv = phi i64 [ 0, %entry ], [ %iv.next, %for.body ]
+  %arrayidx = getelementptr inbounds float, ptr %in, i64 %iv
+  %in_val = load i32, ptr %arrayidx, align 4
+  %call = tail call { i32, i1 } @llvm.sadd.with.overflow.i32(i32 %in_val, i32 %in_val)
+  %extract_ret = extractvalue { i32, i1 } %call, 0
+  %extract_overflow = extractvalue { i32, i1 } %call, 1
+  %zext_overflow = zext i1 %extract_overflow to i8
+  %arrayidx2 = getelementptr inbounds i32, ptr %out_a, i64 %iv
+  store i32 %extract_ret, ptr %arrayidx2, align 4
+  %arrayidx4 = getelementptr inbounds i8, ptr %out_b, i64 %iv
+  store i8 %zext_overflow, ptr %arrayidx4, align 4
+  %iv.next = add nuw nsw i64 %iv, 1
+  %exitcond.not = icmp eq i64 %iv.next, 1024
+  br i1 %exitcond.not, label %exit, label %for.body
+
+exit:
+  ret void
+}
+
+declare { float, float } @foo(float)
+declare { double, double } @bar(double)
+declare { float, i32 } @baz(float)
+declare %named_struct @bar_named(double)
+
+declare { <2 x float>, <2 x float> } @fixed_vec_foo(<2 x float>)
+declare { <2 x double>, <2 x double> } @fixed_vec_bar(<2 x double>)
+declare { <2 x float>, <2 x i32> } @fixed_vec_baz(<2 x float>)
+
+attributes #0 = { nounwind "vector-function-abi-variant"="_ZGVnN2v_foo(fixed_vec_foo)" }
+attributes #1 = { nounwind "vector-function-abi-variant"="_ZGVnN2v_bar(fixed_vec_bar)" }
+attributes #2 = { nounwind "vector-function-abi-variant"="_ZGVnN2v_baz(fixed_vec_baz)" }
+attributes #3 = { nounwind "vector-function-abi-variant"="_ZGVnN2v_bar_named(fixed_vec_bar)" }
diff --git a/llvm/test/Transforms/LoopVectorize/vplan-widen-struct-return.ll b/llvm/test/Transforms/LoopVectorize/vplan-widen-struct-return.ll
new file mode 100644
index 00000000000000..7e6d341fd55691
--- /dev/null
+++ b/llvm/test/Transforms/LoopVectorize/vplan-widen-struct-return.ll
@@ -0,0 +1,58 @@
+; REQUIRES: asserts
+; RUN: opt < %s -passes=loop-vectorize,dce,instcombine -force-vector-width=2 -force-vector-interleave=1 -debug-only=loop-vectorize -disable-output -S 2>&1 | FileCheck %s
+
+define void @struct_return_f32_widen(ptr noalias %in, ptr noalias writeonly %out_a, ptr noalias writeonly %out_b) {
+; CHECK-LABEL: LV: Checking a loop in 'struct_return_f32_widen'
+; CHECK:       VPlan 'Initial VPlan for VF={2},UF>=1' {
+; CHECK-NEXT:  Live-in vp<%0> = VF * UF
+; CHECK-NEXT:  Live-in vp<%1> = vector-trip-count
+; CHECK-NEXT:  Live-in ir<1024> = original trip-count
+; CHECK-EMPTY:
+; CHECK-NEXT:  vector.ph:
+; CHECK-NEXT:  Successor(s): vector loop
+; CHECK-EMPTY:
+; CHECK-NEXT:  <x1> vector loop: {
+; CHECK-NEXT:    vector.body:
+; CHECK-NEXT:      EMIT vp<%2> = CANONICAL-INDUCTION ir<0>, vp<%7>
+; CHECK-NEXT:      vp<%3> = SCALAR-STEPS vp<%2>, ir<1>
+; CHECK-NEXT:      CLONE ir<%arrayidx> = getelementptr inbounds ir<%in>, vp<%3>
+; CHECK-NEXT:      vp<%4> = vector-pointer ir<%arrayidx>
+; CHECK-NEXT:      WIDEN ir<%in_val> = load vp<%4>
+; CHECK-NEXT:      WIDEN-CALL ir<%call>, ir<%call>.1 = call @foo(ir<%in_val>) (using library function: fixed_vec_foo)
+; CHECK-NEXT:      CLONE ir<%arrayidx2> = getelementptr inbounds ir<%out_a>, vp<%3>
+; CHECK-NEXT:      vp<%5> = vector-pointer ir<%arrayidx2>
+; CHECK-NEXT:      WIDEN store vp<%5>, ir<%call>
+; CHECK-NEXT:      CLONE ir<%arrayidx4> = getelementptr inbounds ir<%out_b>, vp<%3>
+; CHECK-NEXT:      vp<%6> = vector-pointer ir<%arrayidx4>
+; CHECK-NEXT:      WIDEN store vp<%6>, ir<%call>.1
+; CHECK-NEXT:      EMIT vp<%7> = add nuw vp<%2>, vp<%0>
+; CHECK-NEXT:      EMIT branch-on-count vp<%7>, vp<%1>
+; CHECK-NEXT:    No successors
+; CHECK-NEXT:  }
+entry:
+  br label %for.body
+
+for.body:
+  %iv = phi i64 [ 0, %entry ], [ %iv.next, %for.body ]
+  %arrayidx = getelementptr inbounds float, ptr %in, i64 %iv
+  %in_val = load float, ptr %arrayidx, align 4
+  %call = tail call { float, float } @foo(float %in_val) #0
+  %extract_a = extractvalue { float, float } %call, 0
+  %extract_b = extractvalue { float, float } %call, 1
+  %arrayidx2 = getelementptr inbounds float, ptr %out_a, i64 %iv
+  store float %extract_a, ptr %arrayidx2, align 4
+  %arrayidx4 = getelementptr inbounds float, ptr %out_b, i64 %iv
+  store float %extract_b, ptr %arrayidx4, align 4
+  %iv.next = add nuw nsw i64 %iv, 1
+  %exitcond.not = icmp eq i64 %iv.next, 1024
+  br i1 %exitcond.not, label %exit, label %for.body
+
+exit:
+  ret void
+}
+
+declare { float, float } @foo(float)
+
+declare { <2 x float>, <2 x float> } @fixed_vec_foo(<2 x float>)
+
+attributes #0 = { nounwind "vector-function-abi-variant"="_ZGVnN2v_foo(fixed_vec_foo)" }



More information about the llvm-commits mailing list