[Mlir-commits] [mlir] 5fb1bbe - [mlir] Add bytecode encodings for the builtin ElementsAttr attributes

River Riddle llvmlistbot at llvm.org
Tue Sep 13 11:39:47 PDT 2022


Author: River Riddle
Date: 2022-09-13T11:39:20-07:00
New Revision: 5fb1bbe6d49e08af5d47ce4d9498d650e6628dad

URL: https://github.com/llvm/llvm-project/commit/5fb1bbe6d49e08af5d47ce4d9498d650e6628dad
DIFF: https://github.com/llvm/llvm-project/commit/5fb1bbe6d49e08af5d47ce4d9498d650e6628dad.diff

LOG: [mlir] Add bytecode encodings for the builtin ElementsAttr attributes

This adds bytecode support for DenseArrayAttr, DenseIntOrFpElementsAttr,
DenseStringElementsAttr, and SparseElementsAttr.

Differential Revision: https://reviews.llvm.org/D133744

Added: 
    

Modified: 
    mlir/include/mlir/Bytecode/BytecodeImplementation.h
    mlir/lib/Bytecode/Reader/BytecodeReader.cpp
    mlir/lib/Bytecode/Writer/BytecodeWriter.cpp
    mlir/lib/Bytecode/Writer/IRNumbering.cpp
    mlir/lib/IR/BuiltinDialectBytecode.cpp
    mlir/test/Dialect/Builtin/Bytecode/attrs.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Bytecode/BytecodeImplementation.h b/mlir/include/mlir/Bytecode/BytecodeImplementation.h
