[llvm] [VPlan] Add initial anlysis to infer scalar type of VPValues. (PR #69013)

Florian Hahn via llvm-commits llvm-commits at lists.llvm.org
Fri Oct 27 03:39:21 PDT 2023


https://github.com/fhahn updated https://github.com/llvm/llvm-project/pull/69013

>From 22b23c7de44939e470a47be8af077a7c4516b3c5 Mon Sep 17 00:00:00 2001
From: Florian Hahn <flo at fhahn.com>
Date: Tue, 10 Oct 2023 20:57:17 +0100
Subject: [PATCH 01/12] [VPlan] Add initial anlysis to infer scalar type of
 VPValues.

This patch adds initial type inferrence for VPValues. It infers the
scalar type of a VPValue, by bottom-up traversing through defining recipes until
root nodes with known types are reached (e.g. live-ins or memory
recipes). The types are then propagated top down through operations.

This is intended as building block for a VPlan-based cost model, which
will need access to type information for VPValues/recipes.

Initial testing is done by asserting the inferred type matches the type
of the result value generated for a widen recipe.
---
 llvm/lib/Transforms/Vectorize/CMakeLists.txt  |   1 +
 llvm/lib/Transforms/Vectorize/VPlan.h         |   8 +-
 .../Transforms/Vectorize/VPlanAnalysis.cpp    | 225 ++++++++++++++++++
 llvm/lib/Transforms/Vectorize/VPlanAnalysis.h |  56 +++++
 .../lib/Transforms/Vectorize/VPlanRecipes.cpp |  12 +
 5 files changed, 299 insertions(+), 3 deletions(-)
 create mode 100644 llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp
 create mode 100644 llvm/lib/Transforms/Vectorize/VPlanAnalysis.h

diff --git a/llvm/lib/Transforms/Vectorize/CMakeLists.txt b/llvm/lib/Transforms/Vectorize/CMakeLists.txt
index 998dfd956575d3c..9674094024b9ec7 100644
--- a/llvm/lib/Transforms/Vectorize/CMakeLists.txt
+++ b/llvm/lib/Transforms/Vectorize/CMakeLists.txt
@@ -6,6 +6,7 @@ add_llvm_component_library(LLVMVectorize
   Vectorize.cpp
   VectorCombine.cpp
   VPlan.cpp
+  VPlanAnalysis.cpp
   VPlanHCFGBuilder.cpp
   VPlanRecipes.cpp
   VPlanSLP.cpp
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index e65a7ab2cd028ee..ea1f8a5b9d1e9ab 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -1167,6 +1167,8 @@ class VPWidenRecipe : public VPRecipeWithIRFlags, public VPValue {
   /// Produce widened copies of all Ingredients.
   void execute(VPTransformState &State) override;
 
+  unsigned getOpcode() const { return Opcode; }
+
 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
   /// Print the recipe.
   void print(raw_ostream &O, const Twine &Indent,
@@ -1458,7 +1460,7 @@ class VPWidenIntOrFpInductionRecipe : public VPHeaderPHIRecipe {
   bool isCanonical() const;
 
   /// Returns the scalar type of the induction.
-  const Type *getScalarType() const {
+  Type *getScalarType() const {
     return Trunc ? Trunc->getType() : IV->getType();
   }
 };
@@ -2080,7 +2082,7 @@ class VPCanonicalIVPHIRecipe : public VPHeaderPHIRecipe {
 #endif
 
   /// Returns the scalar type of the induction.
-  const Type *getScalarType() const {
+  Type *getScalarType() const {
     return getOperand(0)->getLiveInIRValue()->getType();
   }
 
@@ -2149,7 +2151,7 @@ class VPWidenCanonicalIVRecipe : public VPRecipeBase, public VPValue {
 #endif
 
   /// Returns the scalar type of the induction.
-  const Type *getScalarType() const {
+  Type *getScalarType() const {
     return cast<VPCanonicalIVPHIRecipe>(getOperand(0)->getDefiningRecipe())
         ->getScalarType();
   }
diff --git a/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp b/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp
new file mode 100644
index 000000000000000..088da81f950425c
--- /dev/null
+++ b/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp
@@ -0,0 +1,225 @@
+//===- VPlanAnalysis.cpp - Various Analyses working on VPlan ----*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "VPlanAnalysis.h"
+#include "VPlan.h"
+
+using namespace llvm;
+
+#define DEBUG_TYPE "vplan"
+
+Type *VPTypeAnalysis::inferType(const VPBlendRecipe *R) {
+  return inferType(R->getIncomingValue(0));
+}
+
+Type *VPTypeAnalysis::inferType(const VPInstruction *R) {
+  switch (R->getOpcode()) {
+  case Instruction::Select:
+    return inferType(R->getOperand(1));
+  case VPInstruction::FirstOrderRecurrenceSplice:
+    return inferType(R->getOperand(0));
+  default:
+    llvm_unreachable("Unhandled instruction!");
+  }
+}
+
+Type *VPTypeAnalysis::inferType(const VPInterleaveRecipe *R) { return nullptr; }
+
+Type *VPTypeAnalysis::inferType(const VPReductionPHIRecipe *R) {
+  return R->getOperand(0)->getLiveInIRValue()->getType();
+}
+
+Type *VPTypeAnalysis::inferType(const VPWidenRecipe *R) {
+  unsigned Opcode = R->getOpcode();
+  switch (Opcode) {
+  case Instruction::ICmp:
+  case Instruction::FCmp:
+    return IntegerType::get(Ctx, 1);
+  case Instruction::UDiv:
+  case Instruction::SDiv:
+  case Instruction::SRem:
+  case Instruction::URem:
+  case Instruction::Add:
+  case Instruction::FAdd:
+  case Instruction::Sub:
+  case Instruction::FSub:
+  case Instruction::FNeg:
+  case Instruction::Mul:
+  case Instruction::FMul:
+  case Instruction::FDiv:
+  case Instruction::FRem:
+  case Instruction::Shl:
+  case Instruction::LShr:
+  case Instruction::AShr:
+  case Instruction::And:
+  case Instruction::Or:
+  case Instruction::Xor: {
+    Type *ResTy = inferType(R->getOperand(0));
+    if (Opcode != Instruction::FNeg) {
+      assert(ResTy == inferType(R->getOperand(1)));
+      CachedTypes[R->getOperand(1)] = ResTy;
+    }
+    return ResTy;
+  }
+  case Instruction::Freeze:
+    return inferType(R->getOperand(0));
+  default:
+    // This instruction is not vectorized by simple widening.
+    //    LLVM_DEBUG(dbgs() << "LV: Found an unhandled instruction: " << I);
+    llvm_unreachable("Unhandled instruction!");
+  }
+
+  return nullptr;
+}
+
+Type *VPTypeAnalysis::inferType(const VPWidenCallRecipe *R) {
+  auto &CI = *cast<CallInst>(R->getUnderlyingInstr());
+  return CI.getType();
+}
+
+Type *VPTypeAnalysis::inferType(const VPWidenIntOrFpInductionRecipe *R) {
+  return R->getScalarType();
+}
+
+Type *VPTypeAnalysis::inferType(const VPWidenMemoryInstructionRecipe *R) {
+  if (R->isStore())
+    return cast<StoreInst>(&R->getIngredient())->getValueOperand()->getType();
+
+  return cast<LoadInst>(&R->getIngredient())->getType();
+}
+
+Type *VPTypeAnalysis::inferType(const VPWidenSelectRecipe *R) {
+  return inferType(R->getOperand(1));
+}
+
+Type *VPTypeAnalysis::inferType(const VPReplicateRecipe *R) {
+  switch (R->getUnderlyingInstr()->getOpcode()) {
+  case Instruction::Call: {
+    unsigned CallIdx = R->getNumOperands() - (R->isPredicated() ? 2 : 1);
+    return cast<Function>(R->getOperand(CallIdx)->getLiveInIRValue())
+        ->getReturnType();
+  }
+  case Instruction::UDiv:
+  case Instruction::SDiv:
+  case Instruction::SRem:
+  case Instruction::URem:
+  case Instruction::Add:
+  case Instruction::FAdd:
+  case Instruction::Sub:
+  case Instruction::FSub:
+  case Instruction::FNeg:
+  case Instruction::Mul:
+  case Instruction::FMul:
+  case Instruction::FDiv:
+  case Instruction::FRem:
+  case Instruction::Shl:
+  case Instruction::LShr:
+  case Instruction::AShr:
+  case Instruction::And:
+  case Instruction::Or:
+  case Instruction::Xor:
+  case Instruction::ICmp:
+  case Instruction::FCmp: {
+    Type *ResTy = inferType(R->getOperand(0));
+    assert(ResTy == inferType(R->getOperand(1)));
+    CachedTypes[R->getOperand(1)] = ResTy;
+    return ResTy;
+  }
+  case Instruction::Trunc:
+  case Instruction::SExt:
+  case Instruction::ZExt:
+  case Instruction::FPExt:
+  case Instruction::FPTrunc:
+    return R->getUnderlyingInstr()->getType();
+  case Instruction::ExtractValue: {
+    return R->getUnderlyingValue()->getType();
+  }
+  case Instruction::Freeze:
+    return inferType(R->getOperand(0));
+  case Instruction::Load:
+    return cast<LoadInst>(R->getUnderlyingInstr())->getType();
+  case Instruction::Store:
+    return cast<StoreInst>(R->getUnderlyingInstr())
+        ->getValueOperand()
+        ->getType();
+  default:
+    llvm_unreachable("Unhandled instruction");
+  }
+
+  return nullptr;
+}
+
+Type *VPTypeAnalysis::inferType(const VPValue *V) {
+  auto Iter = CachedTypes.find(V);
+  if (Iter != CachedTypes.end())
+    return Iter->second;
+
+  Type *ResultTy = nullptr;
+  if (V->isLiveIn())
+    ResultTy = V->getLiveInIRValue()->getType();
+  else {
+    const VPRecipeBase *Def = V->getDefiningRecipe();
+    switch (Def->getVPDefID()) {
+    case VPDef::VPBlendSC:
+      ResultTy = inferType(cast<VPBlendRecipe>(Def));
+      break;
+    case VPDef::VPCanonicalIVPHISC:
+      ResultTy = cast<VPCanonicalIVPHIRecipe>(Def)->getScalarType();
+      break;
+    case VPDef::VPFirstOrderRecurrencePHISC:
+      ResultTy = Def->getOperand(0)->getLiveInIRValue()->getType();
+      break;
+    case VPDef::VPInstructionSC:
+      ResultTy = inferType(cast<VPInstruction>(Def));
+      break;
+    case VPDef::VPInterleaveSC:
+      ResultTy = V->getUnderlyingValue()
+                     ->getType(); // inferType(cast<VPInterleaveRecipe>(Def));
+      break;
+    case VPDef::VPPredInstPHISC:
+      ResultTy = inferType(Def->getOperand(0));
+      break;
+    case VPDef::VPReductionPHISC:
+      ResultTy = inferType(cast<VPReductionPHIRecipe>(Def));
+      break;
+    case VPDef::VPReplicateSC:
+      ResultTy = inferType(cast<VPReplicateRecipe>(Def));
+      break;
+    case VPDef::VPScalarIVStepsSC:
+      return inferType(Def->getOperand(0));
+      break;
+    case VPDef::VPWidenSC:
+      ResultTy = inferType(cast<VPWidenRecipe>(Def));
+      break;
+    case VPDef::VPWidenPHISC:
+      return inferType(Def->getOperand(0));
+    case VPDef::VPWidenPointerInductionSC:
+      return inferType(Def->getOperand(0));
+    case VPDef::VPWidenCallSC:
+      ResultTy = inferType(cast<VPWidenCallRecipe>(Def));
+      break;
+    case VPDef::VPWidenCastSC:
+      ResultTy = cast<VPWidenCastRecipe>(Def)->getResultType();
+      break;
+    case VPDef::VPWidenGEPSC:
+      ResultTy = PointerType::get(Ctx, 0);
+      break;
+    case VPDef::VPWidenIntOrFpInductionSC:
+      ResultTy = inferType(cast<VPWidenIntOrFpInductionRecipe>(Def));
+      break;
+    case VPDef::VPWidenMemoryInstructionSC:
+      ResultTy = inferType(cast<VPWidenMemoryInstructionRecipe>(Def));
+      break;
+    case VPDef::VPWidenSelectSC:
+      ResultTy = inferType(cast<VPWidenSelectRecipe>(Def));
+      break;
+    }
+  }
+  CachedTypes[V] = ResultTy;
+  return ResultTy;
+}
diff --git a/llvm/lib/Transforms/Vectorize/VPlanAnalysis.h b/llvm/lib/Transforms/Vectorize/VPlanAnalysis.h
new file mode 100644
index 000000000000000..8fcbe9ca99bb4d5
--- /dev/null
+++ b/llvm/lib/Transforms/Vectorize/VPlanAnalysis.h
@@ -0,0 +1,56 @@
+//===- VPlanAnalysis.h - Various Analyses working on VPlan ------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_TRANSFORMS_VECTORIZE_VPLANANALYSIS_H
+#define LLVM_TRANSFORMS_VECTORIZE_VPLANANALYSIS_H
+
+#include "llvm/ADT/DenseMap.h"
+
+namespace llvm {
+
+class LLVMContext;
+class VPValue;
+class VPBlendRecipe;
+class VPInterleaveRecipe;
+class VPInstruction;
+class VPReductionPHIRecipe;
+class VPWidenRecipe;
+class VPWidenCallRecipe;
+class VPWidenCastRecipe;
+class VPWidenIntOrFpInductionRecipe;
+class VPWidenMemoryInstructionRecipe;
+struct VPWidenSelectRecipe;
+class VPReplicateRecipe;
+class Type;
+
+/// An analysis for type-inferrence for VPValues.
+class VPTypeAnalysis {
+  DenseMap<const VPValue *, Type *> CachedTypes;
+  LLVMContext &Ctx;
+
+  Type *inferType(const VPBlendRecipe *R);
+  Type *inferType(const VPInstruction *R);
+  Type *inferType(const VPInterleaveRecipe *R);
+  Type *inferType(const VPWidenCallRecipe *R);
+  Type *inferType(const VPReductionPHIRecipe *R);
+  Type *inferType(const VPWidenRecipe *R);
+  Type *inferType(const VPWidenIntOrFpInductionRecipe *R);
+  Type *inferType(const VPWidenMemoryInstructionRecipe *R);
+  Type *inferType(const VPWidenSelectRecipe *R);
+  Type *inferType(const VPReplicateRecipe *R);
+
+public:
+  VPTypeAnalysis(LLVMContext &Ctx) : Ctx(Ctx) {}
+
+  /// Infer the type of \p V. Returns the scalar type of \p V.
+  Type *inferType(const VPValue *V);
+};
+
+} // end namespace llvm
+
+#endif // LLVM_TRANSFORMS_VECTORIZE_VPLANANALYSIS_H
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index 2a1213a98095907..b616abddb00f99a 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -12,6 +12,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "VPlan.h"
+#include "VPlanAnalysis.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/Twine.h"
@@ -738,7 +739,18 @@ void VPWidenRecipe::execute(VPTransformState &State) {
                       << Instruction::getOpcodeName(Opcode));
     llvm_unreachable("Unhandled instruction!");
   } // end of switch.
+
+#if !defined(NDEBUG)
+  // Verify that VPlan type infererrence results agree with the type of the
+  // generated values.
+  VPTypeAnalysis A(State.Builder.GetInsertBlock()->getContext());
+  for (unsigned Part = 0; Part < State.UF; ++Part) {
+    assert(VectorType::get(A.inferType(getVPSingleValue()), State.VF) ==
+           State.get(this, Part)->getType());
+  }
+#endif
 }
+
 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
 void VPWidenRecipe::print(raw_ostream &O, const Twine &Indent,
                           VPSlotTracker &SlotTracker) const {

>From 30a8968d487376efc90c3ec4cfae91d055ab3928 Mon Sep 17 00:00:00 2001
From: Florian Hahn <flo at fhahn.com>
Date: Fri, 20 Oct 2023 12:03:59 +0100
Subject: [PATCH 02/12] Remove Store case for VPReplicateRecipe, as it is dead
 code.

---
 llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp | 4 ----
 1 file changed, 4 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp b/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp
index 088da81f950425c..8c27787c425cf75 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp
@@ -143,10 +143,6 @@ Type *VPTypeAnalysis::inferType(const VPReplicateRecipe *R) {
     return inferType(R->getOperand(0));
   case Instruction::Load:
     return cast<LoadInst>(R->getUnderlyingInstr())->getType();
-  case Instruction::Store:
-    return cast<StoreInst>(R->getUnderlyingInstr())
-        ->getValueOperand()
-        ->getType();
   default:
     llvm_unreachable("Unhandled instruction");
   }

>From 58b145e92d8c018f86ed54cae5b5211e71c89441 Mon Sep 17 00:00:00 2001
From: Florian Hahn <flo at fhahn.com>
Date: Mon, 23 Oct 2023 17:19:37 +0100
Subject: [PATCH 03/12] Update llvm/lib/Transforms/Vectorize/VPlanAnalysis.h

Co-authored-by: ayalz <47719489+ayalz at users.noreply.github.com>
---
 llvm/lib/Transforms/Vectorize/VPlanAnalysis.h | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/llvm/lib/Transforms/Vectorize/VPlanAnalysis.h b/llvm/lib/Transforms/Vectorize/VPlanAnalysis.h
index 8fcbe9ca99bb4d5..277cf271d51d041 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanAnalysis.h
+++ b/llvm/lib/Transforms/Vectorize/VPlanAnalysis.h
@@ -28,7 +28,7 @@ struct VPWidenSelectRecipe;
 class VPReplicateRecipe;
 class Type;
 
-/// An analysis for type-inferrence for VPValues.
+/// An analysis for type-inference for VPValues.
 class VPTypeAnalysis {
   DenseMap<const VPValue *, Type *> CachedTypes;
   LLVMContext &Ctx;

>From d89b7d890a4724b940dde7536e75b4d4beef0475 Mon Sep 17 00:00:00 2001
From: Florian Hahn <flo at fhahn.com>
Date: Mon, 23 Oct 2023 17:20:12 +0100
Subject: [PATCH 04/12] Update llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp

Co-authored-by: ayalz <47719489+ayalz at users.noreply.github.com>
---
 llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index b616abddb00f99a..178e1571960a840 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -741,7 +741,7 @@ void VPWidenRecipe::execute(VPTransformState &State) {
   } // end of switch.
 
 #if !defined(NDEBUG)
-  // Verify that VPlan type infererrence results agree with the type of the
+  // Verify that VPlan type inference results agree with the type of the
   // generated values.
   VPTypeAnalysis A(State.Builder.GetInsertBlock()->getContext());
   for (unsigned Part = 0; Part < State.UF; ++Part) {

>From 4ed3b6dfe165aec75f4da8735d392c7aa7ee0822 Mon Sep 17 00:00:00 2001
From: Florian Hahn <flo at fhahn.com>
Date: Mon, 23 Oct 2023 17:24:43 +0100
Subject: [PATCH 05/12] Address comments

---
 llvm/lib/Transforms/Vectorize/VPlan.h         |   4 +-
 .../Transforms/Vectorize/VPlanAnalysis.cpp    | 174 +++++++++---------
 llvm/lib/Transforms/Vectorize/VPlanAnalysis.h |  24 +--
 .../lib/Transforms/Vectorize/VPlanRecipes.cpp |   2 +-
 4 files changed, 105 insertions(+), 99 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index ea1f8a5b9d1e9ab..653290a036dd235 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -2083,7 +2083,7 @@ class VPCanonicalIVPHIRecipe : public VPHeaderPHIRecipe {
 
   /// Returns the scalar type of the induction.
   Type *getScalarType() const {
-    return getOperand(0)->getLiveInIRValue()->getType();
+    return getStartValue()->getLiveInIRValue()->getType();
   }
 
   /// Returns true if the recipe only uses the first lane of operand \p Op.
@@ -2151,7 +2151,7 @@ class VPWidenCanonicalIVRecipe : public VPRecipeBase, public VPValue {
 #endif
 
   /// Returns the scalar type of the induction.
-  Type *getScalarType() const {
+  const Type *getScalarType() const {
     return cast<VPCanonicalIVPHIRecipe>(getOperand(0)->getDefiningRecipe())
         ->getScalarType();
   }
diff --git a/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp b/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp
index 8c27787c425cf75..f2e73ca4f63a7eb 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp
@@ -13,28 +13,35 @@ using namespace llvm;
 
 #define DEBUG_TYPE "vplan"
 
-Type *VPTypeAnalysis::inferType(const VPBlendRecipe *R) {
-  return inferType(R->getIncomingValue(0));
+Type *VPTypeAnalysis::inferScalarType(const VPBlendRecipe *R) {
+  Type *ResTy = inferScalarType(R->getIncomingValue(0));
+  for (unsigned I = 1, E = R->getNumIncomingValues(); I != E; ++I) {
+    VPValue *Inc = R->getIncomingValue(I);
+    assert(inferScalarType(Inc) == ResTy &&
+           "different types inferred for different incoming values");
+    CachedTypes[Inc] = ResTy;
+  }
+  return inferScalarType(R->getIncomingValue(0));
 }
 
-Type *VPTypeAnalysis::inferType(const VPInstruction *R) {
+Type *VPTypeAnalysis::inferScalarType(const VPInstruction *R) {
   switch (R->getOpcode()) {
-  case Instruction::Select:
-    return inferType(R->getOperand(1));
+  case Instruction::Select: {
+    Type *ResTy = inferScalarType(R->getOperand(1));
+    VPValue *OtherV = R->getOperand(2);
+    assert(inferScalarType(OtherV) == ResTy &&
+           "different types inferred for different operands");
+    CachedTypes[OtherV] = ResTy;
+    return ResTy;
+  }
   case VPInstruction::FirstOrderRecurrenceSplice:
-    return inferType(R->getOperand(0));
+    return inferScalarType(R->getOperand(0));
   default:
     llvm_unreachable("Unhandled instruction!");
   }
 }
 
-Type *VPTypeAnalysis::inferType(const VPInterleaveRecipe *R) { return nullptr; }
-
-Type *VPTypeAnalysis::inferType(const VPReductionPHIRecipe *R) {
-  return R->getOperand(0)->getLiveInIRValue()->getType();
-}
-
-Type *VPTypeAnalysis::inferType(const VPWidenRecipe *R) {
+Type *VPTypeAnalysis::inferScalarType(const VPWidenRecipe *R) {
   unsigned Opcode = R->getOpcode();
   switch (Opcode) {
   case Instruction::ICmp:
@@ -48,7 +55,6 @@ Type *VPTypeAnalysis::inferType(const VPWidenRecipe *R) {
   case Instruction::FAdd:
   case Instruction::Sub:
   case Instruction::FSub:
-  case Instruction::FNeg:
   case Instruction::Mul:
   case Instruction::FMul:
   case Instruction::FDiv:
@@ -59,45 +65,47 @@ Type *VPTypeAnalysis::inferType(const VPWidenRecipe *R) {
   case Instruction::And:
   case Instruction::Or:
   case Instruction::Xor: {
-    Type *ResTy = inferType(R->getOperand(0));
-    if (Opcode != Instruction::FNeg) {
-      assert(ResTy == inferType(R->getOperand(1)));
-      CachedTypes[R->getOperand(1)] = ResTy;
-    }
+    Type *ResTy = inferScalarType(R->getOperand(0));
+    assert(ResTy == inferScalarType(R->getOperand(1)) &&
+           "types for both operands must match for binary op");
+    CachedTypes[R->getOperand(1)] = ResTy;
     return ResTy;
   }
+  case Instruction::FNeg:
   case Instruction::Freeze:
-    return inferType(R->getOperand(0));
+    return inferScalarType(R->getOperand(0));
   default:
-    // This instruction is not vectorized by simple widening.
-    //    LLVM_DEBUG(dbgs() << "LV: Found an unhandled instruction: " << I);
-    llvm_unreachable("Unhandled instruction!");
+    break;
   }
 
-  return nullptr;
+  // Type inferrence not implemented for opcode.
+  LLVM_DEBUG(dbgs() << "LV: Found unhandled opcode: "
+                    << Instruction::getOpcodeName(Opcode));
+  llvm_unreachable("Unhandled opcode!");
 }
 
-Type *VPTypeAnalysis::inferType(const VPWidenCallRecipe *R) {
+Type *VPTypeAnalysis::inferScalarType(const VPWidenCallRecipe *R) {
   auto &CI = *cast<CallInst>(R->getUnderlyingInstr());
   return CI.getType();
 }
 
-Type *VPTypeAnalysis::inferType(const VPWidenIntOrFpInductionRecipe *R) {
-  return R->getScalarType();
-}
-
-Type *VPTypeAnalysis::inferType(const VPWidenMemoryInstructionRecipe *R) {
+Type *VPTypeAnalysis::inferScalarType(const VPWidenMemoryInstructionRecipe *R) {
   if (R->isStore())
     return cast<StoreInst>(&R->getIngredient())->getValueOperand()->getType();
 
   return cast<LoadInst>(&R->getIngredient())->getType();
 }
 
-Type *VPTypeAnalysis::inferType(const VPWidenSelectRecipe *R) {
-  return inferType(R->getOperand(1));
+Type *VPTypeAnalysis::inferScalarType(const VPWidenSelectRecipe *R) {
+  Type *ResTy = inferScalarType(R->getOperand(1));
+  VPValue *OtherV = R->getOperand(2);
+  assert(inferScalarType(OtherV) == ResTy &&
+         "different types inferred for different operands");
+  CachedTypes[OtherV] = ResTy;
+  return ResTy;
 }
 
-Type *VPTypeAnalysis::inferType(const VPReplicateRecipe *R) {
+Type *VPTypeAnalysis::inferScalarType(const VPReplicateRecipe *R) {
   switch (R->getUnderlyingInstr()->getOpcode()) {
   case Instruction::Call: {
     unsigned CallIdx = R->getNumOperands() - (R->isPredicated() ? 2 : 1);
@@ -112,7 +120,6 @@ Type *VPTypeAnalysis::inferType(const VPReplicateRecipe *R) {
   case Instruction::FAdd:
   case Instruction::Sub:
   case Instruction::FSub:
-  case Instruction::FNeg:
   case Instruction::Mul:
   case Instruction::FMul:
   case Instruction::FDiv:
@@ -125,8 +132,8 @@ Type *VPTypeAnalysis::inferType(const VPReplicateRecipe *R) {
   case Instruction::Xor:
   case Instruction::ICmp:
   case Instruction::FCmp: {
-    Type *ResTy = inferType(R->getOperand(0));
-    assert(ResTy == inferType(R->getOperand(1)));
+    Type *ResTy = inferScalarType(R->getOperand(0));
+    assert(ResTy == inferScalarType(R->getOperand(1)));
     CachedTypes[R->getOperand(1)] = ResTy;
     return ResTy;
   }
@@ -137,85 +144,82 @@ Type *VPTypeAnalysis::inferType(const VPReplicateRecipe *R) {
   case Instruction::FPTrunc:
     return R->getUnderlyingInstr()->getType();
   case Instruction::ExtractValue: {
-    return R->getUnderlyingValue()->getType();
+    return R->getUnderlyingInstr()->getType();
   }
   case Instruction::Freeze:
-    return inferType(R->getOperand(0));
+  case Instruction::FNeg:
+    return inferScalarType(R->getOperand(0));
   case Instruction::Load:
     return cast<LoadInst>(R->getUnderlyingInstr())->getType();
   default:
-    llvm_unreachable("Unhandled instruction");
+    break;
   }
 
-  return nullptr;
+  llvm_unreachable("Unhandled instruction");
 }
 
-Type *VPTypeAnalysis::inferType(const VPValue *V) {
-  auto Iter = CachedTypes.find(V);
-  if (Iter != CachedTypes.end())
-    return Iter->second;
+Type *VPTypeAnalysis::inferScalarType(const VPValue *V) {
+  if (Type *CachedTy = CachedTypes.lookup(V))
+    return CachedTy;
 
   Type *ResultTy = nullptr;
   if (V->isLiveIn())
-    ResultTy = V->getLiveInIRValue()->getType();
-  else {
+    return V->getLiveInIRValue()->getType();
+
     const VPRecipeBase *Def = V->getDefiningRecipe();
     switch (Def->getVPDefID()) {
-    case VPDef::VPBlendSC:
-      ResultTy = inferType(cast<VPBlendRecipe>(Def));
-      break;
     case VPDef::VPCanonicalIVPHISC:
-      ResultTy = cast<VPCanonicalIVPHIRecipe>(Def)->getScalarType();
-      break;
     case VPDef::VPFirstOrderRecurrencePHISC:
-      ResultTy = Def->getOperand(0)->getLiveInIRValue()->getType();
-      break;
+    case VPDef::VPReductionPHISC:
+    case VPDef::VPWidenPointerInductionSC:
+    // Handle header phi recipes, except VPWienIntOrFpInduction which needs
+    // special handling due it being possibly truncated.
+    ResultTy = cast<VPHeaderPHIRecipe>(Def)
+                   ->getStartValue()
+                   ->getLiveInIRValue()
+                   ->getType();
+    break;
+    case VPDef::VPWidenIntOrFpInductionSC:
+    ResultTy = cast<VPWidenIntOrFpInductionRecipe>(Def)->getScalarType();
+    break;
+    case VPDef::VPPredInstPHISC:
+    case VPDef::VPScalarIVStepsSC:
+    case VPDef::VPWidenPHISC:
+    ResultTy = inferScalarType(Def->getOperand(0));
+    break;
+    case VPDef::VPBlendSC:
+    ResultTy = inferScalarType(cast<VPBlendRecipe>(Def));
+    break;
     case VPDef::VPInstructionSC:
-      ResultTy = inferType(cast<VPInstruction>(Def));
-      break;
+    ResultTy = inferScalarType(cast<VPInstruction>(Def));
+    break;
     case VPDef::VPInterleaveSC:
-      ResultTy = V->getUnderlyingValue()
-                     ->getType(); // inferType(cast<VPInterleaveRecipe>(Def));
-      break;
-    case VPDef::VPPredInstPHISC:
-      ResultTy = inferType(Def->getOperand(0));
-      break;
-    case VPDef::VPReductionPHISC:
-      ResultTy = inferType(cast<VPReductionPHIRecipe>(Def));
-      break;
+    // TODO: Use info from interleave group.
+    ResultTy = V->getUnderlyingValue()->getType();
+    break;
     case VPDef::VPReplicateSC:
-      ResultTy = inferType(cast<VPReplicateRecipe>(Def));
-      break;
-    case VPDef::VPScalarIVStepsSC:
-      return inferType(Def->getOperand(0));
-      break;
+    ResultTy = inferScalarType(cast<VPReplicateRecipe>(Def));
+    break;
     case VPDef::VPWidenSC:
-      ResultTy = inferType(cast<VPWidenRecipe>(Def));
-      break;
-    case VPDef::VPWidenPHISC:
-      return inferType(Def->getOperand(0));
-    case VPDef::VPWidenPointerInductionSC:
-      return inferType(Def->getOperand(0));
+    ResultTy = inferScalarType(cast<VPWidenRecipe>(Def));
+    break;
     case VPDef::VPWidenCallSC:
-      ResultTy = inferType(cast<VPWidenCallRecipe>(Def));
-      break;
+    ResultTy = inferScalarType(cast<VPWidenCallRecipe>(Def));
+    break;
     case VPDef::VPWidenCastSC:
       ResultTy = cast<VPWidenCastRecipe>(Def)->getResultType();
       break;
     case VPDef::VPWidenGEPSC:
       ResultTy = PointerType::get(Ctx, 0);
       break;
-    case VPDef::VPWidenIntOrFpInductionSC:
-      ResultTy = inferType(cast<VPWidenIntOrFpInductionRecipe>(Def));
-      break;
     case VPDef::VPWidenMemoryInstructionSC:
-      ResultTy = inferType(cast<VPWidenMemoryInstructionRecipe>(Def));
+      ResultTy = inferScalarType(cast<VPWidenMemoryInstructionRecipe>(Def));
       break;
     case VPDef::VPWidenSelectSC:
-      ResultTy = inferType(cast<VPWidenSelectRecipe>(Def));
+      ResultTy = inferScalarType(cast<VPWidenSelectRecipe>(Def));
       break;
     }
-  }
-  CachedTypes[V] = ResultTy;
-  return ResultTy;
+    assert(ResultTy && "could not infer type for the given VPValue");
+    CachedTypes[V] = ResultTy;
+    return ResultTy;
 }
diff --git a/llvm/lib/Transforms/Vectorize/VPlanAnalysis.h b/llvm/lib/Transforms/Vectorize/VPlanAnalysis.h
index 277cf271d51d041..8f1cdf00f5cc0b6 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanAnalysis.h
+++ b/llvm/lib/Transforms/Vectorize/VPlanAnalysis.h
@@ -29,26 +29,28 @@ class VPReplicateRecipe;
 class Type;
 
 /// An analysis for type-inference for VPValues.
+/// It infers the scalar type for a given VPValue by bottom-up traversing
+/// through defining recipes until root nodes with known types are reached (e.g.
+/// live-ins or memory recipes). The types are then propagated top down through
+/// operations.
 class VPTypeAnalysis {
   DenseMap<const VPValue *, Type *> CachedTypes;
   LLVMContext &Ctx;
 
-  Type *inferType(const VPBlendRecipe *R);
-  Type *inferType(const VPInstruction *R);
-  Type *inferType(const VPInterleaveRecipe *R);
-  Type *inferType(const VPWidenCallRecipe *R);
-  Type *inferType(const VPReductionPHIRecipe *R);
-  Type *inferType(const VPWidenRecipe *R);
-  Type *inferType(const VPWidenIntOrFpInductionRecipe *R);
-  Type *inferType(const VPWidenMemoryInstructionRecipe *R);
-  Type *inferType(const VPWidenSelectRecipe *R);
-  Type *inferType(const VPReplicateRecipe *R);
+  Type *inferScalarType(const VPBlendRecipe *R);
+  Type *inferScalarType(const VPInstruction *R);
+  Type *inferScalarType(const VPWidenCallRecipe *R);
+  Type *inferScalarType(const VPWidenRecipe *R);
+  Type *inferScalarType(const VPWidenIntOrFpInductionRecipe *R);
+  Type *inferScalarType(const VPWidenMemoryInstructionRecipe *R);
+  Type *inferScalarType(const VPWidenSelectRecipe *R);
+  Type *inferScalarType(const VPReplicateRecipe *R);
 
 public:
   VPTypeAnalysis(LLVMContext &Ctx) : Ctx(Ctx) {}
 
   /// Infer the type of \p V. Returns the scalar type of \p V.
-  Type *inferType(const VPValue *V);
+  Type *inferScalarType(const VPValue *V);
 };
 
 } // end namespace llvm
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index 178e1571960a840..12ab2e3e0df7020 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -745,7 +745,7 @@ void VPWidenRecipe::execute(VPTransformState &State) {
   // generated values.
   VPTypeAnalysis A(State.Builder.GetInsertBlock()->getContext());
   for (unsigned Part = 0; Part < State.UF; ++Part) {
-    assert(VectorType::get(A.inferType(getVPSingleValue()), State.VF) ==
+    assert(VectorType::get(A.inferScalarType(getVPSingleValue()), State.VF) ==
            State.get(this, Part)->getType());
   }
 #endif

>From 529805c13b2dbd62fc860174c2a231b0c7a93f8b Mon Sep 17 00:00:00 2001
From: Florian Hahn <flo at fhahn.com>
Date: Mon, 23 Oct 2023 19:42:35 +0100
Subject: [PATCH 06/12] Fix formatting.

---
 .../Transforms/Vectorize/VPlanAnalysis.cpp    | 48 +++++++++----------
 1 file changed, 24 insertions(+), 24 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp b/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp
index f2e73ca4f63a7eb..1b73b270d0aa2cb 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp
@@ -172,40 +172,40 @@ Type *VPTypeAnalysis::inferScalarType(const VPValue *V) {
     case VPDef::VPFirstOrderRecurrencePHISC:
     case VPDef::VPReductionPHISC:
     case VPDef::VPWidenPointerInductionSC:
-    // Handle header phi recipes, except VPWienIntOrFpInduction which needs
-    // special handling due it being possibly truncated.
-    ResultTy = cast<VPHeaderPHIRecipe>(Def)
-                   ->getStartValue()
-                   ->getLiveInIRValue()
-                   ->getType();
-    break;
+      // Handle header phi recipes, except VPWienIntOrFpInduction which needs
+      // special handling due it being possibly truncated.
+      ResultTy = cast<VPHeaderPHIRecipe>(Def)
+                     ->getStartValue()
+                     ->getLiveInIRValue()
+                     ->getType();
+      break;
     case VPDef::VPWidenIntOrFpInductionSC:
-    ResultTy = cast<VPWidenIntOrFpInductionRecipe>(Def)->getScalarType();
-    break;
+      ResultTy = cast<VPWidenIntOrFpInductionRecipe>(Def)->getScalarType();
+      break;
     case VPDef::VPPredInstPHISC:
     case VPDef::VPScalarIVStepsSC:
     case VPDef::VPWidenPHISC:
-    ResultTy = inferScalarType(Def->getOperand(0));
-    break;
+      ResultTy = inferScalarType(Def->getOperand(0));
+      break;
     case VPDef::VPBlendSC:
-    ResultTy = inferScalarType(cast<VPBlendRecipe>(Def));
-    break;
+      ResultTy = inferScalarType(cast<VPBlendRecipe>(Def));
+      break;
     case VPDef::VPInstructionSC:
-    ResultTy = inferScalarType(cast<VPInstruction>(Def));
-    break;
+      ResultTy = inferScalarType(cast<VPInstruction>(Def));
+      break;
     case VPDef::VPInterleaveSC:
-    // TODO: Use info from interleave group.
-    ResultTy = V->getUnderlyingValue()->getType();
-    break;
+      // TODO: Use info from interleave group.
+      ResultTy = V->getUnderlyingValue()->getType();
+      break;
     case VPDef::VPReplicateSC:
-    ResultTy = inferScalarType(cast<VPReplicateRecipe>(Def));
-    break;
+      ResultTy = inferScalarType(cast<VPReplicateRecipe>(Def));
+      break;
     case VPDef::VPWidenSC:
-    ResultTy = inferScalarType(cast<VPWidenRecipe>(Def));
-    break;
+      ResultTy = inferScalarType(cast<VPWidenRecipe>(Def));
+      break;
     case VPDef::VPWidenCallSC:
-    ResultTy = inferScalarType(cast<VPWidenCallRecipe>(Def));
-    break;
+      ResultTy = inferScalarType(cast<VPWidenCallRecipe>(Def));
+      break;
     case VPDef::VPWidenCastSC:
       ResultTy = cast<VPWidenCastRecipe>(Def)->getResultType();
       break;

>From a75c5624e7a6c60d2ca79023a9984889342b73bd Mon Sep 17 00:00:00 2001
From: Florian Hahn <flo at fhahn.com>
Date: Wed, 25 Oct 2023 13:04:40 +0100
Subject: [PATCH 07/12] Address comments

---
 .../Transforms/Vectorize/VPlanAnalysis.cpp    | 125 +++++++++---------
 llvm/lib/Transforms/Vectorize/VPlanAnalysis.h |   3 +
 .../lib/Transforms/Vectorize/VPlanRecipes.cpp |   3 +-
 3 files changed, 68 insertions(+), 63 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp b/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp
index 1b73b270d0aa2cb..5d325fb70e1b3b8 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp
@@ -21,7 +21,7 @@ Type *VPTypeAnalysis::inferScalarType(const VPBlendRecipe *R) {
            "different types inferred for different incoming values");
     CachedTypes[Inc] = ResTy;
   }
-  return inferScalarType(R->getIncomingValue(0));
+  return ResTy;
 }
 
 Type *VPTypeAnalysis::inferScalarType(const VPInstruction *R) {
@@ -37,8 +37,9 @@ Type *VPTypeAnalysis::inferScalarType(const VPInstruction *R) {
   case VPInstruction::FirstOrderRecurrenceSplice:
     return inferScalarType(R->getOperand(0));
   default:
-    llvm_unreachable("Unhandled instruction!");
+    break;
   }
+  llvm_unreachable("Unhandled opcode!");
 }
 
 Type *VPTypeAnalysis::inferScalarType(const VPWidenRecipe *R) {
@@ -133,7 +134,8 @@ Type *VPTypeAnalysis::inferScalarType(const VPReplicateRecipe *R) {
   case Instruction::ICmp:
   case Instruction::FCmp: {
     Type *ResTy = inferScalarType(R->getOperand(0));
-    assert(ResTy == inferScalarType(R->getOperand(1)));
+    assert(ResTy == inferScalarType(R->getOperand(1)) &&
+           "inferred types for operands of binary op don't match");
     CachedTypes[R->getOperand(1)] = ResTy;
     return ResTy;
   }
@@ -142,10 +144,8 @@ Type *VPTypeAnalysis::inferScalarType(const VPReplicateRecipe *R) {
   case Instruction::ZExt:
   case Instruction::FPExt:
   case Instruction::FPTrunc:
+  case Instruction::ExtractValue:
     return R->getUnderlyingInstr()->getType();
-  case Instruction::ExtractValue: {
-    return R->getUnderlyingInstr()->getType();
-  }
   case Instruction::Freeze:
   case Instruction::FNeg:
     return inferScalarType(R->getOperand(0));
@@ -166,60 +166,61 @@ Type *VPTypeAnalysis::inferScalarType(const VPValue *V) {
   if (V->isLiveIn())
     return V->getLiveInIRValue()->getType();
 
-    const VPRecipeBase *Def = V->getDefiningRecipe();
-    switch (Def->getVPDefID()) {
-    case VPDef::VPCanonicalIVPHISC:
-    case VPDef::VPFirstOrderRecurrencePHISC:
-    case VPDef::VPReductionPHISC:
-    case VPDef::VPWidenPointerInductionSC:
-      // Handle header phi recipes, except VPWienIntOrFpInduction which needs
-      // special handling due it being possibly truncated.
-      ResultTy = cast<VPHeaderPHIRecipe>(Def)
-                     ->getStartValue()
-                     ->getLiveInIRValue()
-                     ->getType();
-      break;
-    case VPDef::VPWidenIntOrFpInductionSC:
-      ResultTy = cast<VPWidenIntOrFpInductionRecipe>(Def)->getScalarType();
-      break;
-    case VPDef::VPPredInstPHISC:
-    case VPDef::VPScalarIVStepsSC:
-    case VPDef::VPWidenPHISC:
-      ResultTy = inferScalarType(Def->getOperand(0));
-      break;
-    case VPDef::VPBlendSC:
-      ResultTy = inferScalarType(cast<VPBlendRecipe>(Def));
-      break;
-    case VPDef::VPInstructionSC:
-      ResultTy = inferScalarType(cast<VPInstruction>(Def));
-      break;
-    case VPDef::VPInterleaveSC:
-      // TODO: Use info from interleave group.
-      ResultTy = V->getUnderlyingValue()->getType();
-      break;
-    case VPDef::VPReplicateSC:
-      ResultTy = inferScalarType(cast<VPReplicateRecipe>(Def));
-      break;
-    case VPDef::VPWidenSC:
-      ResultTy = inferScalarType(cast<VPWidenRecipe>(Def));
-      break;
-    case VPDef::VPWidenCallSC:
-      ResultTy = inferScalarType(cast<VPWidenCallRecipe>(Def));
-      break;
-    case VPDef::VPWidenCastSC:
-      ResultTy = cast<VPWidenCastRecipe>(Def)->getResultType();
-      break;
-    case VPDef::VPWidenGEPSC:
-      ResultTy = PointerType::get(Ctx, 0);
-      break;
-    case VPDef::VPWidenMemoryInstructionSC:
-      ResultTy = inferScalarType(cast<VPWidenMemoryInstructionRecipe>(Def));
-      break;
-    case VPDef::VPWidenSelectSC:
-      ResultTy = inferScalarType(cast<VPWidenSelectRecipe>(Def));
-      break;
-    }
-    assert(ResultTy && "could not infer type for the given VPValue");
-    CachedTypes[V] = ResultTy;
-    return ResultTy;
+  const VPRecipeBase *Def = V->getDefiningRecipe();
+
+  switch (Def->getVPDefID()) {
+  case VPDef::VPCanonicalIVPHISC:
+  case VPDef::VPFirstOrderRecurrencePHISC:
+  case VPDef::VPReductionPHISC:
+  case VPDef::VPWidenPointerInductionSC:
+    // Handle header phi recipes, except VPWienIntOrFpInduction which needs
+    // special handling due it being possibly truncated.
+    ResultTy = cast<VPHeaderPHIRecipe>(Def)
+                   ->getStartValue()
+                   ->getLiveInIRValue()
+                   ->getType();
+    break;
+  case VPDef::VPWidenIntOrFpInductionSC:
+    ResultTy = cast<VPWidenIntOrFpInductionRecipe>(Def)->getScalarType();
+    break;
+  case VPDef::VPPredInstPHISC:
+  case VPDef::VPScalarIVStepsSC:
+  case VPDef::VPWidenPHISC:
+    ResultTy = inferScalarType(Def->getOperand(0));
+    break;
+  case VPDef::VPBlendSC:
+    ResultTy = inferScalarType(cast<VPBlendRecipe>(Def));
+    break;
+  case VPDef::VPInstructionSC:
+    ResultTy = inferScalarType(cast<VPInstruction>(Def));
+    break;
+  case VPDef::VPInterleaveSC:
+    // TODO: Use info from interleave group.
+    ResultTy = V->getUnderlyingValue()->getType();
+    break;
+  case VPDef::VPReplicateSC:
+    ResultTy = inferScalarType(cast<VPReplicateRecipe>(Def));
+    break;
+  case VPDef::VPWidenSC:
+    ResultTy = inferScalarType(cast<VPWidenRecipe>(Def));
+    break;
+  case VPDef::VPWidenCallSC:
+    ResultTy = inferScalarType(cast<VPWidenCallRecipe>(Def));
+    break;
+  case VPDef::VPWidenCastSC:
+    ResultTy = cast<VPWidenCastRecipe>(Def)->getResultType();
+    break;
+  case VPDef::VPWidenGEPSC:
+    ResultTy = PointerType::get(Ctx, 0);
+    break;
+  case VPDef::VPWidenMemoryInstructionSC:
+    ResultTy = inferScalarType(cast<VPWidenMemoryInstructionRecipe>(Def));
+    break;
+  case VPDef::VPWidenSelectSC:
+    ResultTy = inferScalarType(cast<VPWidenSelectRecipe>(Def));
+    break;
+  }
+  assert(ResultTy && "could not infer type for the given VPValue");
+  CachedTypes[V] = ResultTy;
+  return ResultTy;
 }
diff --git a/llvm/lib/Transforms/Vectorize/VPlanAnalysis.h b/llvm/lib/Transforms/Vectorize/VPlanAnalysis.h
index 8f1cdf00f5cc0b6..7c223084669596b 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanAnalysis.h
+++ b/llvm/lib/Transforms/Vectorize/VPlanAnalysis.h
@@ -33,6 +33,9 @@ class Type;
 /// through defining recipes until root nodes with known types are reached (e.g.
 /// live-ins or memory recipes). The types are then propagated top down through
 /// operations.
+/// Note that the analysis caches the infered types. A new analysis object must
+/// be constructed once a VPlan has been modified in a way that invalidates any
+/// of the previously infered types.
 class VPTypeAnalysis {
   DenseMap<const VPValue *, Type *> CachedTypes;
   LLVMContext &Ctx;
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index 12ab2e3e0df7020..3b44aa993e9b480 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -746,7 +746,8 @@ void VPWidenRecipe::execute(VPTransformState &State) {
   VPTypeAnalysis A(State.Builder.GetInsertBlock()->getContext());
   for (unsigned Part = 0; Part < State.UF; ++Part) {
     assert(VectorType::get(A.inferScalarType(getVPSingleValue()), State.VF) ==
-           State.get(this, Part)->getType());
+               State.get(this, Part)->getType() &&
+           "infered type and type from generated instructions do not match");
   }
 #endif
 }

>From 1483f63d427310341ada5a6b9f9a41515cdb8ebd Mon Sep 17 00:00:00 2001
From: Florian Hahn <flo at fhahn.com>
Date: Wed, 25 Oct 2023 19:51:58 +0100
Subject: [PATCH 08/12] Add assert to rule out stores in inferScalarType

---
 llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp | 4 +---
 1 file changed, 1 insertion(+), 3 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp b/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp
index 5d325fb70e1b3b8..baf3b1112356b4c 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp
@@ -91,9 +91,7 @@ Type *VPTypeAnalysis::inferScalarType(const VPWidenCallRecipe *R) {
 }
 
 Type *VPTypeAnalysis::inferScalarType(const VPWidenMemoryInstructionRecipe *R) {
-  if (R->isStore())
-    return cast<StoreInst>(&R->getIngredient())->getValueOperand()->getType();
-
+  assert(!R->isStore() && "Store recipes should not define any values");
   return cast<LoadInst>(&R->getIngredient())->getType();
 }
 

>From 79bcce98f4ebe15b5739468c149620df29aa4deb Mon Sep 17 00:00:00 2001
From: Florian Hahn <flo at fhahn.com>
Date: Thu, 26 Oct 2023 18:36:01 +0100
Subject: [PATCH 09/12] Add a few more missed cases, verify types for
 VPReplicateRecipe codegen.

---
 .../Transforms/Vectorize/LoopVectorize.cpp    | 12 +++++++-
 llvm/lib/Transforms/Vectorize/VPlan.h         |  2 ++
 .../Transforms/Vectorize/VPlanAnalysis.cpp    | 30 ++++++++++++++++---
 3 files changed, 39 insertions(+), 5 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index 88f064b6d57cebc..75402aa0381ac53 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -57,6 +57,7 @@
 #include "LoopVectorizationPlanner.h"
 #include "VPRecipeBuilder.h"
 #include "VPlan.h"
+#include "VPlanAnalysis.h"
 #include "VPlanHCFGBuilder.h"
 #include "VPlanTransforms.h"
 #include "llvm/ADT/APInt.h"
@@ -2702,8 +2703,17 @@ void InnerLoopVectorizer::scalarizeInstruction(const Instruction *Instr,
   bool IsVoidRetTy = Instr->getType()->isVoidTy();
 
   Instruction *Cloned = Instr->clone();
-  if (!IsVoidRetTy)
+  if (!IsVoidRetTy) {
     Cloned->setName(Instr->getName() + ".cloned");
+#if !defined(NDEBUG)
+    // Verify that VPlan type inference results agree with the type of the
+    // generated values.
+    VPTypeAnalysis A(State.Builder.GetInsertBlock()->getContext());
+    assert(A.inferScalarType(RepRecipe->getVPSingleValue()) ==
+               Cloned->getType() &&
+           "infered type and type from generated instructions do not match");
+#endif
+  }
 
   RepRecipe->setFlags(Cloned);
 
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index 653290a036dd235..ca74f406043449a 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -2194,6 +2194,8 @@ class VPDerivedIVRecipe : public VPRecipeBase, public VPValue {
              VPSlotTracker &SlotTracker) const override;
 #endif
 
+  Type *getScalarType() const { return TruncResultTy; }
+
   VPValue *getStartValue() const { return getOperand(0); }
   VPValue *getCanonicalIV() const { return getOperand(1); }
   VPValue *getStepValue() const { return getOperand(2); }
diff --git a/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp b/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp
index baf3b1112356b4c..a41727f22f66b2d 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp
@@ -128,24 +128,39 @@ Type *VPTypeAnalysis::inferScalarType(const VPReplicateRecipe *R) {
   case Instruction::AShr:
   case Instruction::And:
   case Instruction::Or:
-  case Instruction::Xor:
-  case Instruction::ICmp:
-  case Instruction::FCmp: {
+  case Instruction::Xor: {
     Type *ResTy = inferScalarType(R->getOperand(0));
     assert(ResTy == inferScalarType(R->getOperand(1)) &&
            "inferred types for operands of binary op don't match");
     CachedTypes[R->getOperand(1)] = ResTy;
     return ResTy;
   }
+  case Instruction::Select: {
+    Type *ResTy = inferScalarType(R->getOperand(1));
+    assert(ResTy == inferScalarType(R->getOperand(2)) &&
+           "inferred types for operands of select op don't match");
+    CachedTypes[R->getOperand(2)] = ResTy;
+    return ResTy;
+  }
+  case Instruction::ICmp:
+  case Instruction::FCmp:
+    return IntegerType::get(Ctx, 1);
+  case Instruction::Alloca:
+  case Instruction::BitCast:
   case Instruction::Trunc:
   case Instruction::SExt:
   case Instruction::ZExt:
   case Instruction::FPExt:
   case Instruction::FPTrunc:
   case Instruction::ExtractValue:
+  case Instruction::SIToFP:
+  case Instruction::UIToFP:
+  case Instruction::FPToSI:
+  case Instruction::FPToUI:
     return R->getUnderlyingInstr()->getType();
   case Instruction::Freeze:
   case Instruction::FNeg:
+  case Instruction::GetElementPtr:
     return inferScalarType(R->getOperand(0));
   case Instruction::Load:
     return cast<LoadInst>(R->getUnderlyingInstr())->getType();
@@ -178,12 +193,19 @@ Type *VPTypeAnalysis::inferScalarType(const VPValue *V) {
                    ->getLiveInIRValue()
                    ->getType();
     break;
+  case VPDef::VPDerivedIVSC: {
+    // VPDerivedIV may truncate the IV to a specified scalar type or use the
+    // type of the first operand (the step).
+    Type *T = cast<VPDerivedIVRecipe>(Def)->getScalarType();
+    ResultTy = T ? T : inferScalarType(Def->getOperand(0));
+    break;
+  }
   case VPDef::VPWidenIntOrFpInductionSC:
     ResultTy = cast<VPWidenIntOrFpInductionRecipe>(Def)->getScalarType();
     break;
   case VPDef::VPPredInstPHISC:
-  case VPDef::VPScalarIVStepsSC:
   case VPDef::VPWidenPHISC:
+  case VPDef::VPScalarIVStepsSC:
     ResultTy = inferScalarType(Def->getOperand(0));
     break;
   case VPDef::VPBlendSC:

>From 532d75eb7ddbe58d1266fa26f14b4cd6b291a5e4 Mon Sep 17 00:00:00 2001
From: Florian Hahn <flo at fhahn.com>
Date: Thu, 26 Oct 2023 21:03:17 +0100
Subject: [PATCH 10/12] Use TypeSwitch.

---
 .../Transforms/Vectorize/VPlanAnalysis.cpp    | 97 +++++++------------
 1 file changed, 35 insertions(+), 62 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp b/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp
index a41727f22f66b2d..a7f34bc332cee52 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp
@@ -8,6 +8,7 @@
 
 #include "VPlanAnalysis.h"
 #include "VPlan.h"
+#include "llvm/ADT/TypeSwitch.h"
 
 using namespace llvm;
 
@@ -175,71 +176,43 @@ Type *VPTypeAnalysis::inferScalarType(const VPValue *V) {
   if (Type *CachedTy = CachedTypes.lookup(V))
     return CachedTy;
 
-  Type *ResultTy = nullptr;
   if (V->isLiveIn())
     return V->getLiveInIRValue()->getType();
 
-  const VPRecipeBase *Def = V->getDefiningRecipe();
-
-  switch (Def->getVPDefID()) {
-  case VPDef::VPCanonicalIVPHISC:
-  case VPDef::VPFirstOrderRecurrencePHISC:
-  case VPDef::VPReductionPHISC:
-  case VPDef::VPWidenPointerInductionSC:
-    // Handle header phi recipes, except VPWienIntOrFpInduction which needs
-    // special handling due it being possibly truncated.
-    ResultTy = cast<VPHeaderPHIRecipe>(Def)
-                   ->getStartValue()
-                   ->getLiveInIRValue()
-                   ->getType();
-    break;
-  case VPDef::VPDerivedIVSC: {
-    // VPDerivedIV may truncate the IV to a specified scalar type or use the
-    // type of the first operand (the step).
-    Type *T = cast<VPDerivedIVRecipe>(Def)->getScalarType();
-    ResultTy = T ? T : inferScalarType(Def->getOperand(0));
-    break;
-  }
-  case VPDef::VPWidenIntOrFpInductionSC:
-    ResultTy = cast<VPWidenIntOrFpInductionRecipe>(Def)->getScalarType();
-    break;
-  case VPDef::VPPredInstPHISC:
-  case VPDef::VPWidenPHISC:
-  case VPDef::VPScalarIVStepsSC:
-    ResultTy = inferScalarType(Def->getOperand(0));
-    break;
-  case VPDef::VPBlendSC:
-    ResultTy = inferScalarType(cast<VPBlendRecipe>(Def));
-    break;
-  case VPDef::VPInstructionSC:
-    ResultTy = inferScalarType(cast<VPInstruction>(Def));
-    break;
-  case VPDef::VPInterleaveSC:
-    // TODO: Use info from interleave group.
-    ResultTy = V->getUnderlyingValue()->getType();
-    break;
-  case VPDef::VPReplicateSC:
-    ResultTy = inferScalarType(cast<VPReplicateRecipe>(Def));
-    break;
-  case VPDef::VPWidenSC:
-    ResultTy = inferScalarType(cast<VPWidenRecipe>(Def));
-    break;
-  case VPDef::VPWidenCallSC:
-    ResultTy = inferScalarType(cast<VPWidenCallRecipe>(Def));
-    break;
-  case VPDef::VPWidenCastSC:
-    ResultTy = cast<VPWidenCastRecipe>(Def)->getResultType();
-    break;
-  case VPDef::VPWidenGEPSC:
-    ResultTy = PointerType::get(Ctx, 0);
-    break;
-  case VPDef::VPWidenMemoryInstructionSC:
-    ResultTy = inferScalarType(cast<VPWidenMemoryInstructionRecipe>(Def));
-    break;
-  case VPDef::VPWidenSelectSC:
-    ResultTy = inferScalarType(cast<VPWidenSelectRecipe>(Def));
-    break;
-  }
+  Type *ResultTy =
+      TypeSwitch<const VPRecipeBase *, Type *>(V->getDefiningRecipe())
+          .Case<VPCanonicalIVPHIRecipe, VPFirstOrderRecurrencePHIRecipe,
+                VPReductionPHIRecipe, VPWidenPointerInductionRecipe>(
+              [this](const auto *R) {
+                // Handle header phi recipes, except VPWienIntOrFpInduction
+                // which needs special handling due it being possibly truncated.
+                return inferScalarType(R->getStartValue());
+              })
+          .Case<VPWidenIntOrFpInductionRecipe>(
+              [](const VPWidenIntOrFpInductionRecipe *R) {
+                return R->getScalarType();
+              })
+          .Case<VPDerivedIVRecipe>([this](const VPDerivedIVRecipe *R) {
+            // VPDerivedIV may truncate the IV to a specified scalar type or use
+            // the
+            // type of the first operand (the step).
+            Type *T = R->getScalarType();
+            return T ? T : inferScalarType(R->getOperand(0));
+          })
+          .Case<VPPredInstPHIRecipe, VPWidenPHIRecipe, VPScalarIVStepsRecipe,
+                VPWidenGEPRecipe>([this](const VPRecipeBase *R) {
+            return inferScalarType(R->getOperand(0));
+          })
+          .Case<VPBlendRecipe, VPInstruction, VPWidenRecipe, VPReplicateRecipe,
+                VPWidenCallRecipe, VPWidenMemoryInstructionRecipe,
+                VPWidenSelectRecipe>(
+              [this](const auto *R) { return inferScalarType(R); })
+          .Case<VPInterleaveRecipe>([V](const VPInterleaveRecipe *R) {
+            // TODO: Use info from interleave group.
+            return V->getUnderlyingValue()->getType();
+          })
+          .Case<VPWidenCastRecipe>(
+              [](const VPWidenCastRecipe *R) { return R->getResultType(); });
   assert(ResultTy && "could not infer type for the given VPValue");
   CachedTypes[V] = ResultTy;
   return ResultTy;

>From 5bbd764adbbcf55627a36a03b5c85165c94038b2 Mon Sep 17 00:00:00 2001
From: Florian Hahn <flo at fhahn.com>
Date: Fri, 27 Oct 2023 11:24:08 +0100
Subject: [PATCH 11/12] Apply suggestions from code review

Co-authored-by: ayalz <47719489+ayalz at users.noreply.github.com>
---
 llvm/lib/Transforms/Vectorize/LoopVectorize.cpp | 2 +-
 llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp | 5 ++---
 llvm/lib/Transforms/Vectorize/VPlanAnalysis.h   | 6 +++---
 llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp  | 2 +-
 4 files changed, 7 insertions(+), 8 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index 75402aa0381ac53..9dd1b1a3542b137 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -2711,7 +2711,7 @@ void InnerLoopVectorizer::scalarizeInstruction(const Instruction *Instr,
     VPTypeAnalysis A(State.Builder.GetInsertBlock()->getContext());
     assert(A.inferScalarType(RepRecipe->getVPSingleValue()) ==
                Cloned->getType() &&
-           "infered type and type from generated instructions do not match");
+           "inferred type and type from generated instructions do not match");
 #endif
   }
 
diff --git a/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp b/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp
index a7f34bc332cee52..ba06377248af0e2 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp
@@ -80,7 +80,7 @@ Type *VPTypeAnalysis::inferScalarType(const VPWidenRecipe *R) {
     break;
   }
 
-  // Type inferrence not implemented for opcode.
+  // Type inference not implemented for opcode.
   LLVM_DEBUG(dbgs() << "LV: Found unhandled opcode: "
                     << Instruction::getOpcodeName(Opcode));
   llvm_unreachable("Unhandled opcode!");
@@ -194,8 +194,7 @@ Type *VPTypeAnalysis::inferScalarType(const VPValue *V) {
               })
           .Case<VPDerivedIVRecipe>([this](const VPDerivedIVRecipe *R) {
             // VPDerivedIV may truncate the IV to a specified scalar type or use
-            // the
-            // type of the first operand (the step).
+            // the type of the first operand (the start).
             Type *T = R->getScalarType();
             return T ? T : inferScalarType(R->getOperand(0));
           })
diff --git a/llvm/lib/Transforms/Vectorize/VPlanAnalysis.h b/llvm/lib/Transforms/Vectorize/VPlanAnalysis.h
index 7c223084669596b..5df737fbbf90b9b 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanAnalysis.h
+++ b/llvm/lib/Transforms/Vectorize/VPlanAnalysis.h
@@ -31,11 +31,11 @@ class Type;
 /// An analysis for type-inference for VPValues.
 /// It infers the scalar type for a given VPValue by bottom-up traversing
 /// through defining recipes until root nodes with known types are reached (e.g.
-/// live-ins or memory recipes). The types are then propagated top down through
+/// live-ins or load recipes). The types are then propagated top down through
 /// operations.
-/// Note that the analysis caches the infered types. A new analysis object must
+/// Note that the analysis caches the inferred types. A new analysis object must
 /// be constructed once a VPlan has been modified in a way that invalidates any
-/// of the previously infered types.
+/// of the previously inferred types.
 class VPTypeAnalysis {
   DenseMap<const VPValue *, Type *> CachedTypes;
   LLVMContext &Ctx;
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index 3b44aa993e9b480..c100384b33b2321 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -747,7 +747,7 @@ void VPWidenRecipe::execute(VPTransformState &State) {
   for (unsigned Part = 0; Part < State.UF; ++Part) {
     assert(VectorType::get(A.inferScalarType(getVPSingleValue()), State.VF) ==
                State.get(this, Part)->getType() &&
-           "infered type and type from generated instructions do not match");
+           "inferred type and type from generated instructions do not match");
   }
 #endif
 }

>From 74525adae0b7026b305dbac828ba55b5dd45a2a0 Mon Sep 17 00:00:00 2001
From: Florian Hahn <flo at fhahn.com>
Date: Fri, 27 Oct 2023 11:38:23 +0100
Subject: [PATCH 12/12] Address latest comments, thanks!

---
 .../Transforms/Vectorize/LoopVectorize.cpp    |  3 +-
 llvm/lib/Transforms/Vectorize/VPlan.h         |  5 ++-
 .../Transforms/Vectorize/VPlanAnalysis.cpp    | 41 ++++++++++---------
 llvm/lib/Transforms/Vectorize/VPlanAnalysis.h | 16 ++++----
 .../lib/Transforms/Vectorize/VPlanRecipes.cpp |  2 +-
 5 files changed, 35 insertions(+), 32 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index 9dd1b1a3542b137..0b506d95fbc1bdb 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -2709,8 +2709,7 @@ void InnerLoopVectorizer::scalarizeInstruction(const Instruction *Instr,
     // Verify that VPlan type inference results agree with the type of the
     // generated values.
     VPTypeAnalysis A(State.Builder.GetInsertBlock()->getContext());
-    assert(A.inferScalarType(RepRecipe->getVPSingleValue()) ==
-               Cloned->getType() &&
+    assert(A.inferScalarType(RepRecipe) == Cloned->getType() &&
            "inferred type and type from generated instructions do not match");
 #endif
   }
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index ca74f406043449a..3158f18cdae4c06 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -2194,7 +2194,10 @@ class VPDerivedIVRecipe : public VPRecipeBase, public VPValue {
              VPSlotTracker &SlotTracker) const override;
 #endif
 
-  Type *getScalarType() const { return TruncResultTy; }
+  Type *getScalarType() const {
+    return TruncResultTy ? TruncResultTy
+                         : getStartValue()->getLiveInIRValue()->getType();
+  }
 
   VPValue *getStartValue() const { return getOperand(0); }
   VPValue *getCanonicalIV() const { return getOperand(1); }
diff --git a/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp b/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp
index ba06377248af0e2..8e347e56a1f2cb7 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp
@@ -14,7 +14,7 @@ using namespace llvm;
 
 #define DEBUG_TYPE "vplan"
 
-Type *VPTypeAnalysis::inferScalarType(const VPBlendRecipe *R) {
+Type *VPTypeAnalysis::inferScalarTypeForRecipe(const VPBlendRecipe *R) {
   Type *ResTy = inferScalarType(R->getIncomingValue(0));
   for (unsigned I = 1, E = R->getNumIncomingValues(); I != E; ++I) {
     VPValue *Inc = R->getIncomingValue(I);
@@ -25,7 +25,7 @@ Type *VPTypeAnalysis::inferScalarType(const VPBlendRecipe *R) {
   return ResTy;
 }
 
-Type *VPTypeAnalysis::inferScalarType(const VPInstruction *R) {
+Type *VPTypeAnalysis::inferScalarTypeForRecipe(const VPInstruction *R) {
   switch (R->getOpcode()) {
   case Instruction::Select: {
     Type *ResTy = inferScalarType(R->getOperand(1));
@@ -35,15 +35,21 @@ Type *VPTypeAnalysis::inferScalarType(const VPInstruction *R) {
     CachedTypes[OtherV] = ResTy;
     return ResTy;
   }
-  case VPInstruction::FirstOrderRecurrenceSplice:
-    return inferScalarType(R->getOperand(0));
+  case VPInstruction::FirstOrderRecurrenceSplice: {
+    Type *ResTy = inferScalarType(R->getOperand(0));
+    VPValue *OtherV = R->getOperand(1);
+    assert(inferScalarType(OtherV) == ResTy &&
+           "different types inferred for different operands");
+    CachedTypes[OtherV] = ResTy;
+    return ResTy;
+  }
   default:
     break;
   }
   llvm_unreachable("Unhandled opcode!");
 }
 
-Type *VPTypeAnalysis::inferScalarType(const VPWidenRecipe *R) {
+Type *VPTypeAnalysis::inferScalarTypeForRecipe(const VPWidenRecipe *R) {
   unsigned Opcode = R->getOpcode();
   switch (Opcode) {
   case Instruction::ICmp:
@@ -86,17 +92,18 @@ Type *VPTypeAnalysis::inferScalarType(const VPWidenRecipe *R) {
   llvm_unreachable("Unhandled opcode!");
 }
 
-Type *VPTypeAnalysis::inferScalarType(const VPWidenCallRecipe *R) {
+Type *VPTypeAnalysis::inferScalarTypeForRecipe(const VPWidenCallRecipe *R) {
   auto &CI = *cast<CallInst>(R->getUnderlyingInstr());
   return CI.getType();
 }
 
-Type *VPTypeAnalysis::inferScalarType(const VPWidenMemoryInstructionRecipe *R) {
+Type *VPTypeAnalysis::inferScalarTypeForRecipe(
+    const VPWidenMemoryInstructionRecipe *R) {
   assert(!R->isStore() && "Store recipes should not define any values");
   return cast<LoadInst>(&R->getIngredient())->getType();
 }
 
-Type *VPTypeAnalysis::inferScalarType(const VPWidenSelectRecipe *R) {
+Type *VPTypeAnalysis::inferScalarTypeForRecipe(const VPWidenSelectRecipe *R) {
   Type *ResTy = inferScalarType(R->getOperand(1));
   VPValue *OtherV = R->getOperand(2);
   assert(inferScalarType(OtherV) == ResTy &&
@@ -105,7 +112,7 @@ Type *VPTypeAnalysis::inferScalarType(const VPWidenSelectRecipe *R) {
   return ResTy;
 }
 
-Type *VPTypeAnalysis::inferScalarType(const VPReplicateRecipe *R) {
+Type *VPTypeAnalysis::inferScalarTypeForRecipe(const VPReplicateRecipe *R) {
   switch (R->getUnderlyingInstr()->getOpcode()) {
   case Instruction::Call: {
     unsigned CallIdx = R->getNumOperands() - (R->isPredicated() ? 2 : 1);
@@ -158,6 +165,8 @@ Type *VPTypeAnalysis::inferScalarType(const VPReplicateRecipe *R) {
   case Instruction::UIToFP:
   case Instruction::FPToSI:
   case Instruction::FPToUI:
+  case Instruction::PtrToInt:
+  case Instruction::IntToPtr:
     return R->getUnderlyingInstr()->getType();
   case Instruction::Freeze:
   case Instruction::FNeg:
@@ -188,16 +197,8 @@ Type *VPTypeAnalysis::inferScalarType(const VPValue *V) {
                 // which needs special handling due it being possibly truncated.
                 return inferScalarType(R->getStartValue());
               })
-          .Case<VPWidenIntOrFpInductionRecipe>(
-              [](const VPWidenIntOrFpInductionRecipe *R) {
-                return R->getScalarType();
-              })
-          .Case<VPDerivedIVRecipe>([this](const VPDerivedIVRecipe *R) {
-            // VPDerivedIV may truncate the IV to a specified scalar type or use
-            // the type of the first operand (the start).
-            Type *T = R->getScalarType();
-            return T ? T : inferScalarType(R->getOperand(0));
-          })
+          .Case<VPWidenIntOrFpInductionRecipe, VPDerivedIVRecipe>(
+              [](const auto *R) { return R->getScalarType(); })
           .Case<VPPredInstPHIRecipe, VPWidenPHIRecipe, VPScalarIVStepsRecipe,
                 VPWidenGEPRecipe>([this](const VPRecipeBase *R) {
             return inferScalarType(R->getOperand(0));
@@ -205,7 +206,7 @@ Type *VPTypeAnalysis::inferScalarType(const VPValue *V) {
           .Case<VPBlendRecipe, VPInstruction, VPWidenRecipe, VPReplicateRecipe,
                 VPWidenCallRecipe, VPWidenMemoryInstructionRecipe,
                 VPWidenSelectRecipe>(
-              [this](const auto *R) { return inferScalarType(R); })
+              [this](const auto *R) { return inferScalarTypeForRecipe(R); })
           .Case<VPInterleaveRecipe>([V](const VPInterleaveRecipe *R) {
             // TODO: Use info from interleave group.
             return V->getUnderlyingValue()->getType();
diff --git a/llvm/lib/Transforms/Vectorize/VPlanAnalysis.h b/llvm/lib/Transforms/Vectorize/VPlanAnalysis.h
index 5df737fbbf90b9b..34b6b74588325bc 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanAnalysis.h
+++ b/llvm/lib/Transforms/Vectorize/VPlanAnalysis.h
@@ -40,14 +40,14 @@ class VPTypeAnalysis {
   DenseMap<const VPValue *, Type *> CachedTypes;
   LLVMContext &Ctx;
 
-  Type *inferScalarType(const VPBlendRecipe *R);
-  Type *inferScalarType(const VPInstruction *R);
-  Type *inferScalarType(const VPWidenCallRecipe *R);
-  Type *inferScalarType(const VPWidenRecipe *R);
-  Type *inferScalarType(const VPWidenIntOrFpInductionRecipe *R);
-  Type *inferScalarType(const VPWidenMemoryInstructionRecipe *R);
-  Type *inferScalarType(const VPWidenSelectRecipe *R);
-  Type *inferScalarType(const VPReplicateRecipe *R);
+  Type *inferScalarTypeForRecipe(const VPBlendRecipe *R);
+  Type *inferScalarTypeForRecipe(const VPInstruction *R);
+  Type *inferScalarTypeForRecipe(const VPWidenCallRecipe *R);
+  Type *inferScalarTypeForRecipe(const VPWidenRecipe *R);
+  Type *inferScalarTypeForRecipe(const VPWidenIntOrFpInductionRecipe *R);
+  Type *inferScalarTypeForRecipe(const VPWidenMemoryInstructionRecipe *R);
+  Type *inferScalarTypeForRecipe(const VPWidenSelectRecipe *R);
+  Type *inferScalarTypeForRecipe(const VPReplicateRecipe *R);
 
 public:
   VPTypeAnalysis(LLVMContext &Ctx) : Ctx(Ctx) {}
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index c100384b33b2321..3b6077c615bcd2f 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -745,7 +745,7 @@ void VPWidenRecipe::execute(VPTransformState &State) {
   // generated values.
   VPTypeAnalysis A(State.Builder.GetInsertBlock()->getContext());
   for (unsigned Part = 0; Part < State.UF; ++Part) {
-    assert(VectorType::get(A.inferScalarType(getVPSingleValue()), State.VF) ==
+    assert(VectorType::get(A.inferScalarType(this), State.VF) ==
                State.get(this, Part)->getType() &&
            "inferred type and type from generated instructions do not match");
   }



More information about the llvm-commits mailing list