[Mlir-commits] [mlir] [mlir][sparse] Improve sparse tensor type constraints (PR #112133)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sun Oct 13 05:57:13 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Matthias Springer (matthias-springer)

<details>
<summary>Changes</summary>

Sparse tensors are always ranked tensors. Encodings cannot be attached to unranked tensors. Change the type constraint to `RankedTensorOf`, so that we generate `TypedValue<RankedTensorType>` instead of `TypedValue<TensorType>`. This removes the need for type casting in some cases.

Also improve the verifiers (missing `return` statements) and switch a few other `AnyTensor` to `AnyRankedTensor`.


---
Full diff: https://github.com/llvm/llvm-project/pull/112133.diff


4 Files Affected:

- (modified) mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td (+2-2) 
- (modified) mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td (+11-11) 
- (modified) mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp (+48-43) 
- (modified) mlir/test/Dialect/SparseTensor/invalid.mlir (+3-3) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
index cb6c1b63e4e4b0..adcf6fac752fe6 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
@@ -586,10 +586,10 @@ def IsSparseTensorSlicePred
           "  ::mlir::sparse_tensor::getSparseTensorEncoding($_self).isSlice()">;
 
 class SparseTensorOf<list<Type> allowedTypes>
-  : TensorOf<allowedTypes, [IsSparseTensorPred], "sparse tensor">;
+  : RankedTensorOf<allowedTypes, [IsSparseTensorPred], "sparse tensor">;
 
 class SparseTensorSliceOf<list<Type> allowedTypes>
-  : TensorOf<allowedTypes, [IsSparseTensorSlicePred], "sparse tensor slice">;
+  : RankedTensorOf<allowedTypes, [IsSparseTensorSlicePred], "sparse tensor slice">;
 
 class ScalarLikeOf<list<Type> allowedTypes>
   : AnyTypeOf<[0DTensorOf<allowedTypes>, AnyTypeOf<allowedTypes>], "scalar like">;
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index 96a61419a541f7..2c281c9f6aa85d 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -92,8 +92,8 @@ def SparseTensor_AssembleOp : SparseTensor_Op<"assemble", [Pure]> {
     ```
   }];
 
-  let arguments = (ins Variadic<TensorOf<[AnySignlessIntegerOrIndex]>>:$levels,
-                   TensorOf<[AnyType]>:$values);
+  let arguments = (ins Variadic<RankedTensorOf<[AnySignlessIntegerOrIndex]>>:$levels,
+                       RankedTensorOf<[AnyType]>:$values);
   let results = (outs AnySparseTensor: $result);
   let assemblyFormat =
     "` ` `(` $levels       `)` `,` $values attr-dict `:`"
@@ -138,12 +138,12 @@ def SparseTensor_DisassembleOp : SparseTensor_Op<"disassemble", [Pure, SameVaria
   }];
 
   let arguments = (ins AnySparseTensor:$tensor,
-                   Variadic<TensorOf<[AnySignlessIntegerOrIndex]>>:$out_levels,
-                   TensorOf<[AnyType]>:$out_values);
-  let results = (outs Variadic<TensorOf<[AnySignlessIntegerOrIndex]>>:$ret_levels,
-                  TensorOf<[AnyType]>:$ret_values,
-                  Variadic<AnyIndexingScalarLike>:$lvl_lens,
-                  AnyIndexingScalarLike:$val_len);
+                       Variadic<RankedTensorOf<[AnySignlessIntegerOrIndex]>>:$out_levels,
+                       RankedTensorOf<[AnyType]>:$out_values);
+  let results = (outs Variadic<RankedTensorOf<[AnySignlessIntegerOrIndex]>>:$ret_levels,
+                      RankedTensorOf<[AnyType]>:$ret_values,
+                      Variadic<AnyIndexingScalarLike>:$lvl_lens,
+                      AnyIndexingScalarLike:$val_len);
   let assemblyFormat =
     "$tensor attr-dict `:` type($tensor)"
     "`out_lvls` `(` $out_levels `:` type($out_levels) `)` "
@@ -196,8 +196,8 @@ def SparseTensor_ConvertOp : SparseTensor_Op<"convert",
 
   }];
 
-  let arguments = (ins AnyTensor:$source);
-  let results = (outs AnyTensor:$dest);
+  let arguments = (ins AnyRankedTensor:$source);
+  let results = (outs AnyRankedTensor:$dest);
   let assemblyFormat = "$source attr-dict `:` type($source) `to` type($dest)";
 
   let extraClassDeclaration = [{
@@ -1447,7 +1447,7 @@ def SparseTensor_ForeachOp : SparseTensor_Op<"foreach",
   ];
 
   let regions = (region SizedRegion<1>:$region);
-  let arguments = (ins AnyTensor:$tensor,
+  let arguments = (ins AnyRankedTensor:$tensor,
                        Variadic<AnyType>:$initArgs,
                        OptionalAttr<AffineMapAttr>:$order);
   let results = (outs Variadic<AnyType>:$results);
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index b21bc1a93036c4..7b1b1f383e6343 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -1310,7 +1310,7 @@ static LogicalResult verifyPackUnPack(Operation *op, bool requiresStaticShape,
     // The coordinates should be in shape of <? x rank>
     unsigned expCOORank = stt.getLvlRank() - cooStartLvl;
     if (cooTp.getRank() != 2 || expCOORank != cooTp.getShape().back()) {
-      op->emitError("input/output trailing COO level-ranks don't match");
+      return op->emitError("input/output trailing COO level-ranks don't match");
     }
   }
 
@@ -1350,7 +1350,7 @@ static LogicalResult verifyPackUnPack(Operation *op, bool requiresStaticShape,
 }
 
 LogicalResult AssembleOp::verify() {
-  const auto valuesTp = getRankedTensorType(getValues());
+  RankedTensorType valuesTp = getValues().getType();
   const auto lvlsTp = getLevels().getTypes();
   const auto resTp = getSparseTensorType(getResult());
   return verifyPackUnPack(*this, true, resTp, valuesTp, lvlsTp);
@@ -1364,34 +1364,31 @@ LogicalResult DisassembleOp::verify() {
     if (ot.getType() != rt.getType())
       return emitError("output levels and return levels type mismatch");
 
-  const auto valuesTp = getRankedTensorType(getRetValues());
+  RankedTensorType valuesTp = getRetValues().getType();
   const auto lvlsTp = getRetLevels().getTypes();
   const auto srcTp = getSparseTensorType(getTensor());
   return verifyPackUnPack(*this, false, srcTp, valuesTp, lvlsTp);
 }
 
 LogicalResult ConvertOp::verify() {
-  if (auto tp1 = llvm::dyn_cast<RankedTensorType>(getSource().getType())) {
-    if (auto tp2 = llvm::dyn_cast<RankedTensorType>(getDest().getType())) {
-      if (tp1.getRank() != tp2.getRank())
-        return emitError("unexpected conversion mismatch in rank");
-      auto dstEnc =
-          llvm::dyn_cast_or_null<SparseTensorEncodingAttr>(tp2.getEncoding());
-      if (dstEnc && dstEnc.isSlice())
-        return emitError("cannot convert to a sparse tensor slice");
-
-      auto shape1 = tp1.getShape();
-      auto shape2 = tp2.getShape();
-      // Accept size matches between the source and the destination type
-      // (e.g. 10 vs. 10, 10 vs. ?, or ? vs. ?), but reject direct mismatches or
-      // matches that would need a runtime assert (e.g. 10 vs. 20 or ? vs. 10).
-      for (Dimension d = 0, dimRank = tp1.getRank(); d < dimRank; d++)
-        if (shape1[d] != shape2[d] && shape2[d] != ShapedType::kDynamic)
-          return emitError("unexpected conversion mismatch in dimension ") << d;
-      return success();
-    }
-  }
-  return emitError("unexpected type in convert");
+  RankedTensorType tp1 = getSource().getType();
+  RankedTensorType tp2 = getDest().getType();
+  if (tp1.getRank() != tp2.getRank())
+    return emitError("unexpected conversion mismatch in rank");
+  auto dstEnc =
+      llvm::dyn_cast_or_null<SparseTensorEncodingAttr>(tp2.getEncoding());
+  if (dstEnc && dstEnc.isSlice())
+    return emitError("cannot convert to a sparse tensor slice");
+
+  auto shape1 = tp1.getShape();
+  auto shape2 = tp2.getShape();
+  // Accept size matches between the source and the destination type
+  // (e.g. 10 vs. 10, 10 vs. ?, or ? vs. ?), but reject direct mismatches or
+  // matches that would need a runtime assert (e.g. 10 vs. 20 or ? vs. 10).
+  for (Dimension d = 0, dimRank = tp1.getRank(); d < dimRank; d++)
+    if (shape1[d] != shape2[d] && shape2[d] != ShapedType::kDynamic)
+      return emitError("unexpected conversion mismatch in dimension ") << d;
+  return success();
 }
 
 OpFoldResult ConvertOp::fold(FoldAdaptor adaptor) {
@@ -1495,7 +1492,8 @@ LogicalResult LvlOp::verify() {
   if (std::optional<uint64_t> lvl = getConstantLvlIndex()) {
     auto stt = getSparseTensorType(getSource());
     if (static_cast<uint64_t>(lvl.value()) >= stt.getLvlRank())
-      emitError("Level index exceeds the rank of the input sparse tensor");
+      return emitError(
+          "Level index exceeds the rank of the input sparse tensor");
   }
   return success();
 }
@@ -1697,14 +1695,14 @@ LogicalResult ToValuesOp::inferReturnTypes(MLIRContext *ctx,
 }
 
 LogicalResult ToSliceOffsetOp::verify() {
-  auto rank = getRankedTensorType(getSlice()).getRank();
+  auto rank = getSlice().getType().getRank();
   if (rank <= getDim().getSExtValue() || getDim().getSExtValue() < 0)
     return emitError("requested dimension out of bound");
   return success();
 }
 
 LogicalResult ToSliceStrideOp::verify() {
-  auto rank = getRankedTensorType(getSlice()).getRank();
+  auto rank = getSlice().getType().getRank();
   if (rank <= getDim().getSExtValue() || getDim().getSExtValue() < 0)
     return emitError("requested dimension out of bound");
   return success();
@@ -1986,15 +1984,16 @@ LogicalResult ForeachOp::verify() {
   const auto iTp = IndexType::get(getContext());
   for (Dimension d = 0; d < dimRank; d++)
     if (args[d].getType() != iTp)
-      emitError(
+      return emitError(
           llvm::formatv("Expecting Index type for argument at index {0}", d));
 
   const auto elemTp = t.getElementType();
   const auto valueTp = args[dimRank].getType();
   if (elemTp != valueTp)
-    emitError(llvm::formatv("Unmatched element type between input tensor and "
-                            "block argument, expected:{0}, got: {1}",
-                            elemTp, valueTp));
+    return emitError(
+        llvm::formatv("Unmatched element type between input tensor and "
+                      "block argument, expected:{0}, got: {1}",
+                      elemTp, valueTp));
   return success();
 }
 
@@ -2011,15 +2010,15 @@ LogicalResult ReorderCOOOp::verify() {
   SparseTensorType dstStt = getSparseTensorType(getResultCoo());
 
   if (!srcStt.isCOOType() || !dstStt.isCOOType())
-    emitError("Expected COO sparse tensors only");
+    return emitError("Expected COO sparse tensors only");
 
   if (!srcStt.hasSameDimToLvl(dstStt))
-    emitError("Unmatched dim2lvl map between input and result COO");
+    return emitError("Unmatched dim2lvl map between input and result COO");
 
   if (srcStt.getPosType() != dstStt.getPosType() ||
       srcStt.getCrdType() != dstStt.getCrdType() ||
       srcStt.getElementType() != dstStt.getElementType())
-    emitError("Unmatched storage format between input and result COO");
+    return emitError("Unmatched storage format between input and result COO");
 
   return success();
 }
@@ -2044,10 +2043,11 @@ LogicalResult SortOp::verify() {
   AffineMap xPerm = getPermMap();
   uint64_t nx = xPerm.getNumDims();
   if (nx < 1)
-    emitError(llvm::formatv("Expected rank(perm_map) > 1, got {0}", nx));
+    return emitError(llvm::formatv("Expected rank(perm_map) > 1, got {0}", nx));
 
   if (!xPerm.isPermutation())
-    emitError(llvm::formatv("Expected a permutation map, got {0}", xPerm));
+    return emitError(
+        llvm::formatv("Expected a permutation map, got {0}", xPerm));
 
   // We can't check the size of the buffers when n or buffer dimensions aren't
   // compile-time constants.
@@ -2056,19 +2056,24 @@ LogicalResult SortOp::verify() {
     return success();
 
   // Verify dimensions.
-  const auto checkDim = [&](Value v, Size minSize, const char *message) {
+  const auto checkDim = [&](Value v, Size minSize,
+                            const char *message) -> LogicalResult {
     const Size sh = getMemRefType(v).getShape()[0];
     if (!ShapedType::isDynamic(sh) && sh < minSize)
-      emitError(llvm::formatv("{0} got {1} < {2}", message, sh, minSize));
+      return emitError(
+          llvm::formatv("{0} got {1} < {2}", message, sh, minSize));
+    return success();
   };
   uint64_t n = cn.value();
   uint64_t ny = 0;
   if (auto nyAttr = getNyAttr())
     ny = nyAttr.getInt();
-  checkDim(getXy(), n * (nx + ny),
-           "Expected dimension(xy) >= n * (rank(perm_map) + ny)");
+  if (failed(checkDim(getXy(), n * (nx + ny),
+                      "Expected dimension(xy) >= n * (rank(perm_map) + ny)")))
+    return failure();
   for (Value opnd : getYs())
-    checkDim(opnd, n, "Expected dimension(y) >= n");
+    if (failed(checkDim(opnd, n, "Expected dimension(y) >= n")))
+      return failure();
 
   return success();
 }
@@ -2101,8 +2106,8 @@ static ParseResult parseLevelRange(AsmParser &parser, Level &lvlLo,
   }
 
   if (lvlHi <= lvlLo)
-    parser.emitError(parser.getNameLoc(),
-                     "expect larger level upper bound than lower bound");
+    return parser.emitError(parser.getNameLoc(),
+                            "expect larger level upper bound than lower bound");
 
   return success();
 }
diff --git a/mlir/test/Dialect/SparseTensor/invalid.mlir b/mlir/test/Dialect/SparseTensor/invalid.mlir
index 737b736ba795fe..908d2d8aa83f7c 100644
--- a/mlir/test/Dialect/SparseTensor/invalid.mlir
+++ b/mlir/test/Dialect/SparseTensor/invalid.mlir
@@ -105,7 +105,7 @@ func.func @invalid_positions_dense(%arg0: tensor<128xf64>) -> memref<?xindex> {
 
 func.func @invalid_positions_unranked(%arg0: tensor<*xf64>) -> memref<?xindex> {
   // expected-error at +1 {{'sparse_tensor.positions' op operand #0 must be sparse tensor of any type values, but got 'tensor<*xf64>'}}
-  %0 = sparse_tensor.positions %arg0 { level = 0 : index } : tensor<*xf64> to memref<?xindex>
+  %0 = "sparse_tensor.positions"(%arg0) { level = 0 : index } : (tensor<*xf64>) -> (memref<?xindex>)
   return %0 : memref<?xindex>
 }
 
@@ -141,7 +141,7 @@ func.func @invalid_indices_dense(%arg0: tensor<10x10xi32>) -> memref<?xindex> {
 
 func.func @invalid_indices_unranked(%arg0: tensor<*xf64>) -> memref<?xindex> {
   // expected-error at +1 {{'sparse_tensor.coordinates' op operand #0 must be sparse tensor of any type values, but got 'tensor<*xf64>'}}
-  %0 = sparse_tensor.coordinates %arg0 { level = 0 : index } : tensor<*xf64> to memref<?xindex>
+  %0 = "sparse_tensor.coordinates"(%arg0) { level = 0 : index } : (tensor<*xf64>) -> (memref<?xindex>)
   return %0 : memref<?xindex>
 }
 
@@ -347,7 +347,7 @@ func.func @sparse_wrong_arity_compression(%arg0: memref<?xf64>,
 // -----
 
 func.func @sparse_convert_unranked(%arg0: tensor<*xf32>) -> tensor<10xf32> {
-  // expected-error at +1 {{unexpected type in convert}}
+  // expected-error at +1 {{invalid kind of type specified}}
   %0 = sparse_tensor.convert %arg0 : tensor<*xf32> to tensor<10xf32>
   return %0 : tensor<10xf32>
 }

``````````

</details>


https://github.com/llvm/llvm-project/pull/112133


More information about the Mlir-commits mailing list