[Mlir-commits] [mlir] [mlir][sparse] partially support lowering sparse coiteration loops to… (PR #105261)

Peiming Liu llvmlistbot at llvm.org
Tue Aug 20 12:12:59 PDT 2024


https://github.com/PeimingLiu updated https://github.com/llvm/llvm-project/pull/105261

>From beab434591ae9bd7400e43ecb149ad7fcbd8a247 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Wed, 14 Aug 2024 21:35:46 +0000
Subject: [PATCH] [mlir][sparse] partially support lowering sparse coiteration
 loops to scf.while/for.

---
 .../Dialect/SparseTensor/IR/SparseTensor.h    |   6 +
 .../SparseTensor/IR/SparseTensorOps.td        |   4 +
 .../SparseTensor/IR/SparseTensorDialect.cpp   |  10 +
 .../Transforms/SparseIterationToScf.cpp       | 290 +++++++++++++++++-
 .../Transforms/Utils/LoopEmitter.cpp          | 174 ++++++-----
 .../Transforms/Utils/LoopEmitter.h            |  11 +
 .../Transforms/Utils/SparseTensorIterator.h   |   2 +
 .../sparse_kernels_to_iterator.mlir           |  77 ++++-
 ...-sqsum.mlir => iterator-based-kernel.mlir} |  49 ++-
 9 files changed, 539 insertions(+), 84 deletions(-)
 rename mlir/test/Integration/Dialect/SparseTensor/CPU/{iterator-based-sqsum.mlir => iterator-based-kernel.mlir} (63%)

diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
index 388efd1c454b1e..53df01ec0a66e4 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
@@ -96,6 +96,12 @@ class I64BitSet {
     return *this;
   }
 
+  bool isSubSetOf(const I64BitSet p) const {
+    I64BitSet tmp = *this;
+    tmp |= p;
+    return tmp == p;
+  }
+
   // Needed by `llvm::const_set_bits_iterator_impl`.
   int find_first() const { return min(); }
   int find_next(unsigned prev) const {
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index 2803223354d5ee..20512f972e67cd 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -1787,6 +1787,10 @@ def SparseTensor_CoIterateOp : SparseTensor_Op<"coiterate",
           .take_back(getRegionDefinedSpace(regionIdx).count());
     }
     ValueRange getYieldedValues(unsigned regionIdx);
+
+    // Returns a vector of regions that are the `sub-cases` of the given case region.
+    // E.g., `case %it1, _, %it3` is a subcase of `case %it1, %it2, %it3`.
+    SmallVector<Region *> getSubCasesOf(unsigned regionIdx);
   }];
 
   let hasVerifier = 1;
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index a143189c301a43..16856b958d4f13 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -2745,6 +2745,16 @@ LogicalResult CoIterateOp::verifyRegions() {
   return success();
 }
 
+SmallVector<Region *> CoIterateOp::getSubCasesOf(unsigned regionIdx) {
+  SmallVector<Region *> ret;
+  I64BitSet caseBit = getRegionDefinedSpace(regionIdx);
+  for (Region &r : getCaseRegions())
+    if (getRegionDefinedSpace(r.getRegionNumber()).isSubSetOf(caseBit))
+      ret.push_back(&r);
+
+  return ret;
+}
+
 //===----------------------------------------------------------------------===//
 // Sparse Tensor Dialect Setups.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
index b1451dee738ac3..f8f80ddc98ba6d 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
@@ -1,5 +1,6 @@
 
 #include "Utils/CodegenUtils.h"
+#include "Utils/LoopEmitter.h"
 #include "Utils/SparseTensorIterator.h"
 
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
@@ -49,6 +50,143 @@ convertIteratorType(IteratorType itTp, SmallVectorImpl<Type> &fields) {
   return success();
 }
 
