[Mlir-commits] [mlir] [mlir][sparse] introduce `sparse_tensor.extract_space` operation. (PR #88554)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Apr 12 11:42:15 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-sparse
Author: Peiming Liu (PeimingLiu)
<details>
<summary>Changes</summary>
A `sparse_tensor.extract_space %tensor at %iterator` extracts a *sparse* iteration space defined `%tensor`, the operation to traverse the iteration space will be introduced in following PRs.
---
Full diff: https://github.com/llvm/llvm-project/pull/88554.diff
5 Files Affected:
- (modified) mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td (+51)
- (modified) mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td (+95)
- (modified) mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp (+102)
- (modified) mlir/test/Dialect/SparseTensor/invalid.mlir (+82)
- (modified) mlir/test/Dialect/SparseTensor/roundtrip.mlir (+25)
``````````diff
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index 0cfc64f9988a0a..b8e4edc8c8537b 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -1430,6 +1430,57 @@ def SparseTensor_ForeachOp : SparseTensor_Op<"foreach",
let hasVerifier = 1;
}
+//===----------------------------------------------------------------------===//
+// Sparse Tensor Iteration Operations.
+//===----------------------------------------------------------------------===//
+
+def ExtractIterSpaceOp : SparseTensor_Op<"iteration.extract_space",
+ [Pure, DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
+
+ let arguments = (ins AnySparseTensor:$tensor,
+ Optional<AnySparseIterator>:$parentIter,
+ LevelAttr:$loLvl, LevelAttr:$hiLvl);
+
+ let results = (outs AnySparseIterSpace:$resultSpace);
+
+ let summary = "Extract an iteration space from a sparse tensor between certain levels";
+ let description = [{
+ Extracts a `!sparse_tensor.iter_space` from a sparse tensor between
+ certian (consecutive) levels.
+
+ `tensor`: the input sparse tensor that defines the iteration space.
+ `parentIter`: the iterator for the previous level, at which the iteration space
+ at the current levels will be extracted.
+ `loLvl`, `hiLvl`: the level range between [loLvl, hiLvl) in the input tensor that
+ the returned iteration space covers. `hiLvl - loLvl` defines the dimension of the
+ iteration space.
+
+ Example:
+ ```mlir
+ // Extracts a 1-D iteration space from a COO tensor at level 1.
+ %space = sparse_tensor.iteration.extract_space %sp at %it1 lvls = 1
+ : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 0>
+ ```
+ }];
+
+
+ let extraClassDeclaration = [{
+ std::pair<Level, Level> getLvlRange() {
+ return std::make_pair(getLoLvl(), getHiLvl());
+ }
+ unsigned getSpaceDim() {
+ return getHiLvl() - getLoLvl();
+ }
+ ArrayRef<::mlir::sparse_tensor::LevelType> getSpaceLvlTypes() {
+ return getResultSpace().getType().getLvlTypes();
+ }
+ }];
+
+ let hasVerifier = 1;
+ let assemblyFormat = "$tensor (`at` $parentIter^)? `lvls` `=` custom<LevelRange>($loLvl, $hiLvl) "
+ " attr-dict `:` type($tensor) (`,` type($parentIter)^)?";
+}
+
//===----------------------------------------------------------------------===//
// Sparse Tensor Debugging and Test-Only Operations.
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td
index 185cff46ae25d5..264a0a5b3bee6c 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td
@@ -72,4 +72,99 @@ def SparseTensorStorageSpecifier
: Type<CPred<"::llvm::isa<::mlir::sparse_tensor::StorageSpecifierType>($_self)">, "metadata",
"::mlir::sparse_tensor::StorageSpecifierType">;
+//===----------------------------------------------------------------------===//
+// Sparse Tensor Iteration Types.
+//===----------------------------------------------------------------------===//
+
+def SparseTensor_IterSpace : SparseTensor_Type<"IterSpace"> {
+ let mnemonic = "iter_space";
+
+ let description = [{
+ A sparse iteration space that represents an abstract N-D (sparse) iteration space
+ extracted from a sparse tensor.
+
+ Examples:
+
+ ```mlir
+ // An iteration space extracted from a CSR tensor between levels [0, 2).
+ !iter_space<#CSR, lvls = 0 to 2>
+ ```
+ }];
+
+ let parameters = (ins
+ SparseTensorEncodingAttr : $encoding,
+ "Level" : $loLvl,
+ "Level" : $hiLvl
+ );
+
+ let extraClassDeclaration = [{
+ /// The the dimension of the iteration space.
+ unsigned getSpaceDim() const {
+ return getHiLvl() - getLoLvl();
+ }
+
+ /// Get the level types for the iteration space.
+ ArrayRef<LevelType> getLvlTypes() const {
+ return getEncoding().getLvlTypes().slice(getLoLvl(), getSpaceDim());
+ }
+
+ /// Whether the iteration space is unique (i.e., no duplicated coordinate).
+ bool isUnique() {
+ return !getLvlTypes().back().isa<LevelPropNonDefault::Nonunique>();
+ }
+
+ /// Get the corresponding iterator type.
+ ::mlir::sparse_tensor::IteratorType getIteratorType() const;
+ }];
+
+ let assemblyFormat="`<` $encoding `,` `lvls` `=` custom<LevelRange>($loLvl, $hiLvl) `>`";
+}
+
+def SparseTensor_Iterator : SparseTensor_Type<"Iterator"> {
+ let mnemonic = "iterator";
+
+ let description = [{
+ An iterator that points to the current element in the corresponding iteration space.
+
+ Examples:
+
+ ```mlir
+ // An iterator that iterates over a iteration space of type `!iter_space<#CSR, lvls = 0 to 2>`
+ !iterator<#CSR, lvls = 0 to 2>
+ ```
+ }];
+
+ let parameters = (ins
+ SparseTensorEncodingAttr : $encoding,
+ "Level" : $loLvl,
+ "Level" : $hiLvl
+ );
+
+ let extraClassDeclaration = [{
+ /// Get the corresponding iteration space type.
+ ::mlir::sparse_tensor::IterSpaceType getIterSpaceType() const;
+
+ unsigned getSpaceDim() const { return getIterSpaceType().getSpaceDim(); }
+ ArrayRef<LevelType> getLvlTypes() const { return getIterSpaceType().getLvlTypes(); }
+ bool isUnique() { return getIterSpaceType().isUnique(); }
+ }];
+
+ let assemblyFormat="`<` $encoding `,` `lvls` `=` custom<LevelRange>($loLvl, $hiLvl) `>`";
+}
+
+def IsSparseSparseIterSpaceTypePred
+ : CPred<"::llvm::isa<::mlir::sparse_tensor::IterSpaceType>($_self)">;
+
+def IsSparseSparseIteratorTypePred
+ : CPred<"::llvm::isa<::mlir::sparse_tensor::IteratorType>($_self)">;
+
+def AnySparseIterSpace
+ : Type<IsSparseSparseIterSpaceTypePred, "sparse iteration space",
+ "::mlir::sparse_tensor::IterSpaceType">;
+
+def AnySparseIterator
+ : Type<IsSparseSparseIteratorTypePred, "sparse iterator",
+ "::mlir::sparse_tensor::IteratorType">;
+
+
#endif // SPARSETENSOR_TYPES
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index e9058394d33da5..a9d2d2b8826f37 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -30,6 +30,14 @@
#include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc"
#include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrEnums.cpp.inc"
+// Forward declarations, following custom print/parsing methods are referenced
+// by the generated code for SparseTensorTypes.td.
+static mlir::ParseResult parseLevelRange(mlir::AsmParser &,
+ mlir::sparse_tensor::Level &,
+ mlir::sparse_tensor::Level &);
+static void printLevelRange(mlir::AsmPrinter &, mlir::sparse_tensor::Level,
+ mlir::sparse_tensor::Level);
+
#define GET_TYPEDEF_CLASSES
#include "mlir/Dialect/SparseTensor/IR/SparseTensorTypes.cpp.inc"
@@ -1953,6 +1961,100 @@ LogicalResult SortOp::verify() {
return success();
}
+//===----------------------------------------------------------------------===//
+// Sparse Tensor Iteration Operations.
+//===----------------------------------------------------------------------===//
+
+IterSpaceType IteratorType::getIterSpaceType() const {
+ return IterSpaceType::get(getContext(), getEncoding(), getLoLvl(),
+ getHiLvl());
+}
+
+IteratorType IterSpaceType::getIteratorType() const {
+ return IteratorType::get(getContext(), getEncoding(), getLoLvl(), getHiLvl());
+}
+
+static ParseResult parseLevelRange(AsmParser &parser, Level &lvlLo,
+ Level &lvlHi) {
+ if (parser.parseInteger(lvlLo))
+ return failure();
+
+ if (succeeded(parser.parseOptionalKeyword("to"))) {
+ if (parser.parseInteger(lvlHi))
+ return failure();
+ } else {
+ lvlHi = lvlLo + 1;
+ }
+
+ if (lvlHi <= lvlLo)
+ parser.emitError(parser.getNameLoc(),
+ "expect larger level upper bound than lower bound");
+
+ return success();
+}
+
+static ParseResult parseLevelRange(OpAsmParser &parser, IntegerAttr &lvlLoAttr,
+ IntegerAttr &lvlHiAttr) {
+ Level lvlLo, lvlHi;
+ if (parseLevelRange(parser, lvlLo, lvlHi))
+ return failure();
+
+ lvlLoAttr = IntegerAttr::get(parser.getBuilder().getIndexType(), lvlLo);
+ lvlHiAttr = IntegerAttr::get(parser.getBuilder().getIndexType(), lvlHi);
+ return success();
+}
+
+static void printLevelRange(AsmPrinter &p, Level lo, Level hi) {
+
+ if (lo + 1 == hi)
+ p << lo;
+ else
+ p << lo << " to " << hi;
+}
+
+static void printLevelRange(OpAsmPrinter &p, Operation *, IntegerAttr lvlLo,
+ IntegerAttr lvlHi) {
+ unsigned lo = lvlLo.getValue().getZExtValue();
+ unsigned hi = lvlHi.getValue().getZExtValue();
+ printLevelRange(p, lo, hi);
+}
+
+LogicalResult ExtractIterSpaceOp::inferReturnTypes(
+ MLIRContext *ctx, std::optional<Location> loc, ValueRange ops,
+ DictionaryAttr attr, OpaqueProperties prop, RegionRange region,
+ SmallVectorImpl<mlir::Type> &ret) {
+
+ ExtractIterSpaceOp::Adaptor adaptor(ops, attr, prop, region);
+ SparseTensorType stt = getSparseTensorType(adaptor.getTensor());
+ ret.push_back(IterSpaceType::get(ctx, stt.getEncoding(), adaptor.getLoLvl(),
+ adaptor.getHiLvl()));
+ return success();
+}
+
+LogicalResult ExtractIterSpaceOp::verify() {
+ if (getLoLvl() >= getHiLvl())
+ return emitOpError("expected smaller level low than level high");
+
+ TypedValue<IteratorType> pIter = getParentIter();
+ if ((pIter && getLoLvl() == 0) || (!pIter && getLoLvl() != 0)) {
+ return emitOpError(
+ "parent iterator should be specified iff level lower bound equals 0");
+ }
+
+ if (pIter) {
+ IterSpaceType spaceTp = getResultSpace().getType();
+ if (pIter.getType().getEncoding() != spaceTp.getEncoding())
+ return emitOpError(
+ "mismatch in parent iterator encoding and iteration space encoding.");
+
+ if (spaceTp.getLoLvl() != pIter.getType().getHiLvl())
+ return emitOpError("parent iterator should be used to extract an "
+ "iteration space from a consecutive level.");
+ }
+
+ return success();
+}
+
/// Materialize a single constant operation from a given attribute value with
/// the desired resultant type.
Operation *SparseTensorDialect::materializeConstant(OpBuilder &builder,
diff --git a/mlir/test/Dialect/SparseTensor/invalid.mlir b/mlir/test/Dialect/SparseTensor/invalid.mlir
index 7f5c05190fc9a2..579625e22f067f 100644
--- a/mlir/test/Dialect/SparseTensor/invalid.mlir
+++ b/mlir/test/Dialect/SparseTensor/invalid.mlir
@@ -1012,3 +1012,85 @@ func.func @sparse_print(%arg0: tensor<10x10xf64>) {
sparse_tensor.print %arg0 : tensor<10x10xf64>
return
}
+
+// -----
+
+#COO = #sparse_tensor.encoding<{
+ map = (i, j) -> (
+ i : compressed(nonunique),
+ j : singleton(soa)
+ )
+}>
+
+func.func @sparse_extract_iter_space(%sp : tensor<4x8xf32, #COO>, %it1 : !sparse_tensor.iterator<#COO, lvls = 2>) {
+ // expected-error at +1 {{'sparse_tensor.iteration.extract_space' expect larger level upper bound than lower bound}}
+ %l1 = sparse_tensor.iteration.extract_space %sp at %it1 lvls = 2 to 0 : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 2>
+ return
+}
+
+// -----
+
+#COO = #sparse_tensor.encoding<{
+ map = (i, j) -> (
+ i : compressed(nonunique),
+ j : singleton(soa)
+ )
+}>
+
+func.func @sparse_extract_iter_space(%sp : tensor<4x8xf32, #COO>, %it1 : !sparse_tensor.iterator<#COO, lvls = 0>) {
+ // expected-error at +1 {{'sparse_tensor.iteration.extract_space' op parent iterator should be specified iff level lower bound equals 0}}
+ %l1 = sparse_tensor.iteration.extract_space %sp at %it1 lvls = 0 : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 0>
+ return
+}
+
+// -----
+
+#COO = #sparse_tensor.encoding<{
+ map = (i, j) -> (
+ i : compressed(nonunique),
+ j : singleton(soa)
+ )
+}>
+
+func.func @sparse_extract_iter_space(%sp : tensor<4x8xf32, #COO>) {
+ // expected-error at +1 {{'sparse_tensor.iteration.extract_space' op parent iterator should be specified iff level lower bound equals 0}}
+ %l1 = sparse_tensor.iteration.extract_space %sp lvls = 1 : tensor<4x8xf32, #COO>
+ return
+}
+
+// -----
+
+#COO = #sparse_tensor.encoding<{
+ map = (i, j) -> (
+ i : compressed(nonunique),
+ j : singleton(soa)
+ )
+}>
+
+#CSR = #sparse_tensor.encoding<{
+ map = (i, j) -> (
+ i : dense,
+ j : compressed
+ )
+}>
+
+func.func @sparse_extract_iter_space(%sp : tensor<4x8xf32, #COO>, %it1 : !sparse_tensor.iterator<#CSR, lvls = 0>) {
+ // expected-error at +1 {{'sparse_tensor.iteration.extract_space' op mismatch in parent iterator encoding and iteration space encoding.}}
+ %l1 = sparse_tensor.iteration.extract_space %sp at %it1 lvls = 1 : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#CSR, lvls = 0>
+ return
+}
+
+// -----
+
+#COO = #sparse_tensor.encoding<{
+ map = (i, j) -> (
+ i : compressed(nonunique),
+ j : singleton(soa)
+ )
+}>
+
+func.func @sparse_extract_iter_space(%sp : tensor<4x8xf32, #COO>, %it1 : !sparse_tensor.iterator<#COO, lvls = 0>) {
+ // expected-error at +1 {{'sparse_tensor.iteration.extract_space' op parent iterator should be used to extract an iteration space from a consecutive level.}}
+ %l1 = sparse_tensor.iteration.extract_space %sp at %it1 lvls = 2 : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 0>
+ return
+}
diff --git a/mlir/test/Dialect/SparseTensor/roundtrip.mlir b/mlir/test/Dialect/SparseTensor/roundtrip.mlir
index 12f69c1d37b9cd..6c0887ef4d826d 100644
--- a/mlir/test/Dialect/SparseTensor/roundtrip.mlir
+++ b/mlir/test/Dialect/SparseTensor/roundtrip.mlir
@@ -738,3 +738,28 @@ func.func @sparse_has_runtime() -> i1 {
%has_runtime = sparse_tensor.has_runtime_library
return %has_runtime : i1
}
+
+// -----
+
+#COO = #sparse_tensor.encoding<{
+ map = (i, j) -> (
+ i : compressed(nonunique),
+ j : singleton(soa)
+ )
+}>
+
+// CHECK-LABEL: func.func @sparse_extract_iter_space(
+// CHECK-SAME: %[[VAL_0:.*]]: tensor<4x8xf32, #sparse{{[0-9]*}}>,
+// CHECK-SAME: %[[VAL_1:.*]]: !sparse_tensor.iterator<#sparse{{[0-9]*}}, lvls = 0>)
+// CHECK: %[[VAL_2:.*]] = sparse_tensor.iteration.extract_space %[[VAL_0]] lvls = 0
+// CHECK: %[[VAL_3:.*]] = sparse_tensor.iteration.extract_space %[[VAL_0]] at %[[VAL_1]] lvls = 1
+// CHECK: return %[[VAL_2]], %[[VAL_3]] : !sparse_tensor.iter_space<#sparse{{[0-9]*}}, lvls = 0>, !sparse_tensor.iter_space<#sparse{{[0-9]*}}, lvls = 1>
+// CHECK: }
+func.func @sparse_extract_iter_space(%sp : tensor<4x8xf32, #COO>, %it1 : !sparse_tensor.iterator<#COO, lvls = 0>)
+ -> (!sparse_tensor.iter_space<#COO, lvls = 0>, !sparse_tensor.iter_space<#COO, lvls = 1>) {
+ // Extracting the iteration space for the first level needs no parent iterator.
+ %l1 = sparse_tensor.iteration.extract_space %sp lvls = 0 : tensor<4x8xf32, #COO>
+ // Extracting the iteration space for the second level needs a parent iterator.
+ %l2 = sparse_tensor.iteration.extract_space %sp at %it1 lvls = 1 : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 0>
+ return %l1, %l2 : !sparse_tensor.iter_space<#COO, lvls = 0>, !sparse_tensor.iter_space<#COO, lvls = 1>
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/88554
More information about the Mlir-commits
mailing list