[Mlir-commits] [mlir] 23c9e8b - [mlir][tensors] Introduce attribute interface/attribute for tensor encoding

Aart Bik llvmlistbot at llvm.org
Mon Apr 26 18:32:09 PDT 2021


Author: Aart Bik
Date: 2021-04-26T18:31:54-07:00
New Revision: 23c9e8bc25795b69e16d39b674c19c79a2bb107b

URL: https://github.com/llvm/llvm-project/commit/23c9e8bc25795b69e16d39b674c19c79a2bb107b
DIFF: https://github.com/llvm/llvm-project/commit/23c9e8bc25795b69e16d39b674c19c79a2bb107b.diff

LOG: [mlir][tensors] Introduce attribute interface/attribute for tensor encoding

The new "encoding" field in tensor types so far had no meaning. This revision introduces:

1. an encoding attribute interface in IR: for verification between tensors and encodings in general
2. an attribute in Tensor dialect; #tensor.sparse<dict> + concrete sparse tensors API

Active discussion:
https://llvm.discourse.group/t/rfc-introduce-a-sparse-tensor-type-to-core-mlir/2944/

Reviewed By: silvas, penpornk, bixia

Differential Revision: https://reviews.llvm.org/D101008

Added: 
    mlir/include/mlir/Dialect/Tensor/IR/TensorAttrDefs.td
    mlir/include/mlir/IR/TensorEncoding.h
    mlir/include/mlir/IR/TensorEncoding.td
    mlir/lib/IR/TensorEncoding.cpp
    mlir/test/Dialect/Tensor/invalid_sparse_tensor.mlir
    mlir/test/Dialect/Tensor/valid_sparse.mlir

Modified: 
    mlir/include/mlir/Dialect/Tensor/IR/CMakeLists.txt
    mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
    mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
    mlir/include/mlir/IR/CMakeLists.txt
    mlir/lib/Dialect/Tensor/IR/CMakeLists.txt
    mlir/lib/Dialect/Tensor/IR/TensorDialect.cpp
    mlir/lib/IR/BuiltinTypes.cpp
    mlir/lib/IR/CMakeLists.txt
    mlir/lib/Parser/TypeParser.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Tensor/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/Tensor/IR/CMakeLists.txt
index cd14fe5c04561..2f373aaab643b 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/Tensor/IR/CMakeLists.txt
@@ -1,2 +1,7 @@
 add_mlir_dialect(TensorOps tensor)
 add_mlir_doc(TensorOps TensorOps Dialects/ -gen-dialect-doc)
+
+set(LLVM_TARGET_DEFINITIONS TensorAttrDefs.td)
+mlir_tablegen(TensorAttrDefs.h.inc -gen-attrdef-decls)
+mlir_tablegen(TensorAttrDefs.cpp.inc -gen-attrdef-defs)
+add_public_tablegen_target(MLIRTensorAttrDefsIncGen)

diff  --git a/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h b/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
index 830b682c602b6..8fa9a79feacfc 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
+++ b/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
@@ -13,6 +13,7 @@
 #include "mlir/IR/Dialect.h"
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/TensorEncoding.h"
 #include "mlir/Interfaces/CastInterfaces.h"
 #include "mlir/Interfaces/ControlFlowInterfaces.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
@@ -23,6 +24,13 @@
 
 #include "mlir/Dialect/Tensor/IR/TensorOpsDialect.h.inc"
 
