[llvm] [SandboxVec][Legality] Per opcode checks (PR #114145)

via llvm-commits llvm-commits at lists.llvm.org
Tue Oct 29 15:58:00 PDT 2024


https://github.com/vporpo created https://github.com/llvm/llvm-project/pull/114145

This patch adds more opcode-specific legality checks.

>From 77e6485a6a31d5d14d815540021bb54159a61df3 Mon Sep 17 00:00:00 2001
From: Vasileios Porpodas <vporpodas at google.com>
Date: Mon, 28 Oct 2024 11:21:34 -0700
Subject: [PATCH] [SandboxVec][Legality] Per opcode checks

This patch adds more opcode-specific legality checks.
---
 .../Vectorize/SandboxVectorizer/Legality.h    |  17 +-
 .../SandboxVectorizer/Passes/BottomUpVec.h    |   2 +-
 .../Vectorize/SandboxVectorizer/VecUtils.h    |  37 ++
 .../Vectorize/SandboxVectorizer/Legality.cpp  | 104 +++++-
 .../SandboxVectorizer/Passes/BottomUpVec.cpp  |   5 +-
 .../SandboxVectorizer/LegalityTest.cpp        |  92 ++++-
 .../SandboxVectorizer/VecUtilsTest.cpp        | 333 ++++++++++++++++++
 7 files changed, 580 insertions(+), 10 deletions(-)

diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h
index 77ba5cd7f002e9..f43e033e3cc7e3 100644
--- a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h
+++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h
@@ -13,6 +13,8 @@
 #define LLVM_TRANSFORMS_VECTORIZE_SANDBOXVECTORIZER_LEGALITY_H
 
 #include "llvm/ADT/ArrayRef.h"
+#include "llvm/Analysis/ScalarEvolution.h"
+#include "llvm/IR/DataLayout.h"
 #include "llvm/Support/Casting.h"
 #include "llvm/Support/raw_ostream.h"
 
@@ -33,6 +35,9 @@ enum class ResultReason {
   DiffTypes,
   DiffMathFlags,
   DiffWrapFlags,
+  NotConsecutive,
+  Unimplemented,
+  Infeasible,
 };
 
 #ifndef NDEBUG
@@ -59,6 +64,12 @@ struct ToStr {
       return "DiffMathFlags";
     case ResultReason::DiffWrapFlags:
       return "DiffWrapFlags";
+    case ResultReason::NotConsecutive:
+      return "NotConsecutive";
+    case ResultReason::Unimplemented:
+      return "Unimplemented";
+    case ResultReason::Infeasible:
+      return "Infeasible";
     }
     llvm_unreachable("Unknown ResultReason enum");
   }