+static ValueRange
+genCoIterateBranchNest(PatternRewriter &rewriter, Location loc, CoIterateOp op,
+                       Value loopCrd,
+                       ArrayRef<std::unique_ptr<SparseIterator>> iters,
+                       ArrayRef<Region *> subCases, ArrayRef<Value> userReduc) {
+  if (subCases.empty())
+    return userReduc;
+
+  // The current branch that we are handling.
+  Region *b = subCases.front();
+  Value casePred = constantI1(rewriter, loc, true);
+  for (unsigned i : op.getRegionDefinedSpace(b->getRegionNumber()).bits()) {
+    SparseIterator *it = iters[i].get();
+    Value pred = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
+                                                it->getCrd(), loopCrd);
+    casePred = rewriter.create<arith::AndIOp>(loc, casePred, pred);
+  }
+  scf::IfOp ifOp = rewriter.create<scf::IfOp>(
+      loc, ValueRange(userReduc).getTypes(), casePred, /*else=*/true);
+  rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front());
+
+  // Erase the empty block.
+  rewriter.eraseBlock(&ifOp.getThenRegion().front());
+  // Set up block arguments: user-provided values -> loop coord -> iterators.
+  SmallVector<Value> blockArgs(userReduc);
+  blockArgs.push_back(loopCrd);
+  for (unsigned idx : op.getRegionDefinedSpace(b->getRegionNumber()).bits())
+    llvm::append_range(blockArgs, iters[idx]->getCursor());
+
+  IRMapping mapping;
+  for (auto [from, to] :
+       llvm::zip_equal(b->front().getArguments(), blockArgs)) {
+    mapping.map(from, to);
+  }
+
+  // Clone the region, we can not erase the region now because the same region
+  // might be a subcase for multiple lattice point.
+  rewriter.cloneRegionBefore(*b, ifOp.getThenRegion(),
+                             ifOp.getThenRegion().begin(), mapping);
+
+  // replace sparse_tensor::YieldOp -> scf::YieldOp
+  auto spY = cast<sparse_tensor::YieldOp>(&ifOp.getThenRegion().front().back());
+  ValueRange yields = spY.getResults();
+  rewriter.eraseOp(spY);
+  rewriter.setInsertionPointToEnd(&ifOp.getThenRegion().front());
+  rewriter.create<scf::YieldOp>(loc, yields);
+
+  // Generates remaining case recursively.
+  rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front());
+  ValueRange res = genCoIterateBranchNest(rewriter, loc, op, loopCrd, iters,
+                                          subCases.drop_front(), userReduc);
+  if (!res.empty())
+    rewriter.create<scf::YieldOp>(loc, res);
+
+  rewriter.setInsertionPointAfter(ifOp);
+  return ifOp.getResults();
+}
+
+static ValueRange genLoopWithIterator(
+    PatternRewriter &rewriter, Location loc, SparseIterator *it,
+    ValueRange reduc, bool iterFirst,
+    function_ref<SmallVector<Value>(PatternRewriter &rewriter, Location loc,
+                                    Region &loopBody, SparseIterator *it,
+                                    ValueRange reduc)>
+        bodyBuilder) {
+  if (it->iteratableByFor()) {
+    auto [lo, hi] = it->genForCond(rewriter, loc);
+    Value step = constantIndex(rewriter, loc, 1);
+    scf::ForOp forOp = rewriter.create<scf::ForOp>(loc, lo, hi, step, reduc);
+    {
+      OpBuilder::InsertionGuard guard(rewriter);
+      // Erase the implicit yield operation created by ForOp when there is no
+      // yielding values.
+      if (!forOp.getBody()->empty())
+        rewriter.eraseOp(&forOp.getBody()->front());
+      assert(forOp.getBody()->empty());
+
+      it->linkNewScope(forOp.getInductionVar());
+      rewriter.setInsertionPointToStart(forOp.getBody());
+      SmallVector<Value> ret = bodyBuilder(rewriter, loc, forOp.getBodyRegion(),
+                                           it, forOp.getRegionIterArgs());
+
+      rewriter.setInsertionPointToEnd(forOp.getBody());
+      rewriter.create<scf::YieldOp>(loc, ret);
+    }
+    return forOp.getResults();
+  }
+  SmallVector<Value> ivs;
+  // TODO: always put iterator SSA values at the end of argument list to be
+  // consistent with coiterate operation.
+  if (!iterFirst)
+    llvm::append_range(ivs, it->getCursor());
+  // Appends the user-provided values.
+  llvm::append_range(ivs, reduc);
+  if (iterFirst)
+    llvm::append_range(ivs, it->getCursor());
+
+  TypeRange types = ValueRange(ivs).getTypes();
+  auto whileOp = rewriter.create<scf::WhileOp>(loc, types, ivs);
+  {
+    OpBuilder::InsertionGuard guard(rewriter);
+    // Generates loop conditions.
+    SmallVector<Location> l(types.size(), loc);
+    Block *before = rewriter.createBlock(&whileOp.getBefore(), {}, types, l);
+    rewriter.setInsertionPointToStart(before);
+    ValueRange bArgs = before->getArguments();
+    auto [whileCond, remArgs] = it->genWhileCond(rewriter, loc, bArgs);
+    rewriter.create<scf::ConditionOp>(loc, whileCond, before->getArguments());
+
+    // Delegates loop body generation.
+    Region &dstRegion = whileOp.getAfter();
+    Block *after = rewriter.createBlock(&dstRegion, {}, types, l);
+    ValueRange aArgs = whileOp.getAfterArguments();
+    if (iterFirst) {
+      aArgs = it->linkNewScope(aArgs);
+    } else {
+      aArgs = aArgs.take_front(reduc.size());
+      it->linkNewScope(aArgs.drop_front(reduc.size()));
+    }
+
+    rewriter.setInsertionPointToStart(after);
+    SmallVector<Value> ret = bodyBuilder(rewriter, loc, dstRegion, it, aArgs);
+    rewriter.setInsertionPointToEnd(after);
+
+    // Forward loops
+    SmallVector<Value> yields;
+    ValueRange nx = it->forward(rewriter, loc);
+    if (iterFirst)
+      llvm::append_range(yields, nx);
+    llvm::append_range(yields, ret);
+    if (!iterFirst)
+      llvm::append_range(yields, nx);
+    rewriter.create<scf::YieldOp>(loc, yields);
+  }
+  return whileOp.getResults().drop_front(it->getCursor().size());
+}
+
 namespace {
 
 /// Sparse codegen rule for number of entries operator.
@@ -136,6 +274,8 @@ class SparseIterateOpConverter : public OneToNOpConversionPattern<IterateOp> {
       rewriter.replaceOp(op, forOp.getResults(), resultMapping);
     } else {
       SmallVector<Value> ivs;
+      // TODO: put iterator at the end of argument list to be consistent with
+      // coiterate operation.
       llvm::append_range(ivs, it->getCursor());
       for (ValueRange inits : adaptor.getInitArgs())
         llvm::append_range(ivs, inits);
@@ -189,6 +329,153 @@ class SparseIterateOpConverter : public OneToNOpConversionPattern<IterateOp> {
   }
 };
 
+class SparseCoIterateOpConverter
+    : public OneToNOpConversionPattern<CoIterateOp> {
+  using OneToNOpConversionPattern::OneToNOpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(CoIterateOp op, OpAdaptor adaptor,
+                  OneToNPatternRewriter &rewriter) const override {
+    assert(op.getSpaceDim() == 1 && "Not implemented");
+    Location loc = op.getLoc();
+
+    I64BitSet denseBits(0);
+    for (auto [idx, spaceTp] : llvm::enumerate(op.getIterSpaces().getTypes()))
+      if (all_of(cast<IterSpaceType>(spaceTp).getLvlTypes(), isDenseLT))
+        denseBits.set(idx);
+
+    // If there exists a case that only contains dense spaces. I.e., case
+    // bits is a subset of dense bits, or when there is a full empty case (due
+    // to complements), we need a universal pointer to forward the coiteration
+    // loop.
+    bool needUniv =
+        any_of(op.getRegionDefinedSpaces(), [denseBits](I64BitSet caseBits) {
+          // A case for complement.
+          if (caseBits.count() == 0)
+            return true;
+          // An all-dense case.
+          return caseBits.isSubSetOf(denseBits);
+        });
+    assert(!needUniv && "Not implemented");
+    (void)needUniv;
+
+    for (Region &region : op.getCaseRegions()) {
+      // Do a one-shot type conversion on all region blocks, since the same
+      // region might be used multiple time.
+      Block *block = &region.getBlocks().front();
+      OneToNTypeMapping blockTypeMapping(block->getArgumentTypes());
+      if (failed(typeConverter->convertSignatureArgs(block->getArgumentTypes(),
+                                                     blockTypeMapping)))
+        return rewriter.notifyMatchFailure(
+            op, "failed to convert coiterate region argurment types");
+
+      rewriter.applySignatureConversion(block, blockTypeMapping);
+    }
+
+    SmallVector<SparseIterationSpace> spaces;
+    SmallVector<std::unique_ptr<SparseIterator>> iters;
+    for (auto [spaceTp, spaceVals] : llvm::zip_equal(
+             op.getIterSpaces().getTypes(), adaptor.getIterSpaces())) {
+      // TODO: do we really need tid?
+      spaces.push_back(SparseIterationSpace::fromValues(
+          cast<IterSpaceType>(spaceTp), spaceVals, /*tid=*/0));
+      // Extract the iterator.
+      iters.push_back(spaces.back().extractIterator(rewriter, loc));
+    }
+
+    auto getFilteredIters = [&iters](I64BitSet caseBits) {
+      // Retrives a vector of pointers to the iterators used in the case.
+      SmallVector<SparseIterator *> validIters;
+      for (auto idx : caseBits.bits())
+        validIters.push_back(iters[idx].get());
+      return validIters;
+    };
+
+    // Get a flattened user-provided loop reduction values.
+    SmallVector<Value> userReduc;
+    for (ValueRange r : adaptor.getInitArgs())
+      llvm::append_range(userReduc, r);
+
+    // TODO: we need to sort the cases such that they appears in lexical order.
+    // Although sparsification always generates cases in that order, it might
+    // not be the case for human-written code.
+
+    // Generates a loop sequence, one loop per case.
+    for (auto [r, caseBits] :
+         llvm::zip_equal(op.getCaseRegions(), op.getRegionDefinedSpaces())) {
+      assert(caseBits.count() > 0 && "Complement space not implemented");
+
+      // Retrives a vector of pointers to the iterators used in the case.
+      SmallVector<SparseIterator *> validIters = getFilteredIters(caseBits);
+
+      if (validIters.size() > 1) {
+        auto [loop, loopCrd] =
+            genCoIteration(rewriter, loc, validIters, userReduc,
+                           /*uniIdx=*/nullptr, /*userReducFirst=*/true);
+
+        // 1st. find all the cases that is a strict subset of the current case
+        // condition, for which we generate one branch per case inside the loop.
+        // The subcases are never empty, it must contains at least the current
+        // region itself.
+        // TODO: these cases should be sorted.
+        SmallVector<Region *> subCases = op.getSubCasesOf(r.getRegionNumber());
+        assert(!subCases.empty());
+
+        ValueRange res = genCoIterateBranchNest(rewriter, loc, op, loopCrd,
+                                                iters, subCases, userReduc);
+
+        SmallVector<Value> nextIterYields(res);
+        // 2nd. foward the loop.
+        for (SparseIterator *it : validIters) {
+          Value cmp = rewriter.create<arith::CmpIOp>(
+              loc, arith::CmpIPredicate::eq, it->getCrd(), loopCrd);
+          it->forwardIf(rewriter, loc, cmp);
+          llvm::append_range(nextIterYields, it->getCursor());
+        }
+        rewriter.create<scf::YieldOp>(loc, nextIterYields);
+
+        // Exit the loop, relink the iterator SSA value.
+        rewriter.setInsertionPointAfter(loop);
+        ValueRange iterVals = loop->getResults().drop_front(userReduc.size());
+        for (SparseIterator *it : validIters)
+          iterVals = it->linkNewScope(iterVals);
+        assert(iterVals.empty());
+
+        ValueRange curResult = loop->getResults().take_front(userReduc.size());
+        userReduc.assign(curResult.begin(), curResult.end());
+      } else {
+        // This is a simple iteration loop.
+        assert(caseBits.count() == 1);
+
+        Block *block = &r.getBlocks().front();
+        ValueRange curResult = genLoopWithIterator(
+            rewriter, loc, validIters.front(), userReduc, /*iterFirst=*/false,
+            /*bodyBuilder=*/
+            [block](PatternRewriter &rewriter, Location loc, Region &dstRegion,
+                    SparseIterator *it,
+                    ValueRange reduc) -> SmallVector<Value> {
+              SmallVector<Value> blockArgs(reduc);
+              blockArgs.push_back(it->deref(rewriter, loc));
+              llvm::append_range(blockArgs, it->getCursor());
+
+              Block *dstBlock = &dstRegion.getBlocks().front();
+              rewriter.inlineBlockBefore(
+                  block, dstBlock, rewriter.getInsertionPoint(), blockArgs);
+              auto yield = llvm::cast<sparse_tensor::YieldOp>(dstBlock->back());
+              SmallVector<Value> result(yield.getResults());
+              rewriter.eraseOp(yield);
+              return result;
+            });
+
+        userReduc.assign(curResult.begin(), curResult.end());
+      }
+    }
+
+    rewriter.replaceOp(op, userReduc);
+    return success();
+  }
+};
+
 } // namespace
 
 mlir::SparseIterationTypeConverter::SparseIterationTypeConverter() {
@@ -210,5 +497,6 @@ void mlir::populateLowerSparseIterationToSCFPatterns(
 
   IterateOp::getCanonicalizationPatterns(patterns, patterns.getContext());
   patterns.add<ExtractIterSpaceConverter, ExtractValOpConverter,
-               SparseIterateOpConverter>(converter, patterns.getContext());
+               SparseIterateOpConverter, SparseCoIterateOpConverter>(
+      converter, patterns.getContext());
 }
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp
index efb3295fb2a4bf..cb5874ff45068e 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp
@@ -524,84 +524,8 @@ std::pair<Operation *, Value> LoopEmitter::emitForLoopOverTensorAtLvl(
 std::pair<Operation *, Value> LoopEmitter::emitWhileLoopOverTensorsAtLvls(
     OpBuilder &builder, Location loc, ArrayRef<SparseIterator *> spIters,
     MutableArrayRef<Value> reduc, bool needsUniv) {
-  // NOTE: the slice driven tensor-related reduction variable must
-  // appear before normal tensors.
-
-  // The set of induction variables for the while loop.
-  SmallVector<Value> ivs;
-
-  // Construct the while-loop with a parameter for each coordinate.
-  for (SparseIterator *it : spIters) {
-    ValueRange itVals = it->getCursor();
-    ivs.append(itVals.begin(), itVals.end());
-  }
-
-  // The position where user-supplied reduction variable starts.
-  ivs.append(reduc.begin(), reduc.end());
-  // Update universal index.
-  if (needsUniv)
-    ivs.push_back(loopSeqStack.back().first);
-
-  // Ensures all operands are valid.
-  assert(llvm::all_of(ivs, [](Value v) { return v != nullptr; }));
-  TypeRange types = ValueRange(ivs).getTypes();
-  auto whileOp = builder.create<scf::WhileOp>(loc, types, ivs);
-
-  SmallVector<Location> locs(types.size(), loc);
-  Block *before = builder.createBlock(&whileOp.getBefore(), {}, types, locs);
-  Block *after = builder.createBlock(&whileOp.getAfter(), {}, types, locs);
-
-  // Generates loop conditions.
-  builder.setInsertionPointToStart(before);
-  ValueRange bArgs = before->getArguments();
-  Value whileCond = nullptr; // bool values for loop condition.
-
-  for (SparseIterator *it : spIters) {
-    auto [cond, remArgs] = it->genWhileCond(builder, loc, bArgs);
-    whileCond = !whileCond ? cond : ANDI(whileCond, cond);
-    bArgs = remArgs;
-  }
-  // The remaining block arguments are user-provided reduction values and an
-  // optional universal index. Make sure their sizes match.
-  assert(bArgs.size() == reduc.size() + needsUniv);
-  builder.create<scf::ConditionOp>(loc, whileCond, before->getArguments());
-
-  // Generates loop body.
-  builder.setInsertionPointToStart(after);
-  ValueRange aArgs = after->getArguments();
-  // Since some LoopCondKind might need extra checks to filter out invalid
-  // iterations, we maintains another array to hold the iteration arguments to
-  // yield if the checks fails.
-  SmallVector<Value> nextArgs(aArgs.begin(), aArgs.end());
-
-  for (SparseIterator *it : spIters) {
-    aArgs = it->linkNewScope(aArgs);
-    // Dereference the iterator to cache the coordinate.
-    it->deref(builder, loc);
-  }
-
-  // In-place update on reduction variable.
-  assert(aArgs.size() == reduc.size() + needsUniv);
-  for (unsigned i = 0, e = reduc.size(); i < e; i++)
-    reduc[i] = aArgs[i];
-
-  Value min;
-  // Finds the minimum coordinate
-  if (!needsUniv) {
-    for (SparseIterator *it : spIters) {
-      if (min) {
-        Value cmp = CMPI(ult, it->getCrd(), min);
-        min = SELECT(cmp, it->getCrd(), min);
-      } else {
-        min = it->getCrd();
-      }
-    }
-  } else {
-    // Otherwise, universal index is the minimal pos.
-    min = whileOp.getAfterArguments().back();
-  }
-
-  return {whileOp, min};
+  return genCoIteration(builder, loc, spIters, reduc,
+                        needsUniv ? loopSeqStack.back().first : nullptr);
 }
 
 bool LoopEmitter::shouldIteratedByForLoop(ArrayRef<SparseIterator *> spIters) {
@@ -972,6 +896,100 @@ void LoopEmitter::exitCurrentLoop(RewriterBase &rewriter, Location loc,
   loopStack.pop_back();
 }
 
+//===----------------------------------------------------------------------===//
+// Loop generation utils
+//===----------------------------------------------------------------------===//
+
+std::pair<Operation *, Value> sparse_tensor::genCoIteration(
+    OpBuilder &builder, Location loc, ArrayRef<SparseIterator *> spIters,
+    MutableArrayRef<Value> reduc, Value uniIdx, bool userReducFirst) {
+  // NOTE: the slice driven tensor-related reduction variable must
+  // appear before normal tensors.
+
+  // The set of induction variables for the while loop.
+  SmallVector<Value> ivs;
+
+  // TODO: remove the flag after full migration. Currently
+  // `sparse_tensor.coiterate` operation (must) put user provided reduction
+  // values at the front of the block list, while direct sparsification to scf
+  // loops put them at the end.
+  if (userReducFirst)
+    ivs.append(reduc.begin(), reduc.end());
+
+  // Construct the while-loop with a parameter for each coordinate.
+  for (SparseIterator *it : spIters) {
+    ValueRange itVals = it->getCursor();
+    ivs.append(itVals.begin(), itVals.end());
+  }
+
+  if (!userReducFirst)
+    ivs.append(reduc.begin(), reduc.end());
+
+  // Update universal index.
+  if (uniIdx)
+    ivs.push_back(uniIdx);
+
+  // Ensures all operands are valid.
+  assert(llvm::all_of(ivs, [](Value v) { return v != nullptr; }));
+  TypeRange types = ValueRange(ivs).getTypes();
+  auto whileOp = builder.create<scf::WhileOp>(loc, types, ivs);
+
+  SmallVector<Location> locs(types.size(), loc);
+  Block *before = builder.createBlock(&whileOp.getBefore(), {}, types, locs);
+  Block *after = builder.createBlock(&whileOp.getAfter(), {}, types, locs);
+
+  // Generates loop conditions.
+  builder.setInsertionPointToStart(before);
+  ValueRange bArgs = before->getArguments();
+  Value whileCond = nullptr; // bool values for loop condition.
+
+  for (SparseIterator *it : spIters) {
+    auto [cond, remArgs] = it->genWhileCond(builder, loc, bArgs);
+    whileCond = !whileCond ? cond : ANDI(whileCond, cond);
+    bArgs = remArgs;
+  }
+  // The remaining block arguments are user-provided reduction values and an
+  // optional universal index. Make sure their sizes match.
+  assert(bArgs.size() == reduc.size() + (uniIdx ? 1 : 0));
+  builder.create<scf::ConditionOp>(loc, whileCond, before->getArguments());
+
+  // Generates loop body.
+  builder.setInsertionPointToStart(after);
+  ValueRange aArgs = after->getArguments();
+  // Since some LoopCondKind might need extra checks to filter out invalid
+  // iterations, we maintains another array to hold the iteration arguments to
+  // yield if the checks fails.
+  SmallVector<Value> nextArgs(aArgs.begin(), aArgs.end());
+
+  for (SparseIterator *it : spIters) {
+    aArgs = it->linkNewScope(aArgs);
+    // Dereference the iterator to cache the coordinate.
+    it->deref(builder, loc);
+  }
+
+  // In-place update on reduction variable.
+  for (unsigned i = 0, e = reduc.size(); i < e; i++)
+    reduc[i] = aArgs[i];
+
+  Value min;
+  // Finds the minimum coordinate
+  if (!uniIdx) {
+    for (SparseIterator *it : spIters) {
+      if (min) {
+        Value cmp = CMPI(ult, it->getCrd(), min);
+        min = SELECT(cmp, it->getCrd(), min);
+      } else {
+        min = it->getCrd();
+      }
+    }
+  } else {
+    // Otherwise, universal index is the minimal pos.
+    min = whileOp.getAfterArguments().back();
+  }
+
+  return {whileOp, min};
+}
+
 #undef CMPI
 #undef C_IDX
 #undef YIELD
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h
index a9eb888c8b6bec..3e61b5f27fcc2a 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h
@@ -436,6 +436,17 @@ class LoopEmitter {
   std::vector<std::vector<Value>> spIterVals;
 };
 
+//
+// Utils functions to generate sparse loops.
+//
+
+// Generate a while loop that co-iterates over a set of iterators.
+std::pair<Operation *, Value> genCoIteration(OpBuilder &builder, Location loc,
+                                             ArrayRef<SparseIterator *> iters,
+                                             MutableArrayRef<Value> reduc,
+                                             Value uniIdx,
+                                             bool userReducFirst = false);
+
 } // namespace sparse_tensor
 } // namespace mlir
 
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h
index 91f363db93f1df..642cb1afa156b0 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h
@@ -95,6 +95,8 @@ enum class IterKind : uint8_t {
 class SparseIterationSpace {
 public:
   SparseIterationSpace() = default;
+  SparseIterationSpace(SparseIterationSpace &) = delete;
+  SparseIterationSpace(SparseIterationSpace &&) = default;
 
   // Constructs a N-D iteration space.
   SparseIterationSpace(Location loc, OpBuilder &b, Value t, unsigned tid,
diff --git a/mlir/test/Dialect/SparseTensor/sparse_kernels_to_iterator.mlir b/mlir/test/Dialect/SparseTensor/sparse_kernels_to_iterator.mlir
index 2487156a9a2e48..f819458e038582 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_kernels_to_iterator.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_kernels_to_iterator.mlir
@@ -1,7 +1,5 @@
 // RUN: mlir-opt %s --sparse-reinterpret-map -sparsification="sparse-emit-strategy=sparse-iterator" --cse | FileCheck %s --check-prefix="ITER"
-
-// TODO: temporarilly disabled since there is no lowering rules from `coiterate` to `scf`.
-// R_U_N: mlir-opt %s --sparse-reinterpret-map -sparsification="sparse-emit-strategy=sparse-iterator" --cse --sparse-space-collapse --lower-sparse-iteration-to-scf --loop-invariant-code-motion | FileCheck %s
+// RUN: mlir-opt %s --sparse-reinterpret-map -sparsification="sparse-emit-strategy=sparse-iterator" --cse --sparse-space-collapse --lower-sparse-iteration-to-scf --loop-invariant-code-motion -cse --canonicalize | FileCheck %s
 
 
 
@@ -79,6 +77,79 @@ func.func @sqsum(%arg0: tensor<?x?x?x?xi32, #COO>) -> tensor<i32> {
 // ITER:           bufferization.to_tensor
 // ITER:           return
 // ITER:         }
+
+// CHECK-LABEL:   func.func @add(
+// CHECK-SAME:      %[[VAL_0:.*]]: tensor<10xi32, #sparse{{.*}}>,
+// CHECK-SAME:      %[[VAL_1:.*]]: tensor<10xi32, #sparse{{.*}}>) -> tensor<10xi32> {
+// CHECK:           %[[VAL_2:.*]] = arith.constant 1 : index
+// CHECK:           %[[VAL_3:.*]] = arith.constant 0 : index
+// CHECK:           %[[VAL_4:.*]] = arith.constant 0 : i32
+// CHECK:           %[[VAL_5:.*]] = arith.constant dense<0> : tensor<10xi32>
+// CHECK:           %[[VAL_6:.*]] = bufferization.to_memref %[[VAL_5]] : memref<10xi32>
+// CHECK:           linalg.fill ins(%[[VAL_4]] : i32) outs(%[[VAL_6]] : memref<10xi32>)
+// CHECK:           %[[VAL_7:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<10xi32, #sparse{{.*}}> to memref<?xindex>
+// CHECK:           %[[VAL_8:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<10xi32, #sparse{{.*}}> to memref<?xindex>
+// CHECK:           %[[VAL_9:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_3]]] : memref<?xindex>
+// CHECK:           %[[VAL_10:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_2]]] : memref<?xindex>
+// CHECK:           %[[VAL_11:.*]] = sparse_tensor.positions %[[VAL_1]] {level = 0 : index} : tensor<10xi32, #sparse{{.*}}> to memref<?xindex>
+// CHECK:           %[[VAL_12:.*]] = sparse_tensor.coordinates %[[VAL_1]] {level = 0 : index} : tensor<10xi32, #sparse{{.*}}> to memref<?xindex>
+// CHECK:           %[[VAL_13:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_3]]] : memref<?xindex>
+// CHECK:           %[[VAL_14:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_2]]] : memref<?xindex>
+// CHECK:           %[[VAL_15:.*]]:2 = scf.while (%[[VAL_16:.*]] = %[[VAL_9]], %[[VAL_17:.*]] = %[[VAL_13]]) : (index, index) -> (index, index) {
+// CHECK:             %[[VAL_18:.*]] = arith.cmpi ult, %[[VAL_16]], %[[VAL_10]] : index
+// CHECK:             %[[VAL_19:.*]] = arith.cmpi ult, %[[VAL_17]], %[[VAL_14]] : index
+// CHECK:             %[[VAL_20:.*]] = arith.andi %[[VAL_18]], %[[VAL_19]] : i1
+// CHECK:             scf.condition(%[[VAL_20]]) %[[VAL_16]], %[[VAL_17]] : index, index
+// CHECK:           } do {
+// CHECK:           ^bb0(%[[VAL_21:.*]]: index, %[[VAL_22:.*]]: index):
+// CHECK:             %[[VAL_23:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_21]]] : memref<?xindex>
+// CHECK:             %[[VAL_24:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_22]]] : memref<?xindex>
+// CHECK:             %[[VAL_25:.*]] = arith.cmpi ult, %[[VAL_24]], %[[VAL_23]] : index
+// CHECK:             %[[VAL_26:.*]] = arith.select %[[VAL_25]], %[[VAL_24]], %[[VAL_23]] : index
+// CHECK:             %[[VAL_27:.*]] = arith.cmpi eq, %[[VAL_23]], %[[VAL_26]] : index
+// CHECK:             %[[VAL_28:.*]] = arith.cmpi eq, %[[VAL_24]], %[[VAL_26]] : index
+// CHECK:             %[[VAL_29:.*]] = arith.andi %[[VAL_27]], %[[VAL_28]] : i1
+// CHECK:             scf.if %[[VAL_29]] {
+// CHECK:               %[[VAL_30:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<10xi32, #sparse{{.*}}> to memref<?xi32>
+// CHECK:               %[[VAL_31:.*]] = memref.load %[[VAL_30]]{{\[}}%[[VAL_21]]] : memref<?xi32>
+// CHECK:               %[[VAL_32:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<10xi32, #sparse{{.*}}> to memref<?xi32>
+// CHECK:               %[[VAL_33:.*]] = memref.load %[[VAL_32]]{{\[}}%[[VAL_22]]] : memref<?xi32>
+// CHECK:               %[[VAL_34:.*]] = arith.addi %[[VAL_31]], %[[VAL_33]] : i32
+// CHECK:               memref.store %[[VAL_34]], %[[VAL_6]]{{\[}}%[[VAL_26]]] : memref<10xi32>
+// CHECK:             } else {
+// CHECK:               scf.if %[[VAL_27]] {
+// CHECK:                 %[[VAL_35:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<10xi32, #sparse{{.*}}> to memref<?xi32>
+// CHECK:                 %[[VAL_36:.*]] = memref.load %[[VAL_35]]{{\[}}%[[VAL_21]]] : memref<?xi32>
+// CHECK:                 memref.store %[[VAL_36]], %[[VAL_6]]{{\[}}%[[VAL_26]]] : memref<10xi32>
+// CHECK:               } else {
+// CHECK:                 scf.if %[[VAL_28]] {
+// CHECK:                   %[[VAL_37:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<10xi32, #sparse{{.*}}> to memref<?xi32>
+// CHECK:                   %[[VAL_38:.*]] = memref.load %[[VAL_37]]{{\[}}%[[VAL_22]]] : memref<?xi32>
+// CHECK:                   memref.store %[[VAL_38]], %[[VAL_6]]{{\[}}%[[VAL_26]]] : memref<10xi32>
+// CHECK:                 }
+// CHECK:               }
+// CHECK:             }
+// CHECK:             %[[VAL_39:.*]] = arith.addi %[[VAL_21]], %[[VAL_2]] : index
+// CHECK:             %[[VAL_40:.*]] = arith.select %[[VAL_27]], %[[VAL_39]], %[[VAL_21]] : index
+// CHECK:             %[[VAL_41:.*]] = arith.addi %[[VAL_22]], %[[VAL_2]] : index
+// CHECK:             %[[VAL_42:.*]] = arith.select %[[VAL_28]], %[[VAL_41]], %[[VAL_22]] : index
+// CHECK:             scf.yield %[[VAL_40]], %[[VAL_42]] : index, index
+// CHECK:           }
+// CHECK:           %[[VAL_43:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<10xi32, #sparse{{.*}}> to memref<?xi32>
+// CHECK:           scf.for %[[VAL_44:.*]] = %[[VAL_45:.*]]#0 to %[[VAL_10]] step %[[VAL_2]] {
+// CHECK:             %[[VAL_46:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_44]]] : memref<?xindex>
+// CHECK:             %[[VAL_47:.*]] = memref.load %[[VAL_43]]{{\[}}%[[VAL_44]]] : memref<?xi32>
+// CHECK:             memref.store %[[VAL_47]], %[[VAL_6]]{{\[}}%[[VAL_46]]] : memref<10xi32>
+// CHECK:           }
+// CHECK:           %[[VAL_48:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<10xi32, #sparse{{.*}}> to memref<?xi32>
+// CHECK:           scf.for %[[VAL_49:.*]] = %[[VAL_50:.*]]#1 to %[[VAL_14]] step %[[VAL_2]] {
+// CHECK:             %[[VAL_51:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_49]]] : memref<?xindex>
+// CHECK:             %[[VAL_52:.*]] = memref.load %[[VAL_48]]{{\[}}%[[VAL_49]]] : memref<?xi32>
+// CHECK:             memref.store %[[VAL_52]], %[[VAL_6]]{{\[}}%[[VAL_51]]] : memref<10xi32>
+// CHECK:           }
+// CHECK:           %[[VAL_53:.*]] = bufferization.to_tensor %[[VAL_6]] : memref<10xi32>
+// CHECK:           return %[[VAL_53]] : tensor<10xi32>
+// CHECK:         }
 func.func @add(%arg0: tensor<10xi32, #VEC>, %arg1: tensor<10xi32, #VEC>) -> tensor<10xi32> {
   %cst = arith.constant dense<0> : tensor<10xi32>
   %0 = linalg.generic {
diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/iterator-based-sqsum.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/iterator-based-kernel.mlir
similarity index 63%
rename from mlir/test/Integration/Dialect/SparseTensor/CPU/iterator-based-sqsum.mlir
rename to mlir/test/Integration/Dialect/SparseTensor/CPU/iterator-based-kernel.mlir
index 6d03565f8f7b2a..6cca4fa86a162e 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/iterator-based-sqsum.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/iterator-based-kernel.mlir
@@ -35,9 +35,13 @@
   explicitVal = 1 : i32
 }>
 
-// An example of vector reductions.
-module {
+#VEC = #sparse_tensor.encoding<{
+  map = (d0) -> (d0 : compressed)
+}>
 
+
+module {
+  // An example of vector reductions (lowered through sparse_tensor.iterate).
   func.func @sqsum(%arg0: tensor<2x3x4x5xi32, #COO>) -> tensor<i32> {
     %cst = arith.constant dense<0> : tensor<i32>
     %0 = linalg.generic {
@@ -55,7 +59,30 @@ module {
     return %0 : tensor<i32>
   }
 
+  // An example of vector addition (lowered through sparse_tensor.coiterate).
+  func.func @vec_add(%arg0: tensor<4xi32, #VEC>, %arg1: tensor<4xi32, #VEC>) -> tensor<4xi32> {
+    %cst = arith.constant dense<0> : tensor<4xi32>
+    %0 = linalg.generic {
+      indexing_maps = [
+        affine_map<(d0) -> (d0)>,
+        affine_map<(d0) -> (d0)>,
+        affine_map<(d0) -> (d0)>
+      ],
+      iterator_types = ["parallel"]
+    }
+    ins(%arg0, %arg1 : tensor<4xi32, #VEC>, tensor<4xi32, #VEC>)
+    outs(%cst : tensor<4xi32>) {
+      ^bb0(%in1: i32, %in2: i32, %out: i32):
+        %2 = arith.addi %in1, %in2 : i32
+        linalg.yield %2 : i32
+    } -> tensor<4xi32>
+    return %0 : tensor<4xi32>
+  }
+
   func.func @main() {
+    %c0 = arith.constant 0 : index
+    %i0 = arith.constant 0 : i32
+
     %cst = arith.constant sparse<
      [
        [0, 1, 2, 3],
@@ -66,15 +93,33 @@ module {
      [1, 1, 1, 1]
     > : tensor<2x3x4x5xi32>
 
+    %l = arith.constant dense<
+       [0, 1, 2, 3]
+    > : tensor<4xi32>
+    %r = arith.constant dense<
+       [1, 0, 3, 0]
+    > : tensor<4xi32>
+
     %input = sparse_tensor.convert %cst : tensor<2x3x4x5xi32> to tensor<2x3x4x5xi32, #COO>
     %0 = call @sqsum(%input) : (tensor<2x3x4x5xi32, #COO>) -> tensor<i32>
     %v = tensor.extract %0[] : tensor<i32>
 
+    %lhs = sparse_tensor.convert %l : tensor<4xi32> to tensor<4xi32, #VEC>
+    %rhs = sparse_tensor.convert %r : tensor<4xi32> to tensor<4xi32, #VEC>
+    %add = call @vec_add(%lhs, %rhs) : (tensor<4xi32, #VEC>, tensor<4xi32, #VEC>) -> tensor<4xi32>
+
     // CHECK: 4
     vector.print %v : i32
+    // CHECK-NEXT: ( 1, 1, 5, 3 )
+    %vec = vector.transfer_read %add[%c0], %i0 : tensor<4xi32>, vector<4xi32>
+    vector.print %vec : vector<4xi32>
 
     bufferization.dealloc_tensor %input : tensor<2x3x4x5xi32, #COO>
     bufferization.dealloc_tensor %0 : tensor<i32>
+
+    bufferization.dealloc_tensor %lhs : tensor<4xi32, #VEC>
+    bufferization.dealloc_tensor %rhs : tensor<4xi32, #VEC>
+    bufferization.dealloc_tensor %add : tensor<4xi32>
     return
   }
 }



More information about the Mlir-commits mailing list