+//===----------------------------------------------------------------------===//
+// Tensor Dialect Attributes
+//===----------------------------------------------------------------------===//
+
+#define GET_ATTRDEF_CLASSES
+#include "mlir/Dialect/Tensor/IR/TensorAttrDefs.h.inc"
+
 //===----------------------------------------------------------------------===//
 // Tensor Dialect Operations
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorAttrDefs.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorAttrDefs.td
new file mode 100644
index 0000000000000..8103878407d13
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorAttrDefs.td
@@ -0,0 +1,82 @@
+//===-- TensorAttrDefs.td - Tensor Attributes Definitions --*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef TENSOR_ATTRDEFS
+#define TENSOR_ATTRDEFS
+
+include "mlir/Dialect/Tensor/IR/TensorBase.td"
+include "mlir/IR/TensorEncoding.td"
+
+// All of the Tensor attributes will extend this class.
+class Tensor_Attr<string name,
+                  list<Trait> traits = []> : AttrDef<Tensor_Dialect, name, traits>;
+
+// Sparse tensor encoding attribute.
+def SparseTensorEncodingAttr : Tensor_Attr<"SparseTensorEncoding",
+         [ DeclareAttrInterfaceMethods<VerifiableTensorEncoding> ] > {
+  let mnemonic = "sparse";
+
+  let description = [{
+    An attribute to encode "TACO"-style information (see tensor-compiler.org)
+    on the sparsity of tensors. The semantics are defined by means of the
+    methods getDimLevelType(), getDimOrdering(), getPointerType(), and
+    getIndexType(), documented below. The encoding is eventually used by
+    a `sparse compiler` pass to generate sparse code fully automatically
+    for all tensor expressions that involve tensors with a sparse encoding.
+    Compiler passes that run before this sparse compiler pass need to be
+    aware of the semantics of tensor types with such an encoding.
+  }];
+
+  // All data is stored in a dictionary, interpreted by the methods below.
+  let parameters = (
+    ins
+    "DictionaryAttr":$dict
+  );
+
+  let extraClassDeclaration = [{
+    // Dimension level types that define sparse tensors:
+    //   Dense      - dimension is dense, every entry is stored
+    //   Compressed - dimension is sparse, only nonzeros are stored
+    //   Singleton  - dimension contains single coordinate, no siblings
+    enum class DimLevelType {
+      Dense, Compressed, Singleton
+    };
+
+    // Returns the dimension level type in the given dimension `dim`
+    // of this tensor type. The choices, defined by the `DimLevelType`
+    // enum, are `dense` (the dimension should be stored in its entirety),
+    // `compressed` (only non-zero regions or elements should be stored),
+    // or `singleton` (no sibling elements for parent).
+    DimLevelType getDimLevelType(unsigned dim) const;
+
+    // Returns the dimension order of this tensor type as an AffineMap.
+    // Unlike dense storage, most sparse storage schemes do not provide
+    // fast random access. This affine map specifies the order of
+    // dimensions that should be support by the sparse storage scheme
+    // (e.g. (i,j) -> (i,j) requests 2-d row-wise and (i,j) -> (j,i)
+    // requests 2-d column-wise storage).
+    // TODO: block structure with higher-dim inputs
+    AffineMap getDimOrdering() const;
+
+    // Returns the required bit width for pointer storage. A narrow width
+    // reduces the memory footprint of overhead storage, as long as the
+    // width suffices to define the total required range (viz. the maximum
+    // number of stored entries over all indirection dimensions). The choices
+    // are `8`, `16`, `32`, `64`, or `0` for a native width.
+    unsigned getPointerBitWidth() const;
+
+    // Returns the required bit width for index storage. A narrow width
+    // reduces the memory footprint of overhead storage, as long as the
+    // width suffices to define the total required range (viz. the maximum
+    // value of each tensor index over all dimensions). The choices are `8`,
+    // `16`, `32`, `64`, or `0` for a native width.
+    unsigned getIndexBitWidth() const;
+  }];
+}
+
+#endif // LLVMIR_ATTRDEFS

diff  --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index a0e473873d27a..bf0890fea49be 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -10,6 +10,7 @@
 #define TENSOR_OPS
 
 include "mlir/Dialect/Tensor/IR/TensorBase.td"
+include "mlir/Dialect/Tensor/IR/TensorAttrDefs.td"
 include "mlir/Interfaces/CastInterfaces.td"
 include "mlir/Interfaces/ControlFlowInterfaces.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"

