[Mlir-commits] [mlir] [mlir][sparse] introduce `sparse_tensor.extract_iteration_space` operation. (PR #88554)

Peiming Liu llvmlistbot at llvm.org
Tue Apr 16 10:30:34 PDT 2024


https://github.com/PeimingLiu updated https://github.com/llvm/llvm-project/pull/88554

>From 43e96384a5b5a51a9dc8704bb1b003ae35f16d17 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Fri, 12 Apr 2024 17:20:40 +0000
Subject: [PATCH 1/3] [mlir][sparse] introduce `sparse_tensor.extract_space`
 operation that extracts a sparse iteration space to iterate over.

---
 .../SparseTensor/IR/SparseTensorOps.td        |  51 +++++++++
 .../SparseTensor/IR/SparseTensorTypes.td      |  95 ++++++++++++++++
 .../SparseTensor/IR/SparseTensorDialect.cpp   | 102 ++++++++++++++++++
 mlir/test/Dialect/SparseTensor/invalid.mlir   |  82 ++++++++++++++
 mlir/test/Dialect/SparseTensor/roundtrip.mlir |  25 +++++
 5 files changed, 355 insertions(+)

diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index 0cfc64f9988a0a..b8e4edc8c8537b 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -1430,6 +1430,57 @@ def SparseTensor_ForeachOp : SparseTensor_Op<"foreach",
   let hasVerifier = 1;
 }
 
+//===----------------------------------------------------------------------===//
+// Sparse Tensor Iteration Operations.
+//===----------------------------------------------------------------------===//
+
+def ExtractIterSpaceOp : SparseTensor_Op<"iteration.extract_space",
+    [Pure, DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
+
+  let arguments = (ins AnySparseTensor:$tensor,
+                       Optional<AnySparseIterator>:$parentIter,
+                       LevelAttr:$loLvl, LevelAttr:$hiLvl);
+
+  let results = (outs AnySparseIterSpace:$resultSpace);
+
+  let summary = "Extract an iteration space from a sparse tensor between certain levels";
+  let description = [{
+      Extracts a `!sparse_tensor.iter_space` from a sparse tensor between
+      certian (consecutive) levels.
+
+      `tensor`: the input sparse tensor that defines the iteration space.
+      `parentIter`: the iterator for the previous level, at which the iteration space
+      at the current levels will be extracted.
+      `loLvl`, `hiLvl`: the level range between [loLvl, hiLvl) in the input tensor that
+      the returned iteration space covers. `hiLvl - loLvl` defines the dimension of the
+      iteration space.
+
+      Example:
+      ```mlir
+      // Extracts a 1-D iteration space from a COO tensor at level 1.
+      %space = sparse_tensor.iteration.extract_space %sp at %it1 lvls = 1
+        : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 0>
+      ```
+  }];
+
+
+  let extraClassDeclaration = [{
+    std::pair<Level, Level> getLvlRange() {
+      return std::make_pair(getLoLvl(), getHiLvl());
+    }
+    unsigned getSpaceDim() {
+      return getHiLvl() - getLoLvl();
+    }
+    ArrayRef<::mlir::sparse_tensor::LevelType> getSpaceLvlTypes() {
+      return getResultSpace().getType().getLvlTypes();
+    }
+  }];
+
+  let hasVerifier = 1;
+  let assemblyFormat = "$tensor (`at` $parentIter^)? `lvls` `=` custom<LevelRange>($loLvl, $hiLvl) "
+                       " attr-dict `:` type($tensor) (`,` type($parentIter)^)?";
+}
+
 //===----------------------------------------------------------------------===//
 // Sparse Tensor Debugging and Test-Only Operations.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td
index 185cff46ae25d5..264a0a5b3bee6c 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td
@@ -72,4 +72,99 @@ def SparseTensorStorageSpecifier
     : Type<CPred<"::llvm::isa<::mlir::sparse_tensor::StorageSpecifierType>($_self)">, "metadata",
           "::mlir::sparse_tensor::StorageSpecifierType">;
 
