[Mlir-commits] [mlir] d6efb6f - Rework ExecutionEngine::invoke() to make it more friendly to use from C++

Mehdi Amini llvmlistbot at llvm.org
Fri Feb 5 17:33:00 PST 2021


Author: Mehdi Amini
Date: 2021-02-06T01:32:50Z
New Revision: d6efb6fc86a6ded63fc50e3a31377a1f4aa33c6e

URL: https://github.com/llvm/llvm-project/commit/d6efb6fc86a6ded63fc50e3a31377a1f4aa33c6e
DIFF: https://github.com/llvm/llvm-project/commit/d6efb6fc86a6ded63fc50e3a31377a1f4aa33c6e.diff

LOG: Rework ExecutionEngine::invoke() to make it more friendly to use from C++

This new invoke will pack a list of argument before calling the
`invokePacked` method. It accepts returned value as output argument
wrapped in `ExecutionEngine::Result<T>`, and delegate the packing of
arguments to a trait to allow for customization for some types.

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D95961

Added: 
    mlir/unittests/ExecutionEngine/CMakeLists.txt
    mlir/unittests/ExecutionEngine/Invoke.cpp

Modified: 
    mlir/include/mlir/ExecutionEngine/ExecutionEngine.h
    mlir/lib/ExecutionEngine/ExecutionEngine.cpp
    mlir/unittests/CMakeLists.txt

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/ExecutionEngine/ExecutionEngine.h b/mlir/include/mlir/ExecutionEngine/ExecutionEngine.h
index 1b6b0a8f670c..5c9f0fa1c3c9 100644
--- a/mlir/include/mlir/ExecutionEngine/ExecutionEngine.h
+++ b/mlir/include/mlir/ExecutionEngine/ExecutionEngine.h
@@ -101,9 +101,66 @@ class ExecutionEngine {
   /// pointer to it.  Propagates errors in case of failure.
   llvm::Expected<void (*)(void **)> lookup(StringRef name) const;
 
+  /// Invokes the function with the given name passing it the list of opaque
+  /// pointers to the actual arguments.
+  llvm::Error invokePacked(StringRef name,
+                           MutableArrayRef<void *> args = llvm::None);
+
+  /// Trait that defines how a given type is passed to the JIT code. This
+  /// defaults to passing the address but can be specialized.
+  template <typename T>
+  struct Argument {
+    static void pack(SmallVectorImpl<void *> &args, T &val) {
+      args.push_back(&val);
+    }
+  };
+
+  /// Tag to wrap an output parameter when invoking a jitted function.
+  template <typename T>
+  struct Result {
+    Result(T &result) : value(result) {}
+    T &value;
+  };
+
+  /// Helper function to wrap an output operand when using
+  /// ExecutionEngine::invoke.
+  template <typename T>
+  static Result<T> result(T &t) {
+    return Result<T>(t);
+  }
+
+  // Specialization for output parameter: their address is forwarded directly to
+  // the native code.
+  template <typename T>
+  struct Argument<Result<T>> {
+    static void pack(SmallVectorImpl<void *> &args, Result<T> &result) {
+      args.push_back(&result.value);
+    }
+  };
+
   /// Invokes the function with the given name passing it the list of arguments
-  /// as a list of opaque pointers.
-  llvm::Error invoke(StringRef name, MutableArrayRef<void *> args = llvm::None);
+  /// by value. Function result can be obtain through output parameter using the
+  /// `Result` wrapper defined above. For example:
+  ///
+  ///     func @foo(%arg0 : i32) -> i32 attributes { llvm.emit_c_interface }
+  ///
+  /// can be invoked:
+  ///
+  ///     int32_t result = 0;
+  ///     llvm::Error error = jit->invoke("foo", 42,
+  ///                                     result(result));
+  template <typename... Args>
+  llvm::Error invoke(StringRef funcName, Args... args) {
+    const std::string adapterName =
+        std::string("_mlir_ciface_") + funcName.str();
+    llvm::SmallVector<void *> argsArray;
+    // Pack every arguments in an array of pointers. Delegate the packing to a
+    // trait so that it can be overridden per argument type.
+    // TODO: replace with a fold expression when migrating to C++17.
+    int dummy[] = {0, ((void)Argument<Args>::pack(argsArray, args), 0)...};
+    (void)dummy;
+    return invokePacked(adapterName, argsArray);
+  }
 
   /// Set the target triple on the module. This is implicitly done when creating
   /// the engine.

diff  --git a/mlir/lib/ExecutionEngine/ExecutionEngine.cpp b/mlir/lib/ExecutionEngine/ExecutionEngine.cpp
index 9cd2054530ef..5d9d4e19a14e 100644
--- a/mlir/lib/ExecutionEngine/ExecutionEngine.cpp
+++ b/mlir/lib/ExecutionEngine/ExecutionEngine.cpp
@@ -339,7 +339,8 @@ Expected<void (*)(void **)> ExecutionEngine::lookup(StringRef name) const {
   return fptr;
 }
 
-Error ExecutionEngine::invoke(StringRef name, MutableArrayRef<void *> args) {
+Error ExecutionEngine::invokePacked(StringRef name,
+                                    MutableArrayRef<void *> args) {
   auto expectedFPtr = lookup(name);
   if (!expectedFPtr)
     return expectedFPtr.takeError();

diff  --git a/mlir/unittests/CMakeLists.txt b/mlir/unittests/CMakeLists.txt
index 851092c5b56a..d8d52fc1fd53 100644
--- a/mlir/unittests/CMakeLists.txt
+++ b/mlir/unittests/CMakeLists.txt
@@ -7,6 +7,7 @@ endfunction()
 
 add_subdirectory(Analysis)
 add_subdirectory(Dialect)
+add_subdirectory(ExecutionEngine)
 add_subdirectory(IR)
 add_subdirectory(Pass)
 add_subdirectory(SDBM)

diff  --git a/mlir/unittests/ExecutionEngine/CMakeLists.txt b/mlir/unittests/ExecutionEngine/CMakeLists.txt
new file mode 100644
index 000000000000..14a2ffe0e242
--- /dev/null
+++ b/mlir/unittests/ExecutionEngine/CMakeLists.txt
@@ -0,0 +1,12 @@
+add_mlir_unittest(MLIRExecutionEngineTests
+  Invoke.cpp
+)
+get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)
+
+target_link_libraries(MLIRExecutionEngineTests
+  PRIVATE
+  MLIRExecutionEngine
+  MLIRLinalgToLLVM
+  ${dialect_libs}
+
+)

diff  --git a/mlir/unittests/ExecutionEngine/Invoke.cpp b/mlir/unittests/ExecutionEngine/Invoke.cpp
new file mode 100644
index 000000000000..c9abf2e108b8
--- /dev/null
+++ b/mlir/unittests/ExecutionEngine/Invoke.cpp
@@ -0,0 +1,92 @@
+//===- Invoke.cpp ------------------------------------*- C++ -*-===//
+//
+// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h"
+#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
+#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
+#include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
+#include "mlir/Dialect/Linalg/Passes.h"
+#include "mlir/ExecutionEngine/CRunnerUtils.h"
+#include "mlir/ExecutionEngine/ExecutionEngine.h"
+#include "mlir/ExecutionEngine/RunnerUtils.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/InitAllDialects.h"
+#include "mlir/Parser.h"
+#include "mlir/Pass/PassManager.h"
+#include "llvm/Support/TargetSelect.h"
+#include "llvm/Support/raw_ostream.h"
+
+#include "gmock/gmock.h"
+
+using namespace mlir;
+
+static struct LLVMInitializer {
+  LLVMInitializer() {
+    llvm::InitializeNativeTarget();
+    llvm::InitializeNativeTargetAsmPrinter();
+  }
+} initializer;
+
+/// Simple conversion pipeline for the purpose of testing sources written in
+/// dialects lowering to LLVM Dialect.
+static LogicalResult lowerToLLVMDialect(ModuleOp module) {
+  PassManager pm(module.getContext());
+  pm.addPass(mlir::createLowerToLLVMPass());
+  return pm.run(module);
+}
+
+// The JIT isn't supported on Windows at that time
+#ifndef _WIN32
+
+TEST(MLIRExecutionEngine, AddInteger) {
+  std::string moduleStr = R"mlir(
+  func @foo(%arg0 : i32) -> i32 attributes { llvm.emit_c_interface } {
+    %res = std.addi %arg0, %arg0 : i32
+    return %res : i32
+  }
+  )mlir";
+  MLIRContext context;
+  registerAllDialects(context.getDialectRegistry());
+  OwningModuleRef module = parseSourceString(moduleStr, &context);
+  ASSERT_TRUE(!!module);
+  ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module)));
+  auto jitOrError = ExecutionEngine::create(*module);
+  ASSERT_TRUE(!!jitOrError);
+  std::unique_ptr<ExecutionEngine> jit = std::move(jitOrError.get());
+  // The result of the function must be passed as output argument.
+  int result = 0;
+  llvm::Error error =
+      jit->invoke("foo", 42, ExecutionEngine::Result<int>(result));
+  ASSERT_TRUE(!error);
+  ASSERT_EQ(result, 42 + 42);
+}
+
+TEST(MLIRExecutionEngine, SubtractFloat) {
+  std::string moduleStr = R"mlir(
+  func @foo(%arg0 : f32, %arg1 : f32) -> f32 attributes { llvm.emit_c_interface } {
+    %res = std.subf %arg0, %arg1 : f32
+    return %res : f32
+  }
+  )mlir";
+  MLIRContext context;
+  registerAllDialects(context.getDialectRegistry());
+  OwningModuleRef module = parseSourceString(moduleStr, &context);
+  ASSERT_TRUE(!!module);
+  ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module)));
+  auto jitOrError = ExecutionEngine::create(*module);
+  ASSERT_TRUE(!!jitOrError);
+  std::unique_ptr<ExecutionEngine> jit = std::move(jitOrError.get());
+  // The result of the function must be passed as output argument.
+  float result = -1;
+  llvm::Error error =
+      jit->invoke("foo", 43.0f, 1.0f, ExecutionEngine::result(result));
+  ASSERT_TRUE(!error);
+  ASSERT_EQ(result, 42.f);
+}
+
+#endif // _WIN32


        


More information about the Mlir-commits mailing list