diff  --git a/mlir/include/mlir/IR/CMakeLists.txt b/mlir/include/mlir/IR/CMakeLists.txt
index 8ce969777981f..42e07811a4a54 100644
--- a/mlir/include/mlir/IR/CMakeLists.txt
+++ b/mlir/include/mlir/IR/CMakeLists.txt
@@ -26,6 +26,11 @@ mlir_tablegen(BuiltinTypes.h.inc -gen-typedef-decls)
 mlir_tablegen(BuiltinTypes.cpp.inc -gen-typedef-defs)
 add_public_tablegen_target(MLIRBuiltinTypesIncGen)
 
+set(LLVM_TARGET_DEFINITIONS TensorEncoding.td)
+mlir_tablegen(TensorEncInterfaces.h.inc -gen-attr-interface-decls)
+mlir_tablegen(TensorEncInterfaces.cpp.inc -gen-attr-interface-defs)
+add_public_tablegen_target(MLIRTensorEncodingIncGen)
+
 add_mlir_doc(BuiltinAttributes BuiltinAttributes Dialects/ -gen-attrdef-doc)
 add_mlir_doc(BuiltinLocationAttributes BuiltinLocationAttributes Dialects/ -gen-attrdef-doc)
 add_mlir_doc(BuiltinOps BuiltinOps Dialects/ -gen-op-doc)

diff  --git a/mlir/include/mlir/IR/TensorEncoding.h b/mlir/include/mlir/IR/TensorEncoding.h
new file mode 100644
index 0000000000000..5d98aefaeb92e
--- /dev/null
+++ b/mlir/include/mlir/IR/TensorEncoding.h
@@ -0,0 +1,21 @@
+//===- TensorEncoding.h - MLIR Tensor Encoding Declarations------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_IR_TENSORENCODING_H
+#define MLIR_IR_TENSORENCODING_H
+
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/OpDefinition.h"
+
+//===----------------------------------------------------------------------===//
+// Tablegen Type Declarations
+//===----------------------------------------------------------------------===//
+
+#include "mlir/IR/TensorEncInterfaces.h.inc"
+
+#endif // MLIR_IR_TENSORENCODING_H

diff  --git a/mlir/include/mlir/IR/TensorEncoding.td b/mlir/include/mlir/IR/TensorEncoding.td
new file mode 100644
index 0000000000000..3991520d72a5f
--- /dev/null
+++ b/mlir/include/mlir/IR/TensorEncoding.td
@@ -0,0 +1,44 @@
+//===- TensorEncoding.td - Tensor encoding interfaces ------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Defines the interfaces associated with tensor encoding attributes.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_IR_TENSORINTERFACES
+#define MLIR_IR_TENSORINTERFACES
+
+include "mlir/IR/OpBase.td"
+
+//===----------------------------------------------------------------------===//
+// Attribute interface to verify a tensor encoding.
+//===----------------------------------------------------------------------===//
+
+def VerifiableTensorEncoding : AttrInterface<"VerifiableTensorEncoding"> {
+  let cppNamespace = "::mlir";
+  let description = [{
+    Verifies an encoding attribute for a tensor.
+  }];
+  let methods = [
+    InterfaceMethod<
+      /*desc=*/[{
+        Verifies the encoding is valid for a tensor type with the
+        given shape and element type. Generates a diagnostic using
+        the supplied callback on failure.
+      }],
+      /*retTy=*/"::mlir::LogicalResult",
+      /*methodName=*/"verifyEncoding",
+      /*args=*/(ins
+        "ArrayRef<int64_t>":$shape,
+        "Type":$elementType,
+        "::llvm::function_ref<::mlir::InFlightDiagnostic()>":$emitError)
+    >,
+  ];
+}
+
+#endif // MLIR_IR_TENSORINTERFACES

