[Mlir-commits] [mlir] 03881dc - [mlir][emitc] Add a `declare_func` operation (#80297)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Feb 5 23:49:14 PST 2024
Author: Marius Brehler
Date: 2024-02-06T08:49:10+01:00
New Revision: 03881dc0a7695f4c499cc07042b8c59ad7b7335a
URL: https://github.com/llvm/llvm-project/commit/03881dc0a7695f4c499cc07042b8c59ad7b7335a
DIFF: https://github.com/llvm/llvm-project/commit/03881dc0a7695f4c499cc07042b8c59ad7b7335a.diff
LOG: [mlir][emitc] Add a `declare_func` operation (#80297)
This adds the `emitc.declare_func` operation that allows to emit the
declaration of an `emitc.func` at a specific location.
Added:
mlir/test/Target/Cpp/declare_func.mlir
Modified:
mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
mlir/lib/Dialect/EmitC/IR/EmitC.cpp
mlir/lib/Target/Cpp/TranslateToCpp.cpp
mlir/test/Dialect/EmitC/invalid_ops.mlir
mlir/test/Dialect/EmitC/ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
index 6871948d14cfc..39cc360cef41d 100644
--- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
+++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
@@ -460,6 +460,48 @@ def EmitC_CallOp : EmitC_Op<"call",
}];
}
+def EmitC_DeclareFuncOp : EmitC_Op<"declare_func", [
+ DeclareOpInterfaceMethods<SymbolUserOpInterface>
+]> {
+ let summary = "An operation to declare a function";
+ let description = [{
+ The `declare_func` operation allows to insert a function declaration for an
+ `emitc.func` at a specific position. The operation only requires the `callee`
+ of the `emitc.func` to be specified as an attribute.
+
+ Example:
+
+ ```mlir
+ emitc.declare_func @bar
+ emitc.func @foo(%arg0: i32) -> i32 {
+ %0 = emitc.call @bar(%arg0) : (i32) -> (i32)
+ emitc.return %0 : i32
+ }
+
+ emitc.func @bar(%arg0: i32) -> i32 {
+ emitc.return %arg0 : i32
+ }
+ ```
+
+ ```c++
+ // Code emitted for the operations above.
+ int32_t bar(int32_t v1);
+ int32_t foo(int32_t v1) {
+ int32_t v2 = bar(v1);
+ return v2;
+ }
+
+ int32_t bar(int32_t v1) {
+ return v1;
+ }
+ ```
+ }];
+ let arguments = (ins FlatSymbolRefAttr:$sym_name);
+ let assemblyFormat = [{
+ $sym_name attr-dict
+ }];
+}
+
def EmitC_FuncOp : EmitC_Op<"func", [
AutomaticAllocationScope,
FunctionOpInterface, IsolatedFromAbove
diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
index f384fcbefcfdb..0fe2c0dcfc7c5 100644
--- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
+++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
@@ -393,6 +393,24 @@ FunctionType CallOp::getCalleeType() {
return FunctionType::get(getContext(), getOperandTypes(), getResultTypes());
}
+//===----------------------------------------------------------------------===//
+// DeclareFuncOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult
+DeclareFuncOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
+ // Check that the sym_name attribute was specified.
+ auto fnAttr = getSymNameAttr();
+ if (!fnAttr)
+ return emitOpError("requires a 'sym_name' symbol reference attribute");
+ FuncOp fn = symbolTable.lookupNearestSymbolFrom<FuncOp>(*this, fnAttr);
+ if (!fn)
+ return emitOpError() << "'" << fnAttr.getValue()
+ << "' does not reference a valid function";
+
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// FuncOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
index 0e73122dcc0bf..a53d7d1701a90 100644
--- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp
+++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
@@ -14,6 +14,7 @@
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Operation.h"
+#include "mlir/IR/SymbolTable.h"
#include "mlir/Support/IndentedOstream.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Target/Cpp/CppEmitter.h"
@@ -855,8 +856,9 @@ static LogicalResult printFunctionBody(CppEmitter &emitter,
// needs to be printed after the closing brace.
// When generating code for an emitc.for and emitc.verbatim op, printing a
// trailing semicolon is handled within the printOperation function.
- bool trailingSemicolon = !isa<cf::CondBranchOp, emitc::ForOp, emitc::IfOp,
- emitc::LiteralOp, emitc::VerbatimOp>(op);
+ bool trailingSemicolon =
+ !isa<cf::CondBranchOp, emitc::DeclareFuncOp, emitc::ForOp,
+ emitc::IfOp, emitc::LiteralOp, emitc::VerbatimOp>(op);
if (failed(emitter.emitOperation(
op, /*trailingSemicolon=*/trailingSemicolon)))
@@ -938,6 +940,37 @@ static LogicalResult printOperation(CppEmitter &emitter,
return success();
}
+static LogicalResult printOperation(CppEmitter &emitter,
+ DeclareFuncOp declareFuncOp) {
+ CppEmitter::Scope scope(emitter);
+ raw_indented_ostream &os = emitter.ostream();
+
+ auto functionOp = SymbolTable::lookupNearestSymbolFrom<emitc::FuncOp>(
+ declareFuncOp, declareFuncOp.getSymNameAttr());
+
+ if (!functionOp)
+ return failure();
+
+ if (functionOp.getSpecifiers()) {
+ for (Attribute specifier : functionOp.getSpecifiersAttr()) {
+ os << cast<StringAttr>(specifier).str() << " ";
+ }
+ }
+
+ if (failed(emitter.emitTypes(functionOp.getLoc(),
+ functionOp.getFunctionType().getResults())))
+ return failure();
+ os << " " << functionOp.getName();
+
+ os << "(";
+ Operation *operation = functionOp.getOperation();
+ if (failed(printFunctionArgs(emitter, operation, functionOp.getArguments())))
+ return failure();
+ os << ");";
+
+ return success();
+}
+
CppEmitter::CppEmitter(raw_ostream &os, bool declareVariablesAtTop)
: os(os), declareVariablesAtTop(declareVariablesAtTop) {
valueInScopeCount.push(0);
@@ -1251,10 +1284,10 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) {
// EmitC ops.
.Case<emitc::AddOp, emitc::ApplyOp, emitc::AssignOp, emitc::CallOp,
emitc::CallOpaqueOp, emitc::CastOp, emitc::CmpOp,
- emitc::ConstantOp, emitc::DivOp, emitc::ExpressionOp,
- emitc::ForOp, emitc::FuncOp, emitc::IfOp, emitc::IncludeOp,
- emitc::MulOp, emitc::RemOp, emitc::ReturnOp, emitc::SubOp,
- emitc::VariableOp, emitc::VerbatimOp>(
+ emitc::ConstantOp, emitc::DeclareFuncOp, emitc::DivOp,
+ emitc::ExpressionOp, emitc::ForOp, emitc::FuncOp, emitc::IfOp,
+ emitc::IncludeOp, emitc::MulOp, emitc::RemOp, emitc::ReturnOp,
+ emitc::SubOp, emitc::VariableOp, emitc::VerbatimOp>(
[&](auto op) { return printOperation(*this, op); })
// Func ops.
.Case<func::CallOp, func::ConstantOp, func::FuncOp, func::ReturnOp>(
diff --git a/mlir/test/Dialect/EmitC/invalid_ops.mlir b/mlir/test/Dialect/EmitC/invalid_ops.mlir
index 6d2471b4d2b48..121a2163d3832 100644
--- a/mlir/test/Dialect/EmitC/invalid_ops.mlir
+++ b/mlir/test/Dialect/EmitC/invalid_ops.mlir
@@ -321,3 +321,13 @@ func.func @return_inside_func.func(%0: i32) -> (i32) {
// expected-error at +1 {{expected non-function type}}
emitc.func @func_variadic(...)
+
+// -----
+
+// expected-error at +1 {{'emitc.declare_func' op 'bar' does not reference a valid function}}
+emitc.declare_func @bar
+
+// -----
+
+// expected-error at +1 {{'emitc.declare_func' op requires attribute 'sym_name'}}
+"emitc.declare_func"() : () -> ()
diff --git a/mlir/test/Dialect/EmitC/ops.mlir b/mlir/test/Dialect/EmitC/ops.mlir
index e03c3d58c3e84..93119be14c908 100644
--- a/mlir/test/Dialect/EmitC/ops.mlir
+++ b/mlir/test/Dialect/EmitC/ops.mlir
@@ -15,6 +15,8 @@ func.func @f(%arg0: i32, %f: !emitc.opaque<"int32_t">) {
return
}
+emitc.declare_func @func
+
emitc.func @func(%arg0 : i32) {
emitc.call_opaque "foo"(%arg0) : (i32) -> ()
emitc.return
diff --git a/mlir/test/Target/Cpp/declare_func.mlir b/mlir/test/Target/Cpp/declare_func.mlir
new file mode 100644
index 0000000000000..72c087a3388e2
--- /dev/null
+++ b/mlir/test/Target/Cpp/declare_func.mlir
@@ -0,0 +1,16 @@
+// RUN: mlir-translate -mlir-to-cpp %s | FileCheck %s
+
+// CHECK: int32_t bar(int32_t [[V1:[^ ]*]]);
+emitc.declare_func @bar
+// CHECK: int32_t bar(int32_t [[V1:[^ ]*]]) {
+emitc.func @bar(%arg0: i32) -> i32 {
+ emitc.return %arg0 : i32
+}
+
+
+// CHECK: static inline int32_t foo(int32_t [[V1:[^ ]*]]);
+emitc.declare_func @foo
+// CHECK: static inline int32_t foo(int32_t [[V1:[^ ]*]]) {
+emitc.func @foo(%arg0: i32) -> i32 attributes {specifiers = ["static","inline"]} {
+ emitc.return %arg0 : i32
+}
More information about the Mlir-commits
mailing list