[Mlir-commits] [mlir] [mlir][sparse] introduce `sparse_tensor.extract_iteration_space` operation. (PR #88554)
Peiming Liu
llvmlistbot at llvm.org
Tue Apr 16 10:30:34 PDT 2024
https://github.com/PeimingLiu updated https://github.com/llvm/llvm-project/pull/88554
>From 43e96384a5b5a51a9dc8704bb1b003ae35f16d17 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Fri, 12 Apr 2024 17:20:40 +0000
Subject: [PATCH 1/3] [mlir][sparse] introduce `sparse_tensor.extract_space`
operation that extracts a sparse iteration space to iterate over.
---
.../SparseTensor/IR/SparseTensorOps.td | 51 +++++++++
.../SparseTensor/IR/SparseTensorTypes.td | 95 ++++++++++++++++
.../SparseTensor/IR/SparseTensorDialect.cpp | 102 ++++++++++++++++++
mlir/test/Dialect/SparseTensor/invalid.mlir | 82 ++++++++++++++
mlir/test/Dialect/SparseTensor/roundtrip.mlir | 25 +++++
5 files changed, 355 insertions(+)
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index 0cfc64f9988a0a..b8e4edc8c8537b 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -1430,6 +1430,57 @@ def SparseTensor_ForeachOp : SparseTensor_Op<"foreach",
let hasVerifier = 1;
}
+//===----------------------------------------------------------------------===//
+// Sparse Tensor Iteration Operations.
+//===----------------------------------------------------------------------===//
+
+def ExtractIterSpaceOp : SparseTensor_Op<"iteration.extract_space",
+ [Pure, DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
+
+ let arguments = (ins AnySparseTensor:$tensor,
+ Optional<AnySparseIterator>:$parentIter,
+ LevelAttr:$loLvl, LevelAttr:$hiLvl);
+
+ let results = (outs AnySparseIterSpace:$resultSpace);
+
+ let summary = "Extract an iteration space from a sparse tensor between certain levels";
+ let description = [{
+ Extracts a `!sparse_tensor.iter_space` from a sparse tensor between
+ certian (consecutive) levels.
+
+ `tensor`: the input sparse tensor that defines the iteration space.
+ `parentIter`: the iterator for the previous level, at which the iteration space
+ at the current levels will be extracted.
+ `loLvl`, `hiLvl`: the level range between [loLvl, hiLvl) in the input tensor that
+ the returned iteration space covers. `hiLvl - loLvl` defines the dimension of the
+ iteration space.
+
+ Example:
+ ```mlir
+ // Extracts a 1-D iteration space from a COO tensor at level 1.
+ %space = sparse_tensor.iteration.extract_space %sp at %it1 lvls = 1
+ : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 0>
+ ```
+ }];
+
+
+ let extraClassDeclaration = [{
+ std::pair<Level, Level> getLvlRange() {
+ return std::make_pair(getLoLvl(), getHiLvl());
+ }
+ unsigned getSpaceDim() {
+ return getHiLvl() - getLoLvl();
+ }
+ ArrayRef<::mlir::sparse_tensor::LevelType> getSpaceLvlTypes() {
+ return getResultSpace().getType().getLvlTypes();
+ }
+ }];
+
+ let hasVerifier = 1;
+ let assemblyFormat = "$tensor (`at` $parentIter^)? `lvls` `=` custom<LevelRange>($loLvl, $hiLvl) "
+ " attr-dict `:` type($tensor) (`,` type($parentIter)^)?";
+}
+
//===----------------------------------------------------------------------===//
// Sparse Tensor Debugging and Test-Only Operations.
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td
index 185cff46ae25d5..264a0a5b3bee6c 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td
@@ -72,4 +72,99 @@ def SparseTensorStorageSpecifier
: Type<CPred<"::llvm::isa<::mlir::sparse_tensor::StorageSpecifierType>($_self)">, "metadata",
"::mlir::sparse_tensor::StorageSpecifierType">;
+//===----------------------------------------------------------------------===//
+// Sparse Tensor Iteration Types.
+//===----------------------------------------------------------------------===//
+
+def SparseTensor_IterSpace : SparseTensor_Type<"IterSpace"> {
+ let mnemonic = "iter_space";
+
+ let description = [{
+ A sparse iteration space that represents an abstract N-D (sparse) iteration space
+ extracted from a sparse tensor.
+
+ Examples:
+
+ ```mlir
+ // An iteration space extracted from a CSR tensor between levels [0, 2).
+ !iter_space<#CSR, lvls = 0 to 2>
+ ```
+ }];
+
+ let parameters = (ins
+ SparseTensorEncodingAttr : $encoding,
+ "Level" : $loLvl,
+ "Level" : $hiLvl
+ );
+
+ let extraClassDeclaration = [{
+ /// The the dimension of the iteration space.
+ unsigned getSpaceDim() const {
+ return getHiLvl() - getLoLvl();
+ }
+
+ /// Get the level types for the iteration space.
+ ArrayRef<LevelType> getLvlTypes() const {
+ return getEncoding().getLvlTypes().slice(getLoLvl(), getSpaceDim());
+ }
+
+ /// Whether the iteration space is unique (i.e., no duplicated coordinate).
+ bool isUnique() {
+ return !getLvlTypes().back().isa<LevelPropNonDefault::Nonunique>();
+ }
+
+ /// Get the corresponding iterator type.
+ ::mlir::sparse_tensor::IteratorType getIteratorType() const;
+ }];
+
+ let assemblyFormat="`<` $encoding `,` `lvls` `=` custom<LevelRange>($loLvl, $hiLvl) `>`";
+}
+
+def SparseTensor_Iterator : SparseTensor_Type<"Iterator"> {
+ let mnemonic = "iterator";
+
+ let description = [{
+ An iterator that points to the current element in the corresponding iteration space.
+
+ Examples:
+
+ ```mlir
+ // An iterator that iterates over a iteration space of type `!iter_space<#CSR, lvls = 0 to 2>`
+ !iterator<#CSR, lvls = 0 to 2>
+ ```
+ }];
+
+ let parameters = (ins
+ SparseTensorEncodingAttr : $encoding,
+ "Level" : $loLvl,
+ "Level" : $hiLvl
+ );
+
+ let extraClassDeclaration = [{
+ /// Get the corresponding iteration space type.
+ ::mlir::sparse_tensor::IterSpaceType getIterSpaceType() const;
+
+ unsigned getSpaceDim() const { return getIterSpaceType().getSpaceDim(); }
+ ArrayRef<LevelType> getLvlTypes() const { return getIterSpaceType().getLvlTypes(); }
+ bool isUnique() { return getIterSpaceType().isUnique(); }
+ }];
+
+ let assemblyFormat="`<` $encoding `,` `lvls` `=` custom<LevelRange>($loLvl, $hiLvl) `>`";
+}
+
+def IsSparseSparseIterSpaceTypePred
+ : CPred<"::llvm::isa<::mlir::sparse_tensor::IterSpaceType>($_self)">;
+
+def IsSparseSparseIteratorTypePred
+ : CPred<"::llvm::isa<::mlir::sparse_tensor::IteratorType>($_self)">;
+
+def AnySparseIterSpace
+ : Type<IsSparseSparseIterSpaceTypePred, "sparse iteration space",
+ "::mlir::sparse_tensor::IterSpaceType">;
+
+def AnySparseIterator
+ : Type<IsSparseSparseIteratorTypePred, "sparse iterator",
+ "::mlir::sparse_tensor::IteratorType">;
+
+
#endif // SPARSETENSOR_TYPES
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index e9058394d33da5..a9d2d2b8826f37 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -30,6 +30,14 @@
#include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc"
#include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrEnums.cpp.inc"
+// Forward declarations, following custom print/parsing methods are referenced
+// by the generated code for SparseTensorTypes.td.
+static mlir::ParseResult parseLevelRange(mlir::AsmParser &,
+ mlir::sparse_tensor::Level &,
+ mlir::sparse_tensor::Level &);
+static void printLevelRange(mlir::AsmPrinter &, mlir::sparse_tensor::Level,
+ mlir::sparse_tensor::Level);
+
#define GET_TYPEDEF_CLASSES
#include "mlir/Dialect/SparseTensor/IR/SparseTensorTypes.cpp.inc"
@@ -1953,6 +1961,100 @@ LogicalResult SortOp::verify() {
return success();
}
+//===----------------------------------------------------------------------===//
+// Sparse Tensor Iteration Operations.
+//===----------------------------------------------------------------------===//
+
+IterSpaceType IteratorType::getIterSpaceType() const {
+ return IterSpaceType::get(getContext(), getEncoding(), getLoLvl(),
+ getHiLvl());
+}
+
+IteratorType IterSpaceType::getIteratorType() const {
+ return IteratorType::get(getContext(), getEncoding(), getLoLvl(), getHiLvl());
+}
+
+static ParseResult parseLevelRange(AsmParser &parser, Level &lvlLo,
+ Level &lvlHi) {
+ if (parser.parseInteger(lvlLo))
+ return failure();
+
+ if (succeeded(parser.parseOptionalKeyword("to"))) {
+ if (parser.parseInteger(lvlHi))
+ return failure();
+ } else {
+ lvlHi = lvlLo + 1;
+ }
+
+ if (lvlHi <= lvlLo)
+ parser.emitError(parser.getNameLoc(),
+ "expect larger level upper bound than lower bound");
+
+ return success();
+}
+
+static ParseResult parseLevelRange(OpAsmParser &parser, IntegerAttr &lvlLoAttr,
+ IntegerAttr &lvlHiAttr) {
+ Level lvlLo, lvlHi;
+ if (parseLevelRange(parser, lvlLo, lvlHi))
+ return failure();
+
+ lvlLoAttr = IntegerAttr::get(parser.getBuilder().getIndexType(), lvlLo);
+ lvlHiAttr = IntegerAttr::get(parser.getBuilder().getIndexType(), lvlHi);
+ return success();
+}
+
+static void printLevelRange(AsmPrinter &p, Level lo, Level hi) {
+
+ if (lo + 1 == hi)
+ p << lo;
+ else
+ p << lo << " to " << hi;
+}
+
+static void printLevelRange(OpAsmPrinter &p, Operation *, IntegerAttr lvlLo,
+ IntegerAttr lvlHi) {
+ unsigned lo = lvlLo.getValue().getZExtValue();
+ unsigned hi = lvlHi.getValue().getZExtValue();
+ printLevelRange(p, lo, hi);
+}
+
+LogicalResult ExtractIterSpaceOp::inferReturnTypes(
+ MLIRContext *ctx, std::optional<Location> loc, ValueRange ops,
+ DictionaryAttr attr, OpaqueProperties prop, RegionRange region,
+ SmallVectorImpl<mlir::Type> &ret) {
+
+ ExtractIterSpaceOp::Adaptor adaptor(ops, attr, prop, region);
+ SparseTensorType stt = getSparseTensorType(adaptor.getTensor());
+ ret.push_back(IterSpaceType::get(ctx, stt.getEncoding(), adaptor.getLoLvl(),
+ adaptor.getHiLvl()));
+ return success();
+}
+
+LogicalResult ExtractIterSpaceOp::verify() {
+ if (getLoLvl() >= getHiLvl())
+ return emitOpError("expected smaller level low than level high");
+
+ TypedValue<IteratorType> pIter = getParentIter();
+ if ((pIter && getLoLvl() == 0) || (!pIter && getLoLvl() != 0)) {
+ return emitOpError(
+ "parent iterator should be specified iff level lower bound equals 0");
+ }
+
+ if (pIter) {
+ IterSpaceType spaceTp = getResultSpace().getType();
+ if (pIter.getType().getEncoding() != spaceTp.getEncoding())
+ return emitOpError(
+ "mismatch in parent iterator encoding and iteration space encoding.");
+
+ if (spaceTp.getLoLvl() != pIter.getType().getHiLvl())
+ return emitOpError("parent iterator should be used to extract an "
+ "iteration space from a consecutive level.");
+ }
+
+ return success();
+}
+
/// Materialize a single constant operation from a given attribute value with
/// the desired resultant type.
Operation *SparseTensorDialect::materializeConstant(OpBuilder &builder,
diff --git a/mlir/test/Dialect/SparseTensor/invalid.mlir b/mlir/test/Dialect/SparseTensor/invalid.mlir
index 7f5c05190fc9a2..579625e22f067f 100644
--- a/mlir/test/Dialect/SparseTensor/invalid.mlir
+++ b/mlir/test/Dialect/SparseTensor/invalid.mlir
@@ -1012,3 +1012,85 @@ func.func @sparse_print(%arg0: tensor<10x10xf64>) {
sparse_tensor.print %arg0 : tensor<10x10xf64>
return
}
+
+// -----
+
+#COO = #sparse_tensor.encoding<{
+ map = (i, j) -> (
+ i : compressed(nonunique),
+ j : singleton(soa)
+ )
+}>
+
+func.func @sparse_extract_iter_space(%sp : tensor<4x8xf32, #COO>, %it1 : !sparse_tensor.iterator<#COO, lvls = 2>) {
+ // expected-error at +1 {{'sparse_tensor.iteration.extract_space' expect larger level upper bound than lower bound}}
+ %l1 = sparse_tensor.iteration.extract_space %sp at %it1 lvls = 2 to 0 : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 2>
+ return
+}
+
+// -----
+
+#COO = #sparse_tensor.encoding<{
+ map = (i, j) -> (
+ i : compressed(nonunique),
+ j : singleton(soa)
+ )
+}>
+
+func.func @sparse_extract_iter_space(%sp : tensor<4x8xf32, #COO>, %it1 : !sparse_tensor.iterator<#COO, lvls = 0>) {
+ // expected-error at +1 {{'sparse_tensor.iteration.extract_space' op parent iterator should be specified iff level lower bound equals 0}}
+ %l1 = sparse_tensor.iteration.extract_space %sp at %it1 lvls = 0 : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 0>
+ return
+}
+
+// -----
+
+#COO = #sparse_tensor.encoding<{
+ map = (i, j) -> (
+ i : compressed(nonunique),
+ j : singleton(soa)
+ )
+}>
+
+func.func @sparse_extract_iter_space(%sp : tensor<4x8xf32, #COO>) {
+ // expected-error at +1 {{'sparse_tensor.iteration.extract_space' op parent iterator should be specified iff level lower bound equals 0}}
+ %l1 = sparse_tensor.iteration.extract_space %sp lvls = 1 : tensor<4x8xf32, #COO>
+ return
+}
+
+// -----
+
+#COO = #sparse_tensor.encoding<{
+ map = (i, j) -> (
+ i : compressed(nonunique),
+ j : singleton(soa)
+ )
+}>
+
+#CSR = #sparse_tensor.encoding<{
+ map = (i, j) -> (
+ i : dense,
+ j : compressed
+ )
+}>
+
+func.func @sparse_extract_iter_space(%sp : tensor<4x8xf32, #COO>, %it1 : !sparse_tensor.iterator<#CSR, lvls = 0>) {
+ // expected-error at +1 {{'sparse_tensor.iteration.extract_space' op mismatch in parent iterator encoding and iteration space encoding.}}
+ %l1 = sparse_tensor.iteration.extract_space %sp at %it1 lvls = 1 : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#CSR, lvls = 0>
+ return
+}
+
+// -----
+
+#COO = #sparse_tensor.encoding<{
+ map = (i, j) -> (
+ i : compressed(nonunique),
+ j : singleton(soa)
+ )
+}>
+
+func.func @sparse_extract_iter_space(%sp : tensor<4x8xf32, #COO>, %it1 : !sparse_tensor.iterator<#COO, lvls = 0>) {
+ // expected-error at +1 {{'sparse_tensor.iteration.extract_space' op parent iterator should be used to extract an iteration space from a consecutive level.}}
+ %l1 = sparse_tensor.iteration.extract_space %sp at %it1 lvls = 2 : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 0>
+ return
+}
diff --git a/mlir/test/Dialect/SparseTensor/roundtrip.mlir b/mlir/test/Dialect/SparseTensor/roundtrip.mlir
index 12f69c1d37b9cd..6c0887ef4d826d 100644
--- a/mlir/test/Dialect/SparseTensor/roundtrip.mlir
+++ b/mlir/test/Dialect/SparseTensor/roundtrip.mlir
@@ -738,3 +738,28 @@ func.func @sparse_has_runtime() -> i1 {
%has_runtime = sparse_tensor.has_runtime_library
return %has_runtime : i1
}
+
+// -----
+
+#COO = #sparse_tensor.encoding<{
+ map = (i, j) -> (
+ i : compressed(nonunique),
+ j : singleton(soa)
+ )
+}>
+
+// CHECK-LABEL: func.func @sparse_extract_iter_space(
+// CHECK-SAME: %[[VAL_0:.*]]: tensor<4x8xf32, #sparse{{[0-9]*}}>,
+// CHECK-SAME: %[[VAL_1:.*]]: !sparse_tensor.iterator<#sparse{{[0-9]*}}, lvls = 0>)
+// CHECK: %[[VAL_2:.*]] = sparse_tensor.iteration.extract_space %[[VAL_0]] lvls = 0
+// CHECK: %[[VAL_3:.*]] = sparse_tensor.iteration.extract_space %[[VAL_0]] at %[[VAL_1]] lvls = 1
+// CHECK: return %[[VAL_2]], %[[VAL_3]] : !sparse_tensor.iter_space<#sparse{{[0-9]*}}, lvls = 0>, !sparse_tensor.iter_space<#sparse{{[0-9]*}}, lvls = 1>
+// CHECK: }
+func.func @sparse_extract_iter_space(%sp : tensor<4x8xf32, #COO>, %it1 : !sparse_tensor.iterator<#COO, lvls = 0>)
+ -> (!sparse_tensor.iter_space<#COO, lvls = 0>, !sparse_tensor.iter_space<#COO, lvls = 1>) {
+ // Extracting the iteration space for the first level needs no parent iterator.
+ %l1 = sparse_tensor.iteration.extract_space %sp lvls = 0 : tensor<4x8xf32, #COO>
+ // Extracting the iteration space for the second level needs a parent iterator.
+ %l2 = sparse_tensor.iteration.extract_space %sp at %it1 lvls = 1 : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 0>
+ return %l1, %l2 : !sparse_tensor.iter_space<#COO, lvls = 0>, !sparse_tensor.iter_space<#COO, lvls = 1>
+}
>From 8f5f89145636b48ba4b62cbb2a007f6bcfe5d227 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Mon, 15 Apr 2024 16:58:06 +0000
Subject: [PATCH 2/3] update operation name.
---
.../SparseTensor/IR/SparseTensorOps.td | 2 +-
.../SparseTensor/IR/SparseTensorDialect.cpp | 8 ++++++++
mlir/test/Dialect/SparseTensor/invalid.mlir | 20 +++++++++----------
mlir/test/Dialect/SparseTensor/roundtrip.mlir | 8 ++++----
4 files changed, 23 insertions(+), 15 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index b8e4edc8c8537b..38722990b54450 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -1434,7 +1434,7 @@ def SparseTensor_ForeachOp : SparseTensor_Op<"foreach",
// Sparse Tensor Iteration Operations.
//===----------------------------------------------------------------------===//
-def ExtractIterSpaceOp : SparseTensor_Op<"iteration.extract_space",
+def ExtractIterSpaceOp : SparseTensor_Op<"extract_iteration_space",
[Pure, DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
let arguments = (ins AnySparseTensor:$tensor,
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index a9d2d2b8826f37..516b0943bdcfac 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -1974,6 +1974,8 @@ IteratorType IterSpaceType::getIteratorType() const {
return IteratorType::get(getContext(), getEncoding(), getLoLvl(), getHiLvl());
}
+/// Parses a level range in the form "$lo `to` $hi"
+/// or simply "$lo" if $hi - $lo = 1
static ParseResult parseLevelRange(AsmParser &parser, Level &lvlLo,
Level &lvlHi) {
if (parser.parseInteger(lvlLo))
@@ -1993,6 +1995,8 @@ static ParseResult parseLevelRange(AsmParser &parser, Level &lvlLo,
return success();
}
+/// Parses a level range in the form "$lo `to` $hi"
+/// or simply "$lo" if $hi - $lo = 1
static ParseResult parseLevelRange(OpAsmParser &parser, IntegerAttr &lvlLoAttr,
IntegerAttr &lvlHiAttr) {
Level lvlLo, lvlHi;
@@ -2004,6 +2008,8 @@ static ParseResult parseLevelRange(OpAsmParser &parser, IntegerAttr &lvlLoAttr,
return success();
}
+/// Prints a level range in the form "$lo `to` $hi"
+/// or simply "$lo" if $hi - $lo = 1
static void printLevelRange(AsmPrinter &p, Level lo, Level hi) {
if (lo + 1 == hi)
@@ -2012,6 +2018,8 @@ static void printLevelRange(AsmPrinter &p, Level lo, Level hi) {
p << lo << " to " << hi;
}
+/// Prints a level range in the form "$lo `to` $hi"
+/// or simply "$lo" if $hi - $lo = 1
static void printLevelRange(OpAsmPrinter &p, Operation *, IntegerAttr lvlLo,
IntegerAttr lvlHi) {
unsigned lo = lvlLo.getValue().getZExtValue();
diff --git a/mlir/test/Dialect/SparseTensor/invalid.mlir b/mlir/test/Dialect/SparseTensor/invalid.mlir
index 579625e22f067f..3fa696e1600a93 100644
--- a/mlir/test/Dialect/SparseTensor/invalid.mlir
+++ b/mlir/test/Dialect/SparseTensor/invalid.mlir
@@ -1023,8 +1023,8 @@ func.func @sparse_print(%arg0: tensor<10x10xf64>) {
}>
func.func @sparse_extract_iter_space(%sp : tensor<4x8xf32, #COO>, %it1 : !sparse_tensor.iterator<#COO, lvls = 2>) {
- // expected-error at +1 {{'sparse_tensor.iteration.extract_space' expect larger level upper bound than lower bound}}
- %l1 = sparse_tensor.iteration.extract_space %sp at %it1 lvls = 2 to 0 : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 2>
+ // expected-error at +1 {{'sparse_tensor.extract_iteration_space' expect larger level upper bound than lower bound}}
+ %l1 = sparse_tensor.extract_iteration_space %sp at %it1 lvls = 2 to 0 : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 2>
return
}
@@ -1038,8 +1038,8 @@ func.func @sparse_extract_iter_space(%sp : tensor<4x8xf32, #COO>, %it1 : !sparse
}>
func.func @sparse_extract_iter_space(%sp : tensor<4x8xf32, #COO>, %it1 : !sparse_tensor.iterator<#COO, lvls = 0>) {
- // expected-error at +1 {{'sparse_tensor.iteration.extract_space' op parent iterator should be specified iff level lower bound equals 0}}
- %l1 = sparse_tensor.iteration.extract_space %sp at %it1 lvls = 0 : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 0>
+ // expected-error at +1 {{'sparse_tensor.extract_iteration_space' op parent iterator should be specified iff level lower bound equals 0}}
+ %l1 = sparse_tensor.extract_iteration_space %sp at %it1 lvls = 0 : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 0>
return
}
@@ -1053,8 +1053,8 @@ func.func @sparse_extract_iter_space(%sp : tensor<4x8xf32, #COO>, %it1 : !sparse
}>
func.func @sparse_extract_iter_space(%sp : tensor<4x8xf32, #COO>) {
- // expected-error at +1 {{'sparse_tensor.iteration.extract_space' op parent iterator should be specified iff level lower bound equals 0}}
- %l1 = sparse_tensor.iteration.extract_space %sp lvls = 1 : tensor<4x8xf32, #COO>
+ // expected-error at +1 {{'sparse_tensor.extract_iteration_space' op parent iterator should be specified iff level lower bound equals 0}}
+ %l1 = sparse_tensor.extract_iteration_space %sp lvls = 1 : tensor<4x8xf32, #COO>
return
}
@@ -1075,8 +1075,8 @@ func.func @sparse_extract_iter_space(%sp : tensor<4x8xf32, #COO>) {
}>
func.func @sparse_extract_iter_space(%sp : tensor<4x8xf32, #COO>, %it1 : !sparse_tensor.iterator<#CSR, lvls = 0>) {
- // expected-error at +1 {{'sparse_tensor.iteration.extract_space' op mismatch in parent iterator encoding and iteration space encoding.}}
- %l1 = sparse_tensor.iteration.extract_space %sp at %it1 lvls = 1 : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#CSR, lvls = 0>
+ // expected-error at +1 {{'sparse_tensor.extract_iteration_space' op mismatch in parent iterator encoding and iteration space encoding.}}
+ %l1 = sparse_tensor.extract_iteration_space %sp at %it1 lvls = 1 : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#CSR, lvls = 0>
return
}
@@ -1090,7 +1090,7 @@ func.func @sparse_extract_iter_space(%sp : tensor<4x8xf32, #COO>, %it1 : !sparse
}>
func.func @sparse_extract_iter_space(%sp : tensor<4x8xf32, #COO>, %it1 : !sparse_tensor.iterator<#COO, lvls = 0>) {
- // expected-error at +1 {{'sparse_tensor.iteration.extract_space' op parent iterator should be used to extract an iteration space from a consecutive level.}}
- %l1 = sparse_tensor.iteration.extract_space %sp at %it1 lvls = 2 : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 0>
+ // expected-error at +1 {{'sparse_tensor.extract_iteration_space' op parent iterator should be used to extract an iteration space from a consecutive level.}}
+ %l1 = sparse_tensor.extract_iteration_space %sp at %it1 lvls = 2 : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 0>
return
}
diff --git a/mlir/test/Dialect/SparseTensor/roundtrip.mlir b/mlir/test/Dialect/SparseTensor/roundtrip.mlir
index 6c0887ef4d826d..d34071279e5129 100644
--- a/mlir/test/Dialect/SparseTensor/roundtrip.mlir
+++ b/mlir/test/Dialect/SparseTensor/roundtrip.mlir
@@ -751,15 +751,15 @@ func.func @sparse_has_runtime() -> i1 {
// CHECK-LABEL: func.func @sparse_extract_iter_space(
// CHECK-SAME: %[[VAL_0:.*]]: tensor<4x8xf32, #sparse{{[0-9]*}}>,
// CHECK-SAME: %[[VAL_1:.*]]: !sparse_tensor.iterator<#sparse{{[0-9]*}}, lvls = 0>)
-// CHECK: %[[VAL_2:.*]] = sparse_tensor.iteration.extract_space %[[VAL_0]] lvls = 0
-// CHECK: %[[VAL_3:.*]] = sparse_tensor.iteration.extract_space %[[VAL_0]] at %[[VAL_1]] lvls = 1
+// CHECK: %[[VAL_2:.*]] = sparse_tensor.extract_iteration_space %[[VAL_0]] lvls = 0
+// CHECK: %[[VAL_3:.*]] = sparse_tensor.extract_iteration_space %[[VAL_0]] at %[[VAL_1]] lvls = 1
// CHECK: return %[[VAL_2]], %[[VAL_3]] : !sparse_tensor.iter_space<#sparse{{[0-9]*}}, lvls = 0>, !sparse_tensor.iter_space<#sparse{{[0-9]*}}, lvls = 1>
// CHECK: }
func.func @sparse_extract_iter_space(%sp : tensor<4x8xf32, #COO>, %it1 : !sparse_tensor.iterator<#COO, lvls = 0>)
-> (!sparse_tensor.iter_space<#COO, lvls = 0>, !sparse_tensor.iter_space<#COO, lvls = 1>) {
// Extracting the iteration space for the first level needs no parent iterator.
- %l1 = sparse_tensor.iteration.extract_space %sp lvls = 0 : tensor<4x8xf32, #COO>
+ %l1 = sparse_tensor.extract_iteration_space %sp lvls = 0 : tensor<4x8xf32, #COO>
// Extracting the iteration space for the second level needs a parent iterator.
- %l2 = sparse_tensor.iteration.extract_space %sp at %it1 lvls = 1 : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 0>
+ %l2 = sparse_tensor.extract_iteration_space %sp at %it1 lvls = 1 : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 0>
return %l1, %l2 : !sparse_tensor.iter_space<#COO, lvls = 0>, !sparse_tensor.iter_space<#COO, lvls = 1>
}
>From 71aa24c337ee2937ac45a7643be0aa87e06b51a2 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Tue, 16 Apr 2024 17:30:19 +0000
Subject: [PATCH 3/3] address comments
---
.../mlir/Dialect/SparseTensor/IR/SparseTensorOps.td | 13 +++++++++++--
.../Dialect/SparseTensor/IR/SparseTensorTypes.td | 4 +++-
2 files changed, 14 insertions(+), 3 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index 38722990b54450..cf21032abe71a4 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -1443,10 +1443,13 @@ def ExtractIterSpaceOp : SparseTensor_Op<"extract_iteration_space",
let results = (outs AnySparseIterSpace:$resultSpace);
- let summary = "Extract an iteration space from a sparse tensor between certain levels";
+ let summary = "Extracts an iteration space from a sparse tensor between certain levels";
let description = [{
Extracts a `!sparse_tensor.iter_space` from a sparse tensor between
- certian (consecutive) levels.
+ certian (consecutive) levels. For sparse levels, it is usually done by
+ loading a postion range from the underlying sparse tensor storage.
+ E.g., for a compressed level, the iteration space is extracted by
+ [pos[i], pos[i+1]) supposing the the parent iterator points at `i`.
`tensor`: the input sparse tensor that defines the iteration space.
`parentIter`: the iterator for the previous level, at which the iteration space
@@ -1455,6 +1458,12 @@ def ExtractIterSpaceOp : SparseTensor_Op<"extract_iteration_space",
the returned iteration space covers. `hiLvl - loLvl` defines the dimension of the
iteration space.
+ The type of returned the value is automatically inferred to
+ `!sparse_tensor.iter_space<#INPUT_ENCODING, lvls = $loLvl to $hiLvl>`.
+ The returned iteration space can then be iterated over by
+ `sparse_tensor.iterate` operations to visit every stored element
+ (usually nonzeros) in the input sparse tensor.
+
Example:
```mlir
// Extracts a 1-D iteration space from a COO tensor at level 1.
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td
index 264a0a5b3bee6c..79113d8778743c 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td
@@ -81,7 +81,9 @@ def SparseTensor_IterSpace : SparseTensor_Type<"IterSpace"> {
let description = [{
A sparse iteration space that represents an abstract N-D (sparse) iteration space
- extracted from a sparse tensor.
+ extracted from a sparse tensor, i.e., a set of (crd_0, crd_1, ..., crd_N) for
+ every stored element (usually nonzeros) in a sparse tensor between the specified
+ [$loLvl, $hiLvl) levels.
Examples:
More information about the Mlir-commits
mailing list