diff  --git a/mlir/lib/Dialect/Tensor/IR/CMakeLists.txt b/mlir/lib/Dialect/Tensor/IR/CMakeLists.txt
index de650995ebb60..108b7f2470cb2 100644
--- a/mlir/lib/Dialect/Tensor/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Tensor/IR/CMakeLists.txt
@@ -7,6 +7,7 @@ add_mlir_dialect_library(MLIRTensor
 
   DEPENDS
   MLIRTensorOpsIncGen
+  MLIRTensorAttrDefsIncGen
 
   LINK_COMPONENTS
   Core

diff  --git a/mlir/lib/Dialect/Tensor/IR/TensorDialect.cpp b/mlir/lib/Dialect/Tensor/IR/TensorDialect.cpp
index da76560fe85ff..bdc2fe345b9e5 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorDialect.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorDialect.cpp
@@ -7,11 +7,142 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/IR/DialectImplementation.h"
 #include "mlir/Transforms/InliningUtils.h"
+#include "llvm/ADT/TypeSwitch.h"
 
 using namespace mlir;
 using namespace mlir::tensor;
 
+//===----------------------------------------------------------------------===//
+// TableGen'd Attributes Methods
+//===----------------------------------------------------------------------===//
+
+#define GET_ATTRDEF_CLASSES
+#include "mlir/Dialect/Tensor/IR/TensorAttrDefs.cpp.inc"
+
+// Dictionary keys.
+static constexpr StringRef getSparseDimLevelTypeAttrName() {
+  return "sparseDimLevelType";
+}
+static constexpr StringRef getSparseDimOrderingAttrName() {
+  return "sparseDimOrdering";
+}
+static constexpr StringRef getSparsePointerBitWidthAttrName() {
+  return "sparsePointerBitWidth";
+}
+static constexpr StringRef getSparseIndexBitWidthAttrName() {
+  return "sparseIndexBitWidth";
+}
+
+// Dictionary values.
+static constexpr StringRef getDenseDimLevelTypeVal() { return "dense"; }
+static constexpr StringRef getCompressedDimLevelTypeVal() {
+  return "compressed";
+}
+static constexpr StringRef getSingletonDimLevelTypeVal() { return "singleton"; }
+
+Attribute SparseTensorEncodingAttr::parse(MLIRContext *context,
+                                          DialectAsmParser &parser, Type type) {
+  if (failed(parser.parseLess()))
+    return {};
+  DictionaryAttr dict;
+  if (failed(parser.parseAttribute(dict)))
+    return {};
+  if (failed(parser.parseGreater()))
+    return {};
+  return SparseTensorEncodingAttr::get(context, dict);
+}
+
+void SparseTensorEncodingAttr::print(DialectAsmPrinter &printer) const {
+  printer << "sparse<" << getDict() << ">";
+}
+
+LogicalResult SparseTensorEncodingAttr::verifyEncoding(
+    llvm::ArrayRef<int64_t> shape, Type elementType,
+    llvm::function_ref<mlir::InFlightDiagnostic()> emitError) const {
+  unsigned size = shape.size();
+  for (const NamedAttribute &attr : getDict()) {
+    if (attr.first == getSparseDimLevelTypeAttrName()) {
+      // Dimension level type verification.
+      auto arrayAttr = attr.second.dyn_cast<ArrayAttr>();
+      if (!arrayAttr || size != static_cast<int64_t>(arrayAttr.size()))
+        return emitError() << "expected an array of size " << size
+                           << " for dimension level types";
+      for (unsigned i = 0; i < size; i++) {
+        auto strAttr = arrayAttr[i].dyn_cast<StringAttr>();
+        if (!strAttr)
+          return emitError()
+                 << "expected string value in dimension level types";
+        auto strVal = strAttr.getValue();
+        if (strVal != getDenseDimLevelTypeVal() &&
+            strVal != getCompressedDimLevelTypeVal() &&
+            strVal != getSingletonDimLevelTypeVal())
+          return emitError() << "unexpected dimension level type: " << strAttr;
+      }
+    } else if (attr.first == getSparseDimOrderingAttrName()) {
+      // Dimension order verification.
+      auto affineAttr = attr.second.dyn_cast<AffineMapAttr>();
+      if (!affineAttr)
+        return emitError() << "expected an affine map for dimension ordering";
+      AffineMap map = affineAttr.getValue();
+      if (size != map.getNumResults() || !map.isPermutation())
+        return emitError() << "expected a permutation affine map of size "
+                           << size << " for dimension ordering";
+    } else if (attr.first == getSparsePointerBitWidthAttrName() ||
+               attr.first == getSparseIndexBitWidthAttrName()) {
+      // Pointer or index bitwidth verification.
+      auto intAttr = attr.second.dyn_cast<IntegerAttr>();
+      if (!intAttr)
+        return emitError() << "expected an integral bitwidth";
+      switch (intAttr.getInt()) {
+      case 0:
+      case 8:
+      case 16:
+      case 32:
+      case 64:
+        continue;
+      default:
+        return emitError() << "unexpected bitwidth: " << intAttr.getInt();
+      }
+    } else {
+      return emitError() << "unexpected key: " << attr.first.str();
+    }
+  }
+  return success();
+}
+
+SparseTensorEncodingAttr::DimLevelType
+SparseTensorEncodingAttr::getDimLevelType(unsigned dim) const {
+  if (auto value = getDict().get(getSparseDimLevelTypeAttrName())) {
+    auto strVal =
+        value.dyn_cast<ArrayAttr>()[dim].cast<StringAttr>().getValue();
+    if (strVal == getCompressedDimLevelTypeVal())
+      return DimLevelType::Compressed;
+    if (strVal == getSingletonDimLevelTypeVal())
+      return DimLevelType::Singleton;
+  }
+  return DimLevelType::Dense;
+}
+
+AffineMap SparseTensorEncodingAttr::getDimOrdering() const {
+  if (auto value = getDict().get(getSparseDimOrderingAttrName()))
+    return value.cast<AffineMapAttr>().getValue();
+  return {};
+}
+
+unsigned SparseTensorEncodingAttr::getPointerBitWidth() const {
+  if (auto value = getDict().get(getSparsePointerBitWidthAttrName()))
+    return value.cast<IntegerAttr>().getInt();
+  return 0;
+}
+
+unsigned SparseTensorEncodingAttr::getIndexBitWidth() const {
+  if (auto value = getDict().get(getSparseIndexBitWidthAttrName()))
+    return value.cast<IntegerAttr>().getInt();
+  return 0;
+}
+
 //===----------------------------------------------------------------------===//
 // TensorDialect Dialect Interfaces
 //===----------------------------------------------------------------------===//
@@ -30,10 +161,38 @@ struct TensorInlinerInterface : public DialectInlinerInterface {
 };
 } // end anonymous namespace
 
