[Mlir-commits] [mlir] [mlir][emitc] Add a `declare_func` operation (PR #80297)

Marius Brehler llvmlistbot at llvm.org
Mon Feb 5 08:04:45 PST 2024


https://github.com/marbre updated https://github.com/llvm/llvm-project/pull/80297

>From f5d94d2a55ad53bc37067acfe3565943ffa2ec30 Mon Sep 17 00:00:00 2001
From: Marius Brehler <marius.brehler at iml.fraunhofer.de>
Date: Thu, 1 Feb 2024 14:40:45 +0000
Subject: [PATCH 1/2] [mlir][emitc] Add a `declare_func` operation

This adds the `emitc.declare_func` operation that allows to emit the
declaration of an `emitc.func` at a specific location.
---
 mlir/include/mlir/Dialect/EmitC/IR/EmitC.td | 42 +++++++++++++++++++
 mlir/lib/Dialect/EmitC/IR/EmitC.cpp         | 18 +++++++++
 mlir/lib/Target/Cpp/TranslateToCpp.cpp      | 45 ++++++++++++++++++---
 mlir/test/Dialect/EmitC/invalid_ops.mlir    | 10 +++++
 mlir/test/Dialect/EmitC/ops.mlir            |  2 +
 mlir/test/Target/Cpp/declare_func.mlir      | 16 ++++++++
 6 files changed, 127 insertions(+), 6 deletions(-)
 create mode 100644 mlir/test/Target/Cpp/declare_func.mlir

diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
index 6871948d14cfc0..6a42e508411d14 100644
--- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
+++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
@@ -460,6 +460,48 @@ def EmitC_CallOp : EmitC_Op<"call",
   }];
 }
 
