[Mlir-commits] [llvm] [mlir] Op definition (PR #93931)
Shuanglong Kan
llvmlistbot at llvm.org
Fri May 31 00:04:06 PDT 2024
https://github.com/ShlKan created https://github.com/llvm/llvm-project/pull/93931
None
>From 1037c321537642a7df35d33d4a10a6ad5d806c9c Mon Sep 17 00:00:00 2001
From: Shuanglong Kan <kanshuanglong at outlook.com>
Date: Thu, 21 Mar 2024 15:48:40 +0100
Subject: [PATCH 1/5] feat: initial
---
llvm/CMakeLists.txt | 2 +-
llvm/tools/CMakeLists.txt | 1 +
obs/CMakeLists.txt | 15 +++++++++++++++
obs/cmake/modules/AddOBS.cmake | 5 +++++
obs/cmake/modules/CMakeLists.txt | 0
obs/obs-ir/CMakeLists.txt | 5 +++++
obs/obs-ir/test.cpp | 5 +++++
7 files changed, 32 insertions(+), 1 deletion(-)
create mode 100644 obs/CMakeLists.txt
create mode 100644 obs/cmake/modules/AddOBS.cmake
create mode 100644 obs/cmake/modules/CMakeLists.txt
create mode 100644 obs/obs-ir/CMakeLists.txt
create mode 100644 obs/obs-ir/test.cpp
diff --git a/llvm/CMakeLists.txt b/llvm/CMakeLists.txt
index 6f5647d70d8bc..d3eaa7670eb15 100644
--- a/llvm/CMakeLists.txt
+++ b/llvm/CMakeLists.txt
@@ -114,7 +114,7 @@ endif()
# LLVM_EXTERNAL_${project}_SOURCE_DIR using LLVM_ALL_PROJECTS
# This allows an easy way of setting up a build directory for llvm and another
# one for llvm+clang+... using the same sources.
-set(LLVM_ALL_PROJECTS "bolt;clang;clang-tools-extra;compiler-rt;cross-project-tests;libc;libclc;lld;lldb;mlir;openmp;polly;pstl")
+set(LLVM_ALL_PROJECTS "bolt;clang;clang-tools-extra;compiler-rt;cross-project-tests;libc;libclc;lld;lldb;mlir;obs;openmp;polly;pstl")
# The flang project is not yet part of "all" projects (see C++ requirements)
set(LLVM_EXTRA_PROJECTS "flang")
# List of all known projects in the mono repo
diff --git a/llvm/tools/CMakeLists.txt b/llvm/tools/CMakeLists.txt
index c6116ac81d12b..ac1728854a765 100644
--- a/llvm/tools/CMakeLists.txt
+++ b/llvm/tools/CMakeLists.txt
@@ -41,6 +41,7 @@ add_llvm_external_project(clang)
add_llvm_external_project(lld)
add_llvm_external_project(lldb)
add_llvm_external_project(mlir)
+add_llvm_external_project(obs)
# Flang depends on mlir, so place it afterward
add_llvm_external_project(flang)
add_llvm_external_project(bolt)
diff --git a/obs/CMakeLists.txt b/obs/CMakeLists.txt
new file mode 100644
index 0000000000000..6353f1c09cde4
--- /dev/null
+++ b/obs/CMakeLists.txt
@@ -0,0 +1,15 @@
+include(CMakeDependentOption)
+include(GNUInstallDirs)
+
+# Make sure that our source directory is on the current cmake module path so
+# that we can include cmake files from this directory.
+list(INSERT CMAKE_MODULE_PATH 0
+ "${CMAKE_CURRENT_SOURCE_DIR}/cmake/modules"
+ "${LLVM_COMMON_CMAKE_UTILS}/Modules"
+ )
+
+include(CMakeParseArguments)
+include(AddOBS)
+
+
+add_subdirectory(obs-ir)
\ No newline at end of file
diff --git a/obs/cmake/modules/AddOBS.cmake b/obs/cmake/modules/AddOBS.cmake
new file mode 100644
index 0000000000000..5c7420aad0312
--- /dev/null
+++ b/obs/cmake/modules/AddOBS.cmake
@@ -0,0 +1,5 @@
+macro(add_obs_executable name)
+ add_llvm_executable( ${name} ${ARGN} )
+ set_target_properties(${name} PROPERTIES FOLDER "OBS executables")
+endmacro(add_obs_executable)
+
diff --git a/obs/cmake/modules/CMakeLists.txt b/obs/cmake/modules/CMakeLists.txt
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/obs/obs-ir/CMakeLists.txt b/obs/obs-ir/CMakeLists.txt
new file mode 100644
index 0000000000000..f418db1ce4b20
--- /dev/null
+++ b/obs/obs-ir/CMakeLists.txt
@@ -0,0 +1,5 @@
+set(LLVM_LINK_COMPONENTS
+ Support
+ )
+
+add_obs_executable(obstest test.cpp)
\ No newline at end of file
diff --git a/obs/obs-ir/test.cpp b/obs/obs-ir/test.cpp
new file mode 100644
index 0000000000000..faf42d5289969
--- /dev/null
+++ b/obs/obs-ir/test.cpp
@@ -0,0 +1,5 @@
+#include <iostream>
+
+int main(int argc, char* argv[]) {
+ std::cout << "Hello World OBS" << std::endl;
+}
\ No newline at end of file
>From 494786e93565587ce6a055e6e87b12effc8e7056 Mon Sep 17 00:00:00 2001
From: Shuanglong Kan <kanshuanglong at outlook.com>
Date: Mon, 25 Mar 2024 19:27:01 +0100
Subject: [PATCH 2/5] feat: test 1
---
obs/.clang-tidy | 26 +++
obs/CMakeLists.txt | 8 +-
obs/include/AST.h | 227 +++++++++++++++++++
obs/include/Lexer.h | 184 ++++++++++++++++
obs/include/Parser.h | 452 ++++++++++++++++++++++++++++++++++++++
obs/obs-ir/CMakeLists.txt | 5 -
obs/obs-ir/test.cpp | 56 ++++-
obs/parser/AST.cpp | 222 +++++++++++++++++++
8 files changed, 1170 insertions(+), 10 deletions(-)
create mode 100644 obs/.clang-tidy
create mode 100644 obs/include/AST.h
create mode 100644 obs/include/Lexer.h
create mode 100644 obs/include/Parser.h
delete mode 100644 obs/obs-ir/CMakeLists.txt
create mode 100644 obs/parser/AST.cpp
diff --git a/obs/.clang-tidy b/obs/.clang-tidy
new file mode 100644
index 0000000000000..9cece0de812b8
--- /dev/null
+++ b/obs/.clang-tidy
@@ -0,0 +1,26 @@
+Checks: '-*,clang-diagnostic-*,llvm-*,misc-*,-misc-const-correctness,-misc-unused-parameters,-misc-non-private-member-variables-in-classes,-misc-no-recursion,-misc-use-anonymous-namespace,readability-identifier-naming,-misc-include-cleaner'
+CheckOptions:
+ - key: readability-identifier-naming.ClassCase
+ value: CamelCase
+ - key: readability-identifier-naming.EnumCase
+ value: CamelCase
+ - key: readability-identifier-naming.FunctionCase
+ value: camelBack
+ # Exclude from scanning as this is an exported symbol used for fuzzing
+ # throughout the code base.
+ - key: readability-identifier-naming.FunctionIgnoredRegexp
+ value: "LLVMFuzzerTestOneInput"
+ - key: readability-identifier-naming.MemberCase
+ value: CamelCase
+ - key: readability-identifier-naming.ParameterCase
+ value: CamelCase
+ - key: readability-identifier-naming.UnionCase
+ value: CamelCase
+ - key: readability-identifier-naming.VariableCase
+ value: CamelCase
+ - key: readability-identifier-naming.IgnoreMainLikeFunctions
+ value: 1
+ - key: readability-redundant-member-init.IgnoreBaseInCopyConstructors
+ value: 1
+ - key: modernize-use-default-member-init.UseAssignment
+ value: 1
diff --git a/obs/CMakeLists.txt b/obs/CMakeLists.txt
index 6353f1c09cde4..af775cf2ebb5c 100644
--- a/obs/CMakeLists.txt
+++ b/obs/CMakeLists.txt
@@ -1,3 +1,7 @@
+set(LLVM_LINK_COMPONENTS
+ Support
+ )
+
include(CMakeDependentOption)
include(GNUInstallDirs)
@@ -11,5 +15,5 @@ list(INSERT CMAKE_MODULE_PATH 0
include(CMakeParseArguments)
include(AddOBS)
-
-add_subdirectory(obs-ir)
\ No newline at end of file
+include_directories(include/)
+add_obs_executable(obstest obs-ir/test.cpp parser/AST.cpp)
\ No newline at end of file
diff --git a/obs/include/AST.h b/obs/include/AST.h
new file mode 100644
index 0000000000000..46d3f876143b2
--- /dev/null
+++ b/obs/include/AST.h
@@ -0,0 +1,227 @@
+#ifndef OBS_AST_H
+#define OBS_AST_H
+
+#include "Lexer.h"
+
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/StringRef.h"
+#include "llvm/Support/Casting.h"
+#include <utility>
+#include <vector>
+#include <optional>
+
+namespace obs {
+
+struct VarType {
+ std::vector<int64_t> shape ;
+};
+
+class ExprAST {
+public:
+ enum ExprASTKind {
+ Expr_VarDecl,
+ Expr_Return,
+ Expr_Num,
+ Expr_Literal,
+ Expr_Var,
+ Expr_BinOp,
+ Expr_Call,
+ Expr_Print,
+ };
+ ExprAST(ExprASTKind kind, Location location): kind(kind), location(std::move(location)) {}
+
+ ExprASTKind getKind() const { return kind; }
+
+ const Location &loc() {
+ return location;
+ }
+
+ ~ExprAST() = default;
+
+private:
+ const ExprASTKind kind;
+ Location location;
+};
+
+using ExprASTList = std::vector<std::unique_ptr<ExprAST>>;
+
+class NumberExprAST: public ExprAST {
+ double val;
+public:
+ NumberExprAST(Location loc, double val): ExprAST(Expr_Num, loc), val(val) { }
+ double getValue() {
+ return val;
+ }
+ static bool classof(const ExprAST *c) { return c->getKind() == Expr_Num; }
+};
+
+class LiteralExprAST : public ExprAST {
+ std::vector<std::unique_ptr<ExprAST>> values;
+ std::vector<int64_t> dims;
+public:
+ LiteralExprAST(Location loc, std::vector<std::unique_ptr<ExprAST>> values, std::vector<int64_t> dims ):
+ ExprAST(Expr_Literal, loc), values(std::move(values)), dims(std::move(dims)) { }
+
+ llvm::ArrayRef<std::unique_ptr<ExprAST>> getValues() {
+ return values;
+ }
+
+ llvm::ArrayRef<int64_t> getDims() {
+ return dims;
+ }
+
+ static bool classof(const ExprAST * c) {
+ return c->getKind() == Expr_Literal;
+ }
+
+};
+
+class VariableExprAST : public ExprAST {
+ std::string name;
+
+public:
+ VariableExprAST(Location loc, llvm::StringRef name): ExprAST(Expr_Var, std::move(loc)), name(name) { }
+ llvm::StringRef getName() {
+ return name;
+ }
+
+ static bool classof(const ExprAST *c) {
+ return c->getKind() == Expr_Var;
+ }
+};
+
+class VarDeclExprAST : public ExprAST {
+ std::string name;
+ VarType type;
+ std::unique_ptr<ExprAST> initVal;
+
+public:
+ VarDeclExprAST(Location loc, llvm::StringRef name, VarType type, std::unique_ptr<ExprAST>initVal ):
+ ExprAST(Expr_VarDecl, loc), name(name), type(type), initVal(std::move(initVal)) {}
+
+ llvm::StringRef getName() { return name; }
+ VarType getType() { return type; }
+ ExprAST * getInitVal() { return initVal.get(); }
+
+ static bool classof(const ExprAST * c) {
+ return c -> getKind() == Expr_VarDecl;
+ }
+};
+
+class ReturnExprAST : public ExprAST {
+ std::optional<std::unique_ptr<ExprAST>> expr;
+public:
+ ReturnExprAST(Location loc, std::optional<std::unique_ptr<ExprAST>> expr):
+ ExprAST(Expr_Return, loc), expr(std::move(expr)) {}
+
+ std::optional<ExprAST *> getExpr() {
+ if (expr.has_value()) {
+ return expr->get();
+ }
+ return std::nullopt;
+ }
+
+ static bool classof(const ExprAST * c) {
+ return c->getKind() == Expr_Return;
+ }
+};
+
+class BinaryExprAST : public ExprAST {
+ char op;
+ std::unique_ptr<ExprAST> lhs, rhs;
+public:
+ char getOp() { return op; }
+ ExprAST * getLHS() { return lhs.get(); }
+ ExprAST * getRHS() { return rhs.get(); }
+ BinaryExprAST(Location location, char op, std::unique_ptr<ExprAST> lhs, std::unique_ptr<ExprAST> rhs):
+ ExprAST(Expr_BinOp, location), op(op), lhs(std::move(lhs)), rhs(std::move(rhs)) {}
+
+ static bool classof(const ExprAST * c) {
+ return c->getKind() == Expr_BinOp;
+ }
+};
+
+class CallExprAST : public ExprAST {
+ std::string callee;
+ std::vector<std::unique_ptr<ExprAST>> args;
+
+public:
+ CallExprAST(Location loc, std::string callee, std::vector<std::unique_ptr<ExprAST>> args):
+ ExprAST(Expr_Call, loc), callee(callee), args(std::move(args)) {}
+
+ llvm::StringRef getCallee() {
+ return callee;
+ }
+ llvm::ArrayRef<std::unique_ptr<ExprAST>> getArgs() {
+ return args;
+ }
+
+ static bool classof(const ExprAST *c) {
+ return c->getKind() == Expr_Call;
+ }
+
+};
+
+/// Expression class for builtin print calls.
+class PrintExprAST : public ExprAST {
+ std::unique_ptr<ExprAST> arg;
+
+public:
+ PrintExprAST(Location loc, std::unique_ptr<ExprAST> arg)
+ : ExprAST(Expr_Print, std::move(loc)), arg(std::move(arg)) {}
+
+ ExprAST *getArg() { return arg.get(); }
+
+ /// LLVM style RTTI
+ static bool classof(const ExprAST *c) { return c->getKind() == Expr_Print; }
+};
+
+class PrototypeAST {
+ Location location;
+ std::string name;
+ std::vector<std::unique_ptr<VariableExprAST>> args;
+
+public:
+ PrototypeAST(Location location, const std::string &name,
+ std::vector<std::unique_ptr<VariableExprAST>> args)
+ : location(std::move(location)), name(name), args(std::move(args)) {}
+
+ const Location &loc() { return location; }
+ llvm::StringRef getName() const { return name; }
+ llvm::ArrayRef<std::unique_ptr<VariableExprAST>> getArgs() { return args; }
+};
+
+/// This class represents a function definition itself.
+class FunctionAST {
+ std::unique_ptr<PrototypeAST> proto;
+ std::unique_ptr<ExprASTList> body;
+
+public:
+ FunctionAST(std::unique_ptr<PrototypeAST> proto,
+ std::unique_ptr<ExprASTList> body)
+ : proto(std::move(proto)), body(std::move(body)) {}
+ PrototypeAST *getProto() { return proto.get(); }
+ ExprASTList *getBody() { return body.get(); }
+};
+
+/// This class represents a list of functions to be processed together
+class ModuleAST {
+ std::vector<FunctionAST> functions;
+
+public:
+ ModuleAST(std::vector<FunctionAST> functions)
+ : functions(std::move(functions)) {}
+
+ auto begin() { return functions.begin(); }
+ auto end() { return functions.end(); }
+};
+
+void dump(ModuleAST &);
+
+
+}
+
+
+#endif //OBS_AST_H
+
+
diff --git a/obs/include/Lexer.h b/obs/include/Lexer.h
new file mode 100644
index 0000000000000..c2b1c8cf54a0f
--- /dev/null
+++ b/obs/include/Lexer.h
@@ -0,0 +1,184 @@
+#ifndef OBS_LEXER_H
+#define OBS_LEXER_H
+
+#include "llvm/ADT/StringRef.h"
+
+#include <memory>
+#include <string>
+
+namespace obs {
+
+ struct Location {
+ std::shared_ptr<std::string> file;
+ int line;
+ int col;
+ };
+
+ enum Token : int {
+ tok_semicolon = ';',
+ tok_parenthese_open = '(',
+ tok_parenthese_close = ')',
+ tok_bracket_open = '{',
+ tok_bracket_close = '}',
+ tok_sbracket_open = '[',
+ tok_sbracket_close = ']',
+ tok_eof = - 1,
+ tok_return = -2,
+ tok_var = -3,
+ tok_def = -4,
+
+ tok_identifier = -5,
+ tok_number = -6,
+ };
+
+class Lexer {
+public:
+ Lexer(std::string filename) : lastLocation( {std::make_shared<std::string>(std::move(filename)), 0, 0} ) { }
+ virtual ~Lexer() = default;
+
+ Token getCurToken() {
+ return curTok;
+ }
+
+ Token getNextToken() {
+ return curTok = getTok();
+ }
+
+ void consume(Token tok) {
+ assert(tok == curTok && "consume Token mismatch expectation");
+ getNextToken();
+ }
+
+ llvm::StringRef getId() {
+ assert(curTok == tok_identifier);
+ return identifierStr;
+ }
+
+ double getValue() {
+ assert(curTok == tok_number);
+ return numVal;
+ }
+
+
+ Token getTok() {
+ while(isspace(lastChar)) {
+ lastChar = Token(getNextChar());
+ }
+
+ if (isalpha(lastChar)) {
+ identifierStr = (char)lastChar;
+ while(isalnum(lastChar = Token(getNextChar())) || lastChar == '_') {
+ identifierStr += (char)lastChar;
+ }
+ if (identifierStr == "return") {
+ return tok_return;
+ }
+ if (identifierStr == "def") {
+ return tok_def;
+ }
+ if (identifierStr == "var") {
+ return tok_var;
+ }
+ return tok_identifier;
+ }
+
+ //Number: [0-9] +
+ if (isdigit(lastChar) || lastChar == '.') {
+ std::string numStr;
+ do {
+ numStr += lastChar;
+ lastChar = Token(getNextChar());
+ } while(isdigit(lastChar) || lastChar == '.');
+
+ numVal = strtod(numStr.c_str(), nullptr);
+ return tok_number;
+ }
+
+ if (lastChar == '#') {
+ do {
+ lastChar = Token(getNextChar());
+ } while( lastChar != EOF && lastChar != '\n' && lastChar != '\r');
+
+ if (lastChar != EOF) {
+ return getTok();
+ }
+ }
+ if (lastChar == EOF) {
+ return tok_eof;
+ }
+
+ Token thisChar = Token(lastChar);
+ lastChar = Token(getNextChar());
+ return thisChar;
+ }
+
+ Location getLastLocation() {
+ return lastLocation;
+ }
+
+ int getLine(){
+ return curLineNum;
+ }
+
+ int getCol() {
+ return curCol;
+ }
+
+
+
+private:
+
+ virtual llvm::StringRef readNextLine() = 0;
+
+ int getNextChar() {
+ if (curLineBuffer.empty()) {
+ return EOF;
+ }
+ ++curCol;
+ auto nextChar = curLineBuffer.front();
+ curLineBuffer.drop_front();
+ if (curLineBuffer.empty()) {
+ curLineBuffer = readNextLine();
+ }
+ if (nextChar == '\n') {
+ ++curLineNum;
+ curCol = 0;
+ }
+ return nextChar;
+ }
+ //Private member variables.
+ Location lastLocation;
+ Token curTok = tok_eof;
+ Token lastChar = Token(' ');
+ llvm::StringRef curLineBuffer = "\n";
+ int curCol = 0;
+ int curLineNum = 0;
+
+ std::string identifierStr;
+ double numVal = 0;
+};
+
+class LexerBuffer final : public Lexer {
+public:
+ LexerBuffer(const char *begin, const char *end, std::string filename) : Lexer(std::move(filename)), current(begin), end(end) {}
+private:
+ const char* current, *end;
+
+ llvm::StringRef readNextLine() override {
+ auto *begin = current;
+ while (current <= end && *current && *current !='\n') {
+ ++current;
+ };
+ if (current <= end && *current ) {
+ ++current;
+ };
+ llvm::StringRef result{begin, static_cast<size_t>(current - begin)};
+ return result;
+ };
+};
+
+}
+
+
+
+#endif //OBS_LEXER_H
\ No newline at end of file
diff --git a/obs/include/Parser.h b/obs/include/Parser.h
new file mode 100644
index 0000000000000..e9e619774c7fb
--- /dev/null
+++ b/obs/include/Parser.h
@@ -0,0 +1,452 @@
+#ifndef OBS_PARSER_H
+#define OBS_PARSER_H
+
+#include "AST.h"
+#include "Lexer.h"
+
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/StringExtras.h"
+#include "llvm/Support/raw_ostream.h"
+
+#include <map>
+#include <utility>
+#include <vector>
+#include <optional>
+
+namespace obs {
+
+class Parser {
+public:
+ /// Create a Parser for the supplied lexer.
+ Parser(Lexer &lexer) : lexer(lexer) {}
+
+ /// Parse a full Module. A module is a list of function definitions.
+ std::unique_ptr<ModuleAST> parseModule() {
+ lexer.getNextToken(); // prime the lexer
+ // Parse functions one at a time and accumulate in this vector.
+ std::vector<FunctionAST> functions;
+ while (auto f = parseDefinition()) {
+ functions.push_back(std::move(*f));
+ if (lexer.getCurToken() == tok_eof)
+ break;
+ }
+ // If we didn't reach EOF, there was an error during parsing
+ if (lexer.getCurToken() != tok_eof)
+ return parseError<ModuleAST>("nothing", "at end of module");
+
+ return std::make_unique<ModuleAST>(std::move(functions));
+ }
+
+private:
+ Lexer &lexer;
+
+ /// Parse a return statement.
+ /// return :== return ; | return expr ;
+ std::unique_ptr<ReturnExprAST> parseReturn() {
+ auto loc = lexer.getLastLocation();
+ lexer.consume(tok_return);
+
+ // return takes an optional argument
+ std::optional<std::unique_ptr<ExprAST>> expr;
+ if (lexer.getCurToken() != ';') {
+ expr = parseExpression();
+ if (!expr)
+ return nullptr;
+ }
+ return std::make_unique<ReturnExprAST>(std::move(loc), std::move(expr));
+ }
+
+/// Parse a literal number.
+ /// numberexpr ::= number
+ std::unique_ptr<ExprAST> parseNumberExpr() {
+ auto loc = lexer.getLastLocation();
+ auto result =
+ std::make_unique<NumberExprAST>(std::move(loc), lexer.getValue());
+ lexer.consume(tok_number);
+ return std::move(result);
+ }
+
+ /// Parse a literal array expression.
+ /// tensorLiteral ::= [ literalList ] | number
+ /// literalList ::= tensorLiteral | tensorLiteral, literalList
+ std::unique_ptr<ExprAST> parseTensorLiteralExpr() {
+ auto loc = lexer.getLastLocation();
+ lexer.consume(Token('['));
+
+ // Hold the list of values at this nesting level.
+ std::vector<std::unique_ptr<ExprAST>> values;
+ // Hold the dimensions for all the nesting inside this level.
+ std::vector<int64_t> dims;
+ do {
+ // We can have either another nested array or a number literal.
+ if (lexer.getCurToken() == '[') {
+ values.push_back(parseTensorLiteralExpr());
+ if (!values.back())
+ return nullptr; // parse error in the nested array.
+ } else {
+ if (lexer.getCurToken() != tok_number)
+ return parseError<ExprAST>("<num> or [", "in literal expression");
+ values.push_back(parseNumberExpr());
+ }
+
+ // End of this list on ']'
+ if (lexer.getCurToken() == ']')
+ break;
+ // Elements are separated by a comma.
+ if (lexer.getCurToken() != ',')
+ return parseError<ExprAST>("] or ,", "in literal expression");
+
+ lexer.getNextToken(); // eat ,
+ } while (true);
+ if (values.empty())
+ return parseError<ExprAST>("<something>", "to fill literal expression");
+ lexer.getNextToken(); // eat ]
+
+ /// Fill in the dimensions now. First the current nesting level:
+ dims.push_back(values.size());
+
+ /// If there is any nested array, process all of them and ensure that
+ /// dimensions are uniform.
+ if (llvm::any_of(values, [](std::unique_ptr<ExprAST> &expr) {
+ return llvm::isa<LiteralExprAST>(expr.get());
+ })) {
+ auto *firstLiteral = llvm::dyn_cast<LiteralExprAST>(values.front().get());
+ if (!firstLiteral)
+ return parseError<ExprAST>("uniform well-nested dimensions",
+ "inside literal expression");
+
+ // Append the nested dimensions to the current level
+ auto firstDims = firstLiteral->getDims();
+ dims.insert(dims.end(), firstDims.begin(), firstDims.end());
+ // Sanity check that shape is uniform across all elements of the list.
+ for (auto &expr : values) {
+ auto *exprLiteral = llvm::cast<LiteralExprAST>(expr.get());
+ if (!exprLiteral)
+ return parseError<ExprAST>("uniform well-nested dimensions",
+ "inside literal expression");
+ if (exprLiteral->getDims() != firstDims)
+ return parseError<ExprAST>("uniform well-nested dimensions",
+ "inside literal expression");
+ }
+ }
+ return std::make_unique<LiteralExprAST>(std::move(loc), std::move(values),
+ std::move(dims));
+ }
+
+ /// parenexpr ::= '(' expression ')'
+ std::unique_ptr<ExprAST> parseParenExpr() {
+ lexer.getNextToken(); // eat (.
+ auto v = parseExpression();
+ if (!v)
+ return nullptr;
+
+ if (lexer.getCurToken() != ')')
+ return parseError<ExprAST>(")", "to close expression with parentheses");
+ lexer.consume(Token(')'));
+ return v;
+ }
+ /// identifierexpr
+ /// ::= identifier
+ /// ::= identifier '(' expression ')'
+ std::unique_ptr<ExprAST> parseIdentifierExpr() {
+ std::string name(lexer.getId());
+
+ auto loc = lexer.getLastLocation();
+ lexer.getNextToken(); // eat identifier.
+
+ if (lexer.getCurToken() != '(') // Simple variable ref.
+ return std::make_unique<VariableExprAST>(std::move(loc), name);
+
+ // This is a function call.
+ lexer.consume(Token('('));
+ std::vector<std::unique_ptr<ExprAST>> args;
+ if (lexer.getCurToken() != ')') {
+ while (true) {
+ if (auto arg = parseExpression())
+ args.push_back(std::move(arg));
+ else
+ return nullptr;
+ if (lexer.getCurToken() == ')')
+ break;
+
+ if (lexer.getCurToken() != ',')
+ return parseError<ExprAST>(", or )", "in argument list");
+ lexer.getNextToken();
+ }
+ }
+ lexer.consume(Token(')'));
+
+ // It can be a builtin call to print
+ if (name == "print") {
+ if (args.size() != 1)
+ return parseError<ExprAST>("<single arg>", "as argument to print()");
+
+ return std::make_unique<PrintExprAST>(std::move(loc), std::move(args[0]));
+ }
+
+ // Call to a user-defined function
+ return std::make_unique<CallExprAST>(std::move(loc), name, std::move(args));
+ }
+ /// primary
+ /// ::= identifierexpr
+ /// ::= numberexpr
+ /// ::= parenexpr
+ /// ::= tensorliteral
+ std::unique_ptr<ExprAST> parsePrimary() {
+ switch (lexer.getCurToken()) {
+ default:
+ llvm::errs() << "unknown token '" << lexer.getCurToken()
+ << "' when expecting an expression\n";
+ return nullptr;
+ case tok_identifier:
+ return parseIdentifierExpr();
+ case tok_number:
+ return parseNumberExpr();
+ case '(':
+ return parseParenExpr();
+ case '[':
+ return parseTensorLiteralExpr();
+ case ';':
+ return nullptr;
+ case '}':
+ return nullptr;
+ }
+ }
+
+ std::unique_ptr<ExprAST> parseBinOpRHS(int exprPrec,
+ std::unique_ptr<ExprAST> lhs) {
+ // If this is a binop, find its precedence.
+ while (true) {
+ int tokPrec = getTokPrecedence();
+
+ // If this is a binop that binds at least as tightly as the current binop,
+ // consume it, otherwise we are done.
+ if (tokPrec < exprPrec)
+ return lhs;
+
+ // Okay, we know this is a binop.
+ int binOp = lexer.getCurToken();
+ lexer.consume(Token(binOp));
+ auto loc = lexer.getLastLocation();
+
+ // Parse the primary expression after the binary operator.
+ auto rhs = parsePrimary();
+ if (!rhs)
+ return parseError<ExprAST>("expression", "to complete binary operator");
+
+ // If BinOp binds less tightly with rhs than the operator after rhs, let
+ // the pending operator take rhs as its lhs.
+ int nextPrec = getTokPrecedence();
+ if (tokPrec < nextPrec) {
+ rhs = parseBinOpRHS(tokPrec + 1, std::move(rhs));
+ if (!rhs)
+ return nullptr;
+ }
+
+ // Merge lhs/RHS.
+ lhs = std::make_unique<BinaryExprAST>(std::move(loc), binOp,
+ std::move(lhs), std::move(rhs));
+ }
+ }
+
+ /// expression::= primary binop rhs
+ std::unique_ptr<ExprAST> parseExpression() {
+ auto lhs = parsePrimary();
+ if (!lhs)
+ return nullptr;
+
+ return parseBinOpRHS(0, std::move(lhs));
+ }
+
+ /// type ::= < shape_list >
+ /// shape_list ::= num | num , shape_list
+ std::unique_ptr<VarType> parseType() {
+ if (lexer.getCurToken() != '<')
+ return parseError<VarType>("<", "to begin type");
+ lexer.getNextToken(); // eat <
+
+ auto type = std::make_unique<VarType>();
+
+ while (lexer.getCurToken() == tok_number) {
+ type->shape.push_back(lexer.getValue());
+ lexer.getNextToken();
+ if (lexer.getCurToken() == ',')
+ lexer.getNextToken();
+ }
+
+ if (lexer.getCurToken() != '>')
+ return parseError<VarType>(">", "to end type");
+ lexer.getNextToken(); // eat >
+ return type;
+ }
+
+ std::unique_ptr<VarDeclExprAST> parseDeclaration() {
+ if (lexer.getCurToken() != tok_var)
+ return parseError<VarDeclExprAST>("var", "to begin declaration");
+ auto loc = lexer.getLastLocation();
+ lexer.getNextToken(); // eat var
+
+ if (lexer.getCurToken() != tok_identifier)
+ return parseError<VarDeclExprAST>("identified",
+ "after 'var' declaration");
+ std::string id(lexer.getId());
+ lexer.getNextToken(); // eat id
+
+ std::unique_ptr<VarType> type; // Type is optional, it can be inferred
+ if (lexer.getCurToken() == '<') {
+ type = parseType();
+ if (!type)
+ return nullptr;
+ }
+
+ if (!type)
+ type = std::make_unique<VarType>();
+ lexer.consume(Token('='));
+ auto expr = parseExpression();
+ return std::make_unique<VarDeclExprAST>(std::move(loc), std::move(id),
+ std::move(*type), std::move(expr));
+ }
+
+ std::unique_ptr<ExprASTList> parseBlock() {
+ if (lexer.getCurToken() != '{')
+ return parseError<ExprASTList>("{", "to begin block");
+ lexer.consume(Token('{'));
+
+ auto exprList = std::make_unique<ExprASTList>();
+
+ // Ignore empty expressions: swallow sequences of semicolons.
+ while (lexer.getCurToken() == ';')
+ lexer.consume(Token(';'));
+
+ while (lexer.getCurToken() != '}' && lexer.getCurToken() != tok_eof) {
+ if (lexer.getCurToken() == tok_var) {
+ // Variable declaration
+ auto varDecl = parseDeclaration();
+ if (!varDecl)
+ return nullptr;
+ exprList->push_back(std::move(varDecl));
+ } else if (lexer.getCurToken() == tok_return) {
+ // Return statement
+ auto ret = parseReturn();
+ if (!ret)
+ return nullptr;
+ exprList->push_back(std::move(ret));
+ } else {
+ // General expression
+ auto expr = parseExpression();
+ if (!expr)
+ return nullptr;
+ exprList->push_back(std::move(expr));
+ }
+ // Ensure that elements are separated by a semicolon.
+ if (lexer.getCurToken() != ';')
+ return parseError<ExprASTList>(";", "after expression");
+ // Ignore empty expressions: swallow sequences of semicolons.
+ while (lexer.getCurToken() == ';')
+ lexer.consume(Token(';'));
+ }
+
+ if (lexer.getCurToken() != '}')
+ return parseError<ExprASTList>("}", "to close block");
+
+ lexer.consume(Token('}'));
+ return exprList;
+ }
+
+ /// prototype ::= def id '(' decl_list ')'
+ /// decl_list ::= identifier | identifier, decl_list
+ std::unique_ptr<PrototypeAST> parsePrototype() {
+ auto loc = lexer.getLastLocation();
+
+ if (lexer.getCurToken() != tok_def)
+ return parseError<PrototypeAST>("def", "in prototype");
+ lexer.consume(tok_def);
+
+ if (lexer.getCurToken() != tok_identifier)
+ return parseError<PrototypeAST>("function name", "in prototype");
+
+ std::string fnName(lexer.getId());
+ lexer.consume(tok_identifier);
+
+ if (lexer.getCurToken() != '(')
+ return parseError<PrototypeAST>("(", "in prototype");
+ lexer.consume(Token('('));
+
+ std::vector<std::unique_ptr<VariableExprAST>> args;
+ if (lexer.getCurToken() != ')') {
+ do {
+ std::string name(lexer.getId());
+ auto loc = lexer.getLastLocation();
+ lexer.consume(tok_identifier);
+ auto decl = std::make_unique<VariableExprAST>(std::move(loc), name);
+ args.push_back(std::move(decl));
+ if (lexer.getCurToken() != ',')
+ break;
+ lexer.consume(Token(','));
+ if (lexer.getCurToken() != tok_identifier)
+ return parseError<PrototypeAST>(
+ "identifier", "after ',' in function parameter list");
+ } while (true);
+ }
+ if (lexer.getCurToken() != ')')
+ return parseError<PrototypeAST>(")", "to end function prototype");
+
+ // success.
+ lexer.consume(Token(')'));
+ return std::make_unique<PrototypeAST>(std::move(loc), fnName,
+ std::move(args));
+ }
+
+ /// Parse a function definition, we expect a prototype initiated with the
+ /// `def` keyword, followed by a block containing a list of expressions.
+ ///
+ /// definition ::= prototype block
+ std::unique_ptr<FunctionAST> parseDefinition() {
+ auto proto = parsePrototype();
+ if (!proto)
+ return nullptr;
+
+ if (auto block = parseBlock())
+ return std::make_unique<FunctionAST>(std::move(proto), std::move(block));
+ return nullptr;
+ }
+
+ /// Get the precedence of the pending binary operator token.
+ int getTokPrecedence() {
+ if (!isascii(lexer.getCurToken()))
+ return -1;
+
+ // 1 is lowest precedence.
+ switch (static_cast<char>(lexer.getCurToken())) {
+ case '-':
+ return 20;
+ case '+':
+ return 20;
+ case '*':
+ return 40;
+ default:
+ return -1;
+ }
+ }
+ /// Helper function to signal errors while parsing, it takes an argument
+ /// indicating the expected token and another argument giving more context.
+ /// Location is retrieved from the lexer to enrich the error message.
+ template <typename R, typename T, typename U = const char *>
+ std::unique_ptr<R> parseError(T &&expected, U &&context = "") {
+ auto curToken = lexer.getCurToken();
+ llvm::errs() << "Parse error (" << lexer.getLastLocation().line << ", "
+ << lexer.getLastLocation().col << "): expected '" << expected
+ << "' " << context << " but has Token " << curToken;
+ if (isprint(curToken))
+ llvm::errs() << " '" << (char)curToken << "'";
+ llvm::errs() << "\n";
+ return nullptr;
+ }
+};
+
+}
+
+
+
+
+#endif //OBS_PARSER_H
\ No newline at end of file
diff --git a/obs/obs-ir/CMakeLists.txt b/obs/obs-ir/CMakeLists.txt
deleted file mode 100644
index f418db1ce4b20..0000000000000
--- a/obs/obs-ir/CMakeLists.txt
+++ /dev/null
@@ -1,5 +0,0 @@
-set(LLVM_LINK_COMPONENTS
- Support
- )
-
-add_obs_executable(obstest test.cpp)
\ No newline at end of file
diff --git a/obs/obs-ir/test.cpp b/obs/obs-ir/test.cpp
index faf42d5289969..829797eaf4cf5 100644
--- a/obs/obs-ir/test.cpp
+++ b/obs/obs-ir/test.cpp
@@ -1,5 +1,55 @@
+#include "AST.h"
+#include "Lexer.h"
+#include "Parser.h"
+
+#include "llvm/ADT/StringRef.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/ErrorOr.h"
+#include "llvm/Support/MemoryBuffer.h"
+#include "llvm/Support/raw_ostream.h"
+#include <memory>
+#include <string>
+#include <system_error>
#include <iostream>
-int main(int argc, char* argv[]) {
- std::cout << "Hello World OBS" << std::endl;
-}
\ No newline at end of file
+using namespace obs;
+namespace cl = llvm::cl;
+
+static cl::opt<std::string> inputFilename(cl::Positional,
+ cl::desc("<input toy file>"),
+ cl::init("-"),
+ cl::value_desc("filename"));
+
+namespace {
+enum Action { None, DumpAST };
+}
+
+static cl::opt<enum Action> emitAction("emit", cl::desc("Select the kind of output desired"),
+ cl::values(clEnumValN(DumpAST, "ast", "output the AST dump")));
+
+std::unique_ptr<obs::ModuleAST> parseInputFile(llvm::StringRef filename) {
+ llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> fileOrErr = llvm::MemoryBuffer::getFileOrSTDIN(filename);
+ if (std::error_code ec = fileOrErr.getError()) {
+ llvm::errs() << "Could not open input file: " << ec.message() << "\n";
+ return nullptr;
+ }
+ auto buffer = fileOrErr.get() ->getBuffer();
+ LexerBuffer lexer(buffer.begin(), buffer.end(), std::string(filename));
+ Parser parser(lexer);
+ return parser.parseModule();
+}
+
+int main(int argc, char **argv) {
+ cl::ParseCommandLineOptions(argc, argv, "toy compiler\n");
+ auto moduleAST = parseInputFile(inputFilename);
+
+ switch(emitAction) {
+ case Action::DumpAST:
+ dump(*moduleAST);
+ return 0;
+ default:
+ llvm::errs() << "No action specified (parsing only?), use -emit=<action>\n";
+ }
+ return 0;
+}
+
diff --git a/obs/parser/AST.cpp b/obs/parser/AST.cpp
new file mode 100644
index 0000000000000..bb4f245b97903
--- /dev/null
+++ b/obs/parser/AST.cpp
@@ -0,0 +1,222 @@
+#include "AST.h"
+
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/Twine.h"
+#include "llvm/ADT/TypeSwitch.h"
+#include "llvm/Support/Casting.h"
+#include "llvm/Support/raw_ostream.h"
+#include <string>
+
+namespace obs {
+
+// RAII helper to manage increasing/decreasing the indentation as we traverse
+// the AST
+struct Indent {
+ Indent(int &level) : level(level) { ++level; }
+ ~Indent() { --level; }
+ int &level;
+};
+
+/// Helper class that implement the AST tree traversal and print the nodes along
+/// the way. The only data member is the current indentation level.
+class ASTDumper {
+public:
+ void dump(ModuleAST *node);
+
+private:
+ void dump(const VarType &type);
+ void dump(VarDeclExprAST *varDecl);
+ void dump(ExprAST *expr);
+ void dump(ExprASTList *exprList);
+ void dump(NumberExprAST *num);
+ void dump(LiteralExprAST *node);
+ void dump(VariableExprAST *node);
+ void dump(ReturnExprAST *node);
+ void dump(BinaryExprAST *node);
+ void dump(CallExprAST *node);
+ void dump(PrintExprAST *node);
+ void dump(PrototypeAST *node);
+ void dump(FunctionAST *node);
+
+ // Actually print spaces matching the current indentation level
+ void indent() {
+ for (int i = 0; i < curIndent; i++)
+ llvm::errs() << " ";
+ }
+ int curIndent = 0;
+};
+
+/// Return a formatted string for the location of any node
+template <typename T>
+static std::string loc(T *node) {
+ const auto &loc = node->loc();
+ return (llvm::Twine("@") + *loc.file + ":" + llvm::Twine(loc.line) + ":" +
+ llvm::Twine(loc.col))
+ .str();
+}
+
+// Helper Macro to bump the indentation level and print the leading spaces for
+// the current indentations
+#define INDENT() \
+ Indent level_(curIndent); \
+ indent();
+
+/// Dispatch to a generic expressions to the appropriate subclass using RTTI
+void ASTDumper::dump(ExprAST *expr) {
+ llvm::TypeSwitch<ExprAST *>(expr)
+ .Case<BinaryExprAST, CallExprAST, LiteralExprAST, NumberExprAST,
+ PrintExprAST, ReturnExprAST, VarDeclExprAST, VariableExprAST>(
+ [&](auto *node) { this->dump(node); })
+ .Default([&](ExprAST *) {
+ // No match, fallback to a generic message
+ INDENT();
+ llvm::errs() << "<unknown Expr, kind " << expr->getKind() << ">\n";
+ });
+}
+
+/// A variable declaration is printing the variable name, the type, and then
+/// recurse in the initializer value.
+void ASTDumper::dump(VarDeclExprAST *varDecl) {
+ INDENT();
+ llvm::errs() << "VarDecl " << varDecl->getName();
+ dump(varDecl->getType());
+ llvm::errs() << " " << loc(varDecl) << "\n";
+ dump(varDecl->getInitVal());
+}
+
+/// A "block", or a list of expression
+void ASTDumper::dump(ExprASTList *exprList) {
+ INDENT();
+ llvm::errs() << "Block {\n";
+ for (auto &expr : *exprList)
+ dump(expr.get());
+ indent();
+ llvm::errs() << "} // Block\n";
+}
+
+/// A literal number, just print the value.
+void ASTDumper::dump(NumberExprAST *num) {
+ INDENT();
+ llvm::errs() << num->getValue() << " " << loc(num) << "\n";
+}
+
+/// Helper to print recursively a literal. This handles nested array like:
+/// [ [ 1, 2 ], [ 3, 4 ] ]
+/// We print out such array with the dimensions spelled out at every level:
+/// <2,2>[<2>[ 1, 2 ], <2>[ 3, 4 ] ]
+void printLitHelper(ExprAST *litOrNum) {
+ // Inside a literal expression we can have either a number or another literal
+ if (auto *num = llvm::dyn_cast<NumberExprAST>(litOrNum)) {
+ llvm::errs() << num->getValue();
+ return;
+ }
+ auto *literal = llvm::cast<LiteralExprAST>(litOrNum);
+
+ // Print the dimension for this literal first
+ llvm::errs() << "<";
+ llvm::interleaveComma(literal->getDims(), llvm::errs());
+ llvm::errs() << ">";
+
+ // Now print the content, recursing on every element of the list
+ llvm::errs() << "[ ";
+ llvm::interleaveComma(literal->getValues(), llvm::errs(),
+ [&](auto &elt) { printLitHelper(elt.get()); });
+ llvm::errs() << "]";
+}
+
+/// Print a literal, see the recursive helper above for the implementation.
+void ASTDumper::dump(LiteralExprAST *node) {
+ INDENT();
+ llvm::errs() << "Literal: ";
+ printLitHelper(node);
+ llvm::errs() << " " << loc(node) << "\n";
+}
+
+/// Print a variable reference (just a name).
+void ASTDumper::dump(VariableExprAST *node) {
+ INDENT();
+ llvm::errs() << "var: " << node->getName() << " " << loc(node) << "\n";
+}
+
+/// Return statement print the return and its (optional) argument.
+void ASTDumper::dump(ReturnExprAST *node) {
+ INDENT();
+ llvm::errs() << "Return\n";
+ if (node->getExpr().has_value())
+ return dump(*node->getExpr());
+ {
+ INDENT();
+ llvm::errs() << "(void)\n";
+ }
+}
+
+/// Print a binary operation, first the operator, then recurse into LHS and RHS.
+void ASTDumper::dump(BinaryExprAST *node) {
+ INDENT();
+ llvm::errs() << "BinOp: " << node->getOp() << " " << loc(node) << "\n";
+ dump(node->getLHS());
+ dump(node->getRHS());
+}
+
+/// Print a call expression, first the callee name and the list of args by
+/// recursing into each individual argument.
+void ASTDumper::dump(CallExprAST *node) {
+ INDENT();
+ llvm::errs() << "Call '" << node->getCallee() << "' [ " << loc(node) << "\n";
+ for (auto &arg : node->getArgs())
+ dump(arg.get());
+ indent();
+ llvm::errs() << "]\n";
+}
+
+/// Print a builtin print call, first the builtin name and then the argument.
+void ASTDumper::dump(PrintExprAST *node) {
+ INDENT();
+ llvm::errs() << "Print [ " << loc(node) << "\n";
+ dump(node->getArg());
+ indent();
+ llvm::errs() << "]\n";
+}
+
+/// Print type: only the shape is printed in between '<' and '>'
+void ASTDumper::dump(const VarType &type) {
+ llvm::errs() << "<";
+ llvm::interleaveComma(type.shape, llvm::errs());
+ llvm::errs() << ">";
+}
+
+/// Print a function prototype, first the function name, and then the list of
+/// parameters names.
+void ASTDumper::dump(PrototypeAST *node) {
+ INDENT();
+ llvm::errs() << "Proto '" << node->getName() << "' " << loc(node) << "\n";
+ indent();
+ llvm::errs() << "Params: [";
+ llvm::interleaveComma(node->getArgs(), llvm::errs(),
+ [](auto &arg) { llvm::errs() << arg->getName(); });
+ llvm::errs() << "]\n";
+}
+
+/// Print a function, first the prototype and then the body.
+void ASTDumper::dump(FunctionAST *node) {
+ INDENT();
+ llvm::errs() << "Function \n";
+ dump(node->getProto());
+ dump(node->getBody());
+}
+
+/// Print a module, actually loop over the functions and print them in sequence.
+void ASTDumper::dump(ModuleAST *node) {
+ INDENT();
+ llvm::errs() << "Module:\n";
+ for (auto &f : *node)
+ dump(&f);
+}
+}
+
+namespace obs {
+
+// Public API
+void dump(ModuleAST &module) { ASTDumper().dump(&module); }
+
+} // namespace obs
>From 5804ff8e8f188548a2a8013a1baf250e1e2d7cd5 Mon Sep 17 00:00:00 2001
From: Shuanglong Kan <kanshuanglong at outlook.com>
Date: Sat, 30 Mar 2024 21:29:23 +0100
Subject: [PATCH 3/5] feat: practise toy - 1
---
obs/.clang-tidy | 4 +-
obs/CMakeLists.txt | 28 ++++-
obs/include/CMakeLists.txt | 6 +
obs/include/Dialect.h | 17 +++
obs/include/Lexer.h | 2 +-
obs/include/Ops.td | 203 +++++++++++++++++++++++++++++++
obs/obs-ir/Dialect.cpp | 165 +++++++++++++++++++++++++
obs/obs-ir/{test.cpp => obs.cpp} | 13 +-
8 files changed, 432 insertions(+), 6 deletions(-)
create mode 100644 obs/include/CMakeLists.txt
create mode 100644 obs/include/Dialect.h
create mode 100644 obs/include/Ops.td
create mode 100644 obs/obs-ir/Dialect.cpp
rename obs/obs-ir/{test.cpp => obs.cpp} (90%)
diff --git a/obs/.clang-tidy b/obs/.clang-tidy
index 9cece0de812b8..4b7b8e9479f29 100644
--- a/obs/.clang-tidy
+++ b/obs/.clang-tidy
@@ -13,11 +13,11 @@ CheckOptions:
- key: readability-identifier-naming.MemberCase
value: CamelCase
- key: readability-identifier-naming.ParameterCase
- value: CamelCase
+ value: camelCase
- key: readability-identifier-naming.UnionCase
value: CamelCase
- key: readability-identifier-naming.VariableCase
- value: CamelCase
+ value: camelCase
- key: readability-identifier-naming.IgnoreMainLikeFunctions
value: 1
- key: readability-redundant-member-init.IgnoreBaseInCopyConstructors
diff --git a/obs/CMakeLists.txt b/obs/CMakeLists.txt
index af775cf2ebb5c..5b7a50c53125d 100644
--- a/obs/CMakeLists.txt
+++ b/obs/CMakeLists.txt
@@ -5,6 +5,16 @@ set(LLVM_LINK_COMPONENTS
include(CMakeDependentOption)
include(GNUInstallDirs)
+set(MLIR_MAIN_SRC_DIR ${LLVM_MAIN_SRC_DIR}/../mlir/include ) # --src-root
+set(MLIR_INCLUDE_DIR ${LLVM_MAIN_SRC_DIR}/../mlir/include ) # --includedir
+set(MLIR_TABLEGEN_OUTPUT_DIR ${CMAKE_BINARY_DIR}/tools/mlir/include)
+include_directories(SYSTEM ${MLIR_INCLUDE_DIR})
+include_directories(SYSTEM ${MLIR_TABLEGEN_OUTPUT_DIR})
+
+
+include_directories(SYSTEM ${CMAKE_CURRENT_BINARY_DIR}/include)
+add_subdirectory(include)
+
# Make sure that our source directory is on the current cmake module path so
# that we can include cmake files from this directory.
list(INSERT CMAKE_MODULE_PATH 0
@@ -16,4 +26,20 @@ include(CMakeParseArguments)
include(AddOBS)
include_directories(include/)
-add_obs_executable(obstest obs-ir/test.cpp parser/AST.cpp)
\ No newline at end of file
+add_obs_executable(obstest
+ obs-ir/obs.cpp
+ obs-ir/Dialect.cpp
+ parser/AST.cpp
+ DEPENDS
+ OBSOpGen
+ )
+
+target_link_libraries(obstest
+ PRIVATE
+ MLIRAnalysis
+ MLIRFunctionInterfaces
+ MLIRIR
+ MLIRParser
+ MLIRSideEffectInterfaces
+ MLIRTransforms
+)
\ No newline at end of file
diff --git a/obs/include/CMakeLists.txt b/obs/include/CMakeLists.txt
new file mode 100644
index 0000000000000..83793c91c12e2
--- /dev/null
+++ b/obs/include/CMakeLists.txt
@@ -0,0 +1,6 @@
+set(LLVM_TARGET_DEFINITIONS Ops.td)
+mlir_tablegen(Ops.h.inc -gen-op-decls)
+mlir_tablegen(Ops.cpp.inc -gen-op-defs)
+mlir_tablegen(Dialect.h.inc -gen-dialect-decls)
+mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs)
+add_public_tablegen_target(OBSOpGen)
\ No newline at end of file
diff --git a/obs/include/Dialect.h b/obs/include/Dialect.h
new file mode 100644
index 0000000000000..a298ed2f086a0
--- /dev/null
+++ b/obs/include/Dialect.h
@@ -0,0 +1,17 @@
+#ifndef MLIR_DIALECT_TRAITS_H
+#define MLIR_DIALECT_TRAITS_H
+
+#include "mlir/Bytecode/BytecodeOpInterface.h"
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/SymbolTable.h"
+#include "mlir/Interfaces/CallInterfaces.h"
+#include "mlir/Interfaces/FunctionInterfaces.h"
+#include "mlir/Interfaces/SideEffectInterfaces.h"
+
+#include "Dialect.h.inc"
+
+#define GET_OP_CLASSES
+#include "Ops.h.inc"
+
+
+#endif //MLIR_DIALECT_TRAITS_H
diff --git a/obs/include/Lexer.h b/obs/include/Lexer.h
index c2b1c8cf54a0f..bf47f17dfa92c 100644
--- a/obs/include/Lexer.h
+++ b/obs/include/Lexer.h
@@ -136,7 +136,7 @@ class Lexer {
}
++curCol;
auto nextChar = curLineBuffer.front();
- curLineBuffer.drop_front();
+ curLineBuffer = curLineBuffer.drop_front();
if (curLineBuffer.empty()) {
curLineBuffer = readNextLine();
}
diff --git a/obs/include/Ops.td b/obs/include/Ops.td
new file mode 100644
index 0000000000000..96e9d02b72160
--- /dev/null
+++ b/obs/include/Ops.td
@@ -0,0 +1,203 @@
+
+#ifndef OBS_OPS
+#define OBS_OPS
+
+include "mlir/IR/OpBase.td"
+include "mlir/Interfaces/FunctionInterfaces.td"
+include "mlir/IR/SymbolInterfaces.td"
+include "mlir/Interfaces/SideEffectInterfaces.td"
+
+def OBS_Dialect: Dialect {
+ //The namespace of the Dialect.
+ let name = "toy";
+
+ let summary = "A high-level dialect for analyzing and optimzing the Language";
+
+ let description = [{
+ The Toy language is a tensor-based language that allows you to define functions, perform some math computation, and
+ print results. This dialect provides a reprentation of the language that is amenable to analysis and optimization.
+ }];
+
+ let cppNamespace = "::mlir::obs";
+}
+
+class OBS_op<string mnemonic, list<Trait> traits = []> : Op<OBS_Dialect, mnemonic, traits>;
+
+def ConstantOp : OBS_op<"constant", [Pure]> {
+
+ let summary = "constant operation" ;
+
+ let description = [{
+ Constant operation turns a literal into an SSA value. The data is attached to the operation as an attribute.
+ }];
+
+
+ let arguments = (ins F64ElementsAttr:$value);
+ let results = (outs F64Tensor);
+
+ let hasCustomAssemblyFormat = 1;
+
+ let hasVerifier = 1;
+
+ let builders = [
+ OpBuilder<(ins "DenseElementsAttr":$value), [{build($_builder, $_state, value.getType(), value);}] >,
+ OpBuilder<(ins "double":$value)>
+ ];
+
+}
+
+def AddOp : OBS_op<"add"> {
+ let summary = "element-wise addition operation";
+ let description = [{
+ The "add" operation performs element-wise addition between two tensors.
+ The shape of the tensor operands are expected to match.
+ }];
+
+ let arguments = (ins F64Tensor:$lhs, F64Tensor:$rhs);
+ let results = (outs F64Tensor);
+
+ let hasCustomAssemblyFormat = 1;
+
+ let builders = [
+ OpBuilder<(ins "Value":$lhs, "Value":$rhs)>
+ ];
+}
+
+def FuncOp : OBS_op<"func", [FunctionOpInterface, IsolatedFromAbove]> {
+ let summary = "user defined function operation";
+ let description = [{
+ The "toy.func" operation represents a user defined function. These are callable SSA-region operations
+ that contain toy computations.
+ }];
+
+ let arguments = (ins
+ SymbolNameAttr:$sysm_name,
+ TypeAttrOf<FunctionType>:$function_type,
+ OptionalAttr<DictArrayAttr>:$arg_attrs,
+ OptionalAttr<DictArrayAttr>:$res_attrs
+ );
+ let regions = (region AnyRegion:$body);
+
+ let builders = [OpBuilder<(ins "StringRef":$name, "FunctionType":$type,
+ CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)
+ >];
+
+ let extraClassDeclaration = [{
+ ArrayRef<Type> getArgumentTypes() {
+ return getFunctionType().getInputs();
+ }
+
+ ArrayRef<Type> getResultTypes() {
+ return getFunctionType().getResults();
+ }
+
+ Region *getCallableRegion() {
+ return &getBody();
+ }
+ }];
+
+ let hasCustomAssemblyFormat = 1;
+ let skipDefaultBuilders = 1;
+}
+
+def GenericCallOp: OBS_op<"generic_call"> {
+ let summary = "generic call operation";
+ let description = [{
+ Generic calls represent calls to a user defined function that needs to be specialized for the shape
+ of its arguments.
+ }];
+
+ let arguments = (ins FlatSymbolRefAttr:$callee, Variadic<F64Tensor>:$inputs);
+
+ let results = (outs F64Tensor);
+
+ let assemblyFormat = [{
+ $callee `(` $inputs `)` attr-dict `:` functional-type($inputs, results)
+ }];
+
+ let builders = [
+ OpBuilder<(ins "StringRef":$callee, "ArrayRef<Value>":$arguments)>
+ ];
+}
+
+def MulOp : OBS_op<"mul"> {
+ let summary = "element-wise multiplication operation";
+
+ let description = [{
+ The "mul" operation performs element-wise multiplication between two
+ tensors. The shapes of the tensor operands are expected to match.
+ }];
+
+ let arguments = (ins F64Tensor:$lhs, F64Tensor:$rhs);
+ let results = (outs F64Tensor);
+
+ let hasCustomAssemblyFormat = 1;
+ let builders = [
+ OpBuilder<(ins "Value":$lhs, "Value":$rhs)>
+ ];
+}
+
+def PrintOp: OBS_op<"print"> {
+ let summary = "print operation";
+ let description = [{
+ The print builtin operation prints a given input sensor, and produces no results.
+ }];
+ let arguments = (ins F64Tensor:$input);
+ let assemblyFormat = "$input attr-dict `:` type($input)";
+}
+
+def ReshapeOp: OBS_op<"reshape"> {
+ let summary = "tensor reshape operation";
+ let description = [{
+ Reshape operation is transforming its input sensor into a new tensor with the same number of
+ elements but different shapes.
+ }];
+
+ let arguments = (ins F64Tensor:$input);
+ let results = (outs StaticShapeTensorOf<[F64]>);
+
+ let assemblyFormat = [{
+ `(` $input `:` type($input) `)` attr-dict `to` type(results)
+ }];
+}
+
+def TransposeOp: OBS_op<"transpose"> {
+ let summary = "transpose operation";
+ let arguments = (ins F64Tensor:$input);
+ let results = (outs F64Tensor);
+
+ let assemblyFormat = [{
+ `(` $input `:` type($input) `)` attr-dict `to` type(results)
+ }];
+
+ let builders = [
+ OpBuilder<(ins "Value":$input)>
+ ];
+
+ let hasVerifier = 1;
+}
+
+def ReturnOp : OBS_op<"return", [Pure, HasParent<"FuncOp">, Terminator]> {
+ let summary = "return operation";
+ let description = [{
+ The "return" operation represents a return operation within a function;
+ }];
+
+ let arguments = (ins Variadic<F64Tensor>:$input);
+ let assemblyFormat = "($input^ `:` type($input))? attr-dict ";
+
+ let builders = [
+ OpBuilder<(ins), [{ build($_builder, $_state, std::nullopt); }]>
+ ];
+
+ let extraClassDeclaration = [{
+ bool hasOperand() {
+ return getNumOperands() != 0;
+ }
+ }];
+
+ let hasVerifier = 1;
+}
+
+
+#endif //OBS_OPS
diff --git a/obs/obs-ir/Dialect.cpp b/obs/obs-ir/Dialect.cpp
new file mode 100644
index 0000000000000..b33174b8355cd
--- /dev/null
+++ b/obs/obs-ir/Dialect.cpp
@@ -0,0 +1,165 @@
+
+#include "Dialect.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/OperationSupport.h"
+#include "mlir/IR/Value.h"
+#include "mlir/Interfaces/FunctionImplementation.h"
+#include "mlir/Support/LLVM.h"
+#include "mlir/Support/LogicalResult.h"
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/StringRef.h"
+#include "llvm/Support/Casting.h"
+#include <algorithm>
+#include <string>
+
+using namespace mlir;
+using namespace mlir::obs;
+
+#include "Dialect.cpp.inc"
+
+void OBSDialect::initialize() {
+ addOperations<
+ #define GET_OP_LIST
+ #include "Ops.cpp.inc"
+ >();
+}
+
+static mlir::ParseResult parseBinaryOp(mlir::OpAsmParser &parser, mlir::OperationState &result) {
+
+ SmallVector<mlir::OpAsmParser::UnresolvedOperand, 2> operands;
+ SMLoc operandsLoc = parser.getCurrentLocation();
+ Type type;
+
+ if (parser.parseOperandList(operands, 2) ||
+ parser.parseOptionalAttrDict(result.attributes) ||
+ parser.parseColonType(type))
+ return mlir::failure();
+
+ if (FunctionType funcType = llvm::dyn_cast<FunctionType>(type)) {
+ if (parser.resolveOperands(operands, funcType.getInputs(), operandsLoc,
+ result.operands))
+ return mlir::failure();
+ result.addTypes(funcType.getResults());
+ return mlir::success();
+ }
+ if (parser.resolveOperands(operands, type, result.operands))
+ return mlir::failure();
+
+ result.addTypes(type);
+ return mlir::success();
+}
+
+static void printBinary(mlir::OpAsmPrinter &printer, mlir::Operation *op) {
+ printer << " " << op->getOperands();
+ printer.printOptionalAttrDict(op->getAttrs());
+ printer << " : " ;
+
+ Type resultType = *op->result_type_begin();
+ if (llvm::all_of(op->getOperandTypes(), [=](Type type) { return type == resultType; })) {
+ printer << resultType;
+ }
+
+ printer.printFunctionalType(op->getOperandTypes(), op->getResultTypes());
+}
+
+void ConstantOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, double value) {
+ auto dataType = RankedTensorType::get({}, builder.getF64Type());
+ auto dataAttribute = DenseElementsAttr::get(dataType, value);
+ ConstantOp::build(builder, state, dataType, dataAttribute);
+}
+
+mlir::ParseResult ConstantOp::parse(mlir::OpAsmParser &parser, mlir::OperationState &result) {
+ mlir::DenseElementsAttr value;
+
+ if (parser.parseOptionalAttrDict(result.attributes) ||
+ parser.parseAttribute(value, "value", result.attributes))
+ return failure();
+
+ result.addTypes(value.getType());
+ return success();
+}
+
+void ConstantOp::print(mlir::OpAsmPrinter &printer) {
+ printer << " ";
+ printer.printOptionalAttrDict((*this)->getAttrs(), {"value"});
+ printer << getValue();
+}
+
+mlir::LogicalResult ConstantOp::verify() {
+ auto resultType = llvm::dyn_cast<mlir::RankedTensorType>(getResult().getType());
+ if (!resultType)
+ return success();
+
+ auto attrType = llvm::cast<mlir::RankedTensorType>(getValue().getType());
+ for (int dim = 0, dimE = attrType.getRank(); dim < dimE; ++dim) {
+ if (attrType.getShape()[dim] != resultType.getShape()[dim]) {
+ return emitOpError("return type shape mismatches its attribute at dimension")
+ << dim << ": " << attrType.getShape()[dim]
+ << " != " << resultType.getShape()[dim];
+ }
+ }
+ return mlir::success();
+}
+
+void AddOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
+ mlir::Value lhs, mlir::Value rhs) {
+ state.addTypes(UnrankedTensorType::get(builder.getF64Type()));
+ state.addOperands({lhs, rhs});
+}
+
+mlir::ParseResult AddOp::parse(mlir::OpAsmParser &parser, mlir::OperationState &result) {
+ return parseBinaryOp(parser, result);
+}
+
+void AddOp::print(mlir::OpAsmPrinter &p) {
+ printBinary(p, * this);
+}
+
+void GenericCallOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
+ StringRef callee, ArrayRef<mlir::Value> arguments) {
+ state.addTypes(UnrankedTensorType::get(builder.getF64Type()));
+ state.addOperands(arguments);
+ state.addAttribute("callee", mlir::SymbolRefAttr::get(builder.getContext(), callee));
+}
+
+void FuncOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
+ llvm::StringRef name, mlir::FunctionType type,
+ llvm::ArrayRef<mlir::NamedAttribute> attrs ) {
+ buildWithEntryBlock(builder, state, name, type, attrs, type.getInputs());
+}
+
+mlir::ParseResult FuncOp::parse(mlir::OpAsmParser &parser,
+ mlir::OperationState &result) {
+ auto buildFuncType =
+ [](mlir::Builder &builder, llvm::ArrayRef<mlir::Type> argTypes,
+ llvm::ArrayRef<mlir::Type> results,
+ mlir::function_interface_impl::VariadicFlag,
+ std::string &) { return builder.getFunctionType(argTypes, results); };
+ return mlir::function_interface_impl::
+ parseFunctionOp(parser, result, false,
+ getFunctionTypeAttrName(result.name),
+ buildFuncType,
+ getArgAttrsAttrName(result.name),
+ getResAttrsAttrName(result.name));
+}
+
+void FuncOp::print(mlir::OpAsmPrinter &p) {
+ mlir::function_interface_impl::
+ printFunctionOp(p, *this, false,
+ getFunctionTypeAttrName(),
+ getArgAttrsAttrName(),
+ getResAttrsAttrName());
+}
+
+
+#define GET_OP_CLASSES
+#include "Ops.cpp.inc"
+
+
+
+
diff --git a/obs/obs-ir/test.cpp b/obs/obs-ir/obs.cpp
similarity index 90%
rename from obs/obs-ir/test.cpp
rename to obs/obs-ir/obs.cpp
index 829797eaf4cf5..379054e0f4ea7 100644
--- a/obs/obs-ir/test.cpp
+++ b/obs/obs-ir/obs.cpp
@@ -2,15 +2,18 @@
#include "Lexer.h"
#include "Parser.h"
+#include "Dialect.h"
+#include "mlir/IR/MLIRContext.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/ErrorOr.h"
#include "llvm/Support/MemoryBuffer.h"
#include "llvm/Support/raw_ostream.h"
+#include <iostream>
#include <memory>
#include <string>
#include <system_error>
-#include <iostream>
+
using namespace obs;
namespace cl = llvm::cl;
@@ -22,7 +25,7 @@ static cl::opt<std::string> inputFilename(cl::Positional,
namespace {
enum Action { None, DumpAST };
-}
+} // namespace
static cl::opt<enum Action> emitAction("emit", cl::desc("Select the kind of output desired"),
cl::values(clEnumValN(DumpAST, "ast", "output the AST dump")));
@@ -39,6 +42,12 @@ std::unique_ptr<obs::ModuleAST> parseInputFile(llvm::StringRef filename) {
return parser.parseModule();
}
+int dumpMLIR() {
+ mlir::MLIRContext context;
+ context.getOrLoadDialect<mlir::obs::OBSDialect>();
+ return 0;
+}
+
int main(int argc, char **argv) {
cl::ParseCommandLineOptions(argc, argv, "toy compiler\n");
auto moduleAST = parseInputFile(inputFilename);
>From fe9af3502a7728b0bc7e1e3923bf7cd8241fa6b9 Mon Sep 17 00:00:00 2001
From: Shuanglong Kan <kanshuanglong at outlook.com>
Date: Mon, 1 Apr 2024 15:12:58 +0200
Subject: [PATCH 4/5] initial: finish toy example chapter 2
---
obs/CMakeLists.txt | 1 +
obs/include/Lexer.h | 3 +
obs/include/MLIRGen.h | 21 +++
obs/include/Ops.td | 8 +-
obs/obs-ir/Dialect.cpp | 60 ++++++++
obs/obs-ir/MLIRGen.cpp | 334 +++++++++++++++++++++++++++++++++++++++++
obs/obs-ir/obs.cpp | 93 ++++++++++--
7 files changed, 505 insertions(+), 15 deletions(-)
create mode 100644 obs/include/MLIRGen.h
create mode 100644 obs/obs-ir/MLIRGen.cpp
diff --git a/obs/CMakeLists.txt b/obs/CMakeLists.txt
index 5b7a50c53125d..05dcf074a73a6 100644
--- a/obs/CMakeLists.txt
+++ b/obs/CMakeLists.txt
@@ -29,6 +29,7 @@ include_directories(include/)
add_obs_executable(obstest
obs-ir/obs.cpp
obs-ir/Dialect.cpp
+ obs-ir/MLIRGen.cpp
parser/AST.cpp
DEPENDS
OBSOpGen
diff --git a/obs/include/Lexer.h b/obs/include/Lexer.h
index bf47f17dfa92c..6abbb6ec53886 100644
--- a/obs/include/Lexer.h
+++ b/obs/include/Lexer.h
@@ -65,6 +65,9 @@ class Lexer {
lastChar = Token(getNextChar());
}
+ lastLocation.line = curLineNum;
+ lastLocation.col = curCol;
+
if (isalpha(lastChar)) {
identifierStr = (char)lastChar;
while(isalnum(lastChar = Token(getNextChar())) || lastChar == '_') {
diff --git a/obs/include/MLIRGen.h b/obs/include/MLIRGen.h
new file mode 100644
index 0000000000000..b6c856fac5eec
--- /dev/null
+++ b/obs/include/MLIRGen.h
@@ -0,0 +1,21 @@
+
+#ifndef OBS_MLIRGEN_H
+#define OBS_MLIRGEN_H
+
+#include <memory>
+#include <mlir/IR/BuiltinOps.h>
+#include <mlir/IR/MLIRContext.h>
+
+namespace mlir{
+class MLIRContext;
+template <typename OpTy>
+class OwningOpRef;
+class ModuleOp;
+} // namespace mlir
+
+namespace obs {
+class ModuleAST;
+mlir::OwningOpRef<mlir::ModuleOp> mlirGen(mlir::MLIRContext &context, ModuleAST &moduleAST);
+} //namespace obs
+
+#endif //OBS_MLIRGEN_H
\ No newline at end of file
diff --git a/obs/include/Ops.td b/obs/include/Ops.td
index 96e9d02b72160..9463de72f2614 100644
--- a/obs/include/Ops.td
+++ b/obs/include/Ops.td
@@ -9,7 +9,7 @@ include "mlir/Interfaces/SideEffectInterfaces.td"
def OBS_Dialect: Dialect {
//The namespace of the Dialect.
- let name = "toy";
+ let name = "obs";
let summary = "A high-level dialect for analyzing and optimzing the Language";
@@ -66,12 +66,12 @@ def AddOp : OBS_op<"add"> {
def FuncOp : OBS_op<"func", [FunctionOpInterface, IsolatedFromAbove]> {
let summary = "user defined function operation";
let description = [{
- The "toy.func" operation represents a user defined function. These are callable SSA-region operations
- that contain toy computations.
+ The "obs.func" operation represents a user defined function. These are callable SSA-region operations
+ that contain obs computations.
}];
let arguments = (ins
- SymbolNameAttr:$sysm_name,
+ SymbolNameAttr:$sym_name,
TypeAttrOf<FunctionType>:$function_type,
OptionalAttr<DictArrayAttr>:$arg_attrs,
OptionalAttr<DictArrayAttr>:$res_attrs
diff --git a/obs/obs-ir/Dialect.cpp b/obs/obs-ir/Dialect.cpp
index b33174b8355cd..c5a4ef2e5c310 100644
--- a/obs/obs-ir/Dialect.cpp
+++ b/obs/obs-ir/Dialect.cpp
@@ -156,6 +156,66 @@ void FuncOp::print(mlir::OpAsmPrinter &p) {
getResAttrsAttrName());
}
+void MulOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
+ mlir::Value lhs, mlir::Value rhs) {
+ state.addTypes(UnrankedTensorType::get(builder.getF64Type()));
+ state.addOperands({lhs, rhs});
+}
+
+mlir::ParseResult MulOp::parse(mlir::OpAsmParser &parser,
+ mlir::OperationState &result) {
+ return parseBinaryOp(parser, result);
+}
+
+void MulOp::print(mlir::OpAsmPrinter &p) {
+ printBinary(p, * this);
+}
+
+mlir::LogicalResult ReturnOp::verify() {
+ auto function = cast<FuncOp>((*this).getParentOp());
+ if (getNumOperands() > 1) {
+ return emitOpError() << "expects at most 1 return operand" ;
+ }
+ const auto &results = function.getFunctionType().getResults();
+ if (getNumOperands() != results.size()) {
+ return emitOpError() << "does not return the same number of values("
+ << getNumOperands() << ") as the enclosing function ("
+ << results.size() << ")";
+ }
+
+ if (!hasOperand()) {
+ return mlir::success();
+ }
+
+ auto inputType = *operand_type_begin();
+ auto resultType = results.front();
+
+ if (inputType == resultType || llvm::isa<mlir::UnrankedTensorType>(inputType) ||
+ llvm::isa<mlir::UnrankedTensorType>(resultType) ) {
+ return mlir::success();
+ }
+ return emitError() << "type of return operand (" << inputType
+ << ") doesn't match function result type (" << resultType
+ << ")";
+}
+
+void TransposeOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, mlir::Value value) {
+ state.addTypes(UnrankedTensorType::get(builder.getF64Type()));
+ state.addOperands(value);
+}
+
+mlir::LogicalResult TransposeOp::verify() {
+ auto inputType = llvm::dyn_cast<RankedTensorType>(getOperand().getType());
+ auto resultType = llvm::dyn_cast<RankedTensorType>(getType());
+ if (!inputType || !resultType)
+ return mlir::success();
+
+ auto inputShape = inputType.getShape();
+ if(!std::equal(inputShape.begin(), inputShape.end(), resultType.getShape().rbegin())) {
+ return emitError() << "expected result shape to be a transpose of the input";
+ }
+ return mlir::success();
+}
#define GET_OP_CLASSES
#include "Ops.cpp.inc"
diff --git a/obs/obs-ir/MLIRGen.cpp b/obs/obs-ir/MLIRGen.cpp
new file mode 100644
index 0000000000000..388d7e5b5fdb0
--- /dev/null
+++ b/obs/obs-ir/MLIRGen.cpp
@@ -0,0 +1,334 @@
+
+#include "MLIRGen.h"
+
+#include "mlir/IR/Block.h"
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/Value.h"
+#include "mlir/Support/LogicalResult.h"
+#include "AST.h"
+#include "Dialect.h"
+
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/Verifier.h"
+#include "Lexer.h"
+
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/ScopedHashTable.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/StringRef.h"
+#include "llvm/ADT/Twine.h"
+#include "llvm/Support/Casting.h"
+#include <cassert>
+#include <cstdint>
+#include <functional>
+#include <mlir/IR/BuiltinAttributes.h>
+#include <mlir/IR/Location.h>
+#include <mlir/IR/OwningOpRef.h>
+#include <numeric>
+#include <optional>
+#include <vector>
+
+using namespace mlir::obs;
+using namespace obs;
+
+using llvm::ArrayRef;
+using llvm::cast;
+using llvm::dyn_cast;
+using llvm::isa;
+using llvm::ScopedHashTableScope;
+using llvm::SmallVector;
+using llvm::StringRef;
+using llvm::Twine;
+
+namespace {
+
+class MLIRGenImpl {
+
+public:
+ MLIRGenImpl(mlir::MLIRContext &context) : builder( &context) {}
+
+ mlir::ModuleOp mlirGen(ModuleAST &moduleAST) {
+ theModule = mlir::ModuleOp::create(builder.getUnknownLoc());
+
+ for (FunctionAST &f : moduleAST) {
+ mlirGen(f);
+ }
+
+ if (failed(mlir::verify(theModule))) {
+ theModule->emitError("module verification error");
+ return nullptr;
+ }
+ return theModule;
+ }
+
+private:
+ mlir::ModuleOp theModule;
+ mlir::OpBuilder builder;
+ llvm::ScopedHashTable<StringRef, mlir::Value> symbolTable;
+
+ mlir::Location loc(const Location &loc) {
+ return mlir::FileLineColLoc::get(builder.getStringAttr(*loc.file), loc.line, loc.col);
+ }
+
+ mlir::LogicalResult declare(llvm::StringRef var, mlir::Value value) {
+ if (symbolTable.count(var)) {
+ return mlir::failure();
+ }
+ symbolTable.insert(var, value);
+ return mlir::success();
+ }
+
+ mlir::Type getType(ArrayRef<int64_t> shape) {
+ if (shape.empty()) {
+ return mlir::UnrankedTensorType::get(builder.getF64Type());
+ }
+ return mlir::RankedTensorType::get(shape, builder.getF64Type());
+ }
+
+ mlir::Type getType(const VarType &type) {
+ return getType(type.shape);
+ }
+
+ mlir::obs::FuncOp mlirGen(PrototypeAST &proto) {
+ auto location = loc(proto.loc());
+
+ llvm::SmallVector<mlir::Type, 4> argTypes(proto.getArgs().size(), getType(VarType{}));
+
+ auto funcType = builder.getFunctionType(argTypes, std::nullopt);
+ return builder.create<mlir::obs::FuncOp>(location, proto.getName(), funcType);
+ }
+
+ void collectData(ExprAST &expr, std::vector<double> &data) {
+ if (auto *lit = dyn_cast<LiteralExprAST>(&expr)) {
+ for (auto &value : lit -> getValues()) {
+ collectData(*value, data);
+ return;
+ }
+ }
+
+ assert(isa<NumberExprAST>(expr) && "expected literal or number expr");
+ data.push_back(cast<NumberExprAST>(expr).getValue());
+ }
+
+ mlir::Value mlirGen(LiteralExprAST &lit ) {
+ auto type = getType(lit.getDims());
+ std::vector<double> data;
+ data.reserve(std::accumulate(lit.getDims().begin(), lit.getDims().end(), 1, std::multiplies<int>()));
+ collectData(lit, data);
+
+ mlir::Type elementType = builder.getF64Type();
+ auto dataType = mlir::RankedTensorType::get(lit.getDims(), elementType);
+
+ auto dataAttribute =
+ mlir::DenseElementsAttr::get(dataType, llvm::ArrayRef(data));
+
+ return builder.create<ConstantOp>(loc(lit.loc()), type, dataAttribute);
+ }
+
+ mlir::Value mlirGen(ExprAST &expr) {
+ switch (expr.getKind()) {
+ case obs::ExprAST::Expr_BinOp:
+ return mlirGen(cast<BinaryExprAST>(expr));
+ case obs::ExprAST::Expr_Var:
+ return mlirGen(cast<VariableExprAST>(expr));
+ case obs::ExprAST::Expr_Literal:
+ return mlirGen(cast<LiteralExprAST>(expr));
+ case obs::ExprAST::Expr_Call:
+ return mlirGen(cast<CallExprAST>(expr));
+ case obs::ExprAST::Expr_Num:
+ return mlirGen(cast<NumberExprAST>(expr));
+ default:
+ mlir::emitError(loc(expr.loc()))
+ << "MLIR codegen encounter an unhandled expr kind '"
+ << Twine(expr.getKind()) << "'";
+ return nullptr;
+ }
+ }
+
+ mlir::LogicalResult mlirGen(PrintExprAST &call) {
+ auto arg = mlirGen(*call.getArg());
+ if (!arg) {
+ return mlir::failure();
+ }
+
+ builder.create<PrintOp>(loc(call.loc()), arg);
+ return mlir::success();
+ }
+
+ mlir::LogicalResult mlirGen(ExprASTList &blockAST) {
+ ScopedHashTableScope<StringRef, mlir::Value> varScope(symbolTable);
+
+ for (auto &expr : blockAST) {
+ if (auto *vardecl = dyn_cast<VarDeclExprAST>(expr.get())) {
+ if (!mlirGen(*vardecl))
+ return mlir::failure();
+ continue;
+ }
+ if (auto *ret = dyn_cast<ReturnExprAST>(expr.get()))
+ return mlirGen(*ret);
+ if (auto *print = dyn_cast<PrintExprAST>(expr.get())) {
+ if (mlir::failed(mlirGen(*print)))
+ return mlir::success();
+ continue;
+ }
+
+ // Generic expression dispatch codegen.
+ if (!mlirGen(*expr))
+ return mlir::failure();
+ }
+ return mlir::success();
+ }
+
+ mlir::Value mlirGen(NumberExprAST &num) {
+ return builder.create<ConstantOp>(loc(num.loc()), num.getValue());
+ }
+
+ mlir::obs::FuncOp mlirGen(FunctionAST &funcAST) {
+ llvm::ScopedHashTableScope<llvm::StringRef, mlir::Value> varScope(symbolTable);
+
+ builder.setInsertionPointToEnd(theModule.getBody());
+ mlir::obs::FuncOp function = mlirGen(*funcAST.getProto());
+ if (! function) {
+ return nullptr;
+ }
+
+ mlir::Block &entryBlock = function.front();
+
+ auto protoArgs = funcAST.getProto()->getArgs();
+
+ for (const auto nameValue : llvm::zip(protoArgs, entryBlock.getArguments())) {
+ if (failed(declare(std::get<0>(nameValue)->getName(), std::get<1>(nameValue)))) {
+ return nullptr;
+ }
+ }
+
+ builder.setInsertionPointToStart( &entryBlock );
+
+ if (mlir::failed(mlirGen(*funcAST.getBody()))) {
+ function->erase();
+ return nullptr;
+ }
+
+ ReturnOp returnOp;
+ if (!entryBlock.empty()) {
+ returnOp = dyn_cast<ReturnOp>(entryBlock.back());
+ }
+ if (!returnOp) {
+ builder.create<ReturnOp>(loc(funcAST.getProto()->loc()));
+ } else if (returnOp.hasOperand()) {
+ function.setType(builder.getFunctionType(function.getFunctionType().getInputs(), getType(VarType{})));
+ }
+ return function;
+ }
+
+ mlir::Value mlirGen(BinaryExprAST &binop) {
+ mlir::Value lhs = mlirGen(*binop.getLHS());
+ if (!lhs)
+ return nullptr;
+ mlir::Value rhs = mlirGen(*binop.getRHS());
+ if (!rhs)
+ return nullptr;
+
+ auto location = loc(binop.loc());
+
+ switch (binop.getOp()) {
+ case '+':
+ return builder.create<AddOp>(location, lhs, rhs);
+ case '*':
+ return builder.create<MulOp>(location, lhs, rhs);
+ }
+
+ emitError(location, "invalid binary operator '") << binop.getOp() << "'";
+ return nullptr;
+ }
+
+ mlir::Value mlirGen(VariableExprAST &expr) {
+ if (auto variable = symbolTable.lookup(expr.getName())) {
+ return variable;
+ }
+ mlir::emitError(loc(expr.loc()), "error: unknown variable '") << expr.getName() << "'";
+ return nullptr;
+ }
+
+ mlir::LogicalResult mlirGen(ReturnExprAST &ret) {
+ auto location = loc(ret.loc());
+
+ mlir::Value expr = nullptr;
+ if (ret.getExpr().has_value()) {
+ if (!(expr = mlirGen(**ret.getExpr()))){
+ return mlir::failure();
+ }
+ }
+ builder.create<ReturnOp>(location, expr? ArrayRef(expr) : ArrayRef<mlir::Value>());
+ return mlir::success();
+ }
+
+ mlir::Value mlirGen(CallExprAST &call) {
+ llvm::StringRef callee = call.getCallee();
+ auto location = loc(call.loc());
+
+ SmallVector<mlir::Value, 4> operands;
+
+ for (auto &expr : call.getArgs()) {
+ auto arg = mlirGen(*expr);
+ if (!arg) {
+ return nullptr;
+ }
+ operands.push_back(arg);
+ }
+
+ if (callee == "transpose") {
+ if (call.getArgs().size() != 1) {
+ mlir::emitError(location, "MLIR codegen encountered an error: obs.transpose "
+ "does not accept multiple arguments.");
+ return nullptr;
+ }
+ return builder.create<TransposeOp>(location, operands[0]);
+ }
+ return builder.create<GenericCallOp>(location, callee, operands);
+ }
+
+ mlir::Value mlirGen(VarDeclExprAST &vardecl) {
+ auto *init = vardecl.getInitVal();
+
+ if (!init) {
+ mlir::emitError(loc(vardecl.loc()), "missing initializer in variable declaration");
+ return nullptr;
+ }
+
+ mlir::Value value = mlirGen(*init);
+ if (!value) {
+ return nullptr;
+ }
+
+ if (!vardecl.getType().shape.empty()) {
+ value = builder.create<ReshapeOp>(loc(vardecl.loc()), getType(vardecl.getType()), value);
+ }
+
+ if (failed(declare(vardecl.getName(), value))) {
+ return nullptr;
+ }
+ return value;
+ }
+
+};
+
+} //namespace
+
+namespace obs {
+mlir::OwningOpRef<mlir::ModuleOp> mlirGen(mlir::MLIRContext &context, ModuleAST &moduleAST) {
+ return MLIRGenImpl(context).mlirGen(moduleAST);
+}
+} //namespace obs
+
+
+
+
+
+
+
+
diff --git a/obs/obs-ir/obs.cpp b/obs/obs-ir/obs.cpp
index 379054e0f4ea7..ef5ed4d96bcd1 100644
--- a/obs/obs-ir/obs.cpp
+++ b/obs/obs-ir/obs.cpp
@@ -3,12 +3,19 @@
#include "Parser.h"
#include "Dialect.h"
+#include "MLIRGen.h"
+
+#include "mlir/IR/AsmState.h"
+#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/MLIRContext.h"
+#include "mlir/Parser/Parser.h"
+
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/ErrorOr.h"
#include "llvm/Support/MemoryBuffer.h"
#include "llvm/Support/raw_ostream.h"
+#include "llvm/Support/SourceMgr.h"
#include <iostream>
#include <memory>
#include <string>
@@ -19,16 +26,26 @@ using namespace obs;
namespace cl = llvm::cl;
static cl::opt<std::string> inputFilename(cl::Positional,
- cl::desc("<input toy file>"),
+ cl::desc("<input obs file>"),
cl::init("-"),
cl::value_desc("filename"));
namespace {
-enum Action { None, DumpAST };
+ enum InputType { OBS, MLIR };
+} //namespace
+static cl::opt<enum InputType> inputType(
+ "x", cl::init(OBS), cl::desc("Decided the kind of output desired"),
+ cl::values(clEnumValN(OBS, "obs", "load the input file as a OBS source.")),
+ cl::values(clEnumValN(MLIR, "mlir",
+ "load the input file as an MLIR file")));
+
+namespace {
+enum Action { None, DumpAST, DumpMLIR };
} // namespace
static cl::opt<enum Action> emitAction("emit", cl::desc("Select the kind of output desired"),
- cl::values(clEnumValN(DumpAST, "ast", "output the AST dump")));
+ cl::values(clEnumValN(DumpAST, "ast", "output the AST dump")),
+ cl::values(clEnumValN(DumpMLIR, "mlir", "output the MLIR dump")));
std::unique_ptr<obs::ModuleAST> parseInputFile(llvm::StringRef filename) {
llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> fileOrErr = llvm::MemoryBuffer::getFileOrSTDIN(filename);
@@ -45,20 +62,74 @@ std::unique_ptr<obs::ModuleAST> parseInputFile(llvm::StringRef filename) {
int dumpMLIR() {
mlir::MLIRContext context;
context.getOrLoadDialect<mlir::obs::OBSDialect>();
+
+ if (inputType != InputType::MLIR &&
+ !llvm::StringRef(inputFilename).ends_with(".mlir")) {
+ auto moduleAST = parseInputFile(inputFilename);
+ if (!moduleAST)
+ return 6;
+ mlir::OwningOpRef<mlir::ModuleOp> module = mlirGen(context, *moduleAST);
+ if (!module)
+ return 1;
+
+ module->dump();
+ return 0;
+ }
+
+ // Otherwise, the input is '.mlir'.
+ llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> fileOrErr =
+ llvm::MemoryBuffer::getFileOrSTDIN(inputFilename);
+ if (std::error_code ec = fileOrErr.getError()) {
+ llvm::errs() << "Could not open input file: " << ec.message() << "\n";
+ return -1;
+ }
+
+ // Parse the input mlir.
+ llvm::SourceMgr sourceMgr;
+ sourceMgr.AddNewSourceBuffer(std::move(*fileOrErr), llvm::SMLoc());
+ mlir::OwningOpRef<mlir::ModuleOp> module =
+ mlir::parseSourceFile<mlir::ModuleOp>(sourceMgr, &context);
+ if (!module) {
+ llvm::errs() << "Error can't load file " << inputFilename << "\n";
+ return 3;
+ }
+
+ module->dump();
return 0;
+
+
}
-int main(int argc, char **argv) {
- cl::ParseCommandLineOptions(argc, argv, "toy compiler\n");
+int dumpAST() {
+ if (inputType == InputType::MLIR) {
+ llvm::errs() << "Can't dump a OBS AST when the input is MLIR\n";
+ return 5;
+ }
+
auto moduleAST = parseInputFile(inputFilename);
+ if (!moduleAST)
+ return 1;
+
+ dump(*moduleAST);
+ return 0;
+}
+
+int main(int argc, char **argv) {
+ // Register any command line options.
+ mlir::registerAsmPrinterCLOptions();
+ mlir::registerMLIRContextCLOptions();
+ cl::ParseCommandLineOptions(argc, argv, "obs compiler\n");
- switch(emitAction) {
- case Action::DumpAST:
- dump(*moduleAST);
- return 0;
- default:
- llvm::errs() << "No action specified (parsing only?), use -emit=<action>\n";
+ switch (emitAction) {
+ case Action::DumpAST:
+ return dumpAST();
+ case Action::DumpMLIR:
+ return dumpMLIR();
+ default:
+ llvm::errs() << "No action specified (parsing only?), use -emit=<action>\n";
}
+
return 0;
+
}
>From e5caa51de291b64d2cb2f0cd661893df80e643e6 Mon Sep 17 00:00:00 2001
From: Shuanglong Kan <kanshuanglong at outlook.com>
Date: Fri, 31 May 2024 08:03:09 +0100
Subject: [PATCH 5/5] feat: add OBS operations
---
mlir/test/Examples/Toy/Ch2/codegen.toy | 31 -----
obs/.clang-format | 1 +
obs/CMakeLists.txt | 51 ++++++--
obs/codegen/CodeGen.cpp | 55 +++++++++
obs/codegen/CodeGenAction.cpp | 30 +++++
obs/codegen/OBSGen.cpp | 27 ++++
obs/include/CodeGen.h | 41 +++++++
obs/include/CodeGenAction.h | 43 +++++++
obs/include/Dialect.h | 50 ++++++++
obs/include/OBSGen.h | 2 +
obs/include/Ops.td | 125 ++++++++++++++++++-
obs/obs-ir/Dialect.cpp | 163 +++++++++++++++++++++++++
obs/obs-ir/MLIRGen.cpp | 50 +++++++-
obs/test/codegen.toy | 5 +
obs/test/test1.cpp | 9 ++
15 files changed, 636 insertions(+), 47 deletions(-)
delete mode 100644 mlir/test/Examples/Toy/Ch2/codegen.toy
create mode 100644 obs/.clang-format
create mode 100644 obs/codegen/CodeGen.cpp
create mode 100644 obs/codegen/CodeGenAction.cpp
create mode 100644 obs/codegen/OBSGen.cpp
create mode 100644 obs/include/CodeGen.h
create mode 100644 obs/include/CodeGenAction.h
create mode 100644 obs/include/OBSGen.h
create mode 100644 obs/test/codegen.toy
create mode 100644 obs/test/test1.cpp
diff --git a/mlir/test/Examples/Toy/Ch2/codegen.toy b/mlir/test/Examples/Toy/Ch2/codegen.toy
deleted file mode 100644
index 12178d6afb309..0000000000000
--- a/mlir/test/Examples/Toy/Ch2/codegen.toy
+++ /dev/null
@@ -1,31 +0,0 @@
-# RUN: toyc-ch2 %s -emit=mlir 2>&1 | FileCheck %s
-
-# User defined generic function that operates on unknown shaped arguments
-def multiply_transpose(a, b) {
- return transpose(a) * transpose(b);
-}
-
-def main() {
- var a<2, 3> = [[1, 2, 3], [4, 5, 6]];
- var b<2, 3> = [1, 2, 3, 4, 5, 6];
- var c = multiply_transpose(a, b);
- var d = multiply_transpose(b, a);
- print(d);
-}
-
-# CHECK-LABEL: toy.func @multiply_transpose(
-# CHECK-SAME: [[VAL_0:%.*]]: tensor<*xf64>, [[VAL_1:%.*]]: tensor<*xf64>) -> tensor<*xf64>
-# CHECK: [[VAL_2:%.*]] = toy.transpose([[VAL_0]] : tensor<*xf64>) to tensor<*xf64>
-# CHECK-NEXT: [[VAL_3:%.*]] = toy.transpose([[VAL_1]] : tensor<*xf64>) to tensor<*xf64>
-# CHECK-NEXT: [[VAL_4:%.*]] = toy.mul [[VAL_2]], [[VAL_3]] : tensor<*xf64>
-# CHECK-NEXT: toy.return [[VAL_4]] : tensor<*xf64>
-
-# CHECK-LABEL: toy.func @main()
-# CHECK-NEXT: [[VAL_5:%.*]] = toy.constant dense<{{\[\[}}1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>
-# CHECK-NEXT: [[VAL_6:%.*]] = toy.reshape([[VAL_5]] : tensor<2x3xf64>) to tensor<2x3xf64>
-# CHECK-NEXT: [[VAL_7:%.*]] = toy.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00]> : tensor<6xf64>
-# CHECK-NEXT: [[VAL_8:%.*]] = toy.reshape([[VAL_7]] : tensor<6xf64>) to tensor<2x3xf64>
-# CHECK-NEXT: [[VAL_9:%.*]] = toy.generic_call @multiply_transpose([[VAL_6]], [[VAL_8]]) : (tensor<2x3xf64>, tensor<2x3xf64>) -> tensor<*xf64>
-# CHECK-NEXT: [[VAL_10:%.*]] = toy.generic_call @multiply_transpose([[VAL_8]], [[VAL_6]]) : (tensor<2x3xf64>, tensor<2x3xf64>) -> tensor<*xf64>
-# CHECK-NEXT: toy.print [[VAL_10]] : tensor<*xf64>
-# CHECK-NEXT: toy.return
diff --git a/obs/.clang-format b/obs/.clang-format
new file mode 100644
index 0000000000000..468ceea01f1a0
--- /dev/null
+++ b/obs/.clang-format
@@ -0,0 +1 @@
+BasedOnStyle: LLVM
\ No newline at end of file
diff --git a/obs/CMakeLists.txt b/obs/CMakeLists.txt
index 05dcf074a73a6..5dd9548f6a956 100644
--- a/obs/CMakeLists.txt
+++ b/obs/CMakeLists.txt
@@ -5,9 +5,12 @@ set(LLVM_LINK_COMPONENTS
include(CMakeDependentOption)
include(GNUInstallDirs)
-set(MLIR_MAIN_SRC_DIR ${LLVM_MAIN_SRC_DIR}/../mlir/include ) # --src-root
-set(MLIR_INCLUDE_DIR ${LLVM_MAIN_SRC_DIR}/../mlir/include ) # --includedir
-set(MLIR_TABLEGEN_OUTPUT_DIR ${CMAKE_BINARY_DIR}/tools/mlir/include)
+set(MLIR_INCLUDE_DIR ${LLVM_MAIN_SRC_DIR}/../mlir/include/ ) # --includedir
+set(CLANG_INCLUDE_DIR ${LLVM_MAIN_SRC_DIR}/../clang/include/ ) # --includedir
+set(CLANG_INCLUDE_DIR_INC ${CMAKE_BINARY_DIR}/tools/clang/include/ ) # --includedir
+set(MLIR_TABLEGEN_OUTPUT_DIR ${CMAKE_BINARY_DIR}/tools/mlir/include/)
+include_directories(SYSTEM ${CLANG_INCLUDE_DIR})
+include_directories(SYSTEM ${CLANG_INCLUDE_DIR_INC})
include_directories(SYSTEM ${MLIR_INCLUDE_DIR})
include_directories(SYSTEM ${MLIR_TABLEGEN_OUTPUT_DIR})
@@ -26,21 +29,47 @@ include(CMakeParseArguments)
include(AddOBS)
include_directories(include/)
-add_obs_executable(obstest
- obs-ir/obs.cpp
- obs-ir/Dialect.cpp
- obs-ir/MLIRGen.cpp
- parser/AST.cpp
+#add_obs_executable(obstest
+# obs-ir/obs.cpp
+# obs-ir/Dialect.cpp
+# obs-ir/MLIRGen.cpp
+# parser/AST.cpp
+# DEPENDS
+# OBSOpGen
+#)
+
+#target_link_libraries(obstest
+# PRIVATE
+# MLIRAnalysis
+# MLIRFunctionInterfaces
+# MLIRIR
+# MLIRParser
+# MLIRSideEffectInterfaces
+# MLIRTransforms
+#)
+
+
+add_obs_executable(codegen
+ codegen/OBSGen.cpp
+ codegen/CodeGenAction.cpp
+ codegen/CodeGen.cpp
DEPENDS
OBSOpGen
- )
+)
-target_link_libraries(obstest
+target_link_libraries(codegen
PRIVATE
+ clangAST
+ clangBasic
+ clangFrontend
+ clangSerialization
+ clangTooling
MLIRAnalysis
MLIRFunctionInterfaces
MLIRIR
MLIRParser
MLIRSideEffectInterfaces
MLIRTransforms
-)
\ No newline at end of file
+)
+
+
diff --git a/obs/codegen/CodeGen.cpp b/obs/codegen/CodeGen.cpp
new file mode 100644
index 0000000000000..d50de0431cc3b
--- /dev/null
+++ b/obs/codegen/CodeGen.cpp
@@ -0,0 +1,55 @@
+
+#include "MLIRGen.h"
+
+#include "AST.h"
+#include "Dialect.h"
+#include "mlir/IR/Block.h"
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/Value.h"
+#include "mlir/Support/LogicalResult.h"
+
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/ScopedHashTable.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/StringRef.h"
+#include "llvm/ADT/Twine.h"
+#include "llvm/Support/Casting.h"
+#include "llvm/Support/raw_ostream.h"
+
+#include <cassert>
+#include <cstdint>
+#include <functional>
+
+#include <mlir/IR/BuiltinAttributes.h>
+#include <mlir/IR/Location.h>
+#include <mlir/IR/OwningOpRef.h>
+
+#include <iostream>
+#include <numeric>
+#include <optional>
+#include <vector>
+
+using llvm::ArrayRef;
+using llvm::cast;
+using llvm::dyn_cast;
+using llvm::isa;
+using llvm::ScopedHashTableScope;
+using llvm::SmallVector;
+using llvm::StringRef;
+using llvm::Twine;
+
+#include "CodeGen.h"
+
+namespace mlir {
+namespace obs {
+
+bool MLIRGenImpl::VisitFunctionDecl(clang::FunctionDecl *funcDecl) {
+ llvm::outs() << "VisitFunctionDecl: ";
+ funcDecl->getDeclName().dump();
+ llvm::outs() << "\n";
+ return false;
+}
+
+} // namespace obs
+} // namespace mlir
diff --git a/obs/codegen/CodeGenAction.cpp b/obs/codegen/CodeGenAction.cpp
new file mode 100644
index 0000000000000..94fa18dcd188d
--- /dev/null
+++ b/obs/codegen/CodeGenAction.cpp
@@ -0,0 +1,30 @@
+
+
+#include "CodeGenAction.h"
+#include "CodeGen.h"
+
+#include "llvm/Support/Casting.h"
+#include "llvm/Support/raw_ostream.h"
+
+#include <clang/AST/ASTContext.h>
+#include <clang/AST/Decl.h>
+#include <clang/AST/DeclGroup.h>
+
+#include <iostream>
+#include <mlir/IR/MLIRContext.h>
+#include <ostream>
+
+using namespace clang;
+
+namespace mlir {
+namespace obs {
+
+void CodeGenConsumer::HandleTranslationUnit(ASTContext &context) {
+ llvm::outs() << "Enter HandleTranslationUnit\n";
+ MLIRContext codegenContext;
+ MLIRGenImpl mlirGen(codegenContext);
+ mlirGen.TraverseDecl(context.getTranslationUnitDecl());
+}
+
+} // namespace obs
+} // namespace mlir
\ No newline at end of file
diff --git a/obs/codegen/OBSGen.cpp b/obs/codegen/OBSGen.cpp
new file mode 100644
index 0000000000000..7cd866aa124ca
--- /dev/null
+++ b/obs/codegen/OBSGen.cpp
@@ -0,0 +1,27 @@
+
+#include "llvm/Support/CommandLine.h"
+#include <iostream>
+#include <vector>
+
+#include "OBSGen.h"
+
+// `main` function translates a C program into a OBS.
+int main(int argc, const char **argv) {
+
+ llvm::cl::OptionCategory CodeGenCategory("OBS code generation");
+ auto OptionsParser =
+ clang::tooling::CommonOptionsParser::create(argc, argv, CodeGenCategory);
+
+ if (!OptionsParser) {
+ // Fail gracefully for unsupported options.
+ std::cout << "error ----" << std::endl;
+ llvm::errs() << OptionsParser.takeError() << "error";
+ return 1;
+ }
+
+ auto sources = OptionsParser->getSourcePathList();
+
+ clang::tooling::ClangTool Tool(OptionsParser->getCompilations(), sources);
+
+ Tool.run(new mlir::obs::CodeGenFrontendActionFactory());
+}
diff --git a/obs/include/CodeGen.h b/obs/include/CodeGen.h
new file mode 100644
index 0000000000000..d36a8f5ff4e9b
--- /dev/null
+++ b/obs/include/CodeGen.h
@@ -0,0 +1,41 @@
+
+#ifndef _OBS_CODEGEN_H
+#define _OBS_CODEGEN_H
+
+#include "clang/AST/RecursiveASTVisitor.h"
+#include "clang/Basic/SourceManager.h"
+#include <clang/AST/ASTContext.h>
+#include <clang/AST/Decl.h>
+#include <memory>
+
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/Verifier.h"
+
+#include <mlir/IR/BuiltinOps.h>
+#include <mlir/IR/MLIRContext.h>
+#include <mlir/IR/OwningOpRef.h>
+
+namespace mlir {
+namespace obs {
+
+class MLIRGenImpl : public clang::RecursiveASTVisitor<MLIRGenImpl> {
+public:
+ MLIRGenImpl(mlir::MLIRContext &context) : builder(&context) {}
+
+ bool VisitFunctionDecl(clang::FunctionDecl *funcDecl);
+
+private:
+ mlir::ModuleOp theModule;
+ mlir::OpBuilder builder;
+};
+
+mlir::OwningOpRef<mlir::ModuleOp> mlirGen(mlir::MLIRContext &context,
+ clang::TranslationUnitDecl &decl);
+
+} // namespace obs
+} // namespace mlir
+
+#endif
\ No newline at end of file
diff --git a/obs/include/CodeGenAction.h b/obs/include/CodeGenAction.h
new file mode 100644
index 0000000000000..226a24ee78253
--- /dev/null
+++ b/obs/include/CodeGenAction.h
@@ -0,0 +1,43 @@
+
+
+#include "clang/AST/AST.h"
+#include "clang/AST/ASTConsumer.h"
+#include "clang/AST/ASTContext.h"
+#include "clang/Basic/LLVM.h"
+#include "clang/Frontend/CompilerInstance.h"
+#include "clang/Frontend/FrontendAction.h"
+#include "clang/Frontend/PrecompiledPreamble.h"
+#include "clang/Tooling/CommonOptionsParser.h"
+#include "clang/Tooling/Tooling.h"
+#include <iostream>
+#include <memory>
+
+namespace mlir {
+namespace obs {
+
+class CodeGenConsumer : public clang::ASTConsumer {
+public:
+ CodeGenConsumer() {}
+ void HandleTranslationUnit(clang::ASTContext &context) override;
+};
+
+class CodeGenFrontendAction : public clang::ASTFrontendAction {
+protected:
+ std::unique_ptr<clang::ASTConsumer>
+ CreateASTConsumer(clang::CompilerInstance &ci, clang::StringRef) override {
+ return std::make_unique<CodeGenConsumer>();
+ }
+};
+
+class CodeGenFrontendActionFactory
+ : public clang::tooling::FrontendActionFactory {
+public:
+ CodeGenFrontendActionFactory() {}
+
+ std::unique_ptr<clang::FrontendAction> create() override {
+ return std::make_unique<CodeGenFrontendAction>();
+ }
+};
+
+} // namespace obs
+} // namespace mlir
\ No newline at end of file
diff --git a/obs/include/Dialect.h b/obs/include/Dialect.h
index a298ed2f086a0..e87c75f08d864 100644
--- a/obs/include/Dialect.h
+++ b/obs/include/Dialect.h
@@ -9,9 +9,59 @@
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "Dialect.h.inc"
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/StringRef.h"
+#include <mlir/IR/BuiltinTypes.h>
+#include <mlir/IR/Types.h>
+
+namespace mlir {
+namespace obs {
+namespace user_types {
+ struct OwnTypeStorage;
+ struct RefTypeStorage;
+}//user_types
+}//obs
+}//mlir
#define GET_OP_CLASSES
#include "Ops.h.inc"
+
+namespace mlir {
+namespace obs {
+
+class OwnType : public mlir::Type::TypeBase<OwnType, mlir::Type, user_types::OwnTypeStorage> {
+public:
+ using Base::Base;
+
+ //Create an instance of `OwnType`.
+ static OwnType get(mlir::MLIRContext *ctx, StringRef resName, ArrayRef<unsigned int> dims);
+
+ //Return the owned resource name.
+ StringRef getResName() ;
+
+ //Return the dims of the owned resource.
+ ArrayRef<unsigned int> getDims();
+
+ static constexpr llvm::StringLiteral name = "obs.OWN";
+};
+
+class RefType : public mlir::Type::TypeBase<RefType, mlir::Type, user_types::RefTypeStorage> {
+public:
+ using Base::Base;
+
+ //Create an instance of `OwnType`.
+ static RefType get(ArrayRef<mlir::Type> ownerType);
+
+ //Return the owned resource name.
+ ArrayRef<mlir::Type> getOwnerType() ;
+
+ static constexpr llvm::StringLiteral name = "obs.REF";
+};
+
+} //obs
+} //mlir
+
+
#endif //MLIR_DIALECT_TRAITS_H
diff --git a/obs/include/OBSGen.h b/obs/include/OBSGen.h
new file mode 100644
index 0000000000000..740bfcdd7ba1e
--- /dev/null
+++ b/obs/include/OBSGen.h
@@ -0,0 +1,2 @@
+
+#include "CodeGenAction.h"
\ No newline at end of file
diff --git a/obs/include/Ops.td b/obs/include/Ops.td
index 9463de72f2614..40c8e6cb330ff 100644
--- a/obs/include/Ops.td
+++ b/obs/include/Ops.td
@@ -11,18 +11,137 @@ def OBS_Dialect: Dialect {
//The namespace of the Dialect.
let name = "obs";
- let summary = "A high-level dialect for analyzing and optimzing the Language";
+ let summary = "A dialect for representing abstract memory operations.";
let description = [{
- The Toy language is a tensor-based language that allows you to define functions, perform some math computation, and
- print results. This dialect provides a reprentation of the language that is amenable to analysis and optimization.
+ The OBS dialect is an abstraction over memory operations.
+ It is used to reason about Ownership and Borrowing in a programming language.
+ It contains the operations:
+ (1) OwnOp(type, Dim);
+ (2) RefOp(x)
+ (3) Read(x);
+ (4) Write(x);
+ (5) delete(x);
+ (6) Function;
}];
let cppNamespace = "::mlir::obs";
+
+ let useDefaultTypePrinterParser = 1;
}
class OBS_op<string mnemonic, list<Trait> traits = []> : Op<OBS_Dialect, mnemonic, traits>;
+def OBS_OwnType :
+ DialectType<OBS_Dialect, CPred<"::llvm::isa<OwnType>($_self)">, "OBS own type">;
+
+def OBS_RefType :
+ DialectType<OBS_Dialect, CPred<"::llvm::isa<RefType>($_self)">, "OBS ref type">;
+
+def OBS_Type : AnyTypeOf<[OBS_OwnType, OBS_RefType]>;
+
+def OwnOp : OBS_op<"Own", []> {
+
+ let summary = "Allocate a memory resource and bind to an owner." ;
+
+ let description = [{
+ This operation creates a memory resource and bind it to an owner.
+ }];
+
+
+ let arguments = (ins StrAttr:$type, F64Tensor:$dim);
+ let results = (outs OBS_OwnType);
+
+ let assemblyFormat = [{
+ `(` attr-dict `,` $dim `:` type($dim) `)` `:` type(results)
+ }];
+
+
+ let hasVerifier = 1;
+
+ let builders = [
+ OpBuilder<(ins "StringRef":$type, "Value":$dim )>
+ ];
+}
+
+def RefOp : OBS_op<"Ref", []> {
+
+ let summary = "Create a reference to a resource." ;
+
+ let description = [{
+ This operation creates a reference to a resource.
+ }];
+
+ let arguments = (ins OBS_Type:$ownType);
+ let results = (outs OBS_RefType);
+
+ let assemblyFormat = [{
+ `(` attr-dict $ownType `:` type($ownType) `)` `:` type(results)
+ }];
+
+
+ let hasVerifier = 1;
+
+ let builders = [
+ OpBuilder<(ins "Value":$ownType )>
+ ];
+}
+
+def ReadOp : OBS_op<"Read", []> {
+
+ let summary = "Read the resource though a variable" ;
+
+ let description = [{
+ This operation creates a read operation.
+ }];
+
+ let arguments = (ins OBS_Type:$ownType);
+
+ let assemblyFormat = [{
+ `(` attr-dict $ownType `:` type($ownType) `)`
+ }];
+
+
+ let hasVerifier = 1;
+}
+
+def WriteOp : OBS_op<"Write", []> {
+
+ let summary = "Write the resource though a variable" ;
+
+ let description = [{
+ This operation creates a write operation.
+ }];
+
+ let arguments = (ins OBS_Type:$ownType);
+
+ let assemblyFormat = [{
+ `(` attr-dict $ownType `:` type($ownType) `)`
+ }];
+
+
+ let hasVerifier = 1;
+}
+
+def DeleteOp : OBS_op<"Delete", []> {
+
+ let summary = "Delete a resource through its owner." ;
+
+ let description = [{
+ This operation creates a delete operation.
+ }];
+
+ let arguments = (ins OBS_Type:$ownType);
+
+ let assemblyFormat = [{
+ `(` attr-dict $ownType `:` type($ownType) `)`
+ }];
+
+
+ let hasVerifier = 1;
+}
+
+
def ConstantOp : OBS_op<"constant", [Pure]> {
let summary = "constant operation" ;
diff --git a/obs/obs-ir/Dialect.cpp b/obs/obs-ir/Dialect.cpp
index c5a4ef2e5c310..d115c4a887efa 100644
--- a/obs/obs-ir/Dialect.cpp
+++ b/obs/obs-ir/Dialect.cpp
@@ -11,11 +11,21 @@
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/Hashing.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/Casting.h"
+#include "llvm/Support/VersionTuple.h"
#include <algorithm>
+#include <mlir/IR/DialectImplementation.h>
+#include <mlir/IR/MLIRContext.h>
+#include <mlir/IR/TypeRange.h>
+#include <mlir/IR/TypeSupport.h>
+#include <mlir/IR/Types.h>
+#include <mlir/Support/TypeID.h>
#include <string>
+#include <utility>
+#include <vector>
using namespace mlir;
using namespace mlir::obs;
@@ -27,8 +37,12 @@ void OBSDialect::initialize() {
#define GET_OP_LIST
#include "Ops.cpp.inc"
>();
+
+ addTypes<OwnType>();
+ addTypes<RefType>();
}
+
static mlir::ParseResult parseBinaryOp(mlir::OpAsmParser &parser, mlir::OperationState &result) {
SmallVector<mlir::OpAsmParser::UnresolvedOperand, 2> operands;
@@ -67,6 +81,7 @@ static void printBinary(mlir::OpAsmPrinter &printer, mlir::Operation *op) {
printer.printFunctionalType(op->getOperandTypes(), op->getResultTypes());
}
+
void ConstantOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, double value) {
auto dataType = RankedTensorType::get({}, builder.getF64Type());
auto dataAttribute = DenseElementsAttr::get(dataType, value);
@@ -90,6 +105,34 @@ void ConstantOp::print(mlir::OpAsmPrinter &printer) {
printer << getValue();
}
+mlir::LogicalResult OwnOp::verify() {
+ if (getType() == "vector") {
+ return success();
+ } else {
+ return failure();
+ }
+}
+
+mlir::LogicalResult RefOp::verify() {
+ //TODO: add complete verifying.
+ return success();
+}
+
+mlir::LogicalResult WriteOp::verify() {
+ //TODO: add complete verifying.
+ return success();
+}
+
+mlir::LogicalResult ReadOp::verify() {
+ //TODO: add complete verifying.
+ return success();
+}
+
+mlir::LogicalResult DeleteOp::verify() {
+ //TODO: add complete verifying.
+ return success();
+}
+
mlir::LogicalResult ConstantOp::verify() {
auto resultType = llvm::dyn_cast<mlir::RankedTensorType>(getResult().getType());
if (!resultType)
@@ -217,6 +260,126 @@ mlir::LogicalResult TransposeOp::verify() {
return mlir::success();
}
+namespace mlir {
+namespace obs {
+
+namespace user_types {
+/*
+ * The definition of the owner type.
+ **/
+struct OwnTypeStorage : public mlir::TypeStorage {
+
+
+
+ // StringRef is the name of the resource.
+ // RankedTensorType is the dimensions.
+ using KeyTy = std::pair<StringRef, ArrayRef<unsigned int>>;
+
+ OwnTypeStorage(StringRef resName, ArrayRef<unsigned int> dims): resName(resName), dims(dims) {}
+
+ static OwnTypeStorage *construct(mlir::TypeStorageAllocator &allocator, const KeyTy &key) {
+ StringRef resName = allocator.copyInto(key.first);
+ ArrayRef<unsigned int> dim = allocator.copyInto(key.second);
+ return new (allocator.allocate<OwnTypeStorage>()) OwnTypeStorage(resName, dim);
+ }
+
+ static llvm::hash_code hashKey(const KeyTy &key) {
+ return llvm::hash_value(key);
+ }
+
+ static KeyTy getKey(StringRef resName, ArrayRef<unsigned int> dims) {
+ return KeyTy(resName, dims);
+ }
+
+ bool operator==(const KeyTy &key) const {
+ return (key.first == resName && key.second == dims);
+ }
+
+ StringRef resName;
+ ArrayRef<unsigned int> dims;
+};
+
+/*
+ * The definition of the reference type.
+ **/
+struct RefTypeStorage : public mlir::TypeStorage {
+
+ // Here we use array, but usually it contains only one element.
+ using KeyTy = ArrayRef<mlir::Type>;
+
+ RefTypeStorage(ArrayRef<mlir::Type> ownerType): ownerType(ownerType) {}
+
+ static RefTypeStorage *construct(mlir::TypeStorageAllocator &allocator, const KeyTy &key) {
+
+ ArrayRef<mlir::Type> ownerType = allocator.copyInto(key);
+ return new (allocator.allocate<RefTypeStorage>()) RefTypeStorage(ownerType);
+ }
+
+ static llvm::hash_code hashKey(const KeyTy &key) {
+ return llvm::hash_value(key);
+ }
+
+ static KeyTy getKey(ArrayRef<mlir::Type> ownerType) {
+ return KeyTy(ownerType);
+ }
+
+ bool operator==(const KeyTy &key) const {
+ return (key == ownerType);
+ }
+
+ // The resource type it refers to.
+ ArrayRef<mlir::Type> ownerType;
+};
+
+} //user_types
+} //obs
+} //mlir
+
+OwnType OwnType::get(mlir::MLIRContext *ctx, StringRef resName, ArrayRef<unsigned int> dims) {
+ return Base::get(ctx, resName, dims);
+}
+
+StringRef OwnType::getResName() {
+ return getImpl()->resName;
+}
+
+ArrayRef<unsigned int> OwnType::getDims() {
+ return getImpl()->dims;
+}
+
+RefType RefType::get(ArrayRef<mlir::Type> ownType) {
+ return Base::get(ownType.front().getContext(), ownType);
+}
+
+ArrayRef<mlir::Type> RefType::getOwnerType() {
+ return getImpl()->ownerType;
+}
+
+
+void OBSDialect::printType(::mlir::Type type,
+ ::mlir::DialectAsmPrinter &printer) const {
+ if (llvm::isa<OwnType>(type)) {
+ OwnType ownType = llvm::cast<OwnType>(type);
+ printer << "Own( ";
+ printer << ownType.getResName();
+ printer << ", [";
+ llvm::interleaveComma(ownType.getDims(), printer);
+ printer << " ] )";
+ } else if (RefType refType = llvm::cast<RefType>(type)) {
+ printer << "Ref( ";
+ printer << refType.getOwnerType() ;
+ printer << " )";
+ }
+}
+
+mlir::Type OBSDialect::parseType(mlir::DialectAsmParser &parser) const {
+ //TODO: complete parseType
+ if (parser.parseKeyword("Own") || parser.parseLess())
+ return Type();
+ return Type();
+}
+
+
#define GET_OP_CLASSES
#include "Ops.cpp.inc"
diff --git a/obs/obs-ir/MLIRGen.cpp b/obs/obs-ir/MLIRGen.cpp
index 388d7e5b5fdb0..286bdf63cb245 100644
--- a/obs/obs-ir/MLIRGen.cpp
+++ b/obs/obs-ir/MLIRGen.cpp
@@ -21,13 +21,18 @@
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/Twine.h"
+#include "llvm/IR/DerivedTypes.h"
+#include "llvm/IR/Type.h"
#include "llvm/Support/Casting.h"
#include <cassert>
#include <cstdint>
#include <functional>
+#include <mlir-c/BuiltinTypes.h>
#include <mlir/IR/BuiltinAttributes.h>
#include <mlir/IR/Location.h>
#include <mlir/IR/OwningOpRef.h>
+#include <mlir/IR/Types.h>
+#include <mlir/Support/LLVM.h>
#include <numeric>
#include <optional>
#include <vector>
@@ -150,12 +155,23 @@ class MLIRGenImpl {
}
mlir::LogicalResult mlirGen(PrintExprAST &call) {
+
+ auto arg = mlirGen(*call.getArg());
+ if (!arg) {
+ return mlir::failure();
+ }
+
+ auto ownType = mlir::obs::OwnType::get(theModule.getContext(), "vector", {1,2});
+ auto refType = mlir::obs::RefType::get({ownType});
+ builder.create<RefOp>(loc(call.loc()), refType , arg);
+
+ /*
auto arg = mlirGen(*call.getArg());
if (!arg) {
return mlir::failure();
}
- builder.create<PrintOp>(loc(call.loc()), arg);
+ builder.create<PrintOp>(loc(call.loc()), arg); */
return mlir::success();
}
@@ -248,6 +264,7 @@ class MLIRGenImpl {
mlir::Value mlirGen(VariableExprAST &expr) {
if (auto variable = symbolTable.lookup(expr.getName())) {
+ builder.create<ReadOp>(loc(expr.loc()), variable);
return variable;
}
mlir::emitError(loc(expr.loc()), "error: unknown variable '") << expr.getName() << "'";
@@ -293,6 +310,35 @@ class MLIRGenImpl {
}
mlir::Value mlirGen(VarDeclExprAST &vardecl) {
+
+ auto *init = vardecl.getInitVal();
+
+ if (!init) {
+ mlir::emitError(loc(vardecl.loc()), "missing initializer in variable declaration");
+ return nullptr;
+ }
+
+
+
+ mlir::Value value = mlirGen(*init);
+
+ StringRef type = "vector";
+
+ auto ownType = mlir::obs::OwnType::get(theModule.getContext(), "vector", {1,2});
+
+ std::vector<double> data = {1, 2};
+
+
+ builder.getI32VectorAttr({1, 2});
+ value = builder.create<OwnOp>(loc(vardecl.loc()), ownType , type, value);
+
+ if (failed(declare(vardecl.getName(), value))) {
+ return nullptr;
+ }
+
+ return value;
+
+/*
auto *init = vardecl.getInitVal();
if (!init) {
@@ -312,7 +358,7 @@ class MLIRGenImpl {
if (failed(declare(vardecl.getName(), value))) {
return nullptr;
}
- return value;
+ return value; */
}
};
diff --git a/obs/test/codegen.toy b/obs/test/codegen.toy
new file mode 100644
index 0000000000000..05c2d9d1b9f04
--- /dev/null
+++ b/obs/test/codegen.toy
@@ -0,0 +1,5 @@
+
+def main() {
+ var a<2, 3> = [[1, 2, 3], [4, 5, 6]];
+ print(a);
+}
diff --git a/obs/test/test1.cpp b/obs/test/test1.cpp
new file mode 100644
index 0000000000000..3d6f2a47fb915
--- /dev/null
+++ b/obs/test/test1.cpp
@@ -0,0 +1,9 @@
+
+void foo() {
+
+}
+
+int main(){
+ int x;
+ return 1;
+}
\ No newline at end of file
More information about the Mlir-commits
mailing list