+//===----------------------------------------------------------------------===//
+// TensorDialect Methods
+//===----------------------------------------------------------------------===//
+
 void TensorDialect::initialize() {
+  addAttributes<
+#define GET_ATTRDEF_LIST
+#include "mlir/Dialect/Tensor/IR/TensorAttrDefs.cpp.inc"
+      >();
   addOperations<
 #define GET_OP_LIST
 #include "mlir/Dialect/Tensor/IR/TensorOps.cpp.inc"
       >();
   addInterfaces<TensorInlinerInterface>();
 }
+
+Attribute TensorDialect::parseAttribute(DialectAsmParser &parser,
+                                        Type type) const {
+  StringRef attrTag;
+  if (failed(parser.parseKeyword(&attrTag)))
+    return Attribute();
+  Attribute attr;
+  auto parseResult =
+      generatedAttributeParser(getContext(), parser, attrTag, type, attr);
+  if (parseResult.hasValue())
+    return attr;
+  parser.emitError(parser.getNameLoc(), "unknown tensor attribute");
+  return Attribute();
+}
+
+void TensorDialect::printAttribute(::mlir::Attribute attr,
+                                   ::mlir::DialectAsmPrinter &printer) const {
+  if (succeeded(generatedAttributePrinter(attr, printer)))
+    return;
+}

