[Mlir-commits] [mlir] 3cf03f1 - [mlir][sparse] Adding IsSparseTensorPred and updating ops to use it
wren romano
llvmlistbot at llvm.org
Fri Jun 3 17:15:38 PDT 2022
Author: wren romano
Date: 2022-06-03T17:15:31-07:00
New Revision: 3cf03f1c562f25e76309ab904ed830dfaebf74dc
URL: https://github.com/llvm/llvm-project/commit/3cf03f1c562f25e76309ab904ed830dfaebf74dc
DIFF: https://github.com/llvm/llvm-project/commit/3cf03f1c562f25e76309ab904ed830dfaebf74dc.diff
LOG: [mlir][sparse] Adding IsSparseTensorPred and updating ops to use it
Reviewed By: aartbik
Differential Revision: https://reviews.llvm.org/D126994
Added:
Modified:
mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
mlir/test/Dialect/SparseTensor/invalid.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
index 2db552f5fe039..c0c3d4920a1fe 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
@@ -93,4 +93,28 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
}];
}
+def IsSparseTensorPred
+ : CPred<"!!::mlir::sparse_tensor::getSparseTensorEncoding($_self)">;
+
+// The following four follow the same idiom as `TensorOf`, `AnyTensor`,
+// `RankedTensorOf`, `AnyRankedTensor`.
+
+class SparseTensorOf<list<Type> allowedTypes>
+ : ShapedContainerType<
+ allowedTypes,
+ And<[IsTensorTypePred, IsSparseTensorPred]>,
+ "sparse tensor",
+ "::mlir::TensorType">;
+
+def AnySparseTensor : SparseTensorOf<[AnyType]>;
+
+class RankedSparseTensorOf<list<Type> allowedTypes>
+ : ShapedContainerType<
+ allowedTypes,
+ And<[IsTensorTypePred, HasRankPred, IsSparseTensorPred]>,
+ "ranked sparse tensor",
+ "::mlir::TensorType">;
+
+def AnyRankedSparseTensor : RankedSparseTensorOf<[AnyType]>;
+
#endif // SPARSETENSOR_ATTRDEFS
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index 4f31031b1fe84..bdc27b57fe10f 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -27,7 +27,7 @@ class SparseTensor_Op<string mnemonic, list<Trait> traits = []>
def SparseTensor_NewOp : SparseTensor_Op<"new", [NoSideEffect]>,
Arguments<(ins AnyType:$source)>,
- Results<(outs TensorOf<[AnyType]>:$result)> {
+ Results<(outs AnySparseTensor:$result)> {
string summary = "Materializes a new sparse tensor from given source";
string description = [{
Materializes a sparse tensor with contents taken from an opaque pointer
@@ -46,7 +46,6 @@ def SparseTensor_NewOp : SparseTensor_Op<"new", [NoSideEffect]>,
```
}];
let assemblyFormat = "$source attr-dict `:` type($source) `to` type($result)";
- let hasVerifier = 1;
}
def SparseTensor_ConvertOp : SparseTensor_Op<"convert",
@@ -92,7 +91,7 @@ def SparseTensor_ConvertOp : SparseTensor_Op<"convert",
}
def SparseTensor_ToPointersOp : SparseTensor_Op<"pointers", [NoSideEffect]>,
- Arguments<(ins AnyTensor:$tensor, Index:$dim)>,
+ Arguments<(ins AnySparseTensor:$tensor, Index:$dim)>,
Results<(outs AnyStridedMemRefOfRank<1>:$result)> {
let summary = "Extracts pointers array at given dimension from a tensor";
let description = [{
@@ -117,7 +116,7 @@ def SparseTensor_ToPointersOp : SparseTensor_Op<"pointers", [NoSideEffect]>,
}
def SparseTensor_ToIndicesOp : SparseTensor_Op<"indices", [NoSideEffect]>,
- Arguments<(ins AnyTensor:$tensor, Index:$dim)>,
+ Arguments<(ins AnySparseTensor:$tensor, Index:$dim)>,
Results<(outs AnyStridedMemRefOfRank<1>:$result)> {
let summary = "Extracts indices array at given dimension from a tensor";
let description = [{
@@ -142,7 +141,7 @@ def SparseTensor_ToIndicesOp : SparseTensor_Op<"indices", [NoSideEffect]>,
}
def SparseTensor_ToValuesOp : SparseTensor_Op<"values", [NoSideEffect]>,
- Arguments<(ins AnyTensor:$tensor)>,
+ Arguments<(ins AnySparseTensor:$tensor)>,
Results<(outs AnyStridedMemRefOfRank<1>:$result)> {
let summary = "Extracts numerical values array from a tensor";
let description = [{
@@ -173,7 +172,7 @@ def SparseTensor_ToValuesOp : SparseTensor_Op<"values", [NoSideEffect]>,
//===----------------------------------------------------------------------===//
def SparseTensor_LexInsertOp : SparseTensor_Op<"lex_insert", []>,
- Arguments<(ins AnyTensor:$tensor,
+ Arguments<(ins AnySparseTensor:$tensor,
StridedMemRefRankOf<[Index], [1]>:$indices,
AnyType:$value)> {
string summary = "Inserts a value into given sparse tensor in lexicographical index order";
@@ -196,11 +195,10 @@ def SparseTensor_LexInsertOp : SparseTensor_Op<"lex_insert", []>,
}];
let assemblyFormat = "$tensor `,` $indices `,` $value attr-dict `:`"
" type($tensor) `,` type($indices) `,` type($value)";
- let hasVerifier = 1;
}
def SparseTensor_ExpandOp : SparseTensor_Op<"expand", []>,
- Arguments<(ins AnyTensor:$tensor)>,
+ Arguments<(ins AnySparseTensor:$tensor)>,
Results<(outs AnyStridedMemRefOfRank<1>:$values,
StridedMemRefRankOf<[I1],[1]>:$filled,
StridedMemRefRankOf<[Index],[1]>:$added,
@@ -238,11 +236,10 @@ def SparseTensor_ExpandOp : SparseTensor_Op<"expand", []>,
}];
let assemblyFormat = "$tensor attr-dict `:` type($tensor) `to` type($values)"
" `,` type($filled) `,` type($added) `,` type($count)";
- let hasVerifier = 1;
}
def SparseTensor_CompressOp : SparseTensor_Op<"compress", []>,
- Arguments<(ins AnyTensor:$tensor,
+ Arguments<(ins AnySparseTensor:$tensor,
StridedMemRefRankOf<[Index],[1]>:$indices,
AnyStridedMemRefOfRank<1>:$values,
StridedMemRefRankOf<[I1],[1]>:$filled,
@@ -273,11 +270,10 @@ def SparseTensor_CompressOp : SparseTensor_Op<"compress", []>,
" $added `,` $count attr-dict `:` type($tensor) `,`"
" type($indices) `,` type($values) `,` type($filled) `,`"
" type($added) `,` type($count)";
- let hasVerifier = 1;
}
def SparseTensor_LoadOp : SparseTensor_Op<"load", [SameOperandsAndResultType]>,
- Arguments<(ins AnyTensor:$tensor, UnitAttr:$hasInserts)>,
+ Arguments<(ins AnySparseTensor:$tensor, UnitAttr:$hasInserts)>,
Results<(outs AnyTensor:$result)> {
let summary =
"Rematerializes tensor from underlying sparse storage format";
@@ -306,11 +302,10 @@ def SparseTensor_LoadOp : SparseTensor_Op<"load", [SameOperandsAndResultType]>,
```
}];
let assemblyFormat = "$tensor (`hasInserts` $hasInserts^)? attr-dict `:` type($tensor)";
- let hasVerifier = 1;
}
def SparseTensor_ReleaseOp : SparseTensor_Op<"release", []>,
- Arguments<(ins AnyTensor:$tensor)> {
+ Arguments<(ins AnySparseTensor:$tensor)> {
string summary = "Releases underlying sparse storage format of given tensor";
string description = [{
Releases the underlying sparse storage format for a tensor that
@@ -332,11 +327,10 @@ def SparseTensor_ReleaseOp : SparseTensor_Op<"release", []>,
```
}];
let assemblyFormat = "$tensor attr-dict `:` type($tensor)";
- let hasVerifier = 1;
}
def SparseTensor_OutOp : SparseTensor_Op<"out", []>,
- Arguments<(ins AnyType:$tensor, AnyType:$dest)> {
+ Arguments<(ins AnySparseTensor:$tensor, AnyType:$dest)> {
string summary = "Outputs a sparse tensor to the given destination";
string description = [{
Outputs the contents of a sparse tensor to the destination defined by an
@@ -353,7 +347,6 @@ def SparseTensor_OutOp : SparseTensor_Op<"out", []>,
```
}];
let assemblyFormat = "$tensor `,` $dest attr-dict `:` type($tensor) `,` type($dest)";
- let hasVerifier = 1;
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index b860f07528956..418e7fe3bd822 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -208,12 +208,6 @@ static LogicalResult isMatchingWidth(Value result, unsigned width) {
return failure();
}
-LogicalResult NewOp::verify() {
- if (!getSparseTensorEncoding(result().getType()))
- return emitError("expected a sparse tensor result");
- return success();
-}
-
LogicalResult ConvertOp::verify() {
if (auto tp1 = source().getType().dyn_cast<RankedTensorType>()) {
if (auto tp2 = dest().getType().dyn_cast<RankedTensorType>()) {
@@ -240,30 +234,24 @@ OpFoldResult ConvertOp::fold(ArrayRef<Attribute> operands) {
}
LogicalResult ToPointersOp::verify() {
- if (auto e = getSparseTensorEncoding(tensor().getType())) {
- if (failed(isInBounds(dim(), tensor())))
- return emitError("requested pointers dimension out of bounds");
- if (failed(isMatchingWidth(result(), e.getPointerBitWidth())))
- return emitError("unexpected type for pointers");
- return success();
- }
- return emitError("expected a sparse tensor to get pointers");
+ auto e = getSparseTensorEncoding(tensor().getType());
+ if (failed(isInBounds(dim(), tensor())))
+ return emitError("requested pointers dimension out of bounds");
+ if (failed(isMatchingWidth(result(), e.getPointerBitWidth())))
+ return emitError("unexpected type for pointers");
+ return success();
}
LogicalResult ToIndicesOp::verify() {
- if (auto e = getSparseTensorEncoding(tensor().getType())) {
- if (failed(isInBounds(dim(), tensor())))
- return emitError("requested indices dimension out of bounds");
- if (failed(isMatchingWidth(result(), e.getIndexBitWidth())))
- return emitError("unexpected type for indices");
- return success();
- }
- return emitError("expected a sparse tensor to get indices");
+ auto e = getSparseTensorEncoding(tensor().getType());
+ if (failed(isInBounds(dim(), tensor())))
+ return emitError("requested indices dimension out of bounds");
+ if (failed(isMatchingWidth(result(), e.getIndexBitWidth())))
+ return emitError("unexpected type for indices");
+ return success();
}
LogicalResult ToValuesOp::verify() {
- if (!getSparseTensorEncoding(tensor().getType()))
- return emitError("expected a sparse tensor to get values");
RankedTensorType ttp = tensor().getType().cast<RankedTensorType>();
MemRefType mtp = result().getType().cast<MemRefType>();
if (ttp.getElementType() != mtp.getElementType())
@@ -271,46 +259,6 @@ LogicalResult ToValuesOp::verify() {
return success();
}
-//===----------------------------------------------------------------------===//
-// TensorDialect Management Operations.
-//===----------------------------------------------------------------------===//
-
-LogicalResult LexInsertOp::verify() {
- if (!getSparseTensorEncoding(tensor().getType()))
- return emitError("expected a sparse tensor for insertion");
- return success();
-}
-
-LogicalResult ExpandOp::verify() {
- if (!getSparseTensorEncoding(tensor().getType()))
- return emitError("expected a sparse tensor for expansion");
- return success();
-}
-
-LogicalResult CompressOp::verify() {
- if (!getSparseTensorEncoding(tensor().getType()))
- return emitError("expected a sparse tensor for compression");
- return success();
-}
-
-LogicalResult LoadOp::verify() {
- if (!getSparseTensorEncoding(tensor().getType()))
- return emitError("expected a sparse tensor to materialize");
- return success();
-}
-
-LogicalResult ReleaseOp::verify() {
- if (!getSparseTensorEncoding(tensor().getType()))
- return emitError("expected a sparse tensor to release");
- return success();
-}
-
-LogicalResult OutOp::verify() {
- if (!getSparseTensorEncoding(tensor().getType()))
- return emitError("expected a sparse tensor for output");
- return success();
-}
-
//===----------------------------------------------------------------------===//
// TensorDialect Linalg.Generic Operations.
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SparseTensor/invalid.mlir b/mlir/test/Dialect/SparseTensor/invalid.mlir
index 68f3cf3b7c5e6..8df924c0b0404 100644
--- a/mlir/test/Dialect/SparseTensor/invalid.mlir
+++ b/mlir/test/Dialect/SparseTensor/invalid.mlir
@@ -1,7 +1,7 @@
// RUN: mlir-opt %s -split-input-file -verify-diagnostics
func.func @invalid_new_dense(%arg0: !llvm.ptr<i8>) -> tensor<32xf32> {
- // expected-error at +1 {{expected a sparse tensor result}}
+ // expected-error at +1 {{'sparse_tensor.new' op result #0 must be sparse tensor of any type values, but got 'tensor<32xf32>'}}
%0 = sparse_tensor.new %arg0 : !llvm.ptr<i8> to tensor<32xf32>
return %0 : tensor<32xf32>
}
@@ -9,7 +9,7 @@ func.func @invalid_new_dense(%arg0: !llvm.ptr<i8>) -> tensor<32xf32> {
// -----
func.func @invalid_release_dense(%arg0: tensor<4xi32>) {
- // expected-error at +1 {{expected a sparse tensor to release}}
+ // expected-error at +1 {{'sparse_tensor.release' op operand #0 must be sparse tensor of any type values, but got 'tensor<4xi32>'}}
sparse_tensor.release %arg0 : tensor<4xi32>
return
}
@@ -18,7 +18,7 @@ func.func @invalid_release_dense(%arg0: tensor<4xi32>) {
func.func @invalid_pointers_dense(%arg0: tensor<128xf64>) -> memref<?xindex> {
%c = arith.constant 0 : index
- // expected-error at +1 {{expected a sparse tensor to get pointers}}
+ // expected-error at +1 {{'sparse_tensor.pointers' op operand #0 must be sparse tensor of any type values, but got 'tensor<128xf64>'}}
%0 = sparse_tensor.pointers %arg0, %c : tensor<128xf64> to memref<?xindex>
return %0 : memref<?xindex>
}
@@ -27,7 +27,7 @@ func.func @invalid_pointers_dense(%arg0: tensor<128xf64>) -> memref<?xindex> {
func.func @invalid_pointers_unranked(%arg0: tensor<*xf64>) -> memref<?xindex> {
%c = arith.constant 0 : index
- // expected-error at +1 {{expected a sparse tensor to get pointers}}
+ // expected-error at +1 {{'sparse_tensor.pointers' op operand #0 must be sparse tensor of any type values, but got 'tensor<*xf64>'}}
%0 = sparse_tensor.pointers %arg0, %c : tensor<*xf64> to memref<?xindex>
return %0 : memref<?xindex>
}
@@ -58,7 +58,7 @@ func.func @pointers_oob(%arg0: tensor<128xf64, #SparseVector>) -> memref<?xindex
func.func @invalid_indices_dense(%arg0: tensor<10x10xi32>) -> memref<?xindex> {
%c = arith.constant 1 : index
- // expected-error at +1 {{expected a sparse tensor to get indices}}
+ // expected-error at +1 {{'sparse_tensor.indices' op operand #0 must be sparse tensor of any type values, but got 'tensor<10x10xi32>'}}
%0 = sparse_tensor.indices %arg0, %c : tensor<10x10xi32> to memref<?xindex>
return %0 : memref<?xindex>
}
@@ -67,7 +67,7 @@ func.func @invalid_indices_dense(%arg0: tensor<10x10xi32>) -> memref<?xindex> {
func.func @invalid_indices_unranked(%arg0: tensor<*xf64>) -> memref<?xindex> {
%c = arith.constant 0 : index
- // expected-error at +1 {{expected a sparse tensor to get indices}}
+ // expected-error at +1 {{'sparse_tensor.indices' op operand #0 must be sparse tensor of any type values, but got 'tensor<*xf64>'}}
%0 = sparse_tensor.indices %arg0, %c : tensor<*xf64> to memref<?xindex>
return %0 : memref<?xindex>
}
@@ -97,7 +97,7 @@ func.func @indices_oob(%arg0: tensor<128xf64, #SparseVector>) -> memref<?xindex>
// -----
func.func @invalid_values_dense(%arg0: tensor<1024xf32>) -> memref<?xf32> {
- // expected-error at +1 {{expected a sparse tensor to get values}}
+ // expected-error at +1 {{'sparse_tensor.values' op operand #0 must be sparse tensor of any type values, but got 'tensor<1024xf32>'}}
%0 = sparse_tensor.values %arg0 : tensor<1024xf32> to memref<?xf32>
return %0 : memref<?xf32>
}
@@ -115,7 +115,7 @@ func.func @mismatch_values_types(%arg0: tensor<?xf64, #SparseVector>) -> memref<
// -----
func.func @sparse_unannotated_load(%arg0: tensor<16x32xf64>) -> tensor<16x32xf64> {
- // expected-error at +1 {{expected a sparse tensor to materialize}}
+ // expected-error at +1 {{'sparse_tensor.load' op operand #0 must be sparse tensor of any type values, but got 'tensor<16x32xf64>'}}
%0 = sparse_tensor.load %arg0 : tensor<16x32xf64>
return %0 : tensor<16x32xf64>
}
@@ -123,7 +123,7 @@ func.func @sparse_unannotated_load(%arg0: tensor<16x32xf64>) -> tensor<16x32xf64
// -----
func.func @sparse_unannotated_insert(%arg0: tensor<128xf64>, %arg1: memref<?xindex>, %arg2: f64) {
- // expected-error at +1 {{expected a sparse tensor for insertion}}
+ // expected-error at +1 {{'sparse_tensor.lex_insert' op operand #0 must be sparse tensor of any type values, but got 'tensor<128xf64>'}}
sparse_tensor.lex_insert %arg0, %arg1, %arg2 : tensor<128xf64>, memref<?xindex>, f64
return
}
@@ -131,7 +131,7 @@ func.func @sparse_unannotated_insert(%arg0: tensor<128xf64>, %arg1: memref<?xind
// -----
func.func @sparse_unannotated_expansion(%arg0: tensor<128xf64>) {
- // expected-error at +1 {{expected a sparse tensor for expansion}}
+ // expected-error at +1 {{'sparse_tensor.expand' op operand #0 must be sparse tensor of any type values, but got 'tensor<128xf64>'}}
%values, %filled, %added, %count = sparse_tensor.expand %arg0
: tensor<128xf64> to memref<?xf64>, memref<?xi1>, memref<?xindex>, index
return
@@ -142,7 +142,7 @@ func.func @sparse_unannotated_expansion(%arg0: tensor<128xf64>) {
func.func @sparse_unannotated_compression(%arg0: tensor<128xf64>, %arg1: memref<?xindex>,
%arg2: memref<?xf64>, %arg3: memref<?xi1>,
%arg4: memref<?xindex>, %arg5: index) {
- // expected-error at +1 {{expected a sparse tensor for compression}}
+ // expected-error at +1 {{'sparse_tensor.compress' op operand #0 must be sparse tensor of any type values, but got 'tensor<128xf64>'}}
sparse_tensor.compress %arg0, %arg1, %arg2, %arg3, %arg4, %arg5
: tensor<128xf64>, memref<?xindex>, memref<?xf64>, memref<?xi1>, memref<?xindex>, index
}
@@ -178,7 +178,7 @@ func.func @sparse_convert_dim_mismatch(%arg0: tensor<10x?xf32>) -> tensor<10x10x
// -----
func.func @invalid_out_dense(%arg0: tensor<10xf64>, %arg1: !llvm.ptr<i8>) {
- // expected-error at +1 {{expected a sparse tensor for output}}
+ // expected-error at +1 {{'sparse_tensor.out' op operand #0 must be sparse tensor of any type values, but got 'tensor<10xf64>'}}
sparse_tensor.out %arg0, %arg1 : tensor<10xf64>, !llvm.ptr<i8>
return
}
More information about the Mlir-commits
mailing list