[Mlir-commits] [mlir] [MLIR] Add C-API for parsing bytecode (PR #83825)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Mar 4 03:24:29 PST 2024


https://github.com/SpriteOvO updated https://github.com/llvm/llvm-project/pull/83825

>From c33d6e683eef6c4e4a10cb3fe2cdfcd4e8be6a5c Mon Sep 17 00:00:00 2001
From: Asuna <SpriteOvO at gmail.com>
Date: Mon, 4 Mar 2024 12:12:45 +0100
Subject: [PATCH] [MLIR] Add C-API for parsing bytecode

---
 mlir/include/mlir-c/IR.h                      | 21 +++++++
 .../mlir/Bytecode/BytecodeReaderConfig.h      | 16 ++---
 mlir/include/mlir/CAPI/IR.h                   |  2 +
 mlir/include/mlir/IR/AsmState.h               |  6 +-
 mlir/lib/CAPI/IR/IR.cpp                       | 38 +++++++++++
 mlir/test/CAPI/ir.c                           | 63 +++++++++++++++++++
 6 files changed, 136 insertions(+), 10 deletions(-)

diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h
index 82da511f807a34..67b0b7dfa533ef 100644
--- a/mlir/include/mlir-c/IR.h
+++ b/mlir/include/mlir-c/IR.h
@@ -50,6 +50,7 @@ extern "C" {
 
 DEFINE_C_API_STRUCT(MlirAsmState, void);
 DEFINE_C_API_STRUCT(MlirBytecodeWriterConfig, void);
+DEFINE_C_API_STRUCT(MlirBytecodeReaderConfig, void);
 DEFINE_C_API_STRUCT(MlirContext, void);
 DEFINE_C_API_STRUCT(MlirDialect, void);
 DEFINE_C_API_STRUCT(MlirDialectRegistry, void);
@@ -468,6 +469,16 @@ MLIR_CAPI_EXPORTED void
 mlirBytecodeWriterConfigDesiredEmitVersion(MlirBytecodeWriterConfig flags,
                                            int64_t version);
 
+//===----------------------------------------------------------------------===//
+// Bytecode parsing flags API.
+//===----------------------------------------------------------------------===//
+
+MLIR_CAPI_EXPORTED MlirBytecodeReaderConfig
+mlirBytecodeReaderConfigCreate(void);
+
+MLIR_CAPI_EXPORTED void
+mlirBytecodeReaderConfigDestroy(MlirBytecodeReaderConfig config);
+
 //===----------------------------------------------------------------------===//
 // Operation API.
 //===----------------------------------------------------------------------===//
@@ -820,6 +831,16 @@ MLIR_CAPI_EXPORTED MlirOperation mlirBlockGetTerminator(MlirBlock block);
 MLIR_CAPI_EXPORTED void mlirBlockAppendOwnedOperation(MlirBlock block,
                                                       MlirOperation operation);
 
+/// Read the operations defined within the given buffer, containing MLIR
+/// bytecode, into the provided block.
+MLIR_CAPI_EXPORTED MlirLogicalResult mlirBlockAppendParseBytecode(
+    MlirContext context, MlirBlock block, MlirStringRef buffer);
+
+/// Same as mlirBlockAppendParseBytecode but with reader config.
+MLIR_CAPI_EXPORTED MlirLogicalResult mlirBlockAppendParseBytecodeWithConfig(
+    MlirContext context, MlirBlock block, MlirStringRef buffer,
+    MlirBytecodeReaderConfig config);
+
 /// Takes an operation owned by the caller and inserts it as `pos` to the block.
 /// This is an expensive operation that scans the block linearly, prefer
 /// insertBefore/After instead.
diff --git a/mlir/include/mlir/Bytecode/BytecodeReaderConfig.h b/mlir/include/mlir/Bytecode/BytecodeReaderConfig.h
index d623d0da2c0c90..56629b9ea72e95 100644
--- a/mlir/include/mlir/Bytecode/BytecodeReaderConfig.h
+++ b/mlir/include/mlir/Bytecode/BytecodeReaderConfig.h
@@ -43,7 +43,7 @@ class AttrTypeBytecodeReader {
                     CallableT, std::function<LogicalResult(
                                    DialectBytecodeReader &, StringRef, T &)>>,
                 bool> = true>
-  static std::unique_ptr<AttrTypeBytecodeReader<T>>
+  static std::shared_ptr<AttrTypeBytecodeReader<T>>
   fromCallable(CallableT &&readFn) {
     struct Processor : public AttrTypeBytecodeReader<T> {
       Processor(CallableT &&readFn)
@@ -55,7 +55,7 @@ class AttrTypeBytecodeReader {
 
       std::decay_t<CallableT> readFn;
     };
-    return std::make_unique<Processor>(std::forward<CallableT>(readFn));
+    return std::make_shared<Processor>(std::forward<CallableT>(readFn));
   }
 };
 
@@ -69,11 +69,11 @@ class BytecodeReaderConfig {
   BytecodeReaderConfig() = default;
 
   /// Returns the callbacks available to the parser.
-  ArrayRef<std::unique_ptr<AttrTypeBytecodeReader<Attribute>>>
+  ArrayRef<std::shared_ptr<AttrTypeBytecodeReader<Attribute>>>
   getAttributeCallbacks() const {
     return attributeBytecodeParsers;
   }
-  ArrayRef<std::unique_ptr<AttrTypeBytecodeReader<Type>>>
+  ArrayRef<std::shared_ptr<AttrTypeBytecodeReader<Type>>>
   getTypeCallbacks() const {
     return typeBytecodeParsers;
   }
@@ -81,11 +81,11 @@ class BytecodeReaderConfig {
   /// Attach a custom bytecode parser callback to the configuration for parsing
   /// of custom type/attributes encodings.
   void attachAttributeCallback(
-      std::unique_ptr<AttrTypeBytecodeReader<Attribute>> parser) {
+      std::shared_ptr<AttrTypeBytecodeReader<Attribute>> parser) {
     attributeBytecodeParsers.emplace_back(std::move(parser));
   }
   void
-  attachTypeCallback(std::unique_ptr<AttrTypeBytecodeReader<Type>> parser) {
+  attachTypeCallback(std::shared_ptr<AttrTypeBytecodeReader<Type>> parser) {
     typeBytecodeParsers.emplace_back(std::move(parser));
   }
 
@@ -109,9 +109,9 @@ class BytecodeReaderConfig {
   }
 
 private:
-  llvm::SmallVector<std::unique_ptr<AttrTypeBytecodeReader<Attribute>>>
+  llvm::SmallVector<std::shared_ptr<AttrTypeBytecodeReader<Attribute>>>
       attributeBytecodeParsers;
-  llvm::SmallVector<std::unique_ptr<AttrTypeBytecodeReader<Type>>>
+  llvm::SmallVector<std::shared_ptr<AttrTypeBytecodeReader<Type>>>
       typeBytecodeParsers;
 };
 
diff --git a/mlir/include/mlir/CAPI/IR.h b/mlir/include/mlir/CAPI/IR.h
index 1836cb0acb67e7..c191004040c568 100644
--- a/mlir/include/mlir/CAPI/IR.h
+++ b/mlir/include/mlir/CAPI/IR.h
@@ -15,6 +15,7 @@
 #ifndef MLIR_CAPI_IR_H
 #define MLIR_CAPI_IR_H
 
+#include "mlir/Bytecode/BytecodeReader.h"
 #include "mlir/Bytecode/BytecodeWriter.h"
 #include "mlir/CAPI/Wrap.h"
 #include "mlir/IR/BuiltinOps.h"
@@ -23,6 +24,7 @@
 
 DEFINE_C_API_PTR_METHODS(MlirAsmState, mlir::AsmState)
 DEFINE_C_API_PTR_METHODS(MlirBytecodeWriterConfig, mlir::BytecodeWriterConfig)
+DEFINE_C_API_PTR_METHODS(MlirBytecodeReaderConfig, mlir::BytecodeReaderConfig)
 DEFINE_C_API_PTR_METHODS(MlirContext, mlir::MLIRContext)
 DEFINE_C_API_PTR_METHODS(MlirDialect, mlir::Dialect)
 DEFINE_C_API_PTR_METHODS(MlirDialectRegistry, mlir::DialectRegistry)
diff --git a/mlir/include/mlir/IR/AsmState.h b/mlir/include/mlir/IR/AsmState.h
index 42cbedcf9f8837..2b58a0040afe68 100644
--- a/mlir/include/mlir/IR/AsmState.h
+++ b/mlir/include/mlir/IR/AsmState.h
@@ -464,9 +464,11 @@ class ParserConfig {
   /// `fallbackResourceMap` is an optional fallback handler that can be used to
   /// parse external resources not explicitly handled by another parser.
   ParserConfig(MLIRContext *context, bool verifyAfterParse = true,
-               FallbackAsmResourceMap *fallbackResourceMap = nullptr)
+               FallbackAsmResourceMap *fallbackResourceMap = nullptr,
+               BytecodeReaderConfig bytecodeReaderConfig = {})
       : context(context), verifyAfterParse(verifyAfterParse),
-        fallbackResourceMap(fallbackResourceMap) {
+        fallbackResourceMap(fallbackResourceMap),
+        bytecodeReaderConfig{std::move(bytecodeReaderConfig)} {
     assert(context && "expected valid MLIR context");
   }
 
diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp
index cdb64f4ec4a40f..8d15d0e16c0a57 100644
--- a/mlir/lib/CAPI/IR/IR.cpp
+++ b/mlir/lib/CAPI/IR/IR.cpp
@@ -10,6 +10,7 @@
 #include "mlir-c/Support.h"
 
 #include "mlir/AsmParser/AsmParser.h"
+#include "mlir/Bytecode/BytecodeReader.h"
 #include "mlir/Bytecode/BytecodeWriter.h"
 #include "mlir/CAPI/IR.h"
 #include "mlir/CAPI/Support.h"
@@ -28,6 +29,7 @@
 #include "mlir/IR/Visitors.h"
 #include "mlir/Interfaces/InferTypeOpInterface.h"
 #include "mlir/Parser/Parser.h"
+#include "llvm/Support/MemoryBuffer.h"
 #include "llvm/Support/ThreadPool.h"
 
 #include <cstddef>
@@ -236,6 +238,18 @@ void mlirBytecodeWriterConfigDesiredEmitVersion(MlirBytecodeWriterConfig flags,
   unwrap(flags)->setDesiredBytecodeVersion(version);
 }
 
+//===----------------------------------------------------------------------===//
+// Bytecode parsing flags API.
+//===----------------------------------------------------------------------===//
+
+MlirBytecodeReaderConfig mlirBytecodeReaderConfigCreate() {
+  return wrap(new BytecodeReaderConfig());
+}
+
+void mlirBytecodeReaderConfigDestroy(MlirBytecodeReaderConfig config) {
+  delete unwrap(config);
+}
+
 //===----------------------------------------------------------------------===//
 // Location API.
 //===----------------------------------------------------------------------===//
@@ -840,6 +854,30 @@ void mlirBlockAppendOwnedOperation(MlirBlock block, MlirOperation operation) {
   unwrap(block)->push_back(unwrap(operation));
 }
 
+MlirLogicalResult mlirBlockAppendParseBytecode(MlirContext context,
+                                               MlirBlock block,
+                                               MlirStringRef buffer) {
+  auto config = mlirBytecodeReaderConfigCreate();
+  auto result =
+      mlirBlockAppendParseBytecodeWithConfig(context, block, buffer, config);
+  mlirBytecodeReaderConfigDestroy(config);
+  return result;
+}
+
+MlirLogicalResult
+mlirBlockAppendParseBytecodeWithConfig(MlirContext context, MlirBlock block,
+                                       MlirStringRef buffer,
+                                       MlirBytecodeReaderConfig config) {
+  auto memBuffer = llvm::MemoryBuffer::getMemBuffer(
+      unwrap(buffer), /* BufferName */ __func__,
+      /* RequiresNullTerminator */ false);
+  auto parserConfig =
+      mlir::ParserConfig{unwrap(context), /* verifyAfterParse */ true,
+                         /* fallbackResourceMap */ nullptr,
+                         /* bytecodeReaderConfig */ *unwrap(config)};
+  return wrap(mlir::readBytecodeFile(*memBuffer, unwrap(block), parserConfig));
+}
+
 void mlirBlockInsertOwnedOperation(MlirBlock block, intptr_t pos,
                                    MlirOperation operation) {
   auto &opList = unwrap(block)->getOperations();
diff --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c
index a9850c0a132e75..586eb2443d32e9 100644
--- a/mlir/test/CAPI/ir.c
+++ b/mlir/test/CAPI/ir.c
@@ -2381,6 +2381,68 @@ void testDiagnostics(void) {
   mlirContextDestroy(ctx);
 }
 
+void callbackCollectStreamData(MlirStringRef input, void *userData) {
+  MlirStringRef *out = (MlirStringRef *)userData;
+  if (out->data == NULL) {
+    out->data = malloc(input.length);
+    memcpy(out->data, input.data, input.length);
+    out->length = input.length;
+  } else {
+    out->data = realloc(out->data, out->length + input.length);
+    memcpy(out->data + out->length, input.data, input.length);
+    out->length += input.length;
+  }
+}
+
+void testMlirBytecodeReadWrite(MlirContext ctx) {
+  const char *moduleString = "module {\n"
+                             "  func.func @mlirbc_test() {\n"
+                             "    %1 = arith.constant 114: i32\n"
+                             "    %2 = arith.constant 514: i32\n"
+                             "    arith.addi %1, %2: i32\n"
+                             "    return\n"
+                             "  }\n"
+                             "}";
+  MlirModule module =
+      mlirModuleCreateParse(ctx, mlirStringRefCreateFromCString(moduleString));
+
+  MlirStringRef bytecode = mlirStringRefCreate(NULL, 0);
+  mlirOperationWriteBytecode(mlirModuleGetOperation(module),
+                             callbackCollectStreamData, &bytecode);
+
+  MlirBlock blockBc = mlirBlockCreate(0, NULL, NULL);
+  MlirLogicalResult result =
+      mlirBlockAppendParseBytecode(ctx, blockBc, bytecode);
+  assert(mlirLogicalResultIsSuccess(result));
+
+  MlirModule moduleBc =
+      mlirModuleFromOperation(mlirBlockGetFirstOperation(blockBc));
+
+  fprintf(stderr, "===== mlirbc-test manually =====\n");
+  mlirOperationDump(mlirModuleGetOperation(module));
+  fprintf(stderr, "===== mlirbc-test parsed =====\n");
+  mlirOperationDump(mlirModuleGetOperation(moduleBc));
+
+  // CHECK:      ===== mlirbc-test manually =====
+  // CHECK-NEXT: module {
+  // CHECK-NEXT:   func.func @mlirbc_test() {
+  // CHECK-NEXT:     %c114_i32 = arith.constant 114 : i32
+  // CHECK-NEXT:     %c514_i32 = arith.constant 514 : i32
+  // CHECK-NEXT:     %0 = arith.addi %c114_i32, %c514_i32 : i32
+  // CHECK-NEXT:     return
+  // CHECK-NEXT:   }
+  // CHECK-NEXT: }
+  // CHECK-NEXT: ===== mlirbc-test parsed =====
+  // CHECK-NEXT: module {
+  // CHECK-NEXT:   func.func @mlirbc_test() {
+  // CHECK-NEXT:     %c114_i32 = arith.constant 114 : i32
+  // CHECK-NEXT:     %c514_i32 = arith.constant 514 : i32
+  // CHECK-NEXT:     %0 = arith.addi %c114_i32, %c514_i32 : i32
+  // CHECK-NEXT:     return
+  // CHECK-NEXT:   }
+  // CHECK-NEXT: }
+}
+
 int main(void) {
   MlirContext ctx = mlirContextCreate();
   registerAllUpstreamDialects(ctx);
@@ -2426,6 +2488,7 @@ int main(void) {
 
   testExplicitThreadPools();
   testDiagnostics();
+  testMlirBytecodeReadWrite(ctx);
 
   // CHECK: DESTROY MAIN CONTEXT
   // CHECK: reportResourceDelete: resource_i64_blob



More information about the Mlir-commits mailing list