[llvm-branch-commits] [mlir] [mlir][sparse] unify block arguments order between iterate/coiterate operations. (PR #105567)

Peiming Liu via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Wed Aug 21 11:30:14 PDT 2024


https://github.com/PeimingLiu updated https://github.com/llvm/llvm-project/pull/105567

>From 3f83d7a1eadc1101fb96707ecd348925e5aaed70 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Thu, 15 Aug 2024 21:10:37 +0000
Subject: [PATCH] [mlir][sparse] unify block arguments order between
 iterate/coiterate operations.

stack-info: PR: https://github.com/llvm/llvm-project/pull/105567, branch: users/PeimingLiu/stack/3
---
 .../SparseTensor/IR/SparseTensorOps.td        |  7 ++--
 .../SparseTensor/IR/SparseTensorDialect.cpp   | 31 ++++++++--------
 .../Transforms/SparseIterationToScf.cpp       | 36 ++++++-------------
 3 files changed, 31 insertions(+), 43 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index 20512f972e67cd..96a61419a541f7 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -1644,7 +1644,7 @@ def IterateOp : SparseTensor_Op<"iterate",
       return getIterSpace().getType().getSpaceDim();
     }
     BlockArgument getIterator() {
-      return getRegion().getArguments().front();
+      return getRegion().getArguments().back();
     }
     std::optional<BlockArgument> getLvlCrd(Level lvl) {
       if (getCrdUsedLvls()[lvl]) {
@@ -1654,9 +1654,8 @@ def IterateOp : SparseTensor_Op<"iterate",
       return std::nullopt;
     }
     Block::BlockArgListType getCrds() {
-      // The first block argument is iterator, the remaining arguments are
-      // referenced coordinates.
-      return getRegion().getArguments().slice(1, getCrdUsedLvls().count());
+      // User-provided iteration arguments -> coords -> iterator.
+      return getRegion().getArguments().slice(getNumRegionIterArgs(), getCrdUsedLvls().count());
     }
     unsigned getNumRegionIterArgs() {
       return getRegion().getArguments().size() - 1 - getCrdUsedLvls().count();
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 16856b958d4f13..b21bc1a93036c4 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -2228,9 +2228,10 @@ parseSparseIterateLoop(OpAsmParser &parser, OperationState &state,
         parser.getNameLoc(),
         "mismatch in number of sparse iterators and sparse spaces");
 
-  if (failed(parseUsedCoordList(parser, state, blockArgs)))
+  SmallVector<OpAsmParser::Argument> coords;
+  if (failed(parseUsedCoordList(parser, state, coords)))
     return failure();
-  size_t numCrds = blockArgs.size();
+  size_t numCrds = coords.size();
 
   // Parse "iter_args(%arg = %init, ...)"
   bool hasIterArgs = succeeded(parser.parseOptionalKeyword("iter_args"));
@@ -2238,6 +2239,8 @@ parseSparseIterateLoop(OpAsmParser &parser, OperationState &state,
     if (parser.parseAssignmentList(blockArgs, initArgs))
       return failure();
 
+  blockArgs.append(coords);
+
   SmallVector<Type> iterSpaceTps;
   // parse ": sparse_tensor.iter_space -> ret"
   if (parser.parseColon() || parser.parseTypeList(iterSpaceTps))
@@ -2267,7 +2270,7 @@ parseSparseIterateLoop(OpAsmParser &parser, OperationState &state,
 
   if (hasIterArgs) {
     // Strip off leading args that used for coordinates.
-    MutableArrayRef args = MutableArrayRef(blockArgs).drop_front(numCrds);
+    MutableArrayRef args = MutableArrayRef(blockArgs).drop_back(numCrds);
     if (args.size() != initArgs.size() || args.size() != state.types.size()) {
       return parser.emitError(
           parser.getNameLoc(),
@@ -2448,18 +2451,18 @@ void IterateOp::build(OpBuilder &builder, OperationState &odsState,
   odsState.addTypes(initArgs.getTypes());
   Block *bodyBlock = builder.createBlock(bodyRegion);
 
-  // First argument, sparse iterator
-  bodyBlock->addArgument(
-      llvm::cast<IterSpaceType>(iterSpace.getType()).getIteratorType(),
-      odsState.location);
+  // Starts with a list of user-provided loop arguments.
+  for (Value v : initArgs)
+    bodyBlock->addArgument(v.getType(), v.getLoc());
 
-  // Followed by a list of used coordinates.
+  // Follows by a list of used coordinates.
   for (unsigned i = 0, e = crdUsedLvls.count(); i < e; i++)
     bodyBlock->addArgument(builder.getIndexType(), odsState.location);
 
-  // Followed by a list of user-provided loop arguments.
-  for (Value v : initArgs)
-    bodyBlock->addArgument(v.getType(), v.getLoc());
+  // Ends with sparse iterator
+  bodyBlock->addArgument(
+      llvm::cast<IterSpaceType>(iterSpace.getType()).getIteratorType(),
+      odsState.location);
 }
 
 ParseResult IterateOp::parse(OpAsmParser &parser, OperationState &result) {
@@ -2473,9 +2476,9 @@ ParseResult IterateOp::parse(OpAsmParser &parser, OperationState &result) {
     return parser.emitError(parser.getNameLoc(),
                             "expected only one iterator/iteration space");
 
-  iters.append(iterArgs);
+  iterArgs.append(iters);
   Region *body = result.addRegion();
-  if (parser.parseRegion(*body, iters))
+  if (parser.parseRegion(*body, iterArgs))
     return failure();
 
   IterateOp::ensureTerminator(*body, parser.getBuilder(), result.location);
@@ -2580,7 +2583,7 @@ MutableArrayRef<OpOperand> IterateOp::getInitsMutable() {
 }
 
 Block::BlockArgListType IterateOp::getRegionIterArgs() {
-  return getRegion().getArguments().take_back(getNumRegionIterArgs());
+  return getRegion().getArguments().take_front(getNumRegionIterArgs());
 }
 
 std::optional<MutableArrayRef<OpOperand>> IterateOp::getYieldedValuesMutable() {
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
index f7fcabb0220b50..71a229bea990c0 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
@@ -111,7 +111,7 @@ genCoIterateBranchNest(PatternRewriter &rewriter, Location loc, CoIterateOp op,
 
 static ValueRange genLoopWithIterator(
     PatternRewriter &rewriter, Location loc, SparseIterator *it,
-    ValueRange reduc, bool iterFirst,
+    ValueRange reduc,
     function_ref<SmallVector<Value>(PatternRewriter &rewriter, Location loc,
                                     Region &loopBody, SparseIterator *it,
                                     ValueRange reduc)>
@@ -138,15 +138,9 @@ static ValueRange genLoopWithIterator(
     }
     return forOp.getResults();
   }
-  SmallVector<Value> ivs;
-  // TODO: always put iterator SSA values at the end of argument list to be
-  // consistent with coiterate operation.
-  if (!iterFirst)
-    llvm::append_range(ivs, it->getCursor());
-  // Appends the user-provided values.
-  llvm::append_range(ivs, reduc);
-  if (iterFirst)
-    llvm::append_range(ivs, it->getCursor());
+
+  SmallVector<Value> ivs(reduc);
+  llvm::append_range(ivs, it->getCursor());
 
   TypeRange types = ValueRange(ivs).getTypes();
   auto whileOp = rewriter.create<scf::WhileOp>(loc, types, ivs);
@@ -164,12 +158,8 @@ static ValueRange genLoopWithIterator(
     Region &dstRegion = whileOp.getAfter();
     Block *after = rewriter.createBlock(&dstRegion, {}, types, l);
     ValueRange aArgs = whileOp.getAfterArguments();
-    if (iterFirst) {
-      aArgs = it->linkNewScope(aArgs);
-    } else {
-      aArgs = aArgs.take_front(reduc.size());
-      it->linkNewScope(aArgs.drop_front(reduc.size()));
-    }
+    it->linkNewScope(aArgs.drop_front(reduc.size()));
+    aArgs = aArgs.take_front(reduc.size());
 
     rewriter.setInsertionPointToStart(after);
     SmallVector<Value> ret = bodyBuilder(rewriter, loc, dstRegion, it, aArgs);
@@ -177,12 +167,8 @@ static ValueRange genLoopWithIterator(
 
     // Forward loops
     SmallVector<Value> yields;
-    ValueRange nx = it->forward(rewriter, loc);
-    if (iterFirst)
-      llvm::append_range(yields, nx);
     llvm::append_range(yields, ret);
-    if (!iterFirst)
-      llvm::append_range(yields, nx);
+    llvm::append_range(yields, it->forward(rewriter, loc));
     rewriter.create<scf::YieldOp>(loc, yields);
   }
   return whileOp.getResults().drop_front(it->getCursor().size());
@@ -258,13 +244,13 @@ class SparseIterateOpConverter : public OneToNOpConversionPattern<IterateOp> {
 
     Block *block = op.getBody();
     ValueRange ret = genLoopWithIterator(
-        rewriter, loc, it.get(), ivs, /*iterFirst=*/true,
+        rewriter, loc, it.get(), ivs,
         [block](PatternRewriter &rewriter, Location loc, Region &loopBody,
                 SparseIterator *it, ValueRange reduc) -> SmallVector<Value> {
-          SmallVector<Value> blockArgs(it->getCursor());
+          SmallVector<Value> blockArgs(reduc);
           // TODO: Also appends coordinates if used.
           // blockArgs.push_back(it->deref(rewriter, loc));
-          llvm::append_range(blockArgs, reduc);
+          llvm::append_range(blockArgs, it->getCursor());
 
           Block *dstBlock = &loopBody.getBlocks().front();
           rewriter.inlineBlockBefore(block, dstBlock, dstBlock->end(),
@@ -404,7 +390,7 @@ class SparseCoIterateOpConverter
 
         Block *block = &r.getBlocks().front();
         ValueRange curResult = genLoopWithIterator(
-            rewriter, loc, validIters.front(), userReduc, /*iterFirst=*/false,
+            rewriter, loc, validIters.front(), userReduc,
             /*bodyBuilder=*/
             [block](PatternRewriter &rewriter, Location loc, Region &dstRegion,
                     SparseIterator *it,



More information about the llvm-branch-commits mailing list