[Mlir-commits] [mlir] [MLIR] Add C-API for parsing bytecode (PR #83825)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Mar 4 09:47:24 PST 2024


https://github.com/SpriteOvO updated https://github.com/llvm/llvm-project/pull/83825

>From 7b0954b98d5ae4a3cecd9b61aeec740be21e3921 Mon Sep 17 00:00:00 2001
From: Asuna <SpriteOvO at gmail.com>
Date: Mon, 4 Mar 2024 12:12:45 +0100
Subject: [PATCH] [MLIR] Add C-API for parsing bytecode

---
 mlir/include/mlir-c/IR.h                    |  21 +++
 mlir/include/mlir/CAPI/IR.h                 |   2 +
 mlir/include/mlir/IR/AsmState.h             |  15 +-
 mlir/lib/Bytecode/Reader/BytecodeReader.cpp |  54 +++---
 mlir/lib/CAPI/IR/IR.cpp                     |  42 +++++
 mlir/test/CAPI/ir.c                         |  63 +++++++
 mlir/test/lib/IR/TestBytecodeRoundtrip.cpp  | 190 ++++++++++----------
 7 files changed, 264 insertions(+), 123 deletions(-)

diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h
index 82da511f807a34..67b0b7dfa533ef 100644
--- a/mlir/include/mlir-c/IR.h
+++ b/mlir/include/mlir-c/IR.h
@@ -50,6 +50,7 @@ extern "C" {
 
 DEFINE_C_API_STRUCT(MlirAsmState, void);
 DEFINE_C_API_STRUCT(MlirBytecodeWriterConfig, void);
+DEFINE_C_API_STRUCT(MlirBytecodeReaderConfig, void);
 DEFINE_C_API_STRUCT(MlirContext, void);
 DEFINE_C_API_STRUCT(MlirDialect, void);
 DEFINE_C_API_STRUCT(MlirDialectRegistry, void);
@@ -468,6 +469,16 @@ MLIR_CAPI_EXPORTED void
 mlirBytecodeWriterConfigDesiredEmitVersion(MlirBytecodeWriterConfig flags,
                                            int64_t version);
 
+//===----------------------------------------------------------------------===//
+// Bytecode parsing flags API.
+//===----------------------------------------------------------------------===//
+
+MLIR_CAPI_EXPORTED MlirBytecodeReaderConfig
+mlirBytecodeReaderConfigCreate(void);
+
+MLIR_CAPI_EXPORTED void
+mlirBytecodeReaderConfigDestroy(MlirBytecodeReaderConfig config);
+
 //===----------------------------------------------------------------------===//
 // Operation API.
 //===----------------------------------------------------------------------===//
@@ -820,6 +831,16 @@ MLIR_CAPI_EXPORTED MlirOperation mlirBlockGetTerminator(MlirBlock block);
 MLIR_CAPI_EXPORTED void mlirBlockAppendOwnedOperation(MlirBlock block,
                                                       MlirOperation operation);
 
+/// Read the operations defined within the given buffer, containing MLIR
+/// bytecode, into the provided block.
+MLIR_CAPI_EXPORTED MlirLogicalResult mlirBlockAppendParseBytecode(
+    MlirContext context, MlirBlock block, MlirStringRef buffer);
+
+/// Same as mlirBlockAppendParseBytecode but with reader config.
+MLIR_CAPI_EXPORTED MlirLogicalResult mlirBlockAppendParseBytecodeWithConfig(
+    MlirContext context, MlirBlock block, MlirStringRef buffer,
+    MlirBytecodeReaderConfig config);
+
 /// Takes an operation owned by the caller and inserts it as `pos` to the block.
 /// This is an expensive operation that scans the block linearly, prefer
 /// insertBefore/After instead.
diff --git a/mlir/include/mlir/CAPI/IR.h b/mlir/include/mlir/CAPI/IR.h
index 1836cb0acb67e7..c191004040c568 100644
--- a/mlir/include/mlir/CAPI/IR.h
+++ b/mlir/include/mlir/CAPI/IR.h
@@ -15,6 +15,7 @@
 #ifndef MLIR_CAPI_IR_H
 #define MLIR_CAPI_IR_H
 
+#include "mlir/Bytecode/BytecodeReader.h"
 #include "mlir/Bytecode/BytecodeWriter.h"
 #include "mlir/CAPI/Wrap.h"
 #include "mlir/IR/BuiltinOps.h"
@@ -23,6 +24,7 @@
 
 DEFINE_C_API_PTR_METHODS(MlirAsmState, mlir::AsmState)
 DEFINE_C_API_PTR_METHODS(MlirBytecodeWriterConfig, mlir::BytecodeWriterConfig)
+DEFINE_C_API_PTR_METHODS(MlirBytecodeReaderConfig, mlir::BytecodeReaderConfig)
 DEFINE_C_API_PTR_METHODS(MlirContext, mlir::MLIRContext)
 DEFINE_C_API_PTR_METHODS(MlirDialect, mlir::Dialect)
 DEFINE_C_API_PTR_METHODS(MlirDialectRegistry, mlir::DialectRegistry)
diff --git a/mlir/include/mlir/IR/AsmState.h b/mlir/include/mlir/IR/AsmState.h
index 42cbedcf9f8837..415a089bb024b5 100644
--- a/mlir/include/mlir/IR/AsmState.h
+++ b/mlir/include/mlir/IR/AsmState.h
@@ -464,9 +464,11 @@ class ParserConfig {
   /// `fallbackResourceMap` is an optional fallback handler that can be used to
   /// parse external resources not explicitly handled by another parser.
   ParserConfig(MLIRContext *context, bool verifyAfterParse = true,
-               FallbackAsmResourceMap *fallbackResourceMap = nullptr)
+               FallbackAsmResourceMap *fallbackResourceMap = nullptr,
+               BytecodeReaderConfig *bytecodeReaderConfig = nullptr)
       : context(context), verifyAfterParse(verifyAfterParse),
-        fallbackResourceMap(fallbackResourceMap) {
+        fallbackResourceMap(fallbackResourceMap),
+        bytecodeReaderConfig(bytecodeReaderConfig) {
     assert(context && "expected valid MLIR context");
   }
 
@@ -476,9 +478,10 @@ class ParserConfig {
   /// Returns if the parser should verify the IR after parsing.
   bool shouldVerifyAfterParse() const { return verifyAfterParse; }
 
-  /// Returns the parsing configurations associated to the bytecode read.
-  BytecodeReaderConfig &getBytecodeReaderConfig() const {
-    return const_cast<BytecodeReaderConfig &>(bytecodeReaderConfig);
+  /// Returns the parsing configurations associated to the bytecode read,
+  /// returns nullptr if no config was set.
+  BytecodeReaderConfig *getBytecodeReaderConfig() const {
+    return const_cast<BytecodeReaderConfig *>(bytecodeReaderConfig);
   }
 
   /// Return the resource parser registered to the given name, or nullptr if no
@@ -515,7 +518,7 @@ class ParserConfig {
   bool verifyAfterParse;
   DenseMap<StringRef, std::unique_ptr<AsmResourceParser>> resourceParsers;
   FallbackAsmResourceMap *fallbackResourceMap;
-  BytecodeReaderConfig bytecodeReaderConfig;
+  BytecodeReaderConfig *bytecodeReaderConfig;
 };
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
index dd1e4abaea1664..726dc34c960367 100644
--- a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
+++ b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
@@ -1257,35 +1257,35 @@ LogicalResult AttrTypeReader::parseCustomEntry(Entry<T> &entry,
   if (failed(entry.dialect->load(dialectReader, fileLoc.getContext())))
     return failure();
 
-  if constexpr (std::is_same_v<T, Type>) {
-    // Try parsing with callbacks first if available.
-    for (const auto &callback :
-         parserConfig.getBytecodeReaderConfig().getTypeCallbacks()) {
-      if (failed(
-              callback->read(dialectReader, entry.dialect->name, entry.entry)))
-        return failure();
-      // Early return if parsing was successful.
-      if (!!entry.entry)
-        return success();
+  if (auto *config = parserConfig.getBytecodeReaderConfig()) {
+    if constexpr (std::is_same_v<T, Type>) {
+      // Try parsing with callbacks first if available.
+      for (const auto &callback : config->getTypeCallbacks()) {
+        if (failed(callback->read(dialectReader, entry.dialect->name,
+                                  entry.entry)))
+          return failure();
+        // Early return if parsing was successful.
+        if (!!entry.entry)
+          return success();
 
-      // Reset the reader if we failed to parse, so we can fall through the
-      // other parsing functions.
-      reader = EncodingReader(entry.data, reader.getLoc());
-    }
-  } else {
-    // Try parsing with callbacks first if available.
-    for (const auto &callback :
-         parserConfig.getBytecodeReaderConfig().getAttributeCallbacks()) {
-      if (failed(
-              callback->read(dialectReader, entry.dialect->name, entry.entry)))
-        return failure();
-      // Early return if parsing was successful.
-      if (!!entry.entry)
-        return success();
+        // Reset the reader if we failed to parse, so we can fall through the
+        // other parsing functions.
+        reader = EncodingReader(entry.data, reader.getLoc());
+      }
+    } else {
+      // Try parsing with callbacks first if available.
+      for (const auto &callback : config->getAttributeCallbacks()) {
+        if (failed(callback->read(dialectReader, entry.dialect->name,
+                                  entry.entry)))
+          return failure();
+        // Early return if parsing was successful.
+        if (!!entry.entry)
+          return success();
 
-      // Reset the reader if we failed to parse, so we can fall through the
-      // other parsing functions.
-      reader = EncodingReader(entry.data, reader.getLoc());
+        // Reset the reader if we failed to parse, so we can fall through the
+        // other parsing functions.
+        reader = EncodingReader(entry.data, reader.getLoc());
+      }
     }
   }
 
diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp
index cdb64f4ec4a40f..2e9206077ac52b 100644
--- a/mlir/lib/CAPI/IR/IR.cpp
+++ b/mlir/lib/CAPI/IR/IR.cpp
@@ -10,6 +10,7 @@
 #include "mlir-c/Support.h"
 
 #include "mlir/AsmParser/AsmParser.h"
+#include "mlir/Bytecode/BytecodeReader.h"
 #include "mlir/Bytecode/BytecodeWriter.h"
 #include "mlir/CAPI/IR.h"
 #include "mlir/CAPI/Support.h"
@@ -28,6 +29,7 @@
 #include "mlir/IR/Visitors.h"
 #include "mlir/Interfaces/InferTypeOpInterface.h"
 #include "mlir/Parser/Parser.h"
+#include "llvm/Support/MemoryBuffer.h"
 #include "llvm/Support/ThreadPool.h"
 
 #include <cstddef>
@@ -236,6 +238,18 @@ void mlirBytecodeWriterConfigDesiredEmitVersion(MlirBytecodeWriterConfig flags,
   unwrap(flags)->setDesiredBytecodeVersion(version);
 }
 
+//===----------------------------------------------------------------------===//
+// Bytecode parsing flags API.
+//===----------------------------------------------------------------------===//
+
+MlirBytecodeReaderConfig mlirBytecodeReaderConfigCreate() {
+  return wrap(new BytecodeReaderConfig());
+}
+
+void mlirBytecodeReaderConfigDestroy(MlirBytecodeReaderConfig config) {
+  delete unwrap(config);
+}
+
 //===----------------------------------------------------------------------===//
 // Location API.
 //===----------------------------------------------------------------------===//
@@ -840,6 +854,34 @@ void mlirBlockAppendOwnedOperation(MlirBlock block, MlirOperation operation) {
   unwrap(block)->push_back(unwrap(operation));
 }
 
+static MlirLogicalResult
+mlirBlockAppendParseBytecodeImpl(MlirContext context, MlirBlock block,
+                                 MlirStringRef buffer,
+                                 BytecodeReaderConfig *config) {
+  auto memBuffer = llvm::MemoryBuffer::getMemBuffer(
+      unwrap(buffer), /* BufferName */ __func__,
+      /* RequiresNullTerminator */ false);
+  auto parserConfig =
+      mlir::ParserConfig{unwrap(context), /* verifyAfterParse */ true,
+                         /* fallbackResourceMap */ nullptr,
+                         /* bytecodeReaderConfig */ config};
+  return wrap(mlir::readBytecodeFile(*memBuffer, unwrap(block), parserConfig));
+}
+
+MlirLogicalResult mlirBlockAppendParseBytecode(MlirContext context,
+                                               MlirBlock block,
+                                               MlirStringRef buffer) {
+  return mlirBlockAppendParseBytecodeImpl(context, block, buffer, nullptr);
+}
+
+MlirLogicalResult
+mlirBlockAppendParseBytecodeWithConfig(MlirContext context, MlirBlock block,
+                                       MlirStringRef buffer,
+                                       MlirBytecodeReaderConfig config) {
+  return mlirBlockAppendParseBytecodeImpl(context, block, buffer,
+                                          unwrap(config));
+}
+
 void mlirBlockInsertOwnedOperation(MlirBlock block, intptr_t pos,
                                    MlirOperation operation) {
   auto &opList = unwrap(block)->getOperations();
diff --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c
index a9850c0a132e75..586eb2443d32e9 100644
--- a/mlir/test/CAPI/ir.c
+++ b/mlir/test/CAPI/ir.c
@@ -2381,6 +2381,68 @@ void testDiagnostics(void) {
   mlirContextDestroy(ctx);
 }
 
+void callbackCollectStreamData(MlirStringRef input, void *userData) {
+  MlirStringRef *out = (MlirStringRef *)userData;
+  if (out->data == NULL) {
+    out->data = malloc(input.length);
+    memcpy(out->data, input.data, input.length);
+    out->length = input.length;
+  } else {
+    out->data = realloc(out->data, out->length + input.length);
+    memcpy(out->data + out->length, input.data, input.length);
+    out->length += input.length;
+  }
+}
+
+void testMlirBytecodeReadWrite(MlirContext ctx) {
+  const char *moduleString = "module {\n"
+                             "  func.func @mlirbc_test() {\n"
+                             "    %1 = arith.constant 114: i32\n"
+                             "    %2 = arith.constant 514: i32\n"
+                             "    arith.addi %1, %2: i32\n"
+                             "    return\n"
+                             "  }\n"
+                             "}";
+  MlirModule module =
+      mlirModuleCreateParse(ctx, mlirStringRefCreateFromCString(moduleString));
+
+  MlirStringRef bytecode = mlirStringRefCreate(NULL, 0);
+  mlirOperationWriteBytecode(mlirModuleGetOperation(module),
+                             callbackCollectStreamData, &bytecode);
+
+  MlirBlock blockBc = mlirBlockCreate(0, NULL, NULL);
+  MlirLogicalResult result =
+      mlirBlockAppendParseBytecode(ctx, blockBc, bytecode);
+  assert(mlirLogicalResultIsSuccess(result));
+
+  MlirModule moduleBc =
+      mlirModuleFromOperation(mlirBlockGetFirstOperation(blockBc));
+
+  fprintf(stderr, "===== mlirbc-test manually =====\n");
+  mlirOperationDump(mlirModuleGetOperation(module));
+  fprintf(stderr, "===== mlirbc-test parsed =====\n");
+  mlirOperationDump(mlirModuleGetOperation(moduleBc));
+
+  // CHECK:      ===== mlirbc-test manually =====
+  // CHECK-NEXT: module {
+  // CHECK-NEXT:   func.func @mlirbc_test() {
+  // CHECK-NEXT:     %c114_i32 = arith.constant 114 : i32
+  // CHECK-NEXT:     %c514_i32 = arith.constant 514 : i32
+  // CHECK-NEXT:     %0 = arith.addi %c114_i32, %c514_i32 : i32
+  // CHECK-NEXT:     return
+  // CHECK-NEXT:   }
+  // CHECK-NEXT: }
+  // CHECK-NEXT: ===== mlirbc-test parsed =====
+  // CHECK-NEXT: module {
+  // CHECK-NEXT:   func.func @mlirbc_test() {
+  // CHECK-NEXT:     %c114_i32 = arith.constant 114 : i32
+  // CHECK-NEXT:     %c514_i32 = arith.constant 514 : i32
+  // CHECK-NEXT:     %0 = arith.addi %c114_i32, %c514_i32 : i32
+  // CHECK-NEXT:     return
+  // CHECK-NEXT:   }
+  // CHECK-NEXT: }
+}
+
 int main(void) {
   MlirContext ctx = mlirContextCreate();
   registerAllUpstreamDialects(ctx);
@@ -2426,6 +2488,7 @@ int main(void) {
 
   testExplicitThreadPools();
   testDiagnostics();
+  testMlirBytecodeReadWrite(ctx);
 
   // CHECK: DESTROY MAIN CONTEXT
   // CHECK: reportResourceDelete: resource_i64_blob
diff --git a/mlir/test/lib/IR/TestBytecodeRoundtrip.cpp b/mlir/test/lib/IR/TestBytecodeRoundtrip.cpp
index e668224d343234..4e9d8c0a5f636d 100644
--- a/mlir/test/lib/IR/TestBytecodeRoundtrip.cpp
+++ b/mlir/test/lib/IR/TestBytecodeRoundtrip.cpp
@@ -159,40 +159,43 @@ struct TestBytecodeRoundtripPass
         });
     newCtx->appendDialectRegistry(op->getContext()->getDialectRegistry());
     newCtx->allowUnregisteredDialects();
-    ParserConfig parseConfig(newCtx.get(), /*verifyAfterParse=*/true);
-    parseConfig.getBytecodeReaderConfig().attachTypeCallback(
-        [&](DialectBytecodeReader &reader, StringRef dialectName,
-            Type &entry) -> LogicalResult {
-          // Get test dialect version from the version map.
-          auto versionOr = reader.getDialectVersion<test::TestDialect>();
-          assert(succeeded(versionOr) && "expected reader to be able to access "
-                                         "the version for test dialect");
-          const auto *version =
-              reinterpret_cast<const test::TestDialectVersion *>(*versionOr);
-          if (version->major_ >= 2)
-            return success();
+    BytecodeReaderConfig readConfig;
+    readConfig.attachTypeCallback([&](DialectBytecodeReader &reader,
+                                      StringRef dialectName,
+                                      Type &entry) -> LogicalResult {
+      // Get test dialect version from the version map.
+      auto versionOr = reader.getDialectVersion<test::TestDialect>();
+      assert(succeeded(versionOr) && "expected reader to be able to access "
+                                     "the version for test dialect");
+      const auto *version =
+          reinterpret_cast<const test::TestDialectVersion *>(*versionOr);
+      if (version->major_ >= 2)
+        return success();
 
-          // `dialectName` is the name of the group we have the opportunity to
-          // override. In this case, override only the dialect group "funky",
-          // for which does not exist in memory.
-          if (dialectName != StringLiteral("funky"))
-            return success();
+      // `dialectName` is the name of the group we have the opportunity to
+      // override. In this case, override only the dialect group "funky",
+      // for which does not exist in memory.
+      if (dialectName != StringLiteral("funky"))
+        return success();
 
-          uint64_t encoding;
-          if (failed(reader.readVarInt(encoding)) || encoding != 999)
-            return success();
-          llvm::outs() << "Overriding parsing of IntegerType encoding...\n";
-          uint64_t widthAndSignedness, width;
-          IntegerType::SignednessSemantics signedness;
-          if (succeeded(reader.readVarInt(widthAndSignedness)) &&
-              ((width = widthAndSignedness >> 2), true) &&
-              ((signedness = static_cast<IntegerType::SignednessSemantics>(
-                    widthAndSignedness & 0x3)),
-               true))
-            entry = IntegerType::get(reader.getContext(), width, signedness);
-          // Return nullopt to fall through the rest of the parsing code path.
-          return success();
-        });
+      uint64_t encoding;
+      if (failed(reader.readVarInt(encoding)) || encoding != 999)
+        return success();
+      llvm::outs() << "Overriding parsing of IntegerType encoding...\n";
+      uint64_t widthAndSignedness, width;
+      IntegerType::SignednessSemantics signedness;
+      if (succeeded(reader.readVarInt(widthAndSignedness)) &&
+          ((width = widthAndSignedness >> 2), true) &&
+          ((signedness = static_cast<IntegerType::SignednessSemantics>(
+                widthAndSignedness & 0x3)),
+           true))
+        entry = IntegerType::get(reader.getContext(), width, signedness);
+      // Return nullopt to fall through the rest of the parsing code path.
+      return success();
+    });
+    ParserConfig parseConfig(newCtx.get(), /*verifyAfterParse=*/true,
+                             /*fallbackResourceMap=*/nullptr,
+                             /*bytecodeReaderConfig=*/&readConfig);
     doRoundtripWithConfigs(op, writeConfig, parseConfig);
   }
 
@@ -235,22 +238,24 @@ struct TestBytecodeRoundtripPass
     BytecodeDialectInterface *iface =
         builtin->getRegisteredInterface<BytecodeDialectInterface>();
     BytecodeWriterConfig writeConfig;
-    ParserConfig parseConfig(op->getContext(), /*verifyAfterParse=*/true);
-    parseConfig.getBytecodeReaderConfig().attachTypeCallback(
-        [&](DialectBytecodeReader &reader, StringRef dialectName,
-            Type &entry) -> LogicalResult {
-          if (dialectName != StringLiteral("builtin"))
-            return success();
-          Type builtinAttr = iface->readType(reader);
-          if (auto integerType =
-                  llvm::dyn_cast_or_null<IntegerType>(builtinAttr)) {
-            if (integerType.getWidth() == 32 && integerType.isSignless()) {
-              llvm::outs() << "Overriding parsing of TestI32Type encoding...\n";
-              entry = test::TestI32Type::get(reader.getContext());
-            }
-          }
-          return success();
-        });
+    BytecodeReaderConfig readConfig;
+    readConfig.attachTypeCallback([&](DialectBytecodeReader &reader,
+                                      StringRef dialectName,
+                                      Type &entry) -> LogicalResult {
+      if (dialectName != StringLiteral("builtin"))
+        return success();
+      Type builtinAttr = iface->readType(reader);
+      if (auto integerType = llvm::dyn_cast_or_null<IntegerType>(builtinAttr)) {
+        if (integerType.getWidth() == 32 && integerType.isSignless()) {
+          llvm::outs() << "Overriding parsing of TestI32Type encoding...\n";
+          entry = test::TestI32Type::get(reader.getContext());
+        }
+      }
+      return success();
+    });
+    ParserConfig parseConfig(op->getContext(), /*verifyAfterParse=*/true,
+                             /*fallbackResourceMap=*/nullptr,
+                             /*bytecodeReaderConfig=*/&readConfig);
     doRoundtripWithConfigs(op, writeConfig, parseConfig);
   }
 
@@ -301,28 +306,30 @@ struct TestBytecodeRoundtripPass
     auto i32Type = IntegerType::get(op->getContext(), 32,
                                     IntegerType::SignednessSemantics::Signless);
     BytecodeWriterConfig writeConfig;
-    ParserConfig parseConfig(op->getContext(), /*verifyAfterParse=*/false);
-    parseConfig.getBytecodeReaderConfig().attachAttributeCallback(
-        [&](DialectBytecodeReader &reader, StringRef dialectName,
-            Attribute &entry) -> LogicalResult {
-          // Override only the case where the return type of the builtin reader
-          // is an i32 and fall through on all the other cases, since we want to
-          // still use TestDialect normal codepath to parse the other types.
-          Attribute builtinAttr = iface->readAttribute(reader);
-          if (auto denseAttr =
-                  llvm::dyn_cast_or_null<DenseIntElementsAttr>(builtinAttr)) {
-            if (denseAttr.getType().getShape() == ArrayRef<int64_t>(2) &&
-                denseAttr.getElementType() == i32Type) {
-              llvm::outs()
-                  << "Overriding parsing of TestAttrParamsAttr encoding...\n";
-              int v0 = denseAttr.getValues<IntegerAttr>()[0].getInt();
-              int v1 = denseAttr.getValues<IntegerAttr>()[1].getInt();
-              entry =
-                  test::TestAttrParamsAttr::get(reader.getContext(), v0, v1);
-            }
-          }
-          return success();
-        });
+    BytecodeReaderConfig readConfig;
+    readConfig.attachAttributeCallback([&](DialectBytecodeReader &reader,
+                                           StringRef dialectName,
+                                           Attribute &entry) -> LogicalResult {
+      // Override only the case where the return type of the builtin reader
+      // is an i32 and fall through on all the other cases, since we want to
+      // still use TestDialect normal codepath to parse the other types.
+      Attribute builtinAttr = iface->readAttribute(reader);
+      if (auto denseAttr =
+              llvm::dyn_cast_or_null<DenseIntElementsAttr>(builtinAttr)) {
+        if (denseAttr.getType().getShape() == ArrayRef<int64_t>(2) &&
+            denseAttr.getElementType() == i32Type) {
+          llvm::outs()
+              << "Overriding parsing of TestAttrParamsAttr encoding...\n";
+          int v0 = denseAttr.getValues<IntegerAttr>()[0].getInt();
+          int v1 = denseAttr.getValues<IntegerAttr>()[1].getInt();
+          entry = test::TestAttrParamsAttr::get(reader.getContext(), v0, v1);
+        }
+      }
+      return success();
+    });
+    ParserConfig parseConfig(op->getContext(), /*verifyAfterParse=*/false,
+                             /*fallbackResourceMap=*/nullptr,
+                             /*bytecodeReaderConfig=*/&readConfig);
     doRoundtripWithConfigs(op, writeConfig, parseConfig);
   }
 
@@ -344,26 +351,29 @@ struct TestBytecodeRoundtripPass
             DialectBytecodeWriter &writer) -> LogicalResult {
           return iface->writeType(type, writer);
         });
-    ParserConfig parseConfig(op->getContext(), /*verifyAfterParse=*/false);
-    parseConfig.getBytecodeReaderConfig().attachAttributeCallback(
-        [&](DialectBytecodeReader &reader, StringRef dialectName,
-            Attribute &entry) -> LogicalResult {
-          Attribute builtinAttr = iface->readAttribute(reader);
-          if (!builtinAttr)
-            return failure();
-          entry = builtinAttr;
-          return success();
-        });
-    parseConfig.getBytecodeReaderConfig().attachTypeCallback(
-        [&](DialectBytecodeReader &reader, StringRef dialectName,
-            Type &entry) -> LogicalResult {
-          Type builtinType = iface->readType(reader);
-          if (!builtinType) {
-            return failure();
-          }
-          entry = builtinType;
-          return success();
-        });
+    BytecodeReaderConfig readConfig;
+    readConfig.attachAttributeCallback([&](DialectBytecodeReader &reader,
+                                           StringRef dialectName,
+                                           Attribute &entry) -> LogicalResult {
+      Attribute builtinAttr = iface->readAttribute(reader);
+      if (!builtinAttr)
+        return failure();
+      entry = builtinAttr;
+      return success();
+    });
+    readConfig.attachTypeCallback([&](DialectBytecodeReader &reader,
+                                      StringRef dialectName,
+                                      Type &entry) -> LogicalResult {
+      Type builtinType = iface->readType(reader);
+      if (!builtinType) {
+        return failure();
+      }
+      entry = builtinType;
+      return success();
+    });
+    ParserConfig parseConfig(op->getContext(), /*verifyAfterParse=*/false,
+                             /*fallbackResourceMap=*/nullptr,
+                             /*bytecodeReaderConfig=*/&readConfig);
     doRoundtripWithConfigs(op, writeConfig, parseConfig);
   }
 



More information about the Mlir-commits mailing list