[llvm] [VPlan] Add initial anlysis to infer scalar type of VPValues. (PR #69013)

Florian Hahn via llvm-commits llvm-commits at lists.llvm.org
Fri Oct 13 10:34:11 PDT 2023


https://github.com/fhahn created https://github.com/llvm/llvm-project/pull/69013

This patch adds initial type inferrence for VPValues. It infers the scalar type of a VPValue, by bottom-up traversing through defining recipes until root nodes with known types are reached (e.g. live-ins or memory recipes). The types are then propagated top down through operations.

This is intended as building block for a VPlan-based cost model, which will need access to type information for VPValues/recipes.

Initial testing is done by asserting the inferred type matches the type of the result value generated for a widen recipe.

>From 22b23c7de44939e470a47be8af077a7c4516b3c5 Mon Sep 17 00:00:00 2001
From: Florian Hahn <flo at fhahn.com>
Date: Tue, 10 Oct 2023 20:57:17 +0100
Subject: [PATCH] [VPlan] Add initial anlysis to infer scalar type of VPValues.

This patch adds initial type inferrence for VPValues. It infers the
scalar type of a VPValue, by bottom-up traversing through defining recipes until
root nodes with known types are reached (e.g. live-ins or memory
recipes). The types are then propagated top down through operations.

This is intended as building block for a VPlan-based cost model, which
will need access to type information for VPValues/recipes.

Initial testing is done by asserting the inferred type matches the type
of the result value generated for a widen recipe.
---
 llvm/lib/Transforms/Vectorize/CMakeLists.txt  |   1 +
 llvm/lib/Transforms/Vectorize/VPlan.h         |   8 +-
 .../Transforms/Vectorize/VPlanAnalysis.cpp    | 225 ++++++++++++++++++
 llvm/lib/Transforms/Vectorize/VPlanAnalysis.h |  56 +++++
 .../lib/Transforms/Vectorize/VPlanRecipes.cpp |  12 +
 5 files changed, 299 insertions(+), 3 deletions(-)
 create mode 100644 llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp
 create mode 100644 llvm/lib/Transforms/Vectorize/VPlanAnalysis.h

diff --git a/llvm/lib/Transforms/Vectorize/CMakeLists.txt b/llvm/lib/Transforms/Vectorize/CMakeLists.txt
index 998dfd956575d3c..9674094024b9ec7 100644
--- a/llvm/lib/Transforms/Vectorize/CMakeLists.txt
+++ b/llvm/lib/Transforms/Vectorize/CMakeLists.txt
@@ -6,6 +6,7 @@ add_llvm_component_library(LLVMVectorize
   Vectorize.cpp
   VectorCombine.cpp
   VPlan.cpp
+  VPlanAnalysis.cpp
   VPlanHCFGBuilder.cpp
   VPlanRecipes.cpp
   VPlanSLP.cpp
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index e65a7ab2cd028ee..ea1f8a5b9d1e9ab 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -1167,6 +1167,8 @@ class VPWidenRecipe : public VPRecipeWithIRFlags, public VPValue {
   /// Produce widened copies of all Ingredients.
   void execute(VPTransformState &State) override;
 
+  unsigned getOpcode() const { return Opcode; }
+
 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
   /// Print the recipe.
   void print(raw_ostream &O, const Twine &Indent,
@@ -1458,7 +1460,7 @@ class VPWidenIntOrFpInductionRecipe : public VPHeaderPHIRecipe {
   bool isCanonical() const;
 
   /// Returns the scalar type of the induction.
-  const Type *getScalarType() const {
+  Type *getScalarType() const {
     return Trunc ? Trunc->getType() : IV->getType();
   }
 };
@@ -2080,7 +2082,7 @@ class VPCanonicalIVPHIRecipe : public VPHeaderPHIRecipe {
 #endif
 
   /// Returns the scalar type of the induction.
-  const Type *getScalarType() const {
+  Type *getScalarType() const {
     return getOperand(0)->getLiveInIRValue()->getType();
   }
 
@@ -2149,7 +2151,7 @@ class VPWidenCanonicalIVRecipe : public VPRecipeBase, public VPValue {
 #endif
 
   /// Returns the scalar type of the induction.
-  const Type *getScalarType() const {
+  Type *getScalarType() const {
     return cast<VPCanonicalIVPHIRecipe>(getOperand(0)->getDefiningRecipe())
         ->getScalarType();
   }
diff --git a/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp b/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp
new file mode 100644
index 000000000000000..088da81f950425c
--- /dev/null
+++ b/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp
@@ -0,0 +1,225 @@
+//===- VPlanAnalysis.cpp - Various Analyses working on VPlan ----*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "VPlanAnalysis.h"
+#include "VPlan.h"
+
+using namespace llvm;
+
+#define DEBUG_TYPE "vplan"
+
+Type *VPTypeAnalysis::inferType(const VPBlendRecipe *R) {
+  return inferType(R->getIncomingValue(0));
+}
+
+Type *VPTypeAnalysis::inferType(const VPInstruction *R) {
+  switch (R->getOpcode()) {
+  case Instruction::Select:
+    return inferType(R->getOperand(1));
+  case VPInstruction::FirstOrderRecurrenceSplice:
+    return inferType(R->getOperand(0));
+  default:
+    llvm_unreachable("Unhandled instruction!");
+  }
+}
+
+Type *VPTypeAnalysis::inferType(const VPInterleaveRecipe *R) { return nullptr; }
+
+Type *VPTypeAnalysis::inferType(const VPReductionPHIRecipe *R) {
+  return R->getOperand(0)->getLiveInIRValue()->getType();
+}
+
+Type *VPTypeAnalysis::inferType(const VPWidenRecipe *R) {
+  unsigned Opcode = R->getOpcode();
+  switch (Opcode) {
+  case Instruction::ICmp:
+  case Instruction::FCmp:
+    return IntegerType::get(Ctx, 1);
+  case Instruction::UDiv:
+  case Instruction::SDiv:
+  case Instruction::SRem:
+  case Instruction::URem:
+  case Instruction::Add:
+  case Instruction::FAdd:
+  case Instruction::Sub:
+  case Instruction::FSub:
+  case Instruction::FNeg:
+  case Instruction::Mul:
+  case Instruction::FMul:
+  case Instruction::FDiv:
+  case Instruction::FRem:
+  case Instruction::Shl:
+  case Instruction::LShr:
+  case Instruction::AShr:
+  case Instruction::And:
+  case Instruction::Or:
+  case Instruction::Xor: {
+    Type *ResTy = inferType(R->getOperand(0));
+    if (Opcode != Instruction::FNeg) {
+      assert(ResTy == inferType(R->getOperand(1)));
+      CachedTypes[R->getOperand(1)] = ResTy;
+    }
+    return ResTy;
+  }
+  case Instruction::Freeze:
+    return inferType(R->getOperand(0));
+  default:
+    // This instruction is not vectorized by simple widening.
+    //    LLVM_DEBUG(dbgs() << "LV: Found an unhandled instruction: " << I);
+    llvm_unreachable("Unhandled instruction!");
+  }
+
+  return nullptr;
+}
+
+Type *VPTypeAnalysis::inferType(const VPWidenCallRecipe *R) {
+  auto &CI = *cast<CallInst>(R->getUnderlyingInstr());
+  return CI.getType();
+}
+
+Type *VPTypeAnalysis::inferType(const VPWidenIntOrFpInductionRecipe *R) {
+  return R->getScalarType();
+}
+
+Type *VPTypeAnalysis::inferType(const VPWidenMemoryInstructionRecipe *R) {
+  if (R->isStore())
+    return cast<StoreInst>(&R->getIngredient())->getValueOperand()->getType();
+
+  return cast<LoadInst>(&R->getIngredient())->getType();
+}
+
+Type *VPTypeAnalysis::inferType(const VPWidenSelectRecipe *R) {
+  return inferType(R->getOperand(1));
+}
+
+Type *VPTypeAnalysis::inferType(const VPReplicateRecipe *R) {
+  switch (R->getUnderlyingInstr()->getOpcode()) {
+  case Instruction::Call: {
+    unsigned CallIdx = R->getNumOperands() - (R->isPredicated() ? 2 : 1);
+    return cast<Function>(R->getOperand(CallIdx)->getLiveInIRValue())
+        ->getReturnType();
+  }
+  case Instruction::UDiv:
+  case Instruction::SDiv:
+  case Instruction::SRem:
+  case Instruction::URem:
+  case Instruction::Add:
+  case Instruction::FAdd:
+  case Instruction::Sub:
+  case Instruction::FSub:
+  case Instruction::FNeg:
+  case Instruction::Mul:
+  case Instruction::FMul:
+  case Instruction::FDiv:
+  case Instruction::FRem:
+  case Instruction::Shl:
+  case Instruction::LShr:
+  case Instruction::AShr:
+  case Instruction::And:
+  case Instruction::Or:
+  case Instruction::Xor:
+  case Instruction::ICmp:
+  case Instruction::FCmp: {
+    Type *ResTy = inferType(R->getOperand(0));
+    assert(ResTy == inferType(R->getOperand(1)));
+    CachedTypes[R->getOperand(1)] = ResTy;
+    return ResTy;
+  }
+  case Instruction::Trunc:
+  case Instruction::SExt:
+  case Instruction::ZExt:
+  case Instruction::FPExt:
+  case Instruction::FPTrunc:
+    return R->getUnderlyingInstr()->getType();
+  case Instruction::ExtractValue: {
+    return R->getUnderlyingValue()->getType();
+  }
+  case Instruction::Freeze:
+    return inferType(R->getOperand(0));
+  case Instruction::Load:
+    return cast<LoadInst>(R->getUnderlyingInstr())->getType();
+  case Instruction::Store:
+    return cast<StoreInst>(R->getUnderlyingInstr())
+        ->getValueOperand()
+        ->getType();
+  default:
+    llvm_unreachable("Unhandled instruction");
+  }
+
+  return nullptr;
+}
+
+Type *VPTypeAnalysis::inferType(const VPValue *V) {
+  auto Iter = CachedTypes.find(V);
+  if (Iter != CachedTypes.end())
+    return Iter->second;
+
+  Type *ResultTy = nullptr;
+  if (V->isLiveIn())
+    ResultTy = V->getLiveInIRValue()->getType();
+  else {
+    const VPRecipeBase *Def = V->getDefiningRecipe();
+    switch (Def->getVPDefID()) {
+    case VPDef::VPBlendSC:
+      ResultTy = inferType(cast<VPBlendRecipe>(Def));
+      break;
+    case VPDef::VPCanonicalIVPHISC:
+      ResultTy = cast<VPCanonicalIVPHIRecipe>(Def)->getScalarType();
+      break;
+    case VPDef::VPFirstOrderRecurrencePHISC:
+      ResultTy = Def->getOperand(0)->getLiveInIRValue()->getType();
+      break;
+    case VPDef::VPInstructionSC:
+      ResultTy = inferType(cast<VPInstruction>(Def));
+      break;
+    case VPDef::VPInterleaveSC:
+      ResultTy = V->getUnderlyingValue()
+                     ->getType(); // inferType(cast<VPInterleaveRecipe>(Def));
+      break;
+    case VPDef::VPPredInstPHISC:
+      ResultTy = inferType(Def->getOperand(0));
+      break;
+    case VPDef::VPReductionPHISC:
+      ResultTy = inferType(cast<VPReductionPHIRecipe>(Def));
+      break;
+    case VPDef::VPReplicateSC:
+      ResultTy = inferType(cast<VPReplicateRecipe>(Def));
+      break;
+    case VPDef::VPScalarIVStepsSC:
+      return inferType(Def->getOperand(0));
+      break;
+    case VPDef::VPWidenSC:
+      ResultTy = inferType(cast<VPWidenRecipe>(Def));
+      break;
+    case VPDef::VPWidenPHISC:
+      return inferType(Def->getOperand(0));
+    case VPDef::VPWidenPointerInductionSC:
+      return inferType(Def->getOperand(0));
+    case VPDef::VPWidenCallSC:
+      ResultTy = inferType(cast<VPWidenCallRecipe>(Def));
+      break;
+    case VPDef::VPWidenCastSC:
+      ResultTy = cast<VPWidenCastRecipe>(Def)->getResultType();
+      break;
+    case VPDef::VPWidenGEPSC:
+      ResultTy = PointerType::get(Ctx, 0);
+      break;
+    case VPDef::VPWidenIntOrFpInductionSC:
+      ResultTy = inferType(cast<VPWidenIntOrFpInductionRecipe>(Def));
+      break;
+    case VPDef::VPWidenMemoryInstructionSC:
+      ResultTy = inferType(cast<VPWidenMemoryInstructionRecipe>(Def));
+      break;
+    case VPDef::VPWidenSelectSC:
+      ResultTy = inferType(cast<VPWidenSelectRecipe>(Def));
+      break;
+    }
+  }
+  CachedTypes[V] = ResultTy;
+  return ResultTy;
+}
diff --git a/llvm/lib/Transforms/Vectorize/VPlanAnalysis.h b/llvm/lib/Transforms/Vectorize/VPlanAnalysis.h
new file mode 100644
index 000000000000000..8fcbe9ca99bb4d5
--- /dev/null
+++ b/llvm/lib/Transforms/Vectorize/VPlanAnalysis.h
@@ -0,0 +1,56 @@
+//===- VPlanAnalysis.h - Various Analyses working on VPlan ------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_TRANSFORMS_VECTORIZE_VPLANANALYSIS_H
+#define LLVM_TRANSFORMS_VECTORIZE_VPLANANALYSIS_H
+
+#include "llvm/ADT/DenseMap.h"
+
+namespace llvm {
+
+class LLVMContext;
+class VPValue;
+class VPBlendRecipe;
+class VPInterleaveRecipe;
+class VPInstruction;
+class VPReductionPHIRecipe;
+class VPWidenRecipe;
+class VPWidenCallRecipe;
+class VPWidenCastRecipe;
+class VPWidenIntOrFpInductionRecipe;
+class VPWidenMemoryInstructionRecipe;
+struct VPWidenSelectRecipe;
+class VPReplicateRecipe;
+class Type;
+
+/// An analysis for type-inferrence for VPValues.
+class VPTypeAnalysis {
+  DenseMap<const VPValue *, Type *> CachedTypes;
+  LLVMContext &Ctx;
+
+  Type *inferType(const VPBlendRecipe *R);
+  Type *inferType(const VPInstruction *R);
+  Type *inferType(const VPInterleaveRecipe *R);
+  Type *inferType(const VPWidenCallRecipe *R);
+  Type *inferType(const VPReductionPHIRecipe *R);
+  Type *inferType(const VPWidenRecipe *R);
+  Type *inferType(const VPWidenIntOrFpInductionRecipe *R);
+  Type *inferType(const VPWidenMemoryInstructionRecipe *R);
+  Type *inferType(const VPWidenSelectRecipe *R);
+  Type *inferType(const VPReplicateRecipe *R);
+
+public:
+  VPTypeAnalysis(LLVMContext &Ctx) : Ctx(Ctx) {}
+
+  /// Infer the type of \p V. Returns the scalar type of \p V.
+  Type *inferType(const VPValue *V);
+};
+
+} // end namespace llvm
+
+#endif // LLVM_TRANSFORMS_VECTORIZE_VPLANANALYSIS_H
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index 2a1213a98095907..b616abddb00f99a 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -12,6 +12,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "VPlan.h"
+#include "VPlanAnalysis.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/Twine.h"
@@ -738,7 +739,18 @@ void VPWidenRecipe::execute(VPTransformState &State) {
                       << Instruction::getOpcodeName(Opcode));
     llvm_unreachable("Unhandled instruction!");
   } // end of switch.
+
+#if !defined(NDEBUG)
+  // Verify that VPlan type infererrence results agree with the type of the
+  // generated values.
+  VPTypeAnalysis A(State.Builder.GetInsertBlock()->getContext());
+  for (unsigned Part = 0; Part < State.UF; ++Part) {
+    assert(VectorType::get(A.inferType(getVPSingleValue()), State.VF) ==
+           State.get(this, Part)->getType());
+  }
+#endif
 }
+
 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
 void VPWidenRecipe::print(raw_ostream &O, const Twine &Indent,
                           VPSlotTracker &SlotTracker) const {



More information about the llvm-commits mailing list