[Mlir-commits] [mlir] 5fb1bbe - [mlir] Add bytecode encodings for the builtin ElementsAttr attributes
River Riddle
llvmlistbot at llvm.org
Tue Sep 13 11:39:47 PDT 2022
Author: River Riddle
Date: 2022-09-13T11:39:20-07:00
New Revision: 5fb1bbe6d49e08af5d47ce4d9498d650e6628dad
URL: https://github.com/llvm/llvm-project/commit/5fb1bbe6d49e08af5d47ce4d9498d650e6628dad
DIFF: https://github.com/llvm/llvm-project/commit/5fb1bbe6d49e08af5d47ce4d9498d650e6628dad.diff
LOG: [mlir] Add bytecode encodings for the builtin ElementsAttr attributes
This adds bytecode support for DenseArrayAttr, DenseIntOrFpElementsAttr,
DenseStringElementsAttr, and SparseElementsAttr.
Differential Revision: https://reviews.llvm.org/D133744
Added:
Modified:
mlir/include/mlir/Bytecode/BytecodeImplementation.h
mlir/lib/Bytecode/Reader/BytecodeReader.cpp
mlir/lib/Bytecode/Writer/BytecodeWriter.cpp
mlir/lib/Bytecode/Writer/IRNumbering.cpp
mlir/lib/IR/BuiltinDialectBytecode.cpp
mlir/test/Dialect/Builtin/Bytecode/attrs.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Bytecode/BytecodeImplementation.h b/mlir/include/mlir/Bytecode/BytecodeImplementation.h
index 7607f2a4036be..1ae839d5f40e7 100644
--- a/mlir/include/mlir/Bytecode/BytecodeImplementation.h
+++ b/mlir/include/mlir/Bytecode/BytecodeImplementation.h
@@ -143,6 +143,9 @@ class DialectBytecodeReader {
/// Read a string from the bytecode.
virtual LogicalResult readString(StringRef &result) = 0;
+ /// Read a blob from the bytecode.
+ virtual LogicalResult readBlob(ArrayRef<char> &result) = 0;
+
private:
/// Read a handle to a dialect resource.
virtual FailureOr<AsmDialectResourceHandle> readResourceHandle() = 0;
@@ -225,6 +228,11 @@ class DialectBytecodeWriter {
/// only be called if such a guarantee can be made, such as when the string is
/// owned by an attribute or type.
virtual void writeOwnedString(StringRef str) = 0;
+
+ /// Write a blob to the bytecode, which is owned by the caller and is
+ /// guaranteed to not die before the end of the bytecode process. The blob is
+ /// written as-is, with no additional compression or compaction.
+ virtual void writeOwnedBlob(ArrayRef<char> blob) = 0;
};
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
index 5e10dfad355a3..2f36de523ff32 100644
--- a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
+++ b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
@@ -887,6 +887,17 @@ class DialectReader : public DialectBytecodeReader {
return stringReader.parseString(reader, result);
}
+ LogicalResult readBlob(ArrayRef<char> &result) override {
+ uint64_t dataSize;
+ ArrayRef<uint8_t> data;
+ if (failed(reader.parseVarInt(dataSize)) ||
+ failed(reader.parseBytes(dataSize, data)))
+ return failure();
+ result = llvm::makeArrayRef(reinterpret_cast<const char *>(data.data()),
+ data.size());
+ return success();
+ }
+
private:
AttrTypeReader &attrTypeReader;
StringSectionReader &stringReader;
diff --git a/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp
index 7bcc1a841c245..1fd313453fb24 100644
--- a/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp
+++ b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp
@@ -543,6 +543,12 @@ class DialectWriter : public DialectBytecodeWriter {
emitter.emitVarInt(stringSection.insert(str));
}
+ void writeOwnedBlob(ArrayRef<char> blob) override {
+ emitter.emitVarInt(blob.size());
+ emitter.emitOwnedBlob(ArrayRef<uint8_t>(
+ reinterpret_cast<const uint8_t *>(blob.data()), blob.size()));
+ }
+
private:
EncodingEmitter &emitter;
IRNumberingState &numberingState;
diff --git a/mlir/lib/Bytecode/Writer/IRNumbering.cpp b/mlir/lib/Bytecode/Writer/IRNumbering.cpp
index d5a2ef5f2bf14..9fed3236af588 100644
--- a/mlir/lib/Bytecode/Writer/IRNumbering.cpp
+++ b/mlir/lib/Bytecode/Writer/IRNumbering.cpp
@@ -39,6 +39,7 @@ struct IRNumberingState::NumberingDialectWriter : public DialectBytecodeWriter {
// references. This could potentially be useful for optimizing things like
// file locations.
}
+ void writeOwnedBlob(ArrayRef<char> blob) override {}
/// The parent numbering state that is populated by this writer.
IRNumberingState &state;
diff --git a/mlir/lib/IR/BuiltinDialectBytecode.cpp b/mlir/lib/IR/BuiltinDialectBytecode.cpp
index 260282821ac51..18f3ea8f2382e 100644
--- a/mlir/lib/IR/BuiltinDialectBytecode.cpp
+++ b/mlir/lib/IR/BuiltinDialectBytecode.cpp
@@ -123,6 +123,32 @@ enum AttributeCode {
/// handle: ResourceHandle
/// }
kDenseResourceElementsAttr = 16,
+
+ /// DenseArrayAttr {
+ /// type: RankedTensorType,
+ /// data: blob
+ /// }
+ kDenseArrayAttr = 17,
+
+ /// DenseIntOrFPElementsAttr {
+ /// type: ShapedType,
+ /// data: blob
+ /// }
+ kDenseIntOrFPElementsAttr = 18,
+
+ /// DenseStringElementsAttr {
+ /// type: ShapedType,
+ /// isSplat: varint,
+ /// data: string[]
+ /// }
+ kDenseStringElementsAttr = 19,
+
+ /// SparseElementsAttr {
+ /// type: ShapedType,
+ /// indices: DenseIntElementsAttr,
+ /// values: DenseElementsAttr
+ /// }
+ kSparseElementsAttr = 20,
};
/// This enum contains marker codes used to indicate which type is currently
@@ -279,11 +305,18 @@ struct BuiltinDialectBytecodeInterface : public BytecodeDialectInterface {
Attribute readAttribute(DialectBytecodeReader &reader) const override;
ArrayAttr readArrayAttr(DialectBytecodeReader &reader) const;
+ DenseArrayAttr readDenseArrayAttr(DialectBytecodeReader &reader) const;
+ DenseElementsAttr
+ readDenseIntOrFPElementsAttr(DialectBytecodeReader &reader) const;
+ DenseStringElementsAttr
+ readDenseStringElementsAttr(DialectBytecodeReader &reader) const;
DenseResourceElementsAttr
readDenseResourceElementsAttr(DialectBytecodeReader &reader) const;
DictionaryAttr readDictionaryAttr(DialectBytecodeReader &reader) const;
FloatAttr readFloatAttr(DialectBytecodeReader &reader) const;
IntegerAttr readIntegerAttr(DialectBytecodeReader &reader) const;
+ SparseElementsAttr
+ readSparseElementsAttr(DialectBytecodeReader &reader) const;
StringAttr readStringAttr(DialectBytecodeReader &reader, bool hasType) const;
SymbolRefAttr readSymbolRefAttr(DialectBytecodeReader &reader,
bool hasNestedRefs) const;
@@ -298,11 +331,16 @@ struct BuiltinDialectBytecodeInterface : public BytecodeDialectInterface {
LogicalResult writeAttribute(Attribute attr,
DialectBytecodeWriter &writer) const override;
void write(ArrayAttr attr, DialectBytecodeWriter &writer) const;
+ void write(DenseArrayAttr attr, DialectBytecodeWriter &writer) const;
+ void write(DenseIntOrFPElementsAttr attr,
+ DialectBytecodeWriter &writer) const;
+ void write(DenseStringElementsAttr attr, DialectBytecodeWriter &writer) const;
void write(DenseResourceElementsAttr attr,
DialectBytecodeWriter &writer) const;
void write(DictionaryAttr attr, DialectBytecodeWriter &writer) const;
void write(IntegerAttr attr, DialectBytecodeWriter &writer) const;
void write(FloatAttr attr, DialectBytecodeWriter &writer) const;
+ void write(SparseElementsAttr attr, DialectBytecodeWriter &writer) const;
void write(StringAttr attr, DialectBytecodeWriter &writer) const;
void write(SymbolRefAttr attr, DialectBytecodeWriter &writer) const;
void write(TypeAttr attr, DialectBytecodeWriter &writer) const;
@@ -394,6 +432,14 @@ Attribute BuiltinDialectBytecodeInterface::readAttribute(
return UnknownLoc::get(getContext());
case builtin_encoding::kDenseResourceElementsAttr:
return readDenseResourceElementsAttr(reader);
+ case builtin_encoding::kDenseArrayAttr:
+ return readDenseArrayAttr(reader);
+ case builtin_encoding::kDenseIntOrFPElementsAttr:
+ return readDenseIntOrFPElementsAttr(reader);
+ case builtin_encoding::kDenseStringElementsAttr:
+ return readDenseStringElementsAttr(reader);
+ case builtin_encoding::kSparseElementsAttr:
+ return readSparseElementsAttr(reader);
default:
reader.emitError() << "unknown builtin attribute code: " << code;
return Attribute();
@@ -403,8 +449,10 @@ Attribute BuiltinDialectBytecodeInterface::readAttribute(
LogicalResult BuiltinDialectBytecodeInterface::writeAttribute(
Attribute attr, DialectBytecodeWriter &writer) const {
return TypeSwitch<Attribute, LogicalResult>(attr)
- .Case<ArrayAttr, DenseResourceElementsAttr, DictionaryAttr, FloatAttr,
- IntegerAttr, StringAttr, SymbolRefAttr, TypeAttr>([&](auto attr) {
+ .Case<ArrayAttr, DenseArrayAttr, DenseIntOrFPElementsAttr,
+ DenseStringElementsAttr, DenseResourceElementsAttr, DictionaryAttr,
+ FloatAttr, IntegerAttr, SparseElementsAttr, StringAttr,
+ SymbolRefAttr, TypeAttr>([&](auto attr) {
write(attr, writer);
return success();
})
@@ -441,6 +489,78 @@ void BuiltinDialectBytecodeInterface::write(
writer.writeAttributes(attr.getValue());
}
+//===----------------------------------------------------------------------===//
+// DenseArrayAttr
+
+DenseArrayAttr BuiltinDialectBytecodeInterface::readDenseArrayAttr(
+ DialectBytecodeReader &reader) const {
+ RankedTensorType type;
+ ArrayRef<char> blob;
+ if (failed(reader.readType(type)) || failed(reader.readBlob(blob)))
+ return DenseArrayAttr();
+ return DenseArrayAttr::get(type, blob);
+}
+
+void BuiltinDialectBytecodeInterface::write(
+ DenseArrayAttr attr, DialectBytecodeWriter &writer) const {
+ writer.writeVarInt(builtin_encoding::kDenseArrayAttr);
+ writer.writeType(attr.getType());
+ writer.writeOwnedBlob(attr.getRawData());
+}
+
+//===----------------------------------------------------------------------===//
+// DenseIntOrFPElementsAttr
+
+DenseElementsAttr BuiltinDialectBytecodeInterface::readDenseIntOrFPElementsAttr(
+ DialectBytecodeReader &reader) const {
+ ShapedType type;
+ ArrayRef<char> blob;
+ if (failed(reader.readType(type)) || failed(reader.readBlob(blob)))
+ return DenseIntOrFPElementsAttr();
+ return DenseIntOrFPElementsAttr::getFromRawBuffer(type, blob);
+}
+
+void BuiltinDialectBytecodeInterface::write(
+ DenseIntOrFPElementsAttr attr, DialectBytecodeWriter &writer) const {
+ writer.writeVarInt(builtin_encoding::kDenseIntOrFPElementsAttr);
+ writer.writeType(attr.getType());
+ writer.writeOwnedBlob(attr.getRawData());
+}
+
+//===----------------------------------------------------------------------===//
+// DenseStringElementsAttr
+
+DenseStringElementsAttr
+BuiltinDialectBytecodeInterface::readDenseStringElementsAttr(
+ DialectBytecodeReader &reader) const {
+ ShapedType type;
+ uint64_t isSplat;
+ if (failed(reader.readType(type)) || failed(reader.readVarInt(isSplat)))
+ return DenseStringElementsAttr();
+
+ SmallVector<StringRef> values(isSplat ? 1 : type.getNumElements());
+ for (StringRef &value : values)
+ if (failed(reader.readString(value)))
+ return DenseStringElementsAttr();
+ return DenseStringElementsAttr::get(type, values);
+}
+
+void BuiltinDialectBytecodeInterface::write(
+ DenseStringElementsAttr attr, DialectBytecodeWriter &writer) const {
+ writer.writeVarInt(builtin_encoding::kDenseStringElementsAttr);
+ writer.writeType(attr.getType());
+
+ bool isSplat = attr.isSplat();
+ writer.writeVarInt(isSplat);
+
+ // If the attribute is a splat, only write out the single value.
+ if (isSplat)
+ return writer.writeOwnedString(attr.getRawStringData().front());
+
+ for (StringRef str : attr.getRawStringData())
+ writer.writeOwnedString(str);
+}
+
//===----------------------------------------------------------------------===//
// DenseResourceElementsAttr
@@ -550,6 +670,28 @@ void BuiltinDialectBytecodeInterface::write(
writer.writeAPIntWithKnownWidth(attr.getValue());
}
+//===----------------------------------------------------------------------===//
+// SparseElementsAttr
+
+SparseElementsAttr BuiltinDialectBytecodeInterface::readSparseElementsAttr(
+ DialectBytecodeReader &reader) const {
+ ShapedType type;
+ DenseIntElementsAttr indices;
+ DenseElementsAttr values;
+ if (failed(reader.readType(type)) || failed(reader.readAttribute(indices)) ||
+ failed(reader.readAttribute(values)))
+ return SparseElementsAttr();
+ return SparseElementsAttr::get(type, indices, values);
+}
+
+void BuiltinDialectBytecodeInterface::write(
+ SparseElementsAttr attr, DialectBytecodeWriter &writer) const {
+ writer.writeVarInt(builtin_encoding::kSparseElementsAttr);
+ writer.writeType(attr.getType());
+ writer.writeAttribute(attr.getIndices());
+ writer.writeAttribute(attr.getValues());
+}
+
//===----------------------------------------------------------------------===//
// StringAttr
diff --git a/mlir/test/Dialect/Builtin/Bytecode/attrs.mlir b/mlir/test/Dialect/Builtin/Bytecode/attrs.mlir
index f865ff36b5060..97fc3001b64c3 100644
--- a/mlir/test/Dialect/Builtin/Bytecode/attrs.mlir
+++ b/mlir/test/Dialect/Builtin/Bytecode/attrs.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -emit-bytecode %s | mlir-opt -mlir-print-local-scope | FileCheck %s
+// RUN: mlir-opt -emit-bytecode -allow-unregistered-dialect %s | mlir-opt -allow-unregistered-dialect -mlir-print-local-scope | FileCheck %s
// Bytecode currently does not support big-endian platforms
// UNSUPPORTED: s390x-
@@ -13,6 +13,44 @@ module @TestArray attributes {
bytecode.array = [unit]
} {}
+//===----------------------------------------------------------------------===//
+// DenseArrayAttr
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @TestDenseArray
+module @TestDenseArray attributes {
+ // CHECK: bytecode.test1 = array<i1: true, false, true, false, false>
+ // CHECK: bytecode.test2 = array<i8: 10, 32, -1>
+ // CHECK: bytecode.test3 = array<f64: 1.{{.*}}e+01, 3.2{{.*}}e+01, 1.809{{.*}}e+03
+ bytecode.test1 = array<i1: true, false, true, false, false>,
+ bytecode.test2 = array<i8: 10, 32, 255>,
+ bytecode.test3 = array<f64: 10.0, 32.0, 1809.0>
+} {}
+
+//===----------------------------------------------------------------------===//
+// DenseIntOfFPElementsAttr
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @TestDenseIntOrFPElements
+// CHECK: bytecode.test1 = dense<true> : tensor<256xi1>
+// CHECK: bytecode.test2 = dense<[10, 32, -1]> : tensor<3xi8>
+// CHECK: bytecode.test3 = dense<[1.{{.*}}e+01, 3.2{{.*}}e+01, 1.809{{.*}}e+03]> : tensor<3xf64>
+module @TestDenseIntOrFPElements attributes {
+ bytecode.test1 = dense<true> : tensor<256xi1>,
+ bytecode.test2 = dense<[10, 32, 255]> : tensor<3xi8>,
+ bytecode.test3 = dense<[10.0, 32.0, 1809.0]> : tensor<3xf64>
+} {}
+
+//===----------------------------------------------------------------------===//
+// DenseStringElementsAttr
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @TestDenseStringElementsAttr
+module @TestDenseStringElementsAttr attributes {
+ bytecode.test1 = dense<"splat"> : tensor<256x!bytecode.string>,
+ bytecode.test2 = dense<["foo", "bar", "baz"]> : tensor<3x!bytecode.string>
+} {}
+
//===----------------------------------------------------------------------===//
// FloatAttr
//===----------------------------------------------------------------------===//
@@ -45,6 +83,17 @@ module @TestInt attributes {
bytecode.int3 = 90000000000000000300000000000000000001 : i128
} {}
+//===----------------------------------------------------------------------===//
+// SparseElementsAttr
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @TestSparseElements
+module @TestSparseElements attributes {
+ // CHECK-LITERAL: bytecode.sparse = sparse<[[0, 0], [1, 2]], [1, 5]> : tensor<3x4xi32>
+ bytecode.sparse = sparse<[[0, 0], [1, 2]], [1, 5]> : tensor<3x4xi32>
+} {}
+
+
//===----------------------------------------------------------------------===//
// StringAttr
//===----------------------------------------------------------------------===//
More information about the Mlir-commits
mailing list