[llvm] [SandboxVec][Legality] Per opcode checks (PR #114145)
via llvm-commits
llvm-commits at lists.llvm.org
Tue Oct 29 15:58:00 PDT 2024
https://github.com/vporpo created https://github.com/llvm/llvm-project/pull/114145
This patch adds more opcode-specific legality checks.
>From 77e6485a6a31d5d14d815540021bb54159a61df3 Mon Sep 17 00:00:00 2001
From: Vasileios Porpodas <vporpodas at google.com>
Date: Mon, 28 Oct 2024 11:21:34 -0700
Subject: [PATCH] [SandboxVec][Legality] Per opcode checks
This patch adds more opcode-specific legality checks.
---
.../Vectorize/SandboxVectorizer/Legality.h | 17 +-
.../SandboxVectorizer/Passes/BottomUpVec.h | 2 +-
.../Vectorize/SandboxVectorizer/VecUtils.h | 37 ++
.../Vectorize/SandboxVectorizer/Legality.cpp | 104 +++++-
.../SandboxVectorizer/Passes/BottomUpVec.cpp | 5 +-
.../SandboxVectorizer/LegalityTest.cpp | 92 ++++-
.../SandboxVectorizer/VecUtilsTest.cpp | 333 ++++++++++++++++++
7 files changed, 580 insertions(+), 10 deletions(-)
diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h
index 77ba5cd7f002e9..f43e033e3cc7e3 100644
--- a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h
+++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h
@@ -13,6 +13,8 @@
#define LLVM_TRANSFORMS_VECTORIZE_SANDBOXVECTORIZER_LEGALITY_H
#include "llvm/ADT/ArrayRef.h"
+#include "llvm/Analysis/ScalarEvolution.h"
+#include "llvm/IR/DataLayout.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/raw_ostream.h"
@@ -33,6 +35,9 @@ enum class ResultReason {
DiffTypes,
DiffMathFlags,
DiffWrapFlags,
+ NotConsecutive,
+ Unimplemented,
+ Infeasible,
};
#ifndef NDEBUG
@@ -59,6 +64,12 @@ struct ToStr {
return "DiffMathFlags";
case ResultReason::DiffWrapFlags:
return "DiffWrapFlags";
+ case ResultReason::NotConsecutive:
+ return "NotConsecutive";
+ case ResultReason::Unimplemented:
+ return "Unimplemented";
+ case ResultReason::Infeasible:
+ return "Infeasible";
}
llvm_unreachable("Unknown ResultReason enum");
}
@@ -142,8 +153,12 @@ class LegalityAnalysis {
std::optional<ResultReason>
notVectorizableBasedOnOpcodesAndTypes(ArrayRef<Value *> Bndl);
+ ScalarEvolution &SE;
+ const DataLayout &DL;
+
public:
- LegalityAnalysis() = default;
+ LegalityAnalysis(ScalarEvolution &SE, const DataLayout &DL)
+ : SE(SE), DL(DL) {}
/// A LegalityResult factory.
template <typename ResultT, typename... ArgsT>
ResultT &createLegalityResult(ArgsT... Args) {
diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.h
index 2b0b3f8192c048..7e0b88ae7197d4 100644
--- a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.h
+++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.h
@@ -24,7 +24,7 @@ namespace llvm::sandboxir {
class BottomUpVec final : public FunctionPass {
bool Change = false;
- LegalityAnalysis Legality;
+ std::unique_ptr<LegalityAnalysis> Legality;
void vectorizeRec(ArrayRef<Value *> Bndl);
void tryVectorize(ArrayRef<Value *> Seeds);
diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/VecUtils.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/VecUtils.h
index 9577e8ef7b37cb..8b64ec58da345d 100644
--- a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/VecUtils.h
+++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/VecUtils.h
@@ -12,7 +12,10 @@
#ifndef LLVM_TRANSFORMS_VECTORIZE_SANDBOXVECTORIZER_VECUTILS_H
#define LLVM_TRANSFORMS_VECTORIZE_SANDBOXVECTORIZER_VECUTILS_H
+#include "llvm/Analysis/ScalarEvolution.h"
+#include "llvm/IR/DataLayout.h"
#include "llvm/SandboxIR/Type.h"
+#include "llvm/SandboxIR/Utils.h"
namespace llvm::sandboxir {
@@ -29,6 +32,40 @@ class VecUtils {
static Type *getElementType(Type *Ty) {
return Ty->isVectorTy() ? cast<FixedVectorType>(Ty)->getElementType() : Ty;
}
+
+ /// \Returns true if \p I1 and \p I2 are load/stores accessing consecutive
+ /// memory addresses.
+ template <typename LoadOrStoreT>
+ static bool areConsecutive(LoadOrStoreT *I1, LoadOrStoreT *I2,
+ ScalarEvolution &SE, const DataLayout &DL) {
+ static_assert(std::is_same<LoadOrStoreT, LoadInst>::value ||
+ std::is_same<LoadOrStoreT, StoreInst>::value,
+ "Expected Load or Store!");
+ auto Diff = Utils::getPointerDiffInBytes(I1, I2, SE);
+ if (!Diff)
+ return false;
+ int ElmBytes = Utils::getNumBits(I1) / 8;
+ return *Diff == ElmBytes;
+ }
+
+ template <typename LoadOrStoreT>
+ static bool areConsecutive(ArrayRef<Value *> &Bndl, ScalarEvolution &SE,
+ const DataLayout &DL) {
+ static_assert(std::is_same<LoadOrStoreT, LoadInst>::value ||
+ std::is_same<LoadOrStoreT, StoreInst>::value,
+ "Expected Load or Store!");
+ assert(isa<LoadOrStoreT>(Bndl[0]) && "Expected Load or Store!");
+ auto *LastLS = cast<LoadOrStoreT>(Bndl[0]);
+ for (Value *V : drop_begin(Bndl)) {
+ assert(isa<LoadOrStoreT>(V) &&
+ "Unimplemented: we only support StoreInst!");
+ auto *LS = cast<LoadOrStoreT>(V);
+ if (!VecUtils::areConsecutive(LastLS, LS, SE, DL))
+ return false;
+ LastLS = LS;
+ }
+ return true;
+ }
};
} // namespace llvm::sandboxir
diff --git a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Legality.cpp b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Legality.cpp
index 1cc6356300e492..1efd178778b9f6 100644
--- a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Legality.cpp
+++ b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Legality.cpp
@@ -70,7 +70,109 @@ LegalityAnalysis::notVectorizableBasedOnOpcodesAndTypes(
}
}
- // TODO: Missing checks
+ // Now we need to do further checks for specific opcodes.
+ switch (Opcode) {
+ case Instruction::Opcode::ZExt:
+ case Instruction::Opcode::SExt:
+ case Instruction::Opcode::FPToUI:
+ case Instruction::Opcode::FPToSI:
+ case Instruction::Opcode::FPExt:
+ case Instruction::Opcode::PtrToInt:
+ case Instruction::Opcode::IntToPtr:
+ case Instruction::Opcode::SIToFP:
+ case Instruction::Opcode::UIToFP:
+ case Instruction::Opcode::Trunc:
+ case Instruction::Opcode::FPTrunc:
+ case Instruction::Opcode::BitCast: {
+ // We have already checked that they are of the same opcode.
+ assert(all_of(Bndl,
+ [Opcode](Value *V) {
+ return cast<Instruction>(V)->getOpcode() == Opcode;
+ }) &&
+ "Different opcodes, should have early returned!");
+ // But for these opcodes we should also check the operand type.
+ Type *FromTy0 = Utils::getExpectedType(I0->getOperand(0));
+ if (any_of(drop_begin(Bndl), [FromTy0](Value *V) {
+ return Utils::getExpectedType(cast<User>(V)->getOperand(0)) !=
+ FromTy0;
+ }))
+ return ResultReason::DiffTypes;
+ return std::nullopt;
+ }
+ case Instruction::Opcode::FCmp:
+ case Instruction::Opcode::ICmp: {
+ // We need the same predicate..
+ auto Pred0 = cast<CmpInst>(I0)->getPredicate();
+ bool Same = all_of(Bndl, [Pred0](Value *V) {
+ return cast<CmpInst>(V)->getPredicate() == Pred0;
+ });
+ if (Same)
+ return std::nullopt;
+ return ResultReason::DiffOpcodes;
+ }
+ case Instruction::Opcode::Select:
+ case Instruction::Opcode::FNeg:
+ case Instruction::Opcode::Add:
+ case Instruction::Opcode::FAdd:
+ case Instruction::Opcode::Sub:
+ case Instruction::Opcode::FSub:
+ case Instruction::Opcode::Mul:
+ case Instruction::Opcode::FMul:
+ case Instruction::Opcode::FRem:
+ case Instruction::Opcode::UDiv:
+ case Instruction::Opcode::SDiv:
+ case Instruction::Opcode::FDiv:
+ case Instruction::Opcode::URem:
+ case Instruction::Opcode::SRem:
+ case Instruction::Opcode::Shl:
+ case Instruction::Opcode::LShr:
+ case Instruction::Opcode::AShr:
+ case Instruction::Opcode::And:
+ case Instruction::Opcode::Or:
+ case Instruction::Opcode::Xor:
+ return std::nullopt;
+ case Instruction::Opcode::Load:
+ if (VecUtils::areConsecutive<LoadInst>(Bndl, SE, DL))
+ return std::nullopt;
+ return ResultReason::NotConsecutive;
+ case Instruction::Opcode::Store:
+ if (VecUtils::areConsecutive<StoreInst>(Bndl, SE, DL))
+ return std::nullopt;
+ return ResultReason::NotConsecutive;
+ case Instruction::Opcode::PHI:
+ return ResultReason::Unimplemented;
+ case Instruction::Opcode::Opaque:
+ return ResultReason::Unimplemented;
+ case Instruction::Opcode::Br:
+ case Instruction::Opcode::Ret:
+ case Instruction::Opcode::AddrSpaceCast:
+ case Instruction::Opcode::InsertElement:
+ case Instruction::Opcode::InsertValue:
+ case Instruction::Opcode::ExtractElement:
+ case Instruction::Opcode::ExtractValue:
+ case Instruction::Opcode::ShuffleVector:
+ case Instruction::Opcode::Call:
+ case Instruction::Opcode::GetElementPtr:
+ case Instruction::Opcode::Switch:
+ return ResultReason::Unimplemented;
+ case Instruction::Opcode::VAArg:
+ case Instruction::Opcode::Freeze:
+ case Instruction::Opcode::Fence:
+ case Instruction::Opcode::Invoke:
+ case Instruction::Opcode::CallBr:
+ case Instruction::Opcode::LandingPad:
+ case Instruction::Opcode::CatchPad:
+ case Instruction::Opcode::CleanupPad:
+ case Instruction::Opcode::CatchRet:
+ case Instruction::Opcode::CleanupRet:
+ case Instruction::Opcode::Resume:
+ case Instruction::Opcode::CatchSwitch:
+ case Instruction::Opcode::AtomicRMW:
+ case Instruction::Opcode::AtomicCmpXchg:
+ case Instruction::Opcode::Alloca:
+ case Instruction::Opcode::Unreachable:
+ return ResultReason::Infeasible;
+ }
return std::nullopt;
}
diff --git a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp
index 66d631edfc4076..339330c64f0caa 100644
--- a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp
+++ b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp
@@ -11,6 +11,7 @@
#include "llvm/ADT/SmallVector.h"
#include "llvm/SandboxIR/Function.h"
#include "llvm/SandboxIR/Instruction.h"
+#include "llvm/SandboxIR/Module.h"
#include "llvm/Transforms/Vectorize/SandboxVectorizer/SandboxVectorizerPassBuilder.h"
namespace llvm::sandboxir {
@@ -40,7 +41,7 @@ static SmallVector<Value *, 4> getOperand(ArrayRef<Value *> Bndl,
}
void BottomUpVec::vectorizeRec(ArrayRef<Value *> Bndl) {
- const auto &LegalityRes = Legality.canVectorize(Bndl);
+ const auto &LegalityRes = Legality->canVectorize(Bndl);
switch (LegalityRes.getSubclassID()) {
case LegalityResultID::Widen: {
auto *I = cast<Instruction>(Bndl[0]);
@@ -60,6 +61,8 @@ void BottomUpVec::vectorizeRec(ArrayRef<Value *> Bndl) {
void BottomUpVec::tryVectorize(ArrayRef<Value *> Bndl) { vectorizeRec(Bndl); }
bool BottomUpVec::runOnFunction(Function &F, const Analyses &A) {
+ Legality = std::make_unique<LegalityAnalysis>(A.getScalarEvolution(),
+ F.getParent()->getDataLayout());
Change = false;
// TODO: Start from innermost BBs first
for (auto &BB : F) {
diff --git a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/LegalityTest.cpp b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/LegalityTest.cpp
index 50b78f6f48afdf..68557cb8b129f2 100644
--- a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/LegalityTest.cpp
+++ b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/LegalityTest.cpp
@@ -7,7 +7,13 @@
//===----------------------------------------------------------------------===//
#include "llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h"
+#include "llvm/Analysis/AssumptionCache.h"
+#include "llvm/Analysis/LoopInfo.h"
+#include "llvm/Analysis/ScalarEvolution.h"
+#include "llvm/Analysis/TargetLibraryInfo.h"
#include "llvm/AsmParser/Parser.h"
+#include "llvm/IR/DataLayout.h"
+#include "llvm/IR/Dominators.h"
#include "llvm/SandboxIR/Function.h"
#include "llvm/SandboxIR/Instruction.h"
#include "llvm/Support/SourceMgr.h"
@@ -18,6 +24,22 @@ using namespace llvm;
struct LegalityTest : public testing::Test {
LLVMContext C;
std::unique_ptr<Module> M;
+ std::unique_ptr<DominatorTree> DT;
+ std::unique_ptr<TargetLibraryInfoImpl> TLII;
+ std::unique_ptr<TargetLibraryInfo> TLI;
+ std::unique_ptr<AssumptionCache> AC;
+ std::unique_ptr<LoopInfo> LI;
+ std::unique_ptr<ScalarEvolution> SE;
+
+ ScalarEvolution &getSE(llvm::Function &LLVMF) {
+ DT = std::make_unique<DominatorTree>(LLVMF);
+ TLII = std::make_unique<TargetLibraryInfoImpl>();
+ TLI = std::make_unique<TargetLibraryInfo>(*TLII);
+ AC = std::make_unique<AssumptionCache>(LLVMF);
+ LI = std::make_unique<LoopInfo>(*DT);
+ SE = std::make_unique<ScalarEvolution>(LLVMF, *TLI, *AC, *DT, *LI);
+ return *SE;
+ }
void parseIR(LLVMContext &C, const char *IR) {
SMDiagnostic Err;
@@ -29,12 +51,14 @@ struct LegalityTest : public testing::Test {
TEST_F(LegalityTest, Legality) {
parseIR(C, R"IR(
-define void @foo(ptr %ptr, <2 x float> %vec2, <3 x float> %vec3, i8 %arg, float %farg0, float %farg1, i64 %v0, i64 %v1) {
+define void @foo(ptr %ptr, <2 x float> %vec2, <3 x float> %vec3, i8 %arg, float %farg0, float %farg1, i64 %v0, i64 %v1, i32 %v2) {
%gep0 = getelementptr float, ptr %ptr, i32 0
%gep1 = getelementptr float, ptr %ptr, i32 1
%gep3 = getelementptr float, ptr %ptr, i32 3
%ld0 = load float, ptr %gep0
- %ld1 = load float, ptr %gep0
+ %ld0b = load float, ptr %gep0
+ %ld1 = load float, ptr %gep1
+ %ld3 = load float, ptr %gep3
store float %ld0, ptr %gep0
store float %ld1, ptr %gep1
store <2 x float> %vec2, ptr %gep1
@@ -44,10 +68,17 @@ define void @foo(ptr %ptr, <2 x float> %vec2, <3 x float> %vec3, i8 %arg, float
%fadd1 = fadd fast float %farg1, %farg1
%trunc0 = trunc nuw nsw i64 %v0 to i8
%trunc1 = trunc nsw i64 %v1 to i8
+ %trunc64to8 = trunc i64 %v0 to i8
+ %trunc32to8 = trunc i32 %v2 to i8
+ %cmpSLT = icmp slt i64 %v0, %v1
+ %cmpSGT = icmp sgt i64 %v0, %v1
ret void
}
)IR");
llvm::Function *LLVMF = &*M->getFunction("foo");
+ auto &SE = getSE(*LLVMF);
+ const auto &DL = M->getDataLayout();
+
sandboxir::Context Ctx(C);
auto *F = Ctx.createFunction(LLVMF);
auto *BB = &*F->begin();
@@ -55,8 +86,10 @@ define void @foo(ptr %ptr, <2 x float> %vec2, <3 x float> %vec3, i8 %arg, float
[[maybe_unused]] auto *Gep0 = cast<sandboxir::GetElementPtrInst>(&*It++);
[[maybe_unused]] auto *Gep1 = cast<sandboxir::GetElementPtrInst>(&*It++);
[[maybe_unused]] auto *Gep3 = cast<sandboxir::GetElementPtrInst>(&*It++);
- [[maybe_unused]] auto *Ld0 = cast<sandboxir::LoadInst>(&*It++);
- [[maybe_unused]] auto *Ld1 = cast<sandboxir::LoadInst>(&*It++);
+ auto *Ld0 = cast<sandboxir::LoadInst>(&*It++);
+ auto *Ld0b = cast<sandboxir::LoadInst>(&*It++);
+ auto *Ld1 = cast<sandboxir::LoadInst>(&*It++);
+ auto *Ld3 = cast<sandboxir::LoadInst>(&*It++);
auto *St0 = cast<sandboxir::StoreInst>(&*It++);
auto *St1 = cast<sandboxir::StoreInst>(&*It++);
auto *StVec2 = cast<sandboxir::StoreInst>(&*It++);
@@ -66,8 +99,12 @@ define void @foo(ptr %ptr, <2 x float> %vec2, <3 x float> %vec3, i8 %arg, float
auto *FAdd1 = cast<sandboxir::BinaryOperator>(&*It++);
auto *Trunc0 = cast<sandboxir::TruncInst>(&*It++);
auto *Trunc1 = cast<sandboxir::TruncInst>(&*It++);
+ auto *Trunc64to8 = cast<sandboxir::TruncInst>(&*It++);
+ auto *Trunc32to8 = cast<sandboxir::TruncInst>(&*It++);
+ auto *CmpSLT = cast<sandboxir::CmpInst>(&*It++);
+ auto *CmpSGT = cast<sandboxir::CmpInst>(&*It++);
- sandboxir::LegalityAnalysis Legality;
+ sandboxir::LegalityAnalysis Legality(SE, DL);
const auto &Result = Legality.canVectorize({St0, St1});
EXPECT_TRUE(isa<sandboxir::Widen>(Result));
@@ -109,10 +146,52 @@ define void @foo(ptr %ptr, <2 x float> %vec2, <3 x float> %vec3, i8 %arg, float
EXPECT_EQ(cast<sandboxir::Pack>(Result).getReason(),
sandboxir::ResultReason::DiffWrapFlags);
}
+ {
+ // Check DiffTypes for unary operands that have a different type.
+ const auto &Result = Legality.canVectorize({Trunc64to8, Trunc32to8});
+ EXPECT_TRUE(isa<sandboxir::Pack>(Result));
+ EXPECT_EQ(cast<sandboxir::Pack>(Result).getReason(),
+ sandboxir::ResultReason::DiffTypes);
+ }
+ {
+ // Check DiffOpcodes for CMPs with different predicates.
+ const auto &Result = Legality.canVectorize({CmpSLT, CmpSGT});
+ EXPECT_TRUE(isa<sandboxir::Pack>(Result));
+ EXPECT_EQ(cast<sandboxir::Pack>(Result).getReason(),
+ sandboxir::ResultReason::DiffOpcodes);
+ }
+ {
+ // Check NotConsecutive Ld0,Ld0b
+ const auto &Result = Legality.canVectorize({Ld0, Ld0b});
+ EXPECT_TRUE(isa<sandboxir::Pack>(Result));
+ EXPECT_EQ(cast<sandboxir::Pack>(Result).getReason(),
+ sandboxir::ResultReason::NotConsecutive);
+ }
+ {
+ // Check NotConsecutive Ld0,Ld3
+ const auto &Result = Legality.canVectorize({Ld0, Ld3});
+ EXPECT_TRUE(isa<sandboxir::Pack>(Result));
+ EXPECT_EQ(cast<sandboxir::Pack>(Result).getReason(),
+ sandboxir::ResultReason::NotConsecutive);
+ }
+ {
+ // Check Widen Ld0,Ld1
+ const auto &Result = Legality.canVectorize({Ld0, Ld1});
+ EXPECT_TRUE(isa<sandboxir::Widen>(Result));
+ }
}
#ifndef NDEBUG
TEST_F(LegalityTest, LegalityResultDump) {
+ parseIR(C, R"IR(
+define void @foo() {
+ ret void
+}
+)IR");
+ llvm::Function *LLVMF = &*M->getFunction("foo");
+ auto &SE = getSE(*LLVMF);
+ const auto &DL = M->getDataLayout();
+
auto Matches = [](const sandboxir::LegalityResult &Result,
const std::string &ExpectedStr) -> bool {
std::string Buff;
@@ -120,7 +199,8 @@ TEST_F(LegalityTest, LegalityResultDump) {
Result.print(OS);
return Buff == ExpectedStr;
};
- sandboxir::LegalityAnalysis Legality;
+
+ sandboxir::LegalityAnalysis Legality(SE, DL);
EXPECT_TRUE(
Matches(Legality.createLegalityResult<sandboxir::Widen>(), "Widen"));
EXPECT_TRUE(Matches(Legality.createLegalityResult<sandboxir::Pack>(
diff --git a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/VecUtilsTest.cpp b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/VecUtilsTest.cpp
index e0b08284964392..75f72ce23fbaac 100644
--- a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/VecUtilsTest.cpp
+++ b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/VecUtilsTest.cpp
@@ -7,15 +7,47 @@
//===----------------------------------------------------------------------===//
#include "llvm/Transforms/Vectorize/SandboxVectorizer/VecUtils.h"
+#include "llvm/Analysis/AliasAnalysis.h"
+#include "llvm/Analysis/AssumptionCache.h"
+#include "llvm/Analysis/BasicAliasAnalysis.h"
+#include "llvm/Analysis/LoopInfo.h"
+#include "llvm/Analysis/ScalarEvolution.h"
+#include "llvm/Analysis/TargetLibraryInfo.h"
#include "llvm/AsmParser/Parser.h"
+#include "llvm/IR/DataLayout.h"
+#include "llvm/IR/Dominators.h"
#include "llvm/SandboxIR/Context.h"
+#include "llvm/SandboxIR/Function.h"
#include "llvm/SandboxIR/Type.h"
+#include "llvm/Support/SourceMgr.h"
#include "gtest/gtest.h"
using namespace llvm;
struct VecUtilsTest : public testing::Test {
LLVMContext C;
+ std::unique_ptr<Module> M;
+ std::unique_ptr<AssumptionCache> AC;
+ std::unique_ptr<TargetLibraryInfoImpl> TLII;
+ std::unique_ptr<TargetLibraryInfo> TLI;
+ std::unique_ptr<DominatorTree> DT;
+ std::unique_ptr<LoopInfo> LI;
+ std::unique_ptr<ScalarEvolution> SE;
+ void parseIR(const char *IR) {
+ SMDiagnostic Err;
+ M = parseAssemblyString(IR, Err, C);
+ if (!M)
+ Err.print("VecUtilsTest", errs());
+ }
+ ScalarEvolution &getSE(llvm::Function &LLVMF) {
+ TLII = std::make_unique<TargetLibraryInfoImpl>();
+ TLI = std::make_unique<TargetLibraryInfo>(*TLII);
+ AC = std::make_unique<AssumptionCache>(LLVMF);
+ DT = std::make_unique<DominatorTree>(LLVMF);
+ LI = std::make_unique<LoopInfo>(*DT);
+ SE = std::make_unique<ScalarEvolution>(LLVMF, *TLI, *AC, *DT, *LI);
+ return *SE;
+ }
};
TEST_F(VecUtilsTest, GetNumElements) {
@@ -35,3 +67,304 @@ TEST_F(VecUtilsTest, GetElementType) {
auto *VTy = sandboxir::FixedVectorType::get(ElemTy, 2);
EXPECT_EQ(sandboxir::VecUtils::getElementType(VTy), ElemTy);
}
+
+TEST_F(VecUtilsTest, AreConsecutive_gep_float) {
+ parseIR(R"IR(
+define void @foo(ptr %ptr) {
+ %gep0 = getelementptr inbounds float, ptr %ptr, i64 0
+ %gep1 = getelementptr inbounds float, ptr %ptr, i64 1
+ %gep2 = getelementptr inbounds float, ptr %ptr, i64 2
+ %gep3 = getelementptr inbounds float, ptr %ptr, i64 3
+
+ %ld0 = load float, ptr %gep0
+ %ld1 = load float, ptr %gep1
+ %ld2 = load float, ptr %gep2
+ %ld3 = load float, ptr %gep3
+
+ %v2ld0 = load <2 x float>, ptr %gep0
+ %v2ld1 = load <2 x float>, ptr %gep1
+ %v2ld2 = load <2 x float>, ptr %gep2
+ %v2ld3 = load <2 x float>, ptr %gep3
+
+ %v3ld0 = load <3 x float>, ptr %gep0
+ %v3ld1 = load <3 x float>, ptr %gep1
+ %v3ld2 = load <3 x float>, ptr %gep2
+ %v3ld3 = load <3 x float>, ptr %gep3
+ ret void
+}
+)IR");
+ Function &LLVMF = *M->getFunction("foo");
+ const DataLayout &DL = M->getDataLayout();
+ auto &SE = getSE(LLVMF);
+
+ sandboxir::Context Ctx(C);
+ auto &F = *Ctx.createFunction(&LLVMF);
+
+ auto &BB = *F.begin();
+ auto It = std::next(BB.begin(), 4);
+ auto *L0 = cast<sandboxir::LoadInst>(&*It++);
+ auto *L1 = cast<sandboxir::LoadInst>(&*It++);
+ auto *L2 = cast<sandboxir::LoadInst>(&*It++);
+ auto *L3 = cast<sandboxir::LoadInst>(&*It++);
+
+ auto *V2L0 = cast<sandboxir::LoadInst>(&*It++);
+ auto *V2L1 = cast<sandboxir::LoadInst>(&*It++);
+ auto *V2L2 = cast<sandboxir::LoadInst>(&*It++);
+ auto *V2L3 = cast<sandboxir::LoadInst>(&*It++);
+
+ auto *V3L0 = cast<sandboxir::LoadInst>(&*It++);
+ auto *V3L1 = cast<sandboxir::LoadInst>(&*It++);
+ auto *V3L2 = cast<sandboxir::LoadInst>(&*It++);
+ auto *V3L3 = cast<sandboxir::LoadInst>(&*It++);
+
+ // Scalar
+ EXPECT_TRUE(sandboxir::VecUtils::areConsecutive(L0, L1, SE, DL));
+ EXPECT_TRUE(sandboxir::VecUtils::areConsecutive(L1, L2, SE, DL));
+ EXPECT_TRUE(sandboxir::VecUtils::areConsecutive(L2, L3, SE, DL));
+ EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(L1, L0, SE, DL));
+ EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(L2, L1, SE, DL));
+ EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(L3, L2, SE, DL));
+ EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(L0, L2, SE, DL));
+ EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(L0, L3, SE, DL));
+ EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(L1, L3, SE, DL));
+ EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(L2, L0, SE, DL));
+ EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(L3, L1, SE, DL));
+
+ // Check 2-wide loads
+ EXPECT_TRUE(sandboxir::VecUtils::areConsecutive(V2L0, V2L2, SE, DL));
+ EXPECT_TRUE(sandboxir::VecUtils::areConsecutive(V2L1, V2L3, SE, DL));
+ EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V2L0, V2L1, SE, DL));
+ EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V2L1, V2L2, SE, DL));
+ EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V2L2, V2L3, SE, DL));
+
+ EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V2L3, V2L1, SE, DL));
+ EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V2L3, V2L1, SE, DL));
+ EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V2L3, V2L1, SE, DL));
+ EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V2L3, V2L1, SE, DL));
+
+ // Check 3-wide loads
+ EXPECT_TRUE(sandboxir::VecUtils::areConsecutive(V3L0, V3L3, SE, DL));
+ EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V3L0, V3L1, SE, DL));
+ EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V3L1, V3L2, SE, DL));
+ EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V3L2, V3L3, SE, DL));
+ EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V3L1, V3L0, SE, DL));
+ EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V3L2, V3L1, SE, DL));
+ EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V3L3, V3L2, SE, DL));
+
+ // Check mixes of vectors and scalar
+ EXPECT_TRUE(sandboxir::VecUtils::areConsecutive(L0, V2L1, SE, DL));
+ EXPECT_TRUE(sandboxir::VecUtils::areConsecutive(L1, V2L2, SE, DL));
+ EXPECT_TRUE(sandboxir::VecUtils::areConsecutive(V2L0, L2, SE, DL));
+ EXPECT_TRUE(sandboxir::VecUtils::areConsecutive(V3L0, L3, SE, DL));
+ EXPECT_TRUE(sandboxir::VecUtils::areConsecutive(V2L0, V3L2, SE, DL));
+
+ EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(L0, V2L2, SE, DL));
+ EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(L0, V3L2, SE, DL));
+ EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(L0, V2L3, SE, DL));
+ EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V2L0, V3L1, SE, DL));
+ EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V3L0, L1, SE, DL));
+ EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V3L0, L2, SE, DL));
+ EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V3L0, V2L1, SE, DL));
+ EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V3L0, V2L2, SE, DL));
+ EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V2L1, L0, SE, DL));
+}
+
+TEST_F(VecUtilsTest, AreConsecutive_gep_i8) {
+ parseIR(R"IR(
+define void @foo(ptr %ptr) {
+ %gep0 = getelementptr inbounds i8, ptr %ptr, i64 0
+ %gep1 = getelementptr inbounds i8, ptr %ptr, i64 4
+ %gep2 = getelementptr inbounds i8, ptr %ptr, i64 8
+ %gep3 = getelementptr inbounds i8, ptr %ptr, i64 12
+
+ %ld0 = load float, ptr %gep0
+ %ld1 = load float, ptr %gep1
+ %ld2 = load float, ptr %gep2
+ %ld3 = load float, ptr %gep3
+
+ %v2ld0 = load <2 x float>, ptr %gep0
+ %v2ld1 = load <2 x float>, ptr %gep1
+ %v2ld2 = load <2 x float>, ptr %gep2
+ %v2ld3 = load <2 x float>, ptr %gep3
+
+ %v3ld0 = load <3 x float>, ptr %gep0
+ %v3ld1 = load <3 x float>, ptr %gep1
+ %v3ld2 = load <3 x float>, ptr %gep2
+ %v3ld3 = load <3 x float>, ptr %gep3
+ ret void
+}
+)IR");
+ Function &LLVMF = *M->getFunction("foo");
+ const DataLayout &DL = M->getDataLayout();
+ auto &SE = getSE(LLVMF);
+
+ sandboxir::Context Ctx(C);
+ auto &F = *Ctx.createFunction(&LLVMF);
+ auto &BB = *F.begin();
+ auto It = std::next(BB.begin(), 4);
+ auto *L0 = cast<sandboxir::LoadInst>(&*It++);
+ auto *L1 = cast<sandboxir::LoadInst>(&*It++);
+ auto *L2 = cast<sandboxir::LoadInst>(&*It++);
+ auto *L3 = cast<sandboxir::LoadInst>(&*It++);
+
+ auto *V2L0 = cast<sandboxir::LoadInst>(&*It++);
+ auto *V2L1 = cast<sandboxir::LoadInst>(&*It++);
+ auto *V2L2 = cast<sandboxir::LoadInst>(&*It++);
+ auto *V2L3 = cast<sandboxir::LoadInst>(&*It++);
+
+ auto *V3L0 = cast<sandboxir::LoadInst>(&*It++);
+ auto *V3L1 = cast<sandboxir::LoadInst>(&*It++);
+ auto *V3L2 = cast<sandboxir::LoadInst>(&*It++);
+ auto *V3L3 = cast<sandboxir::LoadInst>(&*It++);
+
+ // Scalar
+ EXPECT_TRUE(sandboxir::VecUtils::areConsecutive(L0, L1, SE, DL));
+ EXPECT_TRUE(sandboxir::VecUtils::areConsecutive(L1, L2, SE, DL));
+ EXPECT_TRUE(sandboxir::VecUtils::areConsecutive(L2, L3, SE, DL));
+ EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(L1, L0, SE, DL));
+ EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(L2, L1, SE, DL));
+ EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(L3, L2, SE, DL));
+ EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(L0, L2, SE, DL));
+ EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(L0, L3, SE, DL));
+ EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(L1, L3, SE, DL));
+ EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(L2, L0, SE, DL));
+ EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(L3, L1, SE, DL));
+
+ // Check 2-wide loads
+ EXPECT_TRUE(sandboxir::VecUtils::areConsecutive(V2L0, V2L2, SE, DL));
+ EXPECT_TRUE(sandboxir::VecUtils::areConsecutive(V2L1, V2L3, SE, DL));
+ EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V2L0, V2L1, SE, DL));
+ EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V2L1, V2L2, SE, DL));
+ EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V2L2, V2L3, SE, DL));
+
+ EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V2L3, V2L1, SE, DL));
+ EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V2L3, V2L1, SE, DL));
+ EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V2L3, V2L1, SE, DL));
+ EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V2L3, V2L1, SE, DL));
+
+ // Check 3-wide loads
+ EXPECT_TRUE(sandboxir::VecUtils::areConsecutive(V3L0, V3L3, SE, DL));
+ EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V3L0, V3L1, SE, DL));
+ EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V3L1, V3L2, SE, DL));
+ EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V3L2, V3L3, SE, DL));
+ EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V3L1, V3L0, SE, DL));
+ EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V3L2, V3L1, SE, DL));
+ EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V3L3, V3L2, SE, DL));
+
+ // Check mixes of vectors and scalar
+ EXPECT_TRUE(sandboxir::VecUtils::areConsecutive(L0, V2L1, SE, DL));
+ EXPECT_TRUE(sandboxir::VecUtils::areConsecutive(L1, V2L2, SE, DL));
+ EXPECT_TRUE(sandboxir::VecUtils::areConsecutive(V2L0, L2, SE, DL));
+ EXPECT_TRUE(sandboxir::VecUtils::areConsecutive(V3L0, L3, SE, DL));
+ EXPECT_TRUE(sandboxir::VecUtils::areConsecutive(V2L0, V3L2, SE, DL));
+
+ EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(L0, V2L2, SE, DL));
+ EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(L0, V3L2, SE, DL));
+ EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(L0, V2L3, SE, DL));
+ EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V2L0, V3L1, SE, DL));
+ EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V3L0, L1, SE, DL));
+ EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V3L0, L2, SE, DL));
+ EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V3L0, V2L1, SE, DL));
+ EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V3L0, V2L2, SE, DL));
+ EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V2L1, L0, SE, DL));
+}
+
+TEST_F(VecUtilsTest, AreConsecutive_gep_i1) {
+ parseIR(R"IR(
+define void @foo(ptr %ptr) {
+ %gep0 = getelementptr inbounds i1, ptr %ptr, i64 0
+ %gep1 = getelementptr inbounds i2, ptr %ptr, i64 4
+ %gep2 = getelementptr inbounds i3, ptr %ptr, i64 8
+ %gep3 = getelementptr inbounds i7, ptr %ptr, i64 12
+
+ %ld0 = load float, ptr %gep0
+ %ld1 = load float, ptr %gep1
+ %ld2 = load float, ptr %gep2
+ %ld3 = load float, ptr %gep3
+
+ %v2ld0 = load <2 x float>, ptr %gep0
+ %v2ld1 = load <2 x float>, ptr %gep1
+ %v2ld2 = load <2 x float>, ptr %gep2
+ %v2ld3 = load <2 x float>, ptr %gep3
+
+ %v3ld0 = load <3 x float>, ptr %gep0
+ %v3ld1 = load <3 x float>, ptr %gep1
+ %v3ld2 = load <3 x float>, ptr %gep2
+ %v3ld3 = load <3 x float>, ptr %gep3
+ ret void
+}
+)IR");
+ Function &LLVMF = *M->getFunction("foo");
+ const DataLayout &DL = M->getDataLayout();
+ auto &SE = getSE(LLVMF);
+
+ sandboxir::Context Ctx(C);
+ auto &F = *Ctx.createFunction(&LLVMF);
+ auto &BB = *F.begin();
+ auto It = std::next(BB.begin(), 4);
+ auto *L0 = cast<sandboxir::LoadInst>(&*It++);
+ auto *L1 = cast<sandboxir::LoadInst>(&*It++);
+ auto *L2 = cast<sandboxir::LoadInst>(&*It++);
+ auto *L3 = cast<sandboxir::LoadInst>(&*It++);
+
+ auto *V2L0 = cast<sandboxir::LoadInst>(&*It++);
+ auto *V2L1 = cast<sandboxir::LoadInst>(&*It++);
+ auto *V2L2 = cast<sandboxir::LoadInst>(&*It++);
+ auto *V2L3 = cast<sandboxir::LoadInst>(&*It++);
+
+ auto *V3L0 = cast<sandboxir::LoadInst>(&*It++);
+ auto *V3L1 = cast<sandboxir::LoadInst>(&*It++);
+ auto *V3L2 = cast<sandboxir::LoadInst>(&*It++);
+ auto *V3L3 = cast<sandboxir::LoadInst>(&*It++);
+
+ // Scalar
+ EXPECT_TRUE(sandboxir::VecUtils::areConsecutive(L0, L1, SE, DL));
+ EXPECT_TRUE(sandboxir::VecUtils::areConsecutive(L1, L2, SE, DL));
+ EXPECT_TRUE(sandboxir::VecUtils::areConsecutive(L2, L3, SE, DL));
+ EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(L1, L0, SE, DL));
+ EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(L2, L1, SE, DL));
+ EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(L3, L2, SE, DL));
+ EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(L0, L2, SE, DL));
+ EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(L0, L3, SE, DL));
+ EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(L1, L3, SE, DL));
+ EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(L2, L0, SE, DL));
+ EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(L3, L1, SE, DL));
+
+ // Check 2-wide loads
+ EXPECT_TRUE(sandboxir::VecUtils::areConsecutive(V2L0, V2L2, SE, DL));
+ EXPECT_TRUE(sandboxir::VecUtils::areConsecutive(V2L1, V2L3, SE, DL));
+ EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V2L0, V2L1, SE, DL));
+ EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V2L1, V2L2, SE, DL));
+ EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V2L2, V2L3, SE, DL));
+
+ EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V2L3, V2L1, SE, DL));
+ EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V2L3, V2L1, SE, DL));
+ EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V2L3, V2L1, SE, DL));
+ EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V2L3, V2L1, SE, DL));
+
+ // Check 3-wide loads
+ EXPECT_TRUE(sandboxir::VecUtils::areConsecutive(V3L0, V3L3, SE, DL));
+ EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V3L0, V3L1, SE, DL));
+ EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V3L1, V3L2, SE, DL));
+ EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V3L2, V3L3, SE, DL));
+ EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V3L1, V3L0, SE, DL));
+ EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V3L2, V3L1, SE, DL));
+ EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V3L3, V3L2, SE, DL));
+
+ // Check mixes of vectors and scalar
+ EXPECT_TRUE(sandboxir::VecUtils::areConsecutive(L0, V2L1, SE, DL));
+ EXPECT_TRUE(sandboxir::VecUtils::areConsecutive(L1, V2L2, SE, DL));
+ EXPECT_TRUE(sandboxir::VecUtils::areConsecutive(V2L0, L2, SE, DL));
+ EXPECT_TRUE(sandboxir::VecUtils::areConsecutive(V3L0, L3, SE, DL));
+ EXPECT_TRUE(sandboxir::VecUtils::areConsecutive(V2L0, V3L2, SE, DL));
+
+ EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(L0, V2L2, SE, DL));
+ EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(L0, V3L2, SE, DL));
+ EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(L0, V2L3, SE, DL));
+ EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V2L0, V3L1, SE, DL));
+ EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V3L0, L1, SE, DL));
+ EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V3L0, L2, SE, DL));
+ EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V3L0, V2L1, SE, DL));
+ EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V3L0, V2L2, SE, DL));
+ EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V2L1, L0, SE, DL));
+}
More information about the llvm-commits
mailing list