[Mlir-commits] [mlir] mlir: add an operation to EmitC for function template instantiation (PR #100895)

Rohan Yadav llvmlistbot at llvm.org
Mon Jul 29 09:50:22 PDT 2024


https://github.com/rohany updated https://github.com/llvm/llvm-project/pull/100895

>From 49943226eced9535abc3da7a3659ea32ce8f789e Mon Sep 17 00:00:00 2001
From: Rohan Yadav <rohany at alumni.cmu.edu>
Date: Sat, 27 Jul 2024 12:37:30 -0700
Subject: [PATCH] mlir: add an operation to EmitC for function template
 instantiation

This commit adds an `emitc.instantiate_function_template` operation to
allow for the expression of function template instantiation. Without
this operation, there is no easy way to express a C++ program like:
```
auto x = ...;
auto y = ...;
const void* fptr = &f<decltype(x), decltype(y)>;
```
Doing so is necessary to generate code that interacts with some lower
level APIs for launching parallel work into runtime systems.
---
 mlir/include/mlir/Dialect/EmitC/IR/EmitC.td   | 27 +++++++++++++++++++
 mlir/include/mlir/IR/DialectInterface.h       |  2 ++
 mlir/lib/Target/Cpp/TranslateToCpp.cpp        | 26 ++++++++++++++++++
 mlir/test/Dialect/EmitC/ops.mlir              |  6 +++++
 .../Cpp/instantiate_function_template.mlir    | 17 ++++++++++++
 5 files changed, 78 insertions(+)
 create mode 100644 mlir/test/Target/Cpp/instantiate_function_template.mlir

diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
index 452302c565139c..231d041cf608ae 100644
--- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
+++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
@@ -1260,5 +1260,32 @@ def EmitC_SubscriptOp : EmitC_Op<"subscript", []> {
   let assemblyFormat = "$value `[` $indices `]` attr-dict `:` functional-type(operands, results)";
 }
 
+def EmitC_InstantiateFunctionTemplateOp : EmitC_Op<"instantiate_function_template", []> {
+  let summary = "Instantiate template operation";
+  let description = [{
+    Instantiate a function template with a given set of types
+    (given by the values as argument to this operation) to obtain
+    a function pointer.
+
+    Example:
+
+    ```mlir
+    %c1 = "emitc.constant"() <{value = 7 : i32}> : () -> i32
+    %0 = emitc.instantiate_function_template "func_template"<%c1> : (i32) -> !emitc.ptr<!emitc.opaque<"void">>
+    ```
+    Translates to the C++:
+    ```c++
+    int32_t v1 = 7;
+    void* v2 = &func_template<decltype(v1)>;
+    ```
+  }];
+  let arguments = (ins
+    Arg<StrAttr, "the C++ function to instantiate">:$callee,
+    Variadic<EmitCType>:$args
+  );
+  let results = (outs EmitC_PointerType);
+  let assemblyFormat = "$callee `<` $args `>` attr-dict `:` functional-type($args, results)";
+}
+
 
 #endif // MLIR_DIALECT_EMITC_IR_EMITC
diff --git a/mlir/include/mlir/IR/DialectInterface.h b/mlir/include/mlir/IR/DialectInterface.h
index 3a7ad87b161eea..1cec1e0d7ce9c6 100644
--- a/mlir/include/mlir/IR/DialectInterface.h
+++ b/mlir/include/mlir/IR/DialectInterface.h
@@ -13,6 +13,8 @@
 #include "llvm/ADT/DenseSet.h"
 #include "llvm/ADT/STLExtras.h"
 
+#include <vector>
+
 namespace mlir {
 class Dialect;
 class MLIRContext;
diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
index 626638282efe1d..60c5c7fb906fbc 100644
--- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp
+++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
@@ -656,6 +656,30 @@ static LogicalResult printOperation(CppEmitter &emitter,
   return success();
 }
 
+static LogicalResult
+printOperation(CppEmitter &emitter,
+               emitc::InstantiateFunctionTemplateOp instOp) {
+
+  raw_ostream &os = emitter.ostream();
+  Operation &op = *instOp.getOperation();
+
+  if (failed(emitter.emitAssignPrefix(op)))
+    return failure();
+  os << "&" << instOp.getCallee() << "<";
+
+  auto emitArgs = [&](mlir::Value val) -> LogicalResult {
+    os << "decltype(";
+    if (failed(emitter.emitOperand(val)))
+      return failure();
+    os << ")";
+    return success();
+  };
+  if (failed(interleaveCommaWithError(instOp.getArgs(), os, emitArgs)))
+    return failure();
+  os << ">";
+  return success();
+}
+
 static LogicalResult printOperation(CppEmitter &emitter,
                                     emitc::ApplyOp applyOp) {
   raw_ostream &os = emitter.ostream();
@@ -1508,6 +1532,8 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) {
           .Case<func::CallOp, func::FuncOp, func::ReturnOp>(
               [&](auto op) { return printOperation(*this, op); })
           .Case<emitc::LiteralOp>([&](auto op) { return success(); })
+          .Case<emitc::InstantiateFunctionTemplateOp>(
+              [&](auto op) { return printOperation(*this, op); })
           .Default([&](Operation *) {
             return op.emitOpError("unable to find printer for op");
           });
diff --git a/mlir/test/Dialect/EmitC/ops.mlir b/mlir/test/Dialect/EmitC/ops.mlir
index 51c484a633eec9..6db53070d833f7 100644
--- a/mlir/test/Dialect/EmitC/ops.mlir
+++ b/mlir/test/Dialect/EmitC/ops.mlir
@@ -224,6 +224,12 @@ func.func @test_subscript(%arg0 : !emitc.array<2x3xf32>, %arg1 : !emitc.ptr<i32>
   return
 }
 
+func.func @test_instantiate_template() {
+  %c1 = "emitc.constant"() <{value = 7 : i32}> : () -> i32
+  %0 = emitc.instantiate_function_template "func_template"<%c1> : (i32) -> !emitc.ptr<!emitc.opaque<"void">>
+  return
+}
+
 emitc.verbatim "#ifdef __cplusplus"
 emitc.verbatim "extern \"C\" {"
 emitc.verbatim "#endif  // __cplusplus"
diff --git a/mlir/test/Target/Cpp/instantiate_function_template.mlir b/mlir/test/Target/Cpp/instantiate_function_template.mlir
new file mode 100644
index 00000000000000..d6aeb9f6f7a578
--- /dev/null
+++ b/mlir/test/Target/Cpp/instantiate_function_template.mlir
@@ -0,0 +1,17 @@
+// RUN: mlir-translate -mlir-to-cpp %s | FileCheck %s -check-prefix=CPP-DEFAULT
+// RUN: mlir-translate -mlir-to-cpp -declare-variables-at-top %s | FileCheck %s -check-prefix=CPP-DECLTOP
+
+func.func @emitc_instantiate_template() {
+  %c1 = "emitc.constant"() <{value = 7 : i32}> : () -> i32
+  %0 = emitc.instantiate_function_template "func_template"<%c1> : (i32) -> !emitc.ptr<!emitc.opaque<"void">>
+  return
+}
+// CPP-DEFAULT: void emitc_instantiate_template() {
+// CPP-DEFAULT-NEXT: int32_t [[V0:[^ ]*]] = 7;
+// CPP-DEFAULT-NEXT: void* [[V1:[^ ]*]] = &func_template<decltype([[V0]])>;
+
+// CPP-DECLTOP: void emitc_instantiate_template() {
+// CPP-DECLTOP-NEXT: int32_t [[V0:[^ ]*]];
+// CPP-DECLTOP-NEXT: void* [[V1:[^ ]*]];
+// CPP-DECLTOP-NEXT: [[V0]] = 7;
+// CPP-DECLTOP-NEXT: [[V1]] = &func_template<decltype([[V0]])>;



More information about the Mlir-commits mailing list