[Mlir-commits] [mlir] [mlir][RFC] Bytecode: op fallback path (PR #129784)
Nikhil Kalra
llvmlistbot at llvm.org
Wed Mar 5 11:16:58 PST 2025
https://github.com/nikalra updated https://github.com/llvm/llvm-project/pull/129784
>From fab9e00c6393919b4e066ef68e6fdc7f407748ab Mon Sep 17 00:00:00 2001
From: Nikhil Kalra <nkalra at apple.com>
Date: Wed, 26 Feb 2025 16:53:00 -0800
Subject: [PATCH 1/9] first shim
---
.../mlir/Bytecode/BytecodeImplementation.h | 9 +++++++
.../mlir/Bytecode/BytecodeOpInterface.td | 24 ++++++++++++++++++
mlir/lib/Bytecode/Reader/BytecodeReader.cpp | 25 ++++++++++++++-----
mlir/lib/Bytecode/Writer/BytecodeWriter.cpp | 7 ++++++
4 files changed, 59 insertions(+), 6 deletions(-)
diff --git a/mlir/include/mlir/Bytecode/BytecodeImplementation.h b/mlir/include/mlir/Bytecode/BytecodeImplementation.h
index 0ddc531073e23..36fa010f7e11e 100644
--- a/mlir/include/mlir/Bytecode/BytecodeImplementation.h
+++ b/mlir/include/mlir/Bytecode/BytecodeImplementation.h
@@ -21,6 +21,7 @@
#include "mlir/IR/OpImplementation.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/Twine.h"
+#include "llvm/Support/ErrorHandling.h"
namespace mlir {
//===--------------------------------------------------------------------===//
@@ -445,6 +446,14 @@ class BytecodeDialectInterface
return Type();
}
+ /// Fall back to an operation of this type if parsing an op from bytecode
+ /// fails for any reason. This can be used to handle new ops emitted from a
+ /// different version of the dialect, that cannot be read by an older version
+ /// of the dialect.
+ virtual FailureOr<OperationName> getFallbackOperationName() const {
+ return failure();
+ }
+
//===--------------------------------------------------------------------===//
// Writing
//===--------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Bytecode/BytecodeOpInterface.td b/mlir/include/mlir/Bytecode/BytecodeOpInterface.td
index 54fb03e34ec51..7069933ab666b 100644
--- a/mlir/include/mlir/Bytecode/BytecodeOpInterface.td
+++ b/mlir/include/mlir/Bytecode/BytecodeOpInterface.td
@@ -40,4 +40,28 @@ def BytecodeOpInterface : OpInterface<"BytecodeOpInterface"> {
];
}
+// `FallbackBytecodeOpInterface`
+def FallbackBytecodeOpInterface : OpInterface<"FallbackBytecodeOpInterface"> {
+ let description = [{
+ This interface allows fallback operations direct access to the bytecode
+ property streams.
+ }];
+ let cppNamespace = "::mlir";
+
+ let methods = [
+ StaticInterfaceMethod<[{
+ Read the properties blob for this operation from the bytecode and populate the state.
+ }],
+ "LogicalResult", "readPropertiesBlob", (ins
+ "ArrayRef<char>":$blob,
+ "::mlir::OperationState &":$state)
+ >,
+ InterfaceMethod<[{
+ Get the properties blob for this operation to be emitted into the bytecode.
+ }],
+ "ArrayRef<char>", "getPropertiesBlob", (ins)
+ >,
+ ];
+}
+
#endif // MLIR_BYTECODE_BYTECODEOPINTERFACES
diff --git a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
index 1204f1c069b1e..bcf5446013f00 100644
--- a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
+++ b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
@@ -1106,13 +1106,18 @@ class PropertiesSectionReader {
dialectReader.withEncodingReader(reader).readBlob(rawProperties)))
return failure();
}
+
+ // If the op is a "fallback" op, give it a handle to the raw properties
+ // buffer.
+ if (auto *iface = opName->getInterface<FallbackBytecodeOpInterface>())
+ return iface->readPropertiesBlob(rawProperties, opState);
+
// Setup a new reader to read from the `rawProperties` sub-buffer.
EncodingReader reader(
StringRef(rawProperties.begin(), rawProperties.size()), fileLoc);
DialectReader propReader = dialectReader.withEncodingReader(reader);
- auto *iface = opName->getInterface<BytecodeOpInterface>();
- if (iface)
+ if (auto *iface = opName->getInterface<BytecodeOpInterface>())
return iface->readProperties(propReader, opState);
if (opName->isRegistered())
return propReader.emitError(
@@ -1506,7 +1511,7 @@ class mlir::BytecodeReader::Impl {
UseListOrderStorage(bool isIndexPairEncoding,
SmallVector<unsigned, 4> &&indices)
: indices(std::move(indices)),
- isIndexPairEncoding(isIndexPairEncoding){};
+ isIndexPairEncoding(isIndexPairEncoding) {};
/// The vector containing the information required to reorder the
/// use-list of a value.
SmallVector<unsigned, 4> indices;
@@ -1863,10 +1868,18 @@ BytecodeReader::Impl::parseOpName(EncodingReader &reader,
// Load the dialect and its version.
DialectReader dialectReader(attrTypeReader, stringReader, resourceReader,
dialectsMap, reader, version);
- if (failed(opName->dialect->load(dialectReader, getContext())))
+ if (succeeded(opName->dialect->load(dialectReader, getContext()))) {
+ opName->opName.emplace(
+ (opName->dialect->name + "." + opName->name).str(), getContext());
+ } else if (auto fallbackOp =
+ opName->dialect->interface->getFallbackOperationName();
+ succeeded(fallbackOp)) {
+ // If the dialect's bytecode interface specifies a fallback op, we want
+ // to use that instead of an unregistered op.
+ opName->opName.emplace(*fallbackOp);
+ } else {
return failure();
- opName->opName.emplace((opName->dialect->name + "." + opName->name).str(),
- getContext());
+ }
}
}
return *opName->opName;
diff --git a/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp
index cc5aaed416512..d3877f1d868ea 100644
--- a/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp
+++ b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp
@@ -518,6 +518,13 @@ class PropertiesSectionBuilder {
return emit(scratch);
}
+ if (auto iface = dyn_cast<FallbackBytecodeOpInterface>(op)) {
+ // Fallback ops should write the properties payload to the bytecode buffer
+ // directly.
+ ArrayRef<char> encodedProperties = iface.getPropertiesBlob();
+ return emit(encodedProperties);
+ }
+
EncodingEmitter emitter;
DialectWriter propertiesWriter(config.bytecodeVersion, emitter,
numberingState, stringSection,
>From 34587b886d756f8afee77449e012e7f33af1e4f5 Mon Sep 17 00:00:00 2001
From: Nikhil Kalra <nkalra at apple.com>
Date: Fri, 28 Feb 2025 12:12:42 -0800
Subject: [PATCH 2/9] reading working
---
.../mlir/Bytecode/BytecodeOpInterface.td | 13 +++--
mlir/lib/Bytecode/Reader/BytecodeReader.cpp | 46 +++++++++++------
.../Dialect/Test/TestDialectInterfaces.cpp | 6 +++
mlir/test/lib/Dialect/Test/TestOpDefs.cpp | 50 +++++++++++++++++++
mlir/test/lib/Dialect/Test/TestOps.td | 21 +++++++-
5 files changed, 116 insertions(+), 20 deletions(-)
diff --git a/mlir/include/mlir/Bytecode/BytecodeOpInterface.td b/mlir/include/mlir/Bytecode/BytecodeOpInterface.td
index 7069933ab666b..f85f7746441aa 100644
--- a/mlir/include/mlir/Bytecode/BytecodeOpInterface.td
+++ b/mlir/include/mlir/Bytecode/BytecodeOpInterface.td
@@ -49,17 +49,24 @@ def FallbackBytecodeOpInterface : OpInterface<"FallbackBytecodeOpInterface"> {
let cppNamespace = "::mlir";
let methods = [
+ StaticInterfaceMethod<[{
+ Set the original name for this operation from the bytecode.
+ }],
+ "void", "setOriginalOperationName", (ins
+ "::mlir::StringRef":$opName,
+ "::mlir::OperationState &":$state)
+ >,
StaticInterfaceMethod<[{
Read the properties blob for this operation from the bytecode and populate the state.
}],
- "LogicalResult", "readPropertiesBlob", (ins
- "ArrayRef<char>":$blob,
+ "::mlir::LogicalResult", "readPropertiesBlob", (ins
+ "::mlir::ArrayRef<char>":$blob,
"::mlir::OperationState &":$state)
>,
InterfaceMethod<[{
Get the properties blob for this operation to be emitted into the bytecode.
}],
- "ArrayRef<char>", "getPropertiesBlob", (ins)
+ "::mlir::ArrayRef<char>", "getPropertiesBlob", (ins)
>,
];
}
diff --git a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
index bcf5446013f00..29b3e27c1fd42 100644
--- a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
+++ b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
@@ -1415,8 +1415,8 @@ class mlir::BytecodeReader::Impl {
/// Parse an operation name reference using the given reader, and set the
/// `wasRegistered` flag that indicates if the bytecode was produced by a
/// context where opName was registered.
- FailureOr<OperationName> parseOpName(EncodingReader &reader,
- std::optional<bool> &wasRegistered);
+ FailureOr<BytecodeOperationName *>
+ parseOpName(EncodingReader &reader, std::optional<bool> &wasRegistered);
//===--------------------------------------------------------------------===//
// Attribute/Type Section
@@ -1848,7 +1848,7 @@ BytecodeReader::Impl::parseDialectSection(ArrayRef<uint8_t> sectionData) {
return success();
}
-FailureOr<OperationName>
+FailureOr<BytecodeOperationName *>
BytecodeReader::Impl::parseOpName(EncodingReader &reader,
std::optional<bool> &wasRegistered) {
BytecodeOperationName *opName = nullptr;
@@ -1868,21 +1868,28 @@ BytecodeReader::Impl::parseOpName(EncodingReader &reader,
// Load the dialect and its version.
DialectReader dialectReader(attrTypeReader, stringReader, resourceReader,
dialectsMap, reader, version);
- if (succeeded(opName->dialect->load(dialectReader, getContext()))) {
- opName->opName.emplace(
- (opName->dialect->name + "." + opName->name).str(), getContext());
- } else if (auto fallbackOp =
- opName->dialect->interface->getFallbackOperationName();
- succeeded(fallbackOp)) {
- // If the dialect's bytecode interface specifies a fallback op, we want
- // to use that instead of an unregistered op.
- opName->opName.emplace(*fallbackOp);
- } else {
+ if (failed(opName->dialect->load(dialectReader, getContext())))
return failure();
+
+ opName->opName.emplace((opName->dialect->name + "." + opName->name).str(),
+ getContext());
+
+ // If the op is unregistered now, but was not marked as unregistered, try
+ // to parse it as a fallback op if the dialect's bytecode interface
+ // specifies one.
+ // We don't treat this condition as an error because we may still be able
+ // to parse the op as an unregistered op if it doesn't use custom
+ // properties encoding.
+ if (wasRegistered && !opName->opName->isRegistered()) {
+ if (auto fallbackOp =
+ opName->dialect->interface->getFallbackOperationName();
+ succeeded(fallbackOp)) {
+ opName->opName.emplace(*fallbackOp);
+ }
}
}
}
- return *opName->opName;
+ return opName;
}
//===----------------------------------------------------------------------===//
@@ -2227,8 +2234,12 @@ BytecodeReader::Impl::parseOpWithoutRegions(EncodingReader &reader,
bool &isIsolatedFromAbove) {
// Parse the name of the operation.
std::optional<bool> wasRegistered;
- FailureOr<OperationName> opName = parseOpName(reader, wasRegistered);
- if (failed(opName))
+ FailureOr<BytecodeOperationName *> bytecodeOp =
+ parseOpName(reader, wasRegistered);
+ if (failed(bytecodeOp))
+ return failure();
+ auto opName = (*bytecodeOp)->opName;
+ if (!opName)
return failure();
// Parse the operation mask, which indicates which components of the operation
@@ -2245,6 +2256,9 @@ BytecodeReader::Impl::parseOpWithoutRegions(EncodingReader &reader,
// With the location and name resolved, we can start building the operation
// state.
OperationState opState(opLoc, *opName);
+ // If this is a fallback op, provide the original name of the operation.
+ if (auto *iface = opName->getInterface<FallbackBytecodeOpInterface>())
+ iface->setOriginalOperationName((*bytecodeOp)->name, opState);
// Parse the attributes of the operation.
if (opMask & bytecode::OpEncodingMask::kHasAttrs) {
diff --git a/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp b/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp
index 64add8cef3698..24e31ec44a85b 100644
--- a/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp
@@ -8,6 +8,7 @@
#include "TestDialect.h"
#include "TestOps.h"
+#include "mlir/IR/OperationSupport.h"
#include "mlir/Interfaces/FoldInterfaces.h"
#include "mlir/Reducer/ReductionPatternInterface.h"
#include "mlir/Transforms/InliningUtils.h"
@@ -92,6 +93,11 @@ struct TestBytecodeDialectInterface : public BytecodeDialectInterface {
return Attribute();
}
+ FailureOr<OperationName> getFallbackOperationName() const final {
+ return OperationName(TestBytecodeFallbackOp::getOperationName(),
+ getContext());
+ }
+
// Emit a specific version of the dialect.
void writeVersion(DialectBytecodeWriter &writer) const final {
// Construct the current dialect version.
diff --git a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
index f6b8a0005f285..747ccf30361e7 100644
--- a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
+++ b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
@@ -9,9 +9,12 @@
#include "TestDialect.h"
#include "TestOps.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Verifier.h"
#include "mlir/Interfaces/FunctionImplementation.h"
#include "mlir/Interfaces/MemorySlotInterfaces.h"
+#include "llvm/Support/LogicalResult.h"
+#include <cstdint>
using namespace mlir;
using namespace test;
@@ -1230,6 +1233,53 @@ void TestVersionedOpA::writeProperties(mlir::DialectBytecodeWriter &writer) {
writer.writeAttribute(prop.modifier);
}
+//===----------------------------------------------------------------------===//
+// TestVersionedOpD
+//===----------------------------------------------------------------------===//
+
+// LogicalResult
+// TestVersionedOpD::readProperties(mlir::DialectBytecodeReader &reader,
+// mlir::OperationState &state) {
+// auto &prop = state.getOrAddProperties<Properties>();
+// StringRef res;
+// if (failed(reader.readString(res)))
+// return failure();
+// if (failed(reader.readAttribute(prop.attribute)))
+// return failure();
+
+// return success();
+// }
+
+// void TestVersionedOpD::writeProperties(mlir::DialectBytecodeWriter &writer) {
+// auto &prop = getProperties();
+// writer.writeOwnedString("version 1");
+// writer.writeAttribute(prop.attribute);
+// }
+
+//===----------------------------------------------------------------------===//
+// TestBytecodeFallbackOp
+//===----------------------------------------------------------------------===//
+
+void TestBytecodeFallbackOp::setOriginalOperationName(StringRef name,
+ OperationState &state) {
+ state.getOrAddProperties<Properties>().setOpname(
+ StringAttr::get(state.getContext(), name));
+}
+
+LogicalResult
+TestBytecodeFallbackOp::readPropertiesBlob(ArrayRef<char> blob,
+ OperationState &state) {
+ state.getOrAddProperties<Properties>().bytecodeProperties =
+ DenseI8ArrayAttr::get(state.getContext(),
+ ArrayRef((const int8_t *)blob.data(), blob.size()));
+ return success();
+}
+
+ArrayRef<char> TestBytecodeFallbackOp::getPropertiesBlob() {
+ return ArrayRef((const char *)getBytecodeProperties().data(),
+ getBytecodeProperties().size());
+}
+
//===----------------------------------------------------------------------===//
// TestOpWithVersionedProperties
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index cdc1237ec8c5a..52ca3a81fe2c7 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -20,6 +20,7 @@ include "mlir/IR/OpAsmInterface.td"
include "mlir/IR/PatternBase.td"
include "mlir/IR/RegionKindInterface.td"
include "mlir/IR/SymbolInterfaces.td"
+include "mlir/Bytecode/BytecodeOpInterface.td"
include "mlir/Interfaces/CallInterfaces.td"
include "mlir/Interfaces/ControlFlowInterfaces.td"
include "mlir/Interfaces/CopyOpInterface.td"
@@ -31,7 +32,6 @@ include "mlir/Interfaces/LoopLikeInterface.td"
include "mlir/Interfaces/MemorySlotInterfaces.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
-
// Include the attribute definitions.
include "TestAttrDefs.td"
// Include the type definitions.
@@ -3030,6 +3030,25 @@ def TestVersionedOpC : TEST_Op<"versionedC"> {
);
}
+// def TestVersionedOpD : TEST_Op<"versionedD"> {
+// let arguments = (ins AnyAttrOf<[TestAttrParams,
+// I32ElementsAttr]>:$attribute
+// );
+
+// let useCustomPropertiesEncoding = 1;
+// }
+
+def TestBytecodeFallbackOp : TEST_Op<"bytecode.fallback", [
+ DeclareOpInterfaceMethods<FallbackBytecodeOpInterface, ["setOriginalOperationName", "readPropertiesBlob", "getPropertiesBlob"]>
+]> {
+ let arguments = (ins
+ StrAttr:$opname,
+ DenseI8ArrayAttr:$bytecodeProperties,
+ Variadic<AnyType>:$operands);
+ let regions = (region VariadicRegion<AnyRegion>:$bodyRegions);
+ let results = (outs Variadic<AnyType>:$results);
+}
+
//===----------------------------------------------------------------------===//
// Test Properties
//===----------------------------------------------------------------------===//
>From d76f47b2ad26255677c0c8a611913b0d6d3c28f3 Mon Sep 17 00:00:00 2001
From: Nikhil Kalra <nkalra at apple.com>
Date: Mon, 3 Mar 2025 12:38:35 -0800
Subject: [PATCH 3/9] roundtrip working
---
.../mlir/Bytecode/BytecodeOpInterface.td | 5 +++++
mlir/lib/Bytecode/Writer/BytecodeWriter.cpp | 18 ++++++++++++++++--
mlir/lib/Bytecode/Writer/IRNumbering.cpp | 12 +++++++++++-
mlir/test/lib/Dialect/Test/TestOpDefs.cpp | 5 +++++
4 files changed, 37 insertions(+), 3 deletions(-)
diff --git a/mlir/include/mlir/Bytecode/BytecodeOpInterface.td b/mlir/include/mlir/Bytecode/BytecodeOpInterface.td
index f85f7746441aa..b7bffd20d9329 100644
--- a/mlir/include/mlir/Bytecode/BytecodeOpInterface.td
+++ b/mlir/include/mlir/Bytecode/BytecodeOpInterface.td
@@ -56,6 +56,11 @@ def FallbackBytecodeOpInterface : OpInterface<"FallbackBytecodeOpInterface"> {
"::mlir::StringRef":$opName,
"::mlir::OperationState &":$state)
>,
+ InterfaceMethod<[{
+ Get the original name for this operation from the bytecode.
+ }],
+ "::mlir::StringRef", "getOriginalOperationName", (ins)
+ >,
StaticInterfaceMethod<[{
Read the properties blob for this operation from the bytecode and populate the state.
}],
diff --git a/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp
index d3877f1d868ea..75ed5f998fa73 100644
--- a/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp
+++ b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp
@@ -848,11 +848,16 @@ void BytecodeWriter::writeDialectSection(EncodingEmitter &emitter) {
// Emit the referenced operation names grouped by dialect.
auto emitOpName = [&](OpNameNumbering &name) {
+ bool isRegistered = name.name.isRegistered();
+ // If we're writing a fallback op, write it as if it were a registered op.
+ if (name.name.hasInterface<FallbackBytecodeOpInterface>())
+ isRegistered = true;
+
size_t stringId = stringSection.insert(name.name.stripDialect());
if (config.bytecodeVersion < bytecode::kNativePropertiesEncoding)
dialectEmitter.emitVarInt(stringId, "dialect op name");
else
- dialectEmitter.emitVarIntWithFlag(stringId, name.name.isRegistered(),
+ dialectEmitter.emitVarIntWithFlag(stringId, isRegistered,
"dialect op name");
};
writeDialectGrouping(dialectEmitter, numberingState.getOpNames(), emitOpName);
@@ -991,7 +996,16 @@ LogicalResult BytecodeWriter::writeBlock(EncodingEmitter &emitter,
}
LogicalResult BytecodeWriter::writeOp(EncodingEmitter &emitter, Operation *op) {
- emitter.emitVarInt(numberingState.getNumber(op->getName()), "op name ID");
+ OperationName opName = op->getName();
+ // For fallback ops, create a new operation name referencing the original op
+ // instead.
+ if (auto fallback = dyn_cast<FallbackBytecodeOpInterface>(op))
+ opName = OperationName((fallback->getDialect()->getNamespace() + "." +
+ fallback.getOriginalOperationName())
+ .str(),
+ op->getContext());
+
+ emitter.emitVarInt(numberingState.getNumber(opName), "op name ID");
// Emit a mask for the operation components. We need to fill this in later
// (when we actually know what needs to be emitted), so emit a placeholder for
diff --git a/mlir/lib/Bytecode/Writer/IRNumbering.cpp b/mlir/lib/Bytecode/Writer/IRNumbering.cpp
index 1bc02e1721573..aa71c692a8e0f 100644
--- a/mlir/lib/Bytecode/Writer/IRNumbering.cpp
+++ b/mlir/lib/Bytecode/Writer/IRNumbering.cpp
@@ -419,7 +419,17 @@ void IRNumberingState::number(Region ®ion) {
void IRNumberingState::number(Operation &op) {
// Number the components of an operation that won't be numbered elsewhere
// (e.g. we don't number operands, regions, or successors here).
- number(op.getName());
+
+ OperationName opName = op.getName();
+ // For fallback ops, create a new operation name referencing the original op
+ // instead.
+ if (auto fallback = dyn_cast<FallbackBytecodeOpInterface>(op))
+ opName = OperationName((fallback->getDialect()->getNamespace() + "." +
+ fallback.getOriginalOperationName())
+ .str(),
+ op.getContext());
+ number(opName);
+
for (OpResult result : op.getResults()) {
valueIDs.try_emplace(result, nextValueID++);
number(result.getType());
diff --git a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
index 747ccf30361e7..07ae95b951fb7 100644
--- a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
+++ b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
@@ -8,6 +8,7 @@
#include "TestDialect.h"
#include "TestOps.h"
+#include "mlir/Bytecode/BytecodeImplementation.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Verifier.h"
@@ -1266,6 +1267,10 @@ void TestBytecodeFallbackOp::setOriginalOperationName(StringRef name,
StringAttr::get(state.getContext(), name));
}
+StringRef TestBytecodeFallbackOp::getOriginalOperationName() {
+ return getProperties().getOpname().getValue();
+}
+
LogicalResult
TestBytecodeFallbackOp::readPropertiesBlob(ArrayRef<char> blob,
OperationState &state) {
>From f88ae0de15e234b0bba204742d5da9a49f9a1995 Mon Sep 17 00:00:00 2001
From: Nikhil Kalra <nkalra at apple.com>
Date: Mon, 3 Mar 2025 13:29:07 -0800
Subject: [PATCH 4/9] refactor / cleanup
---
.../mlir/Bytecode/BytecodeOpInterface.td | 2 +-
mlir/lib/Bytecode/Reader/BytecodeReader.cpp | 8 +++++--
mlir/lib/Bytecode/Writer/BytecodeWriter.cpp | 15 ++++---------
mlir/lib/Bytecode/Writer/IRNumbering.cpp | 21 +++++++++----------
mlir/lib/Bytecode/Writer/IRNumbering.h | 11 ++++++----
mlir/test/lib/Dialect/Test/TestOpDefs.cpp | 2 +-
6 files changed, 29 insertions(+), 30 deletions(-)
diff --git a/mlir/include/mlir/Bytecode/BytecodeOpInterface.td b/mlir/include/mlir/Bytecode/BytecodeOpInterface.td
index b7bffd20d9329..ae352fa5e8f07 100644
--- a/mlir/include/mlir/Bytecode/BytecodeOpInterface.td
+++ b/mlir/include/mlir/Bytecode/BytecodeOpInterface.td
@@ -53,7 +53,7 @@ def FallbackBytecodeOpInterface : OpInterface<"FallbackBytecodeOpInterface"> {
Set the original name for this operation from the bytecode.
}],
"void", "setOriginalOperationName", (ins
- "::mlir::StringRef":$opName,
+ "const ::mlir::Twine&":$opName,
"::mlir::OperationState &":$state)
>,
InterfaceMethod<[{
diff --git a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
index 29b3e27c1fd42..f2aaee35416f1 100644
--- a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
+++ b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
@@ -21,6 +21,7 @@
#include "llvm/ADT/ScopeExit.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringRef.h"
+#include "llvm/ADT/Twine.h"
#include "llvm/Support/Endian.h"
#include "llvm/Support/MemoryBufferRef.h"
#include "llvm/Support/SourceMgr.h"
@@ -2257,8 +2258,11 @@ BytecodeReader::Impl::parseOpWithoutRegions(EncodingReader &reader,
// state.
OperationState opState(opLoc, *opName);
// If this is a fallback op, provide the original name of the operation.
- if (auto *iface = opName->getInterface<FallbackBytecodeOpInterface>())
- iface->setOriginalOperationName((*bytecodeOp)->name, opState);
+ if (auto *iface = opName->getInterface<FallbackBytecodeOpInterface>()) {
+ const Twine originalName =
+ opName->getDialect()->getNamespace() + "." + (*bytecodeOp)->name;
+ iface->setOriginalOperationName(originalName, opState);
+ }
// Parse the attributes of the operation.
if (opMask & bytecode::OpEncodingMask::kHasAttrs) {
diff --git a/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp
index 75ed5f998fa73..f094663f66c24 100644
--- a/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp
+++ b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp
@@ -848,17 +848,12 @@ void BytecodeWriter::writeDialectSection(EncodingEmitter &emitter) {
// Emit the referenced operation names grouped by dialect.
auto emitOpName = [&](OpNameNumbering &name) {
- bool isRegistered = name.name.isRegistered();
- // If we're writing a fallback op, write it as if it were a registered op.
- if (name.name.hasInterface<FallbackBytecodeOpInterface>())
- isRegistered = true;
-
+ const bool isKnownOp = name.isOpaqueEntry || name.name.isRegistered();
size_t stringId = stringSection.insert(name.name.stripDialect());
if (config.bytecodeVersion < bytecode::kNativePropertiesEncoding)
dialectEmitter.emitVarInt(stringId, "dialect op name");
else
- dialectEmitter.emitVarIntWithFlag(stringId, isRegistered,
- "dialect op name");
+ dialectEmitter.emitVarIntWithFlag(stringId, isKnownOp, "dialect op name");
};
writeDialectGrouping(dialectEmitter, numberingState.getOpNames(), emitOpName);
@@ -1000,10 +995,8 @@ LogicalResult BytecodeWriter::writeOp(EncodingEmitter &emitter, Operation *op) {
// For fallback ops, create a new operation name referencing the original op
// instead.
if (auto fallback = dyn_cast<FallbackBytecodeOpInterface>(op))
- opName = OperationName((fallback->getDialect()->getNamespace() + "." +
- fallback.getOriginalOperationName())
- .str(),
- op->getContext());
+ opName =
+ OperationName(fallback.getOriginalOperationName(), op->getContext());
emitter.emitVarInt(numberingState.getNumber(opName), "op name ID");
diff --git a/mlir/lib/Bytecode/Writer/IRNumbering.cpp b/mlir/lib/Bytecode/Writer/IRNumbering.cpp
index aa71c692a8e0f..60bc6bd5170c5 100644
--- a/mlir/lib/Bytecode/Writer/IRNumbering.cpp
+++ b/mlir/lib/Bytecode/Writer/IRNumbering.cpp
@@ -420,15 +420,14 @@ void IRNumberingState::number(Operation &op) {
// Number the components of an operation that won't be numbered elsewhere
// (e.g. we don't number operands, regions, or successors here).
- OperationName opName = op.getName();
- // For fallback ops, create a new operation name referencing the original op
+ // For fallback ops, create a new OperationName referencing the original op
// instead.
- if (auto fallback = dyn_cast<FallbackBytecodeOpInterface>(op))
- opName = OperationName((fallback->getDialect()->getNamespace() + "." +
- fallback.getOriginalOperationName())
- .str(),
- op.getContext());
- number(opName);
+ if (auto fallback = dyn_cast<FallbackBytecodeOpInterface>(op)) {
+ OperationName opName(fallback.getOriginalOperationName(), op.getContext());
+ number(opName, /*isOpaque=*/true);
+ } else {
+ number(op.getName(), /*isOpaque=*/false);
+ }
for (OpResult result : op.getResults()) {
valueIDs.try_emplace(result, nextValueID++);
@@ -467,7 +466,7 @@ void IRNumberingState::number(Operation &op) {
number(op.getLoc());
}
-void IRNumberingState::number(OperationName opName) {
+void IRNumberingState::number(OperationName opName, bool isOpaque) {
OpNameNumbering *&numbering = opNames[opName];
if (numbering) {
++numbering->refCount;
@@ -479,8 +478,8 @@ void IRNumberingState::number(OperationName opName) {
else
dialectNumber = &numberDialect(opName.getDialectNamespace());
- numbering =
- new (opNameAllocator.Allocate()) OpNameNumbering(dialectNumber, opName);
+ numbering = new (opNameAllocator.Allocate())
+ OpNameNumbering(dialectNumber, opName, isOpaque);
orderedOpNames.push_back(numbering);
}
diff --git a/mlir/lib/Bytecode/Writer/IRNumbering.h b/mlir/lib/Bytecode/Writer/IRNumbering.h
index 9b7ac0d3688e3..033b3771b46a3 100644
--- a/mlir/lib/Bytecode/Writer/IRNumbering.h
+++ b/mlir/lib/Bytecode/Writer/IRNumbering.h
@@ -63,8 +63,8 @@ struct TypeNumbering : public AttrTypeNumbering {
/// This class represents the numbering entry of an operation name.
struct OpNameNumbering {
- OpNameNumbering(DialectNumbering *dialect, OperationName name)
- : dialect(dialect), name(name) {}
+ OpNameNumbering(DialectNumbering *dialect, OperationName name, bool isOpaque)
+ : dialect(dialect), name(name), isOpaqueEntry(isOpaque) {}
/// The dialect of this value.
DialectNumbering *dialect;
@@ -72,6 +72,9 @@ struct OpNameNumbering {
/// The concrete name.
OperationName name;
+ /// This entry represents an opaque operation entry.
+ bool isOpaqueEntry = false;
+
/// The number assigned to this name.
unsigned number = 0;
@@ -210,7 +213,7 @@ class IRNumberingState {
/// Get the set desired bytecode version to emit.
int64_t getDesiredBytecodeVersion() const;
-
+
private:
/// This class is used to provide a fake dialect writer for numbering nested
/// attributes and types.
@@ -225,7 +228,7 @@ class IRNumberingState {
DialectNumbering &numberDialect(Dialect *dialect);
DialectNumbering &numberDialect(StringRef dialect);
void number(Operation &op);
- void number(OperationName opName);
+ void number(OperationName opName, bool isOpaque);
void number(Region ®ion);
void number(Type type);
diff --git a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
index 07ae95b951fb7..e2e38bb79aad1 100644
--- a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
+++ b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
@@ -1261,7 +1261,7 @@ void TestVersionedOpA::writeProperties(mlir::DialectBytecodeWriter &writer) {
// TestBytecodeFallbackOp
//===----------------------------------------------------------------------===//
-void TestBytecodeFallbackOp::setOriginalOperationName(StringRef name,
+void TestBytecodeFallbackOp::setOriginalOperationName(const Twine &name,
OperationState &state) {
state.getOrAddProperties<Properties>().setOpname(
StringAttr::get(state.getContext(), name));
>From d2ad94280c4882548aa08224dda5d4e7d1094c08 Mon Sep 17 00:00:00 2001
From: Nikhil Kalra <nkalra at apple.com>
Date: Mon, 3 Mar 2025 14:02:33 -0800
Subject: [PATCH 5/9] use fallback on parsing failure
---
mlir/lib/Bytecode/Reader/BytecodeReader.cpp | 86 ++++++++++++++-------
1 file changed, 58 insertions(+), 28 deletions(-)
diff --git a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
index f2aaee35416f1..59774085f3b01 100644
--- a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
+++ b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
@@ -293,6 +293,16 @@ class EncodingReader {
Location getLoc() const { return fileLoc; }
+ /// Snapshot the location of the BytecodeReader so that parsing can be rewound
+ /// if needed.
+ struct Snapshot {
+ EncodingReader &reader;
+ const uint8_t *dataIt;
+
+ Snapshot(EncodingReader &reader) : reader(reader), dataIt(reader.dataIt) {}
+ void rewind() { reader.dataIt = dataIt; }
+ };
+
private:
/// Parse a variable length encoded integer from the byte stream. This method
/// is a fallback when the number of bytes used to encode the value is greater
@@ -1417,7 +1427,8 @@ class mlir::BytecodeReader::Impl {
/// `wasRegistered` flag that indicates if the bytecode was produced by a
/// context where opName was registered.
FailureOr<BytecodeOperationName *>
- parseOpName(EncodingReader &reader, std::optional<bool> &wasRegistered);
+ parseOpName(EncodingReader &reader, std::optional<bool> &wasRegistered,
+ bool useDialectFallback);
//===--------------------------------------------------------------------===//
// Attribute/Type Section
@@ -1482,7 +1493,8 @@ class mlir::BytecodeReader::Impl {
RegionReadState &readState);
FailureOr<Operation *> parseOpWithoutRegions(EncodingReader &reader,
RegionReadState &readState,
- bool &isIsolatedFromAbove);
+ bool &isIsolatedFromAbove,
+ bool useDialectFallback);
LogicalResult parseRegion(RegionReadState &readState);
LogicalResult parseBlockHeader(EncodingReader &reader,
@@ -1851,14 +1863,18 @@ BytecodeReader::Impl::parseDialectSection(ArrayRef<uint8_t> sectionData) {
FailureOr<BytecodeOperationName *>
BytecodeReader::Impl::parseOpName(EncodingReader &reader,
- std::optional<bool> &wasRegistered) {
+ std::optional<bool> &wasRegistered,
+ bool useDialectFallback) {
BytecodeOperationName *opName = nullptr;
if (failed(parseEntry(reader, opNames, opName, "operation name")))
return failure();
wasRegistered = opName->wasRegistered;
// Check to see if this operation name has already been resolved. If we
// haven't, load the dialect and build the operation name.
- if (!opName->opName) {
+ // If `useDialectFallback`, it's likely that parsing previously failed. We'll
+ // need to reset any previously resolved OperationName with that of the
+ // fallback op.
+ if (!opName->opName || useDialectFallback) {
// If the opName is empty, this is because we use to accept names such as
// `foo` without any `.` separator. We shouldn't tolerate this in textual
// format anymore but for now we'll be backward compatible. This can only
@@ -1872,21 +1888,19 @@ BytecodeReader::Impl::parseOpName(EncodingReader &reader,
if (failed(opName->dialect->load(dialectReader, getContext())))
return failure();
- opName->opName.emplace((opName->dialect->name + "." + opName->name).str(),
- getContext());
-
- // If the op is unregistered now, but was not marked as unregistered, try
- // to parse it as a fallback op if the dialect's bytecode interface
- // specifies one.
- // We don't treat this condition as an error because we may still be able
- // to parse the op as an unregistered op if it doesn't use custom
- // properties encoding.
- if (wasRegistered && !opName->opName->isRegistered()) {
- if (auto fallbackOp =
- opName->dialect->interface->getFallbackOperationName();
- succeeded(fallbackOp)) {
- opName->opName.emplace(*fallbackOp);
- }
+ if (useDialectFallback) {
+ auto fallbackOp =
+ opName->dialect->interface->getFallbackOperationName();
+
+ // If the dialect doesn't have a fallback operation, we can't parse as
+ // instructed.
+ if (failed(fallbackOp))
+ return failure();
+
+ opName->opName.emplace(*fallbackOp);
+ } else {
+ opName->opName.emplace(
+ (opName->dialect->name + "." + opName->name).str(), getContext());
}
}
}
@@ -2164,10 +2178,27 @@ BytecodeReader::Impl::parseRegions(std::vector<RegionReadState> ®ionStack,
// Read in the next operation. We don't read its regions directly, we
// handle those afterwards as necessary.
bool isIsolatedFromAbove = false;
- FailureOr<Operation *> op =
- parseOpWithoutRegions(reader, readState, isIsolatedFromAbove);
- if (failed(op))
- return failure();
+ FailureOr<Operation *> op;
+
+ // Parse the bytecode.
+ {
+ EncodingReader::Snapshot snapshot(reader);
+ op = parseOpWithoutRegions(reader, readState, isIsolatedFromAbove,
+ /*useDialectFallback=*/false);
+
+ // If reading fails, try parsing the op again as a dialect fallback
+ // op (if supported).
+ if (failed(op)) {
+ snapshot.rewind();
+ op = parseOpWithoutRegions(reader, readState, isIsolatedFromAbove,
+ /*useDialectFallback=*/true);
+ }
+
+ // If the dialect doesn't have a fallback op, or parsing as a fallback
+ // op fails, we can no longer continue.
+ if (failed(op))
+ return failure();
+ }
// If the op has regions, add it to the stack for processing and return:
// we stop the processing of the current region and resume it after the
@@ -2229,14 +2260,13 @@ BytecodeReader::Impl::parseRegions(std::vector<RegionReadState> ®ionStack,
return success();
}
-FailureOr<Operation *>
-BytecodeReader::Impl::parseOpWithoutRegions(EncodingReader &reader,
- RegionReadState &readState,
- bool &isIsolatedFromAbove) {
+FailureOr<Operation *> BytecodeReader::Impl::parseOpWithoutRegions(
+ EncodingReader &reader, RegionReadState &readState,
+ bool &isIsolatedFromAbove, bool useDialectFallback) {
// Parse the name of the operation.
std::optional<bool> wasRegistered;
FailureOr<BytecodeOperationName *> bytecodeOp =
- parseOpName(reader, wasRegistered);
+ parseOpName(reader, wasRegistered, useDialectFallback);
if (failed(bytecodeOp))
return failure();
auto opName = (*bytecodeOp)->opName;
>From 1de4485ac2d957a2a8280566873b47f142fee4f4 Mon Sep 17 00:00:00 2001
From: Nikhil Kalra <nkalra at apple.com>
Date: Tue, 4 Mar 2025 13:23:03 -0800
Subject: [PATCH 6/9] fix parsing & writing to be bitwise exact
---
.../mlir/Bytecode/BytecodeOpInterface.td | 14 +--
mlir/lib/Bytecode/Reader/BytecodeReader.cpp | 8 +-
mlir/lib/Bytecode/Writer/BytecodeWriter.cpp | 7 --
.../versioning/versioning-fallback.mlir | 13 +++
mlir/test/lib/Dialect/Test/TestOpDefs.cpp | 100 +++++++++++++-----
mlir/test/lib/Dialect/Test/TestOps.td | 19 ++--
6 files changed, 105 insertions(+), 56 deletions(-)
create mode 100644 mlir/test/Bytecode/versioning/versioning-fallback.mlir
diff --git a/mlir/include/mlir/Bytecode/BytecodeOpInterface.td b/mlir/include/mlir/Bytecode/BytecodeOpInterface.td
index ae352fa5e8f07..a97cf620def0c 100644
--- a/mlir/include/mlir/Bytecode/BytecodeOpInterface.td
+++ b/mlir/include/mlir/Bytecode/BytecodeOpInterface.td
@@ -60,19 +60,7 @@ def FallbackBytecodeOpInterface : OpInterface<"FallbackBytecodeOpInterface"> {
Get the original name for this operation from the bytecode.
}],
"::mlir::StringRef", "getOriginalOperationName", (ins)
- >,
- StaticInterfaceMethod<[{
- Read the properties blob for this operation from the bytecode and populate the state.
- }],
- "::mlir::LogicalResult", "readPropertiesBlob", (ins
- "::mlir::ArrayRef<char>":$blob,
- "::mlir::OperationState &":$state)
- >,
- InterfaceMethod<[{
- Get the properties blob for this operation to be emitted into the bytecode.
- }],
- "::mlir::ArrayRef<char>", "getPropertiesBlob", (ins)
- >,
+ >
];
}
diff --git a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
index 59774085f3b01..f5e12a14f548f 100644
--- a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
+++ b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
@@ -1118,11 +1118,6 @@ class PropertiesSectionReader {
return failure();
}
- // If the op is a "fallback" op, give it a handle to the raw properties
- // buffer.
- if (auto *iface = opName->getInterface<FallbackBytecodeOpInterface>())
- return iface->readPropertiesBlob(rawProperties, opState);
-
// Setup a new reader to read from the `rawProperties` sub-buffer.
EncodingReader reader(
StringRef(rawProperties.begin(), rawProperties.size()), fileLoc);
@@ -2182,6 +2177,9 @@ BytecodeReader::Impl::parseRegions(std::vector<RegionReadState> ®ionStack,
// Parse the bytecode.
{
+ // If the op is registered (and serialized in a compatible manner), or
+ // unregistered but uses standard properties encoding, parsing without
+ // going through the fallback path should work.
EncodingReader::Snapshot snapshot(reader);
op = parseOpWithoutRegions(reader, readState, isIsolatedFromAbove,
/*useDialectFallback=*/false);
diff --git a/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp
index f094663f66c24..526dfb3654492 100644
--- a/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp
+++ b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp
@@ -518,13 +518,6 @@ class PropertiesSectionBuilder {
return emit(scratch);
}
- if (auto iface = dyn_cast<FallbackBytecodeOpInterface>(op)) {
- // Fallback ops should write the properties payload to the bytecode buffer
- // directly.
- ArrayRef<char> encodedProperties = iface.getPropertiesBlob();
- return emit(encodedProperties);
- }
-
EncodingEmitter emitter;
DialectWriter propertiesWriter(config.bytecodeVersion, emitter,
numberingState, stringSection,
diff --git a/mlir/test/Bytecode/versioning/versioning-fallback.mlir b/mlir/test/Bytecode/versioning/versioning-fallback.mlir
new file mode 100644
index 0000000000000..a078613360af6
--- /dev/null
+++ b/mlir/test/Bytecode/versioning/versioning-fallback.mlir
@@ -0,0 +1,13 @@
+// RUN: mlir-opt %s --emit-bytecode > %T/versioning-fallback.mlirbc
+"test.versionedD"() <{attribute = #test.attr_params<42, 24>}> : () -> ()
+
+// COM: check that versionedD was parsed as a fallback op.
+// RUN: mlir-opt %T/versioning-fallback.mlirbc | FileCheck %s --check-prefix=CHECK-PARSE
+// CHECK-PARSE: test.bytecode.fallback
+// CHECK-PARSE-SAME: opname = "test.versionedD"
+
+// COM: check that the bytecode roundtrip was successful
+// RUN: mlir-opt %T/versioning-fallback.mlirbc --verify-roundtrip
+
+// COM: check that the bytecode roundtrip is bitwise exact
+// RUN: mlir-opt %T/versioning-fallback.mlirbc --emit-bytecode | diff %T/versioning-fallback.mlirbc -
diff --git a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
index e2e38bb79aad1..77428517f2b12 100644
--- a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
+++ b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
@@ -10,10 +10,13 @@
#include "TestOps.h"
#include "mlir/Bytecode/BytecodeImplementation.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Verifier.h"
#include "mlir/Interfaces/FunctionImplementation.h"
#include "mlir/Interfaces/MemorySlotInterfaces.h"
+#include "llvm/ADT/StringRef.h"
#include "llvm/Support/LogicalResult.h"
#include <cstdint>
@@ -1238,24 +1241,60 @@ void TestVersionedOpA::writeProperties(mlir::DialectBytecodeWriter &writer) {
// TestVersionedOpD
//===----------------------------------------------------------------------===//
-// LogicalResult
-// TestVersionedOpD::readProperties(mlir::DialectBytecodeReader &reader,
-// mlir::OperationState &state) {
-// auto &prop = state.getOrAddProperties<Properties>();
-// StringRef res;
-// if (failed(reader.readString(res)))
-// return failure();
-// if (failed(reader.readAttribute(prop.attribute)))
-// return failure();
+LogicalResult
+TestVersionedOpD::readProperties(mlir::DialectBytecodeReader &reader,
+ mlir::OperationState &state) {
+ // Always fail so that this uses the fallback path.
+ return failure();
+}
+
+struct FallbackCompliantPropertiesEncoding {
+ int64_t version;
+ SmallVector<Attribute> requiredAttributes;
+ SmallVector<Attribute> optionalAttributes;
+
+ void writeProperties(DialectBytecodeWriter &writer) const {
+ // Write the op version.
+ writer.writeSignedVarInt(version);
+
+ // Write the required attributes.
+ writer.writeList(requiredAttributes,
+ [&](Attribute attr) { writer.writeAttribute(attr); });
+
+ // Write the optional attributes.
+ writer.writeList(optionalAttributes, [&](Attribute attr) {
+ writer.writeOptionalAttribute(attr);
+ });
+ }
+
+ LogicalResult readProperties(DialectBytecodeReader &reader) {
+ // Read the op version.
+ if (failed(reader.readSignedVarInt(version)))
+ return failure();
-// return success();
-// }
+ // Read the required attributes.
+ if (failed(reader.readList(requiredAttributes, [&](Attribute &attr) {
+ return reader.readAttribute(attr);
+ })))
+ return failure();
-// void TestVersionedOpD::writeProperties(mlir::DialectBytecodeWriter &writer) {
-// auto &prop = getProperties();
-// writer.writeOwnedString("version 1");
-// writer.writeAttribute(prop.attribute);
-// }
+ // Read the optional attributes.
+ if (failed(reader.readList(optionalAttributes, [&](Attribute &attr) {
+ return reader.readOptionalAttribute(attr);
+ })))
+ return failure();
+
+ return success();
+ }
+};
+
+void TestVersionedOpD::writeProperties(mlir::DialectBytecodeWriter &writer) {
+ FallbackCompliantPropertiesEncoding encoding{
+ .version = 1,
+ .requiredAttributes = {getAttribute()},
+ .optionalAttributes = {}};
+ encoding.writeProperties(writer);
+}
//===----------------------------------------------------------------------===//
// TestBytecodeFallbackOp
@@ -1272,17 +1311,30 @@ StringRef TestBytecodeFallbackOp::getOriginalOperationName() {
}
LogicalResult
-TestBytecodeFallbackOp::readPropertiesBlob(ArrayRef<char> blob,
- OperationState &state) {
- state.getOrAddProperties<Properties>().bytecodeProperties =
- DenseI8ArrayAttr::get(state.getContext(),
- ArrayRef((const int8_t *)blob.data(), blob.size()));
+TestBytecodeFallbackOp::readProperties(DialectBytecodeReader &reader,
+ OperationState &state) {
+ FallbackCompliantPropertiesEncoding encoding;
+ if (failed(encoding.readProperties(reader)))
+ return failure();
+
+ auto &props = state.getOrAddProperties<Properties>();
+ props.opversion = encoding.version;
+ props.encodedReqdAttributes =
+ ArrayAttr::get(state.getContext(), encoding.requiredAttributes);
+ props.encodedOptAttributes =
+ ArrayAttr::get(state.getContext(), encoding.optionalAttributes);
+
return success();
}
-ArrayRef<char> TestBytecodeFallbackOp::getPropertiesBlob() {
- return ArrayRef((const char *)getBytecodeProperties().data(),
- getBytecodeProperties().size());
+void TestBytecodeFallbackOp::writeProperties(DialectBytecodeWriter &writer) {
+ FallbackCompliantPropertiesEncoding encoding{
+ .version = getOpversion(),
+ .requiredAttributes =
+ llvm::to_vector(getEncodedReqdAttributes().getValue()),
+ .optionalAttributes =
+ llvm::to_vector(getEncodedOptAttributes().getValue())};
+ encoding.writeProperties(writer);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 52ca3a81fe2c7..4ea76b0f970b7 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -3030,23 +3030,28 @@ def TestVersionedOpC : TEST_Op<"versionedC"> {
);
}
-// def TestVersionedOpD : TEST_Op<"versionedD"> {
-// let arguments = (ins AnyAttrOf<[TestAttrParams,
-// I32ElementsAttr]>:$attribute
-// );
+// This op is used to generate tests for the bytecode dialect fallback path.
+def TestVersionedOpD : TEST_Op<"versionedD"> {
+ let arguments = (ins AnyAttrOf<[TestAttrParams,
+ I32ElementsAttr]>:$attribute
+ );
-// let useCustomPropertiesEncoding = 1;
-// }
+ let useCustomPropertiesEncoding = 1;
+}
def TestBytecodeFallbackOp : TEST_Op<"bytecode.fallback", [
DeclareOpInterfaceMethods<FallbackBytecodeOpInterface, ["setOriginalOperationName", "readPropertiesBlob", "getPropertiesBlob"]>
]> {
let arguments = (ins
StrAttr:$opname,
- DenseI8ArrayAttr:$bytecodeProperties,
+ IntProp<"int64_t">:$opversion,
+ ArrayAttr:$encodedReqdAttributes,
+ ArrayAttr:$encodedOptAttributes,
Variadic<AnyType>:$operands);
let regions = (region VariadicRegion<AnyRegion>:$bodyRegions);
let results = (outs Variadic<AnyType>:$results);
+
+ let useCustomPropertiesEncoding = 1;
}
//===----------------------------------------------------------------------===//
>From 8a5066c9d296a28e883da61dfb9ee0b9fe8d5396 Mon Sep 17 00:00:00 2001
From: Nikhil Kalra <nkalra at apple.com>
Date: Tue, 4 Mar 2025 13:28:52 -0800
Subject: [PATCH 7/9] fix tests
---
mlir/lib/Bytecode/Reader/BytecodeReader.cpp | 6 ++++--
1 file changed, 4 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
index f5e12a14f548f..7668e9aa5cb28 100644
--- a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
+++ b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
@@ -1883,9 +1883,11 @@ BytecodeReader::Impl::parseOpName(EncodingReader &reader,
if (failed(opName->dialect->load(dialectReader, getContext())))
return failure();
+ const BytecodeDialectInterface *dialectIface = opName->dialect->interface;
if (useDialectFallback) {
- auto fallbackOp =
- opName->dialect->interface->getFallbackOperationName();
+ FailureOr<OperationName> fallbackOp =
+ dialectIface ? dialectIface->getFallbackOperationName()
+ : FailureOr<OperationName>{};
// If the dialect doesn't have a fallback operation, we can't parse as
// instructed.
>From 9209d5f249ae0c7f15416cd3c8f82bb2b56050b0 Mon Sep 17 00:00:00 2001
From: Nikhil Kalra <nkalra at apple.com>
Date: Tue, 4 Mar 2025 13:37:17 -0800
Subject: [PATCH 8/9] clean up diff
---
mlir/include/mlir/Bytecode/BytecodeOpInterface.td | 4 ++--
mlir/lib/Bytecode/Reader/BytecodeReader.cpp | 4 ++--
2 files changed, 4 insertions(+), 4 deletions(-)
diff --git a/mlir/include/mlir/Bytecode/BytecodeOpInterface.td b/mlir/include/mlir/Bytecode/BytecodeOpInterface.td
index a97cf620def0c..87ba27ad6ac27 100644
--- a/mlir/include/mlir/Bytecode/BytecodeOpInterface.td
+++ b/mlir/include/mlir/Bytecode/BytecodeOpInterface.td
@@ -43,8 +43,8 @@ def BytecodeOpInterface : OpInterface<"BytecodeOpInterface"> {
// `FallbackBytecodeOpInterface`
def FallbackBytecodeOpInterface : OpInterface<"FallbackBytecodeOpInterface"> {
let description = [{
- This interface allows fallback operations direct access to the bytecode
- property streams.
+ This interface allows fallback operations sideband access to the
+ original operation's intrinsic details.
}];
let cppNamespace = "::mlir";
diff --git a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
index 7668e9aa5cb28..64fcc4ed7c6dc 100644
--- a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
+++ b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
@@ -1117,13 +1117,13 @@ class PropertiesSectionReader {
dialectReader.withEncodingReader(reader).readBlob(rawProperties)))
return failure();
}
-
// Setup a new reader to read from the `rawProperties` sub-buffer.
EncodingReader reader(
StringRef(rawProperties.begin(), rawProperties.size()), fileLoc);
DialectReader propReader = dialectReader.withEncodingReader(reader);
- if (auto *iface = opName->getInterface<BytecodeOpInterface>())
+ auto *iface = opName->getInterface<BytecodeOpInterface>();
+ if (iface)
return iface->readProperties(propReader, opState);
if (opName->isRegistered())
return propReader.emitError(
>From ed4a22d8d479d321159ca191171337646fa9700c Mon Sep 17 00:00:00 2001
From: Nikhil Kalra <nkalra at apple.com>
Date: Wed, 5 Mar 2025 11:16:40 -0800
Subject: [PATCH 9/9] remove old interface methods
---
mlir/test/lib/Dialect/Test/TestOps.td | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 4ea76b0f970b7..6c7687e08b633 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -3040,7 +3040,7 @@ def TestVersionedOpD : TEST_Op<"versionedD"> {
}
def TestBytecodeFallbackOp : TEST_Op<"bytecode.fallback", [
- DeclareOpInterfaceMethods<FallbackBytecodeOpInterface, ["setOriginalOperationName", "readPropertiesBlob", "getPropertiesBlob"]>
+ DeclareOpInterfaceMethods<FallbackBytecodeOpInterface, ["setOriginalOperationName"]>
]> {
let arguments = (ins
StrAttr:$opname,
More information about the Mlir-commits
mailing list