[Mlir-commits] [llvm] [mlir] Op definition (PR #93931)

Shuanglong Kan llvmlistbot at llvm.org
Fri May 31 00:04:06 PDT 2024


https://github.com/ShlKan created https://github.com/llvm/llvm-project/pull/93931

None

>From 1037c321537642a7df35d33d4a10a6ad5d806c9c Mon Sep 17 00:00:00 2001
From: Shuanglong Kan <kanshuanglong at outlook.com>
Date: Thu, 21 Mar 2024 15:48:40 +0100
Subject: [PATCH 1/5] feat: initial

---
 llvm/CMakeLists.txt              |  2 +-
 llvm/tools/CMakeLists.txt        |  1 +
 obs/CMakeLists.txt               | 15 +++++++++++++++
 obs/cmake/modules/AddOBS.cmake   |  5 +++++
 obs/cmake/modules/CMakeLists.txt |  0
 obs/obs-ir/CMakeLists.txt        |  5 +++++
 obs/obs-ir/test.cpp              |  5 +++++
 7 files changed, 32 insertions(+), 1 deletion(-)
 create mode 100644 obs/CMakeLists.txt
 create mode 100644 obs/cmake/modules/AddOBS.cmake
 create mode 100644 obs/cmake/modules/CMakeLists.txt
 create mode 100644 obs/obs-ir/CMakeLists.txt
 create mode 100644 obs/obs-ir/test.cpp

diff --git a/llvm/CMakeLists.txt b/llvm/CMakeLists.txt
index 6f5647d70d8bc..d3eaa7670eb15 100644
--- a/llvm/CMakeLists.txt
+++ b/llvm/CMakeLists.txt
@@ -114,7 +114,7 @@ endif()
 # LLVM_EXTERNAL_${project}_SOURCE_DIR using LLVM_ALL_PROJECTS
 # This allows an easy way of setting up a build directory for llvm and another
 # one for llvm+clang+... using the same sources.
-set(LLVM_ALL_PROJECTS "bolt;clang;clang-tools-extra;compiler-rt;cross-project-tests;libc;libclc;lld;lldb;mlir;openmp;polly;pstl")
+set(LLVM_ALL_PROJECTS "bolt;clang;clang-tools-extra;compiler-rt;cross-project-tests;libc;libclc;lld;lldb;mlir;obs;openmp;polly;pstl")
 # The flang project is not yet part of "all" projects (see C++ requirements)
 set(LLVM_EXTRA_PROJECTS "flang")
 # List of all known projects in the mono repo
diff --git a/llvm/tools/CMakeLists.txt b/llvm/tools/CMakeLists.txt
index c6116ac81d12b..ac1728854a765 100644
--- a/llvm/tools/CMakeLists.txt
+++ b/llvm/tools/CMakeLists.txt
@@ -41,6 +41,7 @@ add_llvm_external_project(clang)
 add_llvm_external_project(lld)
 add_llvm_external_project(lldb)
 add_llvm_external_project(mlir)
+add_llvm_external_project(obs)
 # Flang depends on mlir, so place it afterward
 add_llvm_external_project(flang)
 add_llvm_external_project(bolt)
diff --git a/obs/CMakeLists.txt b/obs/CMakeLists.txt
new file mode 100644
index 0000000000000..6353f1c09cde4
--- /dev/null
+++ b/obs/CMakeLists.txt
@@ -0,0 +1,15 @@
+include(CMakeDependentOption)
+include(GNUInstallDirs)
+
+# Make sure that our source directory is on the current cmake module path so
+# that we can include cmake files from this directory.
+list(INSERT CMAKE_MODULE_PATH 0
+  "${CMAKE_CURRENT_SOURCE_DIR}/cmake/modules"
+  "${LLVM_COMMON_CMAKE_UTILS}/Modules"
+  )
+
+include(CMakeParseArguments)
+include(AddOBS)
+
+
+add_subdirectory(obs-ir)
\ No newline at end of file
diff --git a/obs/cmake/modules/AddOBS.cmake b/obs/cmake/modules/AddOBS.cmake
new file mode 100644
index 0000000000000..5c7420aad0312
--- /dev/null
+++ b/obs/cmake/modules/AddOBS.cmake
@@ -0,0 +1,5 @@
+macro(add_obs_executable name)
+  add_llvm_executable( ${name} ${ARGN} )
+  set_target_properties(${name} PROPERTIES FOLDER "OBS executables")
+endmacro(add_obs_executable)
+
diff --git a/obs/cmake/modules/CMakeLists.txt b/obs/cmake/modules/CMakeLists.txt
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/obs/obs-ir/CMakeLists.txt b/obs/obs-ir/CMakeLists.txt
new file mode 100644
index 0000000000000..f418db1ce4b20
--- /dev/null
+++ b/obs/obs-ir/CMakeLists.txt
@@ -0,0 +1,5 @@
+set(LLVM_LINK_COMPONENTS
+  Support
+  )
+
+add_obs_executable(obstest test.cpp)
\ No newline at end of file
diff --git a/obs/obs-ir/test.cpp b/obs/obs-ir/test.cpp
new file mode 100644
index 0000000000000..faf42d5289969
--- /dev/null
+++ b/obs/obs-ir/test.cpp
@@ -0,0 +1,5 @@
+#include <iostream>
+
+int main(int argc, char* argv[]) { 
+    std::cout << "Hello World OBS" << std::endl;
+}
\ No newline at end of file

>From 494786e93565587ce6a055e6e87b12effc8e7056 Mon Sep 17 00:00:00 2001
From: Shuanglong Kan <kanshuanglong at outlook.com>
Date: Mon, 25 Mar 2024 19:27:01 +0100
Subject: [PATCH 2/5] feat: test 1

---
 obs/.clang-tidy           |  26 +++
 obs/CMakeLists.txt        |   8 +-
 obs/include/AST.h         | 227 +++++++++++++++++++
 obs/include/Lexer.h       | 184 ++++++++++++++++
 obs/include/Parser.h      | 452 ++++++++++++++++++++++++++++++++++++++
 obs/obs-ir/CMakeLists.txt |   5 -
 obs/obs-ir/test.cpp       |  56 ++++-
 obs/parser/AST.cpp        | 222 +++++++++++++++++++
 8 files changed, 1170 insertions(+), 10 deletions(-)
 create mode 100644 obs/.clang-tidy
 create mode 100644 obs/include/AST.h
 create mode 100644 obs/include/Lexer.h
 create mode 100644 obs/include/Parser.h
 delete mode 100644 obs/obs-ir/CMakeLists.txt
 create mode 100644 obs/parser/AST.cpp

diff --git a/obs/.clang-tidy b/obs/.clang-tidy
new file mode 100644
index 0000000000000..9cece0de812b8
--- /dev/null
+++ b/obs/.clang-tidy
@@ -0,0 +1,26 @@
+Checks: '-*,clang-diagnostic-*,llvm-*,misc-*,-misc-const-correctness,-misc-unused-parameters,-misc-non-private-member-variables-in-classes,-misc-no-recursion,-misc-use-anonymous-namespace,readability-identifier-naming,-misc-include-cleaner'
+CheckOptions:
+  - key:             readability-identifier-naming.ClassCase
+    value:           CamelCase
+  - key:             readability-identifier-naming.EnumCase
+    value:           CamelCase
+  - key:             readability-identifier-naming.FunctionCase
+    value:           camelBack
+  # Exclude from scanning as this is an exported symbol used for fuzzing
+  # throughout the code base.
+  - key:             readability-identifier-naming.FunctionIgnoredRegexp
+    value:           "LLVMFuzzerTestOneInput"
+  - key:             readability-identifier-naming.MemberCase
+    value:           CamelCase
+  - key:             readability-identifier-naming.ParameterCase
+    value:           CamelCase
+  - key:             readability-identifier-naming.UnionCase
+    value:           CamelCase
+  - key:             readability-identifier-naming.VariableCase
+    value:           CamelCase
+  - key:             readability-identifier-naming.IgnoreMainLikeFunctions
+    value:           1
+  - key:             readability-redundant-member-init.IgnoreBaseInCopyConstructors
+    value:           1
+  - key:             modernize-use-default-member-init.UseAssignment
+    value:           1
diff --git a/obs/CMakeLists.txt b/obs/CMakeLists.txt
index 6353f1c09cde4..af775cf2ebb5c 100644
--- a/obs/CMakeLists.txt
+++ b/obs/CMakeLists.txt
@@ -1,3 +1,7 @@
+set(LLVM_LINK_COMPONENTS
+  Support
+  )
+  
 include(CMakeDependentOption)
 include(GNUInstallDirs)
 
@@ -11,5 +15,5 @@ list(INSERT CMAKE_MODULE_PATH 0
 include(CMakeParseArguments)
 include(AddOBS)
 
-
-add_subdirectory(obs-ir)
\ No newline at end of file
+include_directories(include/)
+add_obs_executable(obstest obs-ir/test.cpp parser/AST.cpp)
\ No newline at end of file
diff --git a/obs/include/AST.h b/obs/include/AST.h
new file mode 100644
index 0000000000000..46d3f876143b2
--- /dev/null
+++ b/obs/include/AST.h
@@ -0,0 +1,227 @@
+#ifndef OBS_AST_H
+#define OBS_AST_H
+
+#include "Lexer.h"
+
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/StringRef.h"
+#include "llvm/Support/Casting.h"
+#include <utility>
+#include <vector>
+#include <optional>
+
+namespace obs {
+
+struct VarType {
+    std::vector<int64_t> shape ;
+};
+
+class ExprAST {
+public:
+    enum ExprASTKind {
+        Expr_VarDecl,
+        Expr_Return,
+        Expr_Num,
+        Expr_Literal,
+        Expr_Var,
+        Expr_BinOp,
+        Expr_Call,
+        Expr_Print,
+    };
+    ExprAST(ExprASTKind kind, Location location): kind(kind), location(std::move(location)) {}
+
+    ExprASTKind getKind() const { return kind; }
+
+    const Location &loc() {
+        return location;
+    }
+
+    ~ExprAST() = default; 
+
+private:
+    const ExprASTKind kind;
+    Location location;
+};
+
+using ExprASTList = std::vector<std::unique_ptr<ExprAST>>;
+
+class NumberExprAST: public ExprAST {
+    double val;
+public:
+    NumberExprAST(Location loc, double val): ExprAST(Expr_Num, loc), val(val) { }
+    double getValue() {
+        return val;
+    }
+    static bool classof(const ExprAST *c) { return c->getKind() == Expr_Num; }
+};
+
+class LiteralExprAST : public ExprAST {
+    std::vector<std::unique_ptr<ExprAST>> values;
+    std::vector<int64_t> dims;
+public:
+    LiteralExprAST(Location loc, std::vector<std::unique_ptr<ExprAST>> values, std::vector<int64_t> dims ):
+    ExprAST(Expr_Literal, loc), values(std::move(values)), dims(std::move(dims)) { }
+
+    llvm::ArrayRef<std::unique_ptr<ExprAST>> getValues() {
+        return values;
+    }
+
+    llvm::ArrayRef<int64_t> getDims() {
+        return dims;
+    }
+
+    static bool classof(const ExprAST * c) {
+        return c->getKind() == Expr_Literal;
+    }
+
+};
+
+class VariableExprAST : public ExprAST {
+    std::string name;
+
+public:
+    VariableExprAST(Location loc, llvm::StringRef name): ExprAST(Expr_Var, std::move(loc)), name(name) { }
+    llvm::StringRef getName() {
+        return name;
+    }
+
+    static bool classof(const ExprAST *c) {
+        return c->getKind() == Expr_Var;
+    }
+};
+
+class VarDeclExprAST : public ExprAST {
+    std::string name;
+    VarType type;
+    std::unique_ptr<ExprAST> initVal;
+
+public:
+    VarDeclExprAST(Location loc, llvm::StringRef name, VarType type, std::unique_ptr<ExprAST>initVal ):
+    ExprAST(Expr_VarDecl, loc), name(name), type(type), initVal(std::move(initVal)) {}
+
+    llvm::StringRef getName() { return name; }
+    VarType getType() { return type; }
+    ExprAST * getInitVal() { return initVal.get(); }
+
+    static bool classof(const ExprAST * c) {
+        return c -> getKind() == Expr_VarDecl;
+    }
+};
+
+class ReturnExprAST : public ExprAST {
+    std::optional<std::unique_ptr<ExprAST>> expr;
+public:
+    ReturnExprAST(Location loc, std::optional<std::unique_ptr<ExprAST>> expr): 
+        ExprAST(Expr_Return, loc), expr(std::move(expr)) {}
+
+    std::optional<ExprAST *> getExpr() {
+        if (expr.has_value()) {
+            return expr->get();
+        }
+        return std::nullopt;
+    }
+
+    static bool classof(const ExprAST * c) {
+        return c->getKind() == Expr_Return;
+    }
+};
+
+class BinaryExprAST : public ExprAST {
+    char op;
+    std::unique_ptr<ExprAST> lhs, rhs;
+public:
+    char getOp() { return op; }
+    ExprAST * getLHS() { return lhs.get(); }
+    ExprAST * getRHS() { return rhs.get(); }
+    BinaryExprAST(Location location, char op, std::unique_ptr<ExprAST> lhs, std::unique_ptr<ExprAST> rhs):
+    ExprAST(Expr_BinOp, location), op(op), lhs(std::move(lhs)), rhs(std::move(rhs)) {}
+
+    static bool classof(const ExprAST * c) {
+        return c->getKind() == Expr_BinOp;
+    }
+};
+
+class CallExprAST : public ExprAST {
+    std::string callee;
+    std::vector<std::unique_ptr<ExprAST>> args;
+
+public:
+    CallExprAST(Location loc, std::string callee, std::vector<std::unique_ptr<ExprAST>> args):
+    ExprAST(Expr_Call, loc), callee(callee), args(std::move(args)) {}
+
+    llvm::StringRef getCallee() {
+        return callee;
+    }
+    llvm::ArrayRef<std::unique_ptr<ExprAST>> getArgs() {
+        return args;
+    }
+
+    static bool classof(const ExprAST *c) {
+        return c->getKind() == Expr_Call;
+    }
+
+};
+
+/// Expression class for builtin print calls.                                                                                                                         
+class PrintExprAST : public ExprAST {
+  std::unique_ptr<ExprAST> arg;
+
+public:
+  PrintExprAST(Location loc, std::unique_ptr<ExprAST> arg)
+      : ExprAST(Expr_Print, std::move(loc)), arg(std::move(arg)) {}
+
+  ExprAST *getArg() { return arg.get(); }
+
+  /// LLVM style RTTI                                                                                                                                                 
+  static bool classof(const ExprAST *c) { return c->getKind() == Expr_Print; }
+};
+
+class PrototypeAST {
+  Location location;
+  std::string name;
+  std::vector<std::unique_ptr<VariableExprAST>> args;
+
+public:
+  PrototypeAST(Location location, const std::string &name,
+               std::vector<std::unique_ptr<VariableExprAST>> args)
+      : location(std::move(location)), name(name), args(std::move(args)) {}
+
+  const Location &loc() { return location; }
+  llvm::StringRef getName() const { return name; }
+  llvm::ArrayRef<std::unique_ptr<VariableExprAST>> getArgs() { return args; }
+};
+
+/// This class represents a function definition itself.                                                                                                               
+class FunctionAST {
+  std::unique_ptr<PrototypeAST> proto;
+  std::unique_ptr<ExprASTList> body;
+
+public:
+  FunctionAST(std::unique_ptr<PrototypeAST> proto,
+              std::unique_ptr<ExprASTList> body)
+      : proto(std::move(proto)), body(std::move(body)) {}
+  PrototypeAST *getProto() { return proto.get(); }
+  ExprASTList *getBody() { return body.get(); }
+};
+
+/// This class represents a list of functions to be processed together                                                                                                
+class ModuleAST {
+  std::vector<FunctionAST> functions;
+
+public:
+  ModuleAST(std::vector<FunctionAST> functions)
+      : functions(std::move(functions)) {}
+
+  auto begin() { return functions.begin(); }
+  auto end() { return functions.end(); }
+};
+
+void dump(ModuleAST &);
+
+
+}
+
+
+#endif //OBS_AST_H
+
+
diff --git a/obs/include/Lexer.h b/obs/include/Lexer.h
new file mode 100644
index 0000000000000..c2b1c8cf54a0f
--- /dev/null
+++ b/obs/include/Lexer.h
@@ -0,0 +1,184 @@
+#ifndef OBS_LEXER_H
+#define OBS_LEXER_H
+
+#include "llvm/ADT/StringRef.h"
+
+#include <memory>
+#include <string>
+
+namespace obs {
+
+  struct Location {
+    std::shared_ptr<std::string> file;
+    int line;
+    int col;
+  };
+  
+  enum Token : int {
+    tok_semicolon = ';',
+    tok_parenthese_open = '(',
+    tok_parenthese_close = ')',
+    tok_bracket_open = '{',
+    tok_bracket_close = '}',
+    tok_sbracket_open = '[',
+    tok_sbracket_close = ']',
+    tok_eof = - 1,
+    tok_return = -2,
+    tok_var = -3,
+    tok_def = -4,
+
+    tok_identifier = -5,
+    tok_number = -6,
+  };
+
+class Lexer {
+public:
+  Lexer(std::string filename) : lastLocation( {std::make_shared<std::string>(std::move(filename)), 0, 0} ) { }
+  virtual ~Lexer() = default;
+
+  Token getCurToken() {
+    return curTok;
+  }
+
+  Token getNextToken() {
+    return curTok = getTok();
+  }
+
+  void consume(Token tok) {
+    assert(tok == curTok && "consume Token mismatch expectation");
+    getNextToken();
+  }
+
+  llvm::StringRef getId() {
+    assert(curTok == tok_identifier);
+    return identifierStr;
+  }
+
+  double getValue() {
+    assert(curTok == tok_number);
+    return numVal;
+  }
+
+
+  Token getTok() {
+    while(isspace(lastChar)) {
+      lastChar = Token(getNextChar());
+    }
+
+    if (isalpha(lastChar)) {
+      identifierStr = (char)lastChar;
+      while(isalnum(lastChar = Token(getNextChar())) || lastChar == '_') {
+        identifierStr += (char)lastChar;
+      }
+      if (identifierStr == "return") {
+        return tok_return;
+      }
+      if (identifierStr == "def") {
+        return tok_def;
+      }
+      if (identifierStr == "var") {
+        return tok_var;
+      }
+      return tok_identifier;
+    }  
+
+    //Number: [0-9] +
+    if (isdigit(lastChar) || lastChar == '.') {
+      std::string numStr;
+      do {
+        numStr += lastChar;
+        lastChar = Token(getNextChar());
+      } while(isdigit(lastChar) || lastChar == '.');
+
+      numVal = strtod(numStr.c_str(), nullptr);
+      return tok_number;
+    }
+
+    if (lastChar == '#') {
+      do {
+        lastChar = Token(getNextChar());
+      } while( lastChar != EOF && lastChar != '\n' && lastChar != '\r');
+
+      if (lastChar != EOF) {
+        return getTok();
+      }
+    }
+    if (lastChar == EOF) {
+      return tok_eof;
+    }
+
+    Token thisChar = Token(lastChar);
+    lastChar = Token(getNextChar());
+    return thisChar;
+  }
+
+  Location getLastLocation() {
+    return lastLocation;
+  }
+
+  int getLine(){
+    return curLineNum;
+  }
+
+  int getCol() {
+    return curCol;
+  }
+
+
+
+private:
+
+  virtual llvm::StringRef readNextLine() = 0;
+
+  int getNextChar() {
+    if (curLineBuffer.empty()) {
+      return EOF;
+    }
+    ++curCol;
+    auto nextChar = curLineBuffer.front();
+    curLineBuffer.drop_front();
+    if (curLineBuffer.empty()) {
+      curLineBuffer = readNextLine();
+    }
+    if (nextChar == '\n') {
+      ++curLineNum;
+      curCol = 0;
+    }
+    return nextChar;
+  }
+  //Private member variables.
+  Location lastLocation;
+  Token curTok = tok_eof;
+  Token lastChar = Token(' ');
+  llvm::StringRef curLineBuffer = "\n";
+  int curCol = 0;
+  int curLineNum = 0;
+
+  std::string identifierStr;
+  double numVal = 0;
+};
+
+class LexerBuffer final : public Lexer {
+public:
+  LexerBuffer(const char *begin, const char *end, std::string filename) : Lexer(std::move(filename)), current(begin), end(end) {}
+private:
+  const char* current, *end;
+  
+  llvm::StringRef readNextLine() override {
+    auto *begin = current;
+    while (current <= end && *current && *current !='\n') {
+      ++current;
+    };
+    if (current <= end && *current ) {
+      ++current;
+    };
+    llvm::StringRef result{begin, static_cast<size_t>(current - begin)};
+    return result;
+  };
+};
+
+}
+
+
+
+#endif //OBS_LEXER_H
\ No newline at end of file
diff --git a/obs/include/Parser.h b/obs/include/Parser.h
new file mode 100644
index 0000000000000..e9e619774c7fb
--- /dev/null
+++ b/obs/include/Parser.h
@@ -0,0 +1,452 @@
+#ifndef OBS_PARSER_H
+#define OBS_PARSER_H
+
+#include "AST.h"
+#include "Lexer.h"
+
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/StringExtras.h"
+#include "llvm/Support/raw_ostream.h"
+
+#include <map>
+#include <utility>
+#include <vector>
+#include <optional>
+
+namespace obs {
+                                                                                                                                                        
+class Parser {
+public:
+  /// Create a Parser for the supplied lexer.                                                                                                                         
+  Parser(Lexer &lexer) : lexer(lexer) {}
+
+  /// Parse a full Module. A module is a list of function definitions.                                                                                                
+  std::unique_ptr<ModuleAST> parseModule() {
+    lexer.getNextToken(); // prime the lexer      
+        // Parse functions one at a time and accumulate in this vector.                                                                                                   
+    std::vector<FunctionAST> functions;
+    while (auto f = parseDefinition()) {
+      functions.push_back(std::move(*f));
+      if (lexer.getCurToken() == tok_eof)
+        break;
+    }
+    // If we didn't reach EOF, there was an error during parsing                                                                                                      
+    if (lexer.getCurToken() != tok_eof)
+      return parseError<ModuleAST>("nothing", "at end of module");
+
+    return std::make_unique<ModuleAST>(std::move(functions));
+  }
+
+private:
+  Lexer &lexer;
+
+  /// Parse a return statement.                                                                                                                                       
+  /// return :== return ; | return expr ;                                                                                                                             
+  std::unique_ptr<ReturnExprAST> parseReturn() {
+    auto loc = lexer.getLastLocation();
+    lexer.consume(tok_return);
+
+    // return takes an optional argument                                                                                                                              
+    std::optional<std::unique_ptr<ExprAST>> expr;
+    if (lexer.getCurToken() != ';') {
+      expr = parseExpression();
+      if (!expr)
+        return nullptr;
+    }
+    return std::make_unique<ReturnExprAST>(std::move(loc), std::move(expr));
+  }
+
+/// Parse a literal number.                                                                                                                                         
+  /// numberexpr ::= number                                                                                                                                           
+  std::unique_ptr<ExprAST> parseNumberExpr() {
+    auto loc = lexer.getLastLocation();
+    auto result =
+        std::make_unique<NumberExprAST>(std::move(loc), lexer.getValue());
+    lexer.consume(tok_number);
+    return std::move(result);
+  }
+
+  /// Parse a literal array expression.                                                                                                                               
+  /// tensorLiteral ::= [ literalList ] | number                                                                                                                      
+  /// literalList ::= tensorLiteral | tensorLiteral, literalList                                                                                                      
+  std::unique_ptr<ExprAST> parseTensorLiteralExpr() {
+    auto loc = lexer.getLastLocation();
+    lexer.consume(Token('['));
+
+    // Hold the list of values at this nesting level.                                                                                                                 
+    std::vector<std::unique_ptr<ExprAST>> values;
+    // Hold the dimensions for all the nesting inside this level.                                                                                                     
+    std::vector<int64_t> dims;
+    do {
+      // We can have either another nested array or a number literal.                                                                                                 
+      if (lexer.getCurToken() == '[') {
+        values.push_back(parseTensorLiteralExpr());
+        if (!values.back())
+          return nullptr; // parse error in the nested array.                                                                                                         
+      } else {
+        if (lexer.getCurToken() != tok_number)
+          return parseError<ExprAST>("<num> or [", "in literal expression");
+        values.push_back(parseNumberExpr());
+      }
+
+      // End of this list on ']'                                                                                                                                      
+      if (lexer.getCurToken() == ']')
+        break;
+      // Elements are separated by a comma.                                                                                                                           
+      if (lexer.getCurToken() != ',')
+        return parseError<ExprAST>("] or ,", "in literal expression");
+
+      lexer.getNextToken(); // eat ,                                                                                                                                  
+    } while (true);
+    if (values.empty())
+      return parseError<ExprAST>("<something>", "to fill literal expression");
+    lexer.getNextToken(); // eat ]                                                                                                                                    
+
+    /// Fill in the dimensions now. First the current nesting level:                                                                                                  
+    dims.push_back(values.size());
+
+    /// If there is any nested array, process all of them and ensure that                                                                                             
+    /// dimensions are uniform.                                                                                                                                       
+    if (llvm::any_of(values, [](std::unique_ptr<ExprAST> &expr) {
+          return llvm::isa<LiteralExprAST>(expr.get());
+        })) {
+      auto *firstLiteral = llvm::dyn_cast<LiteralExprAST>(values.front().get());
+      if (!firstLiteral)
+        return parseError<ExprAST>("uniform well-nested dimensions",
+                                   "inside literal expression");
+
+      // Append the nested dimensions to the current level                                                                                                            
+      auto firstDims = firstLiteral->getDims();
+      dims.insert(dims.end(), firstDims.begin(), firstDims.end());
+      // Sanity check that shape is uniform across all elements of the list.                                                                                          
+      for (auto &expr : values) {
+        auto *exprLiteral = llvm::cast<LiteralExprAST>(expr.get());
+        if (!exprLiteral)
+          return parseError<ExprAST>("uniform well-nested dimensions",
+                                     "inside literal expression");
+        if (exprLiteral->getDims() != firstDims)
+          return parseError<ExprAST>("uniform well-nested dimensions",
+                                     "inside literal expression");
+      }
+    }
+    return std::make_unique<LiteralExprAST>(std::move(loc), std::move(values),
+                                            std::move(dims));
+  }
+
+  /// parenexpr ::= '(' expression ')'                                                                                                                                
+  std::unique_ptr<ExprAST> parseParenExpr() {
+    lexer.getNextToken(); // eat (.                                                                                                                                   
+    auto v = parseExpression();
+    if (!v)
+      return nullptr;
+
+    if (lexer.getCurToken() != ')')
+      return parseError<ExprAST>(")", "to close expression with parentheses");
+    lexer.consume(Token(')'));
+    return v;
+  }
+  /// identifierexpr                                                                                                                                                  
+  ///   ::= identifier                                                                                                                                                
+  ///   ::= identifier '(' expression ')'                                                                                                                             
+  std::unique_ptr<ExprAST> parseIdentifierExpr() {
+    std::string name(lexer.getId());
+
+    auto loc = lexer.getLastLocation();
+    lexer.getNextToken(); // eat identifier.                                                                                                                          
+
+    if (lexer.getCurToken() != '(') // Simple variable ref.                                                                                                           
+      return std::make_unique<VariableExprAST>(std::move(loc), name);
+
+    // This is a function call.                                                                                                                                       
+    lexer.consume(Token('('));
+    std::vector<std::unique_ptr<ExprAST>> args;
+    if (lexer.getCurToken() != ')') {
+      while (true) {
+        if (auto arg = parseExpression())
+          args.push_back(std::move(arg));
+        else
+          return nullptr;
+        if (lexer.getCurToken() == ')')
+          break;
+
+        if (lexer.getCurToken() != ',')
+          return parseError<ExprAST>(", or )", "in argument list");
+        lexer.getNextToken();
+      }
+    }
+    lexer.consume(Token(')'));
+
+    // It can be a builtin call to print                                                                                                                              
+    if (name == "print") {
+      if (args.size() != 1)
+        return parseError<ExprAST>("<single arg>", "as argument to print()");
+
+      return std::make_unique<PrintExprAST>(std::move(loc), std::move(args[0]));
+    }
+
+    // Call to a user-defined function                                                                                                                                
+    return std::make_unique<CallExprAST>(std::move(loc), name, std::move(args));
+  }
+  /// primary                                                                                                                                                         
+  ///   ::= identifierexpr                                                                                                                                            
+  ///   ::= numberexpr                                                                                                                                                
+  ///   ::= parenexpr                                                                                                                                                 
+  ///   ::= tensorliteral                                                                                                                                             
+  std::unique_ptr<ExprAST> parsePrimary() {
+    switch (lexer.getCurToken()) {
+    default:
+      llvm::errs() << "unknown token '" << lexer.getCurToken()
+                   << "' when expecting an expression\n";
+      return nullptr;
+    case tok_identifier:
+      return parseIdentifierExpr();
+    case tok_number:
+      return parseNumberExpr();
+    case '(':
+      return parseParenExpr();
+    case '[':
+      return parseTensorLiteralExpr();
+    case ';':
+      return nullptr;
+    case '}':
+      return nullptr;
+    }
+  }
+
+  std::unique_ptr<ExprAST> parseBinOpRHS(int exprPrec,
+                                         std::unique_ptr<ExprAST> lhs) {
+    // If this is a binop, find its precedence.                                                                                                                       
+    while (true) {
+      int tokPrec = getTokPrecedence();
+
+      // If this is a binop that binds at least as tightly as the current binop,                                                                                      
+      // consume it, otherwise we are done.                                                                                                                           
+      if (tokPrec < exprPrec)
+        return lhs;
+
+      // Okay, we know this is a binop.                                                                                                                               
+      int binOp = lexer.getCurToken();
+      lexer.consume(Token(binOp));
+      auto loc = lexer.getLastLocation();
+
+      // Parse the primary expression after the binary operator.                                                                                                      
+      auto rhs = parsePrimary();
+      if (!rhs)
+	return parseError<ExprAST>("expression", "to complete binary operator");
+
+      // If BinOp binds less tightly with rhs than the operator after rhs, let                                                                                        
+      // the pending operator take rhs as its lhs.                                                                                                                    
+      int nextPrec = getTokPrecedence();
+      if (tokPrec < nextPrec) {
+        rhs = parseBinOpRHS(tokPrec + 1, std::move(rhs));
+        if (!rhs)
+          return nullptr;
+      }
+
+      // Merge lhs/RHS.                                                                                                                                               
+      lhs = std::make_unique<BinaryExprAST>(std::move(loc), binOp,
+                                            std::move(lhs), std::move(rhs));
+    }
+  }
+
+  /// expression::= primary binop rhs                                                                                                                                 
+  std::unique_ptr<ExprAST> parseExpression() {
+    auto lhs = parsePrimary();
+    if (!lhs)
+      return nullptr;
+
+    return parseBinOpRHS(0, std::move(lhs));
+  }
+
+  /// type ::= < shape_list >                                                                                                                                         
+  /// shape_list ::= num | num , shape_list                                                                                                                           
+  std::unique_ptr<VarType> parseType() {
+    if (lexer.getCurToken() != '<')
+      return parseError<VarType>("<", "to begin type");
+    lexer.getNextToken(); // eat <                                                                                                                                    
+
+    auto type = std::make_unique<VarType>();
+
+    while (lexer.getCurToken() == tok_number) {
+      type->shape.push_back(lexer.getValue());
+      lexer.getNextToken();
+      if (lexer.getCurToken() == ',')
+        lexer.getNextToken();
+    }
+
+    if (lexer.getCurToken() != '>')
+      return parseError<VarType>(">", "to end type");
+    lexer.getNextToken(); // eat >                                                                                                                                    
+    return type;
+  }
+
+  std::unique_ptr<VarDeclExprAST> parseDeclaration() {
+    if (lexer.getCurToken() != tok_var)
+      return parseError<VarDeclExprAST>("var", "to begin declaration");
+    auto loc = lexer.getLastLocation();
+    lexer.getNextToken(); // eat var                                                                                                                                  
+
+    if (lexer.getCurToken() != tok_identifier)
+      return parseError<VarDeclExprAST>("identified",
+                                        "after 'var' declaration");
+    std::string id(lexer.getId());
+    lexer.getNextToken(); // eat id                                                                                                                                   
+
+    std::unique_ptr<VarType> type; // Type is optional, it can be inferred                                                                                            
+    if (lexer.getCurToken() == '<') {
+      type = parseType();
+      if (!type)
+        return nullptr;
+    }
+
+    if (!type)
+      type = std::make_unique<VarType>();
+    lexer.consume(Token('='));
+    auto expr = parseExpression();
+    return std::make_unique<VarDeclExprAST>(std::move(loc), std::move(id),
+                                            std::move(*type), std::move(expr));
+  }
+
+  std::unique_ptr<ExprASTList> parseBlock() {
+    if (lexer.getCurToken() != '{')
+      return parseError<ExprASTList>("{", "to begin block");
+    lexer.consume(Token('{'));
+
+    auto exprList = std::make_unique<ExprASTList>();
+
+    // Ignore empty expressions: swallow sequences of semicolons.                                                                                                     
+    while (lexer.getCurToken() == ';')
+      lexer.consume(Token(';'));
+
+    while (lexer.getCurToken() != '}' && lexer.getCurToken() != tok_eof) {
+      if (lexer.getCurToken() == tok_var) {
+        // Variable declaration                                                                                                                                       
+        auto varDecl = parseDeclaration();
+        if (!varDecl)
+          return nullptr;
+        exprList->push_back(std::move(varDecl));
+      } else if (lexer.getCurToken() == tok_return) {
+        // Return statement                                                                                                                                           
+        auto ret = parseReturn();
+        if (!ret)
+          return nullptr;
+        exprList->push_back(std::move(ret));
+      } else {
+        // General expression                                                                                                                                         
+        auto expr = parseExpression();
+        if (!expr)
+          return nullptr;
+        exprList->push_back(std::move(expr));
+      }
+      // Ensure that elements are separated by a semicolon.                                                                                                           
+      if (lexer.getCurToken() != ';')
+        return parseError<ExprASTList>(";", "after expression");
+      // Ignore empty expressions: swallow sequences of semicolons.                                                                                                   
+      while (lexer.getCurToken() == ';')
+        lexer.consume(Token(';'));
+    }
+
+    if (lexer.getCurToken() != '}')
+      return parseError<ExprASTList>("}", "to close block");
+
+    lexer.consume(Token('}'));
+    return exprList;
+  }
+
+  /// prototype ::= def id '(' decl_list ')'                                                                                                                          
+  /// decl_list ::= identifier | identifier, decl_list                                                                                                                
+  std::unique_ptr<PrototypeAST> parsePrototype() {
+    auto loc = lexer.getLastLocation();
+
+    if (lexer.getCurToken() != tok_def)
+      return parseError<PrototypeAST>("def", "in prototype");
+    lexer.consume(tok_def);
+
+    if (lexer.getCurToken() != tok_identifier)
+      return parseError<PrototypeAST>("function name", "in prototype");
+
+    std::string fnName(lexer.getId());
+    lexer.consume(tok_identifier);
+
+    if (lexer.getCurToken() != '(')
+      return parseError<PrototypeAST>("(", "in prototype");
+    lexer.consume(Token('('));
+
+    std::vector<std::unique_ptr<VariableExprAST>> args;
+    if (lexer.getCurToken() != ')') {
+      do {
+        std::string name(lexer.getId());
+        auto loc = lexer.getLastLocation();
+        lexer.consume(tok_identifier);
+        auto decl = std::make_unique<VariableExprAST>(std::move(loc), name);
+        args.push_back(std::move(decl));
+        if (lexer.getCurToken() != ',')
+          break;
+        lexer.consume(Token(','));
+        if (lexer.getCurToken() != tok_identifier)
+          return parseError<PrototypeAST>(
+              "identifier", "after ',' in function parameter list");
+      } while (true);
+    }
+    if (lexer.getCurToken() != ')')
+      return parseError<PrototypeAST>(")", "to end function prototype");
+
+    // success.                                                                                                                                                       
+    lexer.consume(Token(')'));
+    return std::make_unique<PrototypeAST>(std::move(loc), fnName,
+                                          std::move(args));
+  }
+
+  /// Parse a function definition, we expect a prototype initiated with the                                                                                           
+  /// `def` keyword, followed by a block containing a list of expressions.                                                                                            
+  ///                                                                                                                                                                 
+  /// definition ::= prototype block                                                                                                                                  
+  std::unique_ptr<FunctionAST> parseDefinition() {
+    auto proto = parsePrototype();
+    if (!proto)
+      return nullptr;
+
+    if (auto block = parseBlock())
+      return std::make_unique<FunctionAST>(std::move(proto), std::move(block));
+    return nullptr;
+  }
+
+  /// Get the precedence of the pending binary operator token.                                                                                                        
+  int getTokPrecedence() {
+    if (!isascii(lexer.getCurToken()))
+      return -1;
+
+    // 1 is lowest precedence.                                                                                                                                        
+    switch (static_cast<char>(lexer.getCurToken())) {
+    case '-':
+      return 20;
+    case '+':
+      return 20;
+    case '*':
+      return 40;
+    default:
+      return -1;
+    }
+  }
+  /// Helper function to signal errors while parsing, it takes an argument                                                                                            
+  /// indicating the expected token and another argument giving more context.                                                                                         
+  /// Location is retrieved from the lexer to enrich the error message.                                                                                               
+  template <typename R, typename T, typename U = const char *>
+  std::unique_ptr<R> parseError(T &&expected, U &&context = "") {
+    auto curToken = lexer.getCurToken();
+    llvm::errs() << "Parse error (" << lexer.getLastLocation().line << ", "
+                 << lexer.getLastLocation().col << "): expected '" << expected
+                 << "' " << context << " but has Token " << curToken;
+    if (isprint(curToken))
+      llvm::errs() << " '" << (char)curToken << "'";
+    llvm::errs() << "\n";
+    return nullptr;
+  }
+};
+
+}
+
+
+
+
+#endif //OBS_PARSER_H
\ No newline at end of file
diff --git a/obs/obs-ir/CMakeLists.txt b/obs/obs-ir/CMakeLists.txt
deleted file mode 100644
index f418db1ce4b20..0000000000000
--- a/obs/obs-ir/CMakeLists.txt
+++ /dev/null
@@ -1,5 +0,0 @@
-set(LLVM_LINK_COMPONENTS
-  Support
-  )
-
-add_obs_executable(obstest test.cpp)
\ No newline at end of file
diff --git a/obs/obs-ir/test.cpp b/obs/obs-ir/test.cpp
index faf42d5289969..829797eaf4cf5 100644
--- a/obs/obs-ir/test.cpp
+++ b/obs/obs-ir/test.cpp
@@ -1,5 +1,55 @@
+#include "AST.h"
+#include "Lexer.h"
+#include "Parser.h"
+
+#include "llvm/ADT/StringRef.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/ErrorOr.h"
+#include "llvm/Support/MemoryBuffer.h"
+#include "llvm/Support/raw_ostream.h"
+#include <memory>
+#include <string>
+#include <system_error>
 #include <iostream>
 
-int main(int argc, char* argv[]) { 
-    std::cout << "Hello World OBS" << std::endl;
-}
\ No newline at end of file
+using namespace obs;
+namespace cl = llvm::cl;
+
+static cl::opt<std::string> inputFilename(cl::Positional, 
+                                          cl::desc("<input toy file>"),
+                                          cl::init("-"),
+                                          cl::value_desc("filename"));
+
+namespace {
+enum Action { None, DumpAST }; 
+}
+
+static cl::opt<enum Action> emitAction("emit", cl::desc("Select the kind of output desired"), 
+                                       cl::values(clEnumValN(DumpAST, "ast", "output the AST dump")));
+
+std::unique_ptr<obs::ModuleAST> parseInputFile(llvm::StringRef filename) {
+  llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> fileOrErr = llvm::MemoryBuffer::getFileOrSTDIN(filename);
+  if (std::error_code ec = fileOrErr.getError()) {
+    llvm::errs() << "Could not open input file: " << ec.message() << "\n";
+    return nullptr;
+  }
+  auto buffer = fileOrErr.get() ->getBuffer();
+  LexerBuffer lexer(buffer.begin(), buffer.end(), std::string(filename));
+  Parser parser(lexer);
+  return parser.parseModule();
+}
+
+int main(int argc, char **argv) {
+  cl::ParseCommandLineOptions(argc, argv, "toy compiler\n");
+  auto moduleAST = parseInputFile(inputFilename);
+
+  switch(emitAction) {
+    case Action::DumpAST:
+      dump(*moduleAST);
+      return 0;
+    default:
+      llvm::errs() << "No action specified (parsing only?), use -emit=<action>\n";
+  }
+  return 0;
+}
+
diff --git a/obs/parser/AST.cpp b/obs/parser/AST.cpp
new file mode 100644
index 0000000000000..bb4f245b97903
--- /dev/null
+++ b/obs/parser/AST.cpp
@@ -0,0 +1,222 @@
+#include "AST.h"
+
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/Twine.h"
+#include "llvm/ADT/TypeSwitch.h"
+#include "llvm/Support/Casting.h"
+#include "llvm/Support/raw_ostream.h"
+#include <string>
+
+namespace obs {
+
+// RAII helper to manage increasing/decreasing the indentation as we traverse
+// the AST
+struct Indent {
+  Indent(int &level) : level(level) { ++level; }
+  ~Indent() { --level; }
+  int &level;
+};
+
+/// Helper class that implement the AST tree traversal and print the nodes along
+/// the way. The only data member is the current indentation level.
+class ASTDumper {
+public:
+  void dump(ModuleAST *node);
+
+private:
+  void dump(const VarType &type);
+  void dump(VarDeclExprAST *varDecl);
+  void dump(ExprAST *expr);
+  void dump(ExprASTList *exprList);
+  void dump(NumberExprAST *num);
+  void dump(LiteralExprAST *node);
+  void dump(VariableExprAST *node);
+  void dump(ReturnExprAST *node);
+  void dump(BinaryExprAST *node);
+  void dump(CallExprAST *node);
+  void dump(PrintExprAST *node);
+  void dump(PrototypeAST *node);
+  void dump(FunctionAST *node);
+
+  // Actually print spaces matching the current indentation level
+  void indent() {
+    for (int i = 0; i < curIndent; i++)
+      llvm::errs() << "  ";
+  }
+  int curIndent = 0;
+};
+
+/// Return a formatted string for the location of any node
+template <typename T>
+static std::string loc(T *node) {
+  const auto &loc = node->loc();
+  return (llvm::Twine("@") + *loc.file + ":" + llvm::Twine(loc.line) + ":" +
+          llvm::Twine(loc.col))
+      .str();
+}
+
+// Helper Macro to bump the indentation level and print the leading spaces for
+// the current indentations
+#define INDENT()                                                               \
+  Indent level_(curIndent);                                                    \
+  indent();
+
+/// Dispatch to a generic expressions to the appropriate subclass using RTTI
+void ASTDumper::dump(ExprAST *expr) {
+  llvm::TypeSwitch<ExprAST *>(expr)
+      .Case<BinaryExprAST, CallExprAST, LiteralExprAST, NumberExprAST,
+            PrintExprAST, ReturnExprAST, VarDeclExprAST, VariableExprAST>(
+          [&](auto *node) { this->dump(node); })
+      .Default([&](ExprAST *) {
+        // No match, fallback to a generic message
+        INDENT();
+        llvm::errs() << "<unknown Expr, kind " << expr->getKind() << ">\n";
+      });
+}
+
+/// A variable declaration is printing the variable name, the type, and then
+/// recurse in the initializer value.
+void ASTDumper::dump(VarDeclExprAST *varDecl) {
+  INDENT();
+  llvm::errs() << "VarDecl " << varDecl->getName();
+  dump(varDecl->getType());
+  llvm::errs() << " " << loc(varDecl) << "\n";
+  dump(varDecl->getInitVal());
+}
+
+/// A "block", or a list of expression
+void ASTDumper::dump(ExprASTList *exprList) {
+  INDENT();
+  llvm::errs() << "Block {\n";
+  for (auto &expr : *exprList)
+    dump(expr.get());
+  indent();
+  llvm::errs() << "} // Block\n";
+}
+
+/// A literal number, just print the value.
+void ASTDumper::dump(NumberExprAST *num) {
+  INDENT();
+  llvm::errs() << num->getValue() << " " << loc(num) << "\n";
+}
+
+/// Helper to print recursively a literal. This handles nested array like:
+///    [ [ 1, 2 ], [ 3, 4 ] ]
+/// We print out such array with the dimensions spelled out at every level:
+///    <2,2>[<2>[ 1, 2 ], <2>[ 3, 4 ] ]
+void printLitHelper(ExprAST *litOrNum) {
+  // Inside a literal expression we can have either a number or another literal
+  if (auto *num = llvm::dyn_cast<NumberExprAST>(litOrNum)) {
+    llvm::errs() << num->getValue();
+    return;
+  }
+  auto *literal = llvm::cast<LiteralExprAST>(litOrNum);
+
+  // Print the dimension for this literal first
+  llvm::errs() << "<";
+  llvm::interleaveComma(literal->getDims(), llvm::errs());
+  llvm::errs() << ">";
+
+  // Now print the content, recursing on every element of the list
+  llvm::errs() << "[ ";
+  llvm::interleaveComma(literal->getValues(), llvm::errs(),
+                        [&](auto &elt) { printLitHelper(elt.get()); });
+  llvm::errs() << "]";
+}
+
+/// Print a literal, see the recursive helper above for the implementation.
+void ASTDumper::dump(LiteralExprAST *node) {
+  INDENT();
+  llvm::errs() << "Literal: ";
+  printLitHelper(node);
+  llvm::errs() << " " << loc(node) << "\n";
+}
+
+/// Print a variable reference (just a name).
+void ASTDumper::dump(VariableExprAST *node) {
+  INDENT();
+  llvm::errs() << "var: " << node->getName() << " " << loc(node) << "\n";
+}
+
+/// Return statement print the return and its (optional) argument.
+void ASTDumper::dump(ReturnExprAST *node) {
+  INDENT();
+  llvm::errs() << "Return\n";
+  if (node->getExpr().has_value())
+    return dump(*node->getExpr());
+  {
+    INDENT();
+    llvm::errs() << "(void)\n";
+  }
+}
+
+/// Print a binary operation, first the operator, then recurse into LHS and RHS.
+void ASTDumper::dump(BinaryExprAST *node) {
+  INDENT();
+  llvm::errs() << "BinOp: " << node->getOp() << " " << loc(node) << "\n";
+  dump(node->getLHS());
+  dump(node->getRHS());
+}
+
+/// Print a call expression, first the callee name and the list of args by
+/// recursing into each individual argument.
+void ASTDumper::dump(CallExprAST *node) {
+  INDENT();
+  llvm::errs() << "Call '" << node->getCallee() << "' [ " << loc(node) << "\n";
+  for (auto &arg : node->getArgs())
+    dump(arg.get());
+  indent();
+  llvm::errs() << "]\n";
+}
+
+/// Print a builtin print call, first the builtin name and then the argument.
+void ASTDumper::dump(PrintExprAST *node) {
+  INDENT();
+  llvm::errs() << "Print [ " << loc(node) << "\n";
+  dump(node->getArg());
+  indent();
+  llvm::errs() << "]\n";
+}
+
+/// Print type: only the shape is printed in between '<' and '>'
+void ASTDumper::dump(const VarType &type) {
+  llvm::errs() << "<";
+  llvm::interleaveComma(type.shape, llvm::errs());
+  llvm::errs() << ">";
+}
+
+/// Print a function prototype, first the function name, and then the list of
+/// parameters names.
+void ASTDumper::dump(PrototypeAST *node) {
+  INDENT();
+  llvm::errs() << "Proto '" << node->getName() << "' " << loc(node) << "\n";
+  indent();
+  llvm::errs() << "Params: [";
+  llvm::interleaveComma(node->getArgs(), llvm::errs(),
+                        [](auto &arg) { llvm::errs() << arg->getName(); });
+  llvm::errs() << "]\n";
+}
+
+/// Print a function, first the prototype and then the body.
+void ASTDumper::dump(FunctionAST *node) {
+  INDENT();
+  llvm::errs() << "Function \n";
+  dump(node->getProto());
+  dump(node->getBody());
+}
+
+/// Print a module, actually loop over the functions and print them in sequence.
+void ASTDumper::dump(ModuleAST *node) {
+  INDENT();
+  llvm::errs() << "Module:\n";
+  for (auto &f : *node)
+    dump(&f);
+}
+}
+
+namespace obs {
+
+// Public API
+void dump(ModuleAST &module) { ASTDumper().dump(&module); }
+
+} // namespace obs

>From 5804ff8e8f188548a2a8013a1baf250e1e2d7cd5 Mon Sep 17 00:00:00 2001
From: Shuanglong Kan <kanshuanglong at outlook.com>
Date: Sat, 30 Mar 2024 21:29:23 +0100
Subject: [PATCH 3/5] feat: practise toy - 1

---
 obs/.clang-tidy                  |   4 +-
 obs/CMakeLists.txt               |  28 ++++-
 obs/include/CMakeLists.txt       |   6 +
 obs/include/Dialect.h            |  17 +++
 obs/include/Lexer.h              |   2 +-
 obs/include/Ops.td               | 203 +++++++++++++++++++++++++++++++
 obs/obs-ir/Dialect.cpp           | 165 +++++++++++++++++++++++++
 obs/obs-ir/{test.cpp => obs.cpp} |  13 +-
 8 files changed, 432 insertions(+), 6 deletions(-)
 create mode 100644 obs/include/CMakeLists.txt
 create mode 100644 obs/include/Dialect.h
 create mode 100644 obs/include/Ops.td
 create mode 100644 obs/obs-ir/Dialect.cpp
 rename obs/obs-ir/{test.cpp => obs.cpp} (90%)

diff --git a/obs/.clang-tidy b/obs/.clang-tidy
index 9cece0de812b8..4b7b8e9479f29 100644
--- a/obs/.clang-tidy
+++ b/obs/.clang-tidy
@@ -13,11 +13,11 @@ CheckOptions:
   - key:             readability-identifier-naming.MemberCase
     value:           CamelCase
   - key:             readability-identifier-naming.ParameterCase
-    value:           CamelCase
+    value:           camelCase
   - key:             readability-identifier-naming.UnionCase
     value:           CamelCase
   - key:             readability-identifier-naming.VariableCase
-    value:           CamelCase
+    value:           camelCase
   - key:             readability-identifier-naming.IgnoreMainLikeFunctions
     value:           1
   - key:             readability-redundant-member-init.IgnoreBaseInCopyConstructors
diff --git a/obs/CMakeLists.txt b/obs/CMakeLists.txt
index af775cf2ebb5c..5b7a50c53125d 100644
--- a/obs/CMakeLists.txt
+++ b/obs/CMakeLists.txt
@@ -5,6 +5,16 @@ set(LLVM_LINK_COMPONENTS
 include(CMakeDependentOption)
 include(GNUInstallDirs)
 
+set(MLIR_MAIN_SRC_DIR ${LLVM_MAIN_SRC_DIR}/../mlir/include ) # --src-root
+set(MLIR_INCLUDE_DIR ${LLVM_MAIN_SRC_DIR}/../mlir/include ) # --includedir
+set(MLIR_TABLEGEN_OUTPUT_DIR ${CMAKE_BINARY_DIR}/tools/mlir/include)
+include_directories(SYSTEM ${MLIR_INCLUDE_DIR})
+include_directories(SYSTEM ${MLIR_TABLEGEN_OUTPUT_DIR})
+
+
+include_directories(SYSTEM ${CMAKE_CURRENT_BINARY_DIR}/include)
+add_subdirectory(include)
+
 # Make sure that our source directory is on the current cmake module path so
 # that we can include cmake files from this directory.
 list(INSERT CMAKE_MODULE_PATH 0
@@ -16,4 +26,20 @@ include(CMakeParseArguments)
 include(AddOBS)
 
 include_directories(include/)
-add_obs_executable(obstest obs-ir/test.cpp parser/AST.cpp)
\ No newline at end of file
+add_obs_executable(obstest 
+    obs-ir/obs.cpp 
+    obs-ir/Dialect.cpp
+    parser/AST.cpp
+  DEPENDS
+    OBSOpGen
+  )
+
+target_link_libraries(obstest
+  PRIVATE
+  MLIRAnalysis
+  MLIRFunctionInterfaces
+  MLIRIR
+  MLIRParser
+  MLIRSideEffectInterfaces
+  MLIRTransforms
+)
\ No newline at end of file
diff --git a/obs/include/CMakeLists.txt b/obs/include/CMakeLists.txt
new file mode 100644
index 0000000000000..83793c91c12e2
--- /dev/null
+++ b/obs/include/CMakeLists.txt
@@ -0,0 +1,6 @@
+set(LLVM_TARGET_DEFINITIONS Ops.td)
+mlir_tablegen(Ops.h.inc -gen-op-decls)
+mlir_tablegen(Ops.cpp.inc -gen-op-defs)
+mlir_tablegen(Dialect.h.inc -gen-dialect-decls)
+mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs)
+add_public_tablegen_target(OBSOpGen)
\ No newline at end of file
diff --git a/obs/include/Dialect.h b/obs/include/Dialect.h
new file mode 100644
index 0000000000000..a298ed2f086a0
--- /dev/null
+++ b/obs/include/Dialect.h
@@ -0,0 +1,17 @@
+#ifndef MLIR_DIALECT_TRAITS_H
+#define MLIR_DIALECT_TRAITS_H
+
+#include "mlir/Bytecode/BytecodeOpInterface.h"
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/SymbolTable.h"
+#include "mlir/Interfaces/CallInterfaces.h"
+#include "mlir/Interfaces/FunctionInterfaces.h"
+#include "mlir/Interfaces/SideEffectInterfaces.h"
+
+#include "Dialect.h.inc"
+
+#define GET_OP_CLASSES
+#include "Ops.h.inc"
+
+
+#endif //MLIR_DIALECT_TRAITS_H
diff --git a/obs/include/Lexer.h b/obs/include/Lexer.h
index c2b1c8cf54a0f..bf47f17dfa92c 100644
--- a/obs/include/Lexer.h
+++ b/obs/include/Lexer.h
@@ -136,7 +136,7 @@ class Lexer {
     }
     ++curCol;
     auto nextChar = curLineBuffer.front();
-    curLineBuffer.drop_front();
+    curLineBuffer = curLineBuffer.drop_front();
     if (curLineBuffer.empty()) {
       curLineBuffer = readNextLine();
     }
diff --git a/obs/include/Ops.td b/obs/include/Ops.td
new file mode 100644
index 0000000000000..96e9d02b72160
--- /dev/null
+++ b/obs/include/Ops.td
@@ -0,0 +1,203 @@
+
+#ifndef OBS_OPS
+#define OBS_OPS
+
+include "mlir/IR/OpBase.td"
+include "mlir/Interfaces/FunctionInterfaces.td"
+include "mlir/IR/SymbolInterfaces.td"
+include "mlir/Interfaces/SideEffectInterfaces.td"
+
+def OBS_Dialect: Dialect {
+    //The namespace of the Dialect.
+    let name = "toy";
+
+    let summary = "A high-level dialect for analyzing and optimzing the Language";
+
+    let description = [{
+        The Toy language is a tensor-based language that allows you to define functions, perform some math computation, and
+        print results. This dialect provides a reprentation of the language that is amenable to analysis and optimization.
+    }];
+
+    let cppNamespace = "::mlir::obs";
+}
+
+class OBS_op<string mnemonic, list<Trait> traits = []> : Op<OBS_Dialect, mnemonic, traits>;
+
+def ConstantOp : OBS_op<"constant", [Pure]> {
+
+    let summary = "constant operation" ;
+
+    let description = [{
+        Constant operation turns a literal into an SSA value. The data is attached to the operation as an attribute.
+    }];
+
+
+    let arguments = (ins F64ElementsAttr:$value);
+    let results = (outs F64Tensor);
+
+    let hasCustomAssemblyFormat = 1;
+
+    let hasVerifier = 1;
+
+    let builders = [
+        OpBuilder<(ins "DenseElementsAttr":$value), [{build($_builder, $_state, value.getType(), value);}] >,
+        OpBuilder<(ins "double":$value)>
+    ];
+
+}
+
+def AddOp : OBS_op<"add"> {
+    let summary = "element-wise addition operation";
+    let description = [{
+        The "add" operation performs element-wise addition between two tensors.
+        The shape of the tensor operands are expected to match.
+    }];
+
+    let arguments = (ins F64Tensor:$lhs, F64Tensor:$rhs);
+    let results = (outs F64Tensor);
+
+    let hasCustomAssemblyFormat = 1;
+
+    let builders = [
+        OpBuilder<(ins "Value":$lhs, "Value":$rhs)>
+    ];
+}  
+
+def FuncOp : OBS_op<"func", [FunctionOpInterface, IsolatedFromAbove]> {
+    let summary = "user defined function operation";
+    let description = [{
+        The "toy.func" operation represents a user defined function. These are callable SSA-region operations
+        that contain toy computations.
+    }];
+
+    let arguments = (ins
+        SymbolNameAttr:$sysm_name,
+        TypeAttrOf<FunctionType>:$function_type,
+        OptionalAttr<DictArrayAttr>:$arg_attrs,
+        OptionalAttr<DictArrayAttr>:$res_attrs
+    );
+    let regions = (region AnyRegion:$body);
+
+    let builders = [OpBuilder<(ins "StringRef":$name, "FunctionType":$type,
+                                CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)
+                    >];
+
+    let extraClassDeclaration = [{
+        ArrayRef<Type> getArgumentTypes() {
+            return getFunctionType().getInputs();
+        }
+
+        ArrayRef<Type> getResultTypes() {
+            return getFunctionType().getResults();
+        }
+
+        Region *getCallableRegion() {
+            return &getBody();
+        }
+    }];
+
+    let hasCustomAssemblyFormat = 1;
+    let skipDefaultBuilders = 1;
+}
+
+def GenericCallOp: OBS_op<"generic_call"> {
+    let summary = "generic call operation";
+    let description = [{
+        Generic calls represent calls to a user defined function that needs to be specialized for the shape
+        of its arguments.
+    }];
+
+    let arguments = (ins FlatSymbolRefAttr:$callee, Variadic<F64Tensor>:$inputs);
+
+    let results = (outs F64Tensor);
+
+    let assemblyFormat = [{
+        $callee `(` $inputs `)` attr-dict `:` functional-type($inputs, results)
+    }];
+
+    let builders = [
+        OpBuilder<(ins "StringRef":$callee, "ArrayRef<Value>":$arguments)>
+    ];
+}
+
+def MulOp : OBS_op<"mul"> {
+    let summary = "element-wise multiplication operation";
+
+    let description = [{
+        The "mul" operation performs element-wise multiplication between two
+        tensors. The shapes of the tensor operands are expected to match.
+    }];
+
+    let arguments = (ins F64Tensor:$lhs, F64Tensor:$rhs);
+    let results = (outs F64Tensor);
+
+    let hasCustomAssemblyFormat = 1;
+    let builders = [
+        OpBuilder<(ins "Value":$lhs, "Value":$rhs)>
+    ];
+}
+
+def PrintOp: OBS_op<"print"> {
+    let summary = "print operation";
+    let description = [{
+        The print builtin operation prints a given input sensor, and produces no results.
+    }];
+    let arguments = (ins F64Tensor:$input);
+    let assemblyFormat = "$input attr-dict `:` type($input)";
+}
+
+def ReshapeOp: OBS_op<"reshape"> {
+    let summary = "tensor reshape operation";
+    let description = [{
+        Reshape operation is transforming its input sensor into a new tensor with the same number of
+        elements but different shapes.
+    }];
+
+    let arguments = (ins F64Tensor:$input);
+    let results = (outs StaticShapeTensorOf<[F64]>);
+
+    let assemblyFormat = [{
+        `(` $input `:` type($input) `)` attr-dict `to` type(results)
+    }];
+}
+
+def TransposeOp: OBS_op<"transpose"> {
+    let summary = "transpose operation";
+    let arguments = (ins F64Tensor:$input);
+    let results = (outs F64Tensor);
+
+    let assemblyFormat = [{
+        `(` $input `:` type($input) `)` attr-dict `to` type(results)
+    }];
+
+    let builders = [
+        OpBuilder<(ins "Value":$input)>
+    ];
+
+    let hasVerifier = 1;
+}
+
+def ReturnOp : OBS_op<"return", [Pure, HasParent<"FuncOp">, Terminator]> {
+    let summary = "return operation";
+    let description = [{
+        The "return" operation represents a return operation within a function;
+    }];
+
+    let arguments = (ins Variadic<F64Tensor>:$input);
+    let assemblyFormat = "($input^ `:` type($input))? attr-dict ";
+ 
+    let builders = [
+        OpBuilder<(ins), [{ build($_builder, $_state, std::nullopt); }]>
+    ];
+
+    let extraClassDeclaration = [{
+        bool hasOperand() {
+            return getNumOperands() != 0;
+        }
+    }];
+
+    let hasVerifier = 1;
+}
+
+
+#endif //OBS_OPS
diff --git a/obs/obs-ir/Dialect.cpp b/obs/obs-ir/Dialect.cpp
new file mode 100644
index 0000000000000..b33174b8355cd
--- /dev/null
+++ b/obs/obs-ir/Dialect.cpp
@@ -0,0 +1,165 @@
+
+#include "Dialect.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/OperationSupport.h"
+#include "mlir/IR/Value.h"
+#include "mlir/Interfaces/FunctionImplementation.h"
+#include "mlir/Support/LLVM.h"
+#include "mlir/Support/LogicalResult.h"
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/StringRef.h"
+#include "llvm/Support/Casting.h"
+#include <algorithm>
+#include <string>
+
+using namespace mlir;
+using namespace mlir::obs;
+
+#include "Dialect.cpp.inc"
+
+void OBSDialect::initialize() {
+    addOperations<
+    #define GET_OP_LIST
+    #include "Ops.cpp.inc"
+    >();
+}
+
+static mlir::ParseResult parseBinaryOp(mlir::OpAsmParser &parser, mlir::OperationState &result) {
+
+    SmallVector<mlir::OpAsmParser::UnresolvedOperand, 2> operands;
+    SMLoc operandsLoc = parser.getCurrentLocation();
+    Type type;
+
+    if (parser.parseOperandList(operands, 2) ||
+        parser.parseOptionalAttrDict(result.attributes) ||
+        parser.parseColonType(type))
+        return mlir::failure();
+
+    if (FunctionType funcType = llvm::dyn_cast<FunctionType>(type)) {
+        if (parser.resolveOperands(operands, funcType.getInputs(), operandsLoc, 
+            result.operands))
+            return mlir::failure();
+        result.addTypes(funcType.getResults());
+        return mlir::success();
+    }
+    if (parser.resolveOperands(operands, type, result.operands))
+        return mlir::failure();
+    
+    result.addTypes(type);
+    return mlir::success();
+}
+
+static void printBinary(mlir::OpAsmPrinter &printer, mlir::Operation *op) {
+    printer << " " << op->getOperands();
+    printer.printOptionalAttrDict(op->getAttrs());
+    printer <<  " : " ;
+    
+    Type resultType = *op->result_type_begin();
+    if (llvm::all_of(op->getOperandTypes(), [=](Type type) { return type == resultType; })) {
+        printer << resultType;
+    }
+
+    printer.printFunctionalType(op->getOperandTypes(), op->getResultTypes());
+}
+
+void ConstantOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, double value) {
+    auto dataType = RankedTensorType::get({}, builder.getF64Type());
+    auto dataAttribute = DenseElementsAttr::get(dataType, value);
+    ConstantOp::build(builder, state, dataType, dataAttribute);
+}
+
+mlir::ParseResult ConstantOp::parse(mlir::OpAsmParser &parser, mlir::OperationState &result) {
+    mlir::DenseElementsAttr value;
+
+    if (parser.parseOptionalAttrDict(result.attributes) || 
+        parser.parseAttribute(value, "value", result.attributes))
+        return failure();
+    
+    result.addTypes(value.getType());
+    return success();
+}
+
+void ConstantOp::print(mlir::OpAsmPrinter &printer) {
+    printer << " ";
+    printer.printOptionalAttrDict((*this)->getAttrs(), {"value"});
+    printer << getValue();
+}
+
+mlir::LogicalResult ConstantOp::verify() {
+    auto resultType = llvm::dyn_cast<mlir::RankedTensorType>(getResult().getType());
+    if (!resultType)
+        return success();
+
+    auto attrType = llvm::cast<mlir::RankedTensorType>(getValue().getType());
+    for (int dim = 0, dimE = attrType.getRank(); dim < dimE; ++dim) {
+        if (attrType.getShape()[dim] != resultType.getShape()[dim]) {
+            return emitOpError("return type shape mismatches its attribute at dimension")
+            << dim << ": " << attrType.getShape()[dim]
+            << " != " << resultType.getShape()[dim];
+        }
+    }
+    return mlir::success();
+}
+
+void AddOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, 
+                  mlir::Value lhs, mlir::Value rhs) {
+    state.addTypes(UnrankedTensorType::get(builder.getF64Type()));
+    state.addOperands({lhs, rhs});
+}
+
+mlir::ParseResult AddOp::parse(mlir::OpAsmParser &parser, mlir::OperationState &result) {
+    return parseBinaryOp(parser, result);
+}
+
+void AddOp::print(mlir::OpAsmPrinter &p) {
+    printBinary(p, * this);
+}
+
+void GenericCallOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
+                          StringRef callee, ArrayRef<mlir::Value> arguments) {
+    state.addTypes(UnrankedTensorType::get(builder.getF64Type()));
+    state.addOperands(arguments);
+    state.addAttribute("callee", mlir::SymbolRefAttr::get(builder.getContext(), callee));
+}
+
+void FuncOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
+		   llvm::StringRef name, mlir::FunctionType type,
+		   llvm::ArrayRef<mlir::NamedAttribute> attrs ) {
+  buildWithEntryBlock(builder, state, name, type, attrs, type.getInputs());
+}
+
+mlir::ParseResult FuncOp::parse(mlir::OpAsmParser &parser,
+				mlir::OperationState &result) {
+  auto buildFuncType =
+    [](mlir::Builder &builder, llvm::ArrayRef<mlir::Type> argTypes,
+       llvm::ArrayRef<mlir::Type> results,
+       mlir::function_interface_impl::VariadicFlag,
+       std::string &) { return builder.getFunctionType(argTypes, results); };
+  return mlir::function_interface_impl::
+    parseFunctionOp(parser, result, false,
+		    getFunctionTypeAttrName(result.name),
+		    buildFuncType,
+		    getArgAttrsAttrName(result.name),
+		    getResAttrsAttrName(result.name));
+}
+
+void FuncOp::print(mlir::OpAsmPrinter &p) {
+  mlir::function_interface_impl::
+    printFunctionOp(p, *this, false,
+		    getFunctionTypeAttrName(),
+		    getArgAttrsAttrName(),
+		    getResAttrsAttrName());
+}
+
+
+#define GET_OP_CLASSES
+#include "Ops.cpp.inc"
+
+
+
+
diff --git a/obs/obs-ir/test.cpp b/obs/obs-ir/obs.cpp
similarity index 90%
rename from obs/obs-ir/test.cpp
rename to obs/obs-ir/obs.cpp
index 829797eaf4cf5..379054e0f4ea7 100644
--- a/obs/obs-ir/test.cpp
+++ b/obs/obs-ir/obs.cpp
@@ -2,15 +2,18 @@
 #include "Lexer.h"
 #include "Parser.h"
 