+def EmitC_DeclareFuncOp : EmitC_Op<"declare_func", [
+  DeclareOpInterfaceMethods<SymbolUserOpInterface>
+]> {
+  let summary = "An operation to declare a function";
+  let description = [{
+    The `declare_func` operation allows to insert a function declaration for an
+    `emitc.func` at a specific position. The operation only requires the `callee`
+    of the `emitc.func` to be specified as an attribute.
+
+    Example:
+
+    ```mlir
+    emitc.declare_func @bar
+    emitc.func @foo(%arg0: i32) -> i32 {
+      %0 = emitc.call @bar(%arg0) : (i32) -> (i32)
+      emitc.return %0 : i32
+    }
+
+    emitc.func @bar(%arg0: i32) -> i32 {
+      emitc.return %arg0 : i32
+    }
+    ```
+
+    ```c++
+    // Code emitted for the operations above.
+    int32_t bar(int32_t v1);
+    int32_t foo(int32_t v1) {
+      int32_t v2 = bar(v1);
+      return v2;
+    }
+
+    int32_t bar(int32_t v1) {
+      return v1;
+    }
+    ```
+  }];
+  let arguments = (ins FlatSymbolRefAttr:$callee);
+  let assemblyFormat = [{
+    $callee attr-dict
+  }];
+}
+
 def EmitC_FuncOp : EmitC_Op<"func", [
   AutomaticAllocationScope,
   FunctionOpInterface, IsolatedFromAbove
diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
index f384fcbefcfdbb..57345806298012 100644
--- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
+++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
@@ -393,6 +393,24 @@ FunctionType CallOp::getCalleeType() {
   return FunctionType::get(getContext(), getOperandTypes(), getResultTypes());
 }
 
+//===----------------------------------------------------------------------===//
+// DeclareFuncOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult
+DeclareFuncOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
+  // Check that the callee attribute was specified.
+  auto fnAttr = (*this)->getAttrOfType<FlatSymbolRefAttr>("callee");
+  if (!fnAttr)
+    return emitOpError("requires a 'callee' symbol reference attribute");
+  FuncOp fn = symbolTable.lookupNearestSymbolFrom<FuncOp>(*this, fnAttr);
+  if (!fn)
+    return emitOpError() << "'" << fnAttr.getValue()
+                         << "' does not reference a valid function";
+
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // FuncOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
index 0e73122dcc0bfa..bbe3f98f7a107a 100644
--- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp
+++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
@@ -14,6 +14,7 @@
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Dialect.h"
 #include "mlir/IR/Operation.h"
+#include "mlir/IR/SymbolTable.h"
 #include "mlir/Support/IndentedOstream.h"
 #include "mlir/Support/LLVM.h"
 #include "mlir/Target/Cpp/CppEmitter.h"
@@ -855,8 +856,9 @@ static LogicalResult printFunctionBody(CppEmitter &emitter,
       // needs to be printed after the closing brace.
       // When generating code for an emitc.for and emitc.verbatim op, printing a
       // trailing semicolon is handled within the printOperation function.
-      bool trailingSemicolon = !isa<cf::CondBranchOp, emitc::ForOp, emitc::IfOp,
-                                    emitc::LiteralOp, emitc::VerbatimOp>(op);
+      bool trailingSemicolon =
+          !isa<cf::CondBranchOp, emitc::DeclareFuncOp, emitc::ForOp,
+               emitc::IfOp, emitc::LiteralOp, emitc::VerbatimOp>(op);
 
       if (failed(emitter.emitOperation(
               op, /*trailingSemicolon=*/trailingSemicolon)))
@@ -938,6 +940,37 @@ static LogicalResult printOperation(CppEmitter &emitter,
   return success();
 }
 
+static LogicalResult printOperation(CppEmitter &emitter,
+                                    DeclareFuncOp declareFuncOp) {
+  CppEmitter::Scope scope(emitter);
+  raw_indented_ostream &os = emitter.ostream();
+
+  auto functionOp = SymbolTable::lookupNearestSymbolFrom<emitc::FuncOp>(
+      declareFuncOp, declareFuncOp.getCalleeAttr());
+
+  if (!functionOp)
+    return failure();
+
+  if (functionOp.getSpecifiers()) {
+    for (Attribute specifier : functionOp.getSpecifiersAttr()) {
+      os << cast<StringAttr>(specifier).str() << " ";
+    }
+  }
+
+  if (failed(emitter.emitTypes(functionOp.getLoc(),
+                               functionOp.getFunctionType().getResults())))
+    return failure();
+  os << " " << functionOp.getName();
+
+  os << "(";
+  Operation *operation = functionOp.getOperation();
+  if (failed(printFunctionArgs(emitter, operation, functionOp.getArguments())))
+    return failure();
+  os << ");";
+
+  return success();
+}
+
 CppEmitter::CppEmitter(raw_ostream &os, bool declareVariablesAtTop)
     : os(os), declareVariablesAtTop(declareVariablesAtTop) {
   valueInScopeCount.push(0);
@@ -1251,10 +1284,10 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) {
           // EmitC ops.
           .Case<emitc::AddOp, emitc::ApplyOp, emitc::AssignOp, emitc::CallOp,
                 emitc::CallOpaqueOp, emitc::CastOp, emitc::CmpOp,
-                emitc::ConstantOp, emitc::DivOp, emitc::ExpressionOp,
-                emitc::ForOp, emitc::FuncOp, emitc::IfOp, emitc::IncludeOp,
-                emitc::MulOp, emitc::RemOp, emitc::ReturnOp, emitc::SubOp,
-                emitc::VariableOp, emitc::VerbatimOp>(
+                emitc::ConstantOp, emitc::DeclareFuncOp, emitc::DivOp,
+                emitc::ExpressionOp, emitc::ForOp, emitc::FuncOp, emitc::IfOp,
+                emitc::IncludeOp, emitc::MulOp, emitc::RemOp, emitc::ReturnOp,
+                emitc::SubOp, emitc::VariableOp, emitc::VerbatimOp>(
               [&](auto op) { return printOperation(*this, op); })
           // Func ops.
           .Case<func::CallOp, func::ConstantOp, func::FuncOp, func::ReturnOp>(
diff --git a/mlir/test/Dialect/EmitC/invalid_ops.mlir b/mlir/test/Dialect/EmitC/invalid_ops.mlir
index 6d2471b4d2b486..f0477f6342804c 100644
--- a/mlir/test/Dialect/EmitC/invalid_ops.mlir
+++ b/mlir/test/Dialect/EmitC/invalid_ops.mlir
@@ -321,3 +321,13 @@ func.func @return_inside_func.func(%0: i32) -> (i32) {
 
 // expected-error at +1 {{expected non-function type}}
 emitc.func @func_variadic(...)
+
+// -----
+
+// expected-error at +1 {{'emitc.declare_func' op 'bar' does not reference a valid function}}
+emitc.declare_func @bar
+
+// -----
+
+// expected-error at +1 {{'emitc.declare_func' op requires attribute 'callee'}}
+"emitc.declare_func"()  : () -> ()
diff --git a/mlir/test/Dialect/EmitC/ops.mlir b/mlir/test/Dialect/EmitC/ops.mlir
index e03c3d58c3e847..93119be14c908b 100644
--- a/mlir/test/Dialect/EmitC/ops.mlir
+++ b/mlir/test/Dialect/EmitC/ops.mlir
@@ -15,6 +15,8 @@ func.func @f(%arg0: i32, %f: !emitc.opaque<"int32_t">) {
   return
 }
 
+emitc.declare_func @func
+
 emitc.func @func(%arg0 : i32) {
   emitc.call_opaque "foo"(%arg0) : (i32) -> ()
   emitc.return
diff --git a/mlir/test/Target/Cpp/declare_func.mlir b/mlir/test/Target/Cpp/declare_func.mlir
new file mode 100644
index 00000000000000..72c087a3388e20
--- /dev/null
+++ b/mlir/test/Target/Cpp/declare_func.mlir
@@ -0,0 +1,16 @@
+// RUN: mlir-translate -mlir-to-cpp %s | FileCheck %s
+
+// CHECK: int32_t bar(int32_t [[V1:[^ ]*]]);
+emitc.declare_func @bar
+// CHECK: int32_t bar(int32_t [[V1:[^ ]*]]) {
+emitc.func @bar(%arg0: i32) -> i32 {
+    emitc.return %arg0 : i32
+}
+
+
+// CHECK: static inline int32_t foo(int32_t [[V1:[^ ]*]]);
+emitc.declare_func @foo
+// CHECK: static inline int32_t foo(int32_t [[V1:[^ ]*]]) {
+emitc.func @foo(%arg0: i32) -> i32 attributes {specifiers = ["static","inline"]} {
+    emitc.return %arg0 : i32
+}

>From 0864be4cef1aa0e757fef2c8403da1bced41d2ca Mon Sep 17 00:00:00 2001
From: Marius Brehler <marius.brehler at iml.fraunhofer.de>
Date: Sat, 3 Feb 2024 15:48:15 +0000
Subject: [PATCH 2/2] Rename `callee` to `sym_name` and revise checking if
 specified

---
 mlir/include/mlir/Dialect/EmitC/IR/EmitC.td | 4 ++--
 mlir/lib/Dialect/EmitC/IR/EmitC.cpp         | 6 +++---
 mlir/lib/Target/Cpp/TranslateToCpp.cpp      | 2 +-
 mlir/test/Dialect/EmitC/invalid_ops.mlir    | 2 +-
 4 files changed, 7 insertions(+), 7 deletions(-)

diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
index 6a42e508411d14..39cc360cef41d4 100644
--- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
+++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
@@ -496,9 +496,9 @@ def EmitC_DeclareFuncOp : EmitC_Op<"declare_func", [
     }
     ```
   }];
-  let arguments = (ins FlatSymbolRefAttr:$callee);
+  let arguments = (ins FlatSymbolRefAttr:$sym_name);
   let assemblyFormat = [{
-    $callee attr-dict
+    $sym_name attr-dict
   }];
 }
 
diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
index 57345806298012..0fe2c0dcfc7c53 100644
--- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
+++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
@@ -399,10 +399,10 @@ FunctionType CallOp::getCalleeType() {
 
 LogicalResult
 DeclareFuncOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
-  // Check that the callee attribute was specified.
-  auto fnAttr = (*this)->getAttrOfType<FlatSymbolRefAttr>("callee");
+  // Check that the sym_name attribute was specified.
+  auto fnAttr = getSymNameAttr();
   if (!fnAttr)
-    return emitOpError("requires a 'callee' symbol reference attribute");
+    return emitOpError("requires a 'sym_name' symbol reference attribute");
   FuncOp fn = symbolTable.lookupNearestSymbolFrom<FuncOp>(*this, fnAttr);
   if (!fn)
     return emitOpError() << "'" << fnAttr.getValue()
diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
index bbe3f98f7a107a..a53d7d1701a90b 100644
--- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp
+++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
@@ -946,7 +946,7 @@ static LogicalResult printOperation(CppEmitter &emitter,
   raw_indented_ostream &os = emitter.ostream();
 
   auto functionOp = SymbolTable::lookupNearestSymbolFrom<emitc::FuncOp>(
-      declareFuncOp, declareFuncOp.getCalleeAttr());
+      declareFuncOp, declareFuncOp.getSymNameAttr());
 
   if (!functionOp)
     return failure();
diff --git a/mlir/test/Dialect/EmitC/invalid_ops.mlir b/mlir/test/Dialect/EmitC/invalid_ops.mlir
index f0477f6342804c..121a2163d38320 100644
--- a/mlir/test/Dialect/EmitC/invalid_ops.mlir
+++ b/mlir/test/Dialect/EmitC/invalid_ops.mlir
@@ -329,5 +329,5 @@ emitc.declare_func @bar
 
 // -----
 
-// expected-error at +1 {{'emitc.declare_func' op requires attribute 'callee'}}
+// expected-error at +1 {{'emitc.declare_func' op requires attribute 'sym_name'}}
 "emitc.declare_func"()  : () -> ()



More information about the Mlir-commits mailing list