[mlir] [llvm] [mlir][EmitC] Add func, call and return operations and conversions (PR #79612)

Marius Brehler via llvm-commits llvm-commits at lists.llvm.org
Sun Jan 28 23:01:52 PST 2024


https://github.com/marbre updated https://github.com/llvm/llvm-project/pull/79612

>From a03e89ce76f265fbe548a8d2853468b94c7355cd Mon Sep 17 00:00:00 2001
From: Marius Brehler <marius.brehler at iml.fraunhofer.de>
Date: Tue, 23 Jan 2024 14:54:15 +0000
Subject: [PATCH 1/3] [mlir][EmitC] Add func, call and return operations

This adds a `func`, `call` and `return` operation to the EmitC dialect,
closely related to the corresponding operations of the Func dialect. In
contrast to the operations of the Func dialect, the EmitC operations do
not support multiple results. The `emitc.func` op features a
`specifiers` argument that for example allows, with corresponding
support in the emitter, to emit `inline static` functions.
---
 mlir/include/mlir/Dialect/EmitC/IR/EmitC.h    |   1 +
 mlir/include/mlir/Dialect/EmitC/IR/EmitC.td   | 186 ++++++++++++++++++
 mlir/lib/Dialect/EmitC/IR/CMakeLists.txt      |   2 +
 mlir/lib/Dialect/EmitC/IR/EmitC.cpp           | 182 +++++++++++++++++
 mlir/lib/Target/Cpp/TranslateToCpp.cpp        | 142 ++++++++++---
 mlir/test/Dialect/EmitC/invalid_ops.mlir      |  32 +++
 mlir/test/Dialect/EmitC/ops.mlir              |  15 ++
 mlir/test/Target/Cpp/func.mlir                |  39 ++++
 .../llvm-project-overlay/mlir/BUILD.bazel     |   4 +
 9 files changed, 574 insertions(+), 29 deletions(-)
 create mode 100644 mlir/test/Target/Cpp/func.mlir

diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h
index 4dff26e23c4285..3d38744527d599 100644
--- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h
+++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h
@@ -20,6 +20,7 @@
 #include "mlir/IR/Dialect.h"
 #include "mlir/Interfaces/CastInterfaces.h"
 #include "mlir/Interfaces/ControlFlowInterfaces.h"
+#include "mlir/Interfaces/FunctionInterfaces.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 
 #include "mlir/Dialect/EmitC/IR/EmitCDialect.h.inc"
diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
index b8f8f1e2d818d5..df0418a2ac372a 100644
--- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
+++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
@@ -16,8 +16,10 @@
 include "mlir/Dialect/EmitC/IR/EmitCAttributes.td"
 include "mlir/Dialect/EmitC/IR/EmitCTypes.td"
 
+include "mlir/Interfaces/CallInterfaces.td"
 include "mlir/Interfaces/CastInterfaces.td"
 include "mlir/Interfaces/ControlFlowInterfaces.td"
+include "mlir/Interfaces/FunctionInterfaces.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
 include "mlir/IR/RegionKindInterface.td"
 
@@ -386,6 +388,190 @@ def EmitC_ForOp : EmitC_Op<"for",
   let hasRegionVerifier = 1;
 }
 
