[llvm-branch-commits] [mlir] [mlir][sparse] refactoring sparse_tensor.iterate lowering pattern implementation. (PR #105566)
Peiming Liu via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Wed Aug 21 11:30:15 PDT 2024
https://github.com/PeimingLiu updated https://github.com/llvm/llvm-project/pull/105566
>From 937bcd814688e7c6f88ef27b7586254006e0d050 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Thu, 15 Aug 2024 18:10:25 +0000
Subject: [PATCH] [mlir][sparse] refactoring sparse_tensor.iterate lowering
pattern implementation.
stack-info: PR: https://github.com/llvm/llvm-project/pull/105566, branch: users/PeimingLiu/stack/2
---
.../Transforms/SparseIterationToScf.cpp | 118 ++++++------------
1 file changed, 36 insertions(+), 82 deletions(-)
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
index d6c0da4a9e4573..f7fcabb0220b50 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
@@ -244,88 +244,41 @@ class SparseIterateOpConverter : public OneToNOpConversionPattern<IterateOp> {
std::unique_ptr<SparseIterator> it =
iterSpace.extractIterator(rewriter, loc);
- if (it->iteratableByFor()) {
- auto [lo, hi] = it->genForCond(rewriter, loc);
- Value step = constantIndex(rewriter, loc, 1);
- SmallVector<Value> ivs;
- for (ValueRange inits : adaptor.getInitArgs())
- llvm::append_range(ivs, inits);
- scf::ForOp forOp = rewriter.create<scf::ForOp>(loc, lo, hi, step, ivs);
-
- Block *loopBody = op.getBody();
- OneToNTypeMapping bodyTypeMapping(loopBody->getArgumentTypes());
- if (failed(typeConverter->convertSignatureArgs(
- loopBody->getArgumentTypes(), bodyTypeMapping)))
- return failure();
- rewriter.applySignatureConversion(loopBody, bodyTypeMapping);
-
- rewriter.eraseBlock(forOp.getBody());
- Region &dstRegion = forOp.getRegion();
- rewriter.inlineRegionBefore(op.getRegion(), dstRegion, dstRegion.end());
-
- auto yieldOp =
- llvm::cast<sparse_tensor::YieldOp>(forOp.getBody()->getTerminator());
-
- rewriter.setInsertionPointToEnd(forOp.getBody());
- // replace sparse_tensor.yield with scf.yield.
- rewriter.create<scf::YieldOp>(loc, yieldOp.getResults());
- rewriter.eraseOp(yieldOp);
-
- const OneToNTypeMapping &resultMapping = adaptor.getResultMapping();
- rewriter.replaceOp(op, forOp.getResults(), resultMapping);
- } else {
- SmallVector<Value> ivs;
- // TODO: put iterator at the end of argument list to be consistent with
- // coiterate operation.
- llvm::append_range(ivs, it->getCursor());
- for (ValueRange inits : adaptor.getInitArgs())
- llvm::append_range(ivs, inits);
-
- assert(llvm::all_of(ivs, [](Value v) { return v != nullptr; }));
-
- TypeRange types = ValueRange(ivs).getTypes();
- auto whileOp = rewriter.create<scf::WhileOp>(loc, types, ivs);
- SmallVector<Location> l(types.size(), op.getIterator().getLoc());
-
- // Generates loop conditions.
- Block *before = rewriter.createBlock(&whileOp.getBefore(), {}, types, l);
- rewriter.setInsertionPointToStart(before);
- ValueRange bArgs = before->getArguments();
- auto [whileCond, remArgs] = it->genWhileCond(rewriter, loc, bArgs);
- assert(remArgs.size() == adaptor.getInitArgs().size());
- rewriter.create<scf::ConditionOp>(loc, whileCond, before->getArguments());
-
- // Generates loop body.
- Block *loopBody = op.getBody();
- OneToNTypeMapping bodyTypeMapping(loopBody->getArgumentTypes());
- if (failed(typeConverter->convertSignatureArgs(
- loopBody->getArgumentTypes(), bodyTypeMapping)))
- return failure();
- rewriter.applySignatureConversion(loopBody, bodyTypeMapping);
-
- Region &dstRegion = whileOp.getAfter();
- // TODO: handle uses of coordinate!
- rewriter.inlineRegionBefore(op.getRegion(), dstRegion, dstRegion.end());
- ValueRange aArgs = whileOp.getAfterArguments();
- auto yieldOp = llvm::cast<sparse_tensor::YieldOp>(
- whileOp.getAfterBody()->getTerminator());
-
- rewriter.setInsertionPointToEnd(whileOp.getAfterBody());
+ SmallVector<Value> ivs;
+ for (ValueRange inits : adaptor.getInitArgs())
+ llvm::append_range(ivs, inits);
+
+ // Type conversion on iterate op block.
+ OneToNTypeMapping blockTypeMapping(op.getBody()->getArgumentTypes());
+ if (failed(typeConverter->convertSignatureArgs(
+ op.getBody()->getArgumentTypes(), blockTypeMapping)))
+ return rewriter.notifyMatchFailure(
+ op, "failed to convert iterate region argurment types");
+ rewriter.applySignatureConversion(op.getBody(), blockTypeMapping);
+
+ Block *block = op.getBody();
+ ValueRange ret = genLoopWithIterator(
+ rewriter, loc, it.get(), ivs, /*iterFirst=*/true,
+ [block](PatternRewriter &rewriter, Location loc, Region &loopBody,
+ SparseIterator *it, ValueRange reduc) -> SmallVector<Value> {
+ SmallVector<Value> blockArgs(it->getCursor());
+ // TODO: Also appends coordinates if used.
+ // blockArgs.push_back(it->deref(rewriter, loc));
+ llvm::append_range(blockArgs, reduc);
+
+ Block *dstBlock = &loopBody.getBlocks().front();
+ rewriter.inlineBlockBefore(block, dstBlock, dstBlock->end(),
+ blockArgs);
+ auto yield = llvm::cast<sparse_tensor::YieldOp>(dstBlock->back());
+ // We can not use ValueRange as the operation holding the values will
+ // be destoryed.
+ SmallVector<Value> result(yield.getResults());
+ rewriter.eraseOp(yield);
+ return result;
+ });
- aArgs = it->linkNewScope(aArgs);
- ValueRange nx = it->forward(rewriter, loc);
- SmallVector<Value> yields;
- llvm::append_range(yields, nx);
- llvm::append_range(yields, yieldOp.getResults());
-
- // replace sparse_tensor.yield with scf.yield.
- rewriter.eraseOp(yieldOp);
- rewriter.create<scf::YieldOp>(loc, yields);
- const OneToNTypeMapping &resultMapping = adaptor.getResultMapping();
- rewriter.replaceOp(
- op, whileOp.getResults().drop_front(it->getCursor().size()),
- resultMapping);
- }
+ const OneToNTypeMapping &resultMapping = adaptor.getResultMapping();
+ rewriter.replaceOp(op, ret, resultMapping);
return success();
}
};
@@ -366,9 +319,10 @@ class SparseCoIterateOpConverter
Block *block = ®ion.getBlocks().front();
OneToNTypeMapping blockTypeMapping(block->getArgumentTypes());
if (failed(typeConverter->convertSignatureArgs(block->getArgumentTypes(),
- blockTypeMapping)))
+ blockTypeMapping))) {
return rewriter.notifyMatchFailure(
op, "failed to convert coiterate region argurment types");
+ }
rewriter.applySignatureConversion(block, blockTypeMapping);
}
More information about the llvm-branch-commits
mailing list