[Mlir-commits] [mlir] [mlir][sparse] introduce `sparse_tensor.coiterate` operation. (PR #101100)

Peiming Liu llvmlistbot at llvm.org
Wed Jul 31 11:04:46 PDT 2024


https://github.com/PeimingLiu updated https://github.com/llvm/llvm-project/pull/101100

>From 37393b7a27acd15bf482117d0689461d9e37624c Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Mon, 29 Jul 2024 22:08:17 +0000
Subject: [PATCH 1/2] [mlir][sparse] introduce `sparse_tensor.coiterate`
 operation.

---
 .../Dialect/SparseTensor/IR/SparseTensor.h    |  55 +++-
 .../SparseTensor/IR/SparseTensorAttrDefs.td   |  11 +-
 .../SparseTensor/IR/SparseTensorOps.td        | 123 +++++++-
 .../SparseTensor/IR/SparseTensorDialect.cpp   | 293 ++++++++++++++----
 .../Transforms/SparseSpaceCollapse.cpp        |   4 +-
 mlir/test/Dialect/SparseTensor/roundtrip.mlir |  35 ++-
 6 files changed, 430 insertions(+), 91 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
index 68ca036121520..388efd1c454b1 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
@@ -61,37 +61,62 @@ struct COOSegment {
 /// A simple wrapper to encode a bitset of (at most 64) levels, currently used
 /// by `sparse_tensor.iterate` operation for the set of levels on which the
 /// coordinates should be loaded.
-class LevelSet {
-  uint64_t bits = 0;
+class I64BitSet {
+  uint64_t storage = 0;
 
 public:
-  LevelSet() = default;
-  explicit LevelSet(uint64_t bits) : bits(bits) {}
-  operator uint64_t() const { return bits; }
+  using const_set_bits_iterator = llvm::const_set_bits_iterator_impl<I64BitSet>;
+  const_set_bits_iterator begin() const {
+    return const_set_bits_iterator(*this);
+  }
+  const_set_bits_iterator end() const {
+    return const_set_bits_iterator(*this, -1);
+  }
+  iterator_range<const_set_bits_iterator> bits() const {
+    return make_range(begin(), end());
+  }
+
+  I64BitSet() = default;
+  explicit I64BitSet(uint64_t bits) : storage(bits) {}
+  operator uint64_t() const { return storage; }
 
-  LevelSet &set(unsigned i) {
+  I64BitSet &set(unsigned i) {
     assert(i < 64);
-    bits |= static_cast<uint64_t>(0x01u) << i;
+    storage |= static_cast<uint64_t>(0x01u) << i;
     return *this;
   }
 
-  LevelSet &operator|=(LevelSet lhs) {
-    bits |= static_cast<uint64_t>(lhs);
+  I64BitSet &operator|=(I64BitSet lhs) {
+    storage |= static_cast<uint64_t>(lhs);
     return *this;
   }
 
-  LevelSet &lshift(unsigned offset) {
-    bits = bits << offset;
+  I64BitSet &lshift(unsigned offset) {
+    storage = storage << offset;
     return *this;
   }
 
+  // Needed by `llvm::const_set_bits_iterator_impl`.
+  int find_first() const { return min(); }
+  int find_next(unsigned prev) const {
+    if (prev >= max())
+      return -1;
+
+    uint64_t b = storage >> (prev + 1);
+    if (b == 0)
+      return -1;
+
+    return llvm::countr_zero(b) + prev + 1;
+  }
+
   bool operator[](unsigned i) const {
     assert(i < 64);
-    return (bits & (1 << i)) != 0;
+    return (storage & (1 << i)) != 0;
   }
-  unsigned max() const { return 64 - llvm::countl_zero(bits); }
-  unsigned count() const { return llvm::popcount(bits); }
-  bool empty() const { return bits == 0; }
+  unsigned min() const { return llvm::countr_zero(storage); }
+  unsigned max() const { return 64 - llvm::countl_zero(storage); }
+  unsigned count() const { return llvm::popcount(storage); }
+  bool empty() const { return storage == 0; }
 };
 
 } // namespace sparse_tensor
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
index 69b212cce4ceb..cb6c1b63e4e4b 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
@@ -24,16 +24,17 @@ class SparseTensor_Attr<string name,
 // sparse tensor levels.
 //===----------------------------------------------------------------------===//
 
-def LevelSetAttr :
-    TypedAttrBase<
-      I64, "IntegerAttr",
+def I64BitSetAttr : TypedAttrBase<I64, "IntegerAttr",
       And<[CPred<"::llvm::isa<::mlir::IntegerAttr>($_self)">,
            CPred<"::llvm::cast<::mlir::IntegerAttr>($_self).getType().isInteger(64)">]>,
       "LevelSet attribute"> {
-  let returnType = [{::mlir::sparse_tensor::LevelSet}];
-  let convertFromStorage = [{::mlir::sparse_tensor::LevelSet($_self.getValue().getZExtValue())}];
+  let returnType = [{::mlir::sparse_tensor::I64BitSet}];
+  let convertFromStorage = [{::mlir::sparse_tensor::I64BitSet($_self.getValue().getZExtValue())}];
 }
 
+def I64BitSetArrayAttr :
+    TypedArrayAttrBase<I64BitSetAttr, "I64BitSet array attribute">;
+
 //===----------------------------------------------------------------------===//
 // These attributes are just like `IndexAttr` except that they clarify whether
 // the index refers to a dimension (an axis of the semantic tensor) or a level
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index ff9858d5832ba..766b8627ea6cf 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -1306,7 +1306,7 @@ def SparseTensor_SelectOp : SparseTensor_Op<"select", [Pure, SameOperandsAndResu
 
 def SparseTensor_YieldOp : SparseTensor_Op<"yield", [Pure, Terminator,
     ParentOneOf<["BinaryOp", "UnaryOp", "ReduceOp", "SelectOp",
-                 "ForeachOp", "IterateOp"]>]> {
+                 "ForeachOp", "IterateOp", "CoIterateOp"]>]> {
   let summary = "Yield from sparse_tensor set-like operations";
   let description = [{
       Yields a value from within a `binary`, `unary`, `reduce`,
@@ -1629,14 +1629,14 @@ def IterateOp : SparseTensor_Op<"iterate",
 
   let arguments = (ins AnySparseIterSpace:$iterSpace,
                        Variadic<AnyType>:$initArgs,
-                       LevelSetAttr:$crdUsedLvls);
+                       I64BitSetAttr:$crdUsedLvls);
   let results = (outs Variadic<AnyType>:$results);
   let regions = (region SizedRegion<1>:$region);
 
   let skipDefaultBuilders = 1;
   let builders = [
     OpBuilder<(ins "Value":$iterSpace, "ValueRange":$initArgs)>,
-    OpBuilder<(ins "Value":$iterSpace, "ValueRange":$initArgs, "LevelSet" :$crdUsedLvls)>
+    OpBuilder<(ins "Value":$iterSpace, "ValueRange":$initArgs, "I64BitSet" :$crdUsedLvls)>
   ];
 
   let extraClassDeclaration = [{
@@ -1669,6 +1669,123 @@ def IterateOp : SparseTensor_Op<"iterate",
   let hasCustomAssemblyFormat = 1;
 }
 
+def SparseTensor_CoIterateOp : SparseTensor_Op<"coiterate",
+    [AttrSizedOperandSegments,
+     SingleBlockImplicitTerminator<"sparse_tensor::YieldOp">,
+     RecursiveMemoryEffects]> {
+  let summary = "CoIterates over a set of sparse iteration spaces";
+  let description = [{
+      The `sparse_tensor.coiterate` operation represents a loop (nest) over
+      the a set of iteration spaces.
+      The operation can have multiple regions, with each of them defining a
+      case to compute a result at the current iterations. The case condition
+      is defined solely based on the pattern of specified iterators.
+      For example:
+      ```mlir
+      %ret = sparse_tensor.coiterate (%sp1, %sp2) at(%coord) iter_args(%arg = %init)
+           : (!sparse_tensor.iter_space<#CSR, lvls = 0>,
+              !sparse_tensor.iter_space<#COO, lvls = 0>)
+           -> index
+      case %it1, _ {
+        // %coord is specifed in space %sp1 but *NOT* specified in space %sp2.
+      }
+      case %it1, %it2 {
+        // %coord is specifed in *BOTH* spaces %sp1 and %sp2.
+      }
+      ```
+
+      `sparse_tensor.coiterate` can also operate on loop-carried variables.
+      It returns the final values after loop termination.
+      The initial values of the variables are passed as additional SSA operands
+      to the iterator SSA value and used coordinate SSA values.
+      Each operation region has variadic arguments for specified (used), one argument
+      for each loop-carried variable, representing the value of the variable
+      at the current iteration, followed by a list of arguments for iterators.
+      The body region must contain exactly one block that terminates with
+      `sparse_tensor.yield`.
+
+      The results of an `sparse_tensor.coiterate` hold the final values after
+      the last iteration. If the `sparse_tensor.coiterate` defines any values,
+      a yield must be explicitly present in every region defined in the operation.
+      The number and types of the `sparse_tensor.coiterate` results must match
+      the initial values in the iter_args binding and the yield operands.
+
+
+      A `sparse_tensor.coiterate` example that does elementwise addition between two
+      sparse vectors.
+
+
+      ```mlir
+      %ret = sparse_tensor.coiterate (%sp1, %sp2) at(%coord) iter_args(%arg = %init)
+           : (!sparse_tensor.iter_space<#CSR, lvls = 0>,
+              !sparse_tensor.iter_space<#CSR, lvls = 0>)
+           -> tensor<?xindex, #CSR>
+      case %it1, _ {
+         // v = v1 + 0 = v1
+         %v1 = sparse_tensor.extract_value %t1 at %it1 : index
+         %yield = sparse_tensor.insert %v1 into %arg[%coord]
+         sparse_tensor.yield %yield
+      }
+      case _, %it2 {
+         // v = v2 + 0 = v2
+         %v2 = sparse_tensor.extract_value %t2 at %it2 : index
+         %yield = sparse_tensor.insert %v1 into %arg[%coord]
+         sparse_tensor.yield %yield
+      }
+      case %it1, %it2 {
+         // v = v1 + v2
+         %v1 = sparse_tensor.extract_value %t1 at %it1 : index
+         %v2 = sparse_tensor.extract_value %t2 at %it2 : index
+         %v = arith.addi %v1, %v2 : index
+         %yield = sparse_tensor.insert %v into %arg[%coord]
+         sparse_tensor.yield %yield
+      }
+      ```
+  }];
+
+  let arguments = (ins Variadic<AnySparseIterSpace>:$iterSpaces,
+                       Variadic<AnyType>:$initArgs,
+                       I64BitSetAttr:$crdUsedLvls,
+                       I64BitSetArrayAttr:$cases);
+  let results = (outs Variadic<AnyType>:$results);
+  let regions = (region VariadicRegion<SizedRegion<1>>:$caseRegions);
+
+  let extraClassDeclaration = [{
+    unsigned getSpaceDim() {
+      return llvm::cast<::mlir::sparse_tensor::IterSpaceType>(
+                 getIterSpaces().front().getType())
+          .getSpaceDim();
+    }
+    I64BitSet getRegionDefinedSpace(unsigned regionIdx) {
+      return I64BitSet(llvm::cast<IntegerAttr>(getCases()[regionIdx])
+                           .getValue().getZExtValue());
+    }
+    // The block arguments starts with referenced coordinates, follows by
+    // user-provided iteration arguments and ends with iterators.
+    Block::BlockArgListType getCrds(unsigned regionIdx) {
+      return getRegion(regionIdx).getArguments()
+          .take_front(getCrdUsedLvls().count());
+    }
+    unsigned getNumRegionIterArgs(unsigned regionIdx) {
+      return getInitArgs().size();
+    }
+    Block::BlockArgListType getRegionIterArgs(unsigned regionIdx) {
+      return getRegion(regionIdx).getArguments()
+          .slice(getCrdUsedLvls().count(), getNumRegionIterArgs(regionIdx));
+    }
+    Block::BlockArgListType getRegionIterators(unsigned regionIdx) {
+      return getRegion(regionIdx).getArguments()
+          .take_back(getRegionDefinedSpace(regionIdx).count());
+    }
+  }];
+
+  // TODO:
+  // let hasVerifier = 1;
+  // let hasRegionVerifier = 1;
+  // let hasCanonicalizer = 1;
+  let hasCustomAssemblyFormat = 1;
+}
+
 //===----------------------------------------------------------------------===//
 // Sparse Tensor Debugging and Test-Only Operations.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 0a276d87f3bca..a6dfe611c457d 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -2131,10 +2131,82 @@ static void printLevelRange(OpAsmPrinter &p, Operation *, IntegerAttr lvlLo,
   printLevelRange(p, lo, hi);
 }
 
+/// Parses a list of `optional` defined list in the form of
+/// "(%val0, _, %val1, ...)", where `_` is used to annotate that the
+/// corresponding value is not defined (e.g., to represent an undefined
+/// coordinate in the sparse iteration space).
+static ParseResult parseOptionalDefinedList(
+    OpAsmParser &parser, OperationState &state, I64BitSet &definedSet,
+    SmallVectorImpl<OpAsmParser::Argument> &definedArgs,
+    unsigned maxCnt = std::numeric_limits<unsigned>::max(),
+    OpAsmParser::Delimiter delimiter = OpAsmParser::Delimiter::Paren) {
+  unsigned cnt = 0;
+  ParseResult crdList =
+      parser.parseCommaSeparatedList(delimiter, [&]() -> ParseResult {
+        if (parser.parseOptionalKeyword("_")) {
+          if (parser.parseArgument(definedArgs.emplace_back()))
+            return failure();
+          definedSet.set(cnt);
+        }
+        cnt += 1;
+        return success();
+      });
+
+  if (cnt > maxCnt)
+    return parser.emitError(parser.getNameLoc(),
+                            "parsed more value than expected.");
+
+  if (failed(crdList)) {
+    return parser.emitError(
+        parser.getNameLoc(),
+        "expecting SSA value or \"_\" for level coordinates");
+  }
+  assert(definedArgs.size() == definedSet.count());
+  return success();
+}
+
+static void printOptionalDefinedList(OpAsmPrinter &p, unsigned size,
+                                     Block::BlockArgListType blocksArgs,
+                                     I64BitSet definedSet) {
+  if (definedSet.empty())
+    return;
+
+  for (unsigned i = 0; i < size; i++) {
+    if (definedSet[i]) {
+      p << blocksArgs.front();
+      blocksArgs = blocksArgs.drop_front();
+    } else {
+      p << "_";
+    }
+    if (i != size - 1)
+      p << ", ";
+  }
+  assert(blocksArgs.empty());
+}
+
+static ParseResult
+parseUsedCoordList(OpAsmParser &parser, OperationState &state,
+                   SmallVectorImpl<OpAsmParser::Argument> &coords) {
+  // Parse "at(%crd0, _, ...)"
+  I64BitSet crdUsedLvlSet;
+  if (succeeded(parser.parseOptionalKeyword("at")) &&
+      failed(parseOptionalDefinedList(parser, state, crdUsedLvlSet, coords)))
+    return failure();
+
+  // Always use IndexType for the coordinate.
+  for (auto &coord : coords)
+    coord.type = parser.getBuilder().getIndexType();
+
+  // Set the CrdUsedLvl bitset.
+  state.addAttribute("crdUsedLvls",
+                     parser.getBuilder().getI64IntegerAttr(crdUsedLvlSet));
+  return success();
+}
+
 static ParseResult
-parseSparseSpaceLoop(OpAsmParser &parser, OperationState &state,
-                     SmallVectorImpl<OpAsmParser::Argument> &iterators,
-                     SmallVectorImpl<OpAsmParser::Argument> &iterArgs) {
+parseSparseIterateLoop(OpAsmParser &parser, OperationState &state,
+                       SmallVectorImpl<OpAsmParser::Argument> &iterators,
+                       SmallVectorImpl<OpAsmParser::Argument> &blockArgs) {
   SmallVector<OpAsmParser::UnresolvedOperand> spaces;
   SmallVector<OpAsmParser::UnresolvedOperand> initArgs;
 
@@ -2148,37 +2220,14 @@ parseSparseSpaceLoop(OpAsmParser &parser, OperationState &state,
         parser.getNameLoc(),
         "mismatch in number of sparse iterators and sparse spaces");
 
-  // Parse "at(%crd0, _, ...)"
-  LevelSet crdUsedLvlSet;
-  bool hasUsedCrds = succeeded(parser.parseOptionalKeyword("at"));
-  unsigned lvlCrdCnt = 0;
-  if (hasUsedCrds) {
-    ParseResult crdList = parser.parseCommaSeparatedList(
-        OpAsmParser::Delimiter::Paren, [&]() -> ParseResult {
-          if (parser.parseOptionalKeyword("_")) {
-            if (parser.parseArgument(iterArgs.emplace_back()))
-              return failure();
-            // Always use IndexType for the coordinate.
-            crdUsedLvlSet.set(lvlCrdCnt);
-            iterArgs.back().type = parser.getBuilder().getIndexType();
-          }
-          lvlCrdCnt += 1;
-          return success();
-        });
-    if (failed(crdList)) {
-      return parser.emitError(
-          parser.getNameLoc(),
-          "expecting SSA value or \"_\" for level coordinates");
-    }
-  }
-  // Set the CrdUsedLvl bitset.
-  state.addAttribute("crdUsedLvls",
-                     parser.getBuilder().getI64IntegerAttr(crdUsedLvlSet));
+  if (failed(parseUsedCoordList(parser, state, blockArgs)))
+    return failure();
+  size_t numCrds = blockArgs.size();
 
   // Parse "iter_args(%arg = %init, ...)"
   bool hasIterArgs = succeeded(parser.parseOptionalKeyword("iter_args"));
   if (hasIterArgs)
-    if (parser.parseAssignmentList(iterArgs, initArgs))
+    if (parser.parseAssignmentList(blockArgs, initArgs))
       return failure();
 
   SmallVector<Type> iterSpaceTps;
@@ -2196,10 +2245,6 @@ parseSparseSpaceLoop(OpAsmParser &parser, OperationState &state,
       return parser.emitError(parser.getNameLoc(),
                               "expected sparse_tensor.iter_space type for "
                               "iteration space operands");
-    if (hasUsedCrds && spaceTp.getSpaceDim() != lvlCrdCnt)
-      return parser.emitError(parser.getNameLoc(),
-                              "mismatch in number of iteration space dimension "
-                              "and specified coordinates");
     it.type = spaceTp.getIteratorType();
   }
 
@@ -2213,9 +2258,68 @@ parseSparseSpaceLoop(OpAsmParser &parser, OperationState &state,
     return failure();
 
   if (hasIterArgs) {
-    unsigned numCrds = crdUsedLvlSet.count();
     // Strip off leading args that used for coordinates.
-    MutableArrayRef args = MutableArrayRef(iterArgs).drop_front(numCrds);
+    MutableArrayRef args = MutableArrayRef(blockArgs).drop_front(numCrds);
+    if (args.size() != initArgs.size() || args.size() != state.types.size()) {
+      return parser.emitError(
+          parser.getNameLoc(),
+          "mismatch in number of iteration arguments and return values");
+    }
+
+    for (auto [it, init, tp] : llvm::zip_equal(args, initArgs, state.types)) {
+      it.type = tp;
+      if (parser.resolveOperand(init, tp, state.operands))
+        return failure();
+    }
+  }
+  return success();
+}
+
+static ParseResult
+parseSparseCoIterateLoop(OpAsmParser &parser, OperationState &state,
+                         SmallVectorImpl<Value> &spacesVals,
+                         SmallVectorImpl<OpAsmParser::Argument> &blockArgs) {
+
+  // Parse "(%spaces, ...)"
+  SmallVector<OpAsmParser::UnresolvedOperand> spaces;
+  if (parser.parseOperandList(spaces, OpAsmParser::Delimiter::Paren))
+    return failure();
+
+  if (failed(parseUsedCoordList(parser, state, blockArgs)))
+    return failure();
+  size_t numCrds = blockArgs.size();
+
+  // Parse "iter_args(%arg = %init, ...)"
+  SmallVector<OpAsmParser::UnresolvedOperand> initArgs;
+  bool hasIterArgs = succeeded(parser.parseOptionalKeyword("iter_args"));
+  if (hasIterArgs)
+    if (parser.parseAssignmentList(blockArgs, initArgs))
+      return failure();
+
+  SmallVector<Type> iterSpaceTps;
+  // parse ": (sparse_tensor.iter_space, ...) -> ret"
+  if (parser.parseColon() || parser.parseLParen() ||
+      parser.parseTypeList(iterSpaceTps) || parser.parseRParen())
+    return failure();
+
+  if (iterSpaceTps.size() != spaces.size())
+    return parser.emitError(parser.getNameLoc(),
+                            "mismatch in number of iteration space operands "
+                            "and iteration space types");
+
+  if (hasIterArgs)
+    if (parser.parseArrowTypeList(state.types))
+      return failure();
+
+  // Resolves input sparse iteration spaces.
+  if (parser.resolveOperands(spaces, iterSpaceTps, parser.getNameLoc(),
+                             spacesVals))
+    return failure();
+  state.operands.append(spacesVals);
+
+  if (hasIterArgs) {
+    // Strip off leading args that used for coordinates.
+    MutableArrayRef args = MutableArrayRef(blockArgs).drop_front(numCrds);
     if (args.size() != initArgs.size() || args.size() != state.types.size()) {
       return parser.emitError(
           parser.getNameLoc(),
@@ -2285,7 +2389,7 @@ struct RemoveUnusedLvlCrds : public OpRewritePattern<IterateOp> {
 
   LogicalResult matchAndRewrite(IterateOp iterateOp,
                                 PatternRewriter &rewriter) const override {
-    LevelSet newUsedLvls(0);
+    I64BitSet newUsedLvls(0);
     llvm::BitVector toRemove(iterateOp.getBody()->getNumArguments());
     for (unsigned i = 0, e = iterateOp.getSpaceDim(); i < e; i++) {
       if (auto crd = iterateOp.getLvlCrd(i)) {
@@ -2317,13 +2421,13 @@ void IterateOp::build(OpBuilder &builder, OperationState &odsState,
                       Value iterSpace, ValueRange initArgs) {
   unsigned rank = llvm::cast<IterSpaceType>(iterSpace.getType()).getSpaceDim();
   // All ones.
-  LevelSet set((1 << rank) - 1);
+  I64BitSet set((1 << rank) - 1);
   return build(builder, odsState, iterSpace, initArgs, set);
 }
 
 void IterateOp::build(OpBuilder &builder, OperationState &odsState,
                       Value iterSpace, ValueRange initArgs,
-                      LevelSet crdUsedLvls) {
+                      I64BitSet crdUsedLvls) {
   OpBuilder::InsertionGuard guard(builder);
 
   odsState.addOperands(iterSpace);
@@ -2353,7 +2457,7 @@ ParseResult IterateOp::parse(OpAsmParser &parser, OperationState &result) {
   OpAsmParser::UnresolvedOperand iterSpace;
 
   SmallVector<OpAsmParser::Argument> iters, iterArgs;
-  if (parseSparseSpaceLoop(parser, result, iters, iterArgs))
+  if (parseSparseIterateLoop(parser, result, iters, iterArgs))
     return failure();
   if (iters.size() != 1)
     return parser.emitError(parser.getNameLoc(),
@@ -2393,36 +2497,20 @@ static void printInitializationList(OpAsmPrinter &p,
   p << ")";
 }
 
-static void printUsedCrdsList(OpAsmPrinter &p, unsigned spaceDim,
-                              Block::BlockArgListType blocksArgs,
-                              LevelSet crdUsedLvls) {
-  if (crdUsedLvls.empty())
-    return;
-
-  p << " at(";
-  for (unsigned i = 0; i < spaceDim; i++) {
-    if (crdUsedLvls[i]) {
-      p << blocksArgs.front();
-      blocksArgs = blocksArgs.drop_front();
-    } else {
-      p << "_";
-    }
-    if (i != spaceDim - 1)
-      p << ", ";
-  }
-  assert(blocksArgs.empty());
-  p << ")";
-}
-
 void IterateOp::print(OpAsmPrinter &p) {
   p << " " << getIterator() << " in " << getIterSpace();
-  printUsedCrdsList(p, getSpaceDim(), getCrds(), getCrdUsedLvls());
+  if (!getCrdUsedLvls().empty()) {
+    p << " at(";
+    printOptionalDefinedList(p, getSpaceDim(), getCrds(), getCrdUsedLvls());
+    p << ")";
+  }
   printInitializationList(p, getRegionIterArgs(), getInitArgs(), " iter_args");
 
   p << " : " << getIterSpace().getType() << " ";
   if (!getInitArgs().empty())
-    p << "-> (" << getInitArgs().getTypes() << ") ";
+    p.printArrowTypeList(getInitArgs().getTypes());
 
+  p << " ";
   p.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
                 /*printBlockTerminators=*/!getInitArgs().empty());
 }
@@ -2495,13 +2583,88 @@ OperandRange IterateOp::getEntrySuccessorOperands(RegionBranchPoint point) {
 
 void IterateOp::getSuccessorRegions(RegionBranchPoint point,
                                     SmallVectorImpl<RegionSuccessor> &regions) {
-  // Both the operation itself and the region may be branching into the body or
-  // back into the operation itself.
+  // Both the operation itself and the region may be branching into the body
+  // or back into the operation itself.
   regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs()));
   // It is possible for loop not to enter the body.
   regions.push_back(RegionSuccessor(getResults()));
 }
 
+ParseResult CoIterateOp::parse(OpAsmParser &parser, OperationState &result) {
+
+  SmallVector<Value> spaces;
+  // The block argument list of each regions, it is arranged in the order of
+  // ([used coordinate list], [loop iterations args], [sparse iterator list]).
+  SmallVector<OpAsmParser::Argument> blockArgs;
+  if (parseSparseCoIterateLoop(parser, result, spaces, blockArgs))
+    return failure();
+
+  result.addAttribute("operandSegmentSizes",
+                      parser.getBuilder().getDenseI32ArrayAttr(
+                          {static_cast<int32_t>(spaces.size()),
+                           static_cast<int32_t>(result.types.size())}));
+
+  SmallVector<Attribute> cases;
+  while (succeeded(parser.parseOptionalKeyword("case"))) {
+    // Parse one region per case.
+    I64BitSet definedItSet;
+    SmallVector<OpAsmParser::Argument> definedIts;
+    if (parseOptionalDefinedList(parser, result, definedItSet, definedIts,
+                                 spaces.size(), OpAsmParser::Delimiter::None))
+      return failure();
+
+    cases.push_back(parser.getBuilder().getI64IntegerAttr(definedItSet));
+
+    for (auto [i, definedIdx] : llvm::enumerate(definedItSet.bits())) {
+      // Resolve the iterator type based on the iteration space type.
+      auto spaceTp = llvm::cast<IterSpaceType>(spaces[definedIdx].getType());
+      definedIts[i].type = spaceTp.getIteratorType();
+    }
+    definedIts.insert(definedIts.begin(), blockArgs.begin(), blockArgs.end());
+    Region *body = result.addRegion();
+    if (parser.parseRegion(*body, definedIts))
+      return failure();
+
+    CoIterateOp::ensureTerminator(*body, parser.getBuilder(), result.location);
+  }
+
+  result.addAttribute("cases", ArrayAttr::get(parser.getContext(), cases));
+
+  // Parse the optional attribute list.
+  if (parser.parseOptionalAttrDict(result.attributes))
+    return failure();
+
+  return success();
+}
+
+void CoIterateOp::print(OpAsmPrinter &p) {
+  p << " (";
+  llvm::interleaveComma(getIterSpaces(), p, [&](auto s) { p << s; });
+  p << ")";
+
+  if (!getCrdUsedLvls().empty()) {
+    p << " at(";
+    printOptionalDefinedList(p, getSpaceDim(), getCrds(0), getCrdUsedLvls());
+    p << ")";
+  }
+
+  printInitializationList(p, getRegionIterArgs(0), getInitArgs(), " iter_args");
+
+  p << " : (" << getIterSpaces().getTypes() << ")";
+  if (!getInitArgs().empty())
+    p.printArrowTypeList(getInitArgs().getTypes());
+
+  for (unsigned idx = 0, e = getRegions().size(); idx < e; idx++) {
+    p.printNewline();
+    p << "case ";
+    printOptionalDefinedList(p, getIterSpaces().size(), getRegionIterators(idx),
+                             getRegionDefinedSpace(idx));
+    p << " ";
+    p.printRegion(getRegion(idx), /*printEntryBlockArgs=*/false,
+                  /*printBlockTerminators=*/!getInitArgs().empty());
+  }
+}
+
 //===----------------------------------------------------------------------===//
 // Sparse Tensor Dialect Setups.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseSpaceCollapse.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseSpaceCollapse.cpp
index 924046fcd9961..f85c4761a8d52 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseSpaceCollapse.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseSpaceCollapse.cpp
@@ -141,10 +141,10 @@ void collapseSparseSpace(MutableArrayRef<CollapseSpaceInfo> toCollapse) {
   auto cloned = llvm::cast<IterateOp>(builder.clone(*innermost, mapper));
   builder.setInsertionPointToStart(cloned.getBody());
 
-  LevelSet crdUsedLvls;
+  I64BitSet crdUsedLvls;
   unsigned shift = 0, argIdx = 1;
   for (auto info : toCollapse.drop_back()) {
-    LevelSet set = info.loop.getCrdUsedLvls();
+    I64BitSet set = info.loop.getCrdUsedLvls();
     crdUsedLvls |= set.lshift(shift);
     shift += info.loop.getSpaceDim();
     for (BlockArgument crd : info.loop.getCrds()) {
diff --git a/mlir/test/Dialect/SparseTensor/roundtrip.mlir b/mlir/test/Dialect/SparseTensor/roundtrip.mlir
index 055709ee69eb7..ab861a2019dfa 100644
--- a/mlir/test/Dialect/SparseTensor/roundtrip.mlir
+++ b/mlir/test/Dialect/SparseTensor/roundtrip.mlir
@@ -801,7 +801,7 @@ func.func @sparse_extract_iter_space(%sp : tensor<4x8xf32, #COO>, %it1 : !sparse
 // CHECK-SAME:      %[[VAL_1:.*]]: index,
 // CHECK-SAME:      %[[VAL_2:.*]]: index) -> index {
 // CHECK:           %[[VAL_3:.*]] = sparse_tensor.extract_iteration_space %[[VAL_0]] lvls = 0 : tensor<4x8xf32, #sparse{{[0-9]*}}>
-// CHECK:           %[[VAL_4:.*]] = sparse_tensor.iterate %[[VAL_5:.*]] in %[[VAL_3]] at(%[[VAL_6:.*]]) iter_args(%[[VAL_7:.*]] = %[[VAL_1]]) : !sparse_tensor.iter_space<#sparse{{[0-9]*}}, lvls = 0> -> (index) {
+// CHECK:           %[[VAL_4:.*]] = sparse_tensor.iterate %[[VAL_5:.*]] in %[[VAL_3]] at(%[[VAL_6:.*]]) iter_args(%[[VAL_7:.*]] = %[[VAL_1]]) : !sparse_tensor.iter_space<#sparse{{[0-9]*}}, lvls = 0> -> index {
 // CHECK:             sparse_tensor.yield %[[VAL_7]] : index
 // CHECK:           }
 // CHECK:           return %[[VAL_4]] : index
@@ -813,3 +813,36 @@ func.func @sparse_iterate(%sp : tensor<4x8xf32, #COO>, %i : index, %j : index) -
   }
   return %r1 : index
 }
+
+
+// -----
+
+#COO = #sparse_tensor.encoding<{
+  map = (i, j) -> (
+    i : compressed(nonunique),
+    j : singleton(soa)
+  )
+}>
+
+
+// CHECK-LABEL:   func.func @sparse_coiteration(
+// CHECK-SAME:      %[[SP1:.*]]: !sparse_tensor.iter_space<#sparse, lvls = 0>,
+// CHECK-SAME:      %[[SP2:.*]]: !sparse_tensor.iter_space<#sparse, lvls = 1>) -> index {
+// CHECK:           %[[INIT:.*]] = arith.constant 0 : index
+// CHECK:           %[[RET:.*]] = sparse_tensor.coiterate (%[[SP1]], %[[SP2]]) at(%[[COORD:.*]]) iter_args(%[[ARG:.*]] = %[[INIT]])
+// CHECK:           case %[[VAL_6:.*]], _ {
+// CHECK:             sparse_tensor.yield %[[ARG]] : index
+// CHECK:           }
+// CHECK:           return %[[RET]] : index
+// CHECK:         }
+func.func @sparse_coiteration(%sp1 : !sparse_tensor.iter_space<#COO, lvls = 0>,
+                              %sp2 : !sparse_tensor.iter_space<#COO, lvls = 1>) -> index {
+  %init = arith.constant 0 : index
+  %ret = sparse_tensor.coiterate (%sp1, %sp2) at (%coord) iter_args(%arg = %init)
+       : (!sparse_tensor.iter_space<#COO, lvls = 0>, !sparse_tensor.iter_space<#COO, lvls = 1>)
+       -> index
+  case %it1, _ {
+    sparse_tensor.yield %arg : index
+  }
+  return %ret : index
+}

>From 3343713186b184ae86af94a2547a8001115519d8 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Wed, 31 Jul 2024 18:04:29 +0000
Subject: [PATCH 2/2] address comments.

---
 .../mlir/Dialect/SparseTensor/IR/SparseTensorOps.td   | 11 +++++------
 1 file changed, 5 insertions(+), 6 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index 766b8627ea6cf..f44dd7f8393a4 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -1673,13 +1673,12 @@ def SparseTensor_CoIterateOp : SparseTensor_Op<"coiterate",
     [AttrSizedOperandSegments,
      SingleBlockImplicitTerminator<"sparse_tensor::YieldOp">,
      RecursiveMemoryEffects]> {
-  let summary = "CoIterates over a set of sparse iteration spaces";
+  let summary = "Co-iterates over a set of sparse iteration spaces";
   let description = [{
       The `sparse_tensor.coiterate` operation represents a loop (nest) over
-      the a set of iteration spaces.
-      The operation can have multiple regions, with each of them defining a
-      case to compute a result at the current iterations. The case condition
-      is defined solely based on the pattern of specified iterators.
+      a set of iteration spaces. The operation can have multiple regions,
+      with each of them defining a case to compute a result at the current iterations.
+      The case condition is defined solely based on the pattern of specified iterators.
       For example:
       ```mlir
       %ret = sparse_tensor.coiterate (%sp1, %sp2) at(%coord) iter_args(%arg = %init)
@@ -1695,7 +1694,7 @@ def SparseTensor_CoIterateOp : SparseTensor_Op<"coiterate",
       ```
 
       `sparse_tensor.coiterate` can also operate on loop-carried variables.
-      It returns the final values after loop termination.
+      It returns the final value for each loop-carried variable after loop termination.
       The initial values of the variables are passed as additional SSA operands
       to the iterator SSA value and used coordinate SSA values.
       Each operation region has variadic arguments for specified (used), one argument



More information about the Mlir-commits mailing list