[llvm] [SandboxVec] Legality boilerplate (PR #108650)

via llvm-commits llvm-commits at lists.llvm.org
Tue Sep 17 16:45:08 PDT 2024


https://github.com/vporpo updated https://github.com/llvm/llvm-project/pull/108650

>From 0066b7bb062530ed8f6cdf08102b251f6b470c0a Mon Sep 17 00:00:00 2001
From: Vasileios Porpodas <vporpodas at google.com>
Date: Wed, 11 Sep 2024 18:26:55 -0700
Subject: [PATCH] [SandboxVec] Legality boilerplate

This patch adds the basic API for the Legality component of the vectorizer.
It also adds some very basic code in the bottom-up vectorizer that uses the API.
---
 .../Vectorize/SandboxVectorizer/Legality.h    | 62 +++++++++++++++++++
 .../SandboxVectorizer/Passes/BottomUpVec.h    |  7 +++
 .../SandboxVectorizer/Passes/BottomUpVec.cpp  | 53 +++++++++++++++-
 .../SandboxVectorizer/CMakeLists.txt          |  1 +
 .../SandboxVectorizer/LegalityTest.cpp        | 56 +++++++++++++++++
 5 files changed, 178 insertions(+), 1 deletion(-)
 create mode 100644 llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h
 create mode 100644 llvm/unittests/Transforms/Vectorize/SandboxVectorizer/LegalityTest.cpp

diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h
new file mode 100644
index 00000000000000..78c1c0e4c04649
--- /dev/null
+++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h
@@ -0,0 +1,62 @@
+//===- Legality.h -----------------------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Legality checks for the Sandbox Vectorizer.
+//
+
+#ifndef LLVM_TRANSFORMS_VECTORIZE_SANDBOXVECTORIZER_LEGALITY_H
+#define LLVM_TRANSFORMS_VECTORIZE_SANDBOXVECTORIZER_LEGALITY_H
+
+#include "llvm/SandboxIR/SandboxIR.h"
+
+namespace llvm::sandboxir {
+
+class LegalityAnalysis;
+
+enum class LegalityResultID {
+  Widen, ///> Vectorize by combining scalars to a vector.
+};
+
+/// The legality outcome is represented by a class rather than an enum class
+/// because in some cases the legality checks are expensive and look for a
+/// particular instruction that can be passed along to the vectorizer to avoid
+/// repeating the same expensive computation.
+class LegalityResult {
+protected:
+  LegalityResultID ID;
+  /// Only Legality can create LegalityResults.
+  LegalityResult(LegalityResultID ID) : ID(ID) {}
+  friend class LegalityAnalysis;
+
+public:
+  LegalityResultID getSubclassID() const { return ID; }
+};
+
+class Widen final : public LegalityResult {
+  friend class LegalityAnalysis;
+  Widen() : LegalityResult(LegalityResultID::Widen) {}
+
+public:
+  static bool classof(const LegalityResult *From) {
+    return From->getSubclassID() == LegalityResultID::Widen;
+  }
+};
+
+/// Performs the legality analysis and returns a LegalityResult object.
+class LegalityAnalysis {
+public:
+  LegalityAnalysis() = default;
+  LegalityResult canVectorize(ArrayRef<Value *> Bndl) {
+    // TODO: For now everything is legal.
+    return Widen();
+  }
+};
+
+} // namespace llvm::sandboxir
+
+#endif // LLVM_TRANSFORMS_VECTORIZE_SANDBOXVECTORIZER_LEGALITY_H
diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.h
index 5b3d1a50aa1ec0..99582e3e0e0233 100644
--- a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.h
+++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.h
@@ -12,11 +12,18 @@
 #ifndef LLVM_TRANSFORMS_VECTORIZE_SANDBOXVECTORIZER_PASSES_BOTTOMUPVEC_H
 #define LLVM_TRANSFORMS_VECTORIZE_SANDBOXVECTORIZER_PASSES_BOTTOMUPVEC_H
 
+#include "llvm/ADT/ArrayRef.h"
 #include "llvm/SandboxIR/Pass.h"
+#include "llvm/SandboxIR/SandboxIR.h"
+#include "llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h"
 
 namespace llvm::sandboxir {
 
 class BottomUpVec final : public FunctionPass {
+  bool Change = false;
+  LegalityAnalysis Legality;
+  void vectorizeRec(ArrayRef<Value *> Bndl);
+  void tryVectorize(ArrayRef<Value *> Seeds);
 
 public:
   BottomUpVec() : FunctionPass("bottom-up-vec") {}
diff --git a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp
index c4870b70fd52da..0c44d05f0474d5 100644
--- a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp
+++ b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp
@@ -7,7 +7,58 @@
 //===----------------------------------------------------------------------===//
 
 #include "llvm/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.h"
+#include "llvm/ADT/SmallVector.h"
 
 using namespace llvm::sandboxir;
 
-bool BottomUpVec::runOnFunction(Function &F) { return false; }
+namespace llvm::sandboxir {
+// TODO: This is a temporary function that returns some seeds.
+//       Replace this with SeedCollector's function when it lands.
+static llvm::SmallVector<Value *, 4> collectSeeds(BasicBlock &BB) {
+  llvm::SmallVector<Value *, 4> Seeds;
+  for (auto &I : BB)
+    if (auto *SI = llvm::dyn_cast<StoreInst>(&I))
+      Seeds.push_back(SI);
+  return Seeds;
+}
+
+static SmallVector<Value *, 4> getOperand(ArrayRef<Value *> Bndl,
+                                          unsigned OpIdx) {
+  SmallVector<Value *, 4> Operands;
+  for (Value *BndlV : Bndl) {
+    auto *BndlI = cast<Instruction>(BndlV);
+    Operands.push_back(BndlI->getOperand(OpIdx));
+  }
+  return Operands;
+}
+
+} // namespace llvm::sandboxir
+
+void BottomUpVec::vectorizeRec(ArrayRef<Value *> Bndl) {
+  auto LegalityRes = Legality.canVectorize(Bndl);
+  switch (LegalityRes.getSubclassID()) {
+  case LegalityResultID::Widen: {
+    auto *I = cast<Instruction>(Bndl[0]);
+    for (auto OpIdx : seq<unsigned>(I->getNumOperands())) {
+      auto OperandBndl = getOperand(Bndl, OpIdx);
+      vectorizeRec(OperandBndl);
+    }
+    break;
+  }
+  }
+}
+
+void BottomUpVec::tryVectorize(ArrayRef<Value *> Bndl) { vectorizeRec(Bndl); }
+
+bool BottomUpVec::runOnFunction(Function &F) {
+  Change = false;
+  // TODO: Start from innermost BBs first
+  for (auto &BB : F) {
+    // TODO: Replace with proper SeedCollector function.
+    auto Seeds = collectSeeds(BB);
+    // TODO: Slice Seeds into smaller chunks.
+    if (Seeds.size() >= 2)
+      tryVectorize(Seeds);
+  }
+  return Change;
+}
diff --git a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/CMakeLists.txt b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/CMakeLists.txt
index 488c9c2344b56c..2c7bf7d7e87541 100644
--- a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/CMakeLists.txt
+++ b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/CMakeLists.txt
@@ -9,4 +9,5 @@ set(LLVM_LINK_COMPONENTS
 
 add_llvm_unittest(SandboxVectorizerTests
   DependencyGraphTest.cpp
+  LegalityTest.cpp
   )
diff --git a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/LegalityTest.cpp b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/LegalityTest.cpp
new file mode 100644
index 00000000000000..a136be41ae363b
--- /dev/null
+++ b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/LegalityTest.cpp
@@ -0,0 +1,56 @@
+//===- LegalityTest.cpp ---------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h"
+#include "llvm/AsmParser/Parser.h"
+#include "llvm/SandboxIR/SandboxIR.h"
+#include "llvm/Support/SourceMgr.h"
+#include "gtest/gtest.h"
+
+using namespace llvm;
+
+struct LegalityTest : public testing::Test {
+  LLVMContext C;
+  std::unique_ptr<Module> M;
+
+  void parseIR(LLVMContext &C, const char *IR) {
+    SMDiagnostic Err;
+    M = parseAssemblyString(IR, Err, C);
+    if (!M)
+      Err.print("LegalityTest", errs());
+  }
+};
+
+TEST_F(LegalityTest, Legality) {
+  parseIR(C, R"IR(
+define void @foo(ptr %ptr) {
+  %gep0 = getelementptr float, ptr %ptr, i32 0
+  %gep1 = getelementptr float, ptr %ptr, i32 1
+  %ld0 = load float, ptr %gep0
+  %ld1 = load float, ptr %gep0
+  store float %ld0, ptr %gep0
+  store float %ld1, ptr %gep1
+  ret void
+}
+)IR");
+  llvm::Function *LLVMF = &*M->getFunction("foo");
+  sandboxir::Context Ctx(C);
+  auto *F = Ctx.createFunction(LLVMF);
+  auto *BB = &*F->begin();
+  auto It = BB->begin();
+  [[maybe_unused]] auto *Gep0 = cast<sandboxir::GetElementPtrInst>(&*It++);
+  [[maybe_unused]] auto *Gep1 = cast<sandboxir::GetElementPtrInst>(&*It++);
+  [[maybe_unused]] auto *Ld0 = cast<sandboxir::LoadInst>(&*It++);
+  [[maybe_unused]] auto *Ld1 = cast<sandboxir::LoadInst>(&*It++);
+  auto *St0 = cast<sandboxir::StoreInst>(&*It++);
+  auto *St1 = cast<sandboxir::StoreInst>(&*It++);
+
+  sandboxir::LegalityAnalysis Legality;
+  auto Result = Legality.canVectorize({St0, St1});
+  EXPECT_TRUE(isa<sandboxir::Widen>(Result));
+}



More information about the llvm-commits mailing list