[Mlir-commits] [mlir] d6efb6f - Rework ExecutionEngine::invoke() to make it more friendly to use from C++
Mehdi Amini
llvmlistbot at llvm.org
Fri Feb 5 17:33:00 PST 2021
Author: Mehdi Amini
Date: 2021-02-06T01:32:50Z
New Revision: d6efb6fc86a6ded63fc50e3a31377a1f4aa33c6e
URL: https://github.com/llvm/llvm-project/commit/d6efb6fc86a6ded63fc50e3a31377a1f4aa33c6e
DIFF: https://github.com/llvm/llvm-project/commit/d6efb6fc86a6ded63fc50e3a31377a1f4aa33c6e.diff
LOG: Rework ExecutionEngine::invoke() to make it more friendly to use from C++
This new invoke will pack a list of argument before calling the
`invokePacked` method. It accepts returned value as output argument
wrapped in `ExecutionEngine::Result<T>`, and delegate the packing of
arguments to a trait to allow for customization for some types.
Reviewed By: ftynse
Differential Revision: https://reviews.llvm.org/D95961
Added:
mlir/unittests/ExecutionEngine/CMakeLists.txt
mlir/unittests/ExecutionEngine/Invoke.cpp
Modified:
mlir/include/mlir/ExecutionEngine/ExecutionEngine.h
mlir/lib/ExecutionEngine/ExecutionEngine.cpp
mlir/unittests/CMakeLists.txt
Removed:
################################################################################
diff --git a/mlir/include/mlir/ExecutionEngine/ExecutionEngine.h b/mlir/include/mlir/ExecutionEngine/ExecutionEngine.h
index 1b6b0a8f670c..5c9f0fa1c3c9 100644
--- a/mlir/include/mlir/ExecutionEngine/ExecutionEngine.h
+++ b/mlir/include/mlir/ExecutionEngine/ExecutionEngine.h
@@ -101,9 +101,66 @@ class ExecutionEngine {
/// pointer to it. Propagates errors in case of failure.
llvm::Expected<void (*)(void **)> lookup(StringRef name) const;
+ /// Invokes the function with the given name passing it the list of opaque
+ /// pointers to the actual arguments.
+ llvm::Error invokePacked(StringRef name,
+ MutableArrayRef<void *> args = llvm::None);
+
+ /// Trait that defines how a given type is passed to the JIT code. This
+ /// defaults to passing the address but can be specialized.
+ template <typename T>
+ struct Argument {
+ static void pack(SmallVectorImpl<void *> &args, T &val) {
+ args.push_back(&val);
+ }
+ };
+
+ /// Tag to wrap an output parameter when invoking a jitted function.
+ template <typename T>
+ struct Result {
+ Result(T &result) : value(result) {}
+ T &value;
+ };
+
+ /// Helper function to wrap an output operand when using
+ /// ExecutionEngine::invoke.
+ template <typename T>
+ static Result<T> result(T &t) {
+ return Result<T>(t);
+ }
+
+ // Specialization for output parameter: their address is forwarded directly to
+ // the native code.
+ template <typename T>
+ struct Argument<Result<T>> {
+ static void pack(SmallVectorImpl<void *> &args, Result<T> &result) {
+ args.push_back(&result.value);
+ }
+ };
+
/// Invokes the function with the given name passing it the list of arguments
- /// as a list of opaque pointers.
- llvm::Error invoke(StringRef name, MutableArrayRef<void *> args = llvm::None);
+ /// by value. Function result can be obtain through output parameter using the
+ /// `Result` wrapper defined above. For example:
+ ///
+ /// func @foo(%arg0 : i32) -> i32 attributes { llvm.emit_c_interface }
+ ///
+ /// can be invoked:
+ ///
+ /// int32_t result = 0;
+ /// llvm::Error error = jit->invoke("foo", 42,
+ /// result(result));
+ template <typename... Args>
+ llvm::Error invoke(StringRef funcName, Args... args) {
+ const std::string adapterName =
+ std::string("_mlir_ciface_") + funcName.str();
+ llvm::SmallVector<void *> argsArray;
+ // Pack every arguments in an array of pointers. Delegate the packing to a
+ // trait so that it can be overridden per argument type.
+ // TODO: replace with a fold expression when migrating to C++17.
+ int dummy[] = {0, ((void)Argument<Args>::pack(argsArray, args), 0)...};
+ (void)dummy;
+ return invokePacked(adapterName, argsArray);
+ }
/// Set the target triple on the module. This is implicitly done when creating
/// the engine.
diff --git a/mlir/lib/ExecutionEngine/ExecutionEngine.cpp b/mlir/lib/ExecutionEngine/ExecutionEngine.cpp
index 9cd2054530ef..5d9d4e19a14e 100644
--- a/mlir/lib/ExecutionEngine/ExecutionEngine.cpp
+++ b/mlir/lib/ExecutionEngine/ExecutionEngine.cpp
@@ -339,7 +339,8 @@ Expected<void (*)(void **)> ExecutionEngine::lookup(StringRef name) const {
return fptr;
}
-Error ExecutionEngine::invoke(StringRef name, MutableArrayRef<void *> args) {
+Error ExecutionEngine::invokePacked(StringRef name,
+ MutableArrayRef<void *> args) {
auto expectedFPtr = lookup(name);
if (!expectedFPtr)
return expectedFPtr.takeError();
diff --git a/mlir/unittests/CMakeLists.txt b/mlir/unittests/CMakeLists.txt
index 851092c5b56a..d8d52fc1fd53 100644
--- a/mlir/unittests/CMakeLists.txt
+++ b/mlir/unittests/CMakeLists.txt
@@ -7,6 +7,7 @@ endfunction()
add_subdirectory(Analysis)
add_subdirectory(Dialect)
+add_subdirectory(ExecutionEngine)
add_subdirectory(IR)
add_subdirectory(Pass)
add_subdirectory(SDBM)
diff --git a/mlir/unittests/ExecutionEngine/CMakeLists.txt b/mlir/unittests/ExecutionEngine/CMakeLists.txt
new file mode 100644
index 000000000000..14a2ffe0e242
--- /dev/null
+++ b/mlir/unittests/ExecutionEngine/CMakeLists.txt
@@ -0,0 +1,12 @@
+add_mlir_unittest(MLIRExecutionEngineTests
+ Invoke.cpp
+)
+get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)
+
+target_link_libraries(MLIRExecutionEngineTests
+ PRIVATE
+ MLIRExecutionEngine
+ MLIRLinalgToLLVM
+ ${dialect_libs}
+
+)
diff --git a/mlir/unittests/ExecutionEngine/Invoke.cpp b/mlir/unittests/ExecutionEngine/Invoke.cpp
new file mode 100644
index 000000000000..c9abf2e108b8
--- /dev/null
+++ b/mlir/unittests/ExecutionEngine/Invoke.cpp
@@ -0,0 +1,92 @@
+//===- Invoke.cpp ------------------------------------*- C++ -*-===//
+//
+// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h"
+#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
+#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
+#include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
+#include "mlir/Dialect/Linalg/Passes.h"
+#include "mlir/ExecutionEngine/CRunnerUtils.h"
+#include "mlir/ExecutionEngine/ExecutionEngine.h"
+#include "mlir/ExecutionEngine/RunnerUtils.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/InitAllDialects.h"
+#include "mlir/Parser.h"
+#include "mlir/Pass/PassManager.h"
+#include "llvm/Support/TargetSelect.h"
+#include "llvm/Support/raw_ostream.h"
+
+#include "gmock/gmock.h"
+
+using namespace mlir;
+
+static struct LLVMInitializer {
+ LLVMInitializer() {
+ llvm::InitializeNativeTarget();
+ llvm::InitializeNativeTargetAsmPrinter();
+ }
+} initializer;
+
+/// Simple conversion pipeline for the purpose of testing sources written in
+/// dialects lowering to LLVM Dialect.
+static LogicalResult lowerToLLVMDialect(ModuleOp module) {
+ PassManager pm(module.getContext());
+ pm.addPass(mlir::createLowerToLLVMPass());
+ return pm.run(module);
+}
+
+// The JIT isn't supported on Windows at that time
+#ifndef _WIN32
+
+TEST(MLIRExecutionEngine, AddInteger) {
+ std::string moduleStr = R"mlir(
+ func @foo(%arg0 : i32) -> i32 attributes { llvm.emit_c_interface } {
+ %res = std.addi %arg0, %arg0 : i32
+ return %res : i32
+ }
+ )mlir";
+ MLIRContext context;
+ registerAllDialects(context.getDialectRegistry());
+ OwningModuleRef module = parseSourceString(moduleStr, &context);
+ ASSERT_TRUE(!!module);
+ ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module)));
+ auto jitOrError = ExecutionEngine::create(*module);
+ ASSERT_TRUE(!!jitOrError);
+ std::unique_ptr<ExecutionEngine> jit = std::move(jitOrError.get());
+ // The result of the function must be passed as output argument.
+ int result = 0;
+ llvm::Error error =
+ jit->invoke("foo", 42, ExecutionEngine::Result<int>(result));
+ ASSERT_TRUE(!error);
+ ASSERT_EQ(result, 42 + 42);
+}
+
+TEST(MLIRExecutionEngine, SubtractFloat) {
+ std::string moduleStr = R"mlir(
+ func @foo(%arg0 : f32, %arg1 : f32) -> f32 attributes { llvm.emit_c_interface } {
+ %res = std.subf %arg0, %arg1 : f32
+ return %res : f32
+ }
+ )mlir";
+ MLIRContext context;
+ registerAllDialects(context.getDialectRegistry());
+ OwningModuleRef module = parseSourceString(moduleStr, &context);
+ ASSERT_TRUE(!!module);
+ ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module)));
+ auto jitOrError = ExecutionEngine::create(*module);
+ ASSERT_TRUE(!!jitOrError);
+ std::unique_ptr<ExecutionEngine> jit = std::move(jitOrError.get());
+ // The result of the function must be passed as output argument.
+ float result = -1;
+ llvm::Error error =
+ jit->invoke("foo", 43.0f, 1.0f, ExecutionEngine::result(result));
+ ASSERT_TRUE(!error);
+ ASSERT_EQ(result, 42.f);
+}
+
+#endif // _WIN32
More information about the Mlir-commits
mailing list