[llvm] [SandboxVec][BottomUpVec] Implement InstrMaps (PR #122848)

via llvm-commits llvm-commits at lists.llvm.org
Wed Jan 15 12:48:42 PST 2025


https://github.com/vporpo updated https://github.com/llvm/llvm-project/pull/122848

>From 02e7a63cd87c62b5c758f9e9ec84266c24d626b6 Mon Sep 17 00:00:00 2001
From: Vasileios Porpodas <vporpodas at google.com>
Date: Thu, 14 Nov 2024 13:07:33 -0800
Subject: [PATCH] [SandboxVec][BottomUpVec] Implement InstrMaps

InstrMaps is a helper data structure that maps scalars to vectors and the
reverse. This is used by the vectorizer to figure out which vectors it can
extract scalar values from.
---
 .../Vectorize/SandboxVectorizer/InstrMaps.h   |  77 +++++++
 .../Vectorize/SandboxVectorizer/Legality.h    |  84 +++++++-
 .../SandboxVectorizer/Passes/BottomUpVec.h    |   3 +
 llvm/lib/Transforms/Vectorize/CMakeLists.txt  |   1 +
 .../Vectorize/SandboxVectorizer/InstrMaps.cpp |  21 ++
 .../Vectorize/SandboxVectorizer/Legality.cpp  |  36 +++-
 .../SandboxVectorizer/Passes/BottomUpVec.cpp  | 204 ++++++++++--------
 .../SandboxVectorizer/bottomup_basic.ll       |  20 ++
 .../SandboxVectorizer/CMakeLists.txt          |   1 +
 .../SandboxVectorizer/InstrMapsTest.cpp       |  78 +++++++
 .../SandboxVectorizer/LegalityTest.cpp        |  75 ++++++-
 11 files changed, 496 insertions(+), 104 deletions(-)
 create mode 100644 llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/InstrMaps.h
 create mode 100644 llvm/lib/Transforms/Vectorize/SandboxVectorizer/InstrMaps.cpp
 create mode 100644 llvm/unittests/Transforms/Vectorize/SandboxVectorizer/InstrMapsTest.cpp

diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/InstrMaps.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/InstrMaps.h
new file mode 100644
index 00000000000000..22304248b92335
--- /dev/null
+++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/InstrMaps.h
@@ -0,0 +1,77 @@
+//===- InstructionMaps.h ----------------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_TRANSFORMS_VECTORIZE_SANDBOXVEC_PASSES_INSTRUCTIONMAPS_H
+#define LLVM_TRANSFORMS_VECTORIZE_SANDBOXVEC_PASSES_INSTRUCTIONMAPS_H
+
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/SmallSet.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/SandboxIR/Value.h"
+#include "llvm/Support/Casting.h"
+#include "llvm/Support/raw_ostream.h"
+
+namespace llvm::sandboxir {
+
+/// Maps the original instructions to the vectorized instrs and the reverse.
+/// For now an original instr can only map to a single vector.
+class InstrMaps {
+  /// A map from the original values that got combined into vectors, to the
+  /// vector value(s).
+  DenseMap<Value *, Value *> OrigToVectorMap;
+  /// A map from the vector value to a map of the original value to its lane.
+  /// Please note that for constant vectors, there may multiple original values
+  /// with the same lane, as they may be coming from vectorizing different
+  /// original values.
+  DenseMap<Value *, DenseMap<Value *, unsigned>> VectorToOrigLaneMap;
+
+public:
+  /// \Returns all the vector value that we got from vectorizing \p Orig, or
+  /// nullptr if not found.
+  Value *getVectorForOrig(Value *Orig) const {
+    auto It = OrigToVectorMap.find(Orig);
+    return It != OrigToVectorMap.end() ? It->second : nullptr;
+  }
+  /// \Returns the lane of \p Orig before it got vectorized into \p Vec, or
+  /// nullopt if not found.
+  std::optional<int> getOrigLane(Value *Vec, Value *Orig) const {
+    auto It1 = VectorToOrigLaneMap.find(Vec);
+    if (It1 == VectorToOrigLaneMap.end())
+      return std::nullopt;
+    const auto &OrigToLaneMap = It1->second;
+    auto It2 = OrigToLaneMap.find(Orig);
+    if (It2 == OrigToLaneMap.end())
+      return std::nullopt;
+    return It2->second;
+  }
+  /// Update the map to reflect that \p Origs got vectorized into \p Vec.
+  void registerVector(ArrayRef<Value *> Origs, Value *Vec) {
+    auto &OrigToLaneMap = VectorToOrigLaneMap[Vec];
+    for (auto [Lane, Orig] : enumerate(Origs)) {
+      auto Pair = OrigToVectorMap.try_emplace(Orig, Vec);
+      assert(Pair.second && "Orig already exists in the map!");
+      OrigToLaneMap[Orig] = Lane;
+    }
+  }
+  void clear() {
+    OrigToVectorMap.clear();
+    VectorToOrigLaneMap.clear();
+  }
+#ifndef NDEBUG
+  void print(raw_ostream &OS) const {
+    OS << "OrigToVectorMap:\n";
+    for (auto [Orig, Vec] : OrigToVectorMap)
+      OS << *Orig << " : " << *Vec << "\n";
+  }
+  LLVM_DUMP_METHOD void dump() const;
+#endif
+};
+} // namespace llvm::sandboxir
+
+#endif // LLVM_TRANSFORMS_VECTORIZE_SANDBOXVEC_PASSES_INSTRUCTIONMAPS_H
diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h
index 233cf82a1b3dfb..8dfb02818fe011 100644
--- a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h
+++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h
@@ -23,10 +23,12 @@ namespace llvm::sandboxir {
 
 class LegalityAnalysis;
 class Value;
+class InstrMaps;
 
 enum class LegalityResultID {
-  Pack,  ///> Collect scalar values.
-  Widen, ///> Vectorize by combining scalars to a vector.
+  Pack,         ///> Collect scalar values.
+  Widen,        ///> Vectorize by combining scalars to a vector.
+  DiamondReuse, ///> Don't generate new code, reuse existing vector.
 };
 
 /// The reason for vectorizing or not vectorizing.
@@ -50,6 +52,8 @@ struct ToStr {
       return "Pack";
     case LegalityResultID::Widen:
       return "Widen";
+    case LegalityResultID::DiamondReuse:
+      return "DiamondReuse";
     }
     llvm_unreachable("Unknown LegalityResultID enum");
   }
@@ -137,6 +141,19 @@ class Widen final : public LegalityResult {
   }
 };
 
