[llvm] b0b8864 - [VPlan] Add initial anlysis to infer scalar type of VPValues. (#69013)

via llvm-commits llvm-commits at lists.llvm.org
Fri Oct 27 06:38:31 PDT 2023


Author: Florian Hahn
Date: 2023-10-27T14:38:28+01:00
New Revision: b0b88643a1facc4fcb06783e295de3d4a91dc924

URL: https://github.com/llvm/llvm-project/commit/b0b88643a1facc4fcb06783e295de3d4a91dc924
DIFF: https://github.com/llvm/llvm-project/commit/b0b88643a1facc4fcb06783e295de3d4a91dc924.diff

LOG: [VPlan] Add initial anlysis to infer scalar type of VPValues. (#69013)

This patch adds initial type inferrence for VPValues. It infers the
scalar type of a VPValue, by bottom-up traversing through defining
recipes until root nodes with known types are reached (e.g. live-ins or
load recipes). The types are then propagated top down through
operations.

This is intended as building block for a VPlan-based cost model, which
will need access to type information for VPValues/recipes.

Initial testing is done by asserting the inferred type matches the type
of the result value generated for a widen and replicate recipes.

Added: 
    llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp
    llvm/lib/Transforms/Vectorize/VPlanAnalysis.h

Modified: 
    llvm/lib/Transforms/Vectorize/CMakeLists.txt
    llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
    llvm/lib/Transforms/Vectorize/VPlan.h
    llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/Vectorize/CMakeLists.txt b/llvm/lib/Transforms/Vectorize/CMakeLists.txt
index 998dfd956575d3c..9674094024b9ec7 100644
--- a/llvm/lib/Transforms/Vectorize/CMakeLists.txt
+++ b/llvm/lib/Transforms/Vectorize/CMakeLists.txt
@@ -6,6 +6,7 @@ add_llvm_component_library(LLVMVectorize
   Vectorize.cpp
   VectorCombine.cpp
   VPlan.cpp
+  VPlanAnalysis.cpp
   VPlanHCFGBuilder.cpp
   VPlanRecipes.cpp
   VPlanSLP.cpp

diff  --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index 0807d2a7e5a2671..16c761a91ff2326 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -57,6 +57,7 @@
 #include "LoopVectorizationPlanner.h"
 #include "VPRecipeBuilder.h"
 #include "VPlan.h"
+#include "VPlanAnalysis.h"
 #include "VPlanHCFGBuilder.h"
 #include "VPlanTransforms.h"
 #include "llvm/ADT/APInt.h"
@@ -2702,8 +2703,15 @@ void InnerLoopVectorizer::scalarizeInstruction(const Instruction *Instr,
   bool IsVoidRetTy = Instr->getType()->isVoidTy();
 
   Instruction *Cloned = Instr->clone();
-  if (!IsVoidRetTy)
+  if (!IsVoidRetTy) {
     Cloned->setName(Instr->getName() + ".cloned");
+#if !defined(NDEBUG)
+    // Verify that VPlan type inference results agree with the type of the
+    // generated values.
+    assert(State.TypeAnalysis.inferScalarType(RepRecipe) == Cloned->getType() &&
+           "inferred type and type from generated instructions do not match");
+#endif
+  }
 
   RepRecipe->setFlags(Cloned);
 
@@ -7660,7 +7668,8 @@ SCEV2ValueTy LoopVectorizationPlanner::executePlan(
     VPlanTransforms::optimizeForVFAndUF(BestVPlan, BestVF, BestUF, PSE);
 
   // Perform the actual loop transformation.
-  VPTransformState State{BestVF, BestUF, LI, DT, ILV.Builder, &ILV, &BestVPlan};
+  VPTransformState State(BestVF, BestUF, LI, DT, ILV.Builder, &ILV, &BestVPlan,
+                         OrigLoop->getHeader()->getContext());
 
   // 0. Generate SCEV-dependent code into the preheader, including TripCount,
   // before making any changes to the CFG.

diff  --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index e65a7ab2cd028ee..96cfc9354ead806 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -23,6 +23,7 @@
 #ifndef LLVM_TRANSFORMS_VECTORIZE_VPLAN_H
 #define LLVM_TRANSFORMS_VECTORIZE_VPLAN_H
 
+#include "VPlanAnalysis.h"
 #include "VPlanValue.h"
 #include "llvm/ADT/DenseMap.h"
 #include "llvm/ADT/MapVector.h"
@@ -233,9 +234,9 @@ struct VPIteration {
 struct VPTransformState {
   VPTransformState(ElementCount VF, unsigned UF, LoopInfo *LI,
                    DominatorTree *DT, IRBuilderBase &Builder,
-                   InnerLoopVectorizer *ILV, VPlan *Plan)
+                   InnerLoopVectorizer *ILV, VPlan *Plan, LLVMContext &Ctx)
       : VF(VF), UF(UF), LI(LI), DT(DT), Builder(Builder), ILV(ILV), Plan(Plan),
-        LVer(nullptr) {}
+        LVer(nullptr), TypeAnalysis(Ctx) {}
 
   /// The chosen Vectorization and Unroll Factors of the loop being vectorized.
   ElementCount VF;
@@ -413,6 +414,9 @@ struct VPTransformState {
   /// Map SCEVs to their expanded values. Populated when executing
   /// VPExpandSCEVRecipes.
   DenseMap<const SCEV *, Value *> ExpandedSCEVs;
+
+  /// VPlan-based type analysis.
+  VPTypeAnalysis TypeAnalysis;
 };
 
 /// VPBlockBase is the building block of the Hierarchical Control-Flow Graph.
@@ -1167,6 +1171,8 @@ class VPWidenRecipe : public VPRecipeWithIRFlags, public VPValue {
   /// Produce widened copies of all Ingredients.
   void execute(VPTransformState &State) override;
 
+  unsigned getOpcode() const { return Opcode; }
+
 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
   /// Print the recipe.
   void print(raw_ostream &O, const Twine &Indent,
@@ -1458,7 +1464,7 @@ class VPWidenIntOrFpInductionRecipe : public VPHeaderPHIRecipe {
   bool isCanonical() const;
 
   /// Returns the scalar type of the induction.
-  const Type *getScalarType() const {
+  Type *getScalarType() const {
     return Trunc ? Trunc->getType() : IV->getType();
   }
 };
@@ -2080,8 +2086,8 @@ class VPCanonicalIVPHIRecipe : public VPHeaderPHIRecipe {
 #endif
 
   /// Returns the scalar type of the induction.
-  const Type *getScalarType() const {
-    return getOperand(0)->getLiveInIRValue()->getType();
+  Type *getScalarType() const {
+    return getStartValue()->getLiveInIRValue()->getType();
   }
 
   /// Returns true if the recipe only uses the first lane of operand \p Op.
@@ -2192,6 +2198,11 @@ class VPDerivedIVRecipe : public VPRecipeBase, public VPValue {
              VPSlotTracker &SlotTracker) const override;
 #endif
 
+  Type *getScalarType() const {
+    return TruncResultTy ? TruncResultTy
+                         : getStartValue()->getLiveInIRValue()->getType();
+  }
+
   VPValue *getStartValue() const { return getOperand(0); }
   VPValue *getCanonicalIV() const { return getOperand(1); }
   VPValue *getStepValue() const { return getOperand(2); }

diff  --git a/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp b/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp
new file mode 100644
index 000000000000000..ff1b6b30aa3e125
--- /dev/null
+++ b/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp
@@ -0,0 +1,232 @@
+//===- VPlanAnalysis.cpp - Various Analyses working on VPlan ----*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "VPlanAnalysis.h"
+#include "VPlan.h"
+#include "llvm/ADT/TypeSwitch.h"
+
+using namespace llvm;
+
+#define DEBUG_TYPE "vplan"
+
+Type *VPTypeAnalysis::inferScalarTypeForRecipe(const VPBlendRecipe *R) {
+  Type *ResTy = inferScalarType(R->getIncomingValue(0));
+  for (unsigned I = 1, E = R->getNumIncomingValues(); I != E; ++I) {
+    VPValue *Inc = R->getIncomingValue(I);
+    assert(inferScalarType(Inc) == ResTy &&
+           "
diff erent types inferred for 
diff erent incoming values");
+    CachedTypes[Inc] = ResTy;
+  }
+  return ResTy;
+}
+
+Type *VPTypeAnalysis::inferScalarTypeForRecipe(const VPInstruction *R) {
+  switch (R->getOpcode()) {
+  case Instruction::Select: {
+    Type *ResTy = inferScalarType(R->getOperand(1));
+    VPValue *OtherV = R->getOperand(2);
+    assert(inferScalarType(OtherV) == ResTy &&
+           "
diff erent types inferred for 
diff erent operands");
+    CachedTypes[OtherV] = ResTy;
+    return ResTy;
+  }
+  case VPInstruction::FirstOrderRecurrenceSplice: {
+    Type *ResTy = inferScalarType(R->getOperand(0));
+    VPValue *OtherV = R->getOperand(1);
+    assert(inferScalarType(OtherV) == ResTy &&
+           "
diff erent types inferred for 
diff erent operands");
+    CachedTypes[OtherV] = ResTy;
+    return ResTy;
+  }
+  default:
+    break;
+  }
+  // Type inference not implemented for opcode.
+  LLVM_DEBUG({
+    dbgs() << "LV: Found unhandled opcode for: ";
+    R->getVPSingleValue()->dump();
+  });
+  llvm_unreachable("Unhandled opcode!");
+}
+
+Type *VPTypeAnalysis::inferScalarTypeForRecipe(const VPWidenRecipe *R) {
+  unsigned Opcode = R->getOpcode();
+  switch (Opcode) {
+  case Instruction::ICmp:
+  case Instruction::FCmp:
+    return IntegerType::get(Ctx, 1);
+  case Instruction::UDiv:
+  case Instruction::SDiv:
+  case Instruction::SRem:
+  case Instruction::URem:
+  case Instruction::Add:
+  case Instruction::FAdd:
+  case Instruction::Sub:
+  case Instruction::FSub:
+  case Instruction::Mul:
+  case Instruction::FMul:
+  case Instruction::FDiv:
+  case Instruction::FRem:
+  case Instruction::Shl:
+  case Instruction::LShr:
+  case Instruction::AShr:
+  case Instruction::And:
+  case Instruction::Or:
+  case Instruction::Xor: {
+    Type *ResTy = inferScalarType(R->getOperand(0));
+    assert(ResTy == inferScalarType(R->getOperand(1)) &&
+           "types for both operands must match for binary op");
+    CachedTypes[R->getOperand(1)] = ResTy;
+    return ResTy;
+  }
+  case Instruction::FNeg:
+  case Instruction::Freeze:
+    return inferScalarType(R->getOperand(0));
+  default:
+    break;
+  }
+
+  // Type inference not implemented for opcode.
+  LLVM_DEBUG({
+    dbgs() << "LV: Found unhandled opcode for: ";
+    R->getVPSingleValue()->dump();
+  });
+  llvm_unreachable("Unhandled opcode!");
+}
+
+Type *VPTypeAnalysis::inferScalarTypeForRecipe(const VPWidenCallRecipe *R) {
+  auto &CI = *cast<CallInst>(R->getUnderlyingInstr());
+  return CI.getType();
+}
+
+Type *VPTypeAnalysis::inferScalarTypeForRecipe(
+    const VPWidenMemoryInstructionRecipe *R) {
+  assert(!R->isStore() && "Store recipes should not define any values");
+  return cast<LoadInst>(&R->getIngredient())->getType();
+}
+
+Type *VPTypeAnalysis::inferScalarTypeForRecipe(const VPWidenSelectRecipe *R) {
+  Type *ResTy = inferScalarType(R->getOperand(1));
+  VPValue *OtherV = R->getOperand(2);
+  assert(inferScalarType(OtherV) == ResTy &&
+         "
diff erent types inferred for 
diff erent operands");
+  CachedTypes[OtherV] = ResTy;
+  return ResTy;
+}
+
+Type *VPTypeAnalysis::inferScalarTypeForRecipe(const VPReplicateRecipe *R) {
+  switch (R->getUnderlyingInstr()->getOpcode()) {
+  case Instruction::Call: {
+    unsigned CallIdx = R->getNumOperands() - (R->isPredicated() ? 2 : 1);
+    return cast<Function>(R->getOperand(CallIdx)->getLiveInIRValue())
+        ->getReturnType();
+  }
+  case Instruction::UDiv:
+  case Instruction::SDiv:
+  case Instruction::SRem:
+  case Instruction::URem:
+  case Instruction::Add:
+  case Instruction::FAdd:
+  case Instruction::Sub:
+  case Instruction::FSub:
+  case Instruction::Mul:
+  case Instruction::FMul:
+  case Instruction::FDiv:
+  case Instruction::FRem:
+  case Instruction::Shl:
+  case Instruction::LShr:
+  case Instruction::AShr:
+  case Instruction::And:
+  case Instruction::Or:
+  case Instruction::Xor: {
+    Type *ResTy = inferScalarType(R->getOperand(0));
+    assert(ResTy == inferScalarType(R->getOperand(1)) &&
+           "inferred types for operands of binary op don't match");
+    CachedTypes[R->getOperand(1)] = ResTy;
+    return ResTy;
+  }
+  case Instruction::Select: {
+    Type *ResTy = inferScalarType(R->getOperand(1));
+    assert(ResTy == inferScalarType(R->getOperand(2)) &&
+           "inferred types for operands of select op don't match");
+    CachedTypes[R->getOperand(2)] = ResTy;
+    return ResTy;
+  }
+  case Instruction::ICmp:
+  case Instruction::FCmp:
+    return IntegerType::get(Ctx, 1);
+  case Instruction::Alloca:
+  case Instruction::BitCast:
+  case Instruction::Trunc:
+  case Instruction::SExt:
+  case Instruction::ZExt:
+  case Instruction::FPExt:
+  case Instruction::FPTrunc:
+  case Instruction::ExtractValue:
+  case Instruction::SIToFP:
+  case Instruction::UIToFP:
+  case Instruction::FPToSI:
+  case Instruction::FPToUI:
+  case Instruction::PtrToInt:
+  case Instruction::IntToPtr:
+    return R->getUnderlyingInstr()->getType();
+  case Instruction::Freeze:
+  case Instruction::FNeg:
+  case Instruction::GetElementPtr:
+    return inferScalarType(R->getOperand(0));
+  case Instruction::Load:
+    return cast<LoadInst>(R->getUnderlyingInstr())->getType();
+  default:
+    break;
+  }
+  // Type inference not implemented for opcode.
+  LLVM_DEBUG({
+    dbgs() << "LV: Found unhandled opcode for: ";
+    R->getVPSingleValue()->dump();
+  });
+  llvm_unreachable("Unhandled opcode");
+}
+
+Type *VPTypeAnalysis::inferScalarType(const VPValue *V) {
+  if (Type *CachedTy = CachedTypes.lookup(V))
+    return CachedTy;
+
+  if (V->isLiveIn())
+    return V->getLiveInIRValue()->getType();
+
+  Type *ResultTy =
+      TypeSwitch<const VPRecipeBase *, Type *>(V->getDefiningRecipe())
+          .Case<VPCanonicalIVPHIRecipe, VPFirstOrderRecurrencePHIRecipe,
+                VPReductionPHIRecipe, VPWidenPointerInductionRecipe>(
+              [this](const auto *R) {
+                // Handle header phi recipes, except VPWienIntOrFpInduction
+                // which needs special handling due it being possibly truncated.
+                // TODO: consider inferring/caching type of siblings, e.g.,
+                // backedge value, here and in cases below.
+                return inferScalarType(R->getStartValue());
+              })
+          .Case<VPWidenIntOrFpInductionRecipe, VPDerivedIVRecipe>(
+              [](const auto *R) { return R->getScalarType(); })
+          .Case<VPPredInstPHIRecipe, VPWidenPHIRecipe, VPScalarIVStepsRecipe,
+                VPWidenGEPRecipe>([this](const VPRecipeBase *R) {
+            return inferScalarType(R->getOperand(0));
+          })
+          .Case<VPBlendRecipe, VPInstruction, VPWidenRecipe, VPReplicateRecipe,
+                VPWidenCallRecipe, VPWidenMemoryInstructionRecipe,
+                VPWidenSelectRecipe>(
+              [this](const auto *R) { return inferScalarTypeForRecipe(R); })
+          .Case<VPInterleaveRecipe>([V](const VPInterleaveRecipe *R) {
+            // TODO: Use info from interleave group.
+            return V->getUnderlyingValue()->getType();
+          })
+          .Case<VPWidenCastRecipe>(
+              [](const VPWidenCastRecipe *R) { return R->getResultType(); });
+  assert(ResultTy && "could not infer type for the given VPValue");
+  CachedTypes[V] = ResultTy;
+  return ResultTy;
+}

diff  --git a/llvm/lib/Transforms/Vectorize/VPlanAnalysis.h b/llvm/lib/Transforms/Vectorize/VPlanAnalysis.h
new file mode 100644
index 000000000000000..34b6b74588325bc
--- /dev/null
+++ b/llvm/lib/Transforms/Vectorize/VPlanAnalysis.h
@@ -0,0 +1,61 @@
+//===- VPlanAnalysis.h - Various Analyses working on VPlan ------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_TRANSFORMS_VECTORIZE_VPLANANALYSIS_H
+#define LLVM_TRANSFORMS_VECTORIZE_VPLANANALYSIS_H
+
+#include "llvm/ADT/DenseMap.h"
+
+namespace llvm {
+
+class LLVMContext;
+class VPValue;
+class VPBlendRecipe;
+class VPInterleaveRecipe;
+class VPInstruction;
+class VPReductionPHIRecipe;
+class VPWidenRecipe;
+class VPWidenCallRecipe;
+class VPWidenCastRecipe;
+class VPWidenIntOrFpInductionRecipe;
+class VPWidenMemoryInstructionRecipe;
+struct VPWidenSelectRecipe;
+class VPReplicateRecipe;
+class Type;
+
+/// An analysis for type-inference for VPValues.
+/// It infers the scalar type for a given VPValue by bottom-up traversing
+/// through defining recipes until root nodes with known types are reached (e.g.
+/// live-ins or load recipes). The types are then propagated top down through
+/// operations.
+/// Note that the analysis caches the inferred types. A new analysis object must
+/// be constructed once a VPlan has been modified in a way that invalidates any
+/// of the previously inferred types.
+class VPTypeAnalysis {
+  DenseMap<const VPValue *, Type *> CachedTypes;
+  LLVMContext &Ctx;
+
+  Type *inferScalarTypeForRecipe(const VPBlendRecipe *R);
+  Type *inferScalarTypeForRecipe(const VPInstruction *R);
+  Type *inferScalarTypeForRecipe(const VPWidenCallRecipe *R);
+  Type *inferScalarTypeForRecipe(const VPWidenRecipe *R);
+  Type *inferScalarTypeForRecipe(const VPWidenIntOrFpInductionRecipe *R);
+  Type *inferScalarTypeForRecipe(const VPWidenMemoryInstructionRecipe *R);
+  Type *inferScalarTypeForRecipe(const VPWidenSelectRecipe *R);
+  Type *inferScalarTypeForRecipe(const VPReplicateRecipe *R);
+
+public:
+  VPTypeAnalysis(LLVMContext &Ctx) : Ctx(Ctx) {}
+
+  /// Infer the type of \p V. Returns the scalar type of \p V.
+  Type *inferScalarType(const VPValue *V);
+};
+
+} // end namespace llvm
+
+#endif // LLVM_TRANSFORMS_VECTORIZE_VPLANANALYSIS_H

diff  --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index efc95c1cd08c6fd..6b3218dca1b18b0 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -12,6 +12,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "VPlan.h"
+#include "VPlanAnalysis.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/Twine.h"
@@ -738,7 +739,18 @@ void VPWidenRecipe::execute(VPTransformState &State) {
                       << Instruction::getOpcodeName(Opcode));
     llvm_unreachable("Unhandled instruction!");
   } // end of switch.
+
+#if !defined(NDEBUG)
+  // Verify that VPlan type inference results agree with the type of the
+  // generated values.
+  for (unsigned Part = 0; Part < State.UF; ++Part) {
+    assert(VectorType::get(State.TypeAnalysis.inferScalarType(this),
+                           State.VF) == State.get(this, Part)->getType() &&
+           "inferred type and type from generated instructions do not match");
+  }
+#endif
 }
+
 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
 void VPWidenRecipe::print(raw_ostream &O, const Twine &Indent,
                           VPSlotTracker &SlotTracker) const {


        


More information about the llvm-commits mailing list