[Mlir-commits] [mlir] mlir: add an operation to EmitC for function template instantiation (PR #100895)
Rohan Yadav
llvmlistbot at llvm.org
Mon Jul 29 09:50:22 PDT 2024
https://github.com/rohany updated https://github.com/llvm/llvm-project/pull/100895
>From 49943226eced9535abc3da7a3659ea32ce8f789e Mon Sep 17 00:00:00 2001
From: Rohan Yadav <rohany at alumni.cmu.edu>
Date: Sat, 27 Jul 2024 12:37:30 -0700
Subject: [PATCH] mlir: add an operation to EmitC for function template
instantiation
This commit adds an `emitc.instantiate_function_template` operation to
allow for the expression of function template instantiation. Without
this operation, there is no easy way to express a C++ program like:
```
auto x = ...;
auto y = ...;
const void* fptr = &f<decltype(x), decltype(y)>;
```
Doing so is necessary to generate code that interacts with some lower
level APIs for launching parallel work into runtime systems.
---
mlir/include/mlir/Dialect/EmitC/IR/EmitC.td | 27 +++++++++++++++++++
mlir/include/mlir/IR/DialectInterface.h | 2 ++
mlir/lib/Target/Cpp/TranslateToCpp.cpp | 26 ++++++++++++++++++
mlir/test/Dialect/EmitC/ops.mlir | 6 +++++
.../Cpp/instantiate_function_template.mlir | 17 ++++++++++++
5 files changed, 78 insertions(+)
create mode 100644 mlir/test/Target/Cpp/instantiate_function_template.mlir
diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
index 452302c565139c..231d041cf608ae 100644
--- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
+++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
@@ -1260,5 +1260,32 @@ def EmitC_SubscriptOp : EmitC_Op<"subscript", []> {
let assemblyFormat = "$value `[` $indices `]` attr-dict `:` functional-type(operands, results)";
}
+def EmitC_InstantiateFunctionTemplateOp : EmitC_Op<"instantiate_function_template", []> {
+ let summary = "Instantiate template operation";
+ let description = [{
+ Instantiate a function template with a given set of types
+ (given by the values as argument to this operation) to obtain
+ a function pointer.
+
+ Example:
+
+ ```mlir
+ %c1 = "emitc.constant"() <{value = 7 : i32}> : () -> i32
+ %0 = emitc.instantiate_function_template "func_template"<%c1> : (i32) -> !emitc.ptr<!emitc.opaque<"void">>
+ ```
+ Translates to the C++:
+ ```c++
+ int32_t v1 = 7;
+ void* v2 = &func_template<decltype(v1)>;
+ ```
+ }];
+ let arguments = (ins
+ Arg<StrAttr, "the C++ function to instantiate">:$callee,
+ Variadic<EmitCType>:$args
+ );
+ let results = (outs EmitC_PointerType);
+ let assemblyFormat = "$callee `<` $args `>` attr-dict `:` functional-type($args, results)";
+}
+
#endif // MLIR_DIALECT_EMITC_IR_EMITC
diff --git a/mlir/include/mlir/IR/DialectInterface.h b/mlir/include/mlir/IR/DialectInterface.h
index 3a7ad87b161eea..1cec1e0d7ce9c6 100644
--- a/mlir/include/mlir/IR/DialectInterface.h
+++ b/mlir/include/mlir/IR/DialectInterface.h
@@ -13,6 +13,8 @@
#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/STLExtras.h"
+#include <vector>
+
namespace mlir {
class Dialect;
class MLIRContext;
diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
index 626638282efe1d..60c5c7fb906fbc 100644
--- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp
+++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
@@ -656,6 +656,30 @@ static LogicalResult printOperation(CppEmitter &emitter,
return success();
}
+static LogicalResult
+printOperation(CppEmitter &emitter,
+ emitc::InstantiateFunctionTemplateOp instOp) {
+
+ raw_ostream &os = emitter.ostream();
+ Operation &op = *instOp.getOperation();
+
+ if (failed(emitter.emitAssignPrefix(op)))
+ return failure();
+ os << "&" << instOp.getCallee() << "<";
+
+ auto emitArgs = [&](mlir::Value val) -> LogicalResult {
+ os << "decltype(";
+ if (failed(emitter.emitOperand(val)))
+ return failure();
+ os << ")";
+ return success();
+ };
+ if (failed(interleaveCommaWithError(instOp.getArgs(), os, emitArgs)))
+ return failure();
+ os << ">";
+ return success();
+}
+
static LogicalResult printOperation(CppEmitter &emitter,
emitc::ApplyOp applyOp) {
raw_ostream &os = emitter.ostream();
@@ -1508,6 +1532,8 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) {
.Case<func::CallOp, func::FuncOp, func::ReturnOp>(
[&](auto op) { return printOperation(*this, op); })
.Case<emitc::LiteralOp>([&](auto op) { return success(); })
+ .Case<emitc::InstantiateFunctionTemplateOp>(
+ [&](auto op) { return printOperation(*this, op); })
.Default([&](Operation *) {
return op.emitOpError("unable to find printer for op");
});
diff --git a/mlir/test/Dialect/EmitC/ops.mlir b/mlir/test/Dialect/EmitC/ops.mlir
index 51c484a633eec9..6db53070d833f7 100644
--- a/mlir/test/Dialect/EmitC/ops.mlir
+++ b/mlir/test/Dialect/EmitC/ops.mlir
@@ -224,6 +224,12 @@ func.func @test_subscript(%arg0 : !emitc.array<2x3xf32>, %arg1 : !emitc.ptr<i32>
return
}
+func.func @test_instantiate_template() {
+ %c1 = "emitc.constant"() <{value = 7 : i32}> : () -> i32
+ %0 = emitc.instantiate_function_template "func_template"<%c1> : (i32) -> !emitc.ptr<!emitc.opaque<"void">>
+ return
+}
+
emitc.verbatim "#ifdef __cplusplus"
emitc.verbatim "extern \"C\" {"
emitc.verbatim "#endif // __cplusplus"
diff --git a/mlir/test/Target/Cpp/instantiate_function_template.mlir b/mlir/test/Target/Cpp/instantiate_function_template.mlir
new file mode 100644
index 00000000000000..d6aeb9f6f7a578
--- /dev/null
+++ b/mlir/test/Target/Cpp/instantiate_function_template.mlir
@@ -0,0 +1,17 @@
+// RUN: mlir-translate -mlir-to-cpp %s | FileCheck %s -check-prefix=CPP-DEFAULT
+// RUN: mlir-translate -mlir-to-cpp -declare-variables-at-top %s | FileCheck %s -check-prefix=CPP-DECLTOP
+
+func.func @emitc_instantiate_template() {
+ %c1 = "emitc.constant"() <{value = 7 : i32}> : () -> i32
+ %0 = emitc.instantiate_function_template "func_template"<%c1> : (i32) -> !emitc.ptr<!emitc.opaque<"void">>
+ return
+}
+// CPP-DEFAULT: void emitc_instantiate_template() {
+// CPP-DEFAULT-NEXT: int32_t [[V0:[^ ]*]] = 7;
+// CPP-DEFAULT-NEXT: void* [[V1:[^ ]*]] = &func_template<decltype([[V0]])>;
+
+// CPP-DECLTOP: void emitc_instantiate_template() {
+// CPP-DECLTOP-NEXT: int32_t [[V0:[^ ]*]];
+// CPP-DECLTOP-NEXT: void* [[V1:[^ ]*]];
+// CPP-DECLTOP-NEXT: [[V0]] = 7;
+// CPP-DECLTOP-NEXT: [[V1]] = &func_template<decltype([[V0]])>;
More information about the Mlir-commits
mailing list