[Mlir-commits] [mlir] 785a24f - [mlir][sparse] introduce `sparse_tensor.coiterate` operation. (#101100)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Jul 31 15:14:30 PDT 2024


Author: Peiming Liu
Date: 2024-07-31T15:14:27-07:00
New Revision: 785a24f1561c610ecbce7cdfbff053e0a3a7caec

URL: https://github.com/llvm/llvm-project/commit/785a24f1561c610ecbce7cdfbff053e0a3a7caec
DIFF: https://github.com/llvm/llvm-project/commit/785a24f1561c610ecbce7cdfbff053e0a3a7caec.diff

LOG: [mlir][sparse] introduce `sparse_tensor.coiterate` operation. (#101100)

This PR introduces `sparse_tensor.coiterate` operation, which represents
a loop that traverses multiple sparse iteration space.

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
    mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
    mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
    mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/SparseSpaceCollapse.cpp
    mlir/test/Dialect/SparseTensor/invalid.mlir
    mlir/test/Dialect/SparseTensor/roundtrip.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
index 68ca036121520..388efd1c454b1 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
@@ -61,37 +61,62 @@ struct COOSegment {
 /// A simple wrapper to encode a bitset of (at most 64) levels, currently used
 /// by `sparse_tensor.iterate` operation for the set of levels on which the
 /// coordinates should be loaded.
-class LevelSet {
-  uint64_t bits = 0;
+class I64BitSet {
+  uint64_t storage = 0;
 
 public:
-  LevelSet() = default;
-  explicit LevelSet(uint64_t bits) : bits(bits) {}
-  operator uint64_t() const { return bits; }
+  using const_set_bits_iterator = llvm::const_set_bits_iterator_impl<I64BitSet>;
+  const_set_bits_iterator begin() const {
+    return const_set_bits_iterator(*this);
+  }
+  const_set_bits_iterator end() const {
+    return const_set_bits_iterator(*this, -1);
+  }
+  iterator_range<const_set_bits_iterator> bits() const {
+    return make_range(begin(), end());
+  }
+
+  I64BitSet() = default;
+  explicit I64BitSet(uint64_t bits) : storage(bits) {}
+  operator uint64_t() const { return storage; }
 
-  LevelSet &set(unsigned i) {
+  I64BitSet &set(unsigned i) {
     assert(i < 64);
-    bits |= static_cast<uint64_t>(0x01u) << i;
+    storage |= static_cast<uint64_t>(0x01u) << i;
     return *this;
   }
 
-  LevelSet &operator|=(LevelSet lhs) {
-    bits |= static_cast<uint64_t>(lhs);
+  I64BitSet &operator|=(I64BitSet lhs) {
+    storage |= static_cast<uint64_t>(lhs);
     return *this;
   }
 
-  LevelSet &lshift(unsigned offset) {
-    bits = bits << offset;
+  I64BitSet &lshift(unsigned offset) {
+    storage = storage << offset;
     return *this;
   }
 
+  // Needed by `llvm::const_set_bits_iterator_impl`.
+  int find_first() const { return min(); }
+  int find_next(unsigned prev) const {
+    if (prev >= max())
+      return -1;
+
+    uint64_t b = storage >> (prev + 1);
+    if (b == 0)
+      return -1;
+
+    return llvm::countr_zero(b) + prev + 1;
+  }
+
   bool operator[](unsigned i) const {
     assert(i < 64);
-    return (bits & (1 << i)) != 0;
+    return (storage & (1 << i)) != 0;
   }
-  unsigned max() const { return 64 - llvm::countl_zero(bits); }
-  unsigned count() const { return llvm::popcount(bits); }
-  bool empty() const { return bits == 0; }
+  unsigned min() const { return llvm::countr_zero(storage); }
+  unsigned max() const { return 64 - llvm::countl_zero(storage); }
+  unsigned count() const { return llvm::popcount(storage); }
+  bool empty() const { return storage == 0; }
 };
 
 } // namespace sparse_tensor

diff  --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
index 69b212cce4ceb..cb6c1b63e4e4b 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
@@ -24,16 +24,17 @@ class SparseTensor_Attr<string name,
 // sparse tensor levels.
 //===----------------------------------------------------------------------===//
 
-def LevelSetAttr :
-    TypedAttrBase<
-      I64, "IntegerAttr",
+def I64BitSetAttr : TypedAttrBase<I64, "IntegerAttr",
       And<[CPred<"::llvm::isa<::mlir::IntegerAttr>($_self)">,
            CPred<"::llvm::cast<::mlir::IntegerAttr>($_self).getType().isInteger(64)">]>,
       "LevelSet attribute"> {
-  let returnType = [{::mlir::sparse_tensor::LevelSet}];
-  let convertFromStorage = [{::mlir::sparse_tensor::LevelSet($_self.getValue().getZExtValue())}];
+  let returnType = [{::mlir::sparse_tensor::I64BitSet}];
+  let convertFromStorage = [{::mlir::sparse_tensor::I64BitSet($_self.getValue().getZExtValue())}];
 }
 
+def I64BitSetArrayAttr :
+    TypedArrayAttrBase<I64BitSetAttr, "I64BitSet array attribute">;
+
 //===----------------------------------------------------------------------===//
 // These attributes are just like `IndexAttr` except that they clarify whether
 // the index refers to a dimension (an axis of the semantic tensor) or a level

diff  --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index ff9858d5832ba..6e17f804993e2 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -1306,7 +1306,7 @@ def SparseTensor_SelectOp : SparseTensor_Op<"select", [Pure, SameOperandsAndResu
 
 def SparseTensor_YieldOp : SparseTensor_Op<"yield", [Pure, Terminator,
     ParentOneOf<["BinaryOp", "UnaryOp", "ReduceOp", "SelectOp",
-                 "ForeachOp", "IterateOp"]>]> {
+                 "ForeachOp", "IterateOp", "CoIterateOp"]>]> {
   let summary = "Yield from sparse_tensor set-like operations";
   let description = [{
       Yields a value from within a `binary`, `unary`, `reduce`,
@@ -1629,14 +1629,14 @@ def IterateOp : SparseTensor_Op<"iterate",
 
   let arguments = (ins AnySparseIterSpace:$iterSpace,
                        Variadic<AnyType>:$initArgs,
-                       LevelSetAttr:$crdUsedLvls);
+                       I64BitSetAttr:$crdUsedLvls);
   let results = (outs Variadic<AnyType>:$results);
   let regions = (region SizedRegion<1>:$region);
 
   let skipDefaultBuilders = 1;
   let builders = [
     OpBuilder<(ins "Value":$iterSpace, "ValueRange":$initArgs)>,
-    OpBuilder<(ins "Value":$iterSpace, "ValueRange":$initArgs, "LevelSet" :$crdUsedLvls)>
+    OpBuilder<(ins "Value":$iterSpace, "ValueRange":$initArgs, "I64BitSet" :$crdUsedLvls)>
   ];
 
   let extraClassDeclaration = [{
@@ -1669,6 +1669,127 @@ def IterateOp : SparseTensor_Op<"iterate",
   let hasCustomAssemblyFormat = 1;
 }
 
+def SparseTensor_CoIterateOp : SparseTensor_Op<"coiterate",
+    [AttrSizedOperandSegments,
+     SingleBlockImplicitTerminator<"sparse_tensor::YieldOp">,
+     RecursiveMemoryEffects]> {
+  let summary = "Co-iterates over a set of sparse iteration spaces";
+  let description = [{
+      The `sparse_tensor.coiterate` operation represents a loop (nest) over
+      a set of iteration spaces. The operation can have multiple regions,
+      with each of them defining a case to compute a result at the current iterations.
+      The case condition is defined solely based on the pattern of specified iterators.
+      For example:
+      ```mlir
+      %ret = sparse_tensor.coiterate (%sp1, %sp2) at(%coord) iter_args(%arg = %init)
+           : (!sparse_tensor.iter_space<#CSR, lvls = 0>,
+              !sparse_tensor.iter_space<#COO, lvls = 0>)
+           -> index
+      case %it1, _ {
+        // %coord is specifed in space %sp1 but *NOT* specified in space %sp2.
+      }
+      case %it1, %it2 {
+        // %coord is specifed in *BOTH* spaces %sp1 and %sp2.
+      }
+      ```
+
+      `sparse_tensor.coiterate` can also operate on loop-carried variables.
+      It returns the final value for each loop-carried variable after loop termination.
+      The initial values of the variables are passed as additional SSA operands
+      to the iterator SSA value and used coordinate SSA values.
+      Each operation region has variadic arguments for specified (used), one argument
+      for each loop-carried variable, representing the value of the variable
+      at the current iteration, followed by a list of arguments for iterators.
+      The body region must contain exactly one block that terminates with
+      `sparse_tensor.yield`.
+
+      The results of an `sparse_tensor.coiterate` hold the final values after
+      the last iteration. If the `sparse_tensor.coiterate` defines any values,
+      a yield must be explicitly present in every region defined in the operation.
+      The number and types of the `sparse_tensor.coiterate` results must match
+      the initial values in the iter_args binding and the yield operands.
+
+
+      A `sparse_tensor.coiterate` example that does elementwise addition between two
+      sparse vectors.
+
+
+      ```mlir
+      %ret = sparse_tensor.coiterate (%sp1, %sp2) at(%coord) iter_args(%arg = %init)
+           : (!sparse_tensor.iter_space<#CSR, lvls = 0>,
+              !sparse_tensor.iter_space<#CSR, lvls = 0>)
+           -> tensor<?xindex, #CSR>
+      case %it1, _ {
+         // v = v1 + 0 = v1
+         %v1 = sparse_tensor.extract_value %t1 at %it1 : index
+         %yield = sparse_tensor.insert %v1 into %arg[%coord]
+         sparse_tensor.yield %yield
+      }
+      case _, %it2 {
+         // v = v2 + 0 = v2
+         %v2 = sparse_tensor.extract_value %t2 at %it2 : index
+         %yield = sparse_tensor.insert %v1 into %arg[%coord]
+         sparse_tensor.yield %yield
+      }
+      case %it1, %it2 {
+         // v = v1 + v2
+         %v1 = sparse_tensor.extract_value %t1 at %it1 : index
+         %v2 = sparse_tensor.extract_value %t2 at %it2 : index
+         %v = arith.addi %v1, %v2 : index
+         %yield = sparse_tensor.insert %v into %arg[%coord]
+         sparse_tensor.yield %yield
+      }
+      ```
+  }];
+
+  let arguments = (ins Variadic<AnySparseIterSpace>:$iterSpaces,
+                       Variadic<AnyType>:$initArgs,
+                       I64BitSetAttr:$crdUsedLvls,
+                       I64BitSetArrayAttr:$cases);
+  let results = (outs Variadic<AnyType>:$results);
+  let regions = (region VariadicRegion<SizedRegion<1>>:$caseRegions);
+
+  let extraClassDeclaration = [{
+    unsigned getSpaceDim() {
+      return llvm::cast<::mlir::sparse_tensor::IterSpaceType>(
+                 getIterSpaces().front().getType())
+          .getSpaceDim();
+    }
+    I64BitSet getRegionDefinedSpace(unsigned regionIdx) {
+      return I64BitSet(llvm::cast<IntegerAttr>(getCases()[regionIdx])
+                           .getValue().getZExtValue());
+    }
+    auto getRegionDefinedSpaces() {
+      return llvm::map_range(getCases().getValue(), [](Attribute attr) {
+        return I64BitSet(llvm::cast<IntegerAttr>(attr).getValue().getZExtValue());
+      });
+    }
+
+    // The block arguments starts with referenced coordinates, follows by
+    // user-provided iteration arguments and ends with iterators.
+    Block::BlockArgListType getCrds(unsigned regionIdx) {
+      return getRegion(regionIdx).getArguments()
+          .take_front(getCrdUsedLvls().count());
+    }
+    unsigned getNumRegionIterArgs(unsigned regionIdx) {
+      return getInitArgs().size();
+    }
+    Block::BlockArgListType getRegionIterArgs(unsigned regionIdx) {
+      return getRegion(regionIdx).getArguments()
+          .slice(getCrdUsedLvls().count(), getNumRegionIterArgs(regionIdx));
+    }
+    Block::BlockArgListType getRegionIterators(unsigned regionIdx) {
+      return getRegion(regionIdx).getArguments()
+          .take_back(getRegionDefinedSpace(regionIdx).count());
+    }
+    ValueRange getYieldedValues(unsigned regionIdx);
+  }];
+
+  let hasVerifier = 1;
+  let hasRegionVerifier = 1;
+  let hasCustomAssemblyFormat = 1;
+}
+
 //===----------------------------------------------------------------------===//
 // Sparse Tensor Debugging and Test-Only Operations.
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 0a276d87f3bca..1135ea32fe1ab 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -2131,10 +2131,82 @@ static void printLevelRange(OpAsmPrinter &p, Operation *, IntegerAttr lvlLo,
   printLevelRange(p, lo, hi);
 }
 
+/// Parses a list of `optional` defined list in the form of
+/// "(%val0, _, %val1, ...)", where `_` is used to annotate that the
+/// corresponding value is not defined (e.g., to represent an undefined
+/// coordinate in the sparse iteration space).
+static ParseResult parseOptionalDefinedList(
+    OpAsmParser &parser, OperationState &state, I64BitSet &definedSet,
+    SmallVectorImpl<OpAsmParser::Argument> &definedArgs,
+    unsigned maxCnt = std::numeric_limits<unsigned>::max(),
+    OpAsmParser::Delimiter delimiter = OpAsmParser::Delimiter::Paren) {
+  unsigned cnt = 0;
+  ParseResult crdList =
+      parser.parseCommaSeparatedList(delimiter, [&]() -> ParseResult {
+        if (parser.parseOptionalKeyword("_")) {
+          if (parser.parseArgument(definedArgs.emplace_back()))
+            return failure();
+          definedSet.set(cnt);
+        }
+        cnt += 1;
+        return success();
+      });
+
+  if (cnt > maxCnt)
+    return parser.emitError(parser.getNameLoc(),
+                            "parsed more value than expected.");
+
+  if (failed(crdList)) {
+    return parser.emitError(
+        parser.getNameLoc(),
+        "expecting SSA value or \"_\" for level coordinates");
+  }
+  assert(definedArgs.size() == definedSet.count());
+  return success();
+}
+
+static void printOptionalDefinedList(OpAsmPrinter &p, unsigned size,
+                                     Block::BlockArgListType blocksArgs,
+                                     I64BitSet definedSet) {
+  if (definedSet.empty())
+    return;
+
+  for (unsigned i = 0; i < size; i++) {
+    if (definedSet[i]) {
+      p << blocksArgs.front();
+      blocksArgs = blocksArgs.drop_front();
+    } else {
+      p << "_";
+    }
+    if (i != size - 1)
+      p << ", ";
+  }
+  assert(blocksArgs.empty());
+}
+
 static ParseResult
-parseSparseSpaceLoop(OpAsmParser &parser, OperationState &state,
-                     SmallVectorImpl<OpAsmParser::Argument> &iterators,
-                     SmallVectorImpl<OpAsmParser::Argument> &iterArgs) {
+parseUsedCoordList(OpAsmParser &parser, OperationState &state,
+                   SmallVectorImpl<OpAsmParser::Argument> &coords) {
+  // Parse "at(%crd0, _, ...)"
+  I64BitSet crdUsedLvlSet;
+  if (succeeded(parser.parseOptionalKeyword("at")) &&
+      failed(parseOptionalDefinedList(parser, state, crdUsedLvlSet, coords)))
+    return failure();
+
+  // Always use IndexType for the coordinate.
+  for (auto &coord : coords)
+    coord.type = parser.getBuilder().getIndexType();
+
+  // Set the CrdUsedLvl bitset.
+  state.addAttribute("crdUsedLvls",
+                     parser.getBuilder().getI64IntegerAttr(crdUsedLvlSet));
+  return success();
+}
+
+static ParseResult
+parseSparseIterateLoop(OpAsmParser &parser, OperationState &state,
+                       SmallVectorImpl<OpAsmParser::Argument> &iterators,
+                       SmallVectorImpl<OpAsmParser::Argument> &blockArgs) {
   SmallVector<OpAsmParser::UnresolvedOperand> spaces;
   SmallVector<OpAsmParser::UnresolvedOperand> initArgs;
 
@@ -2148,37 +2220,14 @@ parseSparseSpaceLoop(OpAsmParser &parser, OperationState &state,
         parser.getNameLoc(),
         "mismatch in number of sparse iterators and sparse spaces");
 
-  // Parse "at(%crd0, _, ...)"
-  LevelSet crdUsedLvlSet;
-  bool hasUsedCrds = succeeded(parser.parseOptionalKeyword("at"));
-  unsigned lvlCrdCnt = 0;
-  if (hasUsedCrds) {
-    ParseResult crdList = parser.parseCommaSeparatedList(
-        OpAsmParser::Delimiter::Paren, [&]() -> ParseResult {
-          if (parser.parseOptionalKeyword("_")) {
-            if (parser.parseArgument(iterArgs.emplace_back()))
-              return failure();
-            // Always use IndexType for the coordinate.
-            crdUsedLvlSet.set(lvlCrdCnt);
-            iterArgs.back().type = parser.getBuilder().getIndexType();
-          }
-          lvlCrdCnt += 1;
-          return success();
-        });
-    if (failed(crdList)) {
-      return parser.emitError(
-          parser.getNameLoc(),
-          "expecting SSA value or \"_\" for level coordinates");
-    }
-  }
-  // Set the CrdUsedLvl bitset.
-  state.addAttribute("crdUsedLvls",
-                     parser.getBuilder().getI64IntegerAttr(crdUsedLvlSet));
+  if (failed(parseUsedCoordList(parser, state, blockArgs)))
+    return failure();
+  size_t numCrds = blockArgs.size();
 
   // Parse "iter_args(%arg = %init, ...)"
   bool hasIterArgs = succeeded(parser.parseOptionalKeyword("iter_args"));
   if (hasIterArgs)
-    if (parser.parseAssignmentList(iterArgs, initArgs))
+    if (parser.parseAssignmentList(blockArgs, initArgs))
       return failure();
 
   SmallVector<Type> iterSpaceTps;
@@ -2196,10 +2245,6 @@ parseSparseSpaceLoop(OpAsmParser &parser, OperationState &state,
       return parser.emitError(parser.getNameLoc(),
                               "expected sparse_tensor.iter_space type for "
                               "iteration space operands");
-    if (hasUsedCrds && spaceTp.getSpaceDim() != lvlCrdCnt)
-      return parser.emitError(parser.getNameLoc(),
-                              "mismatch in number of iteration space dimension "
-                              "and specified coordinates");
     it.type = spaceTp.getIteratorType();
   }
 
@@ -2213,9 +2258,68 @@ parseSparseSpaceLoop(OpAsmParser &parser, OperationState &state,
     return failure();
 
   if (hasIterArgs) {
-    unsigned numCrds = crdUsedLvlSet.count();
     // Strip off leading args that used for coordinates.
-    MutableArrayRef args = MutableArrayRef(iterArgs).drop_front(numCrds);
+    MutableArrayRef args = MutableArrayRef(blockArgs).drop_front(numCrds);
+    if (args.size() != initArgs.size() || args.size() != state.types.size()) {
+      return parser.emitError(
+          parser.getNameLoc(),
+          "mismatch in number of iteration arguments and return values");
+    }
+
+    for (auto [it, init, tp] : llvm::zip_equal(args, initArgs, state.types)) {
+      it.type = tp;
+      if (parser.resolveOperand(init, tp, state.operands))
+        return failure();
+    }
+  }
+  return success();
+}
+
+static ParseResult
+parseSparseCoIterateLoop(OpAsmParser &parser, OperationState &state,
+                         SmallVectorImpl<Value> &spacesVals,
+                         SmallVectorImpl<OpAsmParser::Argument> &blockArgs) {
+
+  // Parse "(%spaces, ...)"
+  SmallVector<OpAsmParser::UnresolvedOperand> spaces;
+  if (parser.parseOperandList(spaces, OpAsmParser::Delimiter::Paren))
+    return failure();
+
+  if (failed(parseUsedCoordList(parser, state, blockArgs)))
+    return failure();
+  size_t numCrds = blockArgs.size();
+
+  // Parse "iter_args(%arg = %init, ...)"
+  SmallVector<OpAsmParser::UnresolvedOperand> initArgs;
+  bool hasIterArgs = succeeded(parser.parseOptionalKeyword("iter_args"));
+  if (hasIterArgs)
+    if (parser.parseAssignmentList(blockArgs, initArgs))
+      return failure();
+
+  SmallVector<Type> iterSpaceTps;
+  // parse ": (sparse_tensor.iter_space, ...) -> ret"
+  if (parser.parseColon() || parser.parseLParen() ||
+      parser.parseTypeList(iterSpaceTps) || parser.parseRParen())
+    return failure();
+
+  if (iterSpaceTps.size() != spaces.size())
+    return parser.emitError(parser.getNameLoc(),
+                            "mismatch in number of iteration space operands "
+                            "and iteration space types");
+
+  if (hasIterArgs)
+    if (parser.parseArrowTypeList(state.types))
+      return failure();
+
+  // Resolves input sparse iteration spaces.
+  if (parser.resolveOperands(spaces, iterSpaceTps, parser.getNameLoc(),
+                             spacesVals))
+    return failure();
+  state.operands.append(spacesVals);
+
+  if (hasIterArgs) {
+    // Strip off leading args that used for coordinates.
+    MutableArrayRef args = MutableArrayRef(blockArgs).drop_front(numCrds);
     if (args.size() != initArgs.size() || args.size() != state.types.size()) {
       return parser.emitError(
           parser.getNameLoc(),
@@ -2285,7 +2389,7 @@ struct RemoveUnusedLvlCrds : public OpRewritePattern<IterateOp> {
 
   LogicalResult matchAndRewrite(IterateOp iterateOp,
                                 PatternRewriter &rewriter) const override {
-    LevelSet newUsedLvls(0);
+    I64BitSet newUsedLvls(0);
     llvm::BitVector toRemove(iterateOp.getBody()->getNumArguments());
     for (unsigned i = 0, e = iterateOp.getSpaceDim(); i < e; i++) {
       if (auto crd = iterateOp.getLvlCrd(i)) {
@@ -2317,13 +2421,13 @@ void IterateOp::build(OpBuilder &builder, OperationState &odsState,
                       Value iterSpace, ValueRange initArgs) {
   unsigned rank = llvm::cast<IterSpaceType>(iterSpace.getType()).getSpaceDim();
   // All ones.
-  LevelSet set((1 << rank) - 1);
+  I64BitSet set((1 << rank) - 1);
   return build(builder, odsState, iterSpace, initArgs, set);
 }
 
 void IterateOp::build(OpBuilder &builder, OperationState &odsState,
                       Value iterSpace, ValueRange initArgs,
-                      LevelSet crdUsedLvls) {
+                      I64BitSet crdUsedLvls) {
   OpBuilder::InsertionGuard guard(builder);
 
   odsState.addOperands(iterSpace);
@@ -2353,7 +2457,7 @@ ParseResult IterateOp::parse(OpAsmParser &parser, OperationState &result) {
   OpAsmParser::UnresolvedOperand iterSpace;
 
   SmallVector<OpAsmParser::Argument> iters, iterArgs;
-  if (parseSparseSpaceLoop(parser, result, iters, iterArgs))
+  if (parseSparseIterateLoop(parser, result, iters, iterArgs))
     return failure();
   if (iters.size() != 1)
     return parser.emitError(parser.getNameLoc(),
@@ -2393,51 +2497,39 @@ static void printInitializationList(OpAsmPrinter &p,
   p << ")";
 }
 
-static void printUsedCrdsList(OpAsmPrinter &p, unsigned spaceDim,
-                              Block::BlockArgListType blocksArgs,
-                              LevelSet crdUsedLvls) {
-  if (crdUsedLvls.empty())
-    return;
-
-  p << " at(";
-  for (unsigned i = 0; i < spaceDim; i++) {
-    if (crdUsedLvls[i]) {
-      p << blocksArgs.front();
-      blocksArgs = blocksArgs.drop_front();
-    } else {
-      p << "_";
-    }
-    if (i != spaceDim - 1)
-      p << ", ";
+template <typename SparseLoopOp>
+static LogicalResult verifySparseLoopOp(SparseLoopOp op) {
+  if (op.getInitArgs().size() != op.getNumResults()) {
+    return op.emitOpError(
+        "mismatch in number of loop-carried values and defined values");
   }
-  assert(blocksArgs.empty());
-  p << ")";
+  if (op.getCrdUsedLvls().max() > op.getSpaceDim())
+    return op.emitOpError("required out-of-bound coordinates");
+
+  return success();
 }
 
+LogicalResult IterateOp::verify() { return verifySparseLoopOp(*this); }
+LogicalResult CoIterateOp::verify() { return verifySparseLoopOp(*this); }
+
 void IterateOp::print(OpAsmPrinter &p) {
   p << " " << getIterator() << " in " << getIterSpace();
-  printUsedCrdsList(p, getSpaceDim(), getCrds(), getCrdUsedLvls());
+  if (!getCrdUsedLvls().empty()) {
+    p << " at(";
+    printOptionalDefinedList(p, getSpaceDim(), getCrds(), getCrdUsedLvls());
+    p << ")";
+  }
   printInitializationList(p, getRegionIterArgs(), getInitArgs(), " iter_args");
 
   p << " : " << getIterSpace().getType() << " ";
   if (!getInitArgs().empty())
-    p << "-> (" << getInitArgs().getTypes() << ") ";
+    p.printArrowTypeList(getInitArgs().getTypes());
 
+  p << " ";
   p.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
                 /*printBlockTerminators=*/!getInitArgs().empty());
 }
 
-LogicalResult IterateOp::verify() {
-  if (getInitArgs().size() != getNumResults()) {
-    return emitOpError(
-        "mismatch in number of loop-carried values and defined values");
-  }
-  if (getCrdUsedLvls().max() > getSpaceDim())
-    return emitOpError("required out-of-bound coordinates");
-
-  return success();
-}
-
 LogicalResult IterateOp::verifyRegions() {
   if (getIterator().getType() != getIterSpace().getType().getIteratorType())
     return emitOpError("mismatch in iterator and iteration space type");
@@ -2495,13 +2587,136 @@ OperandRange IterateOp::getEntrySuccessorOperands(RegionBranchPoint point) {
 
 void IterateOp::getSuccessorRegions(RegionBranchPoint point,
                                     SmallVectorImpl<RegionSuccessor> &regions) {
-  // Both the operation itself and the region may be branching into the body or
-  // back into the operation itself.
+  // Both the operation itself and the region may be branching into the body
+  // or back into the operation itself.
   regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs()));
   // It is possible for loop not to enter the body.
   regions.push_back(RegionSuccessor(getResults()));
 }
 
+ParseResult CoIterateOp::parse(OpAsmParser &parser, OperationState &result) {
+
+  SmallVector<Value> spaces;
+  // The block argument list of each regions, it is arranged in the order of
+  // ([used coordinate list], [loop iterations args], [sparse iterator list]).
+  SmallVector<OpAsmParser::Argument> blockArgs;
+  if (parseSparseCoIterateLoop(parser, result, spaces, blockArgs))
+    return failure();
+
+  result.addAttribute("operandSegmentSizes",
+                      parser.getBuilder().getDenseI32ArrayAttr(
+                          {static_cast<int32_t>(spaces.size()),
+                           static_cast<int32_t>(result.types.size())}));
+
+  SmallVector<Attribute> cases;
+  while (succeeded(parser.parseOptionalKeyword("case"))) {
+    // Parse one region per case.
+    I64BitSet definedItSet;
+    SmallVector<OpAsmParser::Argument> definedIts;
+    if (parseOptionalDefinedList(parser, result, definedItSet, definedIts,
+                                 spaces.size(), OpAsmParser::Delimiter::None))
+      return failure();
+
+    cases.push_back(parser.getBuilder().getI64IntegerAttr(definedItSet));
+
+    for (auto [i, definedIdx] : llvm::enumerate(definedItSet.bits())) {
+      // Resolve the iterator type based on the iteration space type.
+      auto spaceTp = llvm::cast<IterSpaceType>(spaces[definedIdx].getType());
+      definedIts[i].type = spaceTp.getIteratorType();
+    }
+    definedIts.insert(definedIts.begin(), blockArgs.begin(), blockArgs.end());
+    Region *body = result.addRegion();
+    if (parser.parseRegion(*body, definedIts))
+      return failure();
+
+    CoIterateOp::ensureTerminator(*body, parser.getBuilder(), result.location);
+  }
+
+  result.addAttribute("cases", ArrayAttr::get(parser.getContext(), cases));
+
+  // Parse the optional attribute list.
+  if (parser.parseOptionalAttrDict(result.attributes))
+    return failure();
+
+  return success();
+}
+
+void CoIterateOp::print(OpAsmPrinter &p) {
+  p << " (";
+  llvm::interleaveComma(getIterSpaces(), p, [&](auto s) { p << s; });
+  p << ")";
+
+  if (!getCrdUsedLvls().empty()) {
+    p << " at(";
+    printOptionalDefinedList(p, getSpaceDim(), getCrds(0), getCrdUsedLvls());
+    p << ")";
+  }
+
+  printInitializationList(p, getRegionIterArgs(0), getInitArgs(), " iter_args");
+
+  p << " : (" << getIterSpaces().getTypes() << ")";
+  if (!getInitArgs().empty())
+    p.printArrowTypeList(getInitArgs().getTypes());
+
+  for (unsigned idx = 0, e = getRegions().size(); idx < e; idx++) {
+    p.printNewline();
+    p << "case ";
+    printOptionalDefinedList(p, getIterSpaces().size(), getRegionIterators(idx),
+                             getRegionDefinedSpace(idx));
+    p << " ";
+    p.printRegion(getRegion(idx), /*printEntryBlockArgs=*/false,
+                  /*printBlockTerminators=*/!getInitArgs().empty());
+  }
+}
+
+ValueRange CoIterateOp::getYieldedValues(unsigned regionIdx) {
+  return cast<sparse_tensor::YieldOp>(
+             getRegion(regionIdx).getBlocks().front().getTerminator())
+      .getResults();
+}
+
+LogicalResult CoIterateOp::verifyRegions() {
+  for (unsigned r = 0, e = getNumRegions(); r < e; r++) {
+    if (getNumRegionIterArgs(r) != getNumResults())
+      return emitOpError(
+          "mismatch in number of basic block args and defined values");
+
+    auto initArgs = getInitArgs();
+    auto iterArgs = getRegionIterArgs(r);
+    auto yieldVals = getYieldedValues(r);
+    auto opResults = getResults();
+    if (!llvm::all_equal({initArgs.size(), iterArgs.size(), yieldVals.size(),
+                          opResults.size()})) {
+      return emitOpError()
+             << "number mismatch between iter args and results on " << r
+             << "th region";
+    }
+
+    for (auto [i, init, iter, yield, ret] :
+         llvm::enumerate(initArgs, iterArgs, yieldVals, opResults)) {
+      if (init.getType() != ret.getType())
+        return emitOpError()
+               << "types mismatch between " << i
+               << "th iter operand and defined value on " << r << "th region";
+      if (iter.getType() != ret.getType())
+        return emitOpError() << "types mismatch between " << i
+                             << "th iter region arg and defined value on " << r
+                             << "th region";
+      if (yield.getType() != ret.getType())
+        return emitOpError()
+               << "types mismatch between " << i
+               << "th yield value and defined value on " << r << "th region";
+    }
+  }
+
+  auto cases = getRegionDefinedSpaces();
+  llvm::SmallSetVector<uint64_t, 8> set(cases.begin(), cases.end());
+  if (set.size() != getNumRegions())
+    return emitOpError("contains duplicated cases.");
+
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // Sparse Tensor Dialect Setups.
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseSpaceCollapse.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseSpaceCollapse.cpp
index 924046fcd9961..f85c4761a8d52 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseSpaceCollapse.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseSpaceCollapse.cpp
@@ -141,10 +141,10 @@ void collapseSparseSpace(MutableArrayRef<CollapseSpaceInfo> toCollapse) {
   auto cloned = llvm::cast<IterateOp>(builder.clone(*innermost, mapper));
   builder.setInsertionPointToStart(cloned.getBody());
 
-  LevelSet crdUsedLvls;
+  I64BitSet crdUsedLvls;
   unsigned shift = 0, argIdx = 1;
   for (auto info : toCollapse.drop_back()) {
-    LevelSet set = info.loop.getCrdUsedLvls();
+    I64BitSet set = info.loop.getCrdUsedLvls();
     crdUsedLvls |= set.lshift(shift);
     shift += info.loop.getSpaceDim();
     for (BlockArgument crd : info.loop.getCrds()) {

diff  --git a/mlir/test/Dialect/SparseTensor/invalid.mlir b/mlir/test/Dialect/SparseTensor/invalid.mlir
index 61cc9be88685c..737b736ba795f 100644
--- a/mlir/test/Dialect/SparseTensor/invalid.mlir
+++ b/mlir/test/Dialect/SparseTensor/invalid.mlir
@@ -1191,3 +1191,78 @@ func.func @sparse_iterate(%sp : tensor<4x8xf32, #COO>, %i : index, %j : index) -
   }
   return %r1 : index
 }
+
+// -----
+
+#COO = #sparse_tensor.encoding<{
+  map = (i, j) -> (
+    i : compressed(nonunique),
+    j : singleton(soa)
+  )
+}>
+
+
+func.func @sparse_coiteration(%sp1 : !sparse_tensor.iter_space<#COO, lvls = 0>,
+                              %sp2 : !sparse_tensor.iter_space<#COO, lvls = 1>) -> index {
+  %init = arith.constant 0 : index
+  // expected-error @+1 {{'sparse_tensor.coiterate' op contains duplicated cases.}}
+  %ret = sparse_tensor.coiterate (%sp1, %sp2) at (%coord) iter_args(%arg = %init)
+       : (!sparse_tensor.iter_space<#COO, lvls = 0>, !sparse_tensor.iter_space<#COO, lvls = 1>)
+       -> index
+  case %it1, _ {
+    sparse_tensor.yield %arg : index
+  }
+  case %it1, _ {
+    sparse_tensor.yield %arg : index
+  }
+  return %ret : index
+}
+
+
+// -----
+
+#COO = #sparse_tensor.encoding<{
+  map = (i, j) -> (
+    i : compressed(nonunique),
+    j : singleton(soa)
+  )
+}>
+
+
+func.func @sparse_coiteration(%sp1 : !sparse_tensor.iter_space<#COO, lvls = 0>,
+                              %sp2 : !sparse_tensor.iter_space<#COO, lvls = 1>) -> index {
+  %init = arith.constant 0 : index
+  // expected-error @+1 {{'sparse_tensor.coiterate' op types mismatch between 0th yield value and defined value on 0th region}}
+  %ret = sparse_tensor.coiterate (%sp1, %sp2) at (%coord) iter_args(%arg = %init)
+       : (!sparse_tensor.iter_space<#COO, lvls = 0>, !sparse_tensor.iter_space<#COO, lvls = 1>)
+       -> index
+  case %it1, _ {
+    %i = arith.constant 1 : i32
+    sparse_tensor.yield %i : i32
+  }
+  return %ret : index
+}
+
+// -----
+
+#COO = #sparse_tensor.encoding<{
+  map = (i, j) -> (
+    i : compressed(nonunique),
+    j : singleton(soa)
+  )
+}>
+
+
+func.func @sparse_coiteration(%sp1 : !sparse_tensor.iter_space<#COO, lvls = 0>,
+                              %sp2 : !sparse_tensor.iter_space<#COO, lvls = 1>) -> index {
+  %init = arith.constant 0 : index
+  // expected-error @+1 {{'sparse_tensor.coiterate' op required out-of-bound coordinates}}
+  %ret = sparse_tensor.coiterate (%sp1, %sp2) at (%coord1, %coord2) iter_args(%arg = %init)
+       : (!sparse_tensor.iter_space<#COO, lvls = 0>, !sparse_tensor.iter_space<#COO, lvls = 1>)
+       -> index
+  case %it1, _ {
+    %i = arith.constant 1 : i32
+    sparse_tensor.yield %i : i32
+  }
+  return %ret : index
+}

diff  --git a/mlir/test/Dialect/SparseTensor/roundtrip.mlir b/mlir/test/Dialect/SparseTensor/roundtrip.mlir
index 055709ee69eb7..ab861a2019dfa 100644
--- a/mlir/test/Dialect/SparseTensor/roundtrip.mlir
+++ b/mlir/test/Dialect/SparseTensor/roundtrip.mlir
@@ -801,7 +801,7 @@ func.func @sparse_extract_iter_space(%sp : tensor<4x8xf32, #COO>, %it1 : !sparse
 // CHECK-SAME:      %[[VAL_1:.*]]: index,
 // CHECK-SAME:      %[[VAL_2:.*]]: index) -> index {
 // CHECK:           %[[VAL_3:.*]] = sparse_tensor.extract_iteration_space %[[VAL_0]] lvls = 0 : tensor<4x8xf32, #sparse{{[0-9]*}}>
-// CHECK:           %[[VAL_4:.*]] = sparse_tensor.iterate %[[VAL_5:.*]] in %[[VAL_3]] at(%[[VAL_6:.*]]) iter_args(%[[VAL_7:.*]] = %[[VAL_1]]) : !sparse_tensor.iter_space<#sparse{{[0-9]*}}, lvls = 0> -> (index) {
+// CHECK:           %[[VAL_4:.*]] = sparse_tensor.iterate %[[VAL_5:.*]] in %[[VAL_3]] at(%[[VAL_6:.*]]) iter_args(%[[VAL_7:.*]] = %[[VAL_1]]) : !sparse_tensor.iter_space<#sparse{{[0-9]*}}, lvls = 0> -> index {
 // CHECK:             sparse_tensor.yield %[[VAL_7]] : index
 // CHECK:           }
 // CHECK:           return %[[VAL_4]] : index
@@ -813,3 +813,36 @@ func.func @sparse_iterate(%sp : tensor<4x8xf32, #COO>, %i : index, %j : index) -
   }
   return %r1 : index
 }
+
+
+// -----
+
+#COO = #sparse_tensor.encoding<{
+  map = (i, j) -> (
+    i : compressed(nonunique),
+    j : singleton(soa)
+  )
+}>
+
+
+// CHECK-LABEL:   func.func @sparse_coiteration(
+// CHECK-SAME:      %[[SP1:.*]]: !sparse_tensor.iter_space<#sparse, lvls = 0>,
+// CHECK-SAME:      %[[SP2:.*]]: !sparse_tensor.iter_space<#sparse, lvls = 1>) -> index {
+// CHECK:           %[[INIT:.*]] = arith.constant 0 : index
+// CHECK:           %[[RET:.*]] = sparse_tensor.coiterate (%[[SP1]], %[[SP2]]) at(%[[COORD:.*]]) iter_args(%[[ARG:.*]] = %[[INIT]])
+// CHECK:           case %[[VAL_6:.*]], _ {
+// CHECK:             sparse_tensor.yield %[[ARG]] : index
+// CHECK:           }
+// CHECK:           return %[[RET]] : index
+// CHECK:         }
+func.func @sparse_coiteration(%sp1 : !sparse_tensor.iter_space<#COO, lvls = 0>,
+                              %sp2 : !sparse_tensor.iter_space<#COO, lvls = 1>) -> index {
+  %init = arith.constant 0 : index
+  %ret = sparse_tensor.coiterate (%sp1, %sp2) at (%coord) iter_args(%arg = %init)
+       : (!sparse_tensor.iter_space<#COO, lvls = 0>, !sparse_tensor.iter_space<#COO, lvls = 1>)
+       -> index
+  case %it1, _ {
+    sparse_tensor.yield %arg : index
+  }
+  return %ret : index
+}


        


More information about the Mlir-commits mailing list