diff  --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index 4e2e2310ca01e..baadd8d0433cc 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -14,6 +14,7 @@
 #include "mlir/IR/BuiltinDialect.h"
 #include "mlir/IR/Diagnostics.h"
 #include "mlir/IR/Dialect.h"
+#include "mlir/IR/TensorEncoding.h"
 #include "llvm/ADT/APFloat.h"
 #include "llvm/ADT/BitVector.h"
 #include "llvm/ADT/Sequence.h"
@@ -446,7 +447,9 @@ RankedTensorType::verify(function_ref<InFlightDiagnostic()> emitError,
   for (int64_t s : shape)
     if (s < -1)
       return emitError() << "invalid tensor dimension size";
-  // TODO: verify contents of encoding attribute.
+  if (auto v = encoding.dyn_cast_or_null<VerifiableTensorEncoding>())
+    if (failed(v.verifyEncoding(shape, elementType, emitError)))
+      return failure();
   return checkTensorElementType(emitError, elementType);
 }
 

diff  --git a/mlir/lib/IR/CMakeLists.txt b/mlir/lib/IR/CMakeLists.txt
index 68367d69b68a8..06f1985b53283 100644
--- a/mlir/lib/IR/CMakeLists.txt
+++ b/mlir/lib/IR/CMakeLists.txt
@@ -22,6 +22,7 @@ add_mlir_library(MLIRIR
   Region.cpp
   RegionKindInterface.cpp
   SymbolTable.cpp
+  TensorEncoding.cpp
   Types.cpp
   TypeRange.cpp
   TypeUtilities.cpp
@@ -45,6 +46,7 @@ add_mlir_library(MLIRIR
   MLIRRegionKindInterfaceIncGen
   MLIRSideEffectInterfacesIncGen
   MLIRSymbolInterfacesIncGen
+  MLIRTensorEncodingIncGen
 
   LINK_LIBS PUBLIC
   MLIRSupport

diff  --git a/mlir/lib/IR/TensorEncoding.cpp b/mlir/lib/IR/TensorEncoding.cpp
new file mode 100644
index 0000000000000..2ab6788306049
--- /dev/null
+++ b/mlir/lib/IR/TensorEncoding.cpp
@@ -0,0 +1,17 @@
+//===- TensorEncoding.cpp - MLIR Tensor Encoding --------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/IR/TensorEncoding.h"
+
+using namespace mlir;
+
+//===----------------------------------------------------------------------===//
+// Tensor Encoding Interfaces Methods
+//===----------------------------------------------------------------------===//
+
+#include "mlir/IR/TensorEncInterfaces.cpp.inc"

diff  --git a/mlir/lib/Parser/TypeParser.cpp b/mlir/lib/Parser/TypeParser.cpp
index 0ec36c6085da1..b523d14a547da 100644
--- a/mlir/lib/Parser/TypeParser.cpp
+++ b/mlir/lib/Parser/TypeParser.cpp
@@ -13,6 +13,7 @@
 #include "Parser.h"
 #include "mlir/IR/AffineMap.h"
 #include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/TensorEncoding.h"
 
 using namespace mlir;
 using namespace mlir::detail;
@@ -412,8 +413,14 @@ Type Parser::parseTensorType() {
 
   // Parse an optional encoding attribute.
   Attribute encoding;
-  if (consumeIf(Token::comma))
+  if (consumeIf(Token::comma)) {
     encoding = parseAttribute();
+    if (auto v = encoding.dyn_cast_or_null<VerifiableTensorEncoding>()) {
+      if (failed(v.verifyEncoding(dimensions, elementType,
+                                  [&] { return emitError(); })))
+        return nullptr;
+    }
+  }
 
   if (!elementType || parseToken(Token::greater, "expected '>' in tensor type"))
     return nullptr;

diff  --git a/mlir/test/Dialect/Tensor/invalid_sparse_tensor.mlir b/mlir/test/Dialect/Tensor/invalid_sparse_tensor.mlir
new file mode 100644
index 0000000000000..b317c3b684e99
--- /dev/null
+++ b/mlir/test/Dialect/Tensor/invalid_sparse_tensor.mlir
@@ -0,0 +1,46 @@
+// RUN: mlir-opt <%s -split-input-file -verify-diagnostics
+
+// -----
+
+#a = #tensor.sparse<{sparseDimLevelType = [1,2]}>
+func private @tensor_size_mismatch(%arg0: tensor<8xi32, #a>) -> () // expected-error {{expected an array of size 1 for dimension level types}}
+
+// -----
+
+#a = #tensor.sparse<{sparseDimLevelType = [1]}>
+func private @tensor_type_mismatch(%arg0: tensor<8xi32, #a>) -> () // expected-error {{expected string value in dimension level types}}
+
+// -----
+
+#a = #tensor.sparse<{sparseDimLevelType = ["strange"]}>
+func private @tensor_value_mismatch(%arg0: tensor<8xi32, #a>) -> () // expected-error {{unexpected dimension level type: "strange"}}
+
+// -----
+
+#a = #tensor.sparse<{sparseDimOrdering = "wrong"}>
+func private @tensor_order_mismatch(%arg0: tensor<8xi32, #a>) -> () // expected-error {{expected an affine map for dimension ordering}}
+
+// -----
+
+#a = #tensor.sparse<{sparseDimOrdering = affine_map<(i,j) -> (i,i)>}>
+func private @tensor_no_permutation(%arg0: tensor<16x32xf32, #a>) -> () // expected-error {{expected a permutation affine map of size 2 for dimension ordering}}
+
+// -----
+
+#a = #tensor.sparse<{sparsePointerBitWidth = 42}>
+func private @tensor_invalid_int_ptr(%arg0: tensor<16x32xf32, #a>) -> () // expected-error {{unexpected bitwidth: 42}}
+
+// -----
+
+#a = #tensor.sparse<{sparseIndexBitWidth = "not really"}>
+func private @tensor_no_int_index(%arg0: tensor<16x32xf32, #a>) -> () // expected-error {{expected an integral bitwidth}}
+
+// -----
+
+#a = #tensor.sparse<{sparseIndexBitWidth = 128}>
+func private @tensor_invalid_int_index(%arg0: tensor<16x32xf32, #a>) -> () // expected-error {{unexpected bitwidth: 128}}
+
+// -----
+
+#a = #tensor.sparse<{key = 1}>
+func private @tensor_invalid_key(%arg0: tensor<16x32xf32, #a>) -> () // expected-error {{unexpected key: key}}

diff  --git a/mlir/test/Dialect/Tensor/valid_sparse.mlir b/mlir/test/Dialect/Tensor/valid_sparse.mlir
new file mode 100644
index 0000000000000..0f010e5ad9b70
--- /dev/null
+++ b/mlir/test/Dialect/Tensor/valid_sparse.mlir
@@ -0,0 +1,14 @@
+// RUN: mlir-opt <%s | mlir-opt | FileCheck %s
+
+// CHECK: func private @sparse_1d_tensor(tensor<32xf64, #tensor.sparse<{sparseDimLevelType = ["compressed"]}>>)
+func private @sparse_1d_tensor(tensor<32xf64, #tensor.sparse<{sparseDimLevelType = ["compressed"]}>>)
+
+#CSR = #tensor.sparse<{
+  sparseDimLevelType = [ "dense", "compressed" ],
+  sparseDimOrdering = affine_map<(i,j) -> (i,j)>,
+  sparseIndexBitWidth = 64,
+  sparsePointerBitWidth = 64
+}>
+
+// CHECK: func private @sparse_2d_tensor(tensor<?x?xf32, #tensor.sparse<{sparseDimLevelType = ["dense", "compressed"], sparseDimOrdering = affine_map<(d0, d1) -> (d0, d1)>, sparseIndexBitWidth = 64 : i64, sparsePointerBitWidth = 64 : i64}>>)
+func private @sparse_2d_tensor(tensor<?x?xf32, #CSR>)


        


More information about the Mlir-commits mailing list