[Mlir-commits] [mlir] 3cf03f1 - [mlir][sparse] Adding IsSparseTensorPred and updating ops to use it

wren romano llvmlistbot at llvm.org
Fri Jun 3 17:15:38 PDT 2022


Author: wren romano
Date: 2022-06-03T17:15:31-07:00
New Revision: 3cf03f1c562f25e76309ab904ed830dfaebf74dc

URL: https://github.com/llvm/llvm-project/commit/3cf03f1c562f25e76309ab904ed830dfaebf74dc
DIFF: https://github.com/llvm/llvm-project/commit/3cf03f1c562f25e76309ab904ed830dfaebf74dc.diff

LOG: [mlir][sparse] Adding IsSparseTensorPred and updating ops to use it

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D126994

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
    mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
    mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
    mlir/test/Dialect/SparseTensor/invalid.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
index 2db552f5fe039..c0c3d4920a1fe 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
@@ -93,4 +93,28 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
   }];
 }
 
+def IsSparseTensorPred
+  : CPred<"!!::mlir::sparse_tensor::getSparseTensorEncoding($_self)">;
+
+// The following four follow the same idiom as `TensorOf`, `AnyTensor`,
+// `RankedTensorOf`, `AnyRankedTensor`.
+
+class SparseTensorOf<list<Type> allowedTypes>
+  : ShapedContainerType<
+      allowedTypes,
+      And<[IsTensorTypePred, IsSparseTensorPred]>,
+      "sparse tensor",
+      "::mlir::TensorType">;
+
+def AnySparseTensor : SparseTensorOf<[AnyType]>;
+
+class RankedSparseTensorOf<list<Type> allowedTypes>
+  : ShapedContainerType<
+      allowedTypes,
+      And<[IsTensorTypePred, HasRankPred, IsSparseTensorPred]>,
+      "ranked sparse tensor",
+      "::mlir::TensorType">;
+
+def AnyRankedSparseTensor : RankedSparseTensorOf<[AnyType]>;
+
 #endif // SPARSETENSOR_ATTRDEFS

diff  --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index 4f31031b1fe84..bdc27b57fe10f 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -27,7 +27,7 @@ class SparseTensor_Op<string mnemonic, list<Trait> traits = []>
 
 def SparseTensor_NewOp : SparseTensor_Op<"new", [NoSideEffect]>,
     Arguments<(ins AnyType:$source)>,