+//===----------------------------------------------------------------------===//
+// Sparse Tensor Iteration Types.
+//===----------------------------------------------------------------------===//
+
+def SparseTensor_IterSpace : SparseTensor_Type<"IterSpace"> {
+  let mnemonic = "iter_space";
+
+  let description = [{
+    A sparse iteration space that represents an abstract N-D (sparse) iteration space
+    extracted from a sparse tensor.
+
+    Examples:
+
+    ```mlir
+    // An iteration space extracted from a CSR tensor between levels [0, 2).
+    !iter_space<#CSR, lvls = 0 to 2>
+    ```
+  }];
+
+  let parameters = (ins
+     SparseTensorEncodingAttr : $encoding,
+     "Level" : $loLvl,
+     "Level" : $hiLvl
+  );
+
+  let extraClassDeclaration = [{
+     /// The the dimension of the iteration space.
+     unsigned getSpaceDim() const {
+       return getHiLvl() - getLoLvl();
+     }
+
+     /// Get the level types for the iteration space.
+     ArrayRef<LevelType> getLvlTypes() const {
+       return getEncoding().getLvlTypes().slice(getLoLvl(), getSpaceDim());
+     }
+
+     /// Whether the iteration space is unique (i.e., no duplicated coordinate).
+     bool isUnique() {
+       return !getLvlTypes().back().isa<LevelPropNonDefault::Nonunique>();
+     }
+
+     /// Get the corresponding iterator type.
+     ::mlir::sparse_tensor::IteratorType getIteratorType() const;
+  }];
+
+  let assemblyFormat="`<` $encoding `,` `lvls` `=` custom<LevelRange>($loLvl, $hiLvl) `>`";
+}
+
+def SparseTensor_Iterator : SparseTensor_Type<"Iterator"> {
+  let mnemonic = "iterator";
+
+  let description = [{
+    An iterator that points to the current element in the corresponding iteration space.
+
+    Examples:
+
+    ```mlir
+    // An iterator that iterates over a iteration space of type `!iter_space<#CSR, lvls = 0 to 2>`
+    !iterator<#CSR, lvls = 0 to 2>
+    ```
+  }];
+
+  let parameters = (ins
+     SparseTensorEncodingAttr : $encoding,
+     "Level" : $loLvl,
+     "Level" : $hiLvl
+  );
+
+  let extraClassDeclaration = [{
+     /// Get the corresponding iteration space type.
+     ::mlir::sparse_tensor::IterSpaceType getIterSpaceType() const;
+
+     unsigned getSpaceDim() const { return getIterSpaceType().getSpaceDim(); }
+     ArrayRef<LevelType> getLvlTypes() const { return getIterSpaceType().getLvlTypes(); }
+     bool isUnique() { return getIterSpaceType().isUnique(); }
+  }];
+
+  let assemblyFormat="`<` $encoding `,` `lvls` `=` custom<LevelRange>($loLvl, $hiLvl) `>`";
+}
+
+def IsSparseSparseIterSpaceTypePred
+    : CPred<"::llvm::isa<::mlir::sparse_tensor::IterSpaceType>($_self)">;
+
+def IsSparseSparseIteratorTypePred
+    : CPred<"::llvm::isa<::mlir::sparse_tensor::IteratorType>($_self)">;
+
+def AnySparseIterSpace
+    : Type<IsSparseSparseIterSpaceTypePred, "sparse iteration space",
+          "::mlir::sparse_tensor::IterSpaceType">;
+
+def AnySparseIterator
+    : Type<IsSparseSparseIteratorTypePred, "sparse iterator",
+          "::mlir::sparse_tensor::IteratorType">;
+
+
 #endif // SPARSETENSOR_TYPES
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index e9058394d33da5..a9d2d2b8826f37 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -30,6 +30,14 @@
 #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc"
 #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrEnums.cpp.inc"
 
+// Forward declarations, following custom print/parsing methods are referenced
+// by the generated code for SparseTensorTypes.td.
+static mlir::ParseResult parseLevelRange(mlir::AsmParser &,
+                                         mlir::sparse_tensor::Level &,
+                                         mlir::sparse_tensor::Level &);
+static void printLevelRange(mlir::AsmPrinter &, mlir::sparse_tensor::Level,
+                            mlir::sparse_tensor::Level);
+
 #define GET_TYPEDEF_CLASSES
 #include "mlir/Dialect/SparseTensor/IR/SparseTensorTypes.cpp.inc"
 
@@ -1953,6 +1961,100 @@ LogicalResult SortOp::verify() {
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// Sparse Tensor Iteration Operations.
+//===----------------------------------------------------------------------===//
+
+IterSpaceType IteratorType::getIterSpaceType() const {
+  return IterSpaceType::get(getContext(), getEncoding(), getLoLvl(),
+                            getHiLvl());
+}
+
+IteratorType IterSpaceType::getIteratorType() const {
+  return IteratorType::get(getContext(), getEncoding(), getLoLvl(), getHiLvl());
+}
+
+static ParseResult parseLevelRange(AsmParser &parser, Level &lvlLo,
+                                   Level &lvlHi) {
+  if (parser.parseInteger(lvlLo))
+    return failure();
+
+  if (succeeded(parser.parseOptionalKeyword("to"))) {
+    if (parser.parseInteger(lvlHi))
+      return failure();
+  } else {
+    lvlHi = lvlLo + 1;
+  }
+
+  if (lvlHi <= lvlLo)
+    parser.emitError(parser.getNameLoc(),
+                     "expect larger level upper bound than lower bound");
+
+  return success();
+}
+
+static ParseResult parseLevelRange(OpAsmParser &parser, IntegerAttr &lvlLoAttr,
+                                   IntegerAttr &lvlHiAttr) {
+  Level lvlLo, lvlHi;
+  if (parseLevelRange(parser, lvlLo, lvlHi))
+    return failure();
+
+  lvlLoAttr = IntegerAttr::get(parser.getBuilder().getIndexType(), lvlLo);
+  lvlHiAttr = IntegerAttr::get(parser.getBuilder().getIndexType(), lvlHi);
+  return success();
+}
+
+static void printLevelRange(AsmPrinter &p, Level lo, Level hi) {
+
+  if (lo + 1 == hi)
+    p << lo;
+  else
+    p << lo << " to " << hi;
+}
+
+static void printLevelRange(OpAsmPrinter &p, Operation *, IntegerAttr lvlLo,
+                            IntegerAttr lvlHi) {
+  unsigned lo = lvlLo.getValue().getZExtValue();
+  unsigned hi = lvlHi.getValue().getZExtValue();
+  printLevelRange(p, lo, hi);
+}
+
+LogicalResult ExtractIterSpaceOp::inferReturnTypes(
+    MLIRContext *ctx, std::optional<Location> loc, ValueRange ops,
+    DictionaryAttr attr, OpaqueProperties prop, RegionRange region,
+    SmallVectorImpl<mlir::Type> &ret) {
+
+  ExtractIterSpaceOp::Adaptor adaptor(ops, attr, prop, region);
+  SparseTensorType stt = getSparseTensorType(adaptor.getTensor());
+  ret.push_back(IterSpaceType::get(ctx, stt.getEncoding(), adaptor.getLoLvl(),
+                                   adaptor.getHiLvl()));
+  return success();
+}
+
+LogicalResult ExtractIterSpaceOp::verify() {
+  if (getLoLvl() >= getHiLvl())
+    return emitOpError("expected smaller level low than level high");
+
+  TypedValue<IteratorType> pIter = getParentIter();
+  if ((pIter && getLoLvl() == 0) || (!pIter && getLoLvl() != 0)) {
+    return emitOpError(
+        "parent iterator should be specified iff level lower bound equals 0");
+  }
+
+  if (pIter) {
+    IterSpaceType spaceTp = getResultSpace().getType();
+    if (pIter.getType().getEncoding() != spaceTp.getEncoding())
+      return emitOpError(
+          "mismatch in parent iterator encoding and iteration space encoding.");
+
+    if (spaceTp.getLoLvl() != pIter.getType().getHiLvl())
+      return emitOpError("parent iterator should be used to extract an "
+                         "iteration space from a consecutive level.");
+  }
+
+  return success();
+}
+
 /// Materialize a single constant operation from a given attribute value with
 /// the desired resultant type.
 Operation *SparseTensorDialect::materializeConstant(OpBuilder &builder,
diff --git a/mlir/test/Dialect/SparseTensor/invalid.mlir b/mlir/test/Dialect/SparseTensor/invalid.mlir
index 7f5c05190fc9a2..579625e22f067f 100644
--- a/mlir/test/Dialect/SparseTensor/invalid.mlir
+++ b/mlir/test/Dialect/SparseTensor/invalid.mlir
@@ -1012,3 +1012,85 @@ func.func @sparse_print(%arg0: tensor<10x10xf64>) {
   sparse_tensor.print %arg0 : tensor<10x10xf64>
   return
 }
+
+// -----
+
+#COO = #sparse_tensor.encoding<{
+  map = (i, j) -> (
+    i : compressed(nonunique),
+    j : singleton(soa)
+  )
+}>
+
+func.func @sparse_extract_iter_space(%sp : tensor<4x8xf32, #COO>, %it1 : !sparse_tensor.iterator<#COO, lvls = 2>) {
+  // expected-error at +1 {{'sparse_tensor.iteration.extract_space' expect larger level upper bound than lower bound}}
+  %l1 = sparse_tensor.iteration.extract_space %sp at %it1 lvls = 2 to 0 : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 2>
+  return
+}
+
+// -----
+
+#COO = #sparse_tensor.encoding<{
+  map = (i, j) -> (
+    i : compressed(nonunique),
+    j : singleton(soa)
+  )
+}>
+
+func.func @sparse_extract_iter_space(%sp : tensor<4x8xf32, #COO>, %it1 : !sparse_tensor.iterator<#COO, lvls = 0>) {
+  // expected-error at +1 {{'sparse_tensor.iteration.extract_space' op parent iterator should be specified iff level lower bound equals 0}}
+  %l1 = sparse_tensor.iteration.extract_space %sp at %it1 lvls = 0 : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 0>
+  return
+}
+
+// -----
+
+#COO = #sparse_tensor.encoding<{
+  map = (i, j) -> (
+    i : compressed(nonunique),
+    j : singleton(soa)
+  )
+}>
+
+func.func @sparse_extract_iter_space(%sp : tensor<4x8xf32, #COO>) {
+  // expected-error at +1 {{'sparse_tensor.iteration.extract_space' op parent iterator should be specified iff level lower bound equals 0}}
+  %l1 = sparse_tensor.iteration.extract_space %sp lvls = 1 : tensor<4x8xf32, #COO>
+  return
+}
+
+// -----
+
+#COO = #sparse_tensor.encoding<{
+  map = (i, j) -> (
+    i : compressed(nonunique),
+    j : singleton(soa)
+  )
+}>
+
+#CSR = #sparse_tensor.encoding<{
+  map = (i, j) -> (
+    i : dense,
+    j : compressed
+  )
+}>
+
+func.func @sparse_extract_iter_space(%sp : tensor<4x8xf32, #COO>, %it1 : !sparse_tensor.iterator<#CSR, lvls = 0>) {
+  // expected-error at +1 {{'sparse_tensor.iteration.extract_space' op mismatch in parent iterator encoding and iteration space encoding.}}
+  %l1 = sparse_tensor.iteration.extract_space %sp at %it1 lvls = 1 : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#CSR, lvls = 0>
+  return
+}
+
+// -----
+
+#COO = #sparse_tensor.encoding<{
+  map = (i, j) -> (
+    i : compressed(nonunique),
+    j : singleton(soa)
+  )
+}>
+
+func.func @sparse_extract_iter_space(%sp : tensor<4x8xf32, #COO>, %it1 : !sparse_tensor.iterator<#COO, lvls = 0>) {
+  // expected-error at +1 {{'sparse_tensor.iteration.extract_space' op parent iterator should be used to extract an iteration space from a consecutive level.}}
+  %l1 = sparse_tensor.iteration.extract_space %sp at %it1 lvls = 2 : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 0>
+  return
+}
diff --git a/mlir/test/Dialect/SparseTensor/roundtrip.mlir b/mlir/test/Dialect/SparseTensor/roundtrip.mlir
index 12f69c1d37b9cd..6c0887ef4d826d 100644
--- a/mlir/test/Dialect/SparseTensor/roundtrip.mlir
+++ b/mlir/test/Dialect/SparseTensor/roundtrip.mlir
@@ -738,3 +738,28 @@ func.func @sparse_has_runtime() -> i1 {
   %has_runtime = sparse_tensor.has_runtime_library
   return %has_runtime : i1
 }
+
+// -----
+
+#COO = #sparse_tensor.encoding<{
+  map = (i, j) -> (
+    i : compressed(nonunique),
+    j : singleton(soa)
+  )
+}>
+
+// CHECK-LABEL:   func.func @sparse_extract_iter_space(
+// CHECK-SAME:      %[[VAL_0:.*]]: tensor<4x8xf32, #sparse{{[0-9]*}}>,
+// CHECK-SAME:      %[[VAL_1:.*]]: !sparse_tensor.iterator<#sparse{{[0-9]*}}, lvls = 0>)
+// CHECK:           %[[VAL_2:.*]] = sparse_tensor.iteration.extract_space %[[VAL_0]] lvls = 0
+// CHECK:           %[[VAL_3:.*]] = sparse_tensor.iteration.extract_space %[[VAL_0]] at %[[VAL_1]] lvls = 1
+// CHECK:           return %[[VAL_2]], %[[VAL_3]] : !sparse_tensor.iter_space<#sparse{{[0-9]*}}, lvls = 0>, !sparse_tensor.iter_space<#sparse{{[0-9]*}}, lvls = 1>
+// CHECK:         }
+func.func @sparse_extract_iter_space(%sp : tensor<4x8xf32, #COO>, %it1 : !sparse_tensor.iterator<#COO, lvls = 0>)
+  -> (!sparse_tensor.iter_space<#COO, lvls = 0>, !sparse_tensor.iter_space<#COO, lvls = 1>) {
+  // Extracting the iteration space for the first level needs no parent iterator.
+  %l1 = sparse_tensor.iteration.extract_space %sp lvls = 0 : tensor<4x8xf32, #COO>
+  // Extracting the iteration space for the second level needs a parent iterator.
+  %l2 = sparse_tensor.iteration.extract_space %sp at %it1 lvls = 1 : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 0>
+  return %l1, %l2 : !sparse_tensor.iter_space<#COO, lvls = 0>, !sparse_tensor.iter_space<#COO, lvls = 1>
+}

>From 8f5f89145636b48ba4b62cbb2a007f6bcfe5d227 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Mon, 15 Apr 2024 16:58:06 +0000
Subject: [PATCH 2/3] update operation name.

---
 .../SparseTensor/IR/SparseTensorOps.td        |  2 +-
 .../SparseTensor/IR/SparseTensorDialect.cpp   |  8 ++++++++
 mlir/test/Dialect/SparseTensor/invalid.mlir   | 20 +++++++++----------
 mlir/test/Dialect/SparseTensor/roundtrip.mlir |  8 ++++----
 4 files changed, 23 insertions(+), 15 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index b8e4edc8c8537b..38722990b54450 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -1434,7 +1434,7 @@ def SparseTensor_ForeachOp : SparseTensor_Op<"foreach",
 // Sparse Tensor Iteration Operations.
 //===----------------------------------------------------------------------===//
 
-def ExtractIterSpaceOp : SparseTensor_Op<"iteration.extract_space",
+def ExtractIterSpaceOp : SparseTensor_Op<"extract_iteration_space",
     [Pure, DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
 
   let arguments = (ins AnySparseTensor:$tensor,
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index a9d2d2b8826f37..516b0943bdcfac 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -1974,6 +1974,8 @@ IteratorType IterSpaceType::getIteratorType() const {
   return IteratorType::get(getContext(), getEncoding(), getLoLvl(), getHiLvl());
 }
 
+/// Parses a level range in the form "$lo `to` $hi"
+/// or simply "$lo" if $hi - $lo = 1
 static ParseResult parseLevelRange(AsmParser &parser, Level &lvlLo,
                                    Level &lvlHi) {
   if (parser.parseInteger(lvlLo))
@@ -1993,6 +1995,8 @@ static ParseResult parseLevelRange(AsmParser &parser, Level &lvlLo,
   return success();
 }
 
+/// Parses a level range in the form "$lo `to` $hi"
+/// or simply "$lo" if $hi - $lo = 1
 static ParseResult parseLevelRange(OpAsmParser &parser, IntegerAttr &lvlLoAttr,
                                    IntegerAttr &lvlHiAttr) {
   Level lvlLo, lvlHi;
@@ -2004,6 +2008,8 @@ static ParseResult parseLevelRange(OpAsmParser &parser, IntegerAttr &lvlLoAttr,
   return success();
 }
 
+/// Prints a level range in the form "$lo `to` $hi"
+/// or simply "$lo" if $hi - $lo = 1
 static void printLevelRange(AsmPrinter &p, Level lo, Level hi) {
 
   if (lo + 1 == hi)
@@ -2012,6 +2018,8 @@ static void printLevelRange(AsmPrinter &p, Level lo, Level hi) {
     p << lo << " to " << hi;
 }
 
+/// Prints a level range in the form "$lo `to` $hi"
+/// or simply "$lo" if $hi - $lo = 1
 static void printLevelRange(OpAsmPrinter &p, Operation *, IntegerAttr lvlLo,
                             IntegerAttr lvlHi) {
   unsigned lo = lvlLo.getValue().getZExtValue();
diff --git a/mlir/test/Dialect/SparseTensor/invalid.mlir b/mlir/test/Dialect/SparseTensor/invalid.mlir
index 579625e22f067f..3fa696e1600a93 100644
--- a/mlir/test/Dialect/SparseTensor/invalid.mlir
+++ b/mlir/test/Dialect/SparseTensor/invalid.mlir
@@ -1023,8 +1023,8 @@ func.func @sparse_print(%arg0: tensor<10x10xf64>) {
 }>
 
 func.func @sparse_extract_iter_space(%sp : tensor<4x8xf32, #COO>, %it1 : !sparse_tensor.iterator<#COO, lvls = 2>) {
-  // expected-error at +1 {{'sparse_tensor.iteration.extract_space' expect larger level upper bound than lower bound}}
-  %l1 = sparse_tensor.iteration.extract_space %sp at %it1 lvls = 2 to 0 : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 2>
+  // expected-error at +1 {{'sparse_tensor.extract_iteration_space' expect larger level upper bound than lower bound}}
+  %l1 = sparse_tensor.extract_iteration_space %sp at %it1 lvls = 2 to 0 : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 2>
   return
 }
 
@@ -1038,8 +1038,8 @@ func.func @sparse_extract_iter_space(%sp : tensor<4x8xf32, #COO>, %it1 : !sparse
 }>
 
 func.func @sparse_extract_iter_space(%sp : tensor<4x8xf32, #COO>, %it1 : !sparse_tensor.iterator<#COO, lvls = 0>) {
-  // expected-error at +1 {{'sparse_tensor.iteration.extract_space' op parent iterator should be specified iff level lower bound equals 0}}
-  %l1 = sparse_tensor.iteration.extract_space %sp at %it1 lvls = 0 : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 0>
+  // expected-error at +1 {{'sparse_tensor.extract_iteration_space' op parent iterator should be specified iff level lower bound equals 0}}
+  %l1 = sparse_tensor.extract_iteration_space %sp at %it1 lvls = 0 : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 0>
   return
 }
 
@@ -1053,8 +1053,8 @@ func.func @sparse_extract_iter_space(%sp : tensor<4x8xf32, #COO>, %it1 : !sparse
 }>
 
 func.func @sparse_extract_iter_space(%sp : tensor<4x8xf32, #COO>) {
-  // expected-error at +1 {{'sparse_tensor.iteration.extract_space' op parent iterator should be specified iff level lower bound equals 0}}
-  %l1 = sparse_tensor.iteration.extract_space %sp lvls = 1 : tensor<4x8xf32, #COO>
+  // expected-error at +1 {{'sparse_tensor.extract_iteration_space' op parent iterator should be specified iff level lower bound equals 0}}
+  %l1 = sparse_tensor.extract_iteration_space %sp lvls = 1 : tensor<4x8xf32, #COO>
   return
 }
 
@@ -1075,8 +1075,8 @@ func.func @sparse_extract_iter_space(%sp : tensor<4x8xf32, #COO>) {
 }>
 
 func.func @sparse_extract_iter_space(%sp : tensor<4x8xf32, #COO>, %it1 : !sparse_tensor.iterator<#CSR, lvls = 0>) {
-  // expected-error at +1 {{'sparse_tensor.iteration.extract_space' op mismatch in parent iterator encoding and iteration space encoding.}}
-  %l1 = sparse_tensor.iteration.extract_space %sp at %it1 lvls = 1 : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#CSR, lvls = 0>
+  // expected-error at +1 {{'sparse_tensor.extract_iteration_space' op mismatch in parent iterator encoding and iteration space encoding.}}
+  %l1 = sparse_tensor.extract_iteration_space %sp at %it1 lvls = 1 : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#CSR, lvls = 0>
   return
 }
 
@@ -1090,7 +1090,7 @@ func.func @sparse_extract_iter_space(%sp : tensor<4x8xf32, #COO>, %it1 : !sparse
 }>
 
 func.func @sparse_extract_iter_space(%sp : tensor<4x8xf32, #COO>, %it1 : !sparse_tensor.iterator<#COO, lvls = 0>) {
-  // expected-error at +1 {{'sparse_tensor.iteration.extract_space' op parent iterator should be used to extract an iteration space from a consecutive level.}}
-  %l1 = sparse_tensor.iteration.extract_space %sp at %it1 lvls = 2 : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 0>
+  // expected-error at +1 {{'sparse_tensor.extract_iteration_space' op parent iterator should be used to extract an iteration space from a consecutive level.}}
+  %l1 = sparse_tensor.extract_iteration_space %sp at %it1 lvls = 2 : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 0>
   return
 }
diff --git a/mlir/test/Dialect/SparseTensor/roundtrip.mlir b/mlir/test/Dialect/SparseTensor/roundtrip.mlir
index 6c0887ef4d826d..d34071279e5129 100644
--- a/mlir/test/Dialect/SparseTensor/roundtrip.mlir
+++ b/mlir/test/Dialect/SparseTensor/roundtrip.mlir
@@ -751,15 +751,15 @@ func.func @sparse_has_runtime() -> i1 {
 // CHECK-LABEL:   func.func @sparse_extract_iter_space(
 // CHECK-SAME:      %[[VAL_0:.*]]: tensor<4x8xf32, #sparse{{[0-9]*}}>,
 // CHECK-SAME:      %[[VAL_1:.*]]: !sparse_tensor.iterator<#sparse{{[0-9]*}}, lvls = 0>)
-// CHECK:           %[[VAL_2:.*]] = sparse_tensor.iteration.extract_space %[[VAL_0]] lvls = 0
-// CHECK:           %[[VAL_3:.*]] = sparse_tensor.iteration.extract_space %[[VAL_0]] at %[[VAL_1]] lvls = 1
+// CHECK:           %[[VAL_2:.*]] = sparse_tensor.extract_iteration_space %[[VAL_0]] lvls = 0
+// CHECK:           %[[VAL_3:.*]] = sparse_tensor.extract_iteration_space %[[VAL_0]] at %[[VAL_1]] lvls = 1
 // CHECK:           return %[[VAL_2]], %[[VAL_3]] : !sparse_tensor.iter_space<#sparse{{[0-9]*}}, lvls = 0>, !sparse_tensor.iter_space<#sparse{{[0-9]*}}, lvls = 1>
 // CHECK:         }
 func.func @sparse_extract_iter_space(%sp : tensor<4x8xf32, #COO>, %it1 : !sparse_tensor.iterator<#COO, lvls = 0>)
   -> (!sparse_tensor.iter_space<#COO, lvls = 0>, !sparse_tensor.iter_space<#COO, lvls = 1>) {
   // Extracting the iteration space for the first level needs no parent iterator.
-  %l1 = sparse_tensor.iteration.extract_space %sp lvls = 0 : tensor<4x8xf32, #COO>
+  %l1 = sparse_tensor.extract_iteration_space %sp lvls = 0 : tensor<4x8xf32, #COO>
   // Extracting the iteration space for the second level needs a parent iterator.
-  %l2 = sparse_tensor.iteration.extract_space %sp at %it1 lvls = 1 : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 0>
+  %l2 = sparse_tensor.extract_iteration_space %sp at %it1 lvls = 1 : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 0>
   return %l1, %l2 : !sparse_tensor.iter_space<#COO, lvls = 0>, !sparse_tensor.iter_space<#COO, lvls = 1>
 }

>From 71aa24c337ee2937ac45a7643be0aa87e06b51a2 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Tue, 16 Apr 2024 17:30:19 +0000
Subject: [PATCH 3/3] address comments

---
 .../mlir/Dialect/SparseTensor/IR/SparseTensorOps.td | 13 +++++++++++--
 .../Dialect/SparseTensor/IR/SparseTensorTypes.td    |  4 +++-
 2 files changed, 14 insertions(+), 3 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index 38722990b54450..cf21032abe71a4 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -1443,10 +1443,13 @@ def ExtractIterSpaceOp : SparseTensor_Op<"extract_iteration_space",
 
   let results = (outs AnySparseIterSpace:$resultSpace);
 
-  let summary = "Extract an iteration space from a sparse tensor between certain levels";
+  let summary = "Extracts an iteration space from a sparse tensor between certain levels";
   let description = [{
       Extracts a `!sparse_tensor.iter_space` from a sparse tensor between
-      certian (consecutive) levels.
+      certian (consecutive) levels. For sparse levels, it is usually done by
+      loading a postion range from the underlying sparse tensor storage.
+      E.g., for a compressed level, the iteration space is extracted by
+      [pos[i], pos[i+1]) supposing the the parent iterator points at `i`.
 
       `tensor`: the input sparse tensor that defines the iteration space.
       `parentIter`: the iterator for the previous level, at which the iteration space
@@ -1455,6 +1458,12 @@ def ExtractIterSpaceOp : SparseTensor_Op<"extract_iteration_space",
       the returned iteration space covers. `hiLvl - loLvl` defines the dimension of the
       iteration space.
 
+      The type of returned the value is automatically inferred to
+      `!sparse_tensor.iter_space<#INPUT_ENCODING, lvls = $loLvl to $hiLvl>`.
+      The returned iteration space can then be iterated over by
+      `sparse_tensor.iterate` operations to visit every stored element
+      (usually nonzeros) in the input sparse tensor.
+
       Example:
       ```mlir
       // Extracts a 1-D iteration space from a COO tensor at level 1.
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td
index 264a0a5b3bee6c..79113d8778743c 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td
@@ -81,7 +81,9 @@ def SparseTensor_IterSpace : SparseTensor_Type<"IterSpace"> {
 
   let description = [{
     A sparse iteration space that represents an abstract N-D (sparse) iteration space
-    extracted from a sparse tensor.
+    extracted from a sparse tensor, i.e., a set of (crd_0, crd_1, ..., crd_N) for
+    every stored element (usually nonzeros) in a sparse tensor between the specified
+    [$loLvl, $hiLvl) levels.
 
     Examples:
 



More information about the Mlir-commits mailing list