[llvm] [mlir] Op definition (PR #93931)

via llvm-commits llvm-commits at lists.llvm.org
Fri May 31 00:04:51 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Shuanglong Kan (ShlKan)

<details>
<summary>Changes</summary>



---

Patch is 94.03 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/93931.diff


27 Files Affected:

- (modified) llvm/CMakeLists.txt (+1-1) 
- (modified) llvm/tools/CMakeLists.txt (+1) 
- (removed) mlir/test/Examples/Toy/Ch2/codegen.toy (-31) 
- (added) obs/.clang-format (+1) 
- (added) obs/.clang-tidy (+26) 
- (added) obs/CMakeLists.txt (+75) 
- (added) obs/cmake/modules/AddOBS.cmake (+5) 
- (added) obs/cmake/modules/CMakeLists.txt () 
- (added) obs/codegen/CodeGen.cpp (+55) 
- (added) obs/codegen/CodeGenAction.cpp (+30) 
- (added) obs/codegen/OBSGen.cpp (+27) 
- (added) obs/include/AST.h (+227) 
- (added) obs/include/CMakeLists.txt (+6) 
- (added) obs/include/CodeGen.h (+41) 
- (added) obs/include/CodeGenAction.h (+43) 
- (added) obs/include/Dialect.h (+67) 
- (added) obs/include/Lexer.h (+187) 
- (added) obs/include/MLIRGen.h (+21) 
- (added) obs/include/OBSGen.h (+2) 
- (added) obs/include/Ops.td (+322) 
- (added) obs/include/Parser.h (+452) 
- (added) obs/obs-ir/Dialect.cpp (+388) 
- (added) obs/obs-ir/MLIRGen.cpp (+380) 
- (added) obs/obs-ir/obs.cpp (+135) 
- (added) obs/parser/AST.cpp (+222) 
- (added) obs/test/codegen.toy (+5) 
- (added) obs/test/test1.cpp (+9) 


``````````diff
diff --git a/llvm/CMakeLists.txt b/llvm/CMakeLists.txt
index 6f5647d70d8bc..d3eaa7670eb15 100644
--- a/llvm/CMakeLists.txt
+++ b/llvm/CMakeLists.txt
@@ -114,7 +114,7 @@ endif()
 # LLVM_EXTERNAL_${project}_SOURCE_DIR using LLVM_ALL_PROJECTS
 # This allows an easy way of setting up a build directory for llvm and another
 # one for llvm+clang+... using the same sources.
-set(LLVM_ALL_PROJECTS "bolt;clang;clang-tools-extra;compiler-rt;cross-project-tests;libc;libclc;lld;lldb;mlir;openmp;polly;pstl")
+set(LLVM_ALL_PROJECTS "bolt;clang;clang-tools-extra;compiler-rt;cross-project-tests;libc;libclc;lld;lldb;mlir;obs;openmp;polly;pstl")
 # The flang project is not yet part of "all" projects (see C++ requirements)
 set(LLVM_EXTRA_PROJECTS "flang")
 # List of all known projects in the mono repo
diff --git a/llvm/tools/CMakeLists.txt b/llvm/tools/CMakeLists.txt
index c6116ac81d12b..ac1728854a765 100644
--- a/llvm/tools/CMakeLists.txt
+++ b/llvm/tools/CMakeLists.txt
@@ -41,6 +41,7 @@ add_llvm_external_project(clang)
 add_llvm_external_project(lld)
 add_llvm_external_project(lldb)
 add_llvm_external_project(mlir)
+add_llvm_external_project(obs)
 # Flang depends on mlir, so place it afterward
 add_llvm_external_project(flang)
 add_llvm_external_project(bolt)
diff --git a/mlir/test/Examples/Toy/Ch2/codegen.toy b/mlir/test/Examples/Toy/Ch2/codegen.toy
deleted file mode 100644
index 12178d6afb309..0000000000000
--- a/mlir/test/Examples/Toy/Ch2/codegen.toy
+++ /dev/null
@@ -1,31 +0,0 @@
-# RUN: toyc-ch2 %s -emit=mlir 2>&1 | FileCheck %s
-
-# User defined generic function that operates on unknown shaped arguments
-def multiply_transpose(a, b) {
-  return transpose(a) * transpose(b);
-}
-
-def main() {
-  var a<2, 3> = [[1, 2, 3], [4, 5, 6]];
-  var b<2, 3> = [1, 2, 3, 4, 5, 6];
-  var c = multiply_transpose(a, b);
-  var d = multiply_transpose(b, a);
-  print(d);
-}
-
-# CHECK-LABEL: toy.func @multiply_transpose(
-# CHECK-SAME:                               [[VAL_0:%.*]]: tensor<*xf64>, [[VAL_1:%.*]]: tensor<*xf64>) -> tensor<*xf64>
-# CHECK:         [[VAL_2:%.*]] = toy.transpose([[VAL_0]] : tensor<*xf64>) to tensor<*xf64>
-# CHECK-NEXT:    [[VAL_3:%.*]] = toy.transpose([[VAL_1]] : tensor<*xf64>) to tensor<*xf64>
-# CHECK-NEXT:    [[VAL_4:%.*]] = toy.mul [[VAL_2]], [[VAL_3]] :  tensor<*xf64>
-# CHECK-NEXT:    toy.return [[VAL_4]] : tensor<*xf64>
-
-# CHECK-LABEL: toy.func @main()
-# CHECK-NEXT:    [[VAL_5:%.*]] = toy.constant dense<{{\[\[}}1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>
-# CHECK-NEXT:    [[VAL_6:%.*]] = toy.reshape([[VAL_5]] : tensor<2x3xf64>) to tensor<2x3xf64>
-# CHECK-NEXT:    [[VAL_7:%.*]] = toy.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00]> : tensor<6xf64>
-# CHECK-NEXT:    [[VAL_8:%.*]] = toy.reshape([[VAL_7]] : tensor<6xf64>) to tensor<2x3xf64>
-# CHECK-NEXT:    [[VAL_9:%.*]] = toy.generic_call @multiply_transpose([[VAL_6]], [[VAL_8]]) : (tensor<2x3xf64>, tensor<2x3xf64>) -> tensor<*xf64>
-# CHECK-NEXT:    [[VAL_10:%.*]] = toy.generic_call @multiply_transpose([[VAL_8]], [[VAL_6]]) : (tensor<2x3xf64>, tensor<2x3xf64>) -> tensor<*xf64>
-# CHECK-NEXT:    toy.print [[VAL_10]] : tensor<*xf64>
-# CHECK-NEXT:    toy.return
diff --git a/obs/.clang-format b/obs/.clang-format
new file mode 100644
index 0000000000000..468ceea01f1a0
--- /dev/null
+++ b/obs/.clang-format
@@ -0,0 +1 @@
+BasedOnStyle: LLVM
\ No newline at end of file
diff --git a/obs/.clang-tidy b/obs/.clang-tidy
new file mode 100644
index 0000000000000..4b7b8e9479f29
--- /dev/null
+++ b/obs/.clang-tidy
@@ -0,0 +1,26 @@
+Checks: '-*,clang-diagnostic-*,llvm-*,misc-*,-misc-const-correctness,-misc-unused-parameters,-misc-non-private-member-variables-in-classes,-misc-no-recursion,-misc-use-anonymous-namespace,readability-identifier-naming,-misc-include-cleaner'
+CheckOptions:
+  - key:             readability-identifier-naming.ClassCase
+    value:           CamelCase
+  - key:             readability-identifier-naming.EnumCase
+    value:           CamelCase
+  - key:             readability-identifier-naming.FunctionCase
+    value:           camelBack
+  # Exclude from scanning as this is an exported symbol used for fuzzing
+  # throughout the code base.
+  - key:             readability-identifier-naming.FunctionIgnoredRegexp
+    value:           "LLVMFuzzerTestOneInput"
+  - key:             readability-identifier-naming.MemberCase
+    value:           CamelCase
+  - key:             readability-identifier-naming.ParameterCase
+    value:           camelCase
+  - key:             readability-identifier-naming.UnionCase
+    value:           CamelCase
+  - key:             readability-identifier-naming.VariableCase
+    value:           camelCase
+  - key:             readability-identifier-naming.IgnoreMainLikeFunctions
+    value:           1
+  - key:             readability-redundant-member-init.IgnoreBaseInCopyConstructors
+    value:           1
+  - key:             modernize-use-default-member-init.UseAssignment
+    value:           1
diff --git a/obs/CMakeLists.txt b/obs/CMakeLists.txt
new file mode 100644
index 0000000000000..5dd9548f6a956
--- /dev/null
+++ b/obs/CMakeLists.txt
@@ -0,0 +1,75 @@
+set(LLVM_LINK_COMPONENTS
+  Support
+  )
+  
+include(CMakeDependentOption)
+include(GNUInstallDirs)
+
+set(MLIR_INCLUDE_DIR ${LLVM_MAIN_SRC_DIR}/../mlir/include/ ) # --includedir
+set(CLANG_INCLUDE_DIR ${LLVM_MAIN_SRC_DIR}/../clang/include/ ) # --includedir
+set(CLANG_INCLUDE_DIR_INC ${CMAKE_BINARY_DIR}/tools/clang/include/ ) # --includedir
+set(MLIR_TABLEGEN_OUTPUT_DIR ${CMAKE_BINARY_DIR}/tools/mlir/include/)
+include_directories(SYSTEM ${CLANG_INCLUDE_DIR})
+include_directories(SYSTEM ${CLANG_INCLUDE_DIR_INC})
+include_directories(SYSTEM ${MLIR_INCLUDE_DIR})
+include_directories(SYSTEM ${MLIR_TABLEGEN_OUTPUT_DIR})
+
+
+include_directories(SYSTEM ${CMAKE_CURRENT_BINARY_DIR}/include)
+add_subdirectory(include)
+
+# Make sure that our source directory is on the current cmake module path so
+# that we can include cmake files from this directory.
+list(INSERT CMAKE_MODULE_PATH 0
+  "${CMAKE_CURRENT_SOURCE_DIR}/cmake/modules"
+  "${LLVM_COMMON_CMAKE_UTILS}/Modules"
+  )
+
+include(CMakeParseArguments)
+include(AddOBS)
+
+include_directories(include/)
+#add_obs_executable(obstest 
+#    obs-ir/obs.cpp 
+#    obs-ir/Dialect.cpp
+#    obs-ir/MLIRGen.cpp
+#    parser/AST.cpp
+#  DEPENDS
+#    OBSOpGen
+#)
+
+#target_link_libraries(obstest
+#  PRIVATE
+#  MLIRAnalysis
+#  MLIRFunctionInterfaces
+#  MLIRIR
+#  MLIRParser
+#  MLIRSideEffectInterfaces
+#  MLIRTransforms
+#)
+
+
+add_obs_executable(codegen
+    codegen/OBSGen.cpp
+    codegen/CodeGenAction.cpp
+    codegen/CodeGen.cpp
+  DEPENDS
+    OBSOpGen
+)
+
+target_link_libraries(codegen
+  PRIVATE
+  clangAST
+  clangBasic
+  clangFrontend
+  clangSerialization
+  clangTooling
+  MLIRAnalysis
+  MLIRFunctionInterfaces
+  MLIRIR
+  MLIRParser
+  MLIRSideEffectInterfaces
+  MLIRTransforms
+)
+
+
diff --git a/obs/cmake/modules/AddOBS.cmake b/obs/cmake/modules/AddOBS.cmake
new file mode 100644
index 0000000000000..5c7420aad0312
--- /dev/null
+++ b/obs/cmake/modules/AddOBS.cmake
@@ -0,0 +1,5 @@
+macro(add_obs_executable name)
+  add_llvm_executable( ${name} ${ARGN} )
+  set_target_properties(${name} PROPERTIES FOLDER "OBS executables")
+endmacro(add_obs_executable)
+
diff --git a/obs/cmake/modules/CMakeLists.txt b/obs/cmake/modules/CMakeLists.txt
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/obs/codegen/CodeGen.cpp b/obs/codegen/CodeGen.cpp
new file mode 100644
index 0000000000000..d50de0431cc3b
--- /dev/null
+++ b/obs/codegen/CodeGen.cpp
@@ -0,0 +1,55 @@
+
+#include "MLIRGen.h"
+
+#include "AST.h"
+#include "Dialect.h"
+#include "mlir/IR/Block.h"
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/Value.h"
+#include "mlir/Support/LogicalResult.h"
+
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/ScopedHashTable.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/StringRef.h"
+#include "llvm/ADT/Twine.h"
+#include "llvm/Support/Casting.h"
+#include "llvm/Support/raw_ostream.h"
+
+#include <cassert>
+#include <cstdint>
+#include <functional>
+
+#include <mlir/IR/BuiltinAttributes.h>
+#include <mlir/IR/Location.h>
+#include <mlir/IR/OwningOpRef.h>
+
+#include <iostream>
+#include <numeric>
+#include <optional>
+#include <vector>
+
+using llvm::ArrayRef;
+using llvm::cast;
+using llvm::dyn_cast;
+using llvm::isa;
+using llvm::ScopedHashTableScope;
+using llvm::SmallVector;
+using llvm::StringRef;
+using llvm::Twine;
+
+#include "CodeGen.h"
+
+namespace mlir {
+namespace obs {
+
+bool MLIRGenImpl::VisitFunctionDecl(clang::FunctionDecl *funcDecl) {
+  llvm::outs() << "VisitFunctionDecl: ";
+  funcDecl->getDeclName().dump();
+  llvm::outs() << "\n";
+  return false;
+}
+
+} // namespace obs
+} // namespace mlir
diff --git a/obs/codegen/CodeGenAction.cpp b/obs/codegen/CodeGenAction.cpp
new file mode 100644
index 0000000000000..94fa18dcd188d
--- /dev/null
+++ b/obs/codegen/CodeGenAction.cpp
@@ -0,0 +1,30 @@
+
+
+#include "CodeGenAction.h"
+#include "CodeGen.h"
+
+#include "llvm/Support/Casting.h"
+#include "llvm/Support/raw_ostream.h"
+
+#include <clang/AST/ASTContext.h>
+#include <clang/AST/Decl.h>
+#include <clang/AST/DeclGroup.h>
+
+#include <iostream>
+#include <mlir/IR/MLIRContext.h>
+#include <ostream>
+
+using namespace clang;
+
+namespace mlir {
+namespace obs {
+
+void CodeGenConsumer::HandleTranslationUnit(ASTContext &context) {
+  llvm::outs() << "Enter HandleTranslationUnit\n";
+  MLIRContext codegenContext;
+  MLIRGenImpl mlirGen(codegenContext);
+  mlirGen.TraverseDecl(context.getTranslationUnitDecl());
+}
+
+} // namespace obs
+} // namespace mlir
\ No newline at end of file
diff --git a/obs/codegen/OBSGen.cpp b/obs/codegen/OBSGen.cpp
new file mode 100644
index 0000000000000..7cd866aa124ca
--- /dev/null
+++ b/obs/codegen/OBSGen.cpp
@@ -0,0 +1,27 @@
+
+#include "llvm/Support/CommandLine.h"
+#include <iostream>
+#include <vector>
+
+#include "OBSGen.h"
+
+// `main` function translates a C program into a OBS.
+int main(int argc, const char **argv) {
+
+  llvm::cl::OptionCategory CodeGenCategory("OBS code generation");
+  auto OptionsParser =
+      clang::tooling::CommonOptionsParser::create(argc, argv, CodeGenCategory);
+
+  if (!OptionsParser) {
+    // Fail gracefully for unsupported options.
+    std::cout << "error ----" << std::endl;
+    llvm::errs() << OptionsParser.takeError() << "error";
+    return 1;
+  }
+
+  auto sources = OptionsParser->getSourcePathList();
+
+  clang::tooling::ClangTool Tool(OptionsParser->getCompilations(), sources);
+
+  Tool.run(new mlir::obs::CodeGenFrontendActionFactory());
+}
diff --git a/obs/include/AST.h b/obs/include/AST.h
new file mode 100644
index 0000000000000..46d3f876143b2
--- /dev/null
+++ b/obs/include/AST.h
@@ -0,0 +1,227 @@
+#ifndef OBS_AST_H
+#define OBS_AST_H
+
+#include "Lexer.h"
+
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/StringRef.h"
+#include "llvm/Support/Casting.h"
+#include <utility>
+#include <vector>
+#include <optional>
+
+namespace obs {
+
+struct VarType {
+    std::vector<int64_t> shape ;
+};
+
+class ExprAST {
+public:
+    enum ExprASTKind {
+        Expr_VarDecl,
+        Expr_Return,
+        Expr_Num,
+        Expr_Literal,
+        Expr_Var,
+        Expr_BinOp,
+        Expr_Call,
+        Expr_Print,
+    };
+    ExprAST(ExprASTKind kind, Location location): kind(kind), location(std::move(location)) {}
+
+    ExprASTKind getKind() const { return kind; }
+
+    const Location &loc() {
+        return location;
+    }
+
+    ~ExprAST() = default; 
+
+private:
+    const ExprASTKind kind;
+    Location location;
+};
+
+using ExprASTList = std::vector<std::unique_ptr<ExprAST>>;
+
+class NumberExprAST: public ExprAST {
+    double val;
+public:
+    NumberExprAST(Location loc, double val): ExprAST(Expr_Num, loc), val(val) { }
+    double getValue() {
+        return val;
+    }
+    static bool classof(const ExprAST *c) { return c->getKind() == Expr_Num; }
+};
+
+class LiteralExprAST : public ExprAST {
+    std::vector<std::unique_ptr<ExprAST>> values;
+    std::vector<int64_t> dims;
+public:
+    LiteralExprAST(Location loc, std::vector<std::unique_ptr<ExprAST>> values, std::vector<int64_t> dims ):
+    ExprAST(Expr_Literal, loc), values(std::move(values)), dims(std::move(dims)) { }
+
+    llvm::ArrayRef<std::unique_ptr<ExprAST>> getValues() {
+        return values;
+    }
+
+    llvm::ArrayRef<int64_t> getDims() {
+        return dims;
+    }
+
+    static bool classof(const ExprAST * c) {
+        return c->getKind() == Expr_Literal;
+    }
+
+};
+
+class VariableExprAST : public ExprAST {
+    std::string name;
+
+public:
+    VariableExprAST(Location loc, llvm::StringRef name): ExprAST(Expr_Var, std::move(loc)), name(name) { }
+    llvm::StringRef getName() {
+        return name;
+    }
+
+    static bool classof(const ExprAST *c) {
+        return c->getKind() == Expr_Var;
+    }
+};
+
+class VarDeclExprAST : public ExprAST {
+    std::string name;
+    VarType type;
+    std::unique_ptr<ExprAST> initVal;
+
+public:
+    VarDeclExprAST(Location loc, llvm::StringRef name, VarType type, std::unique_ptr<ExprAST>initVal ):
+    ExprAST(Expr_VarDecl, loc), name(name), type(type), initVal(std::move(initVal)) {}
+
+    llvm::StringRef getName() { return name; }
+    VarType getType() { return type; }
+    ExprAST * getInitVal() { return initVal.get(); }
+
+    static bool classof(const ExprAST * c) {
+        return c -> getKind() == Expr_VarDecl;
+    }
+};
+
+class ReturnExprAST : public ExprAST {
+    std::optional<std::unique_ptr<ExprAST>> expr;
+public:
+    ReturnExprAST(Location loc, std::optional<std::unique_ptr<ExprAST>> expr): 
+        ExprAST(Expr_Return, loc), expr(std::move(expr)) {}
+
+    std::optional<ExprAST *> getExpr() {
+        if (expr.has_value()) {
+            return expr->get();
+        }
+        return std::nullopt;
+    }
+
+    static bool classof(const ExprAST * c) {
+        return c->getKind() == Expr_Return;
+    }
+};
+
+class BinaryExprAST : public ExprAST {
+    char op;
+    std::unique_ptr<ExprAST> lhs, rhs;
+public:
+    char getOp() { return op; }
+    ExprAST * getLHS() { return lhs.get(); }
+    ExprAST * getRHS() { return rhs.get(); }
+    BinaryExprAST(Location location, char op, std::unique_ptr<ExprAST> lhs, std::unique_ptr<ExprAST> rhs):
+    ExprAST(Expr_BinOp, location), op(op), lhs(std::move(lhs)), rhs(std::move(rhs)) {}
+
+    static bool classof(const ExprAST * c) {
+        return c->getKind() == Expr_BinOp;
+    }
+};
+
+class CallExprAST : public ExprAST {
+    std::string callee;
+    std::vector<std::unique_ptr<ExprAST>> args;
+
+public:
+    CallExprAST(Location loc, std::string callee, std::vector<std::unique_ptr<ExprAST>> args):
+    ExprAST(Expr_Call, loc), callee(callee), args(std::move(args)) {}
+
+    llvm::StringRef getCallee() {
+        return callee;
+    }
+    llvm::ArrayRef<std::unique_ptr<ExprAST>> getArgs() {
+        return args;
+    }
+
+    static bool classof(const ExprAST *c) {
+        return c->getKind() == Expr_Call;
+    }
+
+};
+
+/// Expression class for builtin print calls.                                                                                                                         
+class PrintExprAST : public ExprAST {
+  std::unique_ptr<ExprAST> arg;
+
+public:
+  PrintExprAST(Location loc, std::unique_ptr<ExprAST> arg)
+      : ExprAST(Expr_Print, std::move(loc)), arg(std::move(arg)) {}
+
+  ExprAST *getArg() { return arg.get(); }
+
+  /// LLVM style RTTI                                                                                                                                                 
+  static bool classof(const ExprAST *c) { return c->getKind() == Expr_Print; }
+};
+
+class PrototypeAST {
+  Location location;
+  std::string name;
+  std::vector<std::unique_ptr<VariableExprAST>> args;
+
+public:
+  PrototypeAST(Location location, const std::string &name,
+               std::vector<std::unique_ptr<VariableExprAST>> args)
+      : location(std::move(location)), name(name), args(std::move(args)) {}
+
+  const Location &loc() { return location; }
+  llvm::StringRef getName() const { return name; }
+  llvm::ArrayRef<std::unique_ptr<VariableExprAST>> getArgs() { return args; }
+};
+
+/// This class represents a function definition itself.                                                                                                               
+class FunctionAST {
+  std::unique_ptr<PrototypeAST> proto;
+  std::unique_ptr<ExprASTList> body;
+
+public:
+  FunctionAST(std::unique_ptr<PrototypeAST> proto,
+              std::unique_ptr<ExprASTList> body)
+      : proto(std::move(proto)), body(std::move(body)) {}
+  PrototypeAST *getProto() { return proto.get(); }
+  ExprASTList *getBody() { return body.get(); }
+};
+
+/// This class represents a list of functions to be processed together                                                                                                
+class ModuleAST {
+  std::vector<FunctionAST> functions;
+
+public:
+  ModuleAST(std::vector<FunctionAST> functions)
+      : functions(std::move(functions)) {}
+
+  auto begin() { return functions.begin(); }
+  auto end() { return functions.end(); }
+};
+
+void dump(ModuleAST &);
+
+
+}
+
+
+#endif //OBS_AST_H
+
+
diff --git a/obs/include/CMakeLists.txt b/obs/include/CMakeLists.txt
new file mode 100644
index 0000000000000..83793c91c12e2
--- /dev/null
+++ b/obs/include/CMakeLists.txt
@@ -0,0 +1,6 @@
+set(LLVM_TARGET_DEFINITIONS Ops.td)
+mlir_tablegen(Ops.h.inc -gen-op-decls)
+mlir_tablegen(Ops.cpp.inc -gen-op-defs)
+mlir_tablegen(Dialect.h.inc -gen-dialect-decls)
+mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs)
+add_public_tablegen_target(OBSOpGen)
\ No newline at end of file
diff --git a/obs/include/CodeGen.h b/obs/include/CodeGen.h
new file mode 100644
index 0000000000000..d36a8f5ff4e9b
--- /dev/null
+++ b/obs/include/CodeGen.h
@@ -0,0 +1,41 @@
+
+#ifndef _OBS_CODEGEN_H
+#define _OBS_CODEGEN_H
+
+#include "clang/AST/RecursiveASTVisitor.h"
+#include "clang/Basic/SourceManager.h"
+#include <clang/AST/ASTContext.h>
+#include <clang/AST/Decl.h>
+#include <memory>
+
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/Verifier.h"
+
+#include <mlir/IR/BuiltinOps.h>
+#include <mlir/IR/MLIRContext.h>
+#include <mlir/IR/OwningOpRef.h>
+
+namespace mlir {
+namespace obs {
+
+class MLIRGenImpl : public clang::RecursiveASTVisitor<MLIRGenImpl> {
+public:
+  MLIRGenImpl(mlir::MLIRContext &context) : builder(&context) {}
+
+  bool VisitFunctionDecl(clang::FunctionDecl *funcDecl);
+
+private:
+  mlir::ModuleOp theModule;
+  mlir::OpBuilder builder;
+};
+
+mlir::OwningOpRef<mlir::ModuleOp> mlirGen(mlir::MLIRContext &context,
+                                          clang::TranslationUnitDecl &decl);
+
+} // namespace obs
+} // namespace mlir
+
+#endif
\ No newline at end of file
diff --git a/obs/include/CodeGenAction.h b/obs/include/CodeGenAction.h
new file mode 100644
index 0000000000000..226a24ee78253
--- /dev/null
+++ b/obs/include/CodeGenAction.h
@@ -0,0 +1,43 @@
+
+
+#include "clang/AST/AST.h"
+#include "clang/AST/ASTConsumer.h"
+#include "clang/AST/ASTContext.h"
+#include "clang/Basic/LLVM.h"
+#include "clang/Frontend/CompilerInstance.h"
+#include "clang/Frontend/FrontendAction.h"
+#include "clang/Frontend/PrecompiledPreamble.h"
+#include "clang/Tooling/CommonOptionsParser.h"
+#include "clang/Tooling/Tooling.h"
+#include <iostream>
+#include <memory>
+
+namespace mlir {
+namespace obs {
+
+class CodeGenConsumer : public clang::ASTConsumer {
+public:
+  CodeGenConsumer() {}
+  void HandleTranslationUnit(clang::ASTContext &context) override;
+};
+
+class CodeGenFrontendAction : public clang::ASTFrontendAction {
+protected:
+  std::unique_ptr<clang::ASTConsumer>
+  CreateASTConsumer(clang::CompilerInstance &ci, clang::StringRef) override {
+    return std::make_unique<CodeGenConsumer>();
+  }
+};
+
+class CodeGenFrontendActionFactory
+    : public clang::tooling::FrontendActionFactory {
+public:
+  CodeGenFrontendActionFactory() {}
+
+  std::unique_ptr<clang::FrontendAction> create() override {
+    return std::make_unique<CodeGenFrontendAction>()...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/93931


More information about the llvm-commits mailing list