[Mlir-commits] [mlir] 4f21152 - [mlir] Tighten verification of SparseElementsAttr
River Riddle
llvmlistbot at llvm.org
Mon Sep 20 18:58:10 PDT 2021
Author: River Riddle
Date: 2021-09-21T01:57:42Z
New Revision: 4f21152af12b21ea8f04b322a29dc6ad9e79ef16
URL: https://github.com/llvm/llvm-project/commit/4f21152af12b21ea8f04b322a29dc6ad9e79ef16
DIFF: https://github.com/llvm/llvm-project/commit/4f21152af12b21ea8f04b322a29dc6ad9e79ef16.diff
LOG: [mlir] Tighten verification of SparseElementsAttr
SparseElementsAttr currently does not perform any verfication on construction, with the only verification existing within the parser. This revision moves the parser verification to SparseElementsAttr, and also adds additional verification for when a sparse index is not valid.
Differential Revision: https://reviews.llvm.org/D109189
Added:
Modified:
mlir/include/mlir/IR/BuiltinAttributes.h
mlir/include/mlir/IR/BuiltinAttributes.td
mlir/lib/IR/BuiltinAttributes.cpp
mlir/lib/Parser/AttributeParser.cpp
mlir/lib/Parser/Parser.h
mlir/lib/Parser/TypeParser.cpp
mlir/test/CAPI/ir.c
mlir/test/Dialect/Quant/convert-const.mlir
mlir/test/Dialect/Tensor/canonicalize.mlir
mlir/test/IR/invalid.mlir
mlir/test/IR/parser.mlir
mlir/test/IR/pretty-attributes.mlir
mlir/test/Target/LLVMIR/llvmir.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/IR/BuiltinAttributes.h b/mlir/include/mlir/IR/BuiltinAttributes.h
index e0ede99b19af6..d332718bb9b17 100644
--- a/mlir/include/mlir/IR/BuiltinAttributes.h
+++ b/mlir/include/mlir/IR/BuiltinAttributes.h
@@ -70,6 +70,12 @@ class ElementsAttr : public Attribute {
/// Return if the given 'index' refers to a valid element in this attribute.
bool isValidIndex(ArrayRef<uint64_t> index) const;
+ static bool isValidIndex(ShapedType type, ArrayRef<uint64_t> index);
+
+ /// Returns the 1-dimensional flattened row-major index from the given
+ /// multi-dimensional index.
+ uint64_t getFlattenedIndex(ArrayRef<uint64_t> index) const;
+ static uint64_t getFlattenedIndex(ShapedType type, ArrayRef<uint64_t> index);
/// Returns the number of elements held by this attribute.
int64_t getNumElements() const;
@@ -94,11 +100,6 @@ class ElementsAttr : public Attribute {
/// Method for support type inquiry through isa, cast and dyn_cast.
static bool classof(Attribute attr);
-
-protected:
- /// Returns the 1 dimensional flattened row-major index from the given
- /// multi-dimensional index.
- uint64_t getFlattenedIndex(ArrayRef<uint64_t> index) const;
};
namespace detail {
diff --git a/mlir/include/mlir/IR/BuiltinAttributes.td b/mlir/include/mlir/IR/BuiltinAttributes.td
index 228f1c6dca992..25e54cbfd68c9 100644
--- a/mlir/include/mlir/IR/BuiltinAttributes.td
+++ b/mlir/include/mlir/IR/BuiltinAttributes.td
@@ -791,6 +791,7 @@ def Builtin_SparseElementsAttr
public:
}];
+ let genVerifyDecl = 1;
let skipDefaultBuilders = 1;
}
diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp
index 0ecc3d426af48..d906adc3e8151 100644
--- a/mlir/lib/IR/BuiltinAttributes.cpp
+++ b/mlir/lib/IR/BuiltinAttributes.cpp
@@ -405,25 +405,45 @@ Attribute ElementsAttr::getValue(ArrayRef<uint64_t> index) const {
return cast<SparseElementsAttr>().getValue(index);
}
-/// Return if the given 'index' refers to a valid element in this attribute.
bool ElementsAttr::isValidIndex(ArrayRef<uint64_t> index) const {
- auto type = getType();
-
+ return isValidIndex(getType(), index);
+}
+bool ElementsAttr::isValidIndex(ShapedType type, ArrayRef<uint64_t> index) {
// Verify that the rank of the indices matches the held type.
- auto rank = type.getRank();
+ int64_t rank = type.getRank();
if (rank == 0 && index.size() == 1 && index[0] == 0)
return true;
if (rank != static_cast<int64_t>(index.size()))
return false;
// Verify that all of the indices are within the shape dimensions.
- auto shape = type.getShape();
+ ArrayRef<int64_t> shape = type.getShape();
return llvm::all_of(llvm::seq<int>(0, rank), [&](int i) {
int64_t dim = static_cast<int64_t>(index[i]);
return 0 <= dim && dim < shape[i];
});
}
+uint64_t ElementsAttr::getFlattenedIndex(ArrayRef<uint64_t> index) const {
+ return getFlattenedIndex(getType(), index);
+}
+uint64_t ElementsAttr::getFlattenedIndex(ShapedType type,
+ ArrayRef<uint64_t> index) {
+ assert(isValidIndex(type, index) && "expected valid multi-dimensional index");
+
+ // Reduce the provided multidimensional index into a flattended 1D row-major
+ // index.
+ auto rank = type.getRank();
+ auto shape = type.getShape();
+ uint64_t valueIndex = 0;
+ uint64_t dimMultiplier = 1;
+ for (int i = rank - 1; i >= 0; --i) {
+ valueIndex += index[i] * dimMultiplier;
+ dimMultiplier *= shape[i];
+ }
+ return valueIndex;
+}
+
ElementsAttr
ElementsAttr::mapValues(Type newElementType,
function_ref<APInt(const APInt &)> mapping) const {
@@ -446,25 +466,6 @@ bool ElementsAttr::classof(Attribute attr) {
OpaqueElementsAttr, SparseElementsAttr>();
}
-/// Returns the 1 dimensional flattened row-major index from the given
-/// multi-dimensional index.
-uint64_t ElementsAttr::getFlattenedIndex(ArrayRef<uint64_t> index) const {
- assert(isValidIndex(index) && "expected valid multi-dimensional index");
- auto type = getType();
-
- // Reduce the provided multidimensional index into a flattended 1D row-major
- // index.
- auto rank = type.getRank();
- auto shape = type.getShape();
- uint64_t valueIndex = 0;
- uint64_t dimMultiplier = 1;
- for (int i = rank - 1; i >= 0; --i) {
- valueIndex += index[i] * dimMultiplier;
- dimMultiplier *= shape[i];
- }
- return valueIndex;
-}
-
//===----------------------------------------------------------------------===//
// DenseElementsAttr Utilities
//===----------------------------------------------------------------------===//
@@ -1421,6 +1422,64 @@ std::vector<ptr
diff _t> SparseElementsAttr::getFlattenedSparseIndices() const {
return flatSparseIndices;
}
+LogicalResult
+SparseElementsAttr::verify(function_ref<InFlightDiagnostic()> emitError,
+ ShapedType type, DenseIntElementsAttr sparseIndices,
+ DenseElementsAttr values) {
+ ShapedType valuesType = values.getType();
+ if (valuesType.getRank() != 1)
+ return emitError() << "expected 1-d tensor for sparse element values";
+
+ // Verify the indices and values shape.
+ ShapedType indicesType = sparseIndices.getType();
+ auto emitShapeError = [&]() {
+ return emitError() << "expected shape ([" << type.getShape()
+ << "]); inferred shape of indices literal (["
+ << indicesType.getShape()
+ << "]); inferred shape of values literal (["
+ << valuesType.getShape() << "])";
+ };
+ // Verify indices shape.
+ size_t rank = type.getRank(), indicesRank = indicesType.getRank();
+ if (indicesRank == 2) {
+ if (indicesType.getDimSize(1) != rank)
+ return emitShapeError();
+ } else if (indicesRank != 1 || rank != 1) {
+ return emitShapeError();
+ }
+ // Verify the values shape.
+ int64_t numSparseIndices = indicesType.getDimSize(0);
+ if (numSparseIndices != valuesType.getDimSize(0))
+ return emitShapeError();
+
+ // Verify that the sparse indices are within the value shape.
+ auto emitIndexError = [&](unsigned indexNum, ArrayRef<uint64_t> index) {
+ return emitError()
+ << "sparse index #" << indexNum
+ << " is not contained within the value shape, with index=[" << index
+ << "], and type=" << type;
+ };
+
+ // Handle the case where the index values are a splat.
+ auto sparseIndexValues = sparseIndices.getValues<uint64_t>();
+ if (sparseIndices.isSplat()) {
+ SmallVector<uint64_t> indices(rank, *sparseIndexValues.begin());
+ if (!ElementsAttr::isValidIndex(type, indices))
+ return emitIndexError(0, indices);
+ return success();
+ }
+
+ // Otherwise, reinterpret each index as an ArrayRef.
+ for (size_t i = 0, e = numSparseIndices; i != e; ++i) {
+ ArrayRef<uint64_t> index(&*std::next(sparseIndexValues.begin(), i * rank),
+ rank);
+ if (!ElementsAttr::isValidIndex(type, index))
+ return emitIndexError(i, index);
+ }
+
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// TypeAttr
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Parser/AttributeParser.cpp b/mlir/lib/Parser/AttributeParser.cpp
index 1e9e87bdb7e5a..6e512bdf9747c 100644
--- a/mlir/lib/Parser/AttributeParser.cpp
+++ b/mlir/lib/Parser/AttributeParser.cpp
@@ -893,6 +893,7 @@ ShapedType Parser::parseElementsLiteralType(Type type) {
/// Parse a sparse elements attribute.
Attribute Parser::parseSparseElementsAttr(Type attrType) {
+ llvm::SMLoc loc = getToken().getLoc();
consumeToken(Token::kw_sparse);
if (parseToken(Token::less, "Expected '<' after 'sparse'"))
return nullptr;
@@ -911,8 +912,8 @@ Attribute Parser::parseSparseElementsAttr(Type attrType) {
ShapedType indicesType =
RankedTensorType::get({0, type.getRank()}, indiceEltType);
ShapedType valuesType = RankedTensorType::get({0}, type.getElementType());
- return SparseElementsAttr::get(
- type, DenseElementsAttr::get(indicesType, ArrayRef<Attribute>()),
+ return getChecked<SparseElementsAttr>(
+ loc, type, DenseElementsAttr::get(indicesType, ArrayRef<Attribute>()),
DenseElementsAttr::get(valuesType, ArrayRef<Attribute>()));
}
@@ -963,22 +964,6 @@ Attribute Parser::parseSparseElementsAttr(Type attrType) {
: RankedTensorType::get(valuesParser.getShape(), valuesEltType);
auto values = valuesParser.getAttr(valuesLoc, valuesType);
- /// Sanity check.
- if (valuesType.getRank() != 1)
- return (emitError("expected 1-d tensor for values"), nullptr);
-
- auto sameShape = (indicesType.getRank() == 1) ||
- (type.getRank() == indicesType.getDimSize(1));
- auto sameElementNum = indicesType.getDimSize(0) == valuesType.getDimSize(0);
- if (!sameShape || !sameElementNum) {
- emitError() << "expected shape ([" << type.getShape()
- << "]); inferred shape of indices literal (["
- << indicesType.getShape()
- << "]); inferred shape of values literal (["
- << valuesType.getShape() << "])";
- return nullptr;
- }
-
// Build the sparse elements attribute by the indices and values.
- return SparseElementsAttr::get(type, indices, values);
+ return getChecked<SparseElementsAttr>(loc, type, indices, values);
}
diff --git a/mlir/lib/Parser/Parser.h b/mlir/lib/Parser/Parser.h
index 2f41f99a51859..f8a2c74e455f4 100644
--- a/mlir/lib/Parser/Parser.h
+++ b/mlir/lib/Parser/Parser.h
@@ -140,6 +140,16 @@ class Parser {
// Type Parsing
//===--------------------------------------------------------------------===//
+ /// Invoke the `getChecked` method of the given Attribute or Type class, using
+ /// the provided location to emit errors in the case of failure. Note that
+ /// unlike `OpBuilder::getType`, this method does not implicitly insert a
+ /// context parameter.
+ template <typename T, typename... ParamsT>
+ T getChecked(llvm::SMLoc loc, ParamsT &&...params) {
+ return T::getChecked([&] { return emitError(loc); },
+ std::forward<ParamsT>(params)...);
+ }
+
ParseResult parseFunctionResultTypes(SmallVectorImpl<Type> &elements);
ParseResult parseTypeListNoParens(SmallVectorImpl<Type> &elements);
ParseResult parseTypeListParens(SmallVectorImpl<Type> &elements);
diff --git a/mlir/lib/Parser/TypeParser.cpp b/mlir/lib/Parser/TypeParser.cpp
index b523d14a547da..66ec3b4237081 100644
--- a/mlir/lib/Parser/TypeParser.cpp
+++ b/mlir/lib/Parser/TypeParser.cpp
@@ -193,6 +193,7 @@ ParseResult Parser::parseStridedLayout(int64_t &offset,
/// memory-space ::= integer-literal /* | TODO: address-space-id */
///
Type Parser::parseMemRefType() {
+ llvm::SMLoc loc = getToken().getLoc();
consumeToken(Token::kw_memref);
if (parseToken(Token::less, "expected '<' in memref type"))
@@ -283,15 +284,11 @@ Type Parser::parseMemRefType() {
}
}
- if (isUnranked) {
- return UnrankedMemRefType::getChecked(
- [&]() -> InFlightDiagnostic { return emitError(); }, elementType,
- memorySpace);
- }
+ if (isUnranked)
+ return getChecked<UnrankedMemRefType>(loc, elementType, memorySpace);
- return MemRefType::getChecked(
- [&]() -> InFlightDiagnostic { return emitError(); }, dimensions,
- elementType, affineMapComposition, memorySpace);
+ return getChecked<MemRefType>(loc, dimensions, elementType,
+ affineMapComposition, memorySpace);
}
/// Parse any type except the function type.
diff --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c
index ebb8b0f26e542..d85af8fb6b700 100644
--- a/mlir/test/CAPI/ir.c
+++ b/mlir/test/CAPI/ir.c
@@ -1087,19 +1087,19 @@ int printBuiltinAttributes(MlirContext ctx) {
// CHECK: 1.000000e+00 : f32
// CHECK: 1.000000e+00 : f64
- int64_t indices[] = {4, 7};
- int64_t two = 2;
+ int64_t indices[] = {0, 1};
+ int64_t one = 1;
MlirAttribute indicesAttr = mlirDenseElementsAttrInt64Get(
- mlirRankedTensorTypeGet(1, &two, mlirIntegerTypeGet(ctx, 64), encoding),
+ mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 64), encoding),
2, indices);
MlirAttribute valuesAttr = mlirDenseElementsAttrFloatGet(
- mlirRankedTensorTypeGet(1, &two, mlirF32TypeGet(ctx), encoding), 2,
+ mlirRankedTensorTypeGet(1, &one, mlirF32TypeGet(ctx), encoding), 1,
floats);
MlirAttribute sparseAttr = mlirSparseElementsAttribute(
mlirRankedTensorTypeGet(2, shape, mlirF32TypeGet(ctx), encoding),
indicesAttr, valuesAttr);
mlirAttributeDump(sparseAttr);
- // CHECK: sparse<[4, 7], [0.000000e+00, 1.000000e+00]> : tensor<1x2xf32>
+ // CHECK: sparse<{{\[}}[0, 1]], 0.000000e+00> : tensor<1x2xf32>
return 0;
}
diff --git a/mlir/test/Dialect/Quant/convert-const.mlir b/mlir/test/Dialect/Quant/convert-const.mlir
index fb6baa25ba4cc..293638e73f070 100644
--- a/mlir/test/Dialect/Quant/convert-const.mlir
+++ b/mlir/test/Dialect/Quant/convert-const.mlir
@@ -68,15 +68,15 @@ func @const_dense_tensor_i8_fixedpoint() -> tensor<7xf32> {
// -----
// Verifies i8 fixedpoint quantization on a sparse tensor, sweeping values.
// CHECK-LABEL: const_sparse_tensor_i8_fixedpoint
-func @const_sparse_tensor_i8_fixedpoint() -> tensor<7x2xf32> {
+func @const_sparse_tensor_i8_fixedpoint() -> tensor<2x7xf32> {
// NOTE: Ugly regex match pattern for opening "[[" of indices tensor.
- // CHECK: %cst = constant sparse<{{\[}}[0, 0], [0, 1], [0, 2], [0, 3], [0, 4], [0, 5], [0, 6]], [-128, -128, -64, 0, 64, 127, 127]> : tensor<7x2xi8>
+ // CHECK: %cst = constant sparse<{{\[}}[0, 0], [0, 1], [0, 2], [0, 3], [0, 4], [0, 5], [0, 6]], [-128, -128, -64, 0, 64, 127, 127]> : tensor<2x7xi8>
%cst = constant sparse<
[[0, 0], [0, 1], [0, 2], [0, 3], [0, 4], [0, 5], [0, 6]],
- [-2.0, -1.0, -0.5, 0.0, 0.5, 1.0, 2.0]> : tensor<7x2xf32>
- %1 = "quant.qcast"(%cst) : (tensor<7x2xf32>) -> tensor<7x2x!quant.uniform<i8:f32, 7.812500e-03>>
- %2 = "quant.dcast"(%1) : (tensor<7x2x!quant.uniform<i8:f32, 7.812500e-03>>) -> (tensor<7x2xf32>)
- return %2 : tensor<7x2xf32>
+ [-2.0, -1.0, -0.5, 0.0, 0.5, 1.0, 2.0]> : tensor<2x7xf32>
+ %1 = "quant.qcast"(%cst) : (tensor<2x7xf32>) -> tensor<2x7x!quant.uniform<i8:f32, 7.812500e-03>>
+ %2 = "quant.dcast"(%1) : (tensor<2x7x!quant.uniform<i8:f32, 7.812500e-03>>) -> (tensor<2x7xf32>)
+ return %2 : tensor<2x7xf32>
}
// -----
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 7ef93fbe1b10f..d5171a6358637 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -83,8 +83,8 @@ func @fold_extract(%arg0 : index) -> (f32, f16, f16, i32) {
%ext_2 = tensor.extract %1[%const_1, %const_1, %const_1] : tensor<4x4x4xf16>
// Fold an extract into a sparse with a non sparse index.
- %2 = constant sparse<[[1, 1, 1]], [-2.0]> : tensor<1x1x1xf16>
- %ext_3 = tensor.extract %2[%const_0, %const_0, %const_0] : tensor<1x1x1xf16>
+ %2 = constant sparse<[[1, 1, 1]], [-2.0]> : tensor<2x2x2xf16>
+ %ext_3 = tensor.extract %2[%const_0, %const_0, %const_0] : tensor<2x2x2xf16>
// Fold an extract into a dense tensor.
%3 = constant dense<[[[1, -2, 1, 36]], [[0, 2, -1, 64]]]> : tensor<2x1x4xi32>
diff --git a/mlir/test/IR/invalid.mlir b/mlir/test/IR/invalid.mlir
index fc3eb9c3fbe87..ee4291606e42d 100644
--- a/mlir/test/IR/invalid.mlir
+++ b/mlir/test/IR/invalid.mlir
@@ -897,7 +897,7 @@ func @mi() {
// -----
func @invalid_tensor_literal() {
- // expected-error @+1 {{expected 1-d tensor for values}}
+ // expected-error @+1 {{expected 1-d tensor for sparse element values}}
"foof16"(){bar = sparse<[[0, 0, 0]], [[-2.0]]> : vector<1x1x1xf16>} : () -> ()
// -----
@@ -908,6 +908,12 @@ func @invalid_tensor_literal() {
// -----
+func @invalid_tensor_literal() {
+ // expected-error @+1 {{sparse index #0 is not contained within the value shape, with index=[1, 1], and type='tensor<1x1xi16>'}}
+ "fooi16"(){bar = sparse<1, 10> : tensor<1x1xi16>} : () -> ()
+
+// -----
+
func @invalid_affine_structure() {
%c0 = constant 0 : index
%idx = affine.apply affine_map<(d0, d1)> (%c0, %c0) // expected-error {{expected '->' or ':'}}
diff --git a/mlir/test/IR/parser.mlir b/mlir/test/IR/parser.mlir
index bcc26ceebffb0..959f1e2af4abd 100644
--- a/mlir/test/IR/parser.mlir
+++ b/mlir/test/IR/parser.mlir
@@ -810,7 +810,7 @@ func @sparsetensorattr() -> () {
// CHECK: "fooi32"() {bar = sparse<> : tensor<1x1xi32>} : () -> ()
"fooi32"(){bar = sparse<> : tensor<1x1xi32>} : () -> ()
// CHECK: "fooi64"() {bar = sparse<0, -1> : tensor<1xi64>} : () -> ()
- "fooi64"(){bar = sparse<[[0]], [-1]> : tensor<1xi64>} : () -> ()
+ "fooi64"(){bar = sparse<[0], [-1]> : tensor<1xi64>} : () -> ()
// CHECK: "foo2"() {bar = sparse<> : tensor<0xi32>} : () -> ()
"foo2"(){bar = sparse<> : tensor<0xi32>} : () -> ()
// CHECK: "foo3"() {bar = sparse<> : tensor<i32>} : () -> ()
diff --git a/mlir/test/IR/pretty-attributes.mlir b/mlir/test/IR/pretty-attributes.mlir
index 280e32672ea5f..d1d43a17b8625 100644
--- a/mlir/test/IR/pretty-attributes.mlir
+++ b/mlir/test/IR/pretty-attributes.mlir
@@ -11,8 +11,8 @@
// CHECK: dense<[1, 2]> : tensor<2xi32>
"test.non_elided_dense_attr"() {foo.dense_attr = dense<[1, 2]> : tensor<2xi32>} : () -> ()
-// CHECK: opaque<"_", "0xDEADBEEF"> : vector<1x1x1xf16>
-"test.sparse_attr"() {foo.sparse_attr = sparse<[[1, 2, 3]], -2.0> : vector<1x1x1xf16>} : () -> ()
+// CHECK: opaque<"_", "0xDEADBEEF"> : vector<1x1x10xf16>
+"test.sparse_attr"() {foo.sparse_attr = sparse<[[0, 0, 5]], -2.0> : vector<1x1x10xf16>} : () -> ()
// CHECK: opaque<"_", "0xDEADBEEF"> : tensor<100xf32>
"test.opaque_attr"() {foo.opaque_attr = opaque<"_", "0xEBFE"> : tensor<100xf32> } : () -> ()
diff --git a/mlir/test/Target/LLVMIR/llvmir.mlir b/mlir/test/Target/LLVMIR/llvmir.mlir
index 86549afbc9a83..205308754cc92 100644
--- a/mlir/test/Target/LLVMIR/llvmir.mlir
+++ b/mlir/test/Target/LLVMIR/llvmir.mlir
@@ -1157,7 +1157,7 @@ llvm.func @alloca(%size : i64) {
// CHECK-LABEL: @constants
llvm.func @constants() -> vector<4xf32> {
// CHECK: ret <4 x float> <float 4.2{{0*}}e+01, float 0.{{0*}}e+00, float 0.{{0*}}e+00, float 0.{{0*}}e+00>
- %0 = llvm.mlir.constant(sparse<[[0]], [4.2e+01]> : vector<4xf32>) : vector<4xf32>
+ %0 = llvm.mlir.constant(sparse<[0], [4.2e+01]> : vector<4xf32>) : vector<4xf32>
llvm.return %0 : vector<4xf32>
}
More information about the Mlir-commits
mailing list