[Mlir-commits] [mlir] [MLIR] Add C-API for parsing bytecode (PR #83825)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Mar 4 03:24:29 PST 2024
https://github.com/SpriteOvO updated https://github.com/llvm/llvm-project/pull/83825
>From c33d6e683eef6c4e4a10cb3fe2cdfcd4e8be6a5c Mon Sep 17 00:00:00 2001
From: Asuna <SpriteOvO at gmail.com>
Date: Mon, 4 Mar 2024 12:12:45 +0100
Subject: [PATCH] [MLIR] Add C-API for parsing bytecode
---
mlir/include/mlir-c/IR.h | 21 +++++++
.../mlir/Bytecode/BytecodeReaderConfig.h | 16 ++---
mlir/include/mlir/CAPI/IR.h | 2 +
mlir/include/mlir/IR/AsmState.h | 6 +-
mlir/lib/CAPI/IR/IR.cpp | 38 +++++++++++
mlir/test/CAPI/ir.c | 63 +++++++++++++++++++
6 files changed, 136 insertions(+), 10 deletions(-)
diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h
index 82da511f807a34..67b0b7dfa533ef 100644
--- a/mlir/include/mlir-c/IR.h
+++ b/mlir/include/mlir-c/IR.h
@@ -50,6 +50,7 @@ extern "C" {
DEFINE_C_API_STRUCT(MlirAsmState, void);
DEFINE_C_API_STRUCT(MlirBytecodeWriterConfig, void);
+DEFINE_C_API_STRUCT(MlirBytecodeReaderConfig, void);
DEFINE_C_API_STRUCT(MlirContext, void);
DEFINE_C_API_STRUCT(MlirDialect, void);
DEFINE_C_API_STRUCT(MlirDialectRegistry, void);
@@ -468,6 +469,16 @@ MLIR_CAPI_EXPORTED void
mlirBytecodeWriterConfigDesiredEmitVersion(MlirBytecodeWriterConfig flags,
int64_t version);
+//===----------------------------------------------------------------------===//
+// Bytecode parsing flags API.
+//===----------------------------------------------------------------------===//
+
+MLIR_CAPI_EXPORTED MlirBytecodeReaderConfig
+mlirBytecodeReaderConfigCreate(void);
+
+MLIR_CAPI_EXPORTED void
+mlirBytecodeReaderConfigDestroy(MlirBytecodeReaderConfig config);
+
//===----------------------------------------------------------------------===//
// Operation API.
//===----------------------------------------------------------------------===//
@@ -820,6 +831,16 @@ MLIR_CAPI_EXPORTED MlirOperation mlirBlockGetTerminator(MlirBlock block);
MLIR_CAPI_EXPORTED void mlirBlockAppendOwnedOperation(MlirBlock block,
MlirOperation operation);
+/// Read the operations defined within the given buffer, containing MLIR
+/// bytecode, into the provided block.
+MLIR_CAPI_EXPORTED MlirLogicalResult mlirBlockAppendParseBytecode(
+ MlirContext context, MlirBlock block, MlirStringRef buffer);
+
+/// Same as mlirBlockAppendParseBytecode but with reader config.
+MLIR_CAPI_EXPORTED MlirLogicalResult mlirBlockAppendParseBytecodeWithConfig(
+ MlirContext context, MlirBlock block, MlirStringRef buffer,
+ MlirBytecodeReaderConfig config);
+
/// Takes an operation owned by the caller and inserts it as `pos` to the block.
/// This is an expensive operation that scans the block linearly, prefer
/// insertBefore/After instead.
diff --git a/mlir/include/mlir/Bytecode/BytecodeReaderConfig.h b/mlir/include/mlir/Bytecode/BytecodeReaderConfig.h
index d623d0da2c0c90..56629b9ea72e95 100644
--- a/mlir/include/mlir/Bytecode/BytecodeReaderConfig.h
+++ b/mlir/include/mlir/Bytecode/BytecodeReaderConfig.h
@@ -43,7 +43,7 @@ class AttrTypeBytecodeReader {
CallableT, std::function<LogicalResult(
DialectBytecodeReader &, StringRef, T &)>>,
bool> = true>
- static std::unique_ptr<AttrTypeBytecodeReader<T>>
+ static std::shared_ptr<AttrTypeBytecodeReader<T>>
fromCallable(CallableT &&readFn) {
struct Processor : public AttrTypeBytecodeReader<T> {
Processor(CallableT &&readFn)
@@ -55,7 +55,7 @@ class AttrTypeBytecodeReader {
std::decay_t<CallableT> readFn;
};
- return std::make_unique<Processor>(std::forward<CallableT>(readFn));
+ return std::make_shared<Processor>(std::forward<CallableT>(readFn));
}
};
@@ -69,11 +69,11 @@ class BytecodeReaderConfig {
BytecodeReaderConfig() = default;
/// Returns the callbacks available to the parser.
- ArrayRef<std::unique_ptr<AttrTypeBytecodeReader<Attribute>>>
+ ArrayRef<std::shared_ptr<AttrTypeBytecodeReader<Attribute>>>
getAttributeCallbacks() const {
return attributeBytecodeParsers;
}
- ArrayRef<std::unique_ptr<AttrTypeBytecodeReader<Type>>>
+ ArrayRef<std::shared_ptr<AttrTypeBytecodeReader<Type>>>
getTypeCallbacks() const {
return typeBytecodeParsers;
}
@@ -81,11 +81,11 @@ class BytecodeReaderConfig {
/// Attach a custom bytecode parser callback to the configuration for parsing
/// of custom type/attributes encodings.
void attachAttributeCallback(
- std::unique_ptr<AttrTypeBytecodeReader<Attribute>> parser) {
+ std::shared_ptr<AttrTypeBytecodeReader<Attribute>> parser) {
attributeBytecodeParsers.emplace_back(std::move(parser));
}
void
- attachTypeCallback(std::unique_ptr<AttrTypeBytecodeReader<Type>> parser) {
+ attachTypeCallback(std::shared_ptr<AttrTypeBytecodeReader<Type>> parser) {
typeBytecodeParsers.emplace_back(std::move(parser));
}
@@ -109,9 +109,9 @@ class BytecodeReaderConfig {
}
private:
- llvm::SmallVector<std::unique_ptr<AttrTypeBytecodeReader<Attribute>>>
+ llvm::SmallVector<std::shared_ptr<AttrTypeBytecodeReader<Attribute>>>
attributeBytecodeParsers;
- llvm::SmallVector<std::unique_ptr<AttrTypeBytecodeReader<Type>>>
+ llvm::SmallVector<std::shared_ptr<AttrTypeBytecodeReader<Type>>>
typeBytecodeParsers;
};
diff --git a/mlir/include/mlir/CAPI/IR.h b/mlir/include/mlir/CAPI/IR.h
index 1836cb0acb67e7..c191004040c568 100644
--- a/mlir/include/mlir/CAPI/IR.h
+++ b/mlir/include/mlir/CAPI/IR.h
@@ -15,6 +15,7 @@
#ifndef MLIR_CAPI_IR_H
#define MLIR_CAPI_IR_H
+#include "mlir/Bytecode/BytecodeReader.h"
#include "mlir/Bytecode/BytecodeWriter.h"
#include "mlir/CAPI/Wrap.h"
#include "mlir/IR/BuiltinOps.h"
@@ -23,6 +24,7 @@
DEFINE_C_API_PTR_METHODS(MlirAsmState, mlir::AsmState)
DEFINE_C_API_PTR_METHODS(MlirBytecodeWriterConfig, mlir::BytecodeWriterConfig)
+DEFINE_C_API_PTR_METHODS(MlirBytecodeReaderConfig, mlir::BytecodeReaderConfig)
DEFINE_C_API_PTR_METHODS(MlirContext, mlir::MLIRContext)
DEFINE_C_API_PTR_METHODS(MlirDialect, mlir::Dialect)
DEFINE_C_API_PTR_METHODS(MlirDialectRegistry, mlir::DialectRegistry)
diff --git a/mlir/include/mlir/IR/AsmState.h b/mlir/include/mlir/IR/AsmState.h
index 42cbedcf9f8837..2b58a0040afe68 100644
--- a/mlir/include/mlir/IR/AsmState.h
+++ b/mlir/include/mlir/IR/AsmState.h
@@ -464,9 +464,11 @@ class ParserConfig {
/// `fallbackResourceMap` is an optional fallback handler that can be used to
/// parse external resources not explicitly handled by another parser.
ParserConfig(MLIRContext *context, bool verifyAfterParse = true,
- FallbackAsmResourceMap *fallbackResourceMap = nullptr)
+ FallbackAsmResourceMap *fallbackResourceMap = nullptr,
+ BytecodeReaderConfig bytecodeReaderConfig = {})
: context(context), verifyAfterParse(verifyAfterParse),
- fallbackResourceMap(fallbackResourceMap) {
+ fallbackResourceMap(fallbackResourceMap),
+ bytecodeReaderConfig{std::move(bytecodeReaderConfig)} {
assert(context && "expected valid MLIR context");
}
diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp
index cdb64f4ec4a40f..8d15d0e16c0a57 100644
--- a/mlir/lib/CAPI/IR/IR.cpp
+++ b/mlir/lib/CAPI/IR/IR.cpp
@@ -10,6 +10,7 @@
#include "mlir-c/Support.h"
#include "mlir/AsmParser/AsmParser.h"
+#include "mlir/Bytecode/BytecodeReader.h"
#include "mlir/Bytecode/BytecodeWriter.h"
#include "mlir/CAPI/IR.h"
#include "mlir/CAPI/Support.h"
@@ -28,6 +29,7 @@
#include "mlir/IR/Visitors.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Parser/Parser.h"
+#include "llvm/Support/MemoryBuffer.h"
#include "llvm/Support/ThreadPool.h"
#include <cstddef>
@@ -236,6 +238,18 @@ void mlirBytecodeWriterConfigDesiredEmitVersion(MlirBytecodeWriterConfig flags,
unwrap(flags)->setDesiredBytecodeVersion(version);
}
+//===----------------------------------------------------------------------===//
+// Bytecode parsing flags API.
+//===----------------------------------------------------------------------===//
+
+MlirBytecodeReaderConfig mlirBytecodeReaderConfigCreate() {
+ return wrap(new BytecodeReaderConfig());
+}
+
+void mlirBytecodeReaderConfigDestroy(MlirBytecodeReaderConfig config) {
+ delete unwrap(config);
+}
+
//===----------------------------------------------------------------------===//
// Location API.
//===----------------------------------------------------------------------===//
@@ -840,6 +854,30 @@ void mlirBlockAppendOwnedOperation(MlirBlock block, MlirOperation operation) {
unwrap(block)->push_back(unwrap(operation));
}
+MlirLogicalResult mlirBlockAppendParseBytecode(MlirContext context,
+ MlirBlock block,
+ MlirStringRef buffer) {
+ auto config = mlirBytecodeReaderConfigCreate();
+ auto result =
+ mlirBlockAppendParseBytecodeWithConfig(context, block, buffer, config);
+ mlirBytecodeReaderConfigDestroy(config);
+ return result;
+}
+
+MlirLogicalResult
+mlirBlockAppendParseBytecodeWithConfig(MlirContext context, MlirBlock block,
+ MlirStringRef buffer,
+ MlirBytecodeReaderConfig config) {
+ auto memBuffer = llvm::MemoryBuffer::getMemBuffer(
+ unwrap(buffer), /* BufferName */ __func__,
+ /* RequiresNullTerminator */ false);
+ auto parserConfig =
+ mlir::ParserConfig{unwrap(context), /* verifyAfterParse */ true,
+ /* fallbackResourceMap */ nullptr,
+ /* bytecodeReaderConfig */ *unwrap(config)};
+ return wrap(mlir::readBytecodeFile(*memBuffer, unwrap(block), parserConfig));
+}
+
void mlirBlockInsertOwnedOperation(MlirBlock block, intptr_t pos,
MlirOperation operation) {
auto &opList = unwrap(block)->getOperations();
diff --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c
index a9850c0a132e75..586eb2443d32e9 100644
--- a/mlir/test/CAPI/ir.c
+++ b/mlir/test/CAPI/ir.c
@@ -2381,6 +2381,68 @@ void testDiagnostics(void) {
mlirContextDestroy(ctx);
}
+void callbackCollectStreamData(MlirStringRef input, void *userData) {
+ MlirStringRef *out = (MlirStringRef *)userData;
+ if (out->data == NULL) {
+ out->data = malloc(input.length);
+ memcpy(out->data, input.data, input.length);
+ out->length = input.length;
+ } else {
+ out->data = realloc(out->data, out->length + input.length);
+ memcpy(out->data + out->length, input.data, input.length);
+ out->length += input.length;
+ }
+}
+
+void testMlirBytecodeReadWrite(MlirContext ctx) {
+ const char *moduleString = "module {\n"
+ " func.func @mlirbc_test() {\n"
+ " %1 = arith.constant 114: i32\n"
+ " %2 = arith.constant 514: i32\n"
+ " arith.addi %1, %2: i32\n"
+ " return\n"
+ " }\n"
+ "}";
+ MlirModule module =
+ mlirModuleCreateParse(ctx, mlirStringRefCreateFromCString(moduleString));
+
+ MlirStringRef bytecode = mlirStringRefCreate(NULL, 0);
+ mlirOperationWriteBytecode(mlirModuleGetOperation(module),
+ callbackCollectStreamData, &bytecode);
+
+ MlirBlock blockBc = mlirBlockCreate(0, NULL, NULL);
+ MlirLogicalResult result =
+ mlirBlockAppendParseBytecode(ctx, blockBc, bytecode);
+ assert(mlirLogicalResultIsSuccess(result));
+
+ MlirModule moduleBc =
+ mlirModuleFromOperation(mlirBlockGetFirstOperation(blockBc));
+
+ fprintf(stderr, "===== mlirbc-test manually =====\n");
+ mlirOperationDump(mlirModuleGetOperation(module));
+ fprintf(stderr, "===== mlirbc-test parsed =====\n");
+ mlirOperationDump(mlirModuleGetOperation(moduleBc));
+
+ // CHECK: ===== mlirbc-test manually =====
+ // CHECK-NEXT: module {
+ // CHECK-NEXT: func.func @mlirbc_test() {
+ // CHECK-NEXT: %c114_i32 = arith.constant 114 : i32
+ // CHECK-NEXT: %c514_i32 = arith.constant 514 : i32
+ // CHECK-NEXT: %0 = arith.addi %c114_i32, %c514_i32 : i32
+ // CHECK-NEXT: return
+ // CHECK-NEXT: }
+ // CHECK-NEXT: }
+ // CHECK-NEXT: ===== mlirbc-test parsed =====
+ // CHECK-NEXT: module {
+ // CHECK-NEXT: func.func @mlirbc_test() {
+ // CHECK-NEXT: %c114_i32 = arith.constant 114 : i32
+ // CHECK-NEXT: %c514_i32 = arith.constant 514 : i32
+ // CHECK-NEXT: %0 = arith.addi %c114_i32, %c514_i32 : i32
+ // CHECK-NEXT: return
+ // CHECK-NEXT: }
+ // CHECK-NEXT: }
+}
+
int main(void) {
MlirContext ctx = mlirContextCreate();
registerAllUpstreamDialects(ctx);
@@ -2426,6 +2488,7 @@ int main(void) {
testExplicitThreadPools();
testDiagnostics();
+ testMlirBytecodeReadWrite(ctx);
// CHECK: DESTROY MAIN CONTEXT
// CHECK: reportResourceDelete: resource_i64_blob
More information about the Mlir-commits
mailing list