[llvm] e902c69 - [SandboxVec][BottomUpVec] Implement InstrMaps (#122848)
via llvm-commits
llvm-commits at lists.llvm.org
Thu Jan 16 15:26:39 PST 2025
Author: vporpo
Date: 2025-01-16T15:26:35-08:00
New Revision: e902c6960cff4372d4b3ef9ae424b24ec6b0ea38
URL: https://github.com/llvm/llvm-project/commit/e902c6960cff4372d4b3ef9ae424b24ec6b0ea38
DIFF: https://github.com/llvm/llvm-project/commit/e902c6960cff4372d4b3ef9ae424b24ec6b0ea38.diff
LOG: [SandboxVec][BottomUpVec] Implement InstrMaps (#122848)
InstrMaps is a helper data structure that maps scalars to vectors and
the reverse. This is used by the vectorizer to figure out which vectors
it can extract scalar values from.
Added:
llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/InstrMaps.h
llvm/lib/Transforms/Vectorize/SandboxVectorizer/InstrMaps.cpp
llvm/unittests/Transforms/Vectorize/SandboxVectorizer/InstrMapsTest.cpp
Modified:
llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h
llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.h
llvm/lib/Transforms/Vectorize/CMakeLists.txt
llvm/lib/Transforms/Vectorize/SandboxVectorizer/Legality.cpp
llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp
llvm/test/Transforms/SandboxVectorizer/bottomup_basic.ll
llvm/unittests/Transforms/Vectorize/SandboxVectorizer/CMakeLists.txt
llvm/unittests/Transforms/Vectorize/SandboxVectorizer/LegalityTest.cpp
Removed:
################################################################################
diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/InstrMaps.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/InstrMaps.h
new file mode 100644
index 00000000000000..2c4ba30f6fd052
--- /dev/null
+++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/InstrMaps.h
@@ -0,0 +1,77 @@
+//===- InstrMaps.h ----------------------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_TRANSFORMS_VECTORIZE_SANDBOXVEC_PASSES_INSTRMAPS_H
+#define LLVM_TRANSFORMS_VECTORIZE_SANDBOXVEC_PASSES_INSTRMAPS_H
+
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/SmallSet.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/SandboxIR/Value.h"
+#include "llvm/Support/Casting.h"
+#include "llvm/Support/raw_ostream.h"
+
+namespace llvm::sandboxir {
+
+/// Maps the original instructions to the vectorized instrs and the reverse.
+/// For now an original instr can only map to a single vector.
+class InstrMaps {
+ /// A map from the original values that got combined into vectors, to the
+ /// vector value(s).
+ DenseMap<Value *, Value *> OrigToVectorMap;
+ /// A map from the vector value to a map of the original value to its lane.
+ /// Please note that for constant vectors, there may multiple original values
+ /// with the same lane, as they may be coming from vectorizing
diff erent
+ /// original values.
+ DenseMap<Value *, DenseMap<Value *, unsigned>> VectorToOrigLaneMap;
+
+public:
+ /// \Returns the vector value that we got from vectorizing \p Orig, or
+ /// nullptr if not found.
+ Value *getVectorForOrig(Value *Orig) const {
+ auto It = OrigToVectorMap.find(Orig);
+ return It != OrigToVectorMap.end() ? It->second : nullptr;
+ }
+ /// \Returns the lane of \p Orig before it got vectorized into \p Vec, or
+ /// nullopt if not found.
+ std::optional<unsigned> getOrigLane(Value *Vec, Value *Orig) const {
+ auto It1 = VectorToOrigLaneMap.find(Vec);
+ if (It1 == VectorToOrigLaneMap.end())
+ return std::nullopt;
+ const auto &OrigToLaneMap = It1->second;
+ auto It2 = OrigToLaneMap.find(Orig);
+ if (It2 == OrigToLaneMap.end())
+ return std::nullopt;
+ return It2->second;
+ }
+ /// Update the map to reflect that \p Origs got vectorized into \p Vec.
+ void registerVector(ArrayRef<Value *> Origs, Value *Vec) {
+ auto &OrigToLaneMap = VectorToOrigLaneMap[Vec];
+ for (auto [Lane, Orig] : enumerate(Origs)) {
+ auto Pair = OrigToVectorMap.try_emplace(Orig, Vec);
+ assert(Pair.second && "Orig already exists in the map!");
+ OrigToLaneMap[Orig] = Lane;
+ }
+ }
+ void clear() {
+ OrigToVectorMap.clear();
+ VectorToOrigLaneMap.clear();
+ }
+#ifndef NDEBUG
+ void print(raw_ostream &OS) const {
+ OS << "OrigToVectorMap:\n";
+ for (auto [Orig, Vec] : OrigToVectorMap)
+ OS << *Orig << " : " << *Vec << "\n";
+ }
+ LLVM_DUMP_METHOD void dump() const;
+#endif
+};
+} // namespace llvm::sandboxir
+
+#endif // LLVM_TRANSFORMS_VECTORIZE_SANDBOXVEC_PASSES_INSTRMAPS_H
diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h
index 233cf82a1b3dfb..c03e7a10397ad2 100644
--- a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h
+++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h
@@ -23,10 +23,12 @@ namespace llvm::sandboxir {
class LegalityAnalysis;
class Value;
+class InstrMaps;
enum class LegalityResultID {
- Pack, ///> Collect scalar values.
- Widen, ///> Vectorize by combining scalars to a vector.
+ Pack, ///> Collect scalar values.
+ Widen, ///> Vectorize by combining scalars to a vector.
+ DiamondReuse, ///> Don't generate new code, reuse existing vector.
};
/// The reason for vectorizing or not vectorizing.
@@ -50,6 +52,8 @@ struct ToStr {
return "Pack";
case LegalityResultID::Widen:
return "Widen";
+ case LegalityResultID::DiamondReuse:
+ return "DiamondReuse";
}
llvm_unreachable("Unknown LegalityResultID enum");
}
@@ -137,6 +141,19 @@ class Widen final : public LegalityResult {
}
};
+class DiamondReuse final : public LegalityResult {
+ friend class LegalityAnalysis;
+ Value *Vec;
+ DiamondReuse(Value *Vec)
+ : LegalityResult(LegalityResultID::DiamondReuse), Vec(Vec) {}
+
+public:
+ static bool classof(const LegalityResult *From) {
+ return From->getSubclassID() == LegalityResultID::DiamondReuse;
+ }
+ Value *getVector() const { return Vec; }
+};
+
class Pack final : public LegalityResultWithReason {
Pack(ResultReason Reason)
: LegalityResultWithReason(LegalityResultID::Pack, Reason) {}
@@ -148,6 +165,59 @@ class Pack final : public LegalityResultWithReason {
}
};
+/// Describes how to collect the values needed by each lane.
+class CollectDescr {
+public:
+ /// Describes how to get a value element. If the value is a vector then it
+ /// also provides the index to extract it from.
+ class ExtractElementDescr {
+ Value *V;
+ /// The index in `V` that the value can be extracted from.
+ /// This is nullopt if we need to use `V` as a whole.
+ std::optional<int> ExtractIdx;
+
+ public:
+ ExtractElementDescr(Value *V, int ExtractIdx)
+ : V(V), ExtractIdx(ExtractIdx) {}
+ ExtractElementDescr(Value *V) : V(V), ExtractIdx(std::nullopt) {}
+ Value *getValue() const { return V; }
+ bool needsExtract() const { return ExtractIdx.has_value(); }
+ int getExtractIdx() const { return *ExtractIdx; }
+ };
+
+ using DescrVecT = SmallVector<ExtractElementDescr, 4>;
+ DescrVecT Descrs;
+
+public:
+ CollectDescr(SmallVectorImpl<ExtractElementDescr> &&Descrs)
+ : Descrs(std::move(Descrs)) {}
+ /// If all elements come from a single vector input, then return that vector
+ /// and whether we need a shuffle to get them in order.
+ std::optional<std::pair<Value *, bool>> getSingleInput() const {
+ const auto &Descr0 = *Descrs.begin();
+ Value *V0 = Descr0.getValue();
+ if (!Descr0.needsExtract())
+ return std::nullopt;
+ bool NeedsShuffle = Descr0.getExtractIdx() != 0;
+ int Lane = 1;
+ for (const auto &Descr : drop_begin(Descrs)) {
+ if (!Descr.needsExtract())
+ return std::nullopt;
+ if (Descr.getValue() != V0)
+ return std::nullopt;
+ if (Descr.getExtractIdx() != Lane++)
+ NeedsShuffle = true;
+ }
+ return std::make_pair(V0, NeedsShuffle);
+ }
+ bool hasVectorInputs() const {
+ return any_of(Descrs, [](const auto &D) { return D.needsExtract(); });
+ }
+ const SmallVector<ExtractElementDescr, 4> &getDescrs() const {
+ return Descrs;
+ }
+};
+
/// Performs the legality analysis and returns a LegalityResult object.
class LegalityAnalysis {
Scheduler Sched;
@@ -160,11 +230,17 @@ class LegalityAnalysis {
ScalarEvolution &SE;
const DataLayout &DL;
+ InstrMaps &IMaps;
+
+ /// Finds how we can collect the values in \p Bndl from the vectorized or
+ /// non-vectorized code. It returns a map of the value we should extract from
+ /// and the corresponding shuffle mask we need to use.
+ CollectDescr getHowToCollectValues(ArrayRef<Value *> Bndl) const;
public:
LegalityAnalysis(AAResults &AA, ScalarEvolution &SE, const DataLayout &DL,
- Context &Ctx)
- : Sched(AA, Ctx), SE(SE), DL(DL) {}
+ Context &Ctx, InstrMaps &IMaps)
+ : Sched(AA, Ctx), SE(SE), DL(DL), IMaps(IMaps) {}
/// A LegalityResult factory.
template <typename ResultT, typename... ArgsT>
ResultT &createLegalityResult(ArgsT... Args) {
@@ -177,7 +253,7 @@ class LegalityAnalysis {
// TODO: Try to remove the SkipScheduling argument by refactoring the tests.
const LegalityResult &canVectorize(ArrayRef<Value *> Bndl,
bool SkipScheduling = false);
- void clear() { Sched.clear(); }
+ void clear();
};
} // namespace llvm::sandboxir
diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.h
index 1a53ca6e06f5fd..69cea3c4c7b53b 100644
--- a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.h
+++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.h
@@ -18,6 +18,7 @@
#include "llvm/SandboxIR/Pass.h"
#include "llvm/SandboxIR/PassManager.h"
#include "llvm/Support/raw_ostream.h"
+#include "llvm/Transforms/Vectorize/SandboxVectorizer/InstrMaps.h"
#include "llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h"
namespace llvm::sandboxir {
@@ -26,6 +27,8 @@ class BottomUpVec final : public FunctionPass {
bool Change = false;
std::unique_ptr<LegalityAnalysis> Legality;
DenseSet<Instruction *> DeadInstrCandidates;
+ /// Maps scalars to vectors.
+ InstrMaps IMaps;
/// Creates and returns a vector instruction that replaces the instructions in
/// \p Bndl. \p Operands are the already vectorized operands.
diff --git a/llvm/lib/Transforms/Vectorize/CMakeLists.txt b/llvm/lib/Transforms/Vectorize/CMakeLists.txt
index d769d5100afd23..6a025652f92f8e 100644
--- a/llvm/lib/Transforms/Vectorize/CMakeLists.txt
+++ b/llvm/lib/Transforms/Vectorize/CMakeLists.txt
@@ -4,6 +4,7 @@ add_llvm_component_library(LLVMVectorize
LoopVectorizationLegality.cpp
LoopVectorize.cpp
SandboxVectorizer/DependencyGraph.cpp
+ SandboxVectorizer/InstrMaps.cpp
SandboxVectorizer/Interval.cpp
SandboxVectorizer/Legality.cpp
SandboxVectorizer/Passes/BottomUpVec.cpp
diff --git a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/InstrMaps.cpp b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/InstrMaps.cpp
new file mode 100644
index 00000000000000..4df4829a04c417
--- /dev/null
+++ b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/InstrMaps.cpp
@@ -0,0 +1,21 @@
+//===- InstructionMaps.cpp - Maps scalars to vectors and reverse ----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Transforms/Vectorize/SandboxVectorizer/InstrMaps.h"
+#include "llvm/Support/Debug.h"
+
+namespace llvm::sandboxir {
+
+#ifndef NDEBUG
+void InstrMaps::dump() const {
+ print(dbgs());
+ dbgs() << "\n";
+}
+#endif // NDEBUG
+
+} // namespace llvm::sandboxir
diff --git a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Legality.cpp b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Legality.cpp
index 8c6deeb7df249d..f8149c5bc66363 100644
--- a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Legality.cpp
+++ b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Legality.cpp
@@ -12,6 +12,7 @@
#include "llvm/SandboxIR/Utils.h"
#include "llvm/SandboxIR/Value.h"
#include "llvm/Support/Debug.h"
+#include "llvm/Transforms/Vectorize/SandboxVectorizer/InstrMaps.h"
#include "llvm/Transforms/Vectorize/SandboxVectorizer/VecUtils.h"
namespace llvm::sandboxir {
@@ -184,6 +185,22 @@ static void dumpBndl(ArrayRef<Value *> Bndl) {
}
#endif // NDEBUG
+CollectDescr
+LegalityAnalysis::getHowToCollectValues(ArrayRef<Value *> Bndl) const {
+ SmallVector<CollectDescr::ExtractElementDescr, 4> Vec;
+ Vec.reserve(Bndl.size());
+ for (auto [Lane, V] : enumerate(Bndl)) {
+ if (auto *VecOp = IMaps.getVectorForOrig(V)) {
+ // If there is a vector containing `V`, then get the lane it came from.
+ std::optional<int> ExtractIdxOpt = IMaps.getOrigLane(VecOp, V);
+ Vec.emplace_back(VecOp, ExtractIdxOpt ? *ExtractIdxOpt : -1);
+ } else {
+ Vec.emplace_back(V);
+ }
+ }
+ return CollectDescr(std::move(Vec));
+}
+
const LegalityResult &LegalityAnalysis::canVectorize(ArrayRef<Value *> Bndl,
bool SkipScheduling) {
// If Bndl contains values other than instructions, we need to Pack.
@@ -193,11 +210,21 @@ const LegalityResult &LegalityAnalysis::canVectorize(ArrayRef<Value *> Bndl,
return createLegalityResult<Pack>(ResultReason::NotInstructions);
}
+ auto CollectDescrs = getHowToCollectValues(Bndl);
+ if (CollectDescrs.hasVectorInputs()) {
+ if (auto ValueShuffleOpt = CollectDescrs.getSingleInput()) {
+ auto [Vec, NeedsShuffle] = *ValueShuffleOpt;
+ if (!NeedsShuffle)
+ return createLegalityResult<DiamondReuse>(Vec);
+ llvm_unreachable("TODO: Unimplemented");
+ } else {
+ llvm_unreachable("TODO: Unimplemented");
+ }
+ }
+
if (auto ReasonOpt = notVectorizableBasedOnOpcodesAndTypes(Bndl))
return createLegalityResult<Pack>(*ReasonOpt);
- // TODO: Check for existing vectors containing values in Bndl.
-
if (!SkipScheduling) {
// TODO: Try to remove the IBndl vector.
SmallVector<Instruction *, 8> IBndl;
@@ -210,4 +237,9 @@ const LegalityResult &LegalityAnalysis::canVectorize(ArrayRef<Value *> Bndl,
return createLegalityResult<Widen>();
}
+
+void LegalityAnalysis::clear() {
+ Sched.clear();
+ IMaps.clear();
+}
} // namespace llvm::sandboxir
diff --git a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp
index d44199609838d7..6b2032be535603 100644
--- a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp
+++ b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp
@@ -56,103 +56,114 @@ getInsertPointAfterInstrs(ArrayRef<Value *> Instrs) {
Value *BottomUpVec::createVectorInstr(ArrayRef<Value *> Bndl,
ArrayRef<Value *> Operands) {
- Change = true;
- assert(all_of(Bndl, [](auto *V) { return isa<Instruction>(V); }) &&
- "Expect Instructions!");
- auto &Ctx = Bndl[0]->getContext();
+ auto CreateVectorInstr = [](ArrayRef<Value *> Bndl,
+ ArrayRef<Value *> Operands) -> Value * {
+ assert(all_of(Bndl, [](auto *V) { return isa<Instruction>(V); }) &&
+ "Expect Instructions!");
+ auto &Ctx = Bndl[0]->getContext();
- Type *ScalarTy = VecUtils::getElementType(Utils::getExpectedType(Bndl[0]));
- auto *VecTy = VecUtils::getWideType(ScalarTy, VecUtils::getNumLanes(Bndl));
+ Type *ScalarTy = VecUtils::getElementType(Utils::getExpectedType(Bndl[0]));
+ auto *VecTy = VecUtils::getWideType(ScalarTy, VecUtils::getNumLanes(Bndl));
- BasicBlock::iterator WhereIt = getInsertPointAfterInstrs(Bndl);
+ BasicBlock::iterator WhereIt = getInsertPointAfterInstrs(Bndl);
- auto Opcode = cast<Instruction>(Bndl[0])->getOpcode();
- switch (Opcode) {
- case Instruction::Opcode::ZExt:
- case Instruction::Opcode::SExt:
- case Instruction::Opcode::FPToUI:
- case Instruction::Opcode::FPToSI:
- case Instruction::Opcode::FPExt:
- case Instruction::Opcode::PtrToInt:
- case Instruction::Opcode::IntToPtr:
- case Instruction::Opcode::SIToFP:
- case Instruction::Opcode::UIToFP:
- case Instruction::Opcode::Trunc:
- case Instruction::Opcode::FPTrunc:
- case Instruction::Opcode::BitCast: {
- assert(Operands.size() == 1u && "Casts are unary!");
- return CastInst::create(VecTy, Opcode, Operands[0], WhereIt, Ctx, "VCast");
- }
- case Instruction::Opcode::FCmp:
- case Instruction::Opcode::ICmp: {
- auto Pred = cast<CmpInst>(Bndl[0])->getPredicate();
- assert(all_of(drop_begin(Bndl),
- [Pred](auto *SBV) {
- return cast<CmpInst>(SBV)->getPredicate() == Pred;
- }) &&
- "Expected same predicate across bundle.");
- return CmpInst::create(Pred, Operands[0], Operands[1], WhereIt, Ctx,
- "VCmp");
- }
- case Instruction::Opcode::Select: {
- return SelectInst::create(Operands[0], Operands[1], Operands[2], WhereIt,
- Ctx, "Vec");
- }
- case Instruction::Opcode::FNeg: {
- auto *UOp0 = cast<UnaryOperator>(Bndl[0]);
- auto OpC = UOp0->getOpcode();
- return UnaryOperator::createWithCopiedFlags(OpC, Operands[0], UOp0, WhereIt,
- Ctx, "Vec");
- }
- case Instruction::Opcode::Add:
- case Instruction::Opcode::FAdd:
- case Instruction::Opcode::Sub:
- case Instruction::Opcode::FSub:
- case Instruction::Opcode::Mul:
- case Instruction::Opcode::FMul:
- case Instruction::Opcode::UDiv:
- case Instruction::Opcode::SDiv:
- case Instruction::Opcode::FDiv:
- case Instruction::Opcode::URem:
- case Instruction::Opcode::SRem:
- case Instruction::Opcode::FRem:
- case Instruction::Opcode::Shl:
- case Instruction::Opcode::LShr:
- case Instruction::Opcode::AShr:
- case Instruction::Opcode::And:
- case Instruction::Opcode::Or:
- case Instruction::Opcode::Xor: {
- auto *BinOp0 = cast<BinaryOperator>(Bndl[0]);
- auto *LHS = Operands[0];
- auto *RHS = Operands[1];
- return BinaryOperator::createWithCopiedFlags(BinOp0->getOpcode(), LHS, RHS,
- BinOp0, WhereIt, Ctx, "Vec");
- }
- case Instruction::Opcode::Load: {
- auto *Ld0 = cast<LoadInst>(Bndl[0]);
- Value *Ptr = Ld0->getPointerOperand();
- return LoadInst::create(VecTy, Ptr, Ld0->getAlign(), WhereIt, Ctx, "VecL");
- }
- case Instruction::Opcode::Store: {
- auto Align = cast<StoreInst>(Bndl[0])->getAlign();
- Value *Val = Operands[0];
- Value *Ptr = Operands[1];
- return StoreInst::create(Val, Ptr, Align, WhereIt, Ctx);
- }
- case Instruction::Opcode::Br:
- case Instruction::Opcode::Ret:
- case Instruction::Opcode::PHI:
- case Instruction::Opcode::AddrSpaceCast:
- case Instruction::Opcode::Call:
- case Instruction::Opcode::GetElementPtr:
- llvm_unreachable("Unimplemented");
- break;
- default:
- llvm_unreachable("Unimplemented");
- break;
+ auto Opcode = cast<Instruction>(Bndl[0])->getOpcode();
+ switch (Opcode) {
+ case Instruction::Opcode::ZExt:
+ case Instruction::Opcode::SExt:
+ case Instruction::Opcode::FPToUI:
+ case Instruction::Opcode::FPToSI:
+ case Instruction::Opcode::FPExt:
+ case Instruction::Opcode::PtrToInt:
+ case Instruction::Opcode::IntToPtr:
+ case Instruction::Opcode::SIToFP:
+ case Instruction::Opcode::UIToFP:
+ case Instruction::Opcode::Trunc:
+ case Instruction::Opcode::FPTrunc:
+ case Instruction::Opcode::BitCast: {
+ assert(Operands.size() == 1u && "Casts are unary!");
+ return CastInst::create(VecTy, Opcode, Operands[0], WhereIt, Ctx,
+ "VCast");
+ }
+ case Instruction::Opcode::FCmp:
+ case Instruction::Opcode::ICmp: {
+ auto Pred = cast<CmpInst>(Bndl[0])->getPredicate();
+ assert(all_of(drop_begin(Bndl),
+ [Pred](auto *SBV) {
+ return cast<CmpInst>(SBV)->getPredicate() == Pred;
+ }) &&
+ "Expected same predicate across bundle.");
+ return CmpInst::create(Pred, Operands[0], Operands[1], WhereIt, Ctx,
+ "VCmp");
+ }
+ case Instruction::Opcode::Select: {
+ return SelectInst::create(Operands[0], Operands[1], Operands[2], WhereIt,
+ Ctx, "Vec");
+ }
+ case Instruction::Opcode::FNeg: {
+ auto *UOp0 = cast<UnaryOperator>(Bndl[0]);
+ auto OpC = UOp0->getOpcode();
+ return UnaryOperator::createWithCopiedFlags(OpC, Operands[0], UOp0,
+ WhereIt, Ctx, "Vec");
+ }
+ case Instruction::Opcode::Add:
+ case Instruction::Opcode::FAdd:
+ case Instruction::Opcode::Sub:
+ case Instruction::Opcode::FSub:
+ case Instruction::Opcode::Mul:
+ case Instruction::Opcode::FMul:
+ case Instruction::Opcode::UDiv:
+ case Instruction::Opcode::SDiv:
+ case Instruction::Opcode::FDiv:
+ case Instruction::Opcode::URem:
+ case Instruction::Opcode::SRem:
+ case Instruction::Opcode::FRem:
+ case Instruction::Opcode::Shl:
+ case Instruction::Opcode::LShr:
+ case Instruction::Opcode::AShr:
+ case Instruction::Opcode::And:
+ case Instruction::Opcode::Or:
+ case Instruction::Opcode::Xor: {
+ auto *BinOp0 = cast<BinaryOperator>(Bndl[0]);
+ auto *LHS = Operands[0];
+ auto *RHS = Operands[1];
+ return BinaryOperator::createWithCopiedFlags(
+ BinOp0->getOpcode(), LHS, RHS, BinOp0, WhereIt, Ctx, "Vec");
+ }
+ case Instruction::Opcode::Load: {
+ auto *Ld0 = cast<LoadInst>(Bndl[0]);
+ Value *Ptr = Ld0->getPointerOperand();
+ return LoadInst::create(VecTy, Ptr, Ld0->getAlign(), WhereIt, Ctx,
+ "VecL");
+ }
+ case Instruction::Opcode::Store: {
+ auto Align = cast<StoreInst>(Bndl[0])->getAlign();
+ Value *Val = Operands[0];
+ Value *Ptr = Operands[1];
+ return StoreInst::create(Val, Ptr, Align, WhereIt, Ctx);
+ }
+ case Instruction::Opcode::Br:
+ case Instruction::Opcode::Ret:
+ case Instruction::Opcode::PHI:
+ case Instruction::Opcode::AddrSpaceCast:
+ case Instruction::Opcode::Call:
+ case Instruction::Opcode::GetElementPtr:
+ llvm_unreachable("Unimplemented");
+ break;
+ default:
+ llvm_unreachable("Unimplemented");
+ break;
+ }
+ llvm_unreachable("Missing switch case!");
+ // TODO: Propagate debug info.
+ };
+
+ auto *VecI = CreateVectorInstr(Bndl, Operands);
+ if (VecI != nullptr) {
+ Change = true;
+ IMaps.registerVector(Bndl, VecI);
}
- llvm_unreachable("Missing switch case!");
- // TODO: Propagate debug info.
+ return VecI;
}
void BottomUpVec::tryEraseDeadInstrs() {
@@ -280,6 +291,10 @@ Value *BottomUpVec::vectorizeRec(ArrayRef<Value *> Bndl, unsigned Depth) {
collectPotentiallyDeadInstrs(Bndl);
break;
}
+ case LegalityResultID::DiamondReuse: {
+ NewVec = cast<DiamondReuse>(LegalityRes).getVector();
+ break;
+ }
case LegalityResultID::Pack: {
// If we can't vectorize the seeds then just return.
if (Depth == 0)
@@ -300,9 +315,10 @@ bool BottomUpVec::tryVectorize(ArrayRef<Value *> Bndl) {
}
bool BottomUpVec::runOnFunction(Function &F, const Analyses &A) {
+ IMaps.clear();
Legality = std::make_unique<LegalityAnalysis>(
A.getAA(), A.getScalarEvolution(), F.getParent()->getDataLayout(),
- F.getContext());
+ F.getContext(), IMaps);
Change = false;
const auto &DL = F.getParent()->getDataLayout();
unsigned VecRegBits =
diff --git a/llvm/test/Transforms/SandboxVectorizer/bottomup_basic.ll b/llvm/test/Transforms/SandboxVectorizer/bottomup_basic.ll
index d34c8f88e4b3c6..7bc6e5ac3d7605 100644
--- a/llvm/test/Transforms/SandboxVectorizer/bottomup_basic.ll
+++ b/llvm/test/Transforms/SandboxVectorizer/bottomup_basic.ll
@@ -201,3 +201,23 @@ define void @pack_vectors(ptr %ptr, ptr %ptr2) {
store float %ld1, ptr %ptr1
ret void
}
+
+define void @diamond(ptr %ptr) {
+; CHECK-LABEL: define void @diamond(
+; CHECK-SAME: ptr [[PTR:%.*]]) {
+; CHECK-NEXT: [[PTR0:%.*]] = getelementptr float, ptr [[PTR]], i32 0
+; CHECK-NEXT: [[VECL:%.*]] = load <2 x float>, ptr [[PTR0]], align 4
+; CHECK-NEXT: [[VEC:%.*]] = fsub <2 x float> [[VECL]], [[VECL]]
+; CHECK-NEXT: store <2 x float> [[VEC]], ptr [[PTR0]], align 4
+; CHECK-NEXT: ret void
+;
+ %ptr0 = getelementptr float, ptr %ptr, i32 0
+ %ptr1 = getelementptr float, ptr %ptr, i32 1
+ %ld0 = load float, ptr %ptr0
+ %ld1 = load float, ptr %ptr1
+ %sub0 = fsub float %ld0, %ld0
+ %sub1 = fsub float %ld1, %ld1
+ store float %sub0, ptr %ptr0
+ store float %sub1, ptr %ptr1
+ ret void
+}
diff --git a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/CMakeLists.txt b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/CMakeLists.txt
index df689767b77245..bbfbcc730a4cbe 100644
--- a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/CMakeLists.txt
+++ b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/CMakeLists.txt
@@ -9,6 +9,7 @@ set(LLVM_LINK_COMPONENTS
add_llvm_unittest(SandboxVectorizerTests
DependencyGraphTest.cpp
+ InstrMapsTest.cpp
IntervalTest.cpp
LegalityTest.cpp
SchedulerTest.cpp
diff --git a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/InstrMapsTest.cpp b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/InstrMapsTest.cpp
new file mode 100644
index 00000000000000..bcfb8db7f86741
--- /dev/null
+++ b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/InstrMapsTest.cpp
@@ -0,0 +1,78 @@
+//===- InstrMapsTest.cpp --------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Transforms/Vectorize/SandboxVectorizer/InstrMaps.h"
+#include "llvm/ADT/SmallSet.h"
+#include "llvm/AsmParser/Parser.h"
+#include "llvm/SandboxIR/Function.h"
+#include "llvm/SandboxIR/Instruction.h"
+#include "llvm/Support/SourceMgr.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+using namespace llvm;
+
+struct InstrMapsTest : public testing::Test {
+ LLVMContext C;
+ std::unique_ptr<Module> M;
+
+ void parseIR(LLVMContext &C, const char *IR) {
+ SMDiagnostic Err;
+ M = parseAssemblyString(IR, Err, C);
+ if (!M)
+ Err.print("InstrMapsTest", errs());
+ }
+};
+
+TEST_F(InstrMapsTest, Basic) {
+ parseIR(C, R"IR(
+define void @foo(i8 %v0, i8 %v1, i8 %v2, i8 %v3, <2 x i8> %vec) {
+ %add0 = add i8 %v0, %v0
+ %add1 = add i8 %v1, %v1
+ %add2 = add i8 %v2, %v2
+ %add3 = add i8 %v3, %v3
+ %vadd0 = add <2 x i8> %vec, %vec
+ ret void
+}
+)IR");
+ llvm::Function *LLVMF = &*M->getFunction("foo");
+ sandboxir::Context Ctx(C);
+ auto *F = Ctx.createFunction(LLVMF);
+ auto *BB = &*F->begin();
+ auto It = BB->begin();
+
+ auto *Add0 = cast<sandboxir::BinaryOperator>(&*It++);
+ auto *Add1 = cast<sandboxir::BinaryOperator>(&*It++);
+ auto *Add2 = cast<sandboxir::BinaryOperator>(&*It++);
+ auto *Add3 = cast<sandboxir::BinaryOperator>(&*It++);
+ auto *VAdd0 = cast<sandboxir::BinaryOperator>(&*It++);
+ [[maybe_unused]] auto *Ret = cast<sandboxir::ReturnInst>(&*It++);
+
+ sandboxir::InstrMaps IMaps;
+ // Check with empty IMaps.
+ EXPECT_EQ(IMaps.getVectorForOrig(Add0), nullptr);
+ EXPECT_EQ(IMaps.getVectorForOrig(Add1), nullptr);
+ EXPECT_FALSE(IMaps.getOrigLane(Add0, Add0));
+ // Check with 1 match.
+ IMaps.registerVector({Add0, Add1}, VAdd0);
+ EXPECT_EQ(IMaps.getVectorForOrig(Add0), VAdd0);
+ EXPECT_EQ(IMaps.getVectorForOrig(Add1), VAdd0);
+ EXPECT_FALSE(IMaps.getOrigLane(VAdd0, VAdd0)); // Bad Orig value
+ EXPECT_FALSE(IMaps.getOrigLane(Add0, Add0)); // Bad Vector value
+ EXPECT_EQ(*IMaps.getOrigLane(VAdd0, Add0), 0);
+ EXPECT_EQ(*IMaps.getOrigLane(VAdd0, Add1), 1);
+ // Check when the same vector maps to
diff erent original values (which is
+ // common for vector constants).
+ IMaps.registerVector({Add2, Add3}, VAdd0);
+ EXPECT_EQ(*IMaps.getOrigLane(VAdd0, Add2), 0);
+ EXPECT_EQ(*IMaps.getOrigLane(VAdd0, Add3), 1);
+ // Check when we register for a second time.
+#ifndef NDEBUG
+ EXPECT_DEATH(IMaps.registerVector({Add1, Add0}, VAdd0), ".*exists.*");
+#endif // NDEBUG
+}
diff --git a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/LegalityTest.cpp b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/LegalityTest.cpp
index b5e2c302f5901e..2e90462a633c17 100644
--- a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/LegalityTest.cpp
+++ b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/LegalityTest.cpp
@@ -18,6 +18,7 @@
#include "llvm/SandboxIR/Function.h"
#include "llvm/SandboxIR/Instruction.h"
#include "llvm/Support/SourceMgr.h"
+#include "llvm/Transforms/Vectorize/SandboxVectorizer/InstrMaps.h"
#include "gtest/gtest.h"
using namespace llvm;
@@ -110,7 +111,8 @@ define void @foo(ptr %ptr, <2 x float> %vec2, <3 x float> %vec3, i8 %arg, float
auto *CmpSLT = cast<sandboxir::CmpInst>(&*It++);
auto *CmpSGT = cast<sandboxir::CmpInst>(&*It++);
- sandboxir::LegalityAnalysis Legality(*AA, *SE, DL, Ctx);
+ llvm::sandboxir::InstrMaps IMaps;
+ sandboxir::LegalityAnalysis Legality(*AA, *SE, DL, Ctx, IMaps);
const auto &Result =
Legality.canVectorize({St0, St1}, /*SkipScheduling=*/true);
EXPECT_TRUE(isa<sandboxir::Widen>(Result));
@@ -228,7 +230,8 @@ define void @foo(ptr %ptr) {
auto *St0 = cast<sandboxir::StoreInst>(&*It++);
auto *St1 = cast<sandboxir::StoreInst>(&*It++);
- sandboxir::LegalityAnalysis Legality(*AA, *SE, DL, Ctx);
+ llvm::sandboxir::InstrMaps IMaps;
+ sandboxir::LegalityAnalysis Legality(*AA, *SE, DL, Ctx, IMaps);
{
// Can vectorize St0,St1.
const auto &Result = Legality.canVectorize({St0, St1});
@@ -263,7 +266,8 @@ define void @foo() {
};
sandboxir::Context Ctx(C);
- sandboxir::LegalityAnalysis Legality(*AA, *SE, DL, Ctx);
+ llvm::sandboxir::InstrMaps IMaps;
+ sandboxir::LegalityAnalysis Legality(*AA, *SE, DL, Ctx, IMaps);
EXPECT_TRUE(
Matches(Legality.createLegalityResult<sandboxir::Widen>(), "Widen"));
EXPECT_TRUE(Matches(Legality.createLegalityResult<sandboxir::Pack>(
@@ -283,3 +287,68 @@ define void @foo() {
"Pack Reason: DiffWrapFlags"));
}
#endif // NDEBUG
+
+TEST_F(LegalityTest, CollectDescr) {
+ parseIR(C, R"IR(
+define void @foo(ptr %ptr) {
+ %gep0 = getelementptr float, ptr %ptr, i32 0
+ %gep1 = getelementptr float, ptr %ptr, i32 1
+ %ld0 = load float, ptr %gep0
+ %ld1 = load float, ptr %gep1
+ %vld = load <4 x float>, ptr %ptr
+ ret void
+}
+)IR");
+ llvm::Function *LLVMF = &*M->getFunction("foo");
+ getAnalyses(*LLVMF);
+ sandboxir::Context Ctx(C);
+ auto *F = Ctx.createFunction(LLVMF);
+ auto *BB = &*F->begin();
+ auto It = BB->begin();
+ [[maybe_unused]] auto *Gep0 = cast<sandboxir::GetElementPtrInst>(&*It++);
+ [[maybe_unused]] auto *Gep1 = cast<sandboxir::GetElementPtrInst>(&*It++);
+ auto *Ld0 = cast<sandboxir::LoadInst>(&*It++);
+ [[maybe_unused]] auto *Ld1 = cast<sandboxir::LoadInst>(&*It++);
+ auto *VLd = cast<sandboxir::LoadInst>(&*It++);
+
+ sandboxir::CollectDescr::DescrVecT Descrs;
+ using EEDescr = sandboxir::CollectDescr::ExtractElementDescr;
+
+ {
+ // Check single input, no shuffle.
+ Descrs.push_back(EEDescr(VLd, 0));
+ Descrs.push_back(EEDescr(VLd, 1));
+ sandboxir::CollectDescr CD(std::move(Descrs));
+ EXPECT_TRUE(CD.getSingleInput());
+ EXPECT_EQ(CD.getSingleInput()->first, VLd);
+ EXPECT_EQ(CD.getSingleInput()->second, false);
+ EXPECT_TRUE(CD.hasVectorInputs());
+ }
+ {
+ // Check single input, shuffle.
+ Descrs.push_back(EEDescr(VLd, 1));
+ Descrs.push_back(EEDescr(VLd, 0));
+ sandboxir::CollectDescr CD(std::move(Descrs));
+ EXPECT_TRUE(CD.getSingleInput());
+ EXPECT_EQ(CD.getSingleInput()->first, VLd);
+ EXPECT_EQ(CD.getSingleInput()->second, true);
+ EXPECT_TRUE(CD.hasVectorInputs());
+ }
+ {
+ // Check multiple inputs.
+ Descrs.push_back(EEDescr(Ld0));
+ Descrs.push_back(EEDescr(VLd, 0));
+ Descrs.push_back(EEDescr(VLd, 1));
+ sandboxir::CollectDescr CD(std::move(Descrs));
+ EXPECT_FALSE(CD.getSingleInput());
+ EXPECT_TRUE(CD.hasVectorInputs());
+ }
+ {
+ // Check multiple inputs only scalars.
+ Descrs.push_back(EEDescr(Ld0));
+ Descrs.push_back(EEDescr(Ld1));
+ sandboxir::CollectDescr CD(std::move(Descrs));
+ EXPECT_FALSE(CD.getSingleInput());
+ EXPECT_FALSE(CD.hasVectorInputs());
+ }
+}
More information about the llvm-commits
mailing list