+def EmitC_CallOp : EmitC_Op<"call",
+    [CallOpInterface,
+     DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
+  let summary = "call operation";
+  let description = [{
+    The `emitc.call` operation represents a direct call to an `emitc.func`
+    that is within the same symbol scope as the call. The operands and result type
+    of the call must match the specified function type. The callee is encoded as a
+    symbol reference attribute named "callee".
+
+    Example:
+
+    ```mlir
+    %2 = emitc.call @my_add(%0, %1) : (f32, f32) -> f32
+    ```
+  }];
+  let arguments = (ins FlatSymbolRefAttr:$callee, Variadic<AnyType>:$operands);
+  let results = (outs Variadic<AnyType>);
+
+  let builders = [
+    OpBuilder<(ins "FuncOp":$callee, CArg<"ValueRange", "{}">:$operands), [{
+      $_state.addOperands(operands);
+      $_state.addAttribute("callee", SymbolRefAttr::get(callee));
+      $_state.addTypes(callee.getFunctionType().getResults());
+    }]>,
+    OpBuilder<(ins "SymbolRefAttr":$callee, "TypeRange":$results,
+      CArg<"ValueRange", "{}">:$operands), [{
+      $_state.addOperands(operands);
+      $_state.addAttribute("callee", callee);
+      $_state.addTypes(results);
+    }]>,
+    OpBuilder<(ins "StringAttr":$callee, "TypeRange":$results,
+      CArg<"ValueRange", "{}">:$operands), [{
+      build($_builder, $_state, SymbolRefAttr::get(callee), results, operands);
+    }]>,
+    OpBuilder<(ins "StringRef":$callee, "TypeRange":$results,
+      CArg<"ValueRange", "{}">:$operands), [{
+      build($_builder, $_state, StringAttr::get($_builder.getContext(), callee),
+            results, operands);
+    }]>];
+
+  let extraClassDeclaration = [{
+    FunctionType getCalleeType();
+
+    /// Get the argument operands to the called function.
+    operand_range getArgOperands() {
+      return {arg_operand_begin(), arg_operand_end()};
+    }
+
+    MutableOperandRange getArgOperandsMutable() {
+      return getOperandsMutable();
+    }
+
+    operand_iterator arg_operand_begin() { return operand_begin(); }
+    operand_iterator arg_operand_end() { return operand_end(); }
+
+    /// Return the callee of this operation.
+    CallInterfaceCallable getCallableForCallee() {
+      return (*this)->getAttrOfType<SymbolRefAttr>("callee");
+    }
+
+    /// Set the callee for this operation.
+    void setCalleeFromCallable(CallInterfaceCallable callee) {
+      (*this)->setAttr("callee", callee.get<SymbolRefAttr>());
+    }
+  }];
+
+  let assemblyFormat = [{
+    $callee `(` $operands `)` attr-dict `:` functional-type($operands, results)
+  }];
+}
+
+def EmitC_FuncOp : EmitC_Op<"func", [
+  AutomaticAllocationScope,
+  FunctionOpInterface, IsolatedFromAbove
+]> {
+  let summary = "An operation with a name containing a single `SSACFG` region";
+  let description = [{
+    Operations within the function cannot implicitly capture values defined
+    outside of the function, i.e. Functions are `IsolatedFromAbove`. All
+    external references must use function arguments or attributes that establish
+    a symbolic connection (e.g. symbols referenced by name via a string
+    attribute like SymbolRefAttr). While the MLIR textual form provides a nice
+    inline syntax for function arguments, they are internally represented as
+    “block arguments” to the first block in the region.
+
+    Only dialect attribute names may be specified in the attribute dictionaries
+    for function arguments, results, or the function itself.
+
+    Example:
+
+    ```mlir
+    // A function with no results:
+    emitc.func @foo(%arg0 : i32) {
+      emitc.call_opaque "bar" (%arg0) : (i32) -> ()
+      emitc.return
+    }
+
+    // A function with its argument as single result:
+    emitc.func @foo(%arg0 : i32) -> i32 {
+      emitc.return %arg0 : i32
+    }
+
+    // A function with specifiers attribute:
+    emitc.func @example_specifiers_fn_attr() -> i32
+                attributes {specifiers = ["static","inline"]} {
+      %0 = emitc.call_opaque "foo" (): () -> i32
+      emitc.return %0 : i32
+    }
+
+    ```
+  }];
+  let arguments = (ins SymbolNameAttr:$sym_name,
+                       TypeAttrOf<FunctionType>:$function_type,
+                       OptionalAttr<StrArrayAttr>:$specifiers,
+                       OptionalAttr<DictArrayAttr>:$arg_attrs,
+                       OptionalAttr<DictArrayAttr>:$res_attrs);
+  let regions = (region AnyRegion:$body);
+
+  let builders = [OpBuilder<(ins
+    "StringRef":$name, "FunctionType":$type,
+    CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs,
+    CArg<"ArrayRef<DictionaryAttr>", "{}">:$argAttrs)
+  >];
+  let extraClassDeclaration = [{
+    /// Create a deep copy of this function and all of its blocks, remapping any
+    /// operands that use values outside of the function using the map that is
+    /// provided (leaving them alone if no entry is present). If the mapper
+    /// contains entries for function arguments, these arguments are not
+    /// included in the new function. Replaces references to cloned sub-values
+    /// with the corresponding value that is copied, and adds those mappings to
+    /// the mapper.
+    FuncOp clone(IRMapping &mapper);
+    FuncOp clone();
+
+    /// Clone the internal blocks and attributes from this function into dest.
+    /// Any cloned blocks are appended to the back of dest. This function
+    /// asserts that the attributes of the current function and dest are
+    /// compatible.
+    void cloneInto(FuncOp dest, IRMapping &mapper);
+
+    //===------------------------------------------------------------------===//
+    // FunctionOpInterface Methods
+    //===------------------------------------------------------------------===//
+
+    /// Returns the region on the current operation that is callable. This may
+    /// return null in the case of an external callable object, e.g. an external
+    /// function.
+    ::mlir::Region *getCallableRegion() { return isExternal() ? nullptr : &getBody(); }
+
+    /// Returns the argument types of this function.
+    ArrayRef<Type> getArgumentTypes() { return getFunctionType().getInputs(); }
+
+    /// Returns the result types of this function.
+    ArrayRef<Type> getResultTypes() { return getFunctionType().getResults(); }
+  }];
+  let hasCustomAssemblyFormat = 1;
+  let hasVerifier = 1;
+}
+
+def EmitC_ReturnOp : EmitC_Op<"return", [Pure, HasParent<"FuncOp">,
+                                ReturnLike, Terminator]> {
+  let summary = "Function return operation";
+  let description = [{
+    The `emitc.return` operation represents a return operation within a function.
+    The operation takes zero or exactly one operand and produces no results.
+    The operand number and type must match the signature of the function
+    that contains the operation.
+
+    Example:
+
+    ```mlir
+    emitc.func @foo() : (i32) {
+      ...
+      emitc.return %0 : i32
+    }
+    ```
+  }];
+  let arguments = (ins Optional<AnyType>:$operand);
+
+  let assemblyFormat = "attr-dict ($operand^ `:` type($operand))?";
+  let hasVerifier = 1;
+}
+
 def EmitC_IncludeOp
     : EmitC_Op<"include", [HasParent<"ModuleOp">]> {
   let summary = "Include operation";
diff --git a/mlir/lib/Dialect/EmitC/IR/CMakeLists.txt b/mlir/lib/Dialect/EmitC/IR/CMakeLists.txt
index 4665c41a62e80b..4cc54201d2745d 100644
--- a/mlir/lib/Dialect/EmitC/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/EmitC/IR/CMakeLists.txt
@@ -9,8 +9,10 @@ add_mlir_dialect_library(MLIREmitCDialect
   MLIREmitCAttributesIncGen
 
   LINK_LIBS PUBLIC
+  MLIRCallInterfaces
   MLIRCastInterfaces
   MLIRControlFlowInterfaces
+  MLIRFunctionInterfaces
   MLIRIR
   MLIRSideEffectInterfaces
   )
diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
index 5f502f1f7a1714..ef67764e732d22 100644
--- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
+++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
@@ -8,7 +8,10 @@
 
 #include "mlir/Dialect/EmitC/IR/EmitC.h"
 #include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/DialectImplementation.h"
+#include "mlir/IR/IRMapping.h"
+#include "mlir/Interfaces/FunctionImplementation.h"
 #include "llvm/ADT/StringExtras.h"
 #include "llvm/ADT/TypeSwitch.h"
 
@@ -346,6 +349,185 @@ LogicalResult ForOp::verifyRegions() {
 
   return success();
 }
+//===----------------------------------------------------------------------===//
+// CallOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
+  // Check that the callee attribute was specified.
+  auto fnAttr = (*this)->getAttrOfType<FlatSymbolRefAttr>("callee");
+  if (!fnAttr)
+    return emitOpError("requires a 'callee' symbol reference attribute");
+  FuncOp fn = symbolTable.lookupNearestSymbolFrom<FuncOp>(*this, fnAttr);
+  if (!fn)
+    return emitOpError() << "'" << fnAttr.getValue()
+                         << "' does not reference a valid function";
+
+  // Verify that the operand and result types match the callee.
+  auto fnType = fn.getFunctionType();
+  if (fnType.getNumInputs() != getNumOperands())
+    return emitOpError("incorrect number of operands for callee");
+
+  for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i)
+    if (getOperand(i).getType() != fnType.getInput(i))
+      return emitOpError("operand type mismatch: expected operand type ")
+             << fnType.getInput(i) << ", but provided "
+             << getOperand(i).getType() << " for operand number " << i;
+
+  if (fnType.getNumResults() != getNumResults())
+    return emitOpError("incorrect number of results for callee");
+
+  for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i)
+    if (getResult(i).getType() != fnType.getResult(i)) {
+      auto diag = emitOpError("result type mismatch at index ") << i;
+      diag.attachNote() << "      op result types: " << getResultTypes();
+      diag.attachNote() << "function result types: " << fnType.getResults();
+      return diag;
+    }
+
+  return success();
+}
+
+FunctionType CallOp::getCalleeType() {
+  return FunctionType::get(getContext(), getOperandTypes(), getResultTypes());
+}
+
+//===----------------------------------------------------------------------===//
+// FuncOp
+//===----------------------------------------------------------------------===//
+
+void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name,
+                   FunctionType type, ArrayRef<NamedAttribute> attrs,
+                   ArrayRef<DictionaryAttr> argAttrs) {
+  state.addAttribute(SymbolTable::getSymbolAttrName(),
+                     builder.getStringAttr(name));
+  state.addAttribute(getFunctionTypeAttrName(state.name), TypeAttr::get(type));
+  state.attributes.append(attrs.begin(), attrs.end());
+  state.addRegion();
+
+  if (argAttrs.empty())
+    return;
+  assert(type.getNumInputs() == argAttrs.size());
+  function_interface_impl::addArgAndResultAttrs(
+      builder, state, argAttrs, /*resultAttrs=*/std::nullopt,
+      getArgAttrsAttrName(state.name), getResAttrsAttrName(state.name));
+}
+
+ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
+  auto buildFuncType =
+      [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
+         function_interface_impl::VariadicFlag,
+         std::string &) { return builder.getFunctionType(argTypes, results); };
+
+  return function_interface_impl::parseFunctionOp(
+      parser, result, /*allowVariadic=*/false,
+      getFunctionTypeAttrName(result.name), buildFuncType,
+      getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
+}
+
+void FuncOp::print(OpAsmPrinter &p) {
+  function_interface_impl::printFunctionOp(
+      p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
+      getArgAttrsAttrName(), getResAttrsAttrName());
+}
+
+/// Clone the internal blocks from this function into dest and all attributes
+/// from this function to dest.
+void FuncOp::cloneInto(FuncOp dest, IRMapping &mapper) {
+  // Add the attributes of this function to dest.
+  llvm::MapVector<StringAttr, Attribute> newAttrMap;
+  for (const auto &attr : dest->getAttrs())
+    newAttrMap.insert({attr.getName(), attr.getValue()});
+  for (const auto &attr : (*this)->getAttrs())
+    newAttrMap.insert({attr.getName(), attr.getValue()});
+
+  auto newAttrs = llvm::to_vector(llvm::map_range(
+      newAttrMap, [](std::pair<StringAttr, Attribute> attrPair) {
+        return NamedAttribute(attrPair.first, attrPair.second);
+      }));
+  dest->setAttrs(DictionaryAttr::get(getContext(), newAttrs));
+
+  // Clone the body.
+  getBody().cloneInto(&dest.getBody(), mapper);
+}
+
+/// Create a deep copy of this function and all of its blocks, remapping
+/// any operands that use values outside of the function using the map that is
+/// provided (leaving them alone if no entry is present). Replaces references
+/// to cloned sub-values with the corresponding value that is copied, and adds
+/// those mappings to the mapper.
+FuncOp FuncOp::clone(IRMapping &mapper) {
+  // Create the new function.
+  FuncOp newFunc = cast<FuncOp>(getOperation()->cloneWithoutRegions());
+
+  // If the function has a body, then the user might be deleting arguments to
+  // the function by specifying them in the mapper. If so, we don't add the
+  // argument to the input type vector.
+  if (!isExternal()) {
+    FunctionType oldType = getFunctionType();
+
+    unsigned oldNumArgs = oldType.getNumInputs();
+    SmallVector<Type, 4> newInputs;
+    newInputs.reserve(oldNumArgs);
+    for (unsigned i = 0; i != oldNumArgs; ++i)
+      if (!mapper.contains(getArgument(i)))
+        newInputs.push_back(oldType.getInput(i));
+
+    /// If any of the arguments were dropped, update the type and drop any
+    /// necessary argument attributes.
+    if (newInputs.size() != oldNumArgs) {
+      newFunc.setType(FunctionType::get(oldType.getContext(), newInputs,
+                                        oldType.getResults()));
+
+      if (ArrayAttr argAttrs = getAllArgAttrs()) {
+        SmallVector<Attribute> newArgAttrs;
+        newArgAttrs.reserve(newInputs.size());
+        for (unsigned i = 0; i != oldNumArgs; ++i)
+          if (!mapper.contains(getArgument(i)))
+            newArgAttrs.push_back(argAttrs[i]);
+        newFunc.setAllArgAttrs(newArgAttrs);
+      }
+    }
+  }
+
+  /// Clone the current function into the new one and return it.
+  cloneInto(newFunc, mapper);
+  return newFunc;
+}
+FuncOp FuncOp::clone() {
+  IRMapping mapper;
+  return clone(mapper);
+}
+
+LogicalResult FuncOp::verify() {
+  if (getNumResults() > 1)
+    return emitOpError("requires zero or exactly one result, but has ")
+           << getNumResults();
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// ReturnOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult ReturnOp::verify() {
+  auto function = cast<FuncOp>((*this)->getParentOp());
+
+  // The operand number and types must match the function signature.
+  if (getNumOperands() != function.getNumResults())
+    return emitOpError("has ")
+           << getNumOperands() << " operands, but enclosing function (@"
+           << function.getName() << ") returns " << function.getNumResults();
+
+  if (function.getNumResults() == 1)
+    if (getOperand().getType() != function.getResultTypes()[0])
+      return emitError() << "type of the return operand ("
+                         << getOperand().getType()
+                         << ") doesn't match function result type ("
+                         << function.getResultTypes()[0] << ")"
+                         << " in function @" << function.getName();
+  return success();
+}
 
 //===----------------------------------------------------------------------===//
 // IfOp
diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
index c32cb03caf9db6..9fe16f5b11df65 100644
--- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp
+++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
@@ -495,18 +495,33 @@ static LogicalResult printOperation(CppEmitter &emitter,
   return success();
 }
 
-static LogicalResult printOperation(CppEmitter &emitter, func::CallOp callOp) {
-  if (failed(emitter.emitAssignPrefix(*callOp.getOperation())))
+static LogicalResult printCallOperation(CppEmitter &emitter, Operation *callOp,
+                                        StringRef callee) {
+  if (failed(emitter.emitAssignPrefix(*callOp)))
     return failure();
 
   raw_ostream &os = emitter.ostream();
-  os << callOp.getCallee() << "(";
-  if (failed(emitter.emitOperands(*callOp.getOperation())))
+  os << callee << "(";
+  if (failed(emitter.emitOperands(*callOp)))
     return failure();
   os << ")";
   return success();
 }
 
+static LogicalResult printOperation(CppEmitter &emitter, func::CallOp callOp) {
+  Operation *operation = callOp.getOperation();
+  StringRef callee = callOp.getCallee();
+
+  return (printCallOperation(emitter, operation, callee));
+}
+
+static LogicalResult printOperation(CppEmitter &emitter, emitc::CallOp callOp) {
+  Operation *operation = callOp.getOperation();
+  StringRef callee = callOp.getCallee();
+
+  return (printCallOperation(emitter, operation, callee));
+}
+
 static LogicalResult printOperation(CppEmitter &emitter,
                                     emitc::CallOpaqueOp callOpaqueOp) {
   raw_ostream &os = emitter.ostream();
@@ -724,6 +739,19 @@ static LogicalResult printOperation(CppEmitter &emitter,
   }
 }
 
+static LogicalResult printOperation(CppEmitter &emitter,
+                                    emitc::ReturnOp returnOp) {
+  raw_ostream &os = emitter.ostream();
+  os << "return";
+  if (returnOp.getNumOperands() == 0)
+    return success();
+
+  os << " ";
+  if (failed(emitter.emitOperand(returnOp.getOperand())))
+    return failure();
+  return success();
+}
+
 static LogicalResult printOperation(CppEmitter &emitter, ModuleOp moduleOp) {
   CppEmitter::Scope scope(emitter);
 
@@ -734,39 +762,34 @@ static LogicalResult printOperation(CppEmitter &emitter, ModuleOp moduleOp) {
   return success();
 }
 
-static LogicalResult printOperation(CppEmitter &emitter,
-                                    func::FuncOp functionOp) {
-  // We need to declare variables at top if the function has multiple blocks.
-  if (!emitter.shouldDeclareVariablesAtTop() &&
-      functionOp.getBlocks().size() > 1) {
-    return functionOp.emitOpError(
-        "with multiple blocks needs variables declared at top");
-  }
-
-  CppEmitter::Scope scope(emitter);
+static LogicalResult printFunctionArgs(CppEmitter &emitter,
+                                       Operation *functionOp,
+                                       Region::BlockArgListType arguments) {
   raw_indented_ostream &os = emitter.ostream();
-  if (failed(emitter.emitTypes(functionOp.getLoc(),
-                               functionOp.getFunctionType().getResults())))
-    return failure();
-  os << " " << functionOp.getName();
 
-  os << "(";
   if (failed(interleaveCommaWithError(
-          functionOp.getArguments(), os,
-          [&](BlockArgument arg) -> LogicalResult {
-            if (failed(emitter.emitType(functionOp.getLoc(), arg.getType())))
+          arguments, os, [&](BlockArgument arg) -> LogicalResult {
+            if (failed(emitter.emitType(functionOp->getLoc(), arg.getType())))
               return failure();
             os << " " << emitter.getOrCreateName(arg);
             return success();
           })))
     return failure();
-  os << ") {\n";
+
+  return success();
+}
+
+static LogicalResult printFunctionBody(CppEmitter &emitter,
+                                       Operation *functionOp,
+                                       Region::BlockListType &blocks) {
+  raw_indented_ostream &os = emitter.ostream();
   os.indent();
+
   if (emitter.shouldDeclareVariablesAtTop()) {
     // Declare all variables that hold op results including those from nested
     // regions.
     WalkResult result =
-        functionOp.walk<WalkOrder::PreOrder>([&](Operation *op) -> WalkResult {
+        functionOp->walk<WalkOrder::PreOrder>([&](Operation *op) -> WalkResult {
           if (isa<emitc::LiteralOp>(op) ||
               isa<emitc::ExpressionOp>(op->getParentOp()) ||
               (isa<emitc::ExpressionOp>(op) &&
@@ -785,7 +808,6 @@ static LogicalResult printOperation(CppEmitter &emitter,
       return failure();
   }
 
-  Region::BlockListType &blocks = functionOp.getBlocks();
   // Create label names for basic blocks.
   for (Block &block : blocks) {
     emitter.getOrCreateName(block);
@@ -795,7 +817,7 @@ static LogicalResult printOperation(CppEmitter &emitter,
   for (Block &block : llvm::drop_begin(blocks)) {
     for (BlockArgument &arg : block.getArguments()) {
       if (emitter.hasValueInScope(arg))
-        return functionOp.emitOpError(" block argument #")
+        return functionOp->emitOpError(" block argument #")
                << arg.getArgNumber() << " is out of scope";
       if (failed(
               emitter.emitType(block.getParentOp()->getLoc(), arg.getType()))) {
@@ -825,7 +847,68 @@ static LogicalResult printOperation(CppEmitter &emitter,
         return failure();
     }
   }
+  return success();
+}
+
+static LogicalResult printOperation(CppEmitter &emitter,
+                                    func::FuncOp functionOp) {
+  // We need to declare variables at top if the function has multiple blocks.
+  if (!emitter.shouldDeclareVariablesAtTop() &&
+      functionOp.getBlocks().size() > 1) {
+    return functionOp.emitOpError(
+        "with multiple blocks needs variables declared at top");
+  }
+
+  CppEmitter::Scope scope(emitter);
+  raw_indented_ostream &os = emitter.ostream();
+  if (failed(emitter.emitTypes(functionOp.getLoc(),
+                               functionOp.getFunctionType().getResults())))
+    return failure();
+  os << " " << functionOp.getName();
+
+  os << "(";
+  Operation *operation = functionOp.getOperation();
+  if (failed(printFunctionArgs(emitter, operation, functionOp.getArguments())))
+    return failure();
+  os << ") {\n";
+  if (failed(printFunctionBody(emitter, operation, functionOp.getBlocks())))
+    return failure();
   os.unindent() << "}\n";
+
+  return success();
+}
+
+static LogicalResult printOperation(CppEmitter &emitter,
+                                    emitc::FuncOp functionOp) {
+  // We need to declare variables at top if the function has multiple blocks.
+  if (!emitter.shouldDeclareVariablesAtTop() &&
+      functionOp.getBlocks().size() > 1) {
+    return functionOp.emitOpError(
+        "with multiple blocks needs variables declared at top");
+  }
+
+  CppEmitter::Scope scope(emitter);
+  raw_indented_ostream &os = emitter.ostream();
+  if (functionOp.getSpecifiers()) {
+    for (Attribute specifier : functionOp.getSpecifiersAttr()) {
+      os << cast<StringAttr>(specifier).str() << " ";
+    }
+  }
+
+  if (failed(emitter.emitTypes(functionOp.getLoc(),
+                               functionOp.getFunctionType().getResults())))
+    return failure();
+  os << " " << functionOp.getName();
+
+  os << "(";
+  Operation *operation = functionOp.getOperation();
+  if (failed(printFunctionArgs(emitter, operation, functionOp.getArguments())))
+    return failure();
+  os << ") {\n";
+  if (failed(printFunctionBody(emitter, operation, functionOp.getBlocks())))
+    return failure();
+  os.unindent() << "}\n";
+
   return success();
 }
 
@@ -1140,11 +1223,12 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) {
           .Case<cf::BranchOp, cf::CondBranchOp>(
               [&](auto op) { return printOperation(*this, op); })
           // EmitC ops.
-          .Case<emitc::AddOp, emitc::ApplyOp, emitc::AssignOp,
+          .Case<emitc::AddOp, emitc::ApplyOp, emitc::AssignOp, emitc::CallOp,
                 emitc::CallOpaqueOp, emitc::CastOp, emitc::CmpOp,
                 emitc::ConstantOp, emitc::DivOp, emitc::ExpressionOp,
-                emitc::ForOp, emitc::IfOp, emitc::IncludeOp, emitc::MulOp,
-                emitc::RemOp, emitc::SubOp, emitc::VariableOp>(
+                emitc::ForOp, emitc::FuncOp, emitc::IfOp, emitc::IncludeOp,
+                emitc::MulOp, emitc::ReturnOp, emitc::RemOp, emitc::SubOp,
+                emitc::VariableOp>(
               [&](auto op) { return printOperation(*this, op); })
           // Func ops.
           .Case<func::CallOp, func::ConstantOp, func::FuncOp, func::ReturnOp>(
diff --git a/mlir/test/Dialect/EmitC/invalid_ops.mlir b/mlir/test/Dialect/EmitC/invalid_ops.mlir
index 46eccb1c24eea2..6d2471b4d2b486 100644
--- a/mlir/test/Dialect/EmitC/invalid_ops.mlir
+++ b/mlir/test/Dialect/EmitC/invalid_ops.mlir
@@ -289,3 +289,35 @@ func.func @test_expression_multiple_results(%arg0: i32) -> i32 {
   }
   return %r : i32
 }
+
+// -----
+
+// expected-error @+1 {{'emitc.func' op requires zero or exactly one result, but has 2}}
+emitc.func @multiple_results(%0: i32) -> (i32, i32) {
+  emitc.return %0 : i32
+}
+
+// -----
+
+emitc.func @resulterror() -> i32 {
+^bb42:
+  emitc.return    // expected-error {{'emitc.return' op has 0 operands, but enclosing function (@resulterror) returns 1}}
+}
+
+// -----
+
+emitc.func @return_type_mismatch() -> i32 {
+  %0 = emitc.call_opaque "foo()"(): () -> f32
+  emitc.return %0 : f32  // expected-error {{type of the return operand ('f32') doesn't match function result type ('i32') in function @return_type_mismatch}}
+}
+
+// -----
+
+func.func @return_inside_func.func(%0: i32) -> (i32) {
+  // expected-error at +1 {{'emitc.return' op expects parent op 'emitc.func'}}
+  emitc.return %0 : i32
+}
+// -----
+
+// expected-error at +1 {{expected non-function type}}
+emitc.func @func_variadic(...)
diff --git a/mlir/test/Dialect/EmitC/ops.mlir b/mlir/test/Dialect/EmitC/ops.mlir
index 45ce2bcb99092c..fd30364a0465f4 100644
--- a/mlir/test/Dialect/EmitC/ops.mlir
+++ b/mlir/test/Dialect/EmitC/ops.mlir
@@ -15,6 +15,21 @@ func.func @f(%arg0: i32, %f: !emitc.opaque<"int32_t">) {
   return
 }
 
+emitc.func @func(%arg0 : i32) {
+  emitc.call_opaque "foo"(%arg0) : (i32) -> ()
+  emitc.return
+}
+
+emitc.func @return_i32() -> i32 attributes {specifiers = ["static","inline"]} {
+  %0 = emitc.call_opaque "foo"(): () -> i32
+  emitc.return %0 : i32
+}
+
+emitc.func @call() -> i32 {
+  %0 = emitc.call @return_i32() : () -> (i32)
+  emitc.return %0 : i32
+}
+
 func.func @cast(%arg0: i32) {
   %1 = emitc.cast %arg0: i32 to f32
   return
diff --git a/mlir/test/Target/Cpp/func.mlir b/mlir/test/Target/Cpp/func.mlir
new file mode 100644
index 00000000000000..d2e14a9e5a7aeb
--- /dev/null
+++ b/mlir/test/Target/Cpp/func.mlir
@@ -0,0 +1,39 @@
+// RUN: mlir-translate -mlir-to-cpp %s | FileCheck %s -check-prefix=CPP-DEFAULT
+// RUN: mlir-translate -mlir-to-cpp -declare-variables-at-top %s | FileCheck %s -check-prefix=CPP-DECLTOP
+
+
+emitc.func @emitc_func(%arg0 : i32) {
+  emitc.call_opaque "foo" (%arg0) : (i32) -> ()
+  emitc.return
+}
+// CPP-DEFAULT: void emitc_func(int32_t [[V0:[^ ]*]]) {
+// CPP-DEFAULT-NEXT: foo([[V0:[^ ]*]]);
+// CPP-DEFAULT-NEXT: return;
+
+
+emitc.func @return_i32() -> i32 attributes {specifiers = ["static","inline"]} {
+  %0 = emitc.call_opaque "foo" (): () -> i32
+  emitc.return %0 : i32
+}
+// CPP-DEFAULT: static inline int32_t return_i32() {
+// CPP-DEFAULT-NEXT: [[V0:[^ ]*]] = foo();
+// CPP-DEFAULT-NEXT: return [[V0:[^ ]*]];
+
+// CPP-DECLTOP: static inline int32_t return_i32() {
+// CPP-DECLTOP-NEXT: int32_t [[V0:[^ ]*]];
+// CPP-DECLTOP-NEXT: [[V0:]] = foo();
+// CPP-DECLTOP-NEXT: return [[V0:[^ ]*]];
+
+
+emitc.func @emitc_call() -> i32 {
+  %0 = emitc.call @return_i32() : () -> (i32)
+  emitc.return %0 : i32
+}
+// CPP-DEFAULT: int32_t emitc_call() {
+// CPP-DEFAULT-NEXT: int32_t [[V0:[^ ]*]] = return_i32();
+// CPP-DEFAULT-NEXT: return [[V0:[^ ]*]];
+
+// CPP-DECLTOP: int32_t emitc_call() {
+// CPP-DECLTOP-NEXT: int32_t [[V0:[^ ]*]];
+// CPP-DECLTOP-NEXT: [[V0:[^ ]*]] = return_i32();
+// CPP-DECLTOP-NEXT: return [[V0:[^ ]*]];
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 7a4495e28caed6..c405811c302ed4 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -1564,8 +1564,10 @@ td_library(
     includes = ["include"],
     deps = [
         ":BuiltinDialectTdFiles",
+        ":CallInterfacesTdFiles",
         ":CastInterfacesTdFiles",
         ":ControlFlowInterfacesTdFiles",
+        ":FunctionInterfacesTdFiles",
         ":OpBaseTdFiles",
         ":SideEffectInterfacesTdFiles",
     ],
@@ -3659,10 +3661,12 @@ cc_library(
     ]),
     includes = ["include"],
     deps = [
+        ":CallOpInterfaces",
         ":CastInterfaces",
         ":ControlFlowInterfaces",
         ":EmitCAttributesIncGen",
         ":EmitCOpsIncGen",
+        ":FunctionInterfaces",
         ":IR",
         ":SideEffectInterfaces",
         "//llvm:Support",

>From 5d7f3e90ee2efd86e6399a62e4b8580f4d68c4d8 Mon Sep 17 00:00:00 2001
From: Marius Brehler <marius.brehler at iml.fraunhofer.de>
Date: Wed, 24 Jan 2024 15:55:11 +0000
Subject: [PATCH 2/3] [mlir][emitc] Add `func` to `emitc` conversion

This adds patterns and a pass to convert the Func dialect to EmitC.
A `func.func` op that is `private` is converted to `emitc.func` with a
`"static"` specifier.
---
 .../mlir/Conversion/FuncToEmitC/FuncToEmitC.h |  18 +++
 .../Conversion/FuncToEmitC/FuncToEmitCPass.h  |  24 ++++
 mlir/include/mlir/Conversion/Passes.h         |   1 +
 mlir/include/mlir/Conversion/Passes.td        |   9 ++
 mlir/lib/Conversion/CMakeLists.txt            |   1 +
 .../lib/Conversion/FuncToEmitC/CMakeLists.txt |  16 +++
 .../Conversion/FuncToEmitC/FuncToEmitC.cpp    | 119 ++++++++++++++++++
 .../FuncToEmitC/FuncToEmitCPass.cpp           |  49 ++++++++
 .../Conversion/FuncToEmitC/func-to-emitc.mlir |  55 ++++++++
 9 files changed, 292 insertions(+)
 create mode 100644 mlir/include/mlir/Conversion/FuncToEmitC/FuncToEmitC.h
 create mode 100644 mlir/include/mlir/Conversion/FuncToEmitC/FuncToEmitCPass.h
 create mode 100644 mlir/lib/Conversion/FuncToEmitC/CMakeLists.txt
 create mode 100644 mlir/lib/Conversion/FuncToEmitC/FuncToEmitC.cpp
 create mode 100644 mlir/lib/Conversion/FuncToEmitC/FuncToEmitCPass.cpp
 create mode 100644 mlir/test/Conversion/FuncToEmitC/func-to-emitc.mlir

diff --git a/mlir/include/mlir/Conversion/FuncToEmitC/FuncToEmitC.h b/mlir/include/mlir/Conversion/FuncToEmitC/FuncToEmitC.h
new file mode 100644
index 00000000000000..5c7f87e470306a
--- /dev/null
+++ b/mlir/include/mlir/Conversion/FuncToEmitC/FuncToEmitC.h
@@ -0,0 +1,18 @@
+//===- FuncToEmitC.h - Func to EmitC Patterns -------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CONVERSION_FUNCTOEMITC_FUNCTOEMITC_H
+#define MLIR_CONVERSION_FUNCTOEMITC_FUNCTOEMITC_H
+
+namespace mlir {
+class RewritePatternSet;
+
+void populateFuncToEmitCPatterns(RewritePatternSet &patterns);
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_FUNCTOEMITC_FUNCTOEMITC_H
diff --git a/mlir/include/mlir/Conversion/FuncToEmitC/FuncToEmitCPass.h b/mlir/include/mlir/Conversion/FuncToEmitC/FuncToEmitCPass.h
new file mode 100644
index 00000000000000..8e0e0bc74fb02d
--- /dev/null
+++ b/mlir/include/mlir/Conversion/FuncToEmitC/FuncToEmitCPass.h
@@ -0,0 +1,24 @@
+//===- FuncToEmitCPass.h - Func to EmitC Pass -------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CONVERSION_FUNCTOEMITC_FUNCTOEMITCPASS_H
+#define MLIR_CONVERSION_FUNCTOEMITC_FUNCTOEMITCPASS_H
+
+#include <memory>
+
+namespace mlir {
+class Pass;
+
+#define GEN_PASS_DECL_FUNCTOEMITC
+#include "mlir/Conversion/Passes.h.inc"
+
+std::unique_ptr<Pass> createConvertFuncToEmitC();
+
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_FUNCTOEMITC_FUNCTOEMITCPASS_H
diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h
index a25fd17ea923fb..751b84d9288a85 100644
--- a/mlir/include/mlir/Conversion/Passes.h
+++ b/mlir/include/mlir/Conversion/Passes.h
@@ -28,6 +28,7 @@
 #include "mlir/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.h"
 #include "mlir/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.h"
 #include "mlir/Conversion/ConvertToLLVM/ToLLVMPass.h"
+#include "mlir/Conversion/FuncToEmitC/FuncToEmitCPass.h"
 #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h"
 #include "mlir/Conversion/FuncToSPIRV/FuncToSPIRVPass.h"
 #include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 6193aeb545bc6b..e4f6012ca24fa8 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -344,6 +344,15 @@ def ConvertControlFlowToSPIRV : Pass<"convert-cf-to-spirv"> {
   ];
 }
 
+//===----------------------------------------------------------------------===//
+// FuncToEmitC
+//===----------------------------------------------------------------------===//
+
+def ConvertFuncToEmitC : Pass<"convert-func-to-emitc", "ModuleOp"> {
+  let summary = "Convert Func dialect to EmitC dialect";
+  let dependentDialects = ["emitc::EmitCDialect"];
+}
+
 //===----------------------------------------------------------------------===//
 // FuncToLLVM
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index c3a2481975040c..bc89f8621a29c1 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -17,6 +17,7 @@ add_subdirectory(ControlFlowToLLVM)
 add_subdirectory(ControlFlowToSCF)
 add_subdirectory(ControlFlowToSPIRV)
 add_subdirectory(ConvertToLLVM)
+add_subdirectory(FuncToEmitC)
 add_subdirectory(FuncToLLVM)
 add_subdirectory(FuncToSPIRV)
 add_subdirectory(GPUCommon)
diff --git a/mlir/lib/Conversion/FuncToEmitC/CMakeLists.txt b/mlir/lib/Conversion/FuncToEmitC/CMakeLists.txt
new file mode 100644
index 00000000000000..97752205bbcb40
--- /dev/null
+++ b/mlir/lib/Conversion/FuncToEmitC/CMakeLists.txt
@@ -0,0 +1,16 @@
+add_mlir_conversion_library(MLIRFuncToEmitC
+  FuncToEmitC.cpp
+  FuncToEmitCPass.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/FuncToEmitC
+
+  DEPENDS
+  MLIRConversionPassIncGen
+
+  LINK_LIBS PUBLIC
+  MLIREmitCDialect
+  MLIRFuncDialect
+  MLIRPass
+  MLIRTransformUtils
+  )
diff --git a/mlir/lib/Conversion/FuncToEmitC/FuncToEmitC.cpp b/mlir/lib/Conversion/FuncToEmitC/FuncToEmitC.cpp
new file mode 100644
index 00000000000000..de56a91be60af3
--- /dev/null
+++ b/mlir/lib/Conversion/FuncToEmitC/FuncToEmitC.cpp
@@ -0,0 +1,119 @@
+//===- FuncToEmitC.cpp - Func to EmitC Patterns -----------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements patterns to convert the Func dialect to the EmitC
+// dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/FuncToEmitC/FuncToEmitC.h"
+
+#include "mlir/Dialect/EmitC/IR/EmitC.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+using namespace mlir;
+
+//===----------------------------------------------------------------------===//
+// Conversion Patterns
+//===----------------------------------------------------------------------===//
+
+namespace {
+class CallOpConversion final : public OpConversionPattern<func::CallOp> {
+public:
+  using OpConversionPattern<func::CallOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(func::CallOp callOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    // multiple results func was not converted to spirv.func
+    if (callOp.getNumResults() > 1)
+      return rewriter.notifyMatchFailure(
+          callOp, "Only functions with zero or one result can be converted");
+
+    rewriter.replaceOpWithNewOp<emitc::CallOp>(
+        callOp,
+        callOp.getNumResults() ? callOp.getResult(0).getType() : nullptr,
+        adaptor.getOperands(), callOp->getAttrs());
+
+    return success();
+  }
+};
+
+class FuncOpConversion final : public OpConversionPattern<func::FuncOp> {
+public:
+  using OpConversionPattern<func::FuncOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(func::FuncOp funcOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+
+    if (funcOp.getFunctionType().getNumResults() > 1)
+      return rewriter.notifyMatchFailure(
+          funcOp, "Only functions with zero or one result can be converted");
+
+    if (funcOp.isDeclaration())
+      return rewriter.notifyMatchFailure(funcOp,
+                                         "Declarations cannot be converted");
+
+    // Create the converted emitc.func op.
+    emitc::FuncOp newFuncOp = rewriter.create<emitc::FuncOp>(
+        funcOp.getLoc(), funcOp.getName(), funcOp.getFunctionType());
+
+    // Copy over all attributes other than the function name and type.
+    for (const auto &namedAttr : funcOp->getAttrs()) {
+      if (namedAttr.getName() != funcOp.getFunctionTypeAttrName() &&
+          namedAttr.getName() != SymbolTable::getSymbolAttrName())
+        newFuncOp->setAttr(namedAttr.getName(), namedAttr.getValue());
+    }
+
+    // Create add `static` to specifiers if `func.func` is private.
+    if (funcOp.isPrivate()) {
+      StringAttr specifier = rewriter.getStringAttr("static");
+      ArrayAttr specifiers = rewriter.getArrayAttr(specifier);
+      newFuncOp.setSpecifiersAttr(specifiers);
+    }
+
+    rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
+                                newFuncOp.end());
+    rewriter.eraseOp(funcOp);
+
+    return success();
+  }
+};
+
+class ReturnOpConversion final : public OpConversionPattern<func::ReturnOp> {
+public:
+  using OpConversionPattern<func::ReturnOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(func::ReturnOp returnOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    if (returnOp.getNumOperands() > 1)
+      return rewriter.notifyMatchFailure(
+          returnOp, "Only zero or one operand is supported");
+
+    rewriter.replaceOpWithNewOp<emitc::ReturnOp>(
+        returnOp,
+        returnOp.getNumOperands() ? adaptor.getOperands()[0] : nullptr);
+    return success();
+  }
+};
+} // namespace
+
+//===----------------------------------------------------------------------===//
+// Pattern population
+//===----------------------------------------------------------------------===//
+
+void mlir::populateFuncToEmitCPatterns(RewritePatternSet &patterns) {
+  MLIRContext *ctx = patterns.getContext();
+
+  patterns.add<CallOpConversion>(ctx);
+  patterns.add<FuncOpConversion>(ctx);
+  patterns.add<ReturnOpConversion>(ctx);
+}
diff --git a/mlir/lib/Conversion/FuncToEmitC/FuncToEmitCPass.cpp b/mlir/lib/Conversion/FuncToEmitC/FuncToEmitCPass.cpp
new file mode 100644
index 00000000000000..e29b414b789c9c
--- /dev/null
+++ b/mlir/lib/Conversion/FuncToEmitC/FuncToEmitCPass.cpp
@@ -0,0 +1,49 @@
+//===- FuncToEmitC.cpp - Func to EmitC Pass ---------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a pass to convert the Func dialect to the EmitC dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/FuncToEmitC/FuncToEmitCPass.h"
+
+#include "mlir/Conversion/FuncToEmitC/FuncToEmitC.h"
+#include "mlir/Dialect/EmitC/IR/EmitC.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_CONVERTFUNCTOEMITC
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+
+namespace {
+struct ConvertFuncToEmitC
+    : public impl::ConvertFuncToEmitCBase<ConvertFuncToEmitC> {
+  void runOnOperation() override;
+};
+} // namespace
+
+void ConvertFuncToEmitC::runOnOperation() {
+  ConversionTarget target(getContext());
+
+  target.addLegalDialect<emitc::EmitCDialect>();
+  target.addIllegalOp<func::CallOp>();
+  target.addIllegalOp<func::FuncOp>();
+  target.addIllegalOp<func::ReturnOp>();
+
+  RewritePatternSet patterns(&getContext());
+  populateFuncToEmitCPatterns(patterns);
+
+  if (failed(
+          applyPartialConversion(getOperation(), target, std::move(patterns))))
+    signalPassFailure();
+}
diff --git a/mlir/test/Conversion/FuncToEmitC/func-to-emitc.mlir b/mlir/test/Conversion/FuncToEmitC/func-to-emitc.mlir
new file mode 100644
index 00000000000000..a1c8af2587aa04
--- /dev/null
+++ b/mlir/test/Conversion/FuncToEmitC/func-to-emitc.mlir
@@ -0,0 +1,55 @@
+// RUN: mlir-opt -split-input-file -convert-func-to-emitc %s | FileCheck %s
+
+// CHECK-LABEL: emitc.func @foo()
+// CHECK-NEXT: emitc.return
+func.func @foo() {
+  return
+}
+
+// -----
+
+// CHECK-LABEL: emitc.func private @foo() attributes {specifiers = ["static"]}
+// CHECK-NEXT: emitc.return
+func.func private @foo() {
+  return
+}
+
+// -----
+
+// CHECK-LABEL: emitc.func @foo(%arg0: i32)
+func.func @foo(%arg0: i32) {
+  emitc.call_opaque "bar"(%arg0) : (i32) -> ()
+  return
+}
+
+// -----
+
+// CHECK-LABEL: emitc.func @foo(%arg0: i32) -> i32
+// CHECK-NEXT: emitc.return %arg0 : i32
+func.func @foo(%arg0: i32) -> i32 {
+  return %arg0 : i32
+}
+
+// -----
+
+// CHECK-LABEL: emitc.func @foo(%arg0: i32, %arg1: i32) -> i32
+func.func @foo(%arg0: i32, %arg1: i32) -> i32 {
+  %0 = "emitc.add" (%arg0, %arg1) : (i32, i32) -> i32
+  return %0 : i32
+}
+
+// -----
+
+// CHECK-LABEL: emitc.func private @return_i32(%arg0: i32) -> i32 attributes {specifiers = ["static"]}
+// CHECK-NEXT: emitc.return %arg0 : i32
+func.func private @return_i32(%arg0: i32) -> i32 {
+  return %arg0 : i32
+}
+
+// CHECK-LABEL: emitc.func @call(%arg0: i32) -> i32
+// CHECK-NEXT: %0 = emitc.call @return_i32(%arg0) : (i32) -> i32
+// CHECK-NEXT: emitc.return %0 : i32
+func.func @call(%arg0: i32) -> i32 {
+  %0 = call @return_i32(%arg0) : (i32) -> (i32)
+  return %0 : i32
+}

>From e6d60612b116122090feb97bf85595de01f87683 Mon Sep 17 00:00:00 2001
From: Marius Brehler <marius.brehler at iml.fraunhofer.de>
Date: Fri, 26 Jan 2024 15:07:34 +0000
Subject: [PATCH 3/3] [mlir][bazel] Add config for FuncToEmitC

---
 .../llvm-project-overlay/mlir/BUILD.bazel     | 27 +++++++++++++++++++
 1 file changed, 27 insertions(+)

diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index c405811c302ed4..5cbb79737a67be 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -3902,6 +3902,7 @@ cc_library(
         ":ControlFlowToSPIRV",
         ":ConversionPassIncGen",
         ":ConvertToLLVM",
+        ":FuncToEmitC",
         ":FuncToLLVM",
         ":FuncToSPIRV",
         ":GPUToGPURuntimeTransforms",
@@ -6834,6 +6835,32 @@ cc_library(
     ],
 )
 
+cc_library(
+    name = "FuncToEmitC",
+    srcs = glob([
+        "lib/Conversion/FuncToEmitC*.cpp",
+        "lib/Conversion/FuncToEmitC/*.h",
+    ]),
+    hdrs = glob([
+        "include/mlir/Conversion/FuncToEmitC/*.h",
+    ]),
+    includes = [
+        "include",
+        "lib/Conversion/FuncToEmitC",
+    ],
+    deps = [
+        ":ConversionPassIncGen",
+        ":FuncDialect",
+        ":EmitCDialect",
+        ":IR",
+        ":Pass",
+        ":Support",
+        ":TransformUtils",
+        ":Transforms",
+        "//llvm:Support",
+    ],
+)
+
 cc_library(
     name = "FuncToSPIRV",
     srcs = glob([



More information about the llvm-commits mailing list