[Mlir-commits] [mlir] [mlir][sparse] support sparsification to coiterate operations. (PR #102546)

Peiming Liu llvmlistbot at llvm.org
Thu Aug 8 16:14:55 PDT 2024


https://github.com/PeimingLiu updated https://github.com/llvm/llvm-project/pull/102546

>From 9c1505faf1bbe46b6774aef1bf40710cb20d258d Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Thu, 8 Aug 2024 17:03:36 +0000
Subject: [PATCH] [mlir][sparse] support sparsification to coiterate
 operations.

---
 .../SparseTensor/IR/SparseTensorOps.td        |   4 +
 .../SparseTensor/IR/SparseTensorDialect.cpp   |  16 +++
 .../Transforms/SparseTensorRewriting.cpp      |   2 +-
 .../Transforms/Sparsification.cpp             | 124 ++++++++++++-----
 .../Transforms/Utils/CodegenEnv.h             |   4 +
 .../Transforms/Utils/LoopEmitter.cpp          | 127 +++++++++++++++---
 .../Transforms/Utils/LoopEmitter.h            |  12 +-
 .../sparse_kernels_to_iterator.mlir           |  60 ++++++++-
 8 files changed, 286 insertions(+), 63 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index 6e17f804993e2a..d6e27dc0f75a4b 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -1749,6 +1749,10 @@ def SparseTensor_CoIterateOp : SparseTensor_Op<"coiterate",
   let results = (outs Variadic<AnyType>:$results);
   let regions = (region VariadicRegion<SizedRegion<1>>:$caseRegions);
 
+  let builders = [
+    OpBuilder<(ins "ValueRange":$iterSpace, "ValueRange":$initArgs, "unsigned":$numCases)>,
+  ];
+
   let extraClassDeclaration = [{
     unsigned getSpaceDim() {
       return llvm::cast<::mlir::sparse_tensor::IterSpaceType>(
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 1135ea32fe1abb..277f4c81564ecd 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -2594,6 +2594,22 @@ void IterateOp::getSuccessorRegions(RegionBranchPoint point,
   regions.push_back(RegionSuccessor(getResults()));
 }
 
+void CoIterateOp::build(OpBuilder &builder, OperationState &odsState,
+                        ValueRange iterSpaces, ValueRange initArgs,
+                        unsigned numCases) {
+  unsigned rank =
+      cast<IterSpaceType>(iterSpaces.front().getType()).getSpaceDim();
+  // All ones.
+  I64BitSet set((1 << rank) - 1);
+  // Fake cases bits. We need to preallocate all the regions as Region can not
+  // be dynamically added later after the operation is created.
+  SmallVector<int64_t> caseBits(numCases, 0);
+  ArrayAttr cases = builder.getI64ArrayAttr(caseBits);
+  return CoIterateOp::build(builder, odsState, initArgs.getTypes(), iterSpaces,
+                            initArgs, set, cases,
+                            /*caseRegionsCount=*/numCases);
+}
+
 ParseResult CoIterateOp::parse(OpAsmParser &parser, OperationState &result) {
 
   SmallVector<Value> spaces;
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index 5fb009e3eebe66..cc372ed1be6217 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -1395,7 +1395,7 @@ struct ForeachRewriter : public OpRewritePattern<ForeachOp> {
       loopEmitter.enterNewLoopSeq(rewriter, loc, tidLvls);
       // Note that reduc will be taken care of by loop emitter and get updated
       // in place.
-      loopEmitter.enterCoIterationOverTensorsAtLvls(rewriter, loc, tidLvls,
+      loopEmitter.enterCoIterationOverTensorsAtLvls(rewriter, loc, tidLvls, 1,
                                                     reduc);
     }
 
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index 08fc104fcbeead..b4575592a8ff2f 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -842,11 +842,13 @@ static bool shouldTryParallize(CodegenEnv &env, LoopId curr,
 /// one sparse level in the list.
 static Operation *genCoIteration(CodegenEnv &env, OpBuilder &builder,
                                  ArrayRef<TensorLevel> tidLvls,
-                                 bool tryParallel, bool needsUniv) {
+                                 unsigned numCases, bool tryParallel,
+                                 bool needsUniv) {
   Operation *loop = *env.genLoopBoundary([&](MutableArrayRef<Value> reduc) {
     // Construct while-loop with a parameter for each index.
     return env.emitter().enterCoIterationOverTensorsAtLvls(
-        builder, env.op().getLoc(), tidLvls, reduc, tryParallel, needsUniv);
+        builder, env.op().getLoc(), tidLvls, numCases, reduc, tryParallel,
+        needsUniv);
   });
   assert(loop);
   return loop;
@@ -855,9 +857,11 @@ static Operation *genCoIteration(CodegenEnv &env, OpBuilder &builder,
 /// Generates a for-loop or a while-loop, depending on whether it implements
 /// singleton iteration or co-iteration over the given conjunction.
 static Operation *genLoop(CodegenEnv &env, OpBuilder &builder, LoopId curr,
-                          bool needsUniv, ArrayRef<TensorLevel> tidLvls) {
+                          unsigned numCases, bool needsUniv,
+                          ArrayRef<TensorLevel> tidLvls) {
   bool tryParallel = shouldTryParallize(env, curr, tidLvls);
-  return genCoIteration(env, builder, tidLvls, tryParallel, needsUniv);
+  return genCoIteration(env, builder, tidLvls, numCases, tryParallel,
+                        needsUniv);
 }
 
 /// Generates the induction structure for a while-loop.
@@ -900,6 +904,26 @@ static void finalizeWhileOp(CodegenEnv &env, OpBuilder &builder,
   // basic block where scf::Yield should be inserted.
 }
 
+/// Generate a case region in the coiterate operation.
+static void genCoIterationCase(CodegenEnv &env, OpBuilder &builder,
+                               unsigned caseIdx, LatPointId allCase,
+                               LatPointId curCase,
+                               MutableArrayRef<Value> reduc) {
+  assert(allCase == curCase || env.merger().latGT(allCase, curCase));
+  const BitVector &allCaseBits = env.merger().lat(allCase).simple;
+  const BitVector &curCaseBits = env.merger().lat(curCase).simple;
+
+  /// Computes the subset of iterators that are valid in the current case being
+  /// generated.
+  I64BitSet caseBit(0);
+  for (auto [idx, set] : llvm::enumerate(allCaseBits.set_bits()))
+    if (curCaseBits.test(set))
+      caseBit.set(idx);
+
+  env.emitter().enterCurCoIterationCase(builder, env.op().getLoc(), caseBit,
+                                        caseIdx, reduc);
+}
+
 /// Generates a single if-statement within a while-loop.
 static scf::IfOp genIf(CodegenEnv &env, OpBuilder &builder, LoopId curr,
                        LatPointId p) {
@@ -1175,7 +1199,10 @@ static bool translateBitsToTidLvlPairs(
 /// Starts a single loop in current sequence.
 static std::pair<Operation *, bool> startLoop(CodegenEnv &env,
                                               OpBuilder &builder, LoopId curr,
-                                              LatPointId li, bool needsUniv) {
+                                              LatPointId li, unsigned numCases,
+                                              bool needsUniv) {
+  // TODO: numCases only used when generating iterator-based loops. Cleanup
+  // after fully migration.
   // The set of tensors + lvls to generate loops on
   SmallVector<TensorLevel> tidLvls;
 
@@ -1186,7 +1213,7 @@ static std::pair<Operation *, bool> startLoop(CodegenEnv &env,
       translateBitsToTidLvlPairs(env, li, curr, tidLvls, affineTidLvls);
 
   // Emit the for/while-loop control.
-  Operation *loop = genLoop(env, builder, curr, needsUniv, tidLvls);
+  Operation *loop = genLoop(env, builder, curr, numCases, needsUniv, tidLvls);
   Location loc = env.op().getLoc();
   for (auto [tidLvl, exp] : affineTidLvls) {
     env.emitter().locateLvlAtAffineAddress(builder, loc, tidLvl, exp);
@@ -1259,42 +1286,73 @@ static void genStmt(CodegenEnv &env, RewriterBase &rewriter, ExprId exp,
   // Start a loop sequence.
   bool needsUniv = startLoopSeq(env, rewriter, exp, curr, lts);
 
-  // Emit a loop for every lattice point L0 >= Li in this loop sequence.
-  // We cannot change this to `for (const LatPointId li : env.set(lts))`
-  // because the loop body causes data-movement which invalidates
-  // the iterator.
+  // When using sparse-iterator-based loops, we only need one loops, as
+  // opposed to a loop sequence, to cover all the iterator spaces.
   const unsigned lsize = env.set(lts).size();
-  for (unsigned i = 0; i < lsize; i++) {
-    const LatPointId li = env.set(lts)[i];
-    // Start a loop.
-    auto [loop, isSingleCond] = startLoop(env, rewriter, curr, li, needsUniv);
-
-    // Visit all lattices points with Li >= Lj to generate the
-    // loop-body, possibly with if statements for coiteration.
-    Value redInput = env.getReduc();
-    Value cntInput = env.getExpandCount();
-    Value insInput = env.getInsertionChain();
-    Value validIns = env.getValidLexInsert();
-    // We cannot change this to `for (const LatPointId lj : env.set(lts))`
-    // because the loop body causes data-movement which invalidates the
-    // iterator.
+  if (env.generatingSparseIterator()) {
+    // Get the largest lattice point and start a loop.
+    const LatPointId li = env.set(lts)[0];
+    auto [loop, isSingleCond] =
+        startLoop(env, rewriter, curr, li, lsize, needsUniv);
+    assert(isSingleCond == llvm::isa<IterateOp>(loop));
+    // We cannot change this to `for (const LatPointId li : env.set(lts))`
+    // because the loop body causes data-movement which invalidates
+    // the iterator.
     for (unsigned j = 0; j < lsize; j++) {
       const LatPointId lj = env.set(lts)[j];
       const ExprId ej = env.lat(lj).exp;
-      if (li == lj || env.merger().latGT(li, lj)) {
-        // Recurse into body of each branch.
-        if (!isSingleCond) {
-          scf::IfOp ifOp = genIf(env, rewriter, curr, lj);
-          genStmt(env, rewriter, ej, curr + 1);
-          endIf(env, rewriter, ifOp, redInput, cntInput, insInput, validIns);
-        } else {
+      // Recurse into body of each branch.
+      if (!isSingleCond) {
+        env.genLoopBoundary([&, curr, j, li, lj](MutableArrayRef<Value> reduc) {
+          genCoIterationCase(env, rewriter, /*caseIdx*/ j, li, lj, reduc);
           genStmt(env, rewriter, ej, curr + 1);
-        }
+          // TODO: handle yield values.
+          assert(reduc.empty() && "Not Implemented");
+          rewriter.create<sparse_tensor::YieldOp>(env.op().getLoc());
+          return std::nullopt;
+        });
+        // endIf(env, rewriter, ifOp, redInput, cntInput, insInput, validIns);
+      } else {
+        genStmt(env, rewriter, ej, curr + 1);
       }
     }
-
     // End a loop.
     needsUniv = endLoop(env, rewriter, loop, curr, needsUniv, isSingleCond);
+  } else {
+    // Emit a loop for every lattice point L0 >= Li in this loop sequence.
+    for (unsigned i = 0; i < lsize; i++) {
+      const LatPointId li = env.set(lts)[i];
+      // Start a loop.
+      auto [loop, isSingleCond] =
+          startLoop(env, rewriter, curr, li, lsize, needsUniv);
+
+      // Visit all lattices points with Li >= Lj to generate the
+      // loop-body, possibly with if statements for coiteration.
+      Value redInput = env.getReduc();
+      Value cntInput = env.getExpandCount();
+      Value insInput = env.getInsertionChain();
+      Value validIns = env.getValidLexInsert();
+      // We cannot change this to `for (const LatPointId lj : env.set(lts))`
+      // because the loop body causes data-movement which invalidates the
+      // iterator.
+      for (unsigned j = 0; j < lsize; j++) {
+        const LatPointId lj = env.set(lts)[j];
+        const ExprId ej = env.lat(lj).exp;
+        if (li == lj || env.merger().latGT(li, lj)) {
+          // Recurse into body of each branch.
+          if (!isSingleCond) {
+            scf::IfOp ifOp = genIf(env, rewriter, curr, lj);
+            genStmt(env, rewriter, ej, curr + 1);
+            endIf(env, rewriter, ifOp, redInput, cntInput, insInput, validIns);
+          } else {
+            genStmt(env, rewriter, ej, curr + 1);
+          }
+        }
+      }
+
+      // End a loop.
+      needsUniv = endLoop(env, rewriter, loop, curr, needsUniv, isSingleCond);
+    }
   }
 
   // End a loop sequence.
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenEnv.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenEnv.h
index d69ae53fb0f298..34b793ee11e4ca 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenEnv.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenEnv.h
@@ -49,6 +49,10 @@ class CodegenEnv {
 
   linalg::GenericOp op() const { return linalgOp; }
   const SparsificationOptions &options() const { return sparseOptions; }
+  bool generatingSparseIterator() const {
+    return sparseOptions.sparseEmitStrategy ==
+           SparseEmitStrategy::kSparseIterator;
+  }
   Merger &merger() { return latticeMerger; }
   LoopEmitter &emitter() { return loopEmitter; }
 
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp
index 2be0193f0de83e..0dce63fe593289 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp
@@ -615,33 +615,104 @@ bool LoopEmitter::shouldIteratedByForLoop(ArrayRef<SparseIterator *> spIters) {
   return true;
 }
 
+Region *LoopEmitter::enterCurCoIterationCase(OpBuilder &builder, Location loc,
+                                             I64BitSet caseBit,
+                                             unsigned caseIdx,
+                                             MutableArrayRef<Value> reduc) {
+  auto coIterOp = cast<CoIterateOp>(loopStack.back().loop);
+  SmallVector<Attribute> cases(coIterOp.getCases().getAsRange<Attribute>());
+  cases[caseIdx] = builder.getI64IntegerAttr(caseBit);
+
+  coIterOp.setCasesAttr(builder.getArrayAttr(cases));
+  Region &caseRegion = coIterOp.getRegion(caseIdx);
+  assert(caseRegion.getBlocks().empty() &&
+         "re-initialize the same coiteration case region.");
+
+  // Each block starts with a list of used coordinates of index type.
+  SmallVector<Type> blockArgTps(coIterOp.getCrdUsedLvls().count(),
+                                builder.getIndexType());
+  // Follows by a list of user-provided iteration arguments.
+  TypeRange iterArgsTps = coIterOp.getInitArgs().getTypes();
+  blockArgTps.append(iterArgsTps.begin(), iterArgsTps.end());
+  // Ends with a set of iterators that defines the actually iteration space.
+  for (auto i : caseBit.bits()) {
+    blockArgTps.push_back(
+        cast<IterSpaceType>(coIterOp.getIterSpaces()[i].getType())
+            .getIteratorType());
+  }
+  SmallVector<Location> locs(blockArgTps.size(), loc);
+  caseRegion.emplaceBlock().addArguments(blockArgTps, locs);
+
+  // Entering the new region scope, updating the SSA chain.
+  builder.setInsertionPointToStart(&caseRegion.front());
+  // Update the coordinates.
+  loopStack.back().iv = coIterOp.getCrds(caseIdx).front();
+  // Updates loop iteration arguments.
+  ValueRange iterArgs = coIterOp.getRegionIterArgs(caseIdx);
+  llvm::copy(iterArgs, reduc.begin());
+  // Updates sparse iterator values.
+  ValueRange iters = coIterOp.getRegionIterators(caseIdx);
+  ArrayRef<TensorLevel> tidLvls = loopStack.back().tidLvls;
+  for (auto [i, tl] : llvm::enumerate(unpackTensorLevelRange(tidLvls))) {
+    if (caseBit[i]) {
+      spIterVals[tl.first][tl.second] = iters.front();
+      iters = iters.drop_front();
+    } else {
+      spIterVals[tl.first][tl.second] = nullptr;
+    }
+  }
+  // Must have consumed all iterator SSA values.
+  assert(iters.empty());
+  return &caseRegion;
+}
+
 Operation *LoopEmitter::enterCoIterationOverTensorsAtLvls(
     OpBuilder &builder, Location loc, ArrayRef<TensorLevel> tidLvls,
-    MutableArrayRef<Value> reduc, bool tryParallel, bool needsUniv) {
-
+    unsigned numCases, MutableArrayRef<Value> reduc, bool tryParallel,
+    bool needsUniv) {
+  // TODO: Argument `numCases` only used when generating iterator-based sparse
+  // loops. Simplify the code upon feature complete.
   // TODO: handle coiteration with sparse iterator.
   if (emitStrategy == SparseEmitStrategy::kSparseIterator) {
-    assert(tidLvls.size() == 1);
-    auto [tid, lvl] = unpackTensorLevel(tidLvls.front());
-    Value t = tensors[tid];
-
-    // Extract and iterate over the iteration space.
-    ExtractIterSpaceOp extractSpaceOp =
-        lvl == 0 ? builder.create<ExtractIterSpaceOp>(loc, t)
-                 : builder.create<ExtractIterSpaceOp>(
-                       loc, t, spIterVals[tid][lvl - 1], lvl);
-
-    IterateOp iterOp = builder.create<IterateOp>(
-        loc, extractSpaceOp.getExtractedSpace(), reduc);
-    spIterVals[tid][lvl] = iterOp.getIterator();
+    if (tidLvls.size() == 1) {
+      auto [tid, lvl] = unpackTensorLevel(tidLvls.front());
+      Value t = tensors[tid];
+
+      // Extract and iterate over the iteration space.
+      ExtractIterSpaceOp extractSpaceOp =
+          lvl == 0 ? builder.create<ExtractIterSpaceOp>(loc, t)
+                   : builder.create<ExtractIterSpaceOp>(
+                         loc, t, spIterVals[tid][lvl - 1], lvl);
+
+      IterateOp iterOp = builder.create<IterateOp>(
+          loc, extractSpaceOp.getExtractedSpace(), reduc);
+      spIterVals[tid][lvl] = iterOp.getIterator();
+
+      // Update the reduction varaibles.
+      llvm::copy(iterOp.getRegionIterArgs(), reduc.begin());
+      // Set the insertion point to loop body.
+      builder.setInsertionPointToStart(iterOp.getBody());
+      loopStack.emplace_back(tidLvls, iterOp, builder.getInsertionBlock(),
+                             iterOp.getCrds().front(), loopTag);
+      return iterOp;
+    }
 
-    // Update the reduction varaibles.
-    llvm::copy(iterOp.getRegionIterArgs(), reduc.begin());
-    // Set the insertion point to loop body.
-    builder.setInsertionPointToStart(iterOp.getBody());
-    loopStack.emplace_back(tidLvls, iterOp, builder.getInsertionBlock(),
-                           iterOp.getIterator(), loopTag);
-    return iterOp;
+    // CoIteration Loops.
+    SmallVector<Value> spaces;
+    for (auto [tid, lvl] : unpackTensorLevelRange(tidLvls)) {
+      Value t = tensors[tid];
+      ExtractIterSpaceOp extractSpaceOp =
+          lvl == 0 ? builder.create<ExtractIterSpaceOp>(loc, t)
+                   : builder.create<ExtractIterSpaceOp>(
+                         loc, t, spIterVals[tid][lvl - 1], lvl);
+      spaces.push_back(extractSpaceOp.getExtractedSpace());
+    }
+    auto coIterOp = builder.create<CoIterateOp>(loc, spaces, reduc, numCases);
+    // The CoIterationOp does not have insertion block nor induction variable.
+    // TODO: the `struct LoopInfo` should be simplied after full migration.
+    loopStack.emplace_back(tidLvls, coIterOp, /*insertion block*/ nullptr,
+                           /*induction variable*/ nullptr, loopTag);
+    return coIterOp;
   }
 
   // TODO: support multiple return on parallel for?
@@ -866,6 +937,18 @@ void LoopEmitter::exitCurrentLoop(RewriterBase &rewriter, Location loc,
   // Clean up the values, it would help use to discover potential bug at a
   // earlier stage (instead of silently using a wrong value).
   const LoopInfo &loopInfo = loopStack.back();
+  if (emitStrategy == SparseEmitStrategy::kSparseIterator) {
+    Operation *p = loopInfo.loop;
+    if (isa<IterateOp>(p))
+      rewriter.create<sparse_tensor::YieldOp>(loc, reduc);
+
+    // Exit the loop.
+    rewriter.setInsertionPointAfter(p);
+    // In-place update reduction variables.
+    llvm::copy(p->getResults(), reduc.begin());
+    loopStack.pop_back();
+    return;
+  }
 
   // Sets the insertion point to the right position.
   rewriter.setInsertionPointToEnd(loopInfo.userCodeBlock);
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h
index f3e73e4692c1fd..4dc594118ad21c 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h
@@ -145,8 +145,12 @@ class LoopEmitter {
   /// return the reduction variable used inside the generated loop.
   Operation *enterCoIterationOverTensorsAtLvls(
       OpBuilder &builder, Location loc, ArrayRef<TensorLevel> tidLvls,
-      MutableArrayRef<Value> reduc = {}, bool isParallel = false,
-      bool needsUniv = false);
+      unsigned numCases, MutableArrayRef<Value> reduc = {},
+      bool isParallel = false, bool needsUniv = false);
+
+  Region *enterCurCoIterationCase(OpBuilder &builder, Location loc,
+                                  I64BitSet caseBit, unsigned caseIdx,
+                                  MutableArrayRef<Value> reduc);
 
   /// Generates code to exit the current loop (e.g., generates yields, forwards
   /// loop induction variables, etc).
@@ -260,9 +264,9 @@ class LoopEmitter {
     // required for levels with non-tivial index expressions, which is
     // maintained by the sliceDrivenInfo array below.
     const llvm::SmallVector<TensorLevel> tidLvls;
-    const Operation *loop;      // the loop operation
+    Operation *loop;            // the loop operation
     Block *const userCodeBlock; // the block holding users' generated code.
-    const Value iv;             // the induction variable for the loop
+    Value iv;                   // the induction variable for the loop
   };
 
   void categorizeIterators(ArrayRef<TensorLevel> tidLvls,
diff --git a/mlir/test/Dialect/SparseTensor/sparse_kernels_to_iterator.mlir b/mlir/test/Dialect/SparseTensor/sparse_kernels_to_iterator.mlir
index 268b3940418b71..2aa91c38f9349c 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_kernels_to_iterator.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_kernels_to_iterator.mlir
@@ -1,4 +1,8 @@
-// RUN: mlir-opt %s --sparse-reinterpret-map -sparsification="sparse-emit-strategy=sparse-iterator" --cse --sparse-space-collapse --lower-sparse-iteration-to-scf --loop-invariant-code-motion | FileCheck %s
+// RUN: mlir-opt %s --sparse-reinterpret-map -sparsification="sparse-emit-strategy=sparse-iterator" --cse | FileCheck %s --check-prefix="ITER"
+
+// TODO: temporially disabled since there is no lowering rules from `coiterate` to `scf`.
+// R_U_N: mlir-opt %s --sparse-reinterpret-map -sparsification="sparse-emit-strategy=sparse-iterator" --cse --sparse-space-collapse --lower-sparse-iteration-to-scf --loop-invariant-code-motion | FileCheck %s
+
 
 
 #COO = #sparse_tensor.encoding<{
@@ -10,13 +14,18 @@
   )
 }>
 
+#VEC = #sparse_tensor.encoding<{
+  map = (d0) -> (d0 : compressed)
+}>
+
+
 // CHECK-LABEL:   func.func @sqsum(
 // CHECK-DAG:       %[[C0:.*]] = arith.constant 0 : index
 // CHECK-DAG:       %[[C1:.*]] = arith.constant 1 : index
-// CHECK-DAG:       %[[POS_BUF:.*]] = sparse_tensor.positions %{{.*}} {level = 0 : index} : tensor<?x?x?x?xi32, #sparse> to memref<?xindex>
+// CHECK-DAG:       %[[POS_BUF:.*]] = sparse_tensor.positions %{{.*}} {level = 0 : index} : tensor<?x?x?x?xi32, #sparse{{.*}}> to memref<?xindex>
 // CHECK:           %[[POS_LO:.*]] = memref.load %[[POS_BUF]]{{\[}}%[[C0]]] : memref<?xindex>
 // CHECK:           %[[POS_HI:.*]] = memref.load %[[POS_BUF]]{{\[}}%[[C1]]] : memref<?xindex>
-// CHECK:           %[[VAL_BUF:.*]] = sparse_tensor.values %{{.*}} : tensor<?x?x?x?xi32, #sparse> to memref<?xi32>
+// CHECK:           %[[VAL_BUF:.*]] = sparse_tensor.values %{{.*}} : tensor<?x?x?x?xi32, #sparse{{.*}}> to memref<?xi32>
 // CHECK:           %[[SQ_SUM:.*]] = scf.for %[[POS:.*]] = %[[POS_LO]] to %[[POS_HI]] step %[[C1]] {{.*}} {
 // CHECK:             %[[VAL:.*]] = memref.load %[[VAL_BUF]]{{\[}}%[[POS]]] : memref<?xi32>
 // CHECK:             %[[MUL:.*]] = arith.muli %[[VAL]], %[[VAL]] : i32
@@ -27,6 +36,12 @@
 // CHECK:           %[[RET:.*]] = bufferization.to_tensor
 // CHECK:           return %[[RET]] : tensor<i32>
 // CHECK:         }
+
+// ITER-LABEL:   func.func @sqsum(
+// ITER:           sparse_tensor.iterate
+// ITER:             sparse_tensor.iterate
+// ITER:               sparse_tensor.iterate
+// ITER:         }
 func.func @sqsum(%arg0: tensor<?x?x?x?xi32, #COO>) -> tensor<i32> {
   %cst = arith.constant dense<0> : tensor<i32>
   %0 = linalg.generic {
@@ -43,3 +58,42 @@ func.func @sqsum(%arg0: tensor<?x?x?x?xi32, #COO>) -> tensor<i32> {
   } -> tensor<i32>
   return %0 : tensor<i32>
 }
+
+
+// ITER-LABEL:   func.func @add(
+// ITER:           sparse_tensor.coiterate
+// ITER:           case %[[IT_1:.*]], %[[IT_2:.*]] {
+// ITER:             %[[LHS:.*]] = sparse_tensor.extract_value %{{.*}} at %[[IT_1]]
+// ITER:             %[[RHS:.*]] = sparse_tensor.extract_value %{{.*}} at %[[IT_2]]
+// ITER:             %[[SUM:.*]] = arith.addi %[[LHS]], %[[RHS]] : i32
+// ITER:             memref.store %[[SUM]]
+// ITER:           }
+// ITER:           case %[[IT_1:.*]], _ {
+// ITER:             %[[LHS:.*]] = sparse_tensor.extract_value %{{.*}} at %[[IT_1]]
+// ITER:             memref.store %[[LHS]]
+// ITER:           }
+// ITER:           case _, %[[IT_2:.*]] {
+// ITER:             %[[RHS:.*]] = sparse_tensor.extract_value %{{.*}} at %[[IT_2]]
+// ITER:             memref.store %[[RHS]]
+// ITER:           }
+// ITER:           bufferization.to_tensor
+// ITER:           return
+// ITER:         }
+func.func @add(%arg0: tensor<10xi32, #VEC>, %arg1: tensor<10xi32, #VEC>) -> tensor<10xi32> {
+  %cst = arith.constant dense<0> : tensor<10xi32>
+  %0 = linalg.generic {
+    indexing_maps = [
+      affine_map<(d0) -> (d0)>,
+      affine_map<(d0) -> (d0)>,
+      affine_map<(d0) -> (d0)>
+    ],
+    iterator_types = ["parallel"]
+  }
+  ins(%arg0, %arg1 : tensor<10xi32, #VEC>, tensor<10xi32, #VEC>)
+  outs(%cst : tensor<10xi32>) {
+    ^bb0(%in1: i32, %in2: i32, %out: i32):
+      %2 = arith.addi %in1, %in2 : i32
+      linalg.yield %2 : i32
+  } -> tensor<10xi32>
+  return %0 : tensor<10xi32>
+}



More information about the Mlir-commits mailing list