[Mlir-commits] [mlir] 7186704 - [mlir][sparse] refactoring sparse_tensor.iterate lowering pattern implementation. (#105566)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Aug 23 11:21:48 PDT 2024


Author: Peiming Liu
Date: 2024-08-23T11:21:44-07:00
New Revision: 71867042041ebb02c2865ed7c9b908e691b31a91

URL: https://github.com/llvm/llvm-project/commit/71867042041ebb02c2865ed7c9b908e691b31a91
DIFF: https://github.com/llvm/llvm-project/commit/71867042041ebb02c2865ed7c9b908e691b31a91.diff

LOG: [mlir][sparse] refactoring sparse_tensor.iterate lowering pattern implementation. (#105566)

Added: 
    

Modified: 
    mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
index d6c0da4a9e4573..f7fcabb0220b50 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
@@ -244,88 +244,41 @@ class SparseIterateOpConverter : public OneToNOpConversionPattern<IterateOp> {
     std::unique_ptr<SparseIterator> it =
         iterSpace.extractIterator(rewriter, loc);
 
-    if (it->iteratableByFor()) {
-      auto [lo, hi] = it->genForCond(rewriter, loc);
-      Value step = constantIndex(rewriter, loc, 1);
-      SmallVector<Value> ivs;
-      for (ValueRange inits : adaptor.getInitArgs())
-        llvm::append_range(ivs, inits);
-      scf::ForOp forOp = rewriter.create<scf::ForOp>(loc, lo, hi, step, ivs);
-
-      Block *loopBody = op.getBody();
-      OneToNTypeMapping bodyTypeMapping(loopBody->getArgumentTypes());
-      if (failed(typeConverter->convertSignatureArgs(
-              loopBody->getArgumentTypes(), bodyTypeMapping)))
-        return failure();
-      rewriter.applySignatureConversion(loopBody, bodyTypeMapping);
-
-      rewriter.eraseBlock(forOp.getBody());
-      Region &dstRegion = forOp.getRegion();
-      rewriter.inlineRegionBefore(op.getRegion(), dstRegion, dstRegion.end());
-
-      auto yieldOp =
-          llvm::cast<sparse_tensor::YieldOp>(forOp.getBody()->getTerminator());
-
-      rewriter.setInsertionPointToEnd(forOp.getBody());
-      // replace sparse_tensor.yield with scf.yield.
-      rewriter.create<scf::YieldOp>(loc, yieldOp.getResults());
-      rewriter.eraseOp(yieldOp);
-
-      const OneToNTypeMapping &resultMapping = adaptor.getResultMapping();
-      rewriter.replaceOp(op, forOp.getResults(), resultMapping);
-    } else {
-      SmallVector<Value> ivs;
-      // TODO: put iterator at the end of argument list to be consistent with
-      // coiterate operation.
-      llvm::append_range(ivs, it->getCursor());
-      for (ValueRange inits : adaptor.getInitArgs())
-        llvm::append_range(ivs, inits);
-
-      assert(llvm::all_of(ivs, [](Value v) { return v != nullptr; }));
-
-      TypeRange types = ValueRange(ivs).getTypes();
-      auto whileOp = rewriter.create<scf::WhileOp>(loc, types, ivs);
-      SmallVector<Location> l(types.size(), op.getIterator().getLoc());
-
-      // Generates loop conditions.
-      Block *before = rewriter.createBlock(&whileOp.getBefore(), {}, types, l);
-      rewriter.setInsertionPointToStart(before);
-      ValueRange bArgs = before->getArguments();
-      auto [whileCond, remArgs] = it->genWhileCond(rewriter, loc, bArgs);
-      assert(remArgs.size() == adaptor.getInitArgs().size());
-      rewriter.create<scf::ConditionOp>(loc, whileCond, before->getArguments());
-
-      // Generates loop body.
-      Block *loopBody = op.getBody();
-      OneToNTypeMapping bodyTypeMapping(loopBody->getArgumentTypes());
-      if (failed(typeConverter->convertSignatureArgs(
-              loopBody->getArgumentTypes(), bodyTypeMapping)))
-        return failure();
-      rewriter.applySignatureConversion(loopBody, bodyTypeMapping);
-
-      Region &dstRegion = whileOp.getAfter();
-      // TODO: handle uses of coordinate!
-      rewriter.inlineRegionBefore(op.getRegion(), dstRegion, dstRegion.end());
-      ValueRange aArgs = whileOp.getAfterArguments();
-      auto yieldOp = llvm::cast<sparse_tensor::YieldOp>(
-          whileOp.getAfterBody()->getTerminator());
-
-      rewriter.setInsertionPointToEnd(whileOp.getAfterBody());
+    SmallVector<Value> ivs;
+    for (ValueRange inits : adaptor.getInitArgs())
+      llvm::append_range(ivs, inits);
+
+    // Type conversion on iterate op block.
+    OneToNTypeMapping blockTypeMapping(op.getBody()->getArgumentTypes());
+    if (failed(typeConverter->convertSignatureArgs(
+            op.getBody()->getArgumentTypes(), blockTypeMapping)))
+      return rewriter.notifyMatchFailure(
+          op, "failed to convert iterate region argurment types");
+    rewriter.applySignatureConversion(op.getBody(), blockTypeMapping);
+
+    Block *block = op.getBody();
+    ValueRange ret = genLoopWithIterator(
+        rewriter, loc, it.get(), ivs, /*iterFirst=*/true,
+        [block](PatternRewriter &rewriter, Location loc, Region &loopBody,
+                SparseIterator *it, ValueRange reduc) -> SmallVector<Value> {
+          SmallVector<Value> blockArgs(it->getCursor());
+          // TODO: Also appends coordinates if used.
+          // blockArgs.push_back(it->deref(rewriter, loc));
+          llvm::append_range(blockArgs, reduc);
+
+          Block *dstBlock = &loopBody.getBlocks().front();
+          rewriter.inlineBlockBefore(block, dstBlock, dstBlock->end(),
+                                     blockArgs);
+          auto yield = llvm::cast<sparse_tensor::YieldOp>(dstBlock->back());
+          // We can not use ValueRange as the operation holding the values will
+          // be destoryed.
+          SmallVector<Value> result(yield.getResults());
+          rewriter.eraseOp(yield);
+          return result;
+        });
 
-      aArgs = it->linkNewScope(aArgs);
-      ValueRange nx = it->forward(rewriter, loc);
-      SmallVector<Value> yields;
-      llvm::append_range(yields, nx);
-      llvm::append_range(yields, yieldOp.getResults());
-
-      // replace sparse_tensor.yield with scf.yield.
-      rewriter.eraseOp(yieldOp);
-      rewriter.create<scf::YieldOp>(loc, yields);
-      const OneToNTypeMapping &resultMapping = adaptor.getResultMapping();
-      rewriter.replaceOp(
-          op, whileOp.getResults().drop_front(it->getCursor().size()),
-          resultMapping);
-    }
+    const OneToNTypeMapping &resultMapping = adaptor.getResultMapping();
+    rewriter.replaceOp(op, ret, resultMapping);
     return success();
   }
 };
@@ -366,9 +319,10 @@ class SparseCoIterateOpConverter
       Block *block = &region.getBlocks().front();
       OneToNTypeMapping blockTypeMapping(block->getArgumentTypes());
       if (failed(typeConverter->convertSignatureArgs(block->getArgumentTypes(),
-                                                     blockTypeMapping)))
+                                                     blockTypeMapping))) {
         return rewriter.notifyMatchFailure(
             op, "failed to convert coiterate region argurment types");
+      }
 
       rewriter.applySignatureConversion(block, blockTypeMapping);
     }


        


More information about the Mlir-commits mailing list