[Mlir-commits] [mlir] [mlir][sparse] Improve sparse tensor type constraints (PR #112133)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sun Oct 13 05:57:13 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Matthias Springer (matthias-springer)
<details>
<summary>Changes</summary>
Sparse tensors are always ranked tensors. Encodings cannot be attached to unranked tensors. Change the type constraint to `RankedTensorOf`, so that we generate `TypedValue<RankedTensorType>` instead of `TypedValue<TensorType>`. This removes the need for type casting in some cases.
Also improve the verifiers (missing `return` statements) and switch a few other `AnyTensor` to `AnyRankedTensor`.
---
Full diff: https://github.com/llvm/llvm-project/pull/112133.diff
4 Files Affected:
- (modified) mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td (+2-2)
- (modified) mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td (+11-11)
- (modified) mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp (+48-43)
- (modified) mlir/test/Dialect/SparseTensor/invalid.mlir (+3-3)
``````````diff
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
index cb6c1b63e4e4b0..adcf6fac752fe6 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
@@ -586,10 +586,10 @@ def IsSparseTensorSlicePred
" ::mlir::sparse_tensor::getSparseTensorEncoding($_self).isSlice()">;
class SparseTensorOf<list<Type> allowedTypes>
- : TensorOf<allowedTypes, [IsSparseTensorPred], "sparse tensor">;
+ : RankedTensorOf<allowedTypes, [IsSparseTensorPred], "sparse tensor">;
class SparseTensorSliceOf<list<Type> allowedTypes>
- : TensorOf<allowedTypes, [IsSparseTensorSlicePred], "sparse tensor slice">;
+ : RankedTensorOf<allowedTypes, [IsSparseTensorSlicePred], "sparse tensor slice">;
class ScalarLikeOf<list<Type> allowedTypes>
: AnyTypeOf<[0DTensorOf<allowedTypes>, AnyTypeOf<allowedTypes>], "scalar like">;
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index 96a61419a541f7..2c281c9f6aa85d 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -92,8 +92,8 @@ def SparseTensor_AssembleOp : SparseTensor_Op<"assemble", [Pure]> {
```
}];
- let arguments = (ins Variadic<TensorOf<[AnySignlessIntegerOrIndex]>>:$levels,
- TensorOf<[AnyType]>:$values);
+ let arguments = (ins Variadic<RankedTensorOf<[AnySignlessIntegerOrIndex]>>:$levels,
+ RankedTensorOf<[AnyType]>:$values);
let results = (outs AnySparseTensor: $result);
let assemblyFormat =
"` ` `(` $levels `)` `,` $values attr-dict `:`"
@@ -138,12 +138,12 @@ def SparseTensor_DisassembleOp : SparseTensor_Op<"disassemble", [Pure, SameVaria
}];
let arguments = (ins AnySparseTensor:$tensor,
- Variadic<TensorOf<[AnySignlessIntegerOrIndex]>>:$out_levels,
- TensorOf<[AnyType]>:$out_values);
- let results = (outs Variadic<TensorOf<[AnySignlessIntegerOrIndex]>>:$ret_levels,
- TensorOf<[AnyType]>:$ret_values,
- Variadic<AnyIndexingScalarLike>:$lvl_lens,
- AnyIndexingScalarLike:$val_len);
+ Variadic<RankedTensorOf<[AnySignlessIntegerOrIndex]>>:$out_levels,
+ RankedTensorOf<[AnyType]>:$out_values);
+ let results = (outs Variadic<RankedTensorOf<[AnySignlessIntegerOrIndex]>>:$ret_levels,
+ RankedTensorOf<[AnyType]>:$ret_values,
+ Variadic<AnyIndexingScalarLike>:$lvl_lens,
+ AnyIndexingScalarLike:$val_len);
let assemblyFormat =
"$tensor attr-dict `:` type($tensor)"
"`out_lvls` `(` $out_levels `:` type($out_levels) `)` "
@@ -196,8 +196,8 @@ def SparseTensor_ConvertOp : SparseTensor_Op<"convert",
}];
- let arguments = (ins AnyTensor:$source);
- let results = (outs AnyTensor:$dest);
+ let arguments = (ins AnyRankedTensor:$source);
+ let results = (outs AnyRankedTensor:$dest);
let assemblyFormat = "$source attr-dict `:` type($source) `to` type($dest)";
let extraClassDeclaration = [{
@@ -1447,7 +1447,7 @@ def SparseTensor_ForeachOp : SparseTensor_Op<"foreach",
];
let regions = (region SizedRegion<1>:$region);
- let arguments = (ins AnyTensor:$tensor,
+ let arguments = (ins AnyRankedTensor:$tensor,
Variadic<AnyType>:$initArgs,
OptionalAttr<AffineMapAttr>:$order);
let results = (outs Variadic<AnyType>:$results);
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index b21bc1a93036c4..7b1b1f383e6343 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -1310,7 +1310,7 @@ static LogicalResult verifyPackUnPack(Operation *op, bool requiresStaticShape,
// The coordinates should be in shape of <? x rank>
unsigned expCOORank = stt.getLvlRank() - cooStartLvl;
if (cooTp.getRank() != 2 || expCOORank != cooTp.getShape().back()) {
- op->emitError("input/output trailing COO level-ranks don't match");
+ return op->emitError("input/output trailing COO level-ranks don't match");
}
}
@@ -1350,7 +1350,7 @@ static LogicalResult verifyPackUnPack(Operation *op, bool requiresStaticShape,
}
LogicalResult AssembleOp::verify() {
- const auto valuesTp = getRankedTensorType(getValues());
+ RankedTensorType valuesTp = getValues().getType();
const auto lvlsTp = getLevels().getTypes();
const auto resTp = getSparseTensorType(getResult());
return verifyPackUnPack(*this, true, resTp, valuesTp, lvlsTp);
@@ -1364,34 +1364,31 @@ LogicalResult DisassembleOp::verify() {
if (ot.getType() != rt.getType())
return emitError("output levels and return levels type mismatch");
- const auto valuesTp = getRankedTensorType(getRetValues());
+ RankedTensorType valuesTp = getRetValues().getType();
const auto lvlsTp = getRetLevels().getTypes();
const auto srcTp = getSparseTensorType(getTensor());
return verifyPackUnPack(*this, false, srcTp, valuesTp, lvlsTp);
}
LogicalResult ConvertOp::verify() {
- if (auto tp1 = llvm::dyn_cast<RankedTensorType>(getSource().getType())) {
- if (auto tp2 = llvm::dyn_cast<RankedTensorType>(getDest().getType())) {
- if (tp1.getRank() != tp2.getRank())
- return emitError("unexpected conversion mismatch in rank");
- auto dstEnc =
- llvm::dyn_cast_or_null<SparseTensorEncodingAttr>(tp2.getEncoding());
- if (dstEnc && dstEnc.isSlice())
- return emitError("cannot convert to a sparse tensor slice");
-
- auto shape1 = tp1.getShape();
- auto shape2 = tp2.getShape();
- // Accept size matches between the source and the destination type
- // (e.g. 10 vs. 10, 10 vs. ?, or ? vs. ?), but reject direct mismatches or
- // matches that would need a runtime assert (e.g. 10 vs. 20 or ? vs. 10).
- for (Dimension d = 0, dimRank = tp1.getRank(); d < dimRank; d++)
- if (shape1[d] != shape2[d] && shape2[d] != ShapedType::kDynamic)
- return emitError("unexpected conversion mismatch in dimension ") << d;
- return success();
- }
- }
- return emitError("unexpected type in convert");
+ RankedTensorType tp1 = getSource().getType();
+ RankedTensorType tp2 = getDest().getType();
+ if (tp1.getRank() != tp2.getRank())
+ return emitError("unexpected conversion mismatch in rank");
+ auto dstEnc =
+ llvm::dyn_cast_or_null<SparseTensorEncodingAttr>(tp2.getEncoding());
+ if (dstEnc && dstEnc.isSlice())
+ return emitError("cannot convert to a sparse tensor slice");
+
+ auto shape1 = tp1.getShape();
+ auto shape2 = tp2.getShape();
+ // Accept size matches between the source and the destination type
+ // (e.g. 10 vs. 10, 10 vs. ?, or ? vs. ?), but reject direct mismatches or
+ // matches that would need a runtime assert (e.g. 10 vs. 20 or ? vs. 10).
+ for (Dimension d = 0, dimRank = tp1.getRank(); d < dimRank; d++)
+ if (shape1[d] != shape2[d] && shape2[d] != ShapedType::kDynamic)
+ return emitError("unexpected conversion mismatch in dimension ") << d;
+ return success();
}
OpFoldResult ConvertOp::fold(FoldAdaptor adaptor) {
@@ -1495,7 +1492,8 @@ LogicalResult LvlOp::verify() {
if (std::optional<uint64_t> lvl = getConstantLvlIndex()) {
auto stt = getSparseTensorType(getSource());
if (static_cast<uint64_t>(lvl.value()) >= stt.getLvlRank())
- emitError("Level index exceeds the rank of the input sparse tensor");
+ return emitError(
+ "Level index exceeds the rank of the input sparse tensor");
}
return success();
}
@@ -1697,14 +1695,14 @@ LogicalResult ToValuesOp::inferReturnTypes(MLIRContext *ctx,
}
LogicalResult ToSliceOffsetOp::verify() {
- auto rank = getRankedTensorType(getSlice()).getRank();
+ auto rank = getSlice().getType().getRank();
if (rank <= getDim().getSExtValue() || getDim().getSExtValue() < 0)
return emitError("requested dimension out of bound");
return success();
}
LogicalResult ToSliceStrideOp::verify() {
- auto rank = getRankedTensorType(getSlice()).getRank();
+ auto rank = getSlice().getType().getRank();
if (rank <= getDim().getSExtValue() || getDim().getSExtValue() < 0)
return emitError("requested dimension out of bound");
return success();
@@ -1986,15 +1984,16 @@ LogicalResult ForeachOp::verify() {
const auto iTp = IndexType::get(getContext());
for (Dimension d = 0; d < dimRank; d++)
if (args[d].getType() != iTp)
- emitError(
+ return emitError(
llvm::formatv("Expecting Index type for argument at index {0}", d));
const auto elemTp = t.getElementType();
const auto valueTp = args[dimRank].getType();
if (elemTp != valueTp)
- emitError(llvm::formatv("Unmatched element type between input tensor and "
- "block argument, expected:{0}, got: {1}",
- elemTp, valueTp));
+ return emitError(
+ llvm::formatv("Unmatched element type between input tensor and "
+ "block argument, expected:{0}, got: {1}",
+ elemTp, valueTp));
return success();
}
@@ -2011,15 +2010,15 @@ LogicalResult ReorderCOOOp::verify() {
SparseTensorType dstStt = getSparseTensorType(getResultCoo());
if (!srcStt.isCOOType() || !dstStt.isCOOType())
- emitError("Expected COO sparse tensors only");
+ return emitError("Expected COO sparse tensors only");
if (!srcStt.hasSameDimToLvl(dstStt))
- emitError("Unmatched dim2lvl map between input and result COO");
+ return emitError("Unmatched dim2lvl map between input and result COO");
if (srcStt.getPosType() != dstStt.getPosType() ||
srcStt.getCrdType() != dstStt.getCrdType() ||
srcStt.getElementType() != dstStt.getElementType())
- emitError("Unmatched storage format between input and result COO");
+ return emitError("Unmatched storage format between input and result COO");
return success();
}
@@ -2044,10 +2043,11 @@ LogicalResult SortOp::verify() {
AffineMap xPerm = getPermMap();
uint64_t nx = xPerm.getNumDims();
if (nx < 1)
- emitError(llvm::formatv("Expected rank(perm_map) > 1, got {0}", nx));
+ return emitError(llvm::formatv("Expected rank(perm_map) > 1, got {0}", nx));
if (!xPerm.isPermutation())
- emitError(llvm::formatv("Expected a permutation map, got {0}", xPerm));
+ return emitError(
+ llvm::formatv("Expected a permutation map, got {0}", xPerm));
// We can't check the size of the buffers when n or buffer dimensions aren't
// compile-time constants.
@@ -2056,19 +2056,24 @@ LogicalResult SortOp::verify() {
return success();
// Verify dimensions.
- const auto checkDim = [&](Value v, Size minSize, const char *message) {
+ const auto checkDim = [&](Value v, Size minSize,
+ const char *message) -> LogicalResult {
const Size sh = getMemRefType(v).getShape()[0];
if (!ShapedType::isDynamic(sh) && sh < minSize)
- emitError(llvm::formatv("{0} got {1} < {2}", message, sh, minSize));
+ return emitError(
+ llvm::formatv("{0} got {1} < {2}", message, sh, minSize));
+ return success();
};
uint64_t n = cn.value();
uint64_t ny = 0;
if (auto nyAttr = getNyAttr())
ny = nyAttr.getInt();
- checkDim(getXy(), n * (nx + ny),
- "Expected dimension(xy) >= n * (rank(perm_map) + ny)");
+ if (failed(checkDim(getXy(), n * (nx + ny),
+ "Expected dimension(xy) >= n * (rank(perm_map) + ny)")))
+ return failure();
for (Value opnd : getYs())
- checkDim(opnd, n, "Expected dimension(y) >= n");
+ if (failed(checkDim(opnd, n, "Expected dimension(y) >= n")))
+ return failure();
return success();
}
@@ -2101,8 +2106,8 @@ static ParseResult parseLevelRange(AsmParser &parser, Level &lvlLo,
}
if (lvlHi <= lvlLo)
- parser.emitError(parser.getNameLoc(),
- "expect larger level upper bound than lower bound");
+ return parser.emitError(parser.getNameLoc(),
+ "expect larger level upper bound than lower bound");
return success();
}
diff --git a/mlir/test/Dialect/SparseTensor/invalid.mlir b/mlir/test/Dialect/SparseTensor/invalid.mlir
index 737b736ba795fe..908d2d8aa83f7c 100644
--- a/mlir/test/Dialect/SparseTensor/invalid.mlir
+++ b/mlir/test/Dialect/SparseTensor/invalid.mlir
@@ -105,7 +105,7 @@ func.func @invalid_positions_dense(%arg0: tensor<128xf64>) -> memref<?xindex> {
func.func @invalid_positions_unranked(%arg0: tensor<*xf64>) -> memref<?xindex> {
// expected-error at +1 {{'sparse_tensor.positions' op operand #0 must be sparse tensor of any type values, but got 'tensor<*xf64>'}}
- %0 = sparse_tensor.positions %arg0 { level = 0 : index } : tensor<*xf64> to memref<?xindex>
+ %0 = "sparse_tensor.positions"(%arg0) { level = 0 : index } : (tensor<*xf64>) -> (memref<?xindex>)
return %0 : memref<?xindex>
}
@@ -141,7 +141,7 @@ func.func @invalid_indices_dense(%arg0: tensor<10x10xi32>) -> memref<?xindex> {
func.func @invalid_indices_unranked(%arg0: tensor<*xf64>) -> memref<?xindex> {
// expected-error at +1 {{'sparse_tensor.coordinates' op operand #0 must be sparse tensor of any type values, but got 'tensor<*xf64>'}}
- %0 = sparse_tensor.coordinates %arg0 { level = 0 : index } : tensor<*xf64> to memref<?xindex>
+ %0 = "sparse_tensor.coordinates"(%arg0) { level = 0 : index } : (tensor<*xf64>) -> (memref<?xindex>)
return %0 : memref<?xindex>
}
@@ -347,7 +347,7 @@ func.func @sparse_wrong_arity_compression(%arg0: memref<?xf64>,
// -----
func.func @sparse_convert_unranked(%arg0: tensor<*xf32>) -> tensor<10xf32> {
- // expected-error at +1 {{unexpected type in convert}}
+ // expected-error at +1 {{invalid kind of type specified}}
%0 = sparse_tensor.convert %arg0 : tensor<*xf32> to tensor<10xf32>
return %0 : tensor<10xf32>
}
``````````
</details>
https://github.com/llvm/llvm-project/pull/112133
More information about the Mlir-commits
mailing list