+class DiamondReuse final : public LegalityResult {
+  friend class LegalityAnalysis;
+  Value *Vec;
+  DiamondReuse(Value *Vec)
+      : LegalityResult(LegalityResultID::DiamondReuse), Vec(Vec) {}
+
+public:
+  static bool classof(const LegalityResult *From) {
+    return From->getSubclassID() == LegalityResultID::DiamondReuse;
+  }
+  Value *getVector() const { return Vec; }
+};
+
 class Pack final : public LegalityResultWithReason {
   Pack(ResultReason Reason)
       : LegalityResultWithReason(LegalityResultID::Pack, Reason) {}
@@ -148,6 +165,57 @@ class Pack final : public LegalityResultWithReason {
   }
 };
 
+/// Describes how to collect the values needed by each lane.
+class CollectDescr {
+public:
+  /// Describes how to get a value element. If the value is a vector then it
+  /// also provides the index to extract it from.
+  class ExtractElementDescr {
+    Value *V;
+    /// The index in `V` that the value can be extracted from.
+    /// This is nullopt if we need to use `V` as a whole.
+    std::optional<int> ExtractIdx;
+
+  public:
+    ExtractElementDescr(Value *V, int ExtractIdx)
+        : V(V), ExtractIdx(ExtractIdx) {}
+    ExtractElementDescr(Value *V) : V(V), ExtractIdx(std::nullopt) {}
+    Value *getValue() const { return V; }
+    bool needsExtract() const { return ExtractIdx.has_value(); }
+    int getExtractIdx() const { return *ExtractIdx; }
+  };
+
+  using DescrVecT = SmallVector<ExtractElementDescr, 4>;
+  DescrVecT Descrs;
+
+public:
+  CollectDescr(SmallVectorImpl<ExtractElementDescr> &&Descrs)
+      : Descrs(std::move(Descrs)) {}
+  std::optional<std::pair<Value *, bool>> getSingleInput() const {
+    const auto &Descr0 = *Descrs.begin();
+    Value *V0 = Descr0.getValue();
+    if (!Descr0.needsExtract())
+      return std::nullopt;
+    bool NeedsShuffle = Descr0.getExtractIdx() != 0;
+    int Lane = 1;
+    for (const auto &Descr : drop_begin(Descrs)) {
+      if (!Descr.needsExtract())
+        return std::nullopt;
+      if (Descr.getValue() != V0)
+        return std::nullopt;
+      if (Descr.getExtractIdx() != Lane++)
+        NeedsShuffle = true;
+    }
+    return std::make_pair(V0, NeedsShuffle);
+  }
+  bool hasVectorInputs() const {
+    return any_of(Descrs, [](const auto &D) { return D.needsExtract(); });
+  }
+  const SmallVector<ExtractElementDescr, 4> &getDescrs() const {
+    return Descrs;
+  }
+};
+
 /// Performs the legality analysis and returns a LegalityResult object.
 class LegalityAnalysis {
   Scheduler Sched;
@@ -160,11 +228,17 @@ class LegalityAnalysis {
 
   ScalarEvolution &SE;
   const DataLayout &DL;
+  InstrMaps &IMaps;
+
+  /// Finds how we can collect the values in \p Bndl from the vectorized or
+  /// non-vectorized code. It returns a map of the value we should extract from
+  /// and the corresponding shuffle mask we need to use.
+  CollectDescr getHowToCollectValues(ArrayRef<Value *> Bndl) const;
 
 public:
   LegalityAnalysis(AAResults &AA, ScalarEvolution &SE, const DataLayout &DL,
-                   Context &Ctx)
-      : Sched(AA, Ctx), SE(SE), DL(DL) {}
+                   Context &Ctx, InstrMaps &IMaps)
+      : Sched(AA, Ctx), SE(SE), DL(DL), IMaps(IMaps) {}
   /// A LegalityResult factory.
   template <typename ResultT, typename... ArgsT>
   ResultT &createLegalityResult(ArgsT... Args) {
@@ -177,7 +251,7 @@ class LegalityAnalysis {
   // TODO: Try to remove the SkipScheduling argument by refactoring the tests.
   const LegalityResult &canVectorize(ArrayRef<Value *> Bndl,
                                      bool SkipScheduling = false);
-  void clear() { Sched.clear(); }
+  void clear();
 };
 
 } // namespace llvm::sandboxir
diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.h
index 1a53ca6e06f5fd..69cea3c4c7b53b 100644
--- a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.h
+++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.h
@@ -18,6 +18,7 @@
 #include "llvm/SandboxIR/Pass.h"
 #include "llvm/SandboxIR/PassManager.h"
 #include "llvm/Support/raw_ostream.h"
+#include "llvm/Transforms/Vectorize/SandboxVectorizer/InstrMaps.h"
 #include "llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h"
 
 namespace llvm::sandboxir {
@@ -26,6 +27,8 @@ class BottomUpVec final : public FunctionPass {
   bool Change = false;
   std::unique_ptr<LegalityAnalysis> Legality;
   DenseSet<Instruction *> DeadInstrCandidates;
+  /// Maps scalars to vectors.
+  InstrMaps IMaps;
 
   /// Creates and returns a vector instruction that replaces the instructions in
   /// \p Bndl. \p Operands are the already vectorized operands.
diff --git a/llvm/lib/Transforms/Vectorize/CMakeLists.txt b/llvm/lib/Transforms/Vectorize/CMakeLists.txt
index d769d5100afd23..6a025652f92f8e 100644
--- a/llvm/lib/Transforms/Vectorize/CMakeLists.txt
+++ b/llvm/lib/Transforms/Vectorize/CMakeLists.txt
@@ -4,6 +4,7 @@ add_llvm_component_library(LLVMVectorize
   LoopVectorizationLegality.cpp
   LoopVectorize.cpp
   SandboxVectorizer/DependencyGraph.cpp
+  SandboxVectorizer/InstrMaps.cpp
   SandboxVectorizer/Interval.cpp
   SandboxVectorizer/Legality.cpp
   SandboxVectorizer/Passes/BottomUpVec.cpp
diff --git a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/InstrMaps.cpp b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/InstrMaps.cpp
new file mode 100644
index 00000000000000..4df4829a04c417
--- /dev/null
+++ b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/InstrMaps.cpp
@@ -0,0 +1,21 @@
+//===- InstructionMaps.cpp - Maps scalars to vectors and reverse ----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Transforms/Vectorize/SandboxVectorizer/InstrMaps.h"
+#include "llvm/Support/Debug.h"
+
+namespace llvm::sandboxir {
+
+#ifndef NDEBUG
+void InstrMaps::dump() const {
+  print(dbgs());
+  dbgs() << "\n";
+}
+#endif // NDEBUG
+
+} // namespace llvm::sandboxir
diff --git a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Legality.cpp b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Legality.cpp
index 8c6deeb7df249d..f8149c5bc66363 100644
--- a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Legality.cpp
+++ b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Legality.cpp
@@ -12,6 +12,7 @@
 #include "llvm/SandboxIR/Utils.h"
 #include "llvm/SandboxIR/Value.h"
 #include "llvm/Support/Debug.h"
+#include "llvm/Transforms/Vectorize/SandboxVectorizer/InstrMaps.h"
 #include "llvm/Transforms/Vectorize/SandboxVectorizer/VecUtils.h"
 
 namespace llvm::sandboxir {
@@ -184,6 +185,22 @@ static void dumpBndl(ArrayRef<Value *> Bndl) {
 }
 #endif // NDEBUG
 
+CollectDescr
+LegalityAnalysis::getHowToCollectValues(ArrayRef<Value *> Bndl) const {
+  SmallVector<CollectDescr::ExtractElementDescr, 4> Vec;
+  Vec.reserve(Bndl.size());
+  for (auto [Lane, V] : enumerate(Bndl)) {
+    if (auto *VecOp = IMaps.getVectorForOrig(V)) {
+      // If there is a vector containing `V`, then get the lane it came from.
+      std::optional<int> ExtractIdxOpt = IMaps.getOrigLane(VecOp, V);
+      Vec.emplace_back(VecOp, ExtractIdxOpt ? *ExtractIdxOpt : -1);
+    } else {
+      Vec.emplace_back(V);
+    }
+  }
+  return CollectDescr(std::move(Vec));
+}
+
 const LegalityResult &LegalityAnalysis::canVectorize(ArrayRef<Value *> Bndl,
                                                      bool SkipScheduling) {
   // If Bndl contains values other than instructions, we need to Pack.
@@ -193,11 +210,21 @@ const LegalityResult &LegalityAnalysis::canVectorize(ArrayRef<Value *> Bndl,
     return createLegalityResult<Pack>(ResultReason::NotInstructions);
   }
 
+  auto CollectDescrs = getHowToCollectValues(Bndl);
+  if (CollectDescrs.hasVectorInputs()) {
+    if (auto ValueShuffleOpt = CollectDescrs.getSingleInput()) {
+      auto [Vec, NeedsShuffle] = *ValueShuffleOpt;
+      if (!NeedsShuffle)
+        return createLegalityResult<DiamondReuse>(Vec);
+      llvm_unreachable("TODO: Unimplemented");
+    } else {
+      llvm_unreachable("TODO: Unimplemented");
+    }
+  }
+
   if (auto ReasonOpt = notVectorizableBasedOnOpcodesAndTypes(Bndl))
     return createLegalityResult<Pack>(*ReasonOpt);
 
-  // TODO: Check for existing vectors containing values in Bndl.
-
   if (!SkipScheduling) {
     // TODO: Try to remove the IBndl vector.
     SmallVector<Instruction *, 8> IBndl;
@@ -210,4 +237,9 @@ const LegalityResult &LegalityAnalysis::canVectorize(ArrayRef<Value *> Bndl,
 
   return createLegalityResult<Widen>();
 }
+
+void LegalityAnalysis::clear() {
+  Sched.clear();
+  IMaps.clear();
+}
 } // namespace llvm::sandboxir
diff --git a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp
index d44199609838d7..6b2032be535603 100644
--- a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp
+++ b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp
@@ -56,103 +56,114 @@ getInsertPointAfterInstrs(ArrayRef<Value *> Instrs) {
 
 Value *BottomUpVec::createVectorInstr(ArrayRef<Value *> Bndl,
                                       ArrayRef<Value *> Operands) {
-  Change = true;
-  assert(all_of(Bndl, [](auto *V) { return isa<Instruction>(V); }) &&
-         "Expect Instructions!");
-  auto &Ctx = Bndl[0]->getContext();
+  auto CreateVectorInstr = [](ArrayRef<Value *> Bndl,
+                              ArrayRef<Value *> Operands) -> Value * {
+    assert(all_of(Bndl, [](auto *V) { return isa<Instruction>(V); }) &&
+           "Expect Instructions!");
+    auto &Ctx = Bndl[0]->getContext();
 
-  Type *ScalarTy = VecUtils::getElementType(Utils::getExpectedType(Bndl[0]));
-  auto *VecTy = VecUtils::getWideType(ScalarTy, VecUtils::getNumLanes(Bndl));
+    Type *ScalarTy = VecUtils::getElementType(Utils::getExpectedType(Bndl[0]));
+    auto *VecTy = VecUtils::getWideType(ScalarTy, VecUtils::getNumLanes(Bndl));
 
-  BasicBlock::iterator WhereIt = getInsertPointAfterInstrs(Bndl);
+    BasicBlock::iterator WhereIt = getInsertPointAfterInstrs(Bndl);
 
-  auto Opcode = cast<Instruction>(Bndl[0])->getOpcode();
-  switch (Opcode) {
-  case Instruction::Opcode::ZExt:
-  case Instruction::Opcode::SExt:
-  case Instruction::Opcode::FPToUI:
-  case Instruction::Opcode::FPToSI:
-  case Instruction::Opcode::FPExt:
-  case Instruction::Opcode::PtrToInt:
-  case Instruction::Opcode::IntToPtr:
-  case Instruction::Opcode::SIToFP:
-  case Instruction::Opcode::UIToFP:
-  case Instruction::Opcode::Trunc:
-  case Instruction::Opcode::FPTrunc:
-  case Instruction::Opcode::BitCast: {
-    assert(Operands.size() == 1u && "Casts are unary!");
-    return CastInst::create(VecTy, Opcode, Operands[0], WhereIt, Ctx, "VCast");
-  }
-  case Instruction::Opcode::FCmp:
-  case Instruction::Opcode::ICmp: {
-    auto Pred = cast<CmpInst>(Bndl[0])->getPredicate();
-    assert(all_of(drop_begin(Bndl),
-                  [Pred](auto *SBV) {
-                    return cast<CmpInst>(SBV)->getPredicate() == Pred;
-                  }) &&
-           "Expected same predicate across bundle.");
-    return CmpInst::create(Pred, Operands[0], Operands[1], WhereIt, Ctx,
-                           "VCmp");
-  }
-  case Instruction::Opcode::Select: {
-    return SelectInst::create(Operands[0], Operands[1], Operands[2], WhereIt,
-                              Ctx, "Vec");
-  }
-  case Instruction::Opcode::FNeg: {
-    auto *UOp0 = cast<UnaryOperator>(Bndl[0]);
-    auto OpC = UOp0->getOpcode();
-    return UnaryOperator::createWithCopiedFlags(OpC, Operands[0], UOp0, WhereIt,
-                                                Ctx, "Vec");
-  }
-  case Instruction::Opcode::Add:
-  case Instruction::Opcode::FAdd:
-  case Instruction::Opcode::Sub:
-  case Instruction::Opcode::FSub:
-  case Instruction::Opcode::Mul:
-  case Instruction::Opcode::FMul:
-  case Instruction::Opcode::UDiv:
-  case Instruction::Opcode::SDiv:
-  case Instruction::Opcode::FDiv:
-  case Instruction::Opcode::URem:
-  case Instruction::Opcode::SRem:
-  case Instruction::Opcode::FRem:
-  case Instruction::Opcode::Shl:
-  case Instruction::Opcode::LShr:
-  case Instruction::Opcode::AShr:
-  case Instruction::Opcode::And:
-  case Instruction::Opcode::Or:
-  case Instruction::Opcode::Xor: {
-    auto *BinOp0 = cast<BinaryOperator>(Bndl[0]);
-    auto *LHS = Operands[0];
-    auto *RHS = Operands[1];
-    return BinaryOperator::createWithCopiedFlags(BinOp0->getOpcode(), LHS, RHS,
-                                                 BinOp0, WhereIt, Ctx, "Vec");
-  }
-  case Instruction::Opcode::Load: {
-    auto *Ld0 = cast<LoadInst>(Bndl[0]);
-    Value *Ptr = Ld0->getPointerOperand();
-    return LoadInst::create(VecTy, Ptr, Ld0->getAlign(), WhereIt, Ctx, "VecL");
-  }
-  case Instruction::Opcode::Store: {
-    auto Align = cast<StoreInst>(Bndl[0])->getAlign();
-    Value *Val = Operands[0];
-    Value *Ptr = Operands[1];
-    return StoreInst::create(Val, Ptr, Align, WhereIt, Ctx);
-  }
-  case Instruction::Opcode::Br:
-  case Instruction::Opcode::Ret:
-  case Instruction::Opcode::PHI:
-  case Instruction::Opcode::AddrSpaceCast:
-  case Instruction::Opcode::Call:
-  case Instruction::Opcode::GetElementPtr:
-    llvm_unreachable("Unimplemented");
-    break;
-  default:
-    llvm_unreachable("Unimplemented");
-    break;
+    auto Opcode = cast<Instruction>(Bndl[0])->getOpcode();
+    switch (Opcode) {
+    case Instruction::Opcode::ZExt:
+    case Instruction::Opcode::SExt:
+    case Instruction::Opcode::FPToUI:
+    case Instruction::Opcode::FPToSI:
+    case Instruction::Opcode::FPExt:
+    case Instruction::Opcode::PtrToInt:
+    case Instruction::Opcode::IntToPtr:
+    case Instruction::Opcode::SIToFP:
+    case Instruction::Opcode::UIToFP:
+    case Instruction::Opcode::Trunc:
+    case Instruction::Opcode::FPTrunc:
+    case Instruction::Opcode::BitCast: {
+      assert(Operands.size() == 1u && "Casts are unary!");
+      return CastInst::create(VecTy, Opcode, Operands[0], WhereIt, Ctx,
+                              "VCast");
+    }
+    case Instruction::Opcode::FCmp:
+    case Instruction::Opcode::ICmp: {
+      auto Pred = cast<CmpInst>(Bndl[0])->getPredicate();
+      assert(all_of(drop_begin(Bndl),
+                    [Pred](auto *SBV) {
+                      return cast<CmpInst>(SBV)->getPredicate() == Pred;
+                    }) &&
+             "Expected same predicate across bundle.");
+      return CmpInst::create(Pred, Operands[0], Operands[1], WhereIt, Ctx,
+                             "VCmp");
+    }
+    case Instruction::Opcode::Select: {
+      return SelectInst::create(Operands[0], Operands[1], Operands[2], WhereIt,
+                                Ctx, "Vec");
+    }
+    case Instruction::Opcode::FNeg: {
+      auto *UOp0 = cast<UnaryOperator>(Bndl[0]);
+      auto OpC = UOp0->getOpcode();
+      return UnaryOperator::createWithCopiedFlags(OpC, Operands[0], UOp0,
+                                                  WhereIt, Ctx, "Vec");
+    }
+    case Instruction::Opcode::Add:
+    case Instruction::Opcode::FAdd:
+    case Instruction::Opcode::Sub:
+    case Instruction::Opcode::FSub:
+    case Instruction::Opcode::Mul:
+    case Instruction::Opcode::FMul:
+    case Instruction::Opcode::UDiv:
+    case Instruction::Opcode::SDiv:
+    case Instruction::Opcode::FDiv:
+    case Instruction::Opcode::URem:
+    case Instruction::Opcode::SRem:
+    case Instruction::Opcode::FRem:
+    case Instruction::Opcode::Shl:
+    case Instruction::Opcode::LShr:
+    case Instruction::Opcode::AShr:
+    case Instruction::Opcode::And:
+    case Instruction::Opcode::Or:
+    case Instruction::Opcode::Xor: {
+      auto *BinOp0 = cast<BinaryOperator>(Bndl[0]);
+      auto *LHS = Operands[0];
+      auto *RHS = Operands[1];
+      return BinaryOperator::createWithCopiedFlags(
+          BinOp0->getOpcode(), LHS, RHS, BinOp0, WhereIt, Ctx, "Vec");
+    }
+    case Instruction::Opcode::Load: {
+      auto *Ld0 = cast<LoadInst>(Bndl[0]);
+      Value *Ptr = Ld0->getPointerOperand();
+      return LoadInst::create(VecTy, Ptr, Ld0->getAlign(), WhereIt, Ctx,
+                              "VecL");
+    }
+    case Instruction::Opcode::Store: {
+      auto Align = cast<StoreInst>(Bndl[0])->getAlign();
+      Value *Val = Operands[0];
+      Value *Ptr = Operands[1];
+      return StoreInst::create(Val, Ptr, Align, WhereIt, Ctx);
+    }
+    case Instruction::Opcode::Br:
+    case Instruction::Opcode::Ret:
+    case Instruction::Opcode::PHI:
+    case Instruction::Opcode::AddrSpaceCast:
+    case Instruction::Opcode::Call:
+    case Instruction::Opcode::GetElementPtr:
+      llvm_unreachable("Unimplemented");
+      break;
+    default:
+      llvm_unreachable("Unimplemented");
+      break;
+    }
+    llvm_unreachable("Missing switch case!");
+    // TODO: Propagate debug info.
+  };
+
+  auto *VecI = CreateVectorInstr(Bndl, Operands);
+  if (VecI != nullptr) {
+    Change = true;
+    IMaps.registerVector(Bndl, VecI);
   }
-  llvm_unreachable("Missing switch case!");
-  // TODO: Propagate debug info.
+  return VecI;
 }
 
 void BottomUpVec::tryEraseDeadInstrs() {
@@ -280,6 +291,10 @@ Value *BottomUpVec::vectorizeRec(ArrayRef<Value *> Bndl, unsigned Depth) {
       collectPotentiallyDeadInstrs(Bndl);
     break;
   }
+  case LegalityResultID::DiamondReuse: {
+    NewVec = cast<DiamondReuse>(LegalityRes).getVector();
+    break;
+  }
   case LegalityResultID::Pack: {
     // If we can't vectorize the seeds then just return.
     if (Depth == 0)
@@ -300,9 +315,10 @@ bool BottomUpVec::tryVectorize(ArrayRef<Value *> Bndl) {
 }
 
 bool BottomUpVec::runOnFunction(Function &F, const Analyses &A) {
+  IMaps.clear();
   Legality = std::make_unique<LegalityAnalysis>(
       A.getAA(), A.getScalarEvolution(), F.getParent()->getDataLayout(),
-      F.getContext());
+      F.getContext(), IMaps);
   Change = false;
   const auto &DL = F.getParent()->getDataLayout();
   unsigned VecRegBits =
diff --git a/llvm/test/Transforms/SandboxVectorizer/bottomup_basic.ll b/llvm/test/Transforms/SandboxVectorizer/bottomup_basic.ll
index d34c8f88e4b3c6..7bc6e5ac3d7605 100644
--- a/llvm/test/Transforms/SandboxVectorizer/bottomup_basic.ll
+++ b/llvm/test/Transforms/SandboxVectorizer/bottomup_basic.ll
@@ -201,3 +201,23 @@ define void @pack_vectors(ptr %ptr, ptr %ptr2) {
   store float %ld1, ptr %ptr1
   ret void
 }
+
+define void @diamond(ptr %ptr) {
+; CHECK-LABEL: define void @diamond(
+; CHECK-SAME: ptr [[PTR:%.*]]) {
+; CHECK-NEXT:    [[PTR0:%.*]] = getelementptr float, ptr [[PTR]], i32 0
+; CHECK-NEXT:    [[VECL:%.*]] = load <2 x float>, ptr [[PTR0]], align 4
+; CHECK-NEXT:    [[VEC:%.*]] = fsub <2 x float> [[VECL]], [[VECL]]
+; CHECK-NEXT:    store <2 x float> [[VEC]], ptr [[PTR0]], align 4
+; CHECK-NEXT:    ret void
+;
+  %ptr0 = getelementptr float, ptr %ptr, i32 0
+  %ptr1 = getelementptr float, ptr %ptr, i32 1
+  %ld0 = load float, ptr %ptr0
+  %ld1 = load float, ptr %ptr1
+  %sub0 = fsub float %ld0, %ld0
+  %sub1 = fsub float %ld1, %ld1
+  store float %sub0, ptr %ptr0
+  store float %sub1, ptr %ptr1
+  ret void
+}
diff --git a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/CMakeLists.txt b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/CMakeLists.txt
index df689767b77245..bbfbcc730a4cbe 100644
--- a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/CMakeLists.txt
+++ b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/CMakeLists.txt
@@ -9,6 +9,7 @@ set(LLVM_LINK_COMPONENTS
 
 add_llvm_unittest(SandboxVectorizerTests
   DependencyGraphTest.cpp
+  InstrMapsTest.cpp
   IntervalTest.cpp
   LegalityTest.cpp
   SchedulerTest.cpp
diff --git a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/InstrMapsTest.cpp b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/InstrMapsTest.cpp
new file mode 100644
index 00000000000000..bcfb8db7f86741
--- /dev/null
+++ b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/InstrMapsTest.cpp
@@ -0,0 +1,78 @@
+//===- InstrMapsTest.cpp --------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Transforms/Vectorize/SandboxVectorizer/InstrMaps.h"
+#include "llvm/ADT/SmallSet.h"
+#include "llvm/AsmParser/Parser.h"
+#include "llvm/SandboxIR/Function.h"
+#include "llvm/SandboxIR/Instruction.h"
+#include "llvm/Support/SourceMgr.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+using namespace llvm;
+
+struct InstrMapsTest : public testing::Test {
+  LLVMContext C;
+  std::unique_ptr<Module> M;
+
+  void parseIR(LLVMContext &C, const char *IR) {
+    SMDiagnostic Err;
+    M = parseAssemblyString(IR, Err, C);
+    if (!M)
+      Err.print("InstrMapsTest", errs());
+  }
+};
+
+TEST_F(InstrMapsTest, Basic) {
+  parseIR(C, R"IR(
+define void @foo(i8 %v0, i8 %v1, i8 %v2, i8 %v3, <2 x i8> %vec) {
+  %add0 = add i8 %v0, %v0
+  %add1 = add i8 %v1, %v1
+  %add2 = add i8 %v2, %v2
+  %add3 = add i8 %v3, %v3
+  %vadd0 = add <2 x i8> %vec, %vec
+  ret void
+}
+)IR");
+  llvm::Function *LLVMF = &*M->getFunction("foo");
+  sandboxir::Context Ctx(C);
+  auto *F = Ctx.createFunction(LLVMF);
+  auto *BB = &*F->begin();
+  auto It = BB->begin();
+
+  auto *Add0 = cast<sandboxir::BinaryOperator>(&*It++);
+  auto *Add1 = cast<sandboxir::BinaryOperator>(&*It++);
+  auto *Add2 = cast<sandboxir::BinaryOperator>(&*It++);
+  auto *Add3 = cast<sandboxir::BinaryOperator>(&*It++);
+  auto *VAdd0 = cast<sandboxir::BinaryOperator>(&*It++);
+  [[maybe_unused]] auto *Ret = cast<sandboxir::ReturnInst>(&*It++);
+
+  sandboxir::InstrMaps IMaps;
+  // Check with empty IMaps.
+  EXPECT_EQ(IMaps.getVectorForOrig(Add0), nullptr);
+  EXPECT_EQ(IMaps.getVectorForOrig(Add1), nullptr);
+  EXPECT_FALSE(IMaps.getOrigLane(Add0, Add0));
+  // Check with 1 match.
+  IMaps.registerVector({Add0, Add1}, VAdd0);
+  EXPECT_EQ(IMaps.getVectorForOrig(Add0), VAdd0);
+  EXPECT_EQ(IMaps.getVectorForOrig(Add1), VAdd0);
+  EXPECT_FALSE(IMaps.getOrigLane(VAdd0, VAdd0)); // Bad Orig value
+  EXPECT_FALSE(IMaps.getOrigLane(Add0, Add0));   // Bad Vector value
+  EXPECT_EQ(*IMaps.getOrigLane(VAdd0, Add0), 0);
+  EXPECT_EQ(*IMaps.getOrigLane(VAdd0, Add1), 1);
+  // Check when the same vector maps to different original values (which is
+  // common for vector constants).
+  IMaps.registerVector({Add2, Add3}, VAdd0);
+  EXPECT_EQ(*IMaps.getOrigLane(VAdd0, Add2), 0);
+  EXPECT_EQ(*IMaps.getOrigLane(VAdd0, Add3), 1);
+  // Check when we register for a second time.
+#ifndef NDEBUG
+  EXPECT_DEATH(IMaps.registerVector({Add1, Add0}, VAdd0), ".*exists.*");
+#endif // NDEBUG
+}
diff --git a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/LegalityTest.cpp b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/LegalityTest.cpp
index b5e2c302f5901e..2e90462a633c17 100644
--- a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/LegalityTest.cpp
+++ b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/LegalityTest.cpp
@@ -18,6 +18,7 @@
 #include "llvm/SandboxIR/Function.h"
 #include "llvm/SandboxIR/Instruction.h"
 #include "llvm/Support/SourceMgr.h"
+#include "llvm/Transforms/Vectorize/SandboxVectorizer/InstrMaps.h"
 #include "gtest/gtest.h"
 
 using namespace llvm;
@@ -110,7 +111,8 @@ define void @foo(ptr %ptr, <2 x float> %vec2, <3 x float> %vec3, i8 %arg, float
   auto *CmpSLT = cast<sandboxir::CmpInst>(&*It++);
   auto *CmpSGT = cast<sandboxir::CmpInst>(&*It++);
 
-  sandboxir::LegalityAnalysis Legality(*AA, *SE, DL, Ctx);
+  llvm::sandboxir::InstrMaps IMaps;
+  sandboxir::LegalityAnalysis Legality(*AA, *SE, DL, Ctx, IMaps);
   const auto &Result =
       Legality.canVectorize({St0, St1}, /*SkipScheduling=*/true);
   EXPECT_TRUE(isa<sandboxir::Widen>(Result));
@@ -228,7 +230,8 @@ define void @foo(ptr %ptr) {
   auto *St0 = cast<sandboxir::StoreInst>(&*It++);
   auto *St1 = cast<sandboxir::StoreInst>(&*It++);
 
-  sandboxir::LegalityAnalysis Legality(*AA, *SE, DL, Ctx);
+  llvm::sandboxir::InstrMaps IMaps;
+  sandboxir::LegalityAnalysis Legality(*AA, *SE, DL, Ctx, IMaps);
   {
     // Can vectorize St0,St1.
     const auto &Result = Legality.canVectorize({St0, St1});
@@ -263,7 +266,8 @@ define void @foo() {
   };
 
   sandboxir::Context Ctx(C);
-  sandboxir::LegalityAnalysis Legality(*AA, *SE, DL, Ctx);
+  llvm::sandboxir::InstrMaps IMaps;
+  sandboxir::LegalityAnalysis Legality(*AA, *SE, DL, Ctx, IMaps);
   EXPECT_TRUE(
       Matches(Legality.createLegalityResult<sandboxir::Widen>(), "Widen"));
   EXPECT_TRUE(Matches(Legality.createLegalityResult<sandboxir::Pack>(
@@ -283,3 +287,68 @@ define void @foo() {
                       "Pack Reason: DiffWrapFlags"));
 }
 #endif // NDEBUG
+
+TEST_F(LegalityTest, CollectDescr) {
+  parseIR(C, R"IR(
+define void @foo(ptr %ptr) {
+  %gep0 = getelementptr float, ptr %ptr, i32 0
+  %gep1 = getelementptr float, ptr %ptr, i32 1
+  %ld0 = load float, ptr %gep0
+  %ld1 = load float, ptr %gep1
+  %vld = load <4 x float>, ptr %ptr
+  ret void
+}
+)IR");
+  llvm::Function *LLVMF = &*M->getFunction("foo");
+  getAnalyses(*LLVMF);
+  sandboxir::Context Ctx(C);
+  auto *F = Ctx.createFunction(LLVMF);
+  auto *BB = &*F->begin();
+  auto It = BB->begin();
+  [[maybe_unused]] auto *Gep0 = cast<sandboxir::GetElementPtrInst>(&*It++);
+  [[maybe_unused]] auto *Gep1 = cast<sandboxir::GetElementPtrInst>(&*It++);
+  auto *Ld0 = cast<sandboxir::LoadInst>(&*It++);
+  [[maybe_unused]] auto *Ld1 = cast<sandboxir::LoadInst>(&*It++);
+  auto *VLd = cast<sandboxir::LoadInst>(&*It++);
+
+  sandboxir::CollectDescr::DescrVecT Descrs;
+  using EEDescr = sandboxir::CollectDescr::ExtractElementDescr;
+
+  {
+    // Check single input, no shuffle.
+    Descrs.push_back(EEDescr(VLd, 0));
+    Descrs.push_back(EEDescr(VLd, 1));
+    sandboxir::CollectDescr CD(std::move(Descrs));
+    EXPECT_TRUE(CD.getSingleInput());
+    EXPECT_EQ(CD.getSingleInput()->first, VLd);
+    EXPECT_EQ(CD.getSingleInput()->second, false);
+    EXPECT_TRUE(CD.hasVectorInputs());
+  }
+  {
+    // Check single input, shuffle.
+    Descrs.push_back(EEDescr(VLd, 1));
+    Descrs.push_back(EEDescr(VLd, 0));
+    sandboxir::CollectDescr CD(std::move(Descrs));
+    EXPECT_TRUE(CD.getSingleInput());
+    EXPECT_EQ(CD.getSingleInput()->first, VLd);
+    EXPECT_EQ(CD.getSingleInput()->second, true);
+    EXPECT_TRUE(CD.hasVectorInputs());
+  }
+  {
+    // Check multiple inputs.
+    Descrs.push_back(EEDescr(Ld0));
+    Descrs.push_back(EEDescr(VLd, 0));
+    Descrs.push_back(EEDescr(VLd, 1));
+    sandboxir::CollectDescr CD(std::move(Descrs));
+    EXPECT_FALSE(CD.getSingleInput());
+    EXPECT_TRUE(CD.hasVectorInputs());
+  }
+  {
+    // Check multiple inputs only scalars.
+    Descrs.push_back(EEDescr(Ld0));
+    Descrs.push_back(EEDescr(Ld1));
+    sandboxir::CollectDescr CD(std::move(Descrs));
+    EXPECT_FALSE(CD.getSingleInput());
+    EXPECT_FALSE(CD.hasVectorInputs());
+  }
+}



More information about the llvm-commits mailing list