[Mlir-commits] [mlir] 23c9e8b - [mlir][tensors] Introduce attribute interface/attribute for tensor encoding
Aart Bik
llvmlistbot at llvm.org
Mon Apr 26 18:32:09 PDT 2021
Author: Aart Bik
Date: 2021-04-26T18:31:54-07:00
New Revision: 23c9e8bc25795b69e16d39b674c19c79a2bb107b
URL: https://github.com/llvm/llvm-project/commit/23c9e8bc25795b69e16d39b674c19c79a2bb107b
DIFF: https://github.com/llvm/llvm-project/commit/23c9e8bc25795b69e16d39b674c19c79a2bb107b.diff
LOG: [mlir][tensors] Introduce attribute interface/attribute for tensor encoding
The new "encoding" field in tensor types so far had no meaning. This revision introduces:
1. an encoding attribute interface in IR: for verification between tensors and encodings in general
2. an attribute in Tensor dialect; #tensor.sparse<dict> + concrete sparse tensors API
Active discussion:
https://llvm.discourse.group/t/rfc-introduce-a-sparse-tensor-type-to-core-mlir/2944/
Reviewed By: silvas, penpornk, bixia
Differential Revision: https://reviews.llvm.org/D101008
Added:
mlir/include/mlir/Dialect/Tensor/IR/TensorAttrDefs.td
mlir/include/mlir/IR/TensorEncoding.h
mlir/include/mlir/IR/TensorEncoding.td
mlir/lib/IR/TensorEncoding.cpp
mlir/test/Dialect/Tensor/invalid_sparse_tensor.mlir
mlir/test/Dialect/Tensor/valid_sparse.mlir
Modified:
mlir/include/mlir/Dialect/Tensor/IR/CMakeLists.txt
mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
mlir/include/mlir/IR/CMakeLists.txt
mlir/lib/Dialect/Tensor/IR/CMakeLists.txt
mlir/lib/Dialect/Tensor/IR/TensorDialect.cpp
mlir/lib/IR/BuiltinTypes.cpp
mlir/lib/IR/CMakeLists.txt
mlir/lib/Parser/TypeParser.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/Tensor/IR/CMakeLists.txt
index cd14fe5c04561..2f373aaab643b 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/Tensor/IR/CMakeLists.txt
@@ -1,2 +1,7 @@
add_mlir_dialect(TensorOps tensor)
add_mlir_doc(TensorOps TensorOps Dialects/ -gen-dialect-doc)
+
+set(LLVM_TARGET_DEFINITIONS TensorAttrDefs.td)
+mlir_tablegen(TensorAttrDefs.h.inc -gen-attrdef-decls)
+mlir_tablegen(TensorAttrDefs.cpp.inc -gen-attrdef-defs)
+add_public_tablegen_target(MLIRTensorAttrDefsIncGen)
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h b/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
index 830b682c602b6..8fa9a79feacfc 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
+++ b/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
@@ -13,6 +13,7 @@
#include "mlir/IR/Dialect.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/TensorEncoding.h"
#include "mlir/Interfaces/CastInterfaces.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
@@ -23,6 +24,13 @@
#include "mlir/Dialect/Tensor/IR/TensorOpsDialect.h.inc"
+//===----------------------------------------------------------------------===//
+// Tensor Dialect Attributes
+//===----------------------------------------------------------------------===//
+
+#define GET_ATTRDEF_CLASSES
+#include "mlir/Dialect/Tensor/IR/TensorAttrDefs.h.inc"
+
//===----------------------------------------------------------------------===//
// Tensor Dialect Operations
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorAttrDefs.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorAttrDefs.td
new file mode 100644
index 0000000000000..8103878407d13
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorAttrDefs.td
@@ -0,0 +1,82 @@
+//===-- TensorAttrDefs.td - Tensor Attributes Definitions --*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef TENSOR_ATTRDEFS
+#define TENSOR_ATTRDEFS
+
+include "mlir/Dialect/Tensor/IR/TensorBase.td"
+include "mlir/IR/TensorEncoding.td"
+
+// All of the Tensor attributes will extend this class.
+class Tensor_Attr<string name,
+ list<Trait> traits = []> : AttrDef<Tensor_Dialect, name, traits>;
+
+// Sparse tensor encoding attribute.
+def SparseTensorEncodingAttr : Tensor_Attr<"SparseTensorEncoding",
+ [ DeclareAttrInterfaceMethods<VerifiableTensorEncoding> ] > {
+ let mnemonic = "sparse";
+
+ let description = [{
+ An attribute to encode "TACO"-style information (see tensor-compiler.org)
+ on the sparsity of tensors. The semantics are defined by means of the
+ methods getDimLevelType(), getDimOrdering(), getPointerType(), and
+ getIndexType(), documented below. The encoding is eventually used by
+ a `sparse compiler` pass to generate sparse code fully automatically
+ for all tensor expressions that involve tensors with a sparse encoding.
+ Compiler passes that run before this sparse compiler pass need to be
+ aware of the semantics of tensor types with such an encoding.
+ }];
+
+ // All data is stored in a dictionary, interpreted by the methods below.
+ let parameters = (
+ ins
+ "DictionaryAttr":$dict
+ );
+
+ let extraClassDeclaration = [{
+ // Dimension level types that define sparse tensors:
+ // Dense - dimension is dense, every entry is stored
+ // Compressed - dimension is sparse, only nonzeros are stored
+ // Singleton - dimension contains single coordinate, no siblings
+ enum class DimLevelType {
+ Dense, Compressed, Singleton
+ };
+
+ // Returns the dimension level type in the given dimension `dim`
+ // of this tensor type. The choices, defined by the `DimLevelType`
+ // enum, are `dense` (the dimension should be stored in its entirety),
+ // `compressed` (only non-zero regions or elements should be stored),
+ // or `singleton` (no sibling elements for parent).
+ DimLevelType getDimLevelType(unsigned dim) const;
+
+ // Returns the dimension order of this tensor type as an AffineMap.
+ // Unlike dense storage, most sparse storage schemes do not provide
+ // fast random access. This affine map specifies the order of
+ // dimensions that should be support by the sparse storage scheme
+ // (e.g. (i,j) -> (i,j) requests 2-d row-wise and (i,j) -> (j,i)
+ // requests 2-d column-wise storage).
+ // TODO: block structure with higher-dim inputs
+ AffineMap getDimOrdering() const;
+
+ // Returns the required bit width for pointer storage. A narrow width
+ // reduces the memory footprint of overhead storage, as long as the
+ // width suffices to define the total required range (viz. the maximum
+ // number of stored entries over all indirection dimensions). The choices
+ // are `8`, `16`, `32`, `64`, or `0` for a native width.
+ unsigned getPointerBitWidth() const;
+
+ // Returns the required bit width for index storage. A narrow width
+ // reduces the memory footprint of overhead storage, as long as the
+ // width suffices to define the total required range (viz. the maximum
+ // value of each tensor index over all dimensions). The choices are `8`,
+ // `16`, `32`, `64`, or `0` for a native width.
+ unsigned getIndexBitWidth() const;
+ }];
+}
+
+#endif // LLVMIR_ATTRDEFS
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index a0e473873d27a..bf0890fea49be 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -10,6 +10,7 @@
#define TENSOR_OPS
include "mlir/Dialect/Tensor/IR/TensorBase.td"
+include "mlir/Dialect/Tensor/IR/TensorAttrDefs.td"
include "mlir/Interfaces/CastInterfaces.td"
include "mlir/Interfaces/ControlFlowInterfaces.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
diff --git a/mlir/include/mlir/IR/CMakeLists.txt b/mlir/include/mlir/IR/CMakeLists.txt
index 8ce969777981f..42e07811a4a54 100644
--- a/mlir/include/mlir/IR/CMakeLists.txt
+++ b/mlir/include/mlir/IR/CMakeLists.txt
@@ -26,6 +26,11 @@ mlir_tablegen(BuiltinTypes.h.inc -gen-typedef-decls)
mlir_tablegen(BuiltinTypes.cpp.inc -gen-typedef-defs)
add_public_tablegen_target(MLIRBuiltinTypesIncGen)
+set(LLVM_TARGET_DEFINITIONS TensorEncoding.td)
+mlir_tablegen(TensorEncInterfaces.h.inc -gen-attr-interface-decls)
+mlir_tablegen(TensorEncInterfaces.cpp.inc -gen-attr-interface-defs)
+add_public_tablegen_target(MLIRTensorEncodingIncGen)
+
add_mlir_doc(BuiltinAttributes BuiltinAttributes Dialects/ -gen-attrdef-doc)
add_mlir_doc(BuiltinLocationAttributes BuiltinLocationAttributes Dialects/ -gen-attrdef-doc)
add_mlir_doc(BuiltinOps BuiltinOps Dialects/ -gen-op-doc)
diff --git a/mlir/include/mlir/IR/TensorEncoding.h b/mlir/include/mlir/IR/TensorEncoding.h
new file mode 100644
index 0000000000000..5d98aefaeb92e
--- /dev/null
+++ b/mlir/include/mlir/IR/TensorEncoding.h
@@ -0,0 +1,21 @@
+//===- TensorEncoding.h - MLIR Tensor Encoding Declarations------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_IR_TENSORENCODING_H
+#define MLIR_IR_TENSORENCODING_H
+
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/OpDefinition.h"
+
+//===----------------------------------------------------------------------===//
+// Tablegen Type Declarations
+//===----------------------------------------------------------------------===//
+
+#include "mlir/IR/TensorEncInterfaces.h.inc"
+
+#endif // MLIR_IR_TENSORENCODING_H
diff --git a/mlir/include/mlir/IR/TensorEncoding.td b/mlir/include/mlir/IR/TensorEncoding.td
new file mode 100644
index 0000000000000..3991520d72a5f
--- /dev/null
+++ b/mlir/include/mlir/IR/TensorEncoding.td
@@ -0,0 +1,44 @@
+//===- TensorEncoding.td - Tensor encoding interfaces ------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Defines the interfaces associated with tensor encoding attributes.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_IR_TENSORINTERFACES
+#define MLIR_IR_TENSORINTERFACES
+
+include "mlir/IR/OpBase.td"
+
+//===----------------------------------------------------------------------===//
+// Attribute interface to verify a tensor encoding.
+//===----------------------------------------------------------------------===//
+
+def VerifiableTensorEncoding : AttrInterface<"VerifiableTensorEncoding"> {
+ let cppNamespace = "::mlir";
+ let description = [{
+ Verifies an encoding attribute for a tensor.
+ }];
+ let methods = [
+ InterfaceMethod<
+ /*desc=*/[{
+ Verifies the encoding is valid for a tensor type with the
+ given shape and element type. Generates a diagnostic using
+ the supplied callback on failure.
+ }],
+ /*retTy=*/"::mlir::LogicalResult",
+ /*methodName=*/"verifyEncoding",
+ /*args=*/(ins
+ "ArrayRef<int64_t>":$shape,
+ "Type":$elementType,
+ "::llvm::function_ref<::mlir::InFlightDiagnostic()>":$emitError)
+ >,
+ ];
+}
+
+#endif // MLIR_IR_TENSORINTERFACES
diff --git a/mlir/lib/Dialect/Tensor/IR/CMakeLists.txt b/mlir/lib/Dialect/Tensor/IR/CMakeLists.txt
index de650995ebb60..108b7f2470cb2 100644
--- a/mlir/lib/Dialect/Tensor/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Tensor/IR/CMakeLists.txt
@@ -7,6 +7,7 @@ add_mlir_dialect_library(MLIRTensor
DEPENDS
MLIRTensorOpsIncGen
+ MLIRTensorAttrDefsIncGen
LINK_COMPONENTS
Core
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorDialect.cpp b/mlir/lib/Dialect/Tensor/IR/TensorDialect.cpp
index da76560fe85ff..bdc2fe345b9e5 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorDialect.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorDialect.cpp
@@ -7,11 +7,142 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/IR/DialectImplementation.h"
#include "mlir/Transforms/InliningUtils.h"
+#include "llvm/ADT/TypeSwitch.h"
using namespace mlir;
using namespace mlir::tensor;
+//===----------------------------------------------------------------------===//
+// TableGen'd Attributes Methods
+//===----------------------------------------------------------------------===//
+
+#define GET_ATTRDEF_CLASSES
+#include "mlir/Dialect/Tensor/IR/TensorAttrDefs.cpp.inc"
+
+// Dictionary keys.
+static constexpr StringRef getSparseDimLevelTypeAttrName() {
+ return "sparseDimLevelType";
+}
+static constexpr StringRef getSparseDimOrderingAttrName() {
+ return "sparseDimOrdering";
+}
+static constexpr StringRef getSparsePointerBitWidthAttrName() {
+ return "sparsePointerBitWidth";
+}
+static constexpr StringRef getSparseIndexBitWidthAttrName() {
+ return "sparseIndexBitWidth";
+}
+
+// Dictionary values.
+static constexpr StringRef getDenseDimLevelTypeVal() { return "dense"; }
+static constexpr StringRef getCompressedDimLevelTypeVal() {
+ return "compressed";
+}
+static constexpr StringRef getSingletonDimLevelTypeVal() { return "singleton"; }
+
+Attribute SparseTensorEncodingAttr::parse(MLIRContext *context,
+ DialectAsmParser &parser, Type type) {
+ if (failed(parser.parseLess()))
+ return {};
+ DictionaryAttr dict;
+ if (failed(parser.parseAttribute(dict)))
+ return {};
+ if (failed(parser.parseGreater()))
+ return {};
+ return SparseTensorEncodingAttr::get(context, dict);
+}
+
+void SparseTensorEncodingAttr::print(DialectAsmPrinter &printer) const {
+ printer << "sparse<" << getDict() << ">";
+}
+
+LogicalResult SparseTensorEncodingAttr::verifyEncoding(
+ llvm::ArrayRef<int64_t> shape, Type elementType,
+ llvm::function_ref<mlir::InFlightDiagnostic()> emitError) const {
+ unsigned size = shape.size();
+ for (const NamedAttribute &attr : getDict()) {
+ if (attr.first == getSparseDimLevelTypeAttrName()) {
+ // Dimension level type verification.
+ auto arrayAttr = attr.second.dyn_cast<ArrayAttr>();
+ if (!arrayAttr || size != static_cast<int64_t>(arrayAttr.size()))
+ return emitError() << "expected an array of size " << size
+ << " for dimension level types";
+ for (unsigned i = 0; i < size; i++) {
+ auto strAttr = arrayAttr[i].dyn_cast<StringAttr>();
+ if (!strAttr)
+ return emitError()
+ << "expected string value in dimension level types";
+ auto strVal = strAttr.getValue();
+ if (strVal != getDenseDimLevelTypeVal() &&
+ strVal != getCompressedDimLevelTypeVal() &&
+ strVal != getSingletonDimLevelTypeVal())
+ return emitError() << "unexpected dimension level type: " << strAttr;
+ }
+ } else if (attr.first == getSparseDimOrderingAttrName()) {
+ // Dimension order verification.
+ auto affineAttr = attr.second.dyn_cast<AffineMapAttr>();
+ if (!affineAttr)
+ return emitError() << "expected an affine map for dimension ordering";
+ AffineMap map = affineAttr.getValue();
+ if (size != map.getNumResults() || !map.isPermutation())
+ return emitError() << "expected a permutation affine map of size "
+ << size << " for dimension ordering";
+ } else if (attr.first == getSparsePointerBitWidthAttrName() ||
+ attr.first == getSparseIndexBitWidthAttrName()) {
+ // Pointer or index bitwidth verification.
+ auto intAttr = attr.second.dyn_cast<IntegerAttr>();
+ if (!intAttr)
+ return emitError() << "expected an integral bitwidth";
+ switch (intAttr.getInt()) {
+ case 0:
+ case 8:
+ case 16:
+ case 32:
+ case 64:
+ continue;
+ default:
+ return emitError() << "unexpected bitwidth: " << intAttr.getInt();
+ }
+ } else {
+ return emitError() << "unexpected key: " << attr.first.str();
+ }
+ }
+ return success();
+}
+
+SparseTensorEncodingAttr::DimLevelType
+SparseTensorEncodingAttr::getDimLevelType(unsigned dim) const {
+ if (auto value = getDict().get(getSparseDimLevelTypeAttrName())) {
+ auto strVal =
+ value.dyn_cast<ArrayAttr>()[dim].cast<StringAttr>().getValue();
+ if (strVal == getCompressedDimLevelTypeVal())
+ return DimLevelType::Compressed;
+ if (strVal == getSingletonDimLevelTypeVal())
+ return DimLevelType::Singleton;
+ }
+ return DimLevelType::Dense;
+}
+
+AffineMap SparseTensorEncodingAttr::getDimOrdering() const {
+ if (auto value = getDict().get(getSparseDimOrderingAttrName()))
+ return value.cast<AffineMapAttr>().getValue();
+ return {};
+}
+
+unsigned SparseTensorEncodingAttr::getPointerBitWidth() const {
+ if (auto value = getDict().get(getSparsePointerBitWidthAttrName()))
+ return value.cast<IntegerAttr>().getInt();
+ return 0;
+}
+
+unsigned SparseTensorEncodingAttr::getIndexBitWidth() const {
+ if (auto value = getDict().get(getSparseIndexBitWidthAttrName()))
+ return value.cast<IntegerAttr>().getInt();
+ return 0;
+}
+
//===----------------------------------------------------------------------===//
// TensorDialect Dialect Interfaces
//===----------------------------------------------------------------------===//
@@ -30,10 +161,38 @@ struct TensorInlinerInterface : public DialectInlinerInterface {
};
} // end anonymous namespace
+//===----------------------------------------------------------------------===//
+// TensorDialect Methods
+//===----------------------------------------------------------------------===//
+
void TensorDialect::initialize() {
+ addAttributes<
+#define GET_ATTRDEF_LIST
+#include "mlir/Dialect/Tensor/IR/TensorAttrDefs.cpp.inc"
+ >();
addOperations<
#define GET_OP_LIST
#include "mlir/Dialect/Tensor/IR/TensorOps.cpp.inc"
>();
addInterfaces<TensorInlinerInterface>();
}
+
+Attribute TensorDialect::parseAttribute(DialectAsmParser &parser,
+ Type type) const {
+ StringRef attrTag;
+ if (failed(parser.parseKeyword(&attrTag)))
+ return Attribute();
+ Attribute attr;
+ auto parseResult =
+ generatedAttributeParser(getContext(), parser, attrTag, type, attr);
+ if (parseResult.hasValue())
+ return attr;
+ parser.emitError(parser.getNameLoc(), "unknown tensor attribute");
+ return Attribute();
+}
+
+void TensorDialect::printAttribute(::mlir::Attribute attr,
+ ::mlir::DialectAsmPrinter &printer) const {
+ if (succeeded(generatedAttributePrinter(attr, printer)))
+ return;
+}
diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index 4e2e2310ca01e..baadd8d0433cc 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -14,6 +14,7 @@
#include "mlir/IR/BuiltinDialect.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/Dialect.h"
+#include "mlir/IR/TensorEncoding.h"
#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/BitVector.h"
#include "llvm/ADT/Sequence.h"
@@ -446,7 +447,9 @@ RankedTensorType::verify(function_ref<InFlightDiagnostic()> emitError,
for (int64_t s : shape)
if (s < -1)
return emitError() << "invalid tensor dimension size";
- // TODO: verify contents of encoding attribute.
+ if (auto v = encoding.dyn_cast_or_null<VerifiableTensorEncoding>())
+ if (failed(v.verifyEncoding(shape, elementType, emitError)))
+ return failure();
return checkTensorElementType(emitError, elementType);
}
diff --git a/mlir/lib/IR/CMakeLists.txt b/mlir/lib/IR/CMakeLists.txt
index 68367d69b68a8..06f1985b53283 100644
--- a/mlir/lib/IR/CMakeLists.txt
+++ b/mlir/lib/IR/CMakeLists.txt
@@ -22,6 +22,7 @@ add_mlir_library(MLIRIR
Region.cpp
RegionKindInterface.cpp
SymbolTable.cpp
+ TensorEncoding.cpp
Types.cpp
TypeRange.cpp
TypeUtilities.cpp
@@ -45,6 +46,7 @@ add_mlir_library(MLIRIR
MLIRRegionKindInterfaceIncGen
MLIRSideEffectInterfacesIncGen
MLIRSymbolInterfacesIncGen
+ MLIRTensorEncodingIncGen
LINK_LIBS PUBLIC
MLIRSupport
diff --git a/mlir/lib/IR/TensorEncoding.cpp b/mlir/lib/IR/TensorEncoding.cpp
new file mode 100644
index 0000000000000..2ab6788306049
--- /dev/null
+++ b/mlir/lib/IR/TensorEncoding.cpp
@@ -0,0 +1,17 @@
+//===- TensorEncoding.cpp - MLIR Tensor Encoding --------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/IR/TensorEncoding.h"
+
+using namespace mlir;
+
+//===----------------------------------------------------------------------===//
+// Tensor Encoding Interfaces Methods
+//===----------------------------------------------------------------------===//
+
+#include "mlir/IR/TensorEncInterfaces.cpp.inc"
diff --git a/mlir/lib/Parser/TypeParser.cpp b/mlir/lib/Parser/TypeParser.cpp
index 0ec36c6085da1..b523d14a547da 100644
--- a/mlir/lib/Parser/TypeParser.cpp
+++ b/mlir/lib/Parser/TypeParser.cpp
@@ -13,6 +13,7 @@
#include "Parser.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/TensorEncoding.h"
using namespace mlir;
using namespace mlir::detail;
@@ -412,8 +413,14 @@ Type Parser::parseTensorType() {
// Parse an optional encoding attribute.
Attribute encoding;
- if (consumeIf(Token::comma))
+ if (consumeIf(Token::comma)) {
encoding = parseAttribute();
+ if (auto v = encoding.dyn_cast_or_null<VerifiableTensorEncoding>()) {
+ if (failed(v.verifyEncoding(dimensions, elementType,
+ [&] { return emitError(); })))
+ return nullptr;
+ }
+ }
if (!elementType || parseToken(Token::greater, "expected '>' in tensor type"))
return nullptr;
diff --git a/mlir/test/Dialect/Tensor/invalid_sparse_tensor.mlir b/mlir/test/Dialect/Tensor/invalid_sparse_tensor.mlir
new file mode 100644
index 0000000000000..b317c3b684e99
--- /dev/null
+++ b/mlir/test/Dialect/Tensor/invalid_sparse_tensor.mlir
@@ -0,0 +1,46 @@
+// RUN: mlir-opt <%s -split-input-file -verify-diagnostics
+
+// -----
+
+#a = #tensor.sparse<{sparseDimLevelType = [1,2]}>
+func private @tensor_size_mismatch(%arg0: tensor<8xi32, #a>) -> () // expected-error {{expected an array of size 1 for dimension level types}}
+
+// -----
+
+#a = #tensor.sparse<{sparseDimLevelType = [1]}>
+func private @tensor_type_mismatch(%arg0: tensor<8xi32, #a>) -> () // expected-error {{expected string value in dimension level types}}
+
+// -----
+
+#a = #tensor.sparse<{sparseDimLevelType = ["strange"]}>
+func private @tensor_value_mismatch(%arg0: tensor<8xi32, #a>) -> () // expected-error {{unexpected dimension level type: "strange"}}
+
+// -----
+
+#a = #tensor.sparse<{sparseDimOrdering = "wrong"}>
+func private @tensor_order_mismatch(%arg0: tensor<8xi32, #a>) -> () // expected-error {{expected an affine map for dimension ordering}}
+
+// -----
+
+#a = #tensor.sparse<{sparseDimOrdering = affine_map<(i,j) -> (i,i)>}>
+func private @tensor_no_permutation(%arg0: tensor<16x32xf32, #a>) -> () // expected-error {{expected a permutation affine map of size 2 for dimension ordering}}
+
+// -----
+
+#a = #tensor.sparse<{sparsePointerBitWidth = 42}>
+func private @tensor_invalid_int_ptr(%arg0: tensor<16x32xf32, #a>) -> () // expected-error {{unexpected bitwidth: 42}}
+
+// -----
+
+#a = #tensor.sparse<{sparseIndexBitWidth = "not really"}>
+func private @tensor_no_int_index(%arg0: tensor<16x32xf32, #a>) -> () // expected-error {{expected an integral bitwidth}}
+
+// -----
+
+#a = #tensor.sparse<{sparseIndexBitWidth = 128}>
+func private @tensor_invalid_int_index(%arg0: tensor<16x32xf32, #a>) -> () // expected-error {{unexpected bitwidth: 128}}
+
+// -----
+
+#a = #tensor.sparse<{key = 1}>
+func private @tensor_invalid_key(%arg0: tensor<16x32xf32, #a>) -> () // expected-error {{unexpected key: key}}
diff --git a/mlir/test/Dialect/Tensor/valid_sparse.mlir b/mlir/test/Dialect/Tensor/valid_sparse.mlir
new file mode 100644
index 0000000000000..0f010e5ad9b70
--- /dev/null
+++ b/mlir/test/Dialect/Tensor/valid_sparse.mlir
@@ -0,0 +1,14 @@
+// RUN: mlir-opt <%s | mlir-opt | FileCheck %s
+
+// CHECK: func private @sparse_1d_tensor(tensor<32xf64, #tensor.sparse<{sparseDimLevelType = ["compressed"]}>>)
+func private @sparse_1d_tensor(tensor<32xf64, #tensor.sparse<{sparseDimLevelType = ["compressed"]}>>)
+
+#CSR = #tensor.sparse<{
+ sparseDimLevelType = [ "dense", "compressed" ],
+ sparseDimOrdering = affine_map<(i,j) -> (i,j)>,
+ sparseIndexBitWidth = 64,
+ sparsePointerBitWidth = 64
+}>
+
+// CHECK: func private @sparse_2d_tensor(tensor<?x?xf32, #tensor.sparse<{sparseDimLevelType = ["dense", "compressed"], sparseDimOrdering = affine_map<(d0, d1) -> (d0, d1)>, sparseIndexBitWidth = 64 : i64, sparsePointerBitWidth = 64 : i64}>>)
+func private @sparse_2d_tensor(tensor<?x?xf32, #CSR>)
More information about the Mlir-commits
mailing list