-    Results<(outs TensorOf<[AnyType]>:$result)> {
+    Results<(outs AnySparseTensor:$result)> {
   string summary = "Materializes a new sparse tensor from given source";
   string description = [{
     Materializes a sparse tensor with contents taken from an opaque pointer
@@ -46,7 +46,6 @@ def SparseTensor_NewOp : SparseTensor_Op<"new", [NoSideEffect]>,
     ```
   }];
   let assemblyFormat = "$source attr-dict `:` type($source) `to` type($result)";
-  let hasVerifier = 1;
 }
 
 def SparseTensor_ConvertOp : SparseTensor_Op<"convert",
@@ -92,7 +91,7 @@ def SparseTensor_ConvertOp : SparseTensor_Op<"convert",
 }
 
 def SparseTensor_ToPointersOp : SparseTensor_Op<"pointers", [NoSideEffect]>,
-    Arguments<(ins AnyTensor:$tensor, Index:$dim)>,
+    Arguments<(ins AnySparseTensor:$tensor, Index:$dim)>,
     Results<(outs AnyStridedMemRefOfRank<1>:$result)> {
   let summary = "Extracts pointers array at given dimension from a tensor";
   let description = [{
@@ -117,7 +116,7 @@ def SparseTensor_ToPointersOp : SparseTensor_Op<"pointers", [NoSideEffect]>,
 }
 
 def SparseTensor_ToIndicesOp : SparseTensor_Op<"indices", [NoSideEffect]>,
-    Arguments<(ins AnyTensor:$tensor, Index:$dim)>,
+    Arguments<(ins AnySparseTensor:$tensor, Index:$dim)>,
     Results<(outs AnyStridedMemRefOfRank<1>:$result)> {
   let summary = "Extracts indices array at given dimension from a tensor";
   let description = [{
@@ -142,7 +141,7 @@ def SparseTensor_ToIndicesOp : SparseTensor_Op<"indices", [NoSideEffect]>,
 }
 
 def SparseTensor_ToValuesOp : SparseTensor_Op<"values", [NoSideEffect]>,
-    Arguments<(ins AnyTensor:$tensor)>,
+    Arguments<(ins AnySparseTensor:$tensor)>,
     Results<(outs AnyStridedMemRefOfRank<1>:$result)> {
   let summary = "Extracts numerical values array from a tensor";
   let description = [{
@@ -173,7 +172,7 @@ def SparseTensor_ToValuesOp : SparseTensor_Op<"values", [NoSideEffect]>,
 //===----------------------------------------------------------------------===//
 
 def SparseTensor_LexInsertOp : SparseTensor_Op<"lex_insert", []>,
-    Arguments<(ins AnyTensor:$tensor,
+    Arguments<(ins AnySparseTensor:$tensor,
                StridedMemRefRankOf<[Index], [1]>:$indices,
                AnyType:$value)> {
   string summary = "Inserts a value into given sparse tensor in lexicographical index order";
@@ -196,11 +195,10 @@ def SparseTensor_LexInsertOp : SparseTensor_Op<"lex_insert", []>,
   }];
   let assemblyFormat = "$tensor `,` $indices `,` $value attr-dict `:`"
                        " type($tensor) `,` type($indices) `,` type($value)";
-  let hasVerifier = 1;
 }
 
 def SparseTensor_ExpandOp : SparseTensor_Op<"expand", []>,
-    Arguments<(ins AnyTensor:$tensor)>,
+    Arguments<(ins AnySparseTensor:$tensor)>,
     Results<(outs AnyStridedMemRefOfRank<1>:$values,
                   StridedMemRefRankOf<[I1],[1]>:$filled,
                   StridedMemRefRankOf<[Index],[1]>:$added,
@@ -238,11 +236,10 @@ def SparseTensor_ExpandOp : SparseTensor_Op<"expand", []>,
   }];
   let assemblyFormat = "$tensor attr-dict `:` type($tensor) `to` type($values)"
                        " `,` type($filled) `,` type($added) `,` type($count)";
-  let hasVerifier = 1;
 }
 
 def SparseTensor_CompressOp : SparseTensor_Op<"compress", []>,
-    Arguments<(ins AnyTensor:$tensor,
+    Arguments<(ins AnySparseTensor:$tensor,
                    StridedMemRefRankOf<[Index],[1]>:$indices,
                    AnyStridedMemRefOfRank<1>:$values,
                    StridedMemRefRankOf<[I1],[1]>:$filled,
@@ -273,11 +270,10 @@ def SparseTensor_CompressOp : SparseTensor_Op<"compress", []>,
                         " $added `,` $count attr-dict `:` type($tensor) `,`"
 			" type($indices) `,` type($values) `,` type($filled) `,`"
 			" type($added) `,` type($count)";
-  let hasVerifier = 1;
 }
 
 def SparseTensor_LoadOp : SparseTensor_Op<"load", [SameOperandsAndResultType]>,
-    Arguments<(ins AnyTensor:$tensor, UnitAttr:$hasInserts)>,
+    Arguments<(ins AnySparseTensor:$tensor, UnitAttr:$hasInserts)>,
     Results<(outs AnyTensor:$result)> {
   let summary =
     "Rematerializes tensor from underlying sparse storage format";
@@ -306,11 +302,10 @@ def SparseTensor_LoadOp : SparseTensor_Op<"load", [SameOperandsAndResultType]>,
     ```
   }];
   let assemblyFormat = "$tensor (`hasInserts` $hasInserts^)? attr-dict `:` type($tensor)";
-  let hasVerifier = 1;
 }
 
 def SparseTensor_ReleaseOp : SparseTensor_Op<"release", []>,
-    Arguments<(ins AnyTensor:$tensor)> {
+    Arguments<(ins AnySparseTensor:$tensor)> {
   string summary = "Releases underlying sparse storage format of given tensor";
   string description = [{
     Releases the underlying sparse storage format for a tensor that
@@ -332,11 +327,10 @@ def SparseTensor_ReleaseOp : SparseTensor_Op<"release", []>,
     ```
   }];
   let assemblyFormat = "$tensor attr-dict `:` type($tensor)";
-  let hasVerifier = 1;
 }
 
 def SparseTensor_OutOp : SparseTensor_Op<"out", []>,
-    Arguments<(ins AnyType:$tensor, AnyType:$dest)> {
+    Arguments<(ins AnySparseTensor:$tensor, AnyType:$dest)> {
   string summary = "Outputs a sparse tensor to the given destination";
   string description = [{
     Outputs the contents of a sparse tensor to the destination defined by an
@@ -353,7 +347,6 @@ def SparseTensor_OutOp : SparseTensor_Op<"out", []>,
     ```
   }];
   let assemblyFormat = "$tensor `,` $dest attr-dict `:` type($tensor) `,` type($dest)";
-  let hasVerifier = 1;
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index b860f07528956..418e7fe3bd822 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -208,12 +208,6 @@ static LogicalResult isMatchingWidth(Value result, unsigned width) {
   return failure();
 }
 
-LogicalResult NewOp::verify() {
-  if (!getSparseTensorEncoding(result().getType()))
-    return emitError("expected a sparse tensor result");
-  return success();
-}
-
 LogicalResult ConvertOp::verify() {
   if (auto tp1 = source().getType().dyn_cast<RankedTensorType>()) {
     if (auto tp2 = dest().getType().dyn_cast<RankedTensorType>()) {
@@ -240,30 +234,24 @@ OpFoldResult ConvertOp::fold(ArrayRef<Attribute> operands) {
 }
 
 LogicalResult ToPointersOp::verify() {
-  if (auto e = getSparseTensorEncoding(tensor().getType())) {
-    if (failed(isInBounds(dim(), tensor())))
-      return emitError("requested pointers dimension out of bounds");
-    if (failed(isMatchingWidth(result(), e.getPointerBitWidth())))
-      return emitError("unexpected type for pointers");
-    return success();
-  }
-  return emitError("expected a sparse tensor to get pointers");
+  auto e = getSparseTensorEncoding(tensor().getType());
+  if (failed(isInBounds(dim(), tensor())))
+    return emitError("requested pointers dimension out of bounds");
+  if (failed(isMatchingWidth(result(), e.getPointerBitWidth())))
+    return emitError("unexpected type for pointers");
+  return success();
 }
 
 LogicalResult ToIndicesOp::verify() {
-  if (auto e = getSparseTensorEncoding(tensor().getType())) {
-    if (failed(isInBounds(dim(), tensor())))
-      return emitError("requested indices dimension out of bounds");
-    if (failed(isMatchingWidth(result(), e.getIndexBitWidth())))
-      return emitError("unexpected type for indices");
-    return success();
-  }
-  return emitError("expected a sparse tensor to get indices");
+  auto e = getSparseTensorEncoding(tensor().getType());
+  if (failed(isInBounds(dim(), tensor())))
+    return emitError("requested indices dimension out of bounds");
+  if (failed(isMatchingWidth(result(), e.getIndexBitWidth())))
+    return emitError("unexpected type for indices");
+  return success();
 }
 
 LogicalResult ToValuesOp::verify() {
-  if (!getSparseTensorEncoding(tensor().getType()))
-    return emitError("expected a sparse tensor to get values");
   RankedTensorType ttp = tensor().getType().cast<RankedTensorType>();
   MemRefType mtp = result().getType().cast<MemRefType>();
   if (ttp.getElementType() != mtp.getElementType())
@@ -271,46 +259,6 @@ LogicalResult ToValuesOp::verify() {
   return success();
 }
 
-//===----------------------------------------------------------------------===//
-// TensorDialect Management Operations.
-//===----------------------------------------------------------------------===//
-
-LogicalResult LexInsertOp::verify() {
-  if (!getSparseTensorEncoding(tensor().getType()))
-    return emitError("expected a sparse tensor for insertion");
-  return success();
-}
-
-LogicalResult ExpandOp::verify() {
-  if (!getSparseTensorEncoding(tensor().getType()))
-    return emitError("expected a sparse tensor for expansion");
-  return success();
-}
-
-LogicalResult CompressOp::verify() {
-  if (!getSparseTensorEncoding(tensor().getType()))
-    return emitError("expected a sparse tensor for compression");
-  return success();
-}
-
-LogicalResult LoadOp::verify() {
-  if (!getSparseTensorEncoding(tensor().getType()))
-    return emitError("expected a sparse tensor to materialize");
-  return success();
-}
-
-LogicalResult ReleaseOp::verify() {
-  if (!getSparseTensorEncoding(tensor().getType()))
-    return emitError("expected a sparse tensor to release");
-  return success();
-}
-
-LogicalResult OutOp::verify() {
-  if (!getSparseTensorEncoding(tensor().getType()))
-    return emitError("expected a sparse tensor for output");
-  return success();
-}
-
 //===----------------------------------------------------------------------===//
 // TensorDialect Linalg.Generic Operations.
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/SparseTensor/invalid.mlir b/mlir/test/Dialect/SparseTensor/invalid.mlir
index 68f3cf3b7c5e6..8df924c0b0404 100644
--- a/mlir/test/Dialect/SparseTensor/invalid.mlir
+++ b/mlir/test/Dialect/SparseTensor/invalid.mlir
@@ -1,7 +1,7 @@
 // RUN: mlir-opt %s -split-input-file -verify-diagnostics
 
 func.func @invalid_new_dense(%arg0: !llvm.ptr<i8>) -> tensor<32xf32> {
-  // expected-error at +1 {{expected a sparse tensor result}}
+  // expected-error at +1 {{'sparse_tensor.new' op result #0 must be sparse tensor of any type values, but got 'tensor<32xf32>'}}
   %0 = sparse_tensor.new %arg0 : !llvm.ptr<i8> to tensor<32xf32>
   return %0 : tensor<32xf32>
 }
@@ -9,7 +9,7 @@ func.func @invalid_new_dense(%arg0: !llvm.ptr<i8>) -> tensor<32xf32> {
 // -----
 
 func.func @invalid_release_dense(%arg0: tensor<4xi32>) {
-  // expected-error at +1 {{expected a sparse tensor to release}}
+  // expected-error at +1 {{'sparse_tensor.release' op operand #0 must be sparse tensor of any type values, but got 'tensor<4xi32>'}}
   sparse_tensor.release %arg0 : tensor<4xi32>
   return
 }
@@ -18,7 +18,7 @@ func.func @invalid_release_dense(%arg0: tensor<4xi32>) {
 
 func.func @invalid_pointers_dense(%arg0: tensor<128xf64>) -> memref<?xindex> {
   %c = arith.constant 0 : index
-  // expected-error at +1 {{expected a sparse tensor to get pointers}}
+  // expected-error at +1 {{'sparse_tensor.pointers' op operand #0 must be sparse tensor of any type values, but got 'tensor<128xf64>'}}
   %0 = sparse_tensor.pointers %arg0, %c : tensor<128xf64> to memref<?xindex>
   return %0 : memref<?xindex>
 }
@@ -27,7 +27,7 @@ func.func @invalid_pointers_dense(%arg0: tensor<128xf64>) -> memref<?xindex> {
 
 func.func @invalid_pointers_unranked(%arg0: tensor<*xf64>) -> memref<?xindex> {
   %c = arith.constant 0 : index
-  // expected-error at +1 {{expected a sparse tensor to get pointers}}
+  // expected-error at +1 {{'sparse_tensor.pointers' op operand #0 must be sparse tensor of any type values, but got 'tensor<*xf64>'}}
   %0 = sparse_tensor.pointers %arg0, %c : tensor<*xf64> to memref<?xindex>
   return %0 : memref<?xindex>
 }
@@ -58,7 +58,7 @@ func.func @pointers_oob(%arg0: tensor<128xf64, #SparseVector>) -> memref<?xindex
 
 func.func @invalid_indices_dense(%arg0: tensor<10x10xi32>) -> memref<?xindex> {
   %c = arith.constant 1 : index
-  // expected-error at +1 {{expected a sparse tensor to get indices}}
+  // expected-error at +1 {{'sparse_tensor.indices' op operand #0 must be sparse tensor of any type values, but got 'tensor<10x10xi32>'}}
   %0 = sparse_tensor.indices %arg0, %c : tensor<10x10xi32> to memref<?xindex>
   return %0 : memref<?xindex>
 }
@@ -67,7 +67,7 @@ func.func @invalid_indices_dense(%arg0: tensor<10x10xi32>) -> memref<?xindex> {
 
 func.func @invalid_indices_unranked(%arg0: tensor<*xf64>) -> memref<?xindex> {
   %c = arith.constant 0 : index
-  // expected-error at +1 {{expected a sparse tensor to get indices}}
+  // expected-error at +1 {{'sparse_tensor.indices' op operand #0 must be sparse tensor of any type values, but got 'tensor<*xf64>'}}
   %0 = sparse_tensor.indices %arg0, %c : tensor<*xf64> to memref<?xindex>
   return %0 : memref<?xindex>
 }
@@ -97,7 +97,7 @@ func.func @indices_oob(%arg0: tensor<128xf64, #SparseVector>) -> memref<?xindex>
 // -----
 
 func.func @invalid_values_dense(%arg0: tensor<1024xf32>) -> memref<?xf32> {
-  // expected-error at +1 {{expected a sparse tensor to get values}}
+  // expected-error at +1 {{'sparse_tensor.values' op operand #0 must be sparse tensor of any type values, but got 'tensor<1024xf32>'}}
   %0 = sparse_tensor.values %arg0 : tensor<1024xf32> to memref<?xf32>
   return %0 : memref<?xf32>
 }
@@ -115,7 +115,7 @@ func.func @mismatch_values_types(%arg0: tensor<?xf64, #SparseVector>) -> memref<
 // -----
 
 func.func @sparse_unannotated_load(%arg0: tensor<16x32xf64>) -> tensor<16x32xf64> {
-  // expected-error at +1 {{expected a sparse tensor to materialize}}
+  // expected-error at +1 {{'sparse_tensor.load' op operand #0 must be sparse tensor of any type values, but got 'tensor<16x32xf64>'}}
   %0 = sparse_tensor.load %arg0 : tensor<16x32xf64>
   return %0 : tensor<16x32xf64>
 }
@@ -123,7 +123,7 @@ func.func @sparse_unannotated_load(%arg0: tensor<16x32xf64>) -> tensor<16x32xf64
 // -----
 
 func.func @sparse_unannotated_insert(%arg0: tensor<128xf64>, %arg1: memref<?xindex>, %arg2: f64) {
-  // expected-error at +1 {{expected a sparse tensor for insertion}}
+  // expected-error at +1 {{'sparse_tensor.lex_insert' op operand #0 must be sparse tensor of any type values, but got 'tensor<128xf64>'}}
   sparse_tensor.lex_insert %arg0, %arg1, %arg2 : tensor<128xf64>, memref<?xindex>, f64
   return
 }
@@ -131,7 +131,7 @@ func.func @sparse_unannotated_insert(%arg0: tensor<128xf64>, %arg1: memref<?xind
 // -----
 
 func.func @sparse_unannotated_expansion(%arg0: tensor<128xf64>) {
-  // expected-error at +1 {{expected a sparse tensor for expansion}}
+  // expected-error at +1 {{'sparse_tensor.expand' op operand #0 must be sparse tensor of any type values, but got 'tensor<128xf64>'}}
   %values, %filled, %added, %count = sparse_tensor.expand %arg0
     : tensor<128xf64> to memref<?xf64>, memref<?xi1>, memref<?xindex>, index
   return
@@ -142,7 +142,7 @@ func.func @sparse_unannotated_expansion(%arg0: tensor<128xf64>) {
 func.func @sparse_unannotated_compression(%arg0: tensor<128xf64>, %arg1: memref<?xindex>,
                                      %arg2: memref<?xf64>, %arg3: memref<?xi1>,
 				     %arg4: memref<?xindex>, %arg5: index) {
-  // expected-error at +1 {{expected a sparse tensor for compression}}
+  // expected-error at +1 {{'sparse_tensor.compress' op operand #0 must be sparse tensor of any type values, but got 'tensor<128xf64>'}}
   sparse_tensor.compress %arg0, %arg1, %arg2, %arg3, %arg4, %arg5
     : tensor<128xf64>, memref<?xindex>, memref<?xf64>, memref<?xi1>, memref<?xindex>, index
 }
@@ -178,7 +178,7 @@ func.func @sparse_convert_dim_mismatch(%arg0: tensor<10x?xf32>) -> tensor<10x10x
 // -----
 
 func.func @invalid_out_dense(%arg0: tensor<10xf64>, %arg1: !llvm.ptr<i8>) {
-  // expected-error at +1 {{expected a sparse tensor for output}}
+  // expected-error at +1 {{'sparse_tensor.out' op operand #0 must be sparse tensor of any type values, but got 'tensor<10xf64>'}}
   sparse_tensor.out %arg0, %arg1 : tensor<10xf64>, !llvm.ptr<i8>
   return
 }


        


More information about the Mlir-commits mailing list