@@ -142,8 +153,12 @@ class LegalityAnalysis {
   std::optional<ResultReason>
   notVectorizableBasedOnOpcodesAndTypes(ArrayRef<Value *> Bndl);
 
+  ScalarEvolution &SE;
+  const DataLayout &DL;
+
 public:
-  LegalityAnalysis() = default;
+  LegalityAnalysis(ScalarEvolution &SE, const DataLayout &DL)
+      : SE(SE), DL(DL) {}
   /// A LegalityResult factory.
   template <typename ResultT, typename... ArgsT>
   ResultT &createLegalityResult(ArgsT... Args) {
diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.h
index 2b0b3f8192c048..7e0b88ae7197d4 100644
--- a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.h
+++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.h
@@ -24,7 +24,7 @@ namespace llvm::sandboxir {
 
 class BottomUpVec final : public FunctionPass {
   bool Change = false;
-  LegalityAnalysis Legality;
+  std::unique_ptr<LegalityAnalysis> Legality;
   void vectorizeRec(ArrayRef<Value *> Bndl);
   void tryVectorize(ArrayRef<Value *> Seeds);
 
diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/VecUtils.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/VecUtils.h
index 9577e8ef7b37cb..8b64ec58da345d 100644
--- a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/VecUtils.h
+++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/VecUtils.h
@@ -12,7 +12,10 @@
 #ifndef LLVM_TRANSFORMS_VECTORIZE_SANDBOXVECTORIZER_VECUTILS_H
 #define LLVM_TRANSFORMS_VECTORIZE_SANDBOXVECTORIZER_VECUTILS_H
 
+#include "llvm/Analysis/ScalarEvolution.h"
+#include "llvm/IR/DataLayout.h"
 #include "llvm/SandboxIR/Type.h"
+#include "llvm/SandboxIR/Utils.h"
 
 namespace llvm::sandboxir {
 
@@ -29,6 +32,40 @@ class VecUtils {
   static Type *getElementType(Type *Ty) {
     return Ty->isVectorTy() ? cast<FixedVectorType>(Ty)->getElementType() : Ty;
   }
+
+  /// \Returns true if \p I1 and \p I2 are load/stores accessing consecutive
+  /// memory addresses.
+  template <typename LoadOrStoreT>
+  static bool areConsecutive(LoadOrStoreT *I1, LoadOrStoreT *I2,
+                             ScalarEvolution &SE, const DataLayout &DL) {
+    static_assert(std::is_same<LoadOrStoreT, LoadInst>::value ||
+                      std::is_same<LoadOrStoreT, StoreInst>::value,
+                  "Expected Load or Store!");
+    auto Diff = Utils::getPointerDiffInBytes(I1, I2, SE);
+    if (!Diff)
+      return false;
+    int ElmBytes = Utils::getNumBits(I1) / 8;
+    return *Diff == ElmBytes;
+  }
+
+  template <typename LoadOrStoreT>
+  static bool areConsecutive(ArrayRef<Value *> &Bndl, ScalarEvolution &SE,
+                             const DataLayout &DL) {
+    static_assert(std::is_same<LoadOrStoreT, LoadInst>::value ||
+                      std::is_same<LoadOrStoreT, StoreInst>::value,
+                  "Expected Load or Store!");
+    assert(isa<LoadOrStoreT>(Bndl[0]) && "Expected Load or Store!");
+    auto *LastLS = cast<LoadOrStoreT>(Bndl[0]);
+    for (Value *V : drop_begin(Bndl)) {
+      assert(isa<LoadOrStoreT>(V) &&
+             "Unimplemented: we only support StoreInst!");
+      auto *LS = cast<LoadOrStoreT>(V);
+      if (!VecUtils::areConsecutive(LastLS, LS, SE, DL))
+        return false;
+      LastLS = LS;
+    }
+    return true;
+  }
 };
 
 } // namespace llvm::sandboxir
diff --git a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Legality.cpp b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Legality.cpp
index 1cc6356300e492..1efd178778b9f6 100644
--- a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Legality.cpp
+++ b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Legality.cpp
@@ -70,7 +70,109 @@ LegalityAnalysis::notVectorizableBasedOnOpcodesAndTypes(
     }
   }
 
-  // TODO: Missing checks
+  // Now we need to do further checks for specific opcodes.
+  switch (Opcode) {
+  case Instruction::Opcode::ZExt:
+  case Instruction::Opcode::SExt:
+  case Instruction::Opcode::FPToUI:
+  case Instruction::Opcode::FPToSI:
+  case Instruction::Opcode::FPExt:
+  case Instruction::Opcode::PtrToInt:
+  case Instruction::Opcode::IntToPtr:
+  case Instruction::Opcode::SIToFP:
+  case Instruction::Opcode::UIToFP:
+  case Instruction::Opcode::Trunc:
+  case Instruction::Opcode::FPTrunc:
+  case Instruction::Opcode::BitCast: {
+    // We have already checked that they are of the same opcode.
+    assert(all_of(Bndl,
+                  [Opcode](Value *V) {
+                    return cast<Instruction>(V)->getOpcode() == Opcode;
+                  }) &&
+           "Different opcodes, should have early returned!");
+    // But for these opcodes we should also check the operand type.
+    Type *FromTy0 = Utils::getExpectedType(I0->getOperand(0));
+    if (any_of(drop_begin(Bndl), [FromTy0](Value *V) {
+          return Utils::getExpectedType(cast<User>(V)->getOperand(0)) !=
+                 FromTy0;
+        }))
+      return ResultReason::DiffTypes;
+    return std::nullopt;
+  }
+  case Instruction::Opcode::FCmp:
+  case Instruction::Opcode::ICmp: {
+    // We need the same predicate..
+    auto Pred0 = cast<CmpInst>(I0)->getPredicate();
+    bool Same = all_of(Bndl, [Pred0](Value *V) {
+      return cast<CmpInst>(V)->getPredicate() == Pred0;
+    });
+    if (Same)
+      return std::nullopt;
+    return ResultReason::DiffOpcodes;
+  }
+  case Instruction::Opcode::Select:
+  case Instruction::Opcode::FNeg:
+  case Instruction::Opcode::Add:
+  case Instruction::Opcode::FAdd:
+  case Instruction::Opcode::Sub:
+  case Instruction::Opcode::FSub:
+  case Instruction::Opcode::Mul:
+  case Instruction::Opcode::FMul:
+  case Instruction::Opcode::FRem:
+  case Instruction::Opcode::UDiv:
+  case Instruction::Opcode::SDiv:
+  case Instruction::Opcode::FDiv:
+  case Instruction::Opcode::URem:
+  case Instruction::Opcode::SRem:
+  case Instruction::Opcode::Shl:
+  case Instruction::Opcode::LShr:
+  case Instruction::Opcode::AShr:
+  case Instruction::Opcode::And:
+  case Instruction::Opcode::Or:
+  case Instruction::Opcode::Xor:
+    return std::nullopt;
+  case Instruction::Opcode::Load:
+    if (VecUtils::areConsecutive<LoadInst>(Bndl, SE, DL))
+      return std::nullopt;
+    return ResultReason::NotConsecutive;
+  case Instruction::Opcode::Store:
+    if (VecUtils::areConsecutive<StoreInst>(Bndl, SE, DL))
+      return std::nullopt;
+    return ResultReason::NotConsecutive;
+  case Instruction::Opcode::PHI:
+    return ResultReason::Unimplemented;
+  case Instruction::Opcode::Opaque:
+    return ResultReason::Unimplemented;
+  case Instruction::Opcode::Br:
+  case Instruction::Opcode::Ret:
+  case Instruction::Opcode::AddrSpaceCast:
+  case Instruction::Opcode::InsertElement:
+  case Instruction::Opcode::InsertValue:
+  case Instruction::Opcode::ExtractElement:
+  case Instruction::Opcode::ExtractValue:
+  case Instruction::Opcode::ShuffleVector:
+  case Instruction::Opcode::Call:
+  case Instruction::Opcode::GetElementPtr:
+  case Instruction::Opcode::Switch:
+    return ResultReason::Unimplemented;
+  case Instruction::Opcode::VAArg:
+  case Instruction::Opcode::Freeze:
+  case Instruction::Opcode::Fence:
+  case Instruction::Opcode::Invoke:
+  case Instruction::Opcode::CallBr:
+  case Instruction::Opcode::LandingPad:
+  case Instruction::Opcode::CatchPad:
+  case Instruction::Opcode::CleanupPad:
+  case Instruction::Opcode::CatchRet:
+  case Instruction::Opcode::CleanupRet:
+  case Instruction::Opcode::Resume:
+  case Instruction::Opcode::CatchSwitch:
+  case Instruction::Opcode::AtomicRMW:
+  case Instruction::Opcode::AtomicCmpXchg:
+  case Instruction::Opcode::Alloca:
+  case Instruction::Opcode::Unreachable:
+    return ResultReason::Infeasible;
+  }
 
   return std::nullopt;
 }
diff --git a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp
index 66d631edfc4076..339330c64f0caa 100644
--- a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp
+++ b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp
@@ -11,6 +11,7 @@
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/SandboxIR/Function.h"
 #include "llvm/SandboxIR/Instruction.h"
+#include "llvm/SandboxIR/Module.h"
 #include "llvm/Transforms/Vectorize/SandboxVectorizer/SandboxVectorizerPassBuilder.h"
 
 namespace llvm::sandboxir {
@@ -40,7 +41,7 @@ static SmallVector<Value *, 4> getOperand(ArrayRef<Value *> Bndl,
 }
 
 void BottomUpVec::vectorizeRec(ArrayRef<Value *> Bndl) {
-  const auto &LegalityRes = Legality.canVectorize(Bndl);
+  const auto &LegalityRes = Legality->canVectorize(Bndl);
   switch (LegalityRes.getSubclassID()) {
   case LegalityResultID::Widen: {
     auto *I = cast<Instruction>(Bndl[0]);
@@ -60,6 +61,8 @@ void BottomUpVec::vectorizeRec(ArrayRef<Value *> Bndl) {
 void BottomUpVec::tryVectorize(ArrayRef<Value *> Bndl) { vectorizeRec(Bndl); }
 
 bool BottomUpVec::runOnFunction(Function &F, const Analyses &A) {
+  Legality = std::make_unique<LegalityAnalysis>(A.getScalarEvolution(),
+                                                F.getParent()->getDataLayout());
   Change = false;
   // TODO: Start from innermost BBs first
   for (auto &BB : F) {
diff --git a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/LegalityTest.cpp b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/LegalityTest.cpp
index 50b78f6f48afdf..68557cb8b129f2 100644
--- a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/LegalityTest.cpp
+++ b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/LegalityTest.cpp
@@ -7,7 +7,13 @@
 //===----------------------------------------------------------------------===//
 
 #include "llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h"
+#include "llvm/Analysis/AssumptionCache.h"
+#include "llvm/Analysis/LoopInfo.h"
+#include "llvm/Analysis/ScalarEvolution.h"
+#include "llvm/Analysis/TargetLibraryInfo.h"
 #include "llvm/AsmParser/Parser.h"
+#include "llvm/IR/DataLayout.h"
+#include "llvm/IR/Dominators.h"
 #include "llvm/SandboxIR/Function.h"
 #include "llvm/SandboxIR/Instruction.h"
 #include "llvm/Support/SourceMgr.h"
@@ -18,6 +24,22 @@ using namespace llvm;
 struct LegalityTest : public testing::Test {
   LLVMContext C;
   std::unique_ptr<Module> M;
+  std::unique_ptr<DominatorTree> DT;
+  std::unique_ptr<TargetLibraryInfoImpl> TLII;
+  std::unique_ptr<TargetLibraryInfo> TLI;
+  std::unique_ptr<AssumptionCache> AC;
+  std::unique_ptr<LoopInfo> LI;
+  std::unique_ptr<ScalarEvolution> SE;
+
+  ScalarEvolution &getSE(llvm::Function &LLVMF) {
+    DT = std::make_unique<DominatorTree>(LLVMF);
+    TLII = std::make_unique<TargetLibraryInfoImpl>();
+    TLI = std::make_unique<TargetLibraryInfo>(*TLII);
+    AC = std::make_unique<AssumptionCache>(LLVMF);
+    LI = std::make_unique<LoopInfo>(*DT);
+    SE = std::make_unique<ScalarEvolution>(LLVMF, *TLI, *AC, *DT, *LI);
+    return *SE;
+  }
 
   void parseIR(LLVMContext &C, const char *IR) {
     SMDiagnostic Err;
@@ -29,12 +51,14 @@ struct LegalityTest : public testing::Test {
 
 TEST_F(LegalityTest, Legality) {
   parseIR(C, R"IR(
-define void @foo(ptr %ptr, <2 x float> %vec2, <3 x float> %vec3, i8 %arg, float %farg0, float %farg1, i64 %v0, i64 %v1) {
+define void @foo(ptr %ptr, <2 x float> %vec2, <3 x float> %vec3, i8 %arg, float %farg0, float %farg1, i64 %v0, i64 %v1, i32 %v2) {
   %gep0 = getelementptr float, ptr %ptr, i32 0
   %gep1 = getelementptr float, ptr %ptr, i32 1
   %gep3 = getelementptr float, ptr %ptr, i32 3
   %ld0 = load float, ptr %gep0
-  %ld1 = load float, ptr %gep0
+  %ld0b = load float, ptr %gep0
+  %ld1 = load float, ptr %gep1
+  %ld3 = load float, ptr %gep3
   store float %ld0, ptr %gep0
   store float %ld1, ptr %gep1
   store <2 x float> %vec2, ptr %gep1
@@ -44,10 +68,17 @@ define void @foo(ptr %ptr, <2 x float> %vec2, <3 x float> %vec3, i8 %arg, float
   %fadd1 = fadd fast float %farg1, %farg1
   %trunc0 = trunc nuw nsw i64 %v0 to i8
   %trunc1 = trunc nsw i64 %v1 to i8
+  %trunc64to8 = trunc i64 %v0 to i8
+  %trunc32to8 = trunc i32 %v2 to i8
+  %cmpSLT = icmp slt i64 %v0, %v1
+  %cmpSGT = icmp sgt i64 %v0, %v1
   ret void
 }
 )IR");
   llvm::Function *LLVMF = &*M->getFunction("foo");
+  auto &SE = getSE(*LLVMF);
+  const auto &DL = M->getDataLayout();
+
   sandboxir::Context Ctx(C);
   auto *F = Ctx.createFunction(LLVMF);
   auto *BB = &*F->begin();
@@ -55,8 +86,10 @@ define void @foo(ptr %ptr, <2 x float> %vec2, <3 x float> %vec3, i8 %arg, float
   [[maybe_unused]] auto *Gep0 = cast<sandboxir::GetElementPtrInst>(&*It++);
   [[maybe_unused]] auto *Gep1 = cast<sandboxir::GetElementPtrInst>(&*It++);
   [[maybe_unused]] auto *Gep3 = cast<sandboxir::GetElementPtrInst>(&*It++);
-  [[maybe_unused]] auto *Ld0 = cast<sandboxir::LoadInst>(&*It++);
-  [[maybe_unused]] auto *Ld1 = cast<sandboxir::LoadInst>(&*It++);
+  auto *Ld0 = cast<sandboxir::LoadInst>(&*It++);
+  auto *Ld0b = cast<sandboxir::LoadInst>(&*It++);
+  auto *Ld1 = cast<sandboxir::LoadInst>(&*It++);
+  auto *Ld3 = cast<sandboxir::LoadInst>(&*It++);
   auto *St0 = cast<sandboxir::StoreInst>(&*It++);
   auto *St1 = cast<sandboxir::StoreInst>(&*It++);
   auto *StVec2 = cast<sandboxir::StoreInst>(&*It++);
@@ -66,8 +99,12 @@ define void @foo(ptr %ptr, <2 x float> %vec2, <3 x float> %vec3, i8 %arg, float
   auto *FAdd1 = cast<sandboxir::BinaryOperator>(&*It++);
   auto *Trunc0 = cast<sandboxir::TruncInst>(&*It++);
   auto *Trunc1 = cast<sandboxir::TruncInst>(&*It++);
+  auto *Trunc64to8 = cast<sandboxir::TruncInst>(&*It++);
+  auto *Trunc32to8 = cast<sandboxir::TruncInst>(&*It++);
+  auto *CmpSLT = cast<sandboxir::CmpInst>(&*It++);
+  auto *CmpSGT = cast<sandboxir::CmpInst>(&*It++);
 
-  sandboxir::LegalityAnalysis Legality;
+  sandboxir::LegalityAnalysis Legality(SE, DL);
   const auto &Result = Legality.canVectorize({St0, St1});
   EXPECT_TRUE(isa<sandboxir::Widen>(Result));
 
@@ -109,10 +146,52 @@ define void @foo(ptr %ptr, <2 x float> %vec2, <3 x float> %vec3, i8 %arg, float
     EXPECT_EQ(cast<sandboxir::Pack>(Result).getReason(),
               sandboxir::ResultReason::DiffWrapFlags);
   }
+  {
+    // Check DiffTypes for unary operands that have a different type.
+    const auto &Result = Legality.canVectorize({Trunc64to8, Trunc32to8});
+    EXPECT_TRUE(isa<sandboxir::Pack>(Result));
+    EXPECT_EQ(cast<sandboxir::Pack>(Result).getReason(),
+              sandboxir::ResultReason::DiffTypes);
+  }
+  {
+    // Check DiffOpcodes for CMPs with different predicates.
+    const auto &Result = Legality.canVectorize({CmpSLT, CmpSGT});
+    EXPECT_TRUE(isa<sandboxir::Pack>(Result));
+    EXPECT_EQ(cast<sandboxir::Pack>(Result).getReason(),
+              sandboxir::ResultReason::DiffOpcodes);
+  }
+  {
+    // Check NotConsecutive Ld0,Ld0b
+    const auto &Result = Legality.canVectorize({Ld0, Ld0b});
+    EXPECT_TRUE(isa<sandboxir::Pack>(Result));
+    EXPECT_EQ(cast<sandboxir::Pack>(Result).getReason(),
+              sandboxir::ResultReason::NotConsecutive);
+  }
+  {
+    // Check NotConsecutive Ld0,Ld3
+    const auto &Result = Legality.canVectorize({Ld0, Ld3});
+    EXPECT_TRUE(isa<sandboxir::Pack>(Result));
+    EXPECT_EQ(cast<sandboxir::Pack>(Result).getReason(),
+              sandboxir::ResultReason::NotConsecutive);
+  }
+  {
+    // Check Widen Ld0,Ld1
+    const auto &Result = Legality.canVectorize({Ld0, Ld1});
+    EXPECT_TRUE(isa<sandboxir::Widen>(Result));
+  }
 }
 
 #ifndef NDEBUG
 TEST_F(LegalityTest, LegalityResultDump) {
+  parseIR(C, R"IR(
+define void @foo() {
+  ret void
+}
+)IR");
+  llvm::Function *LLVMF = &*M->getFunction("foo");
+  auto &SE = getSE(*LLVMF);
+  const auto &DL = M->getDataLayout();
+
   auto Matches = [](const sandboxir::LegalityResult &Result,
                     const std::string &ExpectedStr) -> bool {
     std::string Buff;
@@ -120,7 +199,8 @@ TEST_F(LegalityTest, LegalityResultDump) {
     Result.print(OS);
     return Buff == ExpectedStr;
   };
-  sandboxir::LegalityAnalysis Legality;
+
+  sandboxir::LegalityAnalysis Legality(SE, DL);
   EXPECT_TRUE(
       Matches(Legality.createLegalityResult<sandboxir::Widen>(), "Widen"));
   EXPECT_TRUE(Matches(Legality.createLegalityResult<sandboxir::Pack>(
diff --git a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/VecUtilsTest.cpp b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/VecUtilsTest.cpp
index e0b08284964392..75f72ce23fbaac 100644
--- a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/VecUtilsTest.cpp
+++ b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/VecUtilsTest.cpp
@@ -7,15 +7,47 @@
 //===----------------------------------------------------------------------===//
 
 #include "llvm/Transforms/Vectorize/SandboxVectorizer/VecUtils.h"
+#include "llvm/Analysis/AliasAnalysis.h"
+#include "llvm/Analysis/AssumptionCache.h"
+#include "llvm/Analysis/BasicAliasAnalysis.h"
+#include "llvm/Analysis/LoopInfo.h"
+#include "llvm/Analysis/ScalarEvolution.h"
+#include "llvm/Analysis/TargetLibraryInfo.h"
 #include "llvm/AsmParser/Parser.h"
+#include "llvm/IR/DataLayout.h"
+#include "llvm/IR/Dominators.h"
 #include "llvm/SandboxIR/Context.h"
+#include "llvm/SandboxIR/Function.h"
 #include "llvm/SandboxIR/Type.h"
+#include "llvm/Support/SourceMgr.h"
 #include "gtest/gtest.h"
 
 using namespace llvm;
 
 struct VecUtilsTest : public testing::Test {
   LLVMContext C;
+  std::unique_ptr<Module> M;
+  std::unique_ptr<AssumptionCache> AC;
+  std::unique_ptr<TargetLibraryInfoImpl> TLII;
+  std::unique_ptr<TargetLibraryInfo> TLI;
+  std::unique_ptr<DominatorTree> DT;
+  std::unique_ptr<LoopInfo> LI;
+  std::unique_ptr<ScalarEvolution> SE;
+  void parseIR(const char *IR) {
+    SMDiagnostic Err;
+    M = parseAssemblyString(IR, Err, C);
+    if (!M)
+      Err.print("VecUtilsTest", errs());
+  }
+  ScalarEvolution &getSE(llvm::Function &LLVMF) {
+    TLII = std::make_unique<TargetLibraryInfoImpl>();
+    TLI = std::make_unique<TargetLibraryInfo>(*TLII);
+    AC = std::make_unique<AssumptionCache>(LLVMF);
+    DT = std::make_unique<DominatorTree>(LLVMF);
+    LI = std::make_unique<LoopInfo>(*DT);
+    SE = std::make_unique<ScalarEvolution>(LLVMF, *TLI, *AC, *DT, *LI);
+    return *SE;
+  }
 };
 
 TEST_F(VecUtilsTest, GetNumElements) {
@@ -35,3 +67,304 @@ TEST_F(VecUtilsTest, GetElementType) {
   auto *VTy = sandboxir::FixedVectorType::get(ElemTy, 2);
   EXPECT_EQ(sandboxir::VecUtils::getElementType(VTy), ElemTy);
 }
+
+TEST_F(VecUtilsTest, AreConsecutive_gep_float) {
+  parseIR(R"IR(
+define void @foo(ptr %ptr) {
+  %gep0 = getelementptr inbounds float, ptr %ptr, i64 0
+  %gep1 = getelementptr inbounds float, ptr %ptr, i64 1
+  %gep2 = getelementptr inbounds float, ptr %ptr, i64 2
+  %gep3 = getelementptr inbounds float, ptr %ptr, i64 3
+
+  %ld0 = load float, ptr %gep0
+  %ld1 = load float, ptr %gep1
+  %ld2 = load float, ptr %gep2
+  %ld3 = load float, ptr %gep3
+
+  %v2ld0 = load <2 x float>, ptr %gep0
+  %v2ld1 = load <2 x float>, ptr %gep1
+  %v2ld2 = load <2 x float>, ptr %gep2
+  %v2ld3 = load <2 x float>, ptr %gep3
+
+  %v3ld0 = load <3 x float>, ptr %gep0
+  %v3ld1 = load <3 x float>, ptr %gep1
+  %v3ld2 = load <3 x float>, ptr %gep2
+  %v3ld3 = load <3 x float>, ptr %gep3
+  ret void
+}
+)IR");
+  Function &LLVMF = *M->getFunction("foo");
+  const DataLayout &DL = M->getDataLayout();
+  auto &SE = getSE(LLVMF);
+
+  sandboxir::Context Ctx(C);
+  auto &F = *Ctx.createFunction(&LLVMF);
+
+  auto &BB = *F.begin();
+  auto It = std::next(BB.begin(), 4);
+  auto *L0 = cast<sandboxir::LoadInst>(&*It++);
+  auto *L1 = cast<sandboxir::LoadInst>(&*It++);
+  auto *L2 = cast<sandboxir::LoadInst>(&*It++);
+  auto *L3 = cast<sandboxir::LoadInst>(&*It++);
+
+  auto *V2L0 = cast<sandboxir::LoadInst>(&*It++);
+  auto *V2L1 = cast<sandboxir::LoadInst>(&*It++);
+  auto *V2L2 = cast<sandboxir::LoadInst>(&*It++);
+  auto *V2L3 = cast<sandboxir::LoadInst>(&*It++);
+
+  auto *V3L0 = cast<sandboxir::LoadInst>(&*It++);
+  auto *V3L1 = cast<sandboxir::LoadInst>(&*It++);
+  auto *V3L2 = cast<sandboxir::LoadInst>(&*It++);
+  auto *V3L3 = cast<sandboxir::LoadInst>(&*It++);
+
+  // Scalar
+  EXPECT_TRUE(sandboxir::VecUtils::areConsecutive(L0, L1, SE, DL));
+  EXPECT_TRUE(sandboxir::VecUtils::areConsecutive(L1, L2, SE, DL));
+  EXPECT_TRUE(sandboxir::VecUtils::areConsecutive(L2, L3, SE, DL));
+  EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(L1, L0, SE, DL));
+  EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(L2, L1, SE, DL));
+  EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(L3, L2, SE, DL));
+  EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(L0, L2, SE, DL));
+  EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(L0, L3, SE, DL));
+  EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(L1, L3, SE, DL));
+  EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(L2, L0, SE, DL));
+  EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(L3, L1, SE, DL));
+
+  // Check 2-wide loads
+  EXPECT_TRUE(sandboxir::VecUtils::areConsecutive(V2L0, V2L2, SE, DL));
+  EXPECT_TRUE(sandboxir::VecUtils::areConsecutive(V2L1, V2L3, SE, DL));
+  EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V2L0, V2L1, SE, DL));
+  EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V2L1, V2L2, SE, DL));
+  EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V2L2, V2L3, SE, DL));
+
+  EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V2L3, V2L1, SE, DL));
+  EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V2L3, V2L1, SE, DL));
+  EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V2L3, V2L1, SE, DL));
+  EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V2L3, V2L1, SE, DL));
+
+  // Check 3-wide loads
+  EXPECT_TRUE(sandboxir::VecUtils::areConsecutive(V3L0, V3L3, SE, DL));
+  EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V3L0, V3L1, SE, DL));
+  EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V3L1, V3L2, SE, DL));
+  EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V3L2, V3L3, SE, DL));
+  EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V3L1, V3L0, SE, DL));
+  EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V3L2, V3L1, SE, DL));
+  EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V3L3, V3L2, SE, DL));
+
+  // Check mixes of vectors and scalar
+  EXPECT_TRUE(sandboxir::VecUtils::areConsecutive(L0, V2L1, SE, DL));
+  EXPECT_TRUE(sandboxir::VecUtils::areConsecutive(L1, V2L2, SE, DL));
+  EXPECT_TRUE(sandboxir::VecUtils::areConsecutive(V2L0, L2, SE, DL));
+  EXPECT_TRUE(sandboxir::VecUtils::areConsecutive(V3L0, L3, SE, DL));
+  EXPECT_TRUE(sandboxir::VecUtils::areConsecutive(V2L0, V3L2, SE, DL));
+
+  EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(L0, V2L2, SE, DL));
+  EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(L0, V3L2, SE, DL));
+  EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(L0, V2L3, SE, DL));
+  EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V2L0, V3L1, SE, DL));
+  EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V3L0, L1, SE, DL));
+  EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V3L0, L2, SE, DL));
+  EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V3L0, V2L1, SE, DL));
+  EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V3L0, V2L2, SE, DL));
+  EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V2L1, L0, SE, DL));
+}
+
+TEST_F(VecUtilsTest, AreConsecutive_gep_i8) {
+  parseIR(R"IR(
+define void @foo(ptr %ptr) {
+  %gep0 = getelementptr inbounds i8, ptr %ptr, i64 0
+  %gep1 = getelementptr inbounds i8, ptr %ptr, i64 4
+  %gep2 = getelementptr inbounds i8, ptr %ptr, i64 8
+  %gep3 = getelementptr inbounds i8, ptr %ptr, i64 12
+
+  %ld0 = load float, ptr %gep0
+  %ld1 = load float, ptr %gep1
+  %ld2 = load float, ptr %gep2
+  %ld3 = load float, ptr %gep3
+
+  %v2ld0 = load <2 x float>, ptr %gep0
+  %v2ld1 = load <2 x float>, ptr %gep1
+  %v2ld2 = load <2 x float>, ptr %gep2
+  %v2ld3 = load <2 x float>, ptr %gep3
+
+  %v3ld0 = load <3 x float>, ptr %gep0
+  %v3ld1 = load <3 x float>, ptr %gep1
+  %v3ld2 = load <3 x float>, ptr %gep2
+  %v3ld3 = load <3 x float>, ptr %gep3
+  ret void
+}
+)IR");
+  Function &LLVMF = *M->getFunction("foo");
+  const DataLayout &DL = M->getDataLayout();
+  auto &SE = getSE(LLVMF);
+
+  sandboxir::Context Ctx(C);
+  auto &F = *Ctx.createFunction(&LLVMF);
+  auto &BB = *F.begin();
+  auto It = std::next(BB.begin(), 4);
+  auto *L0 = cast<sandboxir::LoadInst>(&*It++);
+  auto *L1 = cast<sandboxir::LoadInst>(&*It++);
+  auto *L2 = cast<sandboxir::LoadInst>(&*It++);
+  auto *L3 = cast<sandboxir::LoadInst>(&*It++);
+
+  auto *V2L0 = cast<sandboxir::LoadInst>(&*It++);
+  auto *V2L1 = cast<sandboxir::LoadInst>(&*It++);
+  auto *V2L2 = cast<sandboxir::LoadInst>(&*It++);
+  auto *V2L3 = cast<sandboxir::LoadInst>(&*It++);
+
+  auto *V3L0 = cast<sandboxir::LoadInst>(&*It++);
+  auto *V3L1 = cast<sandboxir::LoadInst>(&*It++);
+  auto *V3L2 = cast<sandboxir::LoadInst>(&*It++);
+  auto *V3L3 = cast<sandboxir::LoadInst>(&*It++);
+
+  // Scalar
+  EXPECT_TRUE(sandboxir::VecUtils::areConsecutive(L0, L1, SE, DL));
+  EXPECT_TRUE(sandboxir::VecUtils::areConsecutive(L1, L2, SE, DL));
+  EXPECT_TRUE(sandboxir::VecUtils::areConsecutive(L2, L3, SE, DL));
+  EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(L1, L0, SE, DL));
+  EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(L2, L1, SE, DL));
+  EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(L3, L2, SE, DL));
+  EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(L0, L2, SE, DL));
+  EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(L0, L3, SE, DL));
+  EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(L1, L3, SE, DL));
+  EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(L2, L0, SE, DL));
+  EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(L3, L1, SE, DL));
+
+  // Check 2-wide loads
+  EXPECT_TRUE(sandboxir::VecUtils::areConsecutive(V2L0, V2L2, SE, DL));
+  EXPECT_TRUE(sandboxir::VecUtils::areConsecutive(V2L1, V2L3, SE, DL));
+  EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V2L0, V2L1, SE, DL));
+  EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V2L1, V2L2, SE, DL));
+  EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V2L2, V2L3, SE, DL));
+
+  EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V2L3, V2L1, SE, DL));
+  EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V2L3, V2L1, SE, DL));
+  EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V2L3, V2L1, SE, DL));
+  EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V2L3, V2L1, SE, DL));
+
+  // Check 3-wide loads
+  EXPECT_TRUE(sandboxir::VecUtils::areConsecutive(V3L0, V3L3, SE, DL));
+  EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V3L0, V3L1, SE, DL));
+  EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V3L1, V3L2, SE, DL));
+  EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V3L2, V3L3, SE, DL));
+  EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V3L1, V3L0, SE, DL));
+  EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V3L2, V3L1, SE, DL));
+  EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V3L3, V3L2, SE, DL));
+
+  // Check mixes of vectors and scalar
+  EXPECT_TRUE(sandboxir::VecUtils::areConsecutive(L0, V2L1, SE, DL));
+  EXPECT_TRUE(sandboxir::VecUtils::areConsecutive(L1, V2L2, SE, DL));
+  EXPECT_TRUE(sandboxir::VecUtils::areConsecutive(V2L0, L2, SE, DL));
+  EXPECT_TRUE(sandboxir::VecUtils::areConsecutive(V3L0, L3, SE, DL));
+  EXPECT_TRUE(sandboxir::VecUtils::areConsecutive(V2L0, V3L2, SE, DL));
+
+  EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(L0, V2L2, SE, DL));
+  EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(L0, V3L2, SE, DL));
+  EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(L0, V2L3, SE, DL));
+  EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V2L0, V3L1, SE, DL));
+  EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V3L0, L1, SE, DL));
+  EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V3L0, L2, SE, DL));
+  EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V3L0, V2L1, SE, DL));
+  EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V3L0, V2L2, SE, DL));
+  EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V2L1, L0, SE, DL));
+}
+
+TEST_F(VecUtilsTest, AreConsecutive_gep_i1) {
+  parseIR(R"IR(
+define void @foo(ptr %ptr) {
+  %gep0 = getelementptr inbounds i1, ptr %ptr, i64 0
+  %gep1 = getelementptr inbounds i2, ptr %ptr, i64 4
+  %gep2 = getelementptr inbounds i3, ptr %ptr, i64 8
+  %gep3 = getelementptr inbounds i7, ptr %ptr, i64 12
+
+  %ld0 = load float, ptr %gep0
+  %ld1 = load float, ptr %gep1
+  %ld2 = load float, ptr %gep2
+  %ld3 = load float, ptr %gep3
+
+  %v2ld0 = load <2 x float>, ptr %gep0
+  %v2ld1 = load <2 x float>, ptr %gep1
+  %v2ld2 = load <2 x float>, ptr %gep2
+  %v2ld3 = load <2 x float>, ptr %gep3
+
+  %v3ld0 = load <3 x float>, ptr %gep0
+  %v3ld1 = load <3 x float>, ptr %gep1
+  %v3ld2 = load <3 x float>, ptr %gep2
+  %v3ld3 = load <3 x float>, ptr %gep3
+  ret void
+}
+)IR");
+  Function &LLVMF = *M->getFunction("foo");
+  const DataLayout &DL = M->getDataLayout();
+  auto &SE = getSE(LLVMF);
+
+  sandboxir::Context Ctx(C);
+  auto &F = *Ctx.createFunction(&LLVMF);
+  auto &BB = *F.begin();
+  auto It = std::next(BB.begin(), 4);
+  auto *L0 = cast<sandboxir::LoadInst>(&*It++);
+  auto *L1 = cast<sandboxir::LoadInst>(&*It++);
+  auto *L2 = cast<sandboxir::LoadInst>(&*It++);
+  auto *L3 = cast<sandboxir::LoadInst>(&*It++);
+
+  auto *V2L0 = cast<sandboxir::LoadInst>(&*It++);
+  auto *V2L1 = cast<sandboxir::LoadInst>(&*It++);
+  auto *V2L2 = cast<sandboxir::LoadInst>(&*It++);
+  auto *V2L3 = cast<sandboxir::LoadInst>(&*It++);
+
+  auto *V3L0 = cast<sandboxir::LoadInst>(&*It++);
+  auto *V3L1 = cast<sandboxir::LoadInst>(&*It++);
+  auto *V3L2 = cast<sandboxir::LoadInst>(&*It++);
+  auto *V3L3 = cast<sandboxir::LoadInst>(&*It++);
+
+  // Scalar
+  EXPECT_TRUE(sandboxir::VecUtils::areConsecutive(L0, L1, SE, DL));
+  EXPECT_TRUE(sandboxir::VecUtils::areConsecutive(L1, L2, SE, DL));
+  EXPECT_TRUE(sandboxir::VecUtils::areConsecutive(L2, L3, SE, DL));
+  EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(L1, L0, SE, DL));
+  EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(L2, L1, SE, DL));
+  EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(L3, L2, SE, DL));
+  EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(L0, L2, SE, DL));
+  EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(L0, L3, SE, DL));
+  EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(L1, L3, SE, DL));
+  EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(L2, L0, SE, DL));
+  EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(L3, L1, SE, DL));
+
+  // Check 2-wide loads
+  EXPECT_TRUE(sandboxir::VecUtils::areConsecutive(V2L0, V2L2, SE, DL));
+  EXPECT_TRUE(sandboxir::VecUtils::areConsecutive(V2L1, V2L3, SE, DL));
+  EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V2L0, V2L1, SE, DL));
+  EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V2L1, V2L2, SE, DL));
+  EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V2L2, V2L3, SE, DL));
+
+  EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V2L3, V2L1, SE, DL));
+  EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V2L3, V2L1, SE, DL));
+  EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V2L3, V2L1, SE, DL));
+  EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V2L3, V2L1, SE, DL));
+
+  // Check 3-wide loads
+  EXPECT_TRUE(sandboxir::VecUtils::areConsecutive(V3L0, V3L3, SE, DL));
+  EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V3L0, V3L1, SE, DL));
+  EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V3L1, V3L2, SE, DL));
+  EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V3L2, V3L3, SE, DL));
+  EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V3L1, V3L0, SE, DL));
+  EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V3L2, V3L1, SE, DL));
+  EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V3L3, V3L2, SE, DL));
+
+  // Check mixes of vectors and scalar
+  EXPECT_TRUE(sandboxir::VecUtils::areConsecutive(L0, V2L1, SE, DL));
+  EXPECT_TRUE(sandboxir::VecUtils::areConsecutive(L1, V2L2, SE, DL));
+  EXPECT_TRUE(sandboxir::VecUtils::areConsecutive(V2L0, L2, SE, DL));
+  EXPECT_TRUE(sandboxir::VecUtils::areConsecutive(V3L0, L3, SE, DL));
+  EXPECT_TRUE(sandboxir::VecUtils::areConsecutive(V2L0, V3L2, SE, DL));
+
+  EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(L0, V2L2, SE, DL));
+  EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(L0, V3L2, SE, DL));
+  EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(L0, V2L3, SE, DL));
+  EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V2L0, V3L1, SE, DL));
+  EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V3L0, L1, SE, DL));
+  EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V3L0, L2, SE, DL));
+  EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V3L0, V2L1, SE, DL));
+  EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V3L0, V2L2, SE, DL));
+  EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V2L1, L0, SE, DL));
+}



More information about the llvm-commits mailing list