[Mlir-commits] [mlir] 8574405 - [mlir] Add bytecode encoding for the remaining builtin types
River Riddle
llvmlistbot at llvm.org
Fri Aug 26 13:33:47 PDT 2022
Author: River Riddle
Date: 2022-08-26T13:31:06-07:00
New Revision: 857440590baf22aa4d3767b39d2d2c246d230922
URL: https://github.com/llvm/llvm-project/commit/857440590baf22aa4d3767b39d2d2c246d230922
DIFF: https://github.com/llvm/llvm-project/commit/857440590baf22aa4d3767b39d2d2c246d230922.diff
LOG: [mlir] Add bytecode encoding for the remaining builtin types
After this commit we will have an efficient bytecode representation for all
of the builtin types.
Differential Revision: https://reviews.llvm.org/D132604
Added:
Modified:
mlir/include/mlir/Bytecode/BytecodeImplementation.h
mlir/lib/IR/BuiltinDialectBytecode.cpp
mlir/test/Dialect/Builtin/Bytecode/types.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Bytecode/BytecodeImplementation.h b/mlir/include/mlir/Bytecode/BytecodeImplementation.h
index 921d795068fbd..845c790661411 100644
--- a/mlir/include/mlir/Bytecode/BytecodeImplementation.h
+++ b/mlir/include/mlir/Bytecode/BytecodeImplementation.h
@@ -38,10 +38,6 @@ class DialectBytecodeReader {
/// Emit an error to the reader.
virtual InFlightDiagnostic emitError(const Twine &msg = {}) = 0;
- //===--------------------------------------------------------------------===//
- // IR
- //===--------------------------------------------------------------------===//
-
/// Read out a list of elements, invoking the provided callback for each
/// element. The callback function may be in any of the following forms:
/// * LogicalResult(T &)
@@ -71,6 +67,10 @@ class DialectBytecodeReader {
return success();
}
+ //===--------------------------------------------------------------------===//
+ // IR
+ //===--------------------------------------------------------------------===//
+
/// Read a reference to the given attribute.
virtual LogicalResult readAttribute(Attribute &result) = 0;
template <typename T>
@@ -114,6 +114,10 @@ class DialectBytecodeReader {
/// Read a signed variable width integer.
virtual LogicalResult readSignedVarInt(int64_t &result) = 0;
+ LogicalResult readSignedVarInts(SmallVectorImpl<int64_t> &result) {
+ return readList(result,
+ [this](int64_t &value) { return readSignedVarInt(value); });
+ }
/// Read an APInt that is known to have been encoded with the given width.
virtual FailureOr<APInt> readAPIntWithKnownWidth(unsigned bitWidth) = 0;
@@ -178,6 +182,9 @@ class DialectBytecodeWriter {
/// Write a signed variable width integer to the output stream. This should be
/// the preferred method for emitting signed integers whenever possible.
virtual void writeSignedVarInt(int64_t value) = 0;
+ void writeSignedVarInts(ArrayRef<int64_t> value) {
+ writeList(value, [this](int64_t value) { writeSignedVarInt(value); });
+ }
/// Write an APInt to the bytecode stream whose bitwidth will be known
/// externally at read time. This method is useful for encoding APInt values
diff --git a/mlir/lib/IR/BuiltinDialectBytecode.cpp b/mlir/lib/IR/BuiltinDialectBytecode.cpp
index 43288a917fc24..ce6120c77ffbf 100644
--- a/mlir/lib/IR/BuiltinDialectBytecode.cpp
+++ b/mlir/lib/IR/BuiltinDialectBytecode.cpp
@@ -140,6 +140,118 @@ enum TypeCode {
/// }
///
kFunctionType = 2,
+
+ /// BFloat16Type {
+ /// }
+ ///
+ kBFloat16Type = 3,
+
+ /// Float16Type {
+ /// }
+ ///
+ kFloat16Type = 4,
+
+ /// Float32Type {
+ /// }
+ ///
+ kFloat32Type = 5,
+
+ /// Float64Type {
+ /// }
+ ///
+ kFloat64Type = 6,
+
+ /// Float80Type {
+ /// }
+ ///
+ kFloat80Type = 7,
+
+ /// Float128Type {
+ /// }
+ ///
+ kFloat128Type = 8,
+
+ /// ComplexType {
+ /// elementType: Type
+ /// }
+ ///
+ kComplexType = 9,
+
+ /// MemRefType {
+ /// shape: svarint[],
+ /// elementType: Type,
+ /// layout: Attribute
+ /// }
+ ///
+ kMemRefType = 10,
+
+ /// MemRefTypeWithMemSpace {
+ /// memorySpace: Attribute,
+ /// shape: svarint[],
+ /// elementType: Type,
+ /// layout: Attribute
+ /// }
+ /// Variant of MemRefType with non-default memory space.
+ kMemRefTypeWithMemSpace = 11,
+
+ /// NoneType {
+ /// }
+ ///
+ kNoneType = 12,
+
+ /// RankedTensorType {
+ /// shape: svarint[],
+ /// elementType: Type,
+ /// }
+ ///
+ kRankedTensorType = 13,
+
+ /// RankedTensorTypeWithEncoding {
+ /// encoding: Attribute,
+ /// shape: svarint[],
+ /// elementType: Type
+ /// }
+ /// Variant of RankedTensorType with an encoding.
+ kRankedTensorTypeWithEncoding = 14,
+
+ /// TupleType {
+ /// elementTypes: Type[]
+ /// }
+ kTupleType = 15,
+
+ /// UnrankedMemRefType {
+ /// shape: svarint[]
+ /// }
+ ///
+ kUnrankedMemRefType = 16,
+
+ /// UnrankedMemRefTypeWithMemSpace {
+ /// memorySpace: Attribute,
+ /// shape: svarint[]
+ /// }
+ /// Variant of UnrankedMemRefType with non-default memory space.
+ kUnrankedMemRefTypeWithMemSpace = 17,
+
+ /// UnrankedTensorType {
+ /// elementType: Type
+ /// }
+ ///
+ kUnrankedTensorType = 18,
+
+ /// VectorType {
+ /// shape: svarint[],
+ /// elementType: Type
+ /// }
+ ///
+ kVectorType = 19,
+
+ /// VectorTypeWithScalableDims {
+ /// numScalableDims: varint,
+ /// shape: svarint[],
+ /// elementType: Type
+ /// }
+ /// Variant of VectorType with scalable dimensions.
+ kVectorTypeWithScalableDims = 20,
};
} // namespace builtin_encoding
@@ -194,13 +306,32 @@ struct BuiltinDialectBytecodeInterface : public BytecodeDialectInterface {
// Types
Type readType(DialectBytecodeReader &reader) const override;
+ ComplexType readComplexType(DialectBytecodeReader &reader) const;
IntegerType readIntegerType(DialectBytecodeReader &reader) const;
FunctionType readFunctionType(DialectBytecodeReader &reader) const;
+ MemRefType readMemRefType(DialectBytecodeReader &reader,
+ bool hasMemSpace) const;
+ RankedTensorType readRankedTensorType(DialectBytecodeReader &reader,
+ bool hasEncoding) const;
+ TupleType readTupleType(DialectBytecodeReader &reader) const;
+ UnrankedMemRefType readUnrankedMemRefType(DialectBytecodeReader &reader,
+ bool hasMemSpace) const;
+ UnrankedTensorType
+ readUnrankedTensorType(DialectBytecodeReader &reader) const;
+ VectorType readVectorType(DialectBytecodeReader &reader,
+ bool hasScalableDims) const;
LogicalResult writeType(Type type,
DialectBytecodeWriter &writer) const override;
+ void write(ComplexType type, DialectBytecodeWriter &writer) const;
void write(IntegerType type, DialectBytecodeWriter &writer) const;
void write(FunctionType type, DialectBytecodeWriter &writer) const;
+ void write(MemRefType type, DialectBytecodeWriter &writer) const;
+ void write(RankedTensorType type, DialectBytecodeWriter &writer) const;
+ void write(TupleType type, DialectBytecodeWriter &writer) const;
+ void write(UnrankedMemRefType type, DialectBytecodeWriter &writer) const;
+ void write(UnrankedTensorType type, DialectBytecodeWriter &writer) const;
+ void write(VectorType type, DialectBytecodeWriter &writer) const;
};
} // namespace
@@ -576,9 +707,45 @@ Type BuiltinDialectBytecodeInterface::readType(
return readIntegerType(reader);
case builtin_encoding::kIndexType:
return IndexType::get(getContext());
-
case builtin_encoding::kFunctionType:
return readFunctionType(reader);
+ case builtin_encoding::kBFloat16Type:
+ return BFloat16Type::get(getContext());
+ case builtin_encoding::kFloat16Type:
+ return Float16Type::get(getContext());
+ case builtin_encoding::kFloat32Type:
+ return Float32Type::get(getContext());
+ case builtin_encoding::kFloat64Type:
+ return Float64Type::get(getContext());
+ case builtin_encoding::kFloat80Type:
+ return Float80Type::get(getContext());
+ case builtin_encoding::kFloat128Type:
+ return Float128Type::get(getContext());
+ case builtin_encoding::kComplexType:
+ return readComplexType(reader);
+ case builtin_encoding::kMemRefType:
+ return readMemRefType(reader, /*hasMemSpace=*/false);
+ case builtin_encoding::kMemRefTypeWithMemSpace:
+ return readMemRefType(reader, /*hasMemSpace=*/true);
+ case builtin_encoding::kNoneType:
+ return NoneType::get(getContext());
+ case builtin_encoding::kRankedTensorType:
+ return readRankedTensorType(reader, /*hasEncoding=*/false);
+ case builtin_encoding::kRankedTensorTypeWithEncoding:
+ return readRankedTensorType(reader, /*hasEncoding=*/true);
+ case builtin_encoding::kTupleType:
+ return readTupleType(reader);
+ case builtin_encoding::kUnrankedMemRefType:
+ return readUnrankedMemRefType(reader, /*hasMemSpace=*/false);
+ case builtin_encoding::kUnrankedMemRefTypeWithMemSpace:
+ return readUnrankedMemRefType(reader, /*hasMemSpace=*/true);
+ case builtin_encoding::kUnrankedTensorType:
+ return readUnrankedTensorType(reader);
+ case builtin_encoding::kVectorType:
+ return readVectorType(reader, /*hasScalableDims=*/false);
+ case builtin_encoding::kVectorTypeWithScalableDims:
+ return readVectorType(reader, /*hasScalableDims=*/true);
+
default:
reader.emitError() << "unknown builtin type code: " << code;
return Type();
@@ -588,16 +755,56 @@ Type BuiltinDialectBytecodeInterface::readType(
LogicalResult BuiltinDialectBytecodeInterface::writeType(
Type type, DialectBytecodeWriter &writer) const {
return TypeSwitch<Type, LogicalResult>(type)
- .Case<IntegerType, FunctionType>([&](auto type) {
+ .Case<ComplexType, IntegerType, FunctionType, MemRefType,
+ RankedTensorType, TupleType, UnrankedMemRefType, UnrankedTensorType,
+ VectorType>([&](auto type) {
write(type, writer);
return success();
})
.Case([&](IndexType) {
return writer.writeVarInt(builtin_encoding::kIndexType), success();
})
+ .Case([&](BFloat16Type) {
+ return writer.writeVarInt(builtin_encoding::kBFloat16Type), success();
+ })
+ .Case([&](Float16Type) {
+ return writer.writeVarInt(builtin_encoding::kFloat16Type), success();
+ })
+ .Case([&](Float32Type) {
+ return writer.writeVarInt(builtin_encoding::kFloat32Type), success();
+ })
+ .Case([&](Float64Type) {
+ return writer.writeVarInt(builtin_encoding::kFloat64Type), success();
+ })
+ .Case([&](Float80Type) {
+ return writer.writeVarInt(builtin_encoding::kFloat80Type), success();
+ })
+ .Case([&](Float128Type) {
+ return writer.writeVarInt(builtin_encoding::kFloat128Type), success();
+ })
+ .Case([&](NoneType) {
+ return writer.writeVarInt(builtin_encoding::kNoneType), success();
+ })
.Default([&](Type) { return failure(); });
}
+//===----------------------------------------------------------------------===//
+// ComplexType
+
+ComplexType BuiltinDialectBytecodeInterface::readComplexType(
+ DialectBytecodeReader &reader) const {
+ Type elementType;
+ if (failed(reader.readType(elementType)))
+ return ComplexType();
+ return ComplexType::get(elementType);
+}
+
+void BuiltinDialectBytecodeInterface::write(
+ ComplexType type, DialectBytecodeWriter &writer) const {
+ writer.writeVarInt(builtin_encoding::kComplexType);
+ writer.writeType(type.getElementType());
+}
+
//===----------------------------------------------------------------------===//
// IntegerType
@@ -634,3 +841,151 @@ void BuiltinDialectBytecodeInterface::write(
writer.writeTypes(type.getInputs());
writer.writeTypes(type.getResults());
}
+
+//===----------------------------------------------------------------------===//
+// MemRefType
+
+MemRefType
+BuiltinDialectBytecodeInterface::readMemRefType(DialectBytecodeReader &reader,
+ bool hasMemSpace) const {
+ Attribute memorySpace;
+ if (hasMemSpace && failed(reader.readAttribute(memorySpace)))
+ return MemRefType();
+ SmallVector<int64_t> shape;
+ Type elementType;
+ MemRefLayoutAttrInterface layout;
+ if (failed(reader.readSignedVarInts(shape)) ||
+ failed(reader.readType(elementType)) ||
+ failed(reader.readAttribute(layout)))
+ return MemRefType();
+ return MemRefType::get(shape, elementType, layout, memorySpace);
+}
+
+void BuiltinDialectBytecodeInterface::write(
+ MemRefType type, DialectBytecodeWriter &writer) const {
+ if (Attribute memSpace = type.getMemorySpace()) {
+ writer.writeVarInt(builtin_encoding::kMemRefTypeWithMemSpace);
+ writer.writeAttribute(memSpace);
+ } else {
+ writer.writeVarInt(builtin_encoding::kMemRefType);
+ }
+ writer.writeSignedVarInts(type.getShape());
+ writer.writeType(type.getElementType());
+ writer.writeAttribute(type.getLayout());
+}
+
+//===----------------------------------------------------------------------===//
+// RankedTensorType
+
+RankedTensorType BuiltinDialectBytecodeInterface::readRankedTensorType(
+ DialectBytecodeReader &reader, bool hasEncoding) const {
+ Attribute encoding;
+ if (hasEncoding && failed(reader.readAttribute(encoding)))
+ return RankedTensorType();
+ SmallVector<int64_t> shape;
+ Type elementType;
+ if (failed(reader.readSignedVarInts(shape)) ||
+ failed(reader.readType(elementType)))
+ return RankedTensorType();
+ return RankedTensorType::get(shape, elementType, encoding);
+}
+
+void BuiltinDialectBytecodeInterface::write(
+ RankedTensorType type, DialectBytecodeWriter &writer) const {
+ if (Attribute encoding = type.getEncoding()) {
+ writer.writeVarInt(builtin_encoding::kRankedTensorTypeWithEncoding);
+ writer.writeAttribute(encoding);
+ } else {
+ writer.writeVarInt(builtin_encoding::kRankedTensorType);
+ }
+ writer.writeSignedVarInts(type.getShape());
+ writer.writeType(type.getElementType());
+}
+
+//===----------------------------------------------------------------------===//
+// TupleType
+
+TupleType BuiltinDialectBytecodeInterface::readTupleType(
+ DialectBytecodeReader &reader) const {
+ SmallVector<Type> elements;
+ if (failed(reader.readTypes(elements)))
+ return TupleType();
+ return TupleType::get(getContext(), elements);
+}
+
+void BuiltinDialectBytecodeInterface::write(
+ TupleType type, DialectBytecodeWriter &writer) const {
+ writer.writeVarInt(builtin_encoding::kTupleType);
+ writer.writeTypes(type.getTypes());
+}
+
+//===----------------------------------------------------------------------===//
+// UnrankedMemRefType
+
+UnrankedMemRefType BuiltinDialectBytecodeInterface::readUnrankedMemRefType(
+ DialectBytecodeReader &reader, bool hasMemSpace) const {
+ Attribute memorySpace;
+ if (hasMemSpace && failed(reader.readAttribute(memorySpace)))
+ return UnrankedMemRefType();
+ Type elementType;
+ if (failed(reader.readType(elementType)))
+ return UnrankedMemRefType();
+ return UnrankedMemRefType::get(elementType, memorySpace);
+}
+
+void BuiltinDialectBytecodeInterface::write(
+ UnrankedMemRefType type, DialectBytecodeWriter &writer) const {
+ if (Attribute memSpace = type.getMemorySpace()) {
+ writer.writeVarInt(builtin_encoding::kUnrankedMemRefTypeWithMemSpace);
+ writer.writeAttribute(memSpace);
+ } else {
+ writer.writeVarInt(builtin_encoding::kUnrankedMemRefType);
+ }
+ writer.writeType(type.getElementType());
+}
+
+//===----------------------------------------------------------------------===//
+// UnrankedTensorType
+
+UnrankedTensorType BuiltinDialectBytecodeInterface::readUnrankedTensorType(
+ DialectBytecodeReader &reader) const {
+ Type elementType;
+ if (failed(reader.readType(elementType)))
+ return UnrankedTensorType();
+ return UnrankedTensorType::get(elementType);
+}
+
+void BuiltinDialectBytecodeInterface::write(
+ UnrankedTensorType type, DialectBytecodeWriter &writer) const {
+ writer.writeVarInt(builtin_encoding::kUnrankedTensorType);
+ writer.writeType(type.getElementType());
+}
+
+//===----------------------------------------------------------------------===//
+// VectorType
+
+VectorType
+BuiltinDialectBytecodeInterface::readVectorType(DialectBytecodeReader &reader,
+ bool hasScalableDims) const {
+ uint64_t numScalableDims = 0;
+ if (hasScalableDims && failed(reader.readVarInt(numScalableDims)))
+ return VectorType();
+ SmallVector<int64_t> shape;
+ Type elementType;
+ if (failed(reader.readSignedVarInts(shape)) ||
+ failed(reader.readType(elementType)))
+ return VectorType();
+ return VectorType::get(shape, elementType, numScalableDims);
+}
+
+void BuiltinDialectBytecodeInterface::write(
+ VectorType type, DialectBytecodeWriter &writer) const {
+ if (unsigned numScalableDims = type.getNumScalableDims()) {
+ writer.writeVarInt(builtin_encoding::kVectorTypeWithScalableDims);
+ writer.writeVarInt(numScalableDims);
+ } else {
+ writer.writeVarInt(builtin_encoding::kVectorType);
+ }
+ writer.writeSignedVarInts(type.getShape());
+ writer.writeType(type.getElementType());
+}
diff --git a/mlir/test/Dialect/Builtin/Bytecode/types.mlir b/mlir/test/Dialect/Builtin/Bytecode/types.mlir
index bb311aff4ae0f..9f16bcf495cec 100644
--- a/mlir/test/Dialect/Builtin/Bytecode/types.mlir
+++ b/mlir/test/Dialect/Builtin/Bytecode/types.mlir
@@ -3,6 +3,40 @@
// Bytecode currently does not support big-endian platforms
// UNSUPPORTED: s390x-
+//===----------------------------------------------------------------------===//
+// ComplexType
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @TestComplex
+module @TestComplex attributes {
+ // CHECK: bytecode.test = complex<i32>
+ bytecode.test = complex<i32>
+} {}
+
+//===----------------------------------------------------------------------===//
+// FloatType
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @TestFloat
+module @TestFloat attributes {
+ // CHECK: bytecode.test = bf16,
+ // CHECK: bytecode.test1 = f16,
+ // CHECK: bytecode.test2 = f32,
+ // CHECK: bytecode.test3 = f64,
+ // CHECK: bytecode.test4 = f80,
+ // CHECK: bytecode.test5 = f128
+ bytecode.test = bf16,
+ bytecode.test1 = f16,
+ bytecode.test2 = f32,
+ bytecode.test3 = f64,
+ bytecode.test4 = f80,
+ bytecode.test5 = f128
+} {}
+
+//===----------------------------------------------------------------------===//
+// IntegerType
+//===----------------------------------------------------------------------===//
+
// CHECK-LABEL: @TestInteger
module @TestInteger attributes {
// CHECK: bytecode.int = i1024,
@@ -13,12 +47,20 @@ module @TestInteger attributes {
bytecode.int2 = ui512
} {}
+//===----------------------------------------------------------------------===//
+// IndexType
+//===----------------------------------------------------------------------===//
+
// CHECK-LABEL: @TestIndex
module @TestIndex attributes {
// CHECK: bytecode.index = index
bytecode.index = index
} {}
+//===----------------------------------------------------------------------===//
+// FunctionType
+//===----------------------------------------------------------------------===//
+
// CHECK-LABEL: @TestFunc
module @TestFunc attributes {
// CHECK: bytecode.func = () -> (),
@@ -26,3 +68,83 @@ module @TestFunc attributes {
bytecode.func = () -> (),
bytecode.func1 = (i1) -> (i32)
} {}
+
+//===----------------------------------------------------------------------===//
+// MemRefType
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @TestMemRef
+module @TestMemRef attributes {
+ // CHECK: bytecode.test = memref<2xi8>,
+ // CHECK: bytecode.test1 = memref<2xi8, 1>
+ bytecode.test = memref<2xi8>,
+ bytecode.test1 = memref<2xi8, 1>
+} {}
+
+//===----------------------------------------------------------------------===//
+// NoneType
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @TestNone
+module @TestNone attributes {
+ // CHECK: bytecode.test = none
+ bytecode.test = none
+} {}
+
+//===----------------------------------------------------------------------===//
+// RankedTensorType
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @TestRankedTensor
+module @TestRankedTensor attributes {
+ // CHECK: bytecode.test = tensor<16x32x?xf64>,
+ // CHECK: bytecode.test1 = tensor<16xf64, "sparse">
+ bytecode.test = tensor<16x32x?xf64>,
+ bytecode.test1 = tensor<16xf64, "sparse">
+} {}
+
+//===----------------------------------------------------------------------===//
+// TupleType
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @TestTuple
+module @TestTuple attributes {
+ // CHECK: bytecode.test = tuple<>,
+ // CHECK: bytecode.test1 = tuple<i32, i1, f32>
+ bytecode.test = tuple<>,
+ bytecode.test1 = tuple<i32, i1, f32>
+} {}
+
+//===----------------------------------------------------------------------===//
+// UnrankedMemRefType
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @TestUnrankedMemRef
+module @TestUnrankedMemRef attributes {
+ // CHECK: bytecode.test = memref<*xi8>,
+ // CHECK: bytecode.test1 = memref<*xi8, 1>
+ bytecode.test = memref<*xi8>,
+ bytecode.test1 = memref<*xi8, 1>
+} {}
+
+//===----------------------------------------------------------------------===//
+// UnrankedTensorType
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @TestUnrankedTensor
+module @TestUnrankedTensor attributes {
+ // CHECK: bytecode.test = tensor<*xi8>
+ bytecode.test = tensor<*xi8>
+} {}
+
+//===----------------------------------------------------------------------===//
+// VectorType
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @TestVector
+module @TestVector attributes {
+ // CHECK: bytecode.test = vector<8x8x128xi8>,
+ // CHECK: bytecode.test1 = vector<8x[8]xf32>
+ bytecode.test = vector<8x8x128xi8>,
+ bytecode.test1 = vector<8x[8]xf32>
+} {}
More information about the Mlir-commits
mailing list