+#include "Dialect.h"
+#include "mlir/IR/MLIRContext.h"
 #include "llvm/ADT/StringRef.h"
 #include "llvm/Support/CommandLine.h"
 #include "llvm/Support/ErrorOr.h"
 #include "llvm/Support/MemoryBuffer.h"
 #include "llvm/Support/raw_ostream.h"
+#include <iostream>
 #include <memory>
 #include <string>
 #include <system_error>
-#include <iostream>
+
 
 using namespace obs;
 namespace cl = llvm::cl;
@@ -22,7 +25,7 @@ static cl::opt<std::string> inputFilename(cl::Positional,
 
 namespace {
 enum Action { None, DumpAST }; 
-}
+} // namespace
 
 static cl::opt<enum Action> emitAction("emit", cl::desc("Select the kind of output desired"), 
                                        cl::values(clEnumValN(DumpAST, "ast", "output the AST dump")));
@@ -39,6 +42,12 @@ std::unique_ptr<obs::ModuleAST> parseInputFile(llvm::StringRef filename) {
   return parser.parseModule();
 }
 
+int dumpMLIR() {
+  mlir::MLIRContext context;
+  context.getOrLoadDialect<mlir::obs::OBSDialect>();
+  return 0;
+}
+
 int main(int argc, char **argv) {
   cl::ParseCommandLineOptions(argc, argv, "toy compiler\n");
   auto moduleAST = parseInputFile(inputFilename);

>From fe9af3502a7728b0bc7e1e3923bf7cd8241fa6b9 Mon Sep 17 00:00:00 2001
From: Shuanglong Kan <kanshuanglong at outlook.com>
Date: Mon, 1 Apr 2024 15:12:58 +0200
Subject: [PATCH 4/5] initial: finish toy example chapter 2

---
 obs/CMakeLists.txt     |   1 +
 obs/include/Lexer.h    |   3 +
 obs/include/MLIRGen.h  |  21 +++
 obs/include/Ops.td     |   8 +-
 obs/obs-ir/Dialect.cpp |  60 ++++++++
 obs/obs-ir/MLIRGen.cpp | 334 +++++++++++++++++++++++++++++++++++++++++
 obs/obs-ir/obs.cpp     |  93 ++++++++++--
 7 files changed, 505 insertions(+), 15 deletions(-)
 create mode 100644 obs/include/MLIRGen.h
 create mode 100644 obs/obs-ir/MLIRGen.cpp

diff --git a/obs/CMakeLists.txt b/obs/CMakeLists.txt
index 5b7a50c53125d..05dcf074a73a6 100644
--- a/obs/CMakeLists.txt
+++ b/obs/CMakeLists.txt
@@ -29,6 +29,7 @@ include_directories(include/)
 add_obs_executable(obstest 
     obs-ir/obs.cpp 
     obs-ir/Dialect.cpp
+    obs-ir/MLIRGen.cpp
     parser/AST.cpp
   DEPENDS
     OBSOpGen
diff --git a/obs/include/Lexer.h b/obs/include/Lexer.h
index bf47f17dfa92c..6abbb6ec53886 100644
--- a/obs/include/Lexer.h
+++ b/obs/include/Lexer.h
@@ -65,6 +65,9 @@ class Lexer {
       lastChar = Token(getNextChar());
     }
 
+    lastLocation.line = curLineNum;
+    lastLocation.col = curCol;
+
     if (isalpha(lastChar)) {
       identifierStr = (char)lastChar;
       while(isalnum(lastChar = Token(getNextChar())) || lastChar == '_') {
diff --git a/obs/include/MLIRGen.h b/obs/include/MLIRGen.h
new file mode 100644
index 0000000000000..b6c856fac5eec
--- /dev/null
+++ b/obs/include/MLIRGen.h
@@ -0,0 +1,21 @@
+
+#ifndef OBS_MLIRGEN_H
+#define OBS_MLIRGEN_H
+
+#include <memory>
+#include <mlir/IR/BuiltinOps.h>
+#include <mlir/IR/MLIRContext.h>
+
+namespace mlir{
+class MLIRContext;
+template <typename OpTy>
+class OwningOpRef;
+class ModuleOp;
+} // namespace mlir
+
+namespace obs {
+class ModuleAST;
+mlir::OwningOpRef<mlir::ModuleOp> mlirGen(mlir::MLIRContext &context, ModuleAST &moduleAST);
+} //namespace obs
+
+#endif //OBS_MLIRGEN_H
\ No newline at end of file
diff --git a/obs/include/Ops.td b/obs/include/Ops.td
index 96e9d02b72160..9463de72f2614 100644
--- a/obs/include/Ops.td
+++ b/obs/include/Ops.td
@@ -9,7 +9,7 @@ include "mlir/Interfaces/SideEffectInterfaces.td"
 
 def OBS_Dialect: Dialect {
     //The namespace of the Dialect.
-    let name = "toy";
+    let name = "obs";
 
     let summary = "A high-level dialect for analyzing and optimzing the Language";
 
@@ -66,12 +66,12 @@ def AddOp : OBS_op<"add"> {
 def FuncOp : OBS_op<"func", [FunctionOpInterface, IsolatedFromAbove]> {
     let summary = "user defined function operation";
     let description = [{
-        The "toy.func" operation represents a user defined function. These are callable SSA-region operations
-        that contain toy computations.
+        The "obs.func" operation represents a user defined function. These are callable SSA-region operations
+        that contain obs computations.
     }];
 
     let arguments = (ins
-        SymbolNameAttr:$sysm_name,
+        SymbolNameAttr:$sym_name,
         TypeAttrOf<FunctionType>:$function_type,
         OptionalAttr<DictArrayAttr>:$arg_attrs,
         OptionalAttr<DictArrayAttr>:$res_attrs
diff --git a/obs/obs-ir/Dialect.cpp b/obs/obs-ir/Dialect.cpp
index b33174b8355cd..c5a4ef2e5c310 100644
--- a/obs/obs-ir/Dialect.cpp
+++ b/obs/obs-ir/Dialect.cpp
@@ -156,6 +156,66 @@ void FuncOp::print(mlir::OpAsmPrinter &p) {
 		    getResAttrsAttrName());
 }
 
+void MulOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
+		  mlir::Value lhs, mlir::Value rhs) {
+  state.addTypes(UnrankedTensorType::get(builder.getF64Type()));
+  state.addOperands({lhs, rhs});
+}
+
+mlir::ParseResult MulOp::parse(mlir::OpAsmParser &parser,
+			       mlir::OperationState &result) {
+  return parseBinaryOp(parser, result);
+}
+
+void MulOp::print(mlir::OpAsmPrinter &p) {
+  printBinary(p, * this);
+}
+
+mlir::LogicalResult ReturnOp::verify() {
+    auto function = cast<FuncOp>((*this).getParentOp());
+    if (getNumOperands() > 1) {
+        return emitOpError() << "expects at most 1 return operand" ;
+    }
+    const auto &results = function.getFunctionType().getResults();
+    if (getNumOperands() != results.size()) {
+        return emitOpError() << "does not return the same number of values("
+                             << getNumOperands() << ") as the enclosing function ("
+                             << results.size() << ")"; 
+    }
+
+    if (!hasOperand()) {
+        return mlir::success();
+    }
+    
+    auto inputType = *operand_type_begin();
+    auto resultType = results.front();
+
+    if (inputType == resultType || llvm::isa<mlir::UnrankedTensorType>(inputType) ||
+        llvm::isa<mlir::UnrankedTensorType>(resultType) ) {
+        return mlir::success();
+    }
+    return emitError() << "type of return operand (" << inputType
+                       << ") doesn't match function result type (" << resultType
+                       << ")";
+}
+
+void TransposeOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, mlir::Value value) {
+    state.addTypes(UnrankedTensorType::get(builder.getF64Type()));
+    state.addOperands(value);
+}
+
+mlir::LogicalResult TransposeOp::verify() {
+    auto inputType = llvm::dyn_cast<RankedTensorType>(getOperand().getType());
+    auto resultType = llvm::dyn_cast<RankedTensorType>(getType());
+    if (!inputType || !resultType)
+        return mlir::success();
+
+    auto inputShape = inputType.getShape();
+    if(!std::equal(inputShape.begin(), inputShape.end(), resultType.getShape().rbegin())) {
+        return emitError() << "expected result shape to be a transpose of the input";
+    }
+    return mlir::success();
+}
 
 #define GET_OP_CLASSES
 #include "Ops.cpp.inc"
diff --git a/obs/obs-ir/MLIRGen.cpp b/obs/obs-ir/MLIRGen.cpp
new file mode 100644
index 0000000000000..388d7e5b5fdb0
--- /dev/null
+++ b/obs/obs-ir/MLIRGen.cpp
@@ -0,0 +1,334 @@
+
+#include "MLIRGen.h"
+
+#include "mlir/IR/Block.h"
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/Value.h"
+#include "mlir/Support/LogicalResult.h"
+#include "AST.h"
+#include "Dialect.h"
+
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/Verifier.h"
+#include "Lexer.h"
+
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/ScopedHashTable.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/StringRef.h"
+#include "llvm/ADT/Twine.h"
+#include "llvm/Support/Casting.h"
+#include <cassert>
+#include <cstdint>
+#include <functional>
+#include <mlir/IR/BuiltinAttributes.h>
+#include <mlir/IR/Location.h>
+#include <mlir/IR/OwningOpRef.h>
+#include <numeric>
+#include <optional>
+#include <vector>
+
+using namespace mlir::obs;
+using namespace obs;
+
+using llvm::ArrayRef;
+using llvm::cast;
+using llvm::dyn_cast;
+using llvm::isa;
+using llvm::ScopedHashTableScope;
+using llvm::SmallVector;
+using llvm::StringRef;
+using llvm::Twine;
+
+namespace {
+
+class MLIRGenImpl {
+
+public:
+  MLIRGenImpl(mlir::MLIRContext &context) : builder( &context) {}
+
+  mlir::ModuleOp mlirGen(ModuleAST &moduleAST) {
+    theModule = mlir::ModuleOp::create(builder.getUnknownLoc());
+
+    for (FunctionAST &f : moduleAST) {
+        mlirGen(f);
+    }
+
+    if (failed(mlir::verify(theModule))) {
+        theModule->emitError("module verification error");
+        return nullptr;
+    }
+    return theModule;
+  }
+
+private:
+  mlir::ModuleOp theModule;
+  mlir::OpBuilder  builder;
+  llvm::ScopedHashTable<StringRef, mlir::Value> symbolTable;
+
+  mlir::Location loc(const Location &loc) {
+    return mlir::FileLineColLoc::get(builder.getStringAttr(*loc.file), loc.line, loc.col);
+  }
+
+  mlir::LogicalResult declare(llvm::StringRef var, mlir::Value value) {
+    if (symbolTable.count(var)) {
+        return mlir::failure();
+    }
+    symbolTable.insert(var, value);
+    return mlir::success();
+  }
+
+  mlir::Type getType(ArrayRef<int64_t> shape) {
+    if (shape.empty()) {
+        return mlir::UnrankedTensorType::get(builder.getF64Type());
+    }
+    return mlir::RankedTensorType::get(shape, builder.getF64Type());
+  }
+
+  mlir::Type getType(const VarType &type) {
+    return getType(type.shape);
+  }
+
+  mlir::obs::FuncOp mlirGen(PrototypeAST &proto) {
+    auto location = loc(proto.loc());
+
+    llvm::SmallVector<mlir::Type, 4> argTypes(proto.getArgs().size(), getType(VarType{}));
+
+    auto funcType = builder.getFunctionType(argTypes, std::nullopt);
+    return builder.create<mlir::obs::FuncOp>(location, proto.getName(), funcType);
+  }
+
+  void collectData(ExprAST &expr, std::vector<double> &data) {
+    if (auto *lit = dyn_cast<LiteralExprAST>(&expr)) {
+        for (auto &value : lit -> getValues()) {
+            collectData(*value, data);
+            return;
+        }
+    }
+
+    assert(isa<NumberExprAST>(expr) && "expected literal or number expr");
+    data.push_back(cast<NumberExprAST>(expr).getValue());
+  }
+
+  mlir::Value mlirGen(LiteralExprAST &lit ) {
+    auto type = getType(lit.getDims());
+    std::vector<double> data;
+    data.reserve(std::accumulate(lit.getDims().begin(), lit.getDims().end(), 1, std::multiplies<int>()));
+    collectData(lit, data);
+
+    mlir::Type elementType = builder.getF64Type();
+    auto dataType = mlir::RankedTensorType::get(lit.getDims(), elementType);
+
+    auto dataAttribute = 
+        mlir::DenseElementsAttr::get(dataType, llvm::ArrayRef(data));
+
+    return builder.create<ConstantOp>(loc(lit.loc()), type, dataAttribute);
+  }
+
+  mlir::Value mlirGen(ExprAST &expr) {
+    switch (expr.getKind()) {
+    case obs::ExprAST::Expr_BinOp:
+        return mlirGen(cast<BinaryExprAST>(expr));
+    case obs::ExprAST::Expr_Var:
+        return mlirGen(cast<VariableExprAST>(expr));
+    case obs::ExprAST::Expr_Literal:
+        return mlirGen(cast<LiteralExprAST>(expr));
+    case obs::ExprAST::Expr_Call:
+        return mlirGen(cast<CallExprAST>(expr));
+    case obs::ExprAST::Expr_Num:
+        return mlirGen(cast<NumberExprAST>(expr));
+    default:
+        mlir::emitError(loc(expr.loc()))
+             << "MLIR codegen encounter an unhandled expr kind '"
+             << Twine(expr.getKind()) << "'";
+        return nullptr;
+    }
+  }
+
+  mlir::LogicalResult mlirGen(PrintExprAST &call) {
+    auto arg = mlirGen(*call.getArg());
+    if (!arg) {
+        return mlir::failure();
+    }
+
+    builder.create<PrintOp>(loc(call.loc()), arg);
+    return mlir::success();
+  } 
+
+  mlir::LogicalResult mlirGen(ExprASTList &blockAST) {
+    ScopedHashTableScope<StringRef, mlir::Value> varScope(symbolTable);
+
+    for (auto &expr : blockAST) {
+        if (auto *vardecl = dyn_cast<VarDeclExprAST>(expr.get())) {
+        if (!mlirGen(*vardecl))
+          return mlir::failure();
+        continue;
+      }
+      if (auto *ret = dyn_cast<ReturnExprAST>(expr.get()))
+        return mlirGen(*ret);
+      if (auto *print = dyn_cast<PrintExprAST>(expr.get())) {
+        if (mlir::failed(mlirGen(*print)))
+          return mlir::success();
+        continue;
+      }
+
+      // Generic expression dispatch codegen.                                                                                                                         
+      if (!mlirGen(*expr))
+        return mlir::failure();
+    }
+    return mlir::success();
+  }
+
+  mlir::Value mlirGen(NumberExprAST &num) {
+    return builder.create<ConstantOp>(loc(num.loc()), num.getValue());
+  }
+
+  mlir::obs::FuncOp mlirGen(FunctionAST &funcAST) {
+    llvm::ScopedHashTableScope<llvm::StringRef, mlir::Value> varScope(symbolTable);
+
+    builder.setInsertionPointToEnd(theModule.getBody());
+    mlir::obs::FuncOp function = mlirGen(*funcAST.getProto());
+    if (! function) {
+        return nullptr;
+    }
+
+    mlir::Block &entryBlock = function.front();
+
+    auto protoArgs = funcAST.getProto()->getArgs();
+
+    for (const auto nameValue : llvm::zip(protoArgs, entryBlock.getArguments())) {
+        if (failed(declare(std::get<0>(nameValue)->getName(), std::get<1>(nameValue)))) {
+            return nullptr;
+        }
+    }
+
+    builder.setInsertionPointToStart( &entryBlock );
+
+    if (mlir::failed(mlirGen(*funcAST.getBody()))) {
+        function->erase();
+        return nullptr;
+    }
+
+    ReturnOp returnOp;
+    if (!entryBlock.empty()) {
+        returnOp = dyn_cast<ReturnOp>(entryBlock.back());
+    }
+    if (!returnOp) {
+        builder.create<ReturnOp>(loc(funcAST.getProto()->loc()));
+    } else if (returnOp.hasOperand()) {
+        function.setType(builder.getFunctionType(function.getFunctionType().getInputs(), getType(VarType{})));
+    }
+    return function;
+  }
+
+  mlir::Value mlirGen(BinaryExprAST &binop) {
+    mlir::Value lhs = mlirGen(*binop.getLHS());
+    if (!lhs)
+        return nullptr;
+    mlir::Value rhs = mlirGen(*binop.getRHS());
+    if (!rhs)
+        return nullptr;
+
+    auto location = loc(binop.loc());
+
+    switch (binop.getOp()) {
+    case '+':
+        return builder.create<AddOp>(location, lhs, rhs);
+    case '*':
+        return builder.create<MulOp>(location, lhs, rhs);
+    }
+
+    emitError(location, "invalid binary operator '") << binop.getOp() << "'";
+    return nullptr;
+  }
+
+  mlir::Value mlirGen(VariableExprAST &expr) {
+    if (auto variable = symbolTable.lookup(expr.getName())) {
+        return variable;
+    }
+    mlir::emitError(loc(expr.loc()), "error: unknown variable '") << expr.getName() << "'";
+    return nullptr;
+  } 
+
+  mlir::LogicalResult mlirGen(ReturnExprAST &ret) {
+    auto location = loc(ret.loc());
+
+    mlir::Value expr = nullptr;
+    if (ret.getExpr().has_value()) {
+        if (!(expr = mlirGen(**ret.getExpr()))){
+            return mlir::failure();
+        }
+    }
+    builder.create<ReturnOp>(location, expr? ArrayRef(expr) : ArrayRef<mlir::Value>());
+    return mlir::success();
+  }
+
+  mlir::Value mlirGen(CallExprAST &call) {
+    llvm::StringRef callee = call.getCallee();
+    auto location = loc(call.loc());
+
+    SmallVector<mlir::Value, 4> operands;
+
+    for (auto &expr : call.getArgs()) {
+        auto arg = mlirGen(*expr);
+        if (!arg) {
+            return nullptr;
+        }
+        operands.push_back(arg);
+    }
+
+    if (callee == "transpose") {
+        if (call.getArgs().size() != 1) {
+            mlir::emitError(location, "MLIR codegen encountered an error: obs.transpose "
+                                                   "does not accept multiple arguments.");
+            return nullptr;
+        }
+        return builder.create<TransposeOp>(location, operands[0]);
+    }
+    return builder.create<GenericCallOp>(location, callee, operands);
+  }
+
+  mlir::Value mlirGen(VarDeclExprAST &vardecl) {
+    auto *init = vardecl.getInitVal();
+
+    if (!init) {
+        mlir::emitError(loc(vardecl.loc()), "missing initializer in variable declaration");
+        return nullptr;
+    }
+
+    mlir::Value value = mlirGen(*init);
+    if (!value) {
+        return nullptr;
+    }
+
+    if (!vardecl.getType().shape.empty()) {
+        value = builder.create<ReshapeOp>(loc(vardecl.loc()), getType(vardecl.getType()), value);
+    }
+
+    if (failed(declare(vardecl.getName(), value))) {
+        return nullptr;
+    }
+    return value;
+  }
+
+};
+
+} //namespace
+
+namespace obs {
+mlir::OwningOpRef<mlir::ModuleOp> mlirGen(mlir::MLIRContext &context, ModuleAST &moduleAST) {
+    return MLIRGenImpl(context).mlirGen(moduleAST);
+}
+} //namespace obs
+
+
+
+
+
+
+
+
diff --git a/obs/obs-ir/obs.cpp b/obs/obs-ir/obs.cpp
index 379054e0f4ea7..ef5ed4d96bcd1 100644
--- a/obs/obs-ir/obs.cpp
+++ b/obs/obs-ir/obs.cpp
@@ -3,12 +3,19 @@
 #include "Parser.h"
 
 #include "Dialect.h"
+#include "MLIRGen.h"
+
+#include "mlir/IR/AsmState.h"
+#include "mlir/IR/BuiltinOps.h"
 #include "mlir/IR/MLIRContext.h"
+#include "mlir/Parser/Parser.h"
+
 #include "llvm/ADT/StringRef.h"
 #include "llvm/Support/CommandLine.h"
 #include "llvm/Support/ErrorOr.h"
 #include "llvm/Support/MemoryBuffer.h"
 #include "llvm/Support/raw_ostream.h"
+#include "llvm/Support/SourceMgr.h"
 #include <iostream>
 #include <memory>
 #include <string>
@@ -19,16 +26,26 @@ using namespace obs;
 namespace cl = llvm::cl;
 
 static cl::opt<std::string> inputFilename(cl::Positional, 
-                                          cl::desc("<input toy file>"),
+                                          cl::desc("<input obs file>"),
                                           cl::init("-"),
                                           cl::value_desc("filename"));
 
 namespace {
-enum Action { None, DumpAST }; 
+  enum InputType { OBS, MLIR };
+} //namespace
+static cl::opt<enum InputType> inputType(
+    "x", cl::init(OBS), cl::desc("Decided the kind of output desired"),
+    cl::values(clEnumValN(OBS, "obs", "load the input file as a OBS source.")),
+    cl::values(clEnumValN(MLIR, "mlir",
+                          "load the input file as an MLIR file")));
+
+namespace {
+enum Action { None, DumpAST, DumpMLIR }; 
 } // namespace
 
 static cl::opt<enum Action> emitAction("emit", cl::desc("Select the kind of output desired"), 
-                                       cl::values(clEnumValN(DumpAST, "ast", "output the AST dump")));
+                           cl::values(clEnumValN(DumpAST, "ast", "output the AST dump")),
+                           cl::values(clEnumValN(DumpMLIR, "mlir", "output the MLIR dump")));
 
 std::unique_ptr<obs::ModuleAST> parseInputFile(llvm::StringRef filename) {
   llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> fileOrErr = llvm::MemoryBuffer::getFileOrSTDIN(filename);
@@ -45,20 +62,74 @@ std::unique_ptr<obs::ModuleAST> parseInputFile(llvm::StringRef filename) {
 int dumpMLIR() {
   mlir::MLIRContext context;
   context.getOrLoadDialect<mlir::obs::OBSDialect>();
+  
+  if (inputType != InputType::MLIR &&
+      !llvm::StringRef(inputFilename).ends_with(".mlir")) {
+    auto moduleAST = parseInputFile(inputFilename);
+    if (!moduleAST)
+      return 6;
+    mlir::OwningOpRef<mlir::ModuleOp> module = mlirGen(context, *moduleAST);
+    if (!module)
+      return 1;
+
+    module->dump();
+    return 0;
+  }
+
+  // Otherwise, the input is '.mlir'.                                                                                                                                 
+  llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> fileOrErr =
+      llvm::MemoryBuffer::getFileOrSTDIN(inputFilename);
+  if (std::error_code ec = fileOrErr.getError()) {
+    llvm::errs() << "Could not open input file: " << ec.message() << "\n";
+    return -1;
+  }
+
+  // Parse the input mlir.                                                                                                                                            
+  llvm::SourceMgr sourceMgr;
+  sourceMgr.AddNewSourceBuffer(std::move(*fileOrErr), llvm::SMLoc());
+  mlir::OwningOpRef<mlir::ModuleOp> module =
+      mlir::parseSourceFile<mlir::ModuleOp>(sourceMgr, &context);
+  if (!module) {
+    llvm::errs() << "Error can't load file " << inputFilename << "\n";
+    return 3;
+  }
+
+  module->dump();
   return 0;
+
+
 }
 
-int main(int argc, char **argv) {
-  cl::ParseCommandLineOptions(argc, argv, "toy compiler\n");
+int dumpAST() {
+  if (inputType == InputType::MLIR) {
+    llvm::errs() << "Can't dump a OBS AST when the input is MLIR\n";
+    return 5;
+  }
+
   auto moduleAST = parseInputFile(inputFilename);
+  if (!moduleAST)
+    return 1;
+
+  dump(*moduleAST);
+  return 0;
+}
+
+int main(int argc, char **argv) {
+  // Register any command line options.                                                                                                                               
+  mlir::registerAsmPrinterCLOptions();
+  mlir::registerMLIRContextCLOptions();
+  cl::ParseCommandLineOptions(argc, argv, "obs compiler\n");
 
-  switch(emitAction) {
-    case Action::DumpAST:
-      dump(*moduleAST);
-      return 0;
-    default:
-      llvm::errs() << "No action specified (parsing only?), use -emit=<action>\n";
+  switch (emitAction) {
+  case Action::DumpAST:
+    return dumpAST();
+  case Action::DumpMLIR:
+    return dumpMLIR();
+  default:
+    llvm::errs() << "No action specified (parsing only?), use -emit=<action>\n";
   }
+
   return 0;
+
 }
 

>From e5caa51de291b64d2cb2f0cd661893df80e643e6 Mon Sep 17 00:00:00 2001
From: Shuanglong Kan <kanshuanglong at outlook.com>
Date: Fri, 31 May 2024 08:03:09 +0100
Subject: [PATCH 5/5] feat: add OBS operations

---
 mlir/test/Examples/Toy/Ch2/codegen.toy |  31 -----
 obs/.clang-format                      |   1 +
 obs/CMakeLists.txt                     |  51 ++++++--
 obs/codegen/CodeGen.cpp                |  55 +++++++++
 obs/codegen/CodeGenAction.cpp          |  30 +++++
 obs/codegen/OBSGen.cpp                 |  27 ++++
 obs/include/CodeGen.h                  |  41 +++++++
 obs/include/CodeGenAction.h            |  43 +++++++
 obs/include/Dialect.h                  |  50 ++++++++
 obs/include/OBSGen.h                   |   2 +
 obs/include/Ops.td                     | 125 ++++++++++++++++++-
 obs/obs-ir/Dialect.cpp                 | 163 +++++++++++++++++++++++++
 obs/obs-ir/MLIRGen.cpp                 |  50 +++++++-
 obs/test/codegen.toy                   |   5 +
 obs/test/test1.cpp                     |   9 ++
 15 files changed, 636 insertions(+), 47 deletions(-)
 delete mode 100644 mlir/test/Examples/Toy/Ch2/codegen.toy
 create mode 100644 obs/.clang-format
 create mode 100644 obs/codegen/CodeGen.cpp
 create mode 100644 obs/codegen/CodeGenAction.cpp
 create mode 100644 obs/codegen/OBSGen.cpp
 create mode 100644 obs/include/CodeGen.h
 create mode 100644 obs/include/CodeGenAction.h
 create mode 100644 obs/include/OBSGen.h
 create mode 100644 obs/test/codegen.toy
 create mode 100644 obs/test/test1.cpp

diff --git a/mlir/test/Examples/Toy/Ch2/codegen.toy b/mlir/test/Examples/Toy/Ch2/codegen.toy
deleted file mode 100644
index 12178d6afb309..0000000000000
--- a/mlir/test/Examples/Toy/Ch2/codegen.toy
+++ /dev/null
@@ -1,31 +0,0 @@
-# RUN: toyc-ch2 %s -emit=mlir 2>&1 | FileCheck %s
-
-# User defined generic function that operates on unknown shaped arguments
-def multiply_transpose(a, b) {
-  return transpose(a) * transpose(b);
-}
-
-def main() {
-  var a<2, 3> = [[1, 2, 3], [4, 5, 6]];
-  var b<2, 3> = [1, 2, 3, 4, 5, 6];
-  var c = multiply_transpose(a, b);
-  var d = multiply_transpose(b, a);
-  print(d);
-}
-
-# CHECK-LABEL: toy.func @multiply_transpose(
-# CHECK-SAME:                               [[VAL_0:%.*]]: tensor<*xf64>, [[VAL_1:%.*]]: tensor<*xf64>) -> tensor<*xf64>
-# CHECK:         [[VAL_2:%.*]] = toy.transpose([[VAL_0]] : tensor<*xf64>) to tensor<*xf64>
-# CHECK-NEXT:    [[VAL_3:%.*]] = toy.transpose([[VAL_1]] : tensor<*xf64>) to tensor<*xf64>
-# CHECK-NEXT:    [[VAL_4:%.*]] = toy.mul [[VAL_2]], [[VAL_3]] :  tensor<*xf64>
-# CHECK-NEXT:    toy.return [[VAL_4]] : tensor<*xf64>
-
-# CHECK-LABEL: toy.func @main()
-# CHECK-NEXT:    [[VAL_5:%.*]] = toy.constant dense<{{\[\[}}1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>
-# CHECK-NEXT:    [[VAL_6:%.*]] = toy.reshape([[VAL_5]] : tensor<2x3xf64>) to tensor<2x3xf64>
-# CHECK-NEXT:    [[VAL_7:%.*]] = toy.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00]> : tensor<6xf64>
-# CHECK-NEXT:    [[VAL_8:%.*]] = toy.reshape([[VAL_7]] : tensor<6xf64>) to tensor<2x3xf64>
-# CHECK-NEXT:    [[VAL_9:%.*]] = toy.generic_call @multiply_transpose([[VAL_6]], [[VAL_8]]) : (tensor<2x3xf64>, tensor<2x3xf64>) -> tensor<*xf64>
-# CHECK-NEXT:    [[VAL_10:%.*]] = toy.generic_call @multiply_transpose([[VAL_8]], [[VAL_6]]) : (tensor<2x3xf64>, tensor<2x3xf64>) -> tensor<*xf64>
-# CHECK-NEXT:    toy.print [[VAL_10]] : tensor<*xf64>
-# CHECK-NEXT:    toy.return
diff --git a/obs/.clang-format b/obs/.clang-format
new file mode 100644
index 0000000000000..468ceea01f1a0
--- /dev/null
+++ b/obs/.clang-format
@@ -0,0 +1 @@
+BasedOnStyle: LLVM
\ No newline at end of file
diff --git a/obs/CMakeLists.txt b/obs/CMakeLists.txt
index 05dcf074a73a6..5dd9548f6a956 100644
--- a/obs/CMakeLists.txt
+++ b/obs/CMakeLists.txt
@@ -5,9 +5,12 @@ set(LLVM_LINK_COMPONENTS
 include(CMakeDependentOption)
 include(GNUInstallDirs)
 
-set(MLIR_MAIN_SRC_DIR ${LLVM_MAIN_SRC_DIR}/../mlir/include ) # --src-root
-set(MLIR_INCLUDE_DIR ${LLVM_MAIN_SRC_DIR}/../mlir/include ) # --includedir
-set(MLIR_TABLEGEN_OUTPUT_DIR ${CMAKE_BINARY_DIR}/tools/mlir/include)
+set(MLIR_INCLUDE_DIR ${LLVM_MAIN_SRC_DIR}/../mlir/include/ ) # --includedir
+set(CLANG_INCLUDE_DIR ${LLVM_MAIN_SRC_DIR}/../clang/include/ ) # --includedir
+set(CLANG_INCLUDE_DIR_INC ${CMAKE_BINARY_DIR}/tools/clang/include/ ) # --includedir
+set(MLIR_TABLEGEN_OUTPUT_DIR ${CMAKE_BINARY_DIR}/tools/mlir/include/)
+include_directories(SYSTEM ${CLANG_INCLUDE_DIR})
+include_directories(SYSTEM ${CLANG_INCLUDE_DIR_INC})
 include_directories(SYSTEM ${MLIR_INCLUDE_DIR})
 include_directories(SYSTEM ${MLIR_TABLEGEN_OUTPUT_DIR})
 
@@ -26,21 +29,47 @@ include(CMakeParseArguments)
 include(AddOBS)
 
 include_directories(include/)
-add_obs_executable(obstest 
-    obs-ir/obs.cpp 
-    obs-ir/Dialect.cpp
-    obs-ir/MLIRGen.cpp
-    parser/AST.cpp
+#add_obs_executable(obstest 
+#    obs-ir/obs.cpp 
+#    obs-ir/Dialect.cpp
+#    obs-ir/MLIRGen.cpp
+#    parser/AST.cpp
+#  DEPENDS
+#    OBSOpGen
+#)
+
+#target_link_libraries(obstest
+#  PRIVATE
+#  MLIRAnalysis
+#  MLIRFunctionInterfaces
+#  MLIRIR
+#  MLIRParser
+#  MLIRSideEffectInterfaces
+#  MLIRTransforms
+#)
+
+
+add_obs_executable(codegen
+    codegen/OBSGen.cpp
+    codegen/CodeGenAction.cpp
+    codegen/CodeGen.cpp
   DEPENDS
     OBSOpGen
-  )
+)
 
-target_link_libraries(obstest
+target_link_libraries(codegen
   PRIVATE
+  clangAST
+  clangBasic
+  clangFrontend
+  clangSerialization
+  clangTooling
   MLIRAnalysis
   MLIRFunctionInterfaces
   MLIRIR
   MLIRParser
   MLIRSideEffectInterfaces
   MLIRTransforms
-)
\ No newline at end of file
+)
+
+
diff --git a/obs/codegen/CodeGen.cpp b/obs/codegen/CodeGen.cpp
new file mode 100644
index 0000000000000..d50de0431cc3b
--- /dev/null
+++ b/obs/codegen/CodeGen.cpp
@@ -0,0 +1,55 @@
+
+#include "MLIRGen.h"
+
+#include "AST.h"
+#include "Dialect.h"
+#include "mlir/IR/Block.h"
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/Value.h"
+#include "mlir/Support/LogicalResult.h"
+
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/ScopedHashTable.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/StringRef.h"
+#include "llvm/ADT/Twine.h"
+#include "llvm/Support/Casting.h"
+#include "llvm/Support/raw_ostream.h"
+
+#include <cassert>
+#include <cstdint>
+#include <functional>
+
+#include <mlir/IR/BuiltinAttributes.h>
+#include <mlir/IR/Location.h>
+#include <mlir/IR/OwningOpRef.h>
+
+#include <iostream>
+#include <numeric>
+#include <optional>
+#include <vector>
+
+using llvm::ArrayRef;
+using llvm::cast;
+using llvm::dyn_cast;
+using llvm::isa;
+using llvm::ScopedHashTableScope;
+using llvm::SmallVector;
+using llvm::StringRef;
+using llvm::Twine;
+
+#include "CodeGen.h"
+
+namespace mlir {
+namespace obs {
+
+bool MLIRGenImpl::VisitFunctionDecl(clang::FunctionDecl *funcDecl) {
+  llvm::outs() << "VisitFunctionDecl: ";
+  funcDecl->getDeclName().dump();
+  llvm::outs() << "\n";
+  return false;
+}
+
+} // namespace obs
+} // namespace mlir
diff --git a/obs/codegen/CodeGenAction.cpp b/obs/codegen/CodeGenAction.cpp
new file mode 100644
index 0000000000000..94fa18dcd188d
--- /dev/null
+++ b/obs/codegen/CodeGenAction.cpp
@@ -0,0 +1,30 @@
+
+
+#include "CodeGenAction.h"
+#include "CodeGen.h"
+
+#include "llvm/Support/Casting.h"
+#include "llvm/Support/raw_ostream.h"
+
+#include <clang/AST/ASTContext.h>
+#include <clang/AST/Decl.h>
+#include <clang/AST/DeclGroup.h>
+
+#include <iostream>
+#include <mlir/IR/MLIRContext.h>
+#include <ostream>
+
+using namespace clang;
+
+namespace mlir {
+namespace obs {
+
+void CodeGenConsumer::HandleTranslationUnit(ASTContext &context) {
+  llvm::outs() << "Enter HandleTranslationUnit\n";
+  MLIRContext codegenContext;
+  MLIRGenImpl mlirGen(codegenContext);
+  mlirGen.TraverseDecl(context.getTranslationUnitDecl());
+}
+
+} // namespace obs
+} // namespace mlir
\ No newline at end of file
diff --git a/obs/codegen/OBSGen.cpp b/obs/codegen/OBSGen.cpp
new file mode 100644
index 0000000000000..7cd866aa124ca
--- /dev/null
+++ b/obs/codegen/OBSGen.cpp
@@ -0,0 +1,27 @@
+
+#include "llvm/Support/CommandLine.h"
+#include <iostream>
+#include <vector>
+
+#include "OBSGen.h"
+
+// `main` function translates a C program into a OBS.
+int main(int argc, const char **argv) {
+
+  llvm::cl::OptionCategory CodeGenCategory("OBS code generation");
+  auto OptionsParser =
+      clang::tooling::CommonOptionsParser::create(argc, argv, CodeGenCategory);
+
+  if (!OptionsParser) {
+    // Fail gracefully for unsupported options.
+    std::cout << "error ----" << std::endl;
+    llvm::errs() << OptionsParser.takeError() << "error";
+    return 1;
+  }
+
+  auto sources = OptionsParser->getSourcePathList();
+
+  clang::tooling::ClangTool Tool(OptionsParser->getCompilations(), sources);
+
+  Tool.run(new mlir::obs::CodeGenFrontendActionFactory());
+}
diff --git a/obs/include/CodeGen.h b/obs/include/CodeGen.h
new file mode 100644
index 0000000000000..d36a8f5ff4e9b
--- /dev/null
+++ b/obs/include/CodeGen.h
@@ -0,0 +1,41 @@
+
+#ifndef _OBS_CODEGEN_H
+#define _OBS_CODEGEN_H
+
+#include "clang/AST/RecursiveASTVisitor.h"
+#include "clang/Basic/SourceManager.h"
+#include <clang/AST/ASTContext.h>
+#include <clang/AST/Decl.h>
+#include <memory>
+
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/Verifier.h"
+
+#include <mlir/IR/BuiltinOps.h>
+#include <mlir/IR/MLIRContext.h>
+#include <mlir/IR/OwningOpRef.h>
+
+namespace mlir {
+namespace obs {
+
+class MLIRGenImpl : public clang::RecursiveASTVisitor<MLIRGenImpl> {
+public:
+  MLIRGenImpl(mlir::MLIRContext &context) : builder(&context) {}
+
+  bool VisitFunctionDecl(clang::FunctionDecl *funcDecl);
+
+private:
+  mlir::ModuleOp theModule;
+  mlir::OpBuilder builder;
+};
+
+mlir::OwningOpRef<mlir::ModuleOp> mlirGen(mlir::MLIRContext &context,
+                                          clang::TranslationUnitDecl &decl);
+
+} // namespace obs
+} // namespace mlir
+
+#endif
\ No newline at end of file
diff --git a/obs/include/CodeGenAction.h b/obs/include/CodeGenAction.h
new file mode 100644
index 0000000000000..226a24ee78253
--- /dev/null
+++ b/obs/include/CodeGenAction.h
@@ -0,0 +1,43 @@
+
+
+#include "clang/AST/AST.h"
+#include "clang/AST/ASTConsumer.h"
+#include "clang/AST/ASTContext.h"
+#include "clang/Basic/LLVM.h"
+#include "clang/Frontend/CompilerInstance.h"
+#include "clang/Frontend/FrontendAction.h"
+#include "clang/Frontend/PrecompiledPreamble.h"
+#include "clang/Tooling/CommonOptionsParser.h"
+#include "clang/Tooling/Tooling.h"
+#include <iostream>
+#include <memory>
+
+namespace mlir {
+namespace obs {
+
+class CodeGenConsumer : public clang::ASTConsumer {
+public:
+  CodeGenConsumer() {}
+  void HandleTranslationUnit(clang::ASTContext &context) override;
+};
+
+class CodeGenFrontendAction : public clang::ASTFrontendAction {
+protected:
+  std::unique_ptr<clang::ASTConsumer>
+  CreateASTConsumer(clang::CompilerInstance &ci, clang::StringRef) override {
+    return std::make_unique<CodeGenConsumer>();
+  }
+};
+
+class CodeGenFrontendActionFactory
+    : public clang::tooling::FrontendActionFactory {
+public:
+  CodeGenFrontendActionFactory() {}
+
+  std::unique_ptr<clang::FrontendAction> create() override {
+    return std::make_unique<CodeGenFrontendAction>();
+  }
+};
+
+} // namespace obs
+} // namespace mlir
\ No newline at end of file
diff --git a/obs/include/Dialect.h b/obs/include/Dialect.h
index a298ed2f086a0..e87c75f08d864 100644
--- a/obs/include/Dialect.h
+++ b/obs/include/Dialect.h
@@ -9,9 +9,59 @@
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 
 #include "Dialect.h.inc"
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/StringRef.h"
+#include <mlir/IR/BuiltinTypes.h>
+#include <mlir/IR/Types.h>
+
+namespace mlir {
+namespace obs {
+namespace user_types {
+    struct OwnTypeStorage;
+    struct RefTypeStorage;
+}//user_types
+}//obs
+}//mlir
 
 #define GET_OP_CLASSES
 #include "Ops.h.inc"
 
 
+
+namespace mlir {
+namespace obs {
+
+class OwnType : public mlir::Type::TypeBase<OwnType, mlir::Type, user_types::OwnTypeStorage> {
+public:
+    using Base::Base;
+
+    //Create an instance of `OwnType`.
+    static OwnType get(mlir::MLIRContext *ctx, StringRef resName, ArrayRef<unsigned int> dims);
+
+    //Return the owned resource name.
+    StringRef getResName() ;
+
+    //Return the  dims of the owned resource.
+    ArrayRef<unsigned int> getDims();
+
+    static constexpr llvm::StringLiteral name = "obs.OWN";
+};
+
+class RefType : public mlir::Type::TypeBase<RefType, mlir::Type, user_types::RefTypeStorage> {
+public:
+    using Base::Base;
+
+    //Create an instance of `OwnType`.
+    static RefType get(ArrayRef<mlir::Type> ownerType);
+
+    //Return the owned resource name.
+    ArrayRef<mlir::Type> getOwnerType() ;
+
+    static constexpr llvm::StringLiteral name = "obs.REF";
+};
+
+} //obs
+} //mlir
+
+
 #endif //MLIR_DIALECT_TRAITS_H
diff --git a/obs/include/OBSGen.h b/obs/include/OBSGen.h
new file mode 100644
index 0000000000000..740bfcdd7ba1e
--- /dev/null
+++ b/obs/include/OBSGen.h
@@ -0,0 +1,2 @@
+
+#include "CodeGenAction.h"
\ No newline at end of file
diff --git a/obs/include/Ops.td b/obs/include/Ops.td
index 9463de72f2614..40c8e6cb330ff 100644
--- a/obs/include/Ops.td
+++ b/obs/include/Ops.td
@@ -11,18 +11,137 @@ def OBS_Dialect: Dialect {
     //The namespace of the Dialect.
     let name = "obs";
 
-    let summary = "A high-level dialect for analyzing and optimzing the Language";
+    let summary = "A dialect for representing abstract memory operations.";
 
     let description = [{
-        The Toy language is a tensor-based language that allows you to define functions, perform some math computation, and
-        print results. This dialect provides a reprentation of the language that is amenable to analysis and optimization.
+        The OBS dialect is an abstraction over memory operations. 
+        It is used to reason about Ownership and Borrowing in a programming language.
+        It contains the operations:
+        (1) OwnOp(type, Dim);
+        (2) RefOp(x)
+        (3) Read(x);
+        (4) Write(x);
+        (5) delete(x); 
+        (6) Function;
     }];
 
     let cppNamespace = "::mlir::obs";
+
+    let useDefaultTypePrinterParser = 1;
 }
 
 class OBS_op<string mnemonic, list<Trait> traits = []> : Op<OBS_Dialect, mnemonic, traits>;
 
+def OBS_OwnType : 
+    DialectType<OBS_Dialect, CPred<"::llvm::isa<OwnType>($_self)">, "OBS own type">;
+
+def OBS_RefType : 
+    DialectType<OBS_Dialect, CPred<"::llvm::isa<RefType>($_self)">, "OBS ref type">;
+
+def OBS_Type : AnyTypeOf<[OBS_OwnType, OBS_RefType]>;
+
+def OwnOp : OBS_op<"Own", []> {
+
+    let summary = "Allocate a memory resource and bind to an owner." ;
+
+    let description = [{
+        This operation creates a memory resource and bind it to an owner.
+    }];
+
+
+    let arguments = (ins StrAttr:$type, F64Tensor:$dim);
+    let results = (outs OBS_OwnType);
+
+    let assemblyFormat = [{
+       `(` attr-dict `,` $dim `:` type($dim) `)` `:` type(results)
+    }];
+
+
+    let hasVerifier = 1;
+
+    let builders = [
+        OpBuilder<(ins "StringRef":$type, "Value":$dim )>
+    ];
+}
+
+def RefOp : OBS_op<"Ref", []> {
+
+    let summary = "Create a reference to a resource." ;
+
+    let description = [{
+        This operation creates a reference to a resource.
+    }];
+
+    let arguments = (ins OBS_Type:$ownType);
+    let results = (outs OBS_RefType);
+
+    let assemblyFormat = [{
+       `(` attr-dict $ownType `:` type($ownType) `)` `:` type(results)
+    }];
+
+
+    let hasVerifier = 1;
+
+    let builders = [
+        OpBuilder<(ins "Value":$ownType )>
+    ];
+}
+
+def ReadOp : OBS_op<"Read", []> {
+
+    let summary = "Read the resource though a variable" ;
+
+    let description = [{
+        This operation creates a read operation.
+    }];
+
+    let arguments = (ins OBS_Type:$ownType);
+
+    let assemblyFormat = [{
+       `(` attr-dict $ownType `:` type($ownType) `)`
+    }];
+
+
+    let hasVerifier = 1;
+}
+
+def WriteOp : OBS_op<"Write", []> {
+
+    let summary = "Write the resource though a variable" ;
+
+    let description = [{
+        This operation creates a write operation.
+    }];
+
+    let arguments = (ins OBS_Type:$ownType);
+
+    let assemblyFormat = [{
+       `(` attr-dict $ownType `:` type($ownType) `)`
+    }];
+
+
+    let hasVerifier = 1;
+}
+
+def DeleteOp : OBS_op<"Delete", []> {
+
+    let summary = "Delete a resource through its owner." ;
+
+    let description = [{
+        This operation creates a delete operation.
+    }];
+
+    let arguments = (ins OBS_Type:$ownType);
+
+    let assemblyFormat = [{
+       `(` attr-dict $ownType `:` type($ownType) `)`
+    }];
+
+
+    let hasVerifier = 1;
+}
+
+
 def ConstantOp : OBS_op<"constant", [Pure]> {
 
     let summary = "constant operation" ;
diff --git a/obs/obs-ir/Dialect.cpp b/obs/obs-ir/Dialect.cpp
index c5a4ef2e5c310..d115c4a887efa 100644
--- a/obs/obs-ir/Dialect.cpp
+++ b/obs/obs-ir/Dialect.cpp
@@ -11,11 +11,21 @@
 #include "mlir/Support/LLVM.h"
 #include "mlir/Support/LogicalResult.h"
 #include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/Hashing.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/StringRef.h"
 #include "llvm/Support/Casting.h"
+#include "llvm/Support/VersionTuple.h"
 #include <algorithm>
+#include <mlir/IR/DialectImplementation.h>
+#include <mlir/IR/MLIRContext.h>
+#include <mlir/IR/TypeRange.h>
+#include <mlir/IR/TypeSupport.h>
+#include <mlir/IR/Types.h>
+#include <mlir/Support/TypeID.h>
 #include <string>
+#include <utility>
+#include <vector>
 
 using namespace mlir;
 using namespace mlir::obs;
@@ -27,8 +37,12 @@ void OBSDialect::initialize() {
     #define GET_OP_LIST
     #include "Ops.cpp.inc"
     >();
+
+    addTypes<OwnType>();
+    addTypes<RefType>();
 }
 
+
 static mlir::ParseResult parseBinaryOp(mlir::OpAsmParser &parser, mlir::OperationState &result) {
 
     SmallVector<mlir::OpAsmParser::UnresolvedOperand, 2> operands;
@@ -67,6 +81,7 @@ static void printBinary(mlir::OpAsmPrinter &printer, mlir::Operation *op) {
     printer.printFunctionalType(op->getOperandTypes(), op->getResultTypes());
 }
 
+
 void ConstantOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, double value) {
     auto dataType = RankedTensorType::get({}, builder.getF64Type());
     auto dataAttribute = DenseElementsAttr::get(dataType, value);
@@ -90,6 +105,34 @@ void ConstantOp::print(mlir::OpAsmPrinter &printer) {
     printer << getValue();
 }
 
+mlir::LogicalResult OwnOp::verify() {
+    if (getType() == "vector") {
+        return success();
+    } else {
+        return failure();
+    }
+}
+
+mlir::LogicalResult RefOp::verify() {
+    //TODO: add complete verifying.
+    return success();
+}
+
+mlir::LogicalResult WriteOp::verify() {
+    //TODO: add complete verifying.
+    return success();
+}
+
+mlir::LogicalResult ReadOp::verify() {
+    //TODO: add complete verifying.
+    return success();
+}
+
+mlir::LogicalResult DeleteOp::verify() {
+    //TODO: add complete verifying.
+    return success();
+}
+
 mlir::LogicalResult ConstantOp::verify() {
     auto resultType = llvm::dyn_cast<mlir::RankedTensorType>(getResult().getType());
     if (!resultType)
@@ -217,6 +260,126 @@ mlir::LogicalResult TransposeOp::verify() {
     return mlir::success();
 }
 
+namespace mlir {
+namespace obs {
+
+namespace user_types {
+/*
+ * The definition of the owner type. 
+ **/
+struct OwnTypeStorage : public mlir::TypeStorage {
+
+    
+
+    // StringRef is the name of the resource.
+    // RankedTensorType is the dimensions.
+    using KeyTy = std::pair<StringRef, ArrayRef<unsigned int>>;
+
+    OwnTypeStorage(StringRef resName, ArrayRef<unsigned int> dims): resName(resName), dims(dims)  {}
+
+    static OwnTypeStorage *construct(mlir::TypeStorageAllocator &allocator, const KeyTy &key) {
+        StringRef resName = allocator.copyInto(key.first);
+        ArrayRef<unsigned int> dim = allocator.copyInto(key.second);
+        return new (allocator.allocate<OwnTypeStorage>()) OwnTypeStorage(resName, dim);
+    }
+
+    static llvm::hash_code hashKey(const KeyTy &key) {
+        return llvm::hash_value(key);
+    }
+
+    static KeyTy getKey(StringRef resName, ArrayRef<unsigned int> dims) {
+        return KeyTy(resName, dims);
+    }
+
+    bool operator==(const KeyTy &key) const {
+        return (key.first == resName && key.second == dims);
+    }
+
+    StringRef resName;
+    ArrayRef<unsigned int> dims;
+};
+
+/*
+ * The definition of the reference type. 
+ **/
+struct RefTypeStorage : public mlir::TypeStorage {
+
+    // Here we use array, but usually it contains only one element.
+    using KeyTy = ArrayRef<mlir::Type>;
+
+    RefTypeStorage(ArrayRef<mlir::Type> ownerType): ownerType(ownerType) {}
+
+    static RefTypeStorage *construct(mlir::TypeStorageAllocator &allocator, const KeyTy &key) {
+        
+        ArrayRef<mlir::Type> ownerType = allocator.copyInto(key);
+        return new (allocator.allocate<RefTypeStorage>()) RefTypeStorage(ownerType);
+    }
+
+    static llvm::hash_code hashKey(const KeyTy &key) {
+        return llvm::hash_value(key);
+    }
+
+    static KeyTy getKey(ArrayRef<mlir::Type> ownerType) {
+        return KeyTy(ownerType);
+    }
+
+    bool operator==(const KeyTy &key) const {
+        return (key == ownerType);
+    }
+
+    // The resource type it refers to.
+    ArrayRef<mlir::Type> ownerType;
+};
+
+} //user_types
+} //obs
+} //mlir
+
+OwnType OwnType::get(mlir::MLIRContext *ctx, StringRef resName, ArrayRef<unsigned int> dims) {
+    return Base::get(ctx, resName, dims);
+}
+
+StringRef OwnType::getResName() {
+    return getImpl()->resName;
+}
+
+ArrayRef<unsigned int> OwnType::getDims() {
+    return getImpl()->dims;
+}
+
+RefType RefType::get(ArrayRef<mlir::Type> ownType) {
+    return Base::get(ownType.front().getContext(), ownType);
+}
+
+ArrayRef<mlir::Type> RefType::getOwnerType() {
+    return getImpl()->ownerType;
+}
+
+
+void OBSDialect::printType(::mlir::Type type,
+                 ::mlir::DialectAsmPrinter &printer) const {
+    if (llvm::isa<OwnType>(type)) {
+        OwnType ownType = llvm::cast<OwnType>(type);
+        printer << "Own( ";
+        printer << ownType.getResName();
+        printer << ", [";
+        llvm::interleaveComma(ownType.getDims(), printer);
+        printer << " ] )";
+    } else if (RefType refType = llvm::cast<RefType>(type)) {
+        printer << "Ref( ";
+        printer << refType.getOwnerType() ;
+        printer << " )";
+    }
+}
+
+mlir::Type OBSDialect::parseType(mlir::DialectAsmParser &parser) const {
+    //TODO: complete parseType
+    if (parser.parseKeyword("Own") || parser.parseLess())
+        return Type();
+    return Type();
+}
+
+
 #define GET_OP_CLASSES
 #include "Ops.cpp.inc"
 
diff --git a/obs/obs-ir/MLIRGen.cpp b/obs/obs-ir/MLIRGen.cpp
index 388d7e5b5fdb0..286bdf63cb245 100644
--- a/obs/obs-ir/MLIRGen.cpp
+++ b/obs/obs-ir/MLIRGen.cpp
@@ -21,13 +21,18 @@
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/StringRef.h"
 #include "llvm/ADT/Twine.h"
+#include "llvm/IR/DerivedTypes.h"
+#include "llvm/IR/Type.h"
 #include "llvm/Support/Casting.h"
 #include <cassert>
 #include <cstdint>
 #include <functional>
+#include <mlir-c/BuiltinTypes.h>
 #include <mlir/IR/BuiltinAttributes.h>
 #include <mlir/IR/Location.h>
 #include <mlir/IR/OwningOpRef.h>
+#include <mlir/IR/Types.h>
+#include <mlir/Support/LLVM.h>
 #include <numeric>
 #include <optional>
 #include <vector>
@@ -150,12 +155,23 @@ class MLIRGenImpl {
   }
 
   mlir::LogicalResult mlirGen(PrintExprAST &call) {
+
+    auto arg = mlirGen(*call.getArg());
+    if (!arg) {
+      return mlir::failure();
+    }
+
+    auto ownType = mlir::obs::OwnType::get(theModule.getContext(), "vector", {1,2});
+    auto refType = mlir::obs::RefType::get({ownType});
+    builder.create<RefOp>(loc(call.loc()), refType , arg);
+
+    /*
     auto arg = mlirGen(*call.getArg());
     if (!arg) {
         return mlir::failure();
     }
 
-    builder.create<PrintOp>(loc(call.loc()), arg);
+    builder.create<PrintOp>(loc(call.loc()), arg); */
     return mlir::success();
   } 
 
@@ -248,6 +264,7 @@ class MLIRGenImpl {
 
   mlir::Value mlirGen(VariableExprAST &expr) {
     if (auto variable = symbolTable.lookup(expr.getName())) {
+        builder.create<ReadOp>(loc(expr.loc()), variable);
         return variable;
     }
     mlir::emitError(loc(expr.loc()), "error: unknown variable '") << expr.getName() << "'";
@@ -293,6 +310,35 @@ class MLIRGenImpl {
   }
 
   mlir::Value mlirGen(VarDeclExprAST &vardecl) {
+
+    auto *init = vardecl.getInitVal();
+
+    if (!init) {
+        mlir::emitError(loc(vardecl.loc()), "missing initializer in variable declaration");
+        return nullptr;
+    }
+
+    
+
+    mlir::Value value = mlirGen(*init);
+
+    StringRef type = "vector";
+
+    auto ownType = mlir::obs::OwnType::get(theModule.getContext(), "vector", {1,2});
+    
+    std::vector<double> data = {1, 2};
+
+    
+    builder.getI32VectorAttr({1, 2});
+    value = builder.create<OwnOp>(loc(vardecl.loc()), ownType , type, value);
+
+    if (failed(declare(vardecl.getName(), value))) {
+        return nullptr;
+    }
+
+    return value;
+
+/*
     auto *init = vardecl.getInitVal();
 
     if (!init) {
@@ -312,7 +358,7 @@ class MLIRGenImpl {
     if (failed(declare(vardecl.getName(), value))) {
         return nullptr;
     }
-    return value;
+    return value; */
   }
 
 };
diff --git a/obs/test/codegen.toy b/obs/test/codegen.toy
new file mode 100644
index 0000000000000..05c2d9d1b9f04
--- /dev/null
+++ b/obs/test/codegen.toy
@@ -0,0 +1,5 @@
+
+def main() {
+  var a<2, 3> = [[1, 2, 3], [4, 5, 6]];
+  print(a);
+}
diff --git a/obs/test/test1.cpp b/obs/test/test1.cpp
new file mode 100644
index 0000000000000..3d6f2a47fb915
--- /dev/null
+++ b/obs/test/test1.cpp
@@ -0,0 +1,9 @@
+
+void foo() {
+    
+}
+
+int main(){
+    int x;
+    return 1;
+}
\ No newline at end of file



More information about the Mlir-commits mailing list