[Mlir-commits] [mlir] 8574405 - [mlir] Add bytecode encoding for the remaining builtin types

River Riddle llvmlistbot at llvm.org
Fri Aug 26 13:33:47 PDT 2022


Author: River Riddle
Date: 2022-08-26T13:31:06-07:00
New Revision: 857440590baf22aa4d3767b39d2d2c246d230922

URL: https://github.com/llvm/llvm-project/commit/857440590baf22aa4d3767b39d2d2c246d230922
DIFF: https://github.com/llvm/llvm-project/commit/857440590baf22aa4d3767b39d2d2c246d230922.diff

LOG: [mlir] Add bytecode encoding for the remaining builtin types

After this commit we will have an efficient bytecode representation for all
of the builtin types.

Differential Revision: https://reviews.llvm.org/D132604

Added: 
    

Modified: 
    mlir/include/mlir/Bytecode/BytecodeImplementation.h
    mlir/lib/IR/BuiltinDialectBytecode.cpp
    mlir/test/Dialect/Builtin/Bytecode/types.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Bytecode/BytecodeImplementation.h b/mlir/include/mlir/Bytecode/BytecodeImplementation.h
index 921d795068fbd..845c790661411 100644
--- a/mlir/include/mlir/Bytecode/BytecodeImplementation.h
+++ b/mlir/include/mlir/Bytecode/BytecodeImplementation.h
@@ -38,10 +38,6 @@ class DialectBytecodeReader {
   /// Emit an error to the reader.
   virtual InFlightDiagnostic emitError(const Twine &msg = {}) = 0;
 
-  //===--------------------------------------------------------------------===//
-  // IR
-  //===--------------------------------------------------------------------===//
-
   /// Read out a list of elements, invoking the provided callback for each
   /// element. The callback function may be in any of the following forms:
   ///   * LogicalResult(T &)
@@ -71,6 +67,10 @@ class DialectBytecodeReader {
     return success();
   }
 
+  //===--------------------------------------------------------------------===//
+  // IR
+  //===--------------------------------------------------------------------===//
+
   /// Read a reference to the given attribute.
   virtual LogicalResult readAttribute(Attribute &result) = 0;
   template <typename T>
@@ -114,6 +114,10 @@ class DialectBytecodeReader {
 
   /// Read a signed variable width integer.
   virtual LogicalResult readSignedVarInt(int64_t &result) = 0;
+  LogicalResult readSignedVarInts(SmallVectorImpl<int64_t> &result) {
+    return readList(result,
+                    [this](int64_t &value) { return readSignedVarInt(value); });
+  }
 
   /// Read an APInt that is known to have been encoded with the given width.
   virtual FailureOr<APInt> readAPIntWithKnownWidth(unsigned bitWidth) = 0;
@@ -178,6 +182,9 @@ class DialectBytecodeWriter {
   /// Write a signed variable width integer to the output stream. This should be
   /// the preferred method for emitting signed integers whenever possible.
   virtual void writeSignedVarInt(int64_t value) = 0;
+  void writeSignedVarInts(ArrayRef<int64_t> value) {
+    writeList(value, [this](int64_t value) { writeSignedVarInt(value); });
+  }
 
   /// Write an APInt to the bytecode stream whose bitwidth will be known
   /// externally at read time. This method is useful for encoding APInt values

diff  --git a/mlir/lib/IR/BuiltinDialectBytecode.cpp b/mlir/lib/IR/BuiltinDialectBytecode.cpp
index 43288a917fc24..ce6120c77ffbf 100644
--- a/mlir/lib/IR/BuiltinDialectBytecode.cpp
+++ b/mlir/lib/IR/BuiltinDialectBytecode.cpp
@@ -140,6 +140,118 @@ enum TypeCode {
   ///   }
   ///
   kFunctionType = 2,
+
+  ///   BFloat16Type {
+  ///   }
+  ///
+  kBFloat16Type = 3,
+
+  ///   Float16Type {
+  ///   }
+  ///
+  kFloat16Type = 4,
+
+  ///   Float32Type {
+  ///   }
+  ///
+  kFloat32Type = 5,
+
+  ///   Float64Type {
+  ///   }
+  ///
+  kFloat64Type = 6,
+
+  ///   Float80Type {
+  ///   }
+  ///
+  kFloat80Type = 7,
+
+  ///   Float128Type {
+  ///   }
+  ///
+  kFloat128Type = 8,
+
+  ///   ComplexType {
+  ///     elementType: Type
+  ///   }
+  ///
+  kComplexType = 9,
+
+  ///   MemRefType {
+  ///     shape: svarint[],
+  ///     elementType: Type,
+  ///     layout: Attribute
+  ///   }
+  ///
+  kMemRefType = 10,
+
+  ///   MemRefTypeWithMemSpace {
+  ///     memorySpace: Attribute,
+  ///     shape: svarint[],
+  ///     elementType: Type,
+  ///     layout: Attribute
+  ///   }
+  /// Variant of MemRefType with non-default memory space.
+  kMemRefTypeWithMemSpace = 11,
+
+  ///   NoneType {
+  ///   }
+  ///
+  kNoneType = 12,
+
+  ///   RankedTensorType {
+  ///     shape: svarint[],
+  ///     elementType: Type,
+  ///   }
+  ///
+  kRankedTensorType = 13,
+
+  ///   RankedTensorTypeWithEncoding {
+  ///     encoding: Attribute,
+  ///     shape: svarint[],
+  ///     elementType: Type
+  ///   }
+  /// Variant of RankedTensorType with an encoding.
+  kRankedTensorTypeWithEncoding = 14,
+
+  ///   TupleType {
+  ///     elementTypes: Type[]
+  ///   }
+  kTupleType = 15,
+
+  ///   UnrankedMemRefType {
+  ///     shape: svarint[]
+  ///   }
+  ///
+  kUnrankedMemRefType = 16,
+
+  ///   UnrankedMemRefTypeWithMemSpace {
+  ///     memorySpace: Attribute,
+  ///     shape: svarint[]
+  ///   }
+  /// Variant of UnrankedMemRefType with non-default memory space.
+  kUnrankedMemRefTypeWithMemSpace = 17,
+
+  ///   UnrankedTensorType {
+  ///     elementType: Type
+  ///   }
+  ///
+  kUnrankedTensorType = 18,
+
+  ///   VectorType {
+  ///     shape: svarint[],
+  ///     elementType: Type
+  ///   }
+  ///
+  kVectorType = 19,
+
+  ///   VectorTypeWithScalableDims {
+  ///     numScalableDims: varint,
+  ///     shape: svarint[],
+  ///     elementType: Type
+  ///   }
+  /// Variant of VectorType with scalable dimensions.
+  kVectorTypeWithScalableDims = 20,
 };
 
 } // namespace builtin_encoding
@@ -194,13 +306,32 @@ struct BuiltinDialectBytecodeInterface : public BytecodeDialectInterface {
   // Types
 
   Type readType(DialectBytecodeReader &reader) const override;
+  ComplexType readComplexType(DialectBytecodeReader &reader) const;
   IntegerType readIntegerType(DialectBytecodeReader &reader) const;
   FunctionType readFunctionType(DialectBytecodeReader &reader) const;
+  MemRefType readMemRefType(DialectBytecodeReader &reader,
+                            bool hasMemSpace) const;
+  RankedTensorType readRankedTensorType(DialectBytecodeReader &reader,
+                                        bool hasEncoding) const;
+  TupleType readTupleType(DialectBytecodeReader &reader) const;
+  UnrankedMemRefType readUnrankedMemRefType(DialectBytecodeReader &reader,
+                                            bool hasMemSpace) const;
+  UnrankedTensorType
+  readUnrankedTensorType(DialectBytecodeReader &reader) const;
+  VectorType readVectorType(DialectBytecodeReader &reader,
+                            bool hasScalableDims) const;
 
   LogicalResult writeType(Type type,
                           DialectBytecodeWriter &writer) const override;
+  void write(ComplexType type, DialectBytecodeWriter &writer) const;
   void write(IntegerType type, DialectBytecodeWriter &writer) const;
   void write(FunctionType type, DialectBytecodeWriter &writer) const;
+  void write(MemRefType type, DialectBytecodeWriter &writer) const;
+  void write(RankedTensorType type, DialectBytecodeWriter &writer) const;
+  void write(TupleType type, DialectBytecodeWriter &writer) const;
+  void write(UnrankedMemRefType type, DialectBytecodeWriter &writer) const;
+  void write(UnrankedTensorType type, DialectBytecodeWriter &writer) const;
+  void write(VectorType type, DialectBytecodeWriter &writer) const;
 };
 } // namespace
 
@@ -576,9 +707,45 @@ Type BuiltinDialectBytecodeInterface::readType(
     return readIntegerType(reader);
   case builtin_encoding::kIndexType:
     return IndexType::get(getContext());
-
   case builtin_encoding::kFunctionType:
     return readFunctionType(reader);
+  case builtin_encoding::kBFloat16Type:
+    return BFloat16Type::get(getContext());
+  case builtin_encoding::kFloat16Type:
+    return Float16Type::get(getContext());
+  case builtin_encoding::kFloat32Type:
+    return Float32Type::get(getContext());
+  case builtin_encoding::kFloat64Type:
+    return Float64Type::get(getContext());
+  case builtin_encoding::kFloat80Type:
+    return Float80Type::get(getContext());
+  case builtin_encoding::kFloat128Type:
+    return Float128Type::get(getContext());
+  case builtin_encoding::kComplexType:
+    return readComplexType(reader);
+  case builtin_encoding::kMemRefType:
+    return readMemRefType(reader, /*hasMemSpace=*/false);
+  case builtin_encoding::kMemRefTypeWithMemSpace:
+    return readMemRefType(reader, /*hasMemSpace=*/true);
+  case builtin_encoding::kNoneType:
+    return NoneType::get(getContext());
+  case builtin_encoding::kRankedTensorType:
+    return readRankedTensorType(reader, /*hasEncoding=*/false);
+  case builtin_encoding::kRankedTensorTypeWithEncoding:
+    return readRankedTensorType(reader, /*hasEncoding=*/true);
+  case builtin_encoding::kTupleType:
+    return readTupleType(reader);
+  case builtin_encoding::kUnrankedMemRefType:
+    return readUnrankedMemRefType(reader, /*hasMemSpace=*/false);
+  case builtin_encoding::kUnrankedMemRefTypeWithMemSpace:
+    return readUnrankedMemRefType(reader, /*hasMemSpace=*/true);
+  case builtin_encoding::kUnrankedTensorType:
+    return readUnrankedTensorType(reader);
+  case builtin_encoding::kVectorType:
+    return readVectorType(reader, /*hasScalableDims=*/false);
+  case builtin_encoding::kVectorTypeWithScalableDims:
+    return readVectorType(reader, /*hasScalableDims=*/true);
+
   default:
     reader.emitError() << "unknown builtin type code: " << code;
     return Type();
@@ -588,16 +755,56 @@ Type BuiltinDialectBytecodeInterface::readType(
 LogicalResult BuiltinDialectBytecodeInterface::writeType(
     Type type, DialectBytecodeWriter &writer) const {
   return TypeSwitch<Type, LogicalResult>(type)
-      .Case<IntegerType, FunctionType>([&](auto type) {
+      .Case<ComplexType, IntegerType, FunctionType, MemRefType,
+            RankedTensorType, TupleType, UnrankedMemRefType, UnrankedTensorType,
+            VectorType>([&](auto type) {
         write(type, writer);
         return success();
       })
       .Case([&](IndexType) {
         return writer.writeVarInt(builtin_encoding::kIndexType), success();
       })
+      .Case([&](BFloat16Type) {
+        return writer.writeVarInt(builtin_encoding::kBFloat16Type), success();
+      })
+      .Case([&](Float16Type) {
+        return writer.writeVarInt(builtin_encoding::kFloat16Type), success();
+      })
+      .Case([&](Float32Type) {
+        return writer.writeVarInt(builtin_encoding::kFloat32Type), success();
+      })
+      .Case([&](Float64Type) {
+        return writer.writeVarInt(builtin_encoding::kFloat64Type), success();
+      })
+      .Case([&](Float80Type) {
+        return writer.writeVarInt(builtin_encoding::kFloat80Type), success();
+      })
+      .Case([&](Float128Type) {
+        return writer.writeVarInt(builtin_encoding::kFloat128Type), success();
+      })
+      .Case([&](NoneType) {
+        return writer.writeVarInt(builtin_encoding::kNoneType), success();
+      })
       .Default([&](Type) { return failure(); });
 }
 
+//===----------------------------------------------------------------------===//
+// ComplexType
+
+ComplexType BuiltinDialectBytecodeInterface::readComplexType(
+    DialectBytecodeReader &reader) const {
+  Type elementType;
+  if (failed(reader.readType(elementType)))
+    return ComplexType();
+  return ComplexType::get(elementType);
+}
+
+void BuiltinDialectBytecodeInterface::write(
+    ComplexType type, DialectBytecodeWriter &writer) const {
+  writer.writeVarInt(builtin_encoding::kComplexType);
+  writer.writeType(type.getElementType());
+}
+
 //===----------------------------------------------------------------------===//
 // IntegerType
 
@@ -634,3 +841,151 @@ void BuiltinDialectBytecodeInterface::write(
   writer.writeTypes(type.getInputs());
   writer.writeTypes(type.getResults());
 }
+
+//===----------------------------------------------------------------------===//
+// MemRefType
+
+MemRefType
+BuiltinDialectBytecodeInterface::readMemRefType(DialectBytecodeReader &reader,
+                                                bool hasMemSpace) const {
+  Attribute memorySpace;
+  if (hasMemSpace && failed(reader.readAttribute(memorySpace)))
+    return MemRefType();
+  SmallVector<int64_t> shape;
+  Type elementType;
+  MemRefLayoutAttrInterface layout;
+  if (failed(reader.readSignedVarInts(shape)) ||
+      failed(reader.readType(elementType)) ||
+      failed(reader.readAttribute(layout)))
+    return MemRefType();
+  return MemRefType::get(shape, elementType, layout, memorySpace);
+}
+
+void BuiltinDialectBytecodeInterface::write(
+    MemRefType type, DialectBytecodeWriter &writer) const {
+  if (Attribute memSpace = type.getMemorySpace()) {
+    writer.writeVarInt(builtin_encoding::kMemRefTypeWithMemSpace);
+    writer.writeAttribute(memSpace);
+  } else {
+    writer.writeVarInt(builtin_encoding::kMemRefType);
+  }
+  writer.writeSignedVarInts(type.getShape());
+  writer.writeType(type.getElementType());
+  writer.writeAttribute(type.getLayout());
+}
+
+//===----------------------------------------------------------------------===//
+// RankedTensorType
+
+RankedTensorType BuiltinDialectBytecodeInterface::readRankedTensorType(
+    DialectBytecodeReader &reader, bool hasEncoding) const {
+  Attribute encoding;
+  if (hasEncoding && failed(reader.readAttribute(encoding)))
+    return RankedTensorType();
+  SmallVector<int64_t> shape;
+  Type elementType;
+  if (failed(reader.readSignedVarInts(shape)) ||
+      failed(reader.readType(elementType)))
+    return RankedTensorType();
+  return RankedTensorType::get(shape, elementType, encoding);
+}
+
+void BuiltinDialectBytecodeInterface::write(
+    RankedTensorType type, DialectBytecodeWriter &writer) const {
+  if (Attribute encoding = type.getEncoding()) {
+    writer.writeVarInt(builtin_encoding::kRankedTensorTypeWithEncoding);
+    writer.writeAttribute(encoding);
+  } else {
+    writer.writeVarInt(builtin_encoding::kRankedTensorType);
+  }
+  writer.writeSignedVarInts(type.getShape());
+  writer.writeType(type.getElementType());
+}
+
+//===----------------------------------------------------------------------===//
+// TupleType
+
+TupleType BuiltinDialectBytecodeInterface::readTupleType(
+    DialectBytecodeReader &reader) const {
+  SmallVector<Type> elements;
+  if (failed(reader.readTypes(elements)))
+    return TupleType();
+  return TupleType::get(getContext(), elements);
+}
+
+void BuiltinDialectBytecodeInterface::write(
+    TupleType type, DialectBytecodeWriter &writer) const {
+  writer.writeVarInt(builtin_encoding::kTupleType);
+  writer.writeTypes(type.getTypes());
+}
+
+//===----------------------------------------------------------------------===//
+// UnrankedMemRefType
+
+UnrankedMemRefType BuiltinDialectBytecodeInterface::readUnrankedMemRefType(
+    DialectBytecodeReader &reader, bool hasMemSpace) const {
+  Attribute memorySpace;
+  if (hasMemSpace && failed(reader.readAttribute(memorySpace)))
+    return UnrankedMemRefType();
+  Type elementType;
+  if (failed(reader.readType(elementType)))
+    return UnrankedMemRefType();
+  return UnrankedMemRefType::get(elementType, memorySpace);
+}
+
+void BuiltinDialectBytecodeInterface::write(
+    UnrankedMemRefType type, DialectBytecodeWriter &writer) const {
+  if (Attribute memSpace = type.getMemorySpace()) {
+    writer.writeVarInt(builtin_encoding::kUnrankedMemRefTypeWithMemSpace);
+    writer.writeAttribute(memSpace);
+  } else {
+    writer.writeVarInt(builtin_encoding::kUnrankedMemRefType);
+  }
+  writer.writeType(type.getElementType());
+}
+
+//===----------------------------------------------------------------------===//
+// UnrankedTensorType
+
+UnrankedTensorType BuiltinDialectBytecodeInterface::readUnrankedTensorType(
+    DialectBytecodeReader &reader) const {
+  Type elementType;
+  if (failed(reader.readType(elementType)))
+    return UnrankedTensorType();
+  return UnrankedTensorType::get(elementType);
+}
+
+void BuiltinDialectBytecodeInterface::write(
+    UnrankedTensorType type, DialectBytecodeWriter &writer) const {
+  writer.writeVarInt(builtin_encoding::kUnrankedTensorType);
+  writer.writeType(type.getElementType());
+}
+
+//===----------------------------------------------------------------------===//
+// VectorType
+
+VectorType
+BuiltinDialectBytecodeInterface::readVectorType(DialectBytecodeReader &reader,
+                                                bool hasScalableDims) const {
+  uint64_t numScalableDims = 0;
+  if (hasScalableDims && failed(reader.readVarInt(numScalableDims)))
+    return VectorType();
+  SmallVector<int64_t> shape;
+  Type elementType;
+  if (failed(reader.readSignedVarInts(shape)) ||
+      failed(reader.readType(elementType)))
+    return VectorType();
+  return VectorType::get(shape, elementType, numScalableDims);
+}
+
+void BuiltinDialectBytecodeInterface::write(
+    VectorType type, DialectBytecodeWriter &writer) const {
+  if (unsigned numScalableDims = type.getNumScalableDims()) {
+    writer.writeVarInt(builtin_encoding::kVectorTypeWithScalableDims);
+    writer.writeVarInt(numScalableDims);
+  } else {
+    writer.writeVarInt(builtin_encoding::kVectorType);
+  }
+  writer.writeSignedVarInts(type.getShape());
+  writer.writeType(type.getElementType());
+}

diff  --git a/mlir/test/Dialect/Builtin/Bytecode/types.mlir b/mlir/test/Dialect/Builtin/Bytecode/types.mlir
index bb311aff4ae0f..9f16bcf495cec 100644
--- a/mlir/test/Dialect/Builtin/Bytecode/types.mlir
+++ b/mlir/test/Dialect/Builtin/Bytecode/types.mlir
@@ -3,6 +3,40 @@
 // Bytecode currently does not support big-endian platforms
 // UNSUPPORTED: s390x-
 
+//===----------------------------------------------------------------------===//
+// ComplexType
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @TestComplex
+module @TestComplex attributes {
+  // CHECK: bytecode.test = complex<i32>
+  bytecode.test = complex<i32>
+} {}
+
+//===----------------------------------------------------------------------===//
+// FloatType
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @TestFloat
+module @TestFloat attributes {
+  // CHECK: bytecode.test = bf16,
+  // CHECK: bytecode.test1 = f16,
+  // CHECK: bytecode.test2 = f32,
+  // CHECK: bytecode.test3 = f64,
+  // CHECK: bytecode.test4 = f80,
+  // CHECK: bytecode.test5 = f128
+  bytecode.test = bf16,
+  bytecode.test1 = f16,
+  bytecode.test2 = f32,
+  bytecode.test3 = f64,
+  bytecode.test4 = f80,
+  bytecode.test5 = f128
+} {}
+
+//===----------------------------------------------------------------------===//
+// IntegerType
+//===----------------------------------------------------------------------===//
+
 // CHECK-LABEL: @TestInteger
 module @TestInteger attributes {
   // CHECK: bytecode.int = i1024,
@@ -13,12 +47,20 @@ module @TestInteger attributes {
   bytecode.int2 = ui512
 } {}
 
+//===----------------------------------------------------------------------===//
+// IndexType
+//===----------------------------------------------------------------------===//
+
 // CHECK-LABEL: @TestIndex
 module @TestIndex attributes {
   // CHECK: bytecode.index = index
   bytecode.index = index
 } {}
 
+//===----------------------------------------------------------------------===//
+// FunctionType
+//===----------------------------------------------------------------------===//
+
 // CHECK-LABEL: @TestFunc
 module @TestFunc attributes {
   // CHECK: bytecode.func = () -> (),
@@ -26,3 +68,83 @@ module @TestFunc attributes {
   bytecode.func = () -> (),
   bytecode.func1 = (i1) -> (i32)
 } {}
+
+//===----------------------------------------------------------------------===//
+// MemRefType
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @TestMemRef
+module @TestMemRef attributes {
+  // CHECK: bytecode.test = memref<2xi8>,
+  // CHECK: bytecode.test1 = memref<2xi8, 1>
+  bytecode.test = memref<2xi8>,
+  bytecode.test1 = memref<2xi8, 1>
+} {}
+
+//===----------------------------------------------------------------------===//
+// NoneType
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @TestNone
+module @TestNone attributes {
+  // CHECK: bytecode.test = none
+  bytecode.test = none
+} {}
+
+//===----------------------------------------------------------------------===//
+// RankedTensorType
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @TestRankedTensor
+module @TestRankedTensor attributes {
+  // CHECK: bytecode.test = tensor<16x32x?xf64>,
+  // CHECK: bytecode.test1 = tensor<16xf64, "sparse">
+  bytecode.test = tensor<16x32x?xf64>,
+  bytecode.test1 = tensor<16xf64, "sparse">
+} {}
+
+//===----------------------------------------------------------------------===//
+// TupleType
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @TestTuple
+module @TestTuple attributes {
+  // CHECK: bytecode.test = tuple<>,
+  // CHECK: bytecode.test1 = tuple<i32, i1, f32>
+  bytecode.test = tuple<>,
+  bytecode.test1 = tuple<i32, i1, f32>
+} {}
+
+//===----------------------------------------------------------------------===//
+// UnrankedMemRefType
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @TestUnrankedMemRef
+module @TestUnrankedMemRef attributes {
+  // CHECK: bytecode.test = memref<*xi8>,
+  // CHECK: bytecode.test1 = memref<*xi8, 1>
+  bytecode.test = memref<*xi8>,
+  bytecode.test1 = memref<*xi8, 1>
+} {}
+
+//===----------------------------------------------------------------------===//
+// UnrankedTensorType
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @TestUnrankedTensor
+module @TestUnrankedTensor attributes {
+  // CHECK: bytecode.test = tensor<*xi8>
+  bytecode.test = tensor<*xi8>
+} {}
+
+//===----------------------------------------------------------------------===//
+// VectorType
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @TestVector
+module @TestVector attributes {
+  // CHECK: bytecode.test = vector<8x8x128xi8>,
+  // CHECK: bytecode.test1 = vector<8x[8]xf32>
+  bytecode.test = vector<8x8x128xi8>,
+  bytecode.test1 = vector<8x[8]xf32>
+} {}


        


More information about the Mlir-commits mailing list