index 7607f2a4036be..1ae839d5f40e7 100644
--- a/mlir/include/mlir/Bytecode/BytecodeImplementation.h
+++ b/mlir/include/mlir/Bytecode/BytecodeImplementation.h
@@ -143,6 +143,9 @@ class DialectBytecodeReader {
   /// Read a string from the bytecode.
   virtual LogicalResult readString(StringRef &result) = 0;
 
+  /// Read a blob from the bytecode.
+  virtual LogicalResult readBlob(ArrayRef<char> &result) = 0;
+
 private:
   /// Read a handle to a dialect resource.
   virtual FailureOr<AsmDialectResourceHandle> readResourceHandle() = 0;
@@ -225,6 +228,11 @@ class DialectBytecodeWriter {
   /// only be called if such a guarantee can be made, such as when the string is
   /// owned by an attribute or type.
   virtual void writeOwnedString(StringRef str) = 0;
+
+  /// Write a blob to the bytecode, which is owned by the caller and is
+  /// guaranteed to not die before the end of the bytecode process. The blob is
+  /// written as-is, with no additional compression or compaction.
+  virtual void writeOwnedBlob(ArrayRef<char> blob) = 0;
 };
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
index 5e10dfad355a3..2f36de523ff32 100644
--- a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
+++ b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
@@ -887,6 +887,17 @@ class DialectReader : public DialectBytecodeReader {
     return stringReader.parseString(reader, result);
   }
 
+  LogicalResult readBlob(ArrayRef<char> &result) override {
+    uint64_t dataSize;
+    ArrayRef<uint8_t> data;
+    if (failed(reader.parseVarInt(dataSize)) ||
+        failed(reader.parseBytes(dataSize, data)))
+      return failure();
+    result = llvm::makeArrayRef(reinterpret_cast<const char *>(data.data()),
+                                data.size());
+    return success();
+  }
+
 private:
   AttrTypeReader &attrTypeReader;
   StringSectionReader &stringReader;

diff  --git a/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp
index 7bcc1a841c245..1fd313453fb24 100644
--- a/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp
+++ b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp
@@ -543,6 +543,12 @@ class DialectWriter : public DialectBytecodeWriter {
     emitter.emitVarInt(stringSection.insert(str));
   }
 
+  void writeOwnedBlob(ArrayRef<char> blob) override {
+    emitter.emitVarInt(blob.size());
+    emitter.emitOwnedBlob(ArrayRef<uint8_t>(
+        reinterpret_cast<const uint8_t *>(blob.data()), blob.size()));
+  }
+
 private:
   EncodingEmitter &emitter;
   IRNumberingState &numberingState;

diff  --git a/mlir/lib/Bytecode/Writer/IRNumbering.cpp b/mlir/lib/Bytecode/Writer/IRNumbering.cpp
index d5a2ef5f2bf14..9fed3236af588 100644
--- a/mlir/lib/Bytecode/Writer/IRNumbering.cpp
+++ b/mlir/lib/Bytecode/Writer/IRNumbering.cpp
@@ -39,6 +39,7 @@ struct IRNumberingState::NumberingDialectWriter : public DialectBytecodeWriter {
     // references. This could potentially be useful for optimizing things like
     // file locations.
   }
+  void writeOwnedBlob(ArrayRef<char> blob) override {}
 
   /// The parent numbering state that is populated by this writer.
   IRNumberingState &state;

diff  --git a/mlir/lib/IR/BuiltinDialectBytecode.cpp b/mlir/lib/IR/BuiltinDialectBytecode.cpp
index 260282821ac51..18f3ea8f2382e 100644
--- a/mlir/lib/IR/BuiltinDialectBytecode.cpp
+++ b/mlir/lib/IR/BuiltinDialectBytecode.cpp
@@ -123,6 +123,32 @@ enum AttributeCode {
   ///     handle: ResourceHandle
   ///   }
   kDenseResourceElementsAttr = 16,
+
+  ///   DenseArrayAttr {
+  ///     type: RankedTensorType,
+  ///     data: blob
+  ///   }
+  kDenseArrayAttr = 17,
+
+  ///   DenseIntOrFPElementsAttr {
+  ///     type: ShapedType,
+  ///     data: blob
+  ///   }
+  kDenseIntOrFPElementsAttr = 18,
+
+  ///   DenseStringElementsAttr {
+  ///     type: ShapedType,
+  ///     isSplat: varint,
+  ///     data: string[]
+  ///   }
+  kDenseStringElementsAttr = 19,
+
+  ///   SparseElementsAttr {
+  ///     type: ShapedType,
+  ///     indices: DenseIntElementsAttr,
+  ///     values: DenseElementsAttr
+  ///   }
+  kSparseElementsAttr = 20,
 };
 
 /// This enum contains marker codes used to indicate which type is currently
@@ -279,11 +305,18 @@ struct BuiltinDialectBytecodeInterface : public BytecodeDialectInterface {
 
   Attribute readAttribute(DialectBytecodeReader &reader) const override;
   ArrayAttr readArrayAttr(DialectBytecodeReader &reader) const;
+  DenseArrayAttr readDenseArrayAttr(DialectBytecodeReader &reader) const;
+  DenseElementsAttr
+  readDenseIntOrFPElementsAttr(DialectBytecodeReader &reader) const;
+  DenseStringElementsAttr
+  readDenseStringElementsAttr(DialectBytecodeReader &reader) const;
   DenseResourceElementsAttr
   readDenseResourceElementsAttr(DialectBytecodeReader &reader) const;
   DictionaryAttr readDictionaryAttr(DialectBytecodeReader &reader) const;
   FloatAttr readFloatAttr(DialectBytecodeReader &reader) const;
   IntegerAttr readIntegerAttr(DialectBytecodeReader &reader) const;
+  SparseElementsAttr
+  readSparseElementsAttr(DialectBytecodeReader &reader) const;
   StringAttr readStringAttr(DialectBytecodeReader &reader, bool hasType) const;
   SymbolRefAttr readSymbolRefAttr(DialectBytecodeReader &reader,
                                   bool hasNestedRefs) const;
@@ -298,11 +331,16 @@ struct BuiltinDialectBytecodeInterface : public BytecodeDialectInterface {
   LogicalResult writeAttribute(Attribute attr,
                                DialectBytecodeWriter &writer) const override;
   void write(ArrayAttr attr, DialectBytecodeWriter &writer) const;
+  void write(DenseArrayAttr attr, DialectBytecodeWriter &writer) const;
+  void write(DenseIntOrFPElementsAttr attr,
+             DialectBytecodeWriter &writer) const;
+  void write(DenseStringElementsAttr attr, DialectBytecodeWriter &writer) const;
   void write(DenseResourceElementsAttr attr,
              DialectBytecodeWriter &writer) const;
   void write(DictionaryAttr attr, DialectBytecodeWriter &writer) const;
   void write(IntegerAttr attr, DialectBytecodeWriter &writer) const;
   void write(FloatAttr attr, DialectBytecodeWriter &writer) const;
+  void write(SparseElementsAttr attr, DialectBytecodeWriter &writer) const;
   void write(StringAttr attr, DialectBytecodeWriter &writer) const;
   void write(SymbolRefAttr attr, DialectBytecodeWriter &writer) const;
   void write(TypeAttr attr, DialectBytecodeWriter &writer) const;
@@ -394,6 +432,14 @@ Attribute BuiltinDialectBytecodeInterface::readAttribute(
     return UnknownLoc::get(getContext());
   case builtin_encoding::kDenseResourceElementsAttr:
     return readDenseResourceElementsAttr(reader);
+  case builtin_encoding::kDenseArrayAttr:
+    return readDenseArrayAttr(reader);
+  case builtin_encoding::kDenseIntOrFPElementsAttr:
+    return readDenseIntOrFPElementsAttr(reader);
+  case builtin_encoding::kDenseStringElementsAttr:
+    return readDenseStringElementsAttr(reader);
+  case builtin_encoding::kSparseElementsAttr:
+    return readSparseElementsAttr(reader);
   default:
     reader.emitError() << "unknown builtin attribute code: " << code;
     return Attribute();
@@ -403,8 +449,10 @@ Attribute BuiltinDialectBytecodeInterface::readAttribute(
 LogicalResult BuiltinDialectBytecodeInterface::writeAttribute(
     Attribute attr, DialectBytecodeWriter &writer) const {
   return TypeSwitch<Attribute, LogicalResult>(attr)
-      .Case<ArrayAttr, DenseResourceElementsAttr, DictionaryAttr, FloatAttr,
-            IntegerAttr, StringAttr, SymbolRefAttr, TypeAttr>([&](auto attr) {
+      .Case<ArrayAttr, DenseArrayAttr, DenseIntOrFPElementsAttr,
+            DenseStringElementsAttr, DenseResourceElementsAttr, DictionaryAttr,
+            FloatAttr, IntegerAttr, SparseElementsAttr, StringAttr,
+            SymbolRefAttr, TypeAttr>([&](auto attr) {
         write(attr, writer);
         return success();
       })
@@ -441,6 +489,78 @@ void BuiltinDialectBytecodeInterface::write(
   writer.writeAttributes(attr.getValue());
 }
 
+//===----------------------------------------------------------------------===//
+// DenseArrayAttr
+
+DenseArrayAttr BuiltinDialectBytecodeInterface::readDenseArrayAttr(
+    DialectBytecodeReader &reader) const {
+  RankedTensorType type;
+  ArrayRef<char> blob;
+  if (failed(reader.readType(type)) || failed(reader.readBlob(blob)))
+    return DenseArrayAttr();
+  return DenseArrayAttr::get(type, blob);
+}
+
+void BuiltinDialectBytecodeInterface::write(
+    DenseArrayAttr attr, DialectBytecodeWriter &writer) const {
+  writer.writeVarInt(builtin_encoding::kDenseArrayAttr);
+  writer.writeType(attr.getType());
+  writer.writeOwnedBlob(attr.getRawData());
+}
+
+//===----------------------------------------------------------------------===//
+// DenseIntOrFPElementsAttr
+
+DenseElementsAttr BuiltinDialectBytecodeInterface::readDenseIntOrFPElementsAttr(
+    DialectBytecodeReader &reader) const {
+  ShapedType type;
+  ArrayRef<char> blob;
+  if (failed(reader.readType(type)) || failed(reader.readBlob(blob)))
+    return DenseIntOrFPElementsAttr();
+  return DenseIntOrFPElementsAttr::getFromRawBuffer(type, blob);
+}
+
+void BuiltinDialectBytecodeInterface::write(
+    DenseIntOrFPElementsAttr attr, DialectBytecodeWriter &writer) const {
+  writer.writeVarInt(builtin_encoding::kDenseIntOrFPElementsAttr);
+  writer.writeType(attr.getType());
+  writer.writeOwnedBlob(attr.getRawData());
+}
+
+//===----------------------------------------------------------------------===//
+// DenseStringElementsAttr
+
+DenseStringElementsAttr
+BuiltinDialectBytecodeInterface::readDenseStringElementsAttr(
+    DialectBytecodeReader &reader) const {
+  ShapedType type;
+  uint64_t isSplat;
+  if (failed(reader.readType(type)) || failed(reader.readVarInt(isSplat)))
+    return DenseStringElementsAttr();
+
+  SmallVector<StringRef> values(isSplat ? 1 : type.getNumElements());
+  for (StringRef &value : values)
+    if (failed(reader.readString(value)))
+      return DenseStringElementsAttr();
+  return DenseStringElementsAttr::get(type, values);
+}
+
+void BuiltinDialectBytecodeInterface::write(
+    DenseStringElementsAttr attr, DialectBytecodeWriter &writer) const {
+  writer.writeVarInt(builtin_encoding::kDenseStringElementsAttr);
+  writer.writeType(attr.getType());
+
+  bool isSplat = attr.isSplat();
+  writer.writeVarInt(isSplat);
+
+  // If the attribute is a splat, only write out the single value.
+  if (isSplat)
+    return writer.writeOwnedString(attr.getRawStringData().front());
+
+  for (StringRef str : attr.getRawStringData())
+    writer.writeOwnedString(str);
+}
+
 //===----------------------------------------------------------------------===//
 // DenseResourceElementsAttr
 
@@ -550,6 +670,28 @@ void BuiltinDialectBytecodeInterface::write(
   writer.writeAPIntWithKnownWidth(attr.getValue());
 }
 
+//===----------------------------------------------------------------------===//
+// SparseElementsAttr
+
+SparseElementsAttr BuiltinDialectBytecodeInterface::readSparseElementsAttr(
+    DialectBytecodeReader &reader) const {
+  ShapedType type;
+  DenseIntElementsAttr indices;
+  DenseElementsAttr values;
+  if (failed(reader.readType(type)) || failed(reader.readAttribute(indices)) ||
+      failed(reader.readAttribute(values)))
+    return SparseElementsAttr();
+  return SparseElementsAttr::get(type, indices, values);
+}
+
+void BuiltinDialectBytecodeInterface::write(
+    SparseElementsAttr attr, DialectBytecodeWriter &writer) const {
+  writer.writeVarInt(builtin_encoding::kSparseElementsAttr);
+  writer.writeType(attr.getType());
+  writer.writeAttribute(attr.getIndices());
+  writer.writeAttribute(attr.getValues());
+}
+
 //===----------------------------------------------------------------------===//
 // StringAttr
 

diff  --git a/mlir/test/Dialect/Builtin/Bytecode/attrs.mlir b/mlir/test/Dialect/Builtin/Bytecode/attrs.mlir
index f865ff36b5060..97fc3001b64c3 100644
--- a/mlir/test/Dialect/Builtin/Bytecode/attrs.mlir
+++ b/mlir/test/Dialect/Builtin/Bytecode/attrs.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -emit-bytecode %s | mlir-opt -mlir-print-local-scope | FileCheck %s
+// RUN: mlir-opt -emit-bytecode -allow-unregistered-dialect %s | mlir-opt -allow-unregistered-dialect -mlir-print-local-scope | FileCheck %s
 
 // Bytecode currently does not support big-endian platforms
 // UNSUPPORTED: s390x-
@@ -13,6 +13,44 @@ module @TestArray attributes {
   bytecode.array = [unit]
 } {}
 
+//===----------------------------------------------------------------------===//
+// DenseArrayAttr
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @TestDenseArray
+module @TestDenseArray attributes {
+  // CHECK: bytecode.test1 = array<i1: true, false, true, false, false>
+  // CHECK: bytecode.test2 = array<i8: 10, 32, -1>
+  // CHECK: bytecode.test3 = array<f64: 1.{{.*}}e+01, 3.2{{.*}}e+01, 1.809{{.*}}e+03
+  bytecode.test1 = array<i1: true, false, true, false, false>,
+  bytecode.test2 = array<i8: 10, 32, 255>,
+  bytecode.test3 = array<f64: 10.0, 32.0, 1809.0>
+} {}
+
+//===----------------------------------------------------------------------===//
+// DenseIntOfFPElementsAttr
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @TestDenseIntOrFPElements
+// CHECK: bytecode.test1 = dense<true> : tensor<256xi1>
+// CHECK: bytecode.test2 = dense<[10, 32, -1]> : tensor<3xi8>
+// CHECK: bytecode.test3 = dense<[1.{{.*}}e+01, 3.2{{.*}}e+01, 1.809{{.*}}e+03]> : tensor<3xf64>
+module @TestDenseIntOrFPElements attributes {
+  bytecode.test1 = dense<true> : tensor<256xi1>,
+  bytecode.test2 = dense<[10, 32, 255]> : tensor<3xi8>,
+  bytecode.test3 = dense<[10.0, 32.0, 1809.0]> : tensor<3xf64>
+} {}
+
+//===----------------------------------------------------------------------===//
+// DenseStringElementsAttr
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @TestDenseStringElementsAttr
+module @TestDenseStringElementsAttr attributes {
+  bytecode.test1 = dense<"splat"> : tensor<256x!bytecode.string>,
+  bytecode.test2 = dense<["foo", "bar", "baz"]> : tensor<3x!bytecode.string>
+} {}
+
 //===----------------------------------------------------------------------===//
 // FloatAttr
 //===----------------------------------------------------------------------===//
@@ -45,6 +83,17 @@ module @TestInt attributes {
   bytecode.int3 = 90000000000000000300000000000000000001 : i128
 } {}
 
+//===----------------------------------------------------------------------===//
+// SparseElementsAttr
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @TestSparseElements
+module @TestSparseElements attributes {
+  // CHECK-LITERAL: bytecode.sparse = sparse<[[0, 0], [1, 2]], [1, 5]> : tensor<3x4xi32>
+  bytecode.sparse = sparse<[[0, 0], [1, 2]], [1, 5]> : tensor<3x4xi32>
+} {}
+
+
 //===----------------------------------------------------------------------===//
 // StringAttr
 //===----------------------------------------------------------------------===//


        


More information about the Mlir-commits mailing list