[Mlir-commits] [mlir] [mlir][sparse] refactoring sparse_tensor.iterate lowering pattern implementation. (PR #105566)
Peiming Liu
llvmlistbot at llvm.org
Thu Aug 22 16:30:32 PDT 2024
https://github.com/PeimingLiu updated https://github.com/llvm/llvm-project/pull/105566
>From 494be1353d44ea82cf522a45a81938534a8fadbe Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Wed, 14 Aug 2024 21:35:46 +0000
Subject: [PATCH 1/2] [mlir][sparse] partially support lowering sparse
coiteration loops to scf.while/for.
stack-info: PR: https://github.com/llvm/llvm-project/pull/105565, branch: users/PeimingLiu/stack/1
---
.../Dialect/SparseTensor/IR/SparseTensor.h | 22 +-
.../SparseTensor/IR/SparseTensorOps.td | 4 +
.../SparseTensor/IR/SparseTensorDialect.cpp | 10 +
.../Transforms/SparseIterationToScf.cpp | 291 +++++++++++++++++-
.../Transforms/Utils/LoopEmitter.cpp | 174 ++++++-----
.../Transforms/Utils/LoopEmitter.h | 11 +
.../Transforms/Utils/SparseTensorIterator.h | 2 +
.../sparse_kernels_to_iterator.mlir | 77 ++++-
...-sqsum.mlir => iterator-based-kernel.mlir} | 49 ++-
9 files changed, 549 insertions(+), 91 deletions(-)
rename mlir/test/Integration/Dialect/SparseTensor/CPU/{iterator-based-sqsum.mlir => iterator-based-kernel.mlir} (63%)
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
index 388efd1c454b1e..fca2629d72efcf 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
@@ -96,24 +96,32 @@ class I64BitSet {
return *this;
}
+ bool isSubSetOf(const I64BitSet p) const {
+ I64BitSet tmp = *this;
+ tmp |= p;
+ return tmp == p;
+ }
+
// Needed by `llvm::const_set_bits_iterator_impl`.
int find_first() const { return min(); }
int find_next(unsigned prev) const {
- if (prev >= max())
+ if (prev >= max() - 1)
return -1;
- uint64_t b = storage >> (prev + 1);
- if (b == 0)
- return -1;
+ uint64_t b = storage >> (prev + static_cast<int64_t>(1));
+ assert(b != 0);
- return llvm::countr_zero(b) + prev + 1;
+ return llvm::countr_zero(b) + prev + static_cast<int64_t>(1);
}
bool operator[](unsigned i) const {
assert(i < 64);
- return (storage & (1 << i)) != 0;
+ return (storage & (static_cast<int64_t>(1) << i)) != 0;
+ }
+ unsigned min() const {
+ unsigned m = llvm::countr_zero(storage);
+ return m == 64 ? -1 : m;
}
- unsigned min() const { return llvm::countr_zero(storage); }
unsigned max() const { return 64 - llvm::countl_zero(storage); }
unsigned count() const { return llvm::popcount(storage); }
bool empty() const { return storage == 0; }
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index 2803223354d5ee..20512f972e67cd 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -1787,6 +1787,10 @@ def SparseTensor_CoIterateOp : SparseTensor_Op<"coiterate",
.take_back(getRegionDefinedSpace(regionIdx).count());
}
ValueRange getYieldedValues(unsigned regionIdx);
+
+ // Returns a vector of regions that are the `sub-cases` of the given case region.
+ // E.g., `case %it1, _, %it3` is a subcase of `case %it1, %it2, %it3`.
+ SmallVector<Region *> getSubCasesOf(unsigned regionIdx);
}];
let hasVerifier = 1;
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index a143189c301a43..16856b958d4f13 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -2745,6 +2745,16 @@ LogicalResult CoIterateOp::verifyRegions() {
return success();
}
+SmallVector<Region *> CoIterateOp::getSubCasesOf(unsigned regionIdx) {
+ SmallVector<Region *> ret;
+ I64BitSet caseBit = getRegionDefinedSpace(regionIdx);
+ for (Region &r : getCaseRegions())
+ if (getRegionDefinedSpace(r.getRegionNumber()).isSubSetOf(caseBit))
+ ret.push_back(&r);
+
+ return ret;
+}
+
//===----------------------------------------------------------------------===//
// Sparse Tensor Dialect Setups.
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
index b1451dee738ac3..d6c0da4a9e4573 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
@@ -1,5 +1,6 @@
#include "Utils/CodegenUtils.h"
+#include "Utils/LoopEmitter.h"
#include "Utils/SparseTensorIterator.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
@@ -49,6 +50,144 @@ convertIteratorType(IteratorType itTp, SmallVectorImpl<Type> &fields) {
return success();
}
+static ValueRange
+genCoIterateBranchNest(PatternRewriter &rewriter, Location loc, CoIterateOp op,
+ Value loopCrd,
+ ArrayRef<std::unique_ptr<SparseIterator>> iters,
+ ArrayRef<Region *> subCases, ArrayRef<Value> userReduc) {
+ if (subCases.empty())
+ return userReduc;
+
+ // The current branch that we are handling.
+ Region *b = subCases.front();
+ Value casePred = constantI1(rewriter, loc, true);
+ I64BitSet caseBits = op.getRegionDefinedSpace(b->getRegionNumber());
+ for (unsigned i : caseBits.bits()) {
+ SparseIterator *it = iters[i].get();
+ Value pred = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
+ it->getCrd(), loopCrd);
+ casePred = rewriter.create<arith::AndIOp>(loc, casePred, pred);
+ }
+ scf::IfOp ifOp = rewriter.create<scf::IfOp>(
+ loc, ValueRange(userReduc).getTypes(), casePred, /*else=*/true);
+ rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front());
+
+ // Erase the empty block.
+ rewriter.eraseBlock(&ifOp.getThenRegion().front());
+ // Set up block arguments: user-provided values -> loop coord -> iterators.
+ SmallVector<Value> blockArgs(userReduc);
+ blockArgs.push_back(loopCrd);
+ for (unsigned idx : caseBits.bits())
+ llvm::append_range(blockArgs, iters[idx]->getCursor());
+
+ IRMapping mapping;
+ for (auto [from, to] :
+ llvm::zip_equal(b->front().getArguments(), blockArgs)) {
+ mapping.map(from, to);
+ }
+
+ // Clone the region, we can not erase the region now because the same region
+ // might be a subcase for multiple lattice point.
+ rewriter.cloneRegionBefore(*b, ifOp.getThenRegion(),
+ ifOp.getThenRegion().begin(), mapping);
+
+ // replace sparse_tensor::YieldOp -> scf::YieldOp
+ auto spY = cast<sparse_tensor::YieldOp>(&ifOp.getThenRegion().front().back());
+ ValueRange yields = spY.getResults();
+ rewriter.eraseOp(spY);
+ rewriter.setInsertionPointToEnd(&ifOp.getThenRegion().front());
+ rewriter.create<scf::YieldOp>(loc, yields);
+
+ // Generates remaining case recursively.
+ rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front());
+ ValueRange res = genCoIterateBranchNest(rewriter, loc, op, loopCrd, iters,
+ subCases.drop_front(), userReduc);
+ if (!res.empty())
+ rewriter.create<scf::YieldOp>(loc, res);
+
+ rewriter.setInsertionPointAfter(ifOp);
+ return ifOp.getResults();
+}
+
+static ValueRange genLoopWithIterator(
+ PatternRewriter &rewriter, Location loc, SparseIterator *it,
+ ValueRange reduc, bool iterFirst,
+ function_ref<SmallVector<Value>(PatternRewriter &rewriter, Location loc,
+ Region &loopBody, SparseIterator *it,
+ ValueRange reduc)>
+ bodyBuilder) {
+ if (it->iteratableByFor()) {
+ auto [lo, hi] = it->genForCond(rewriter, loc);
+ Value step = constantIndex(rewriter, loc, 1);
+ scf::ForOp forOp = rewriter.create<scf::ForOp>(loc, lo, hi, step, reduc);
+ {
+ OpBuilder::InsertionGuard guard(rewriter);
+ // Erase the implicit yield operation created by ForOp when there is no
+ // yielding values.
+ if (!forOp.getBody()->empty())
+ rewriter.eraseOp(&forOp.getBody()->front());
+ assert(forOp.getBody()->empty());
+
+ it->linkNewScope(forOp.getInductionVar());
+ rewriter.setInsertionPointToStart(forOp.getBody());
+ SmallVector<Value> ret = bodyBuilder(rewriter, loc, forOp.getBodyRegion(),
+ it, forOp.getRegionIterArgs());
+
+ rewriter.setInsertionPointToEnd(forOp.getBody());
+ rewriter.create<scf::YieldOp>(loc, ret);
+ }
+ return forOp.getResults();
+ }
+ SmallVector<Value> ivs;
+ // TODO: always put iterator SSA values at the end of argument list to be
+ // consistent with coiterate operation.
+ if (!iterFirst)
+ llvm::append_range(ivs, it->getCursor());
+ // Appends the user-provided values.
+ llvm::append_range(ivs, reduc);
+ if (iterFirst)
+ llvm::append_range(ivs, it->getCursor());
+
+ TypeRange types = ValueRange(ivs).getTypes();
+ auto whileOp = rewriter.create<scf::WhileOp>(loc, types, ivs);
+ {
+ OpBuilder::InsertionGuard guard(rewriter);
+ // Generates loop conditions.
+ SmallVector<Location> l(types.size(), loc);
+ Block *before = rewriter.createBlock(&whileOp.getBefore(), {}, types, l);
+ rewriter.setInsertionPointToStart(before);
+ ValueRange bArgs = before->getArguments();
+ auto [whileCond, remArgs] = it->genWhileCond(rewriter, loc, bArgs);
+ rewriter.create<scf::ConditionOp>(loc, whileCond, before->getArguments());
+
+ // Delegates loop body generation.
+ Region &dstRegion = whileOp.getAfter();
+ Block *after = rewriter.createBlock(&dstRegion, {}, types, l);
+ ValueRange aArgs = whileOp.getAfterArguments();
+ if (iterFirst) {
+ aArgs = it->linkNewScope(aArgs);
+ } else {
+ aArgs = aArgs.take_front(reduc.size());
+ it->linkNewScope(aArgs.drop_front(reduc.size()));
+ }
+
+ rewriter.setInsertionPointToStart(after);
+ SmallVector<Value> ret = bodyBuilder(rewriter, loc, dstRegion, it, aArgs);
+ rewriter.setInsertionPointToEnd(after);
+
+ // Forward loops
+ SmallVector<Value> yields;
+ ValueRange nx = it->forward(rewriter, loc);
+ if (iterFirst)
+ llvm::append_range(yields, nx);
+ llvm::append_range(yields, ret);
+ if (!iterFirst)
+ llvm::append_range(yields, nx);
+ rewriter.create<scf::YieldOp>(loc, yields);
+ }
+ return whileOp.getResults().drop_front(it->getCursor().size());
+}
+
namespace {
/// Sparse codegen rule for number of entries operator.
@@ -136,6 +275,8 @@ class SparseIterateOpConverter : public OneToNOpConversionPattern<IterateOp> {
rewriter.replaceOp(op, forOp.getResults(), resultMapping);
} else {
SmallVector<Value> ivs;
+ // TODO: put iterator at the end of argument list to be consistent with
+ // coiterate operation.
llvm::append_range(ivs, it->getCursor());
for (ValueRange inits : adaptor.getInitArgs())
llvm::append_range(ivs, inits);
@@ -189,6 +330,153 @@ class SparseIterateOpConverter : public OneToNOpConversionPattern<IterateOp> {
}
};
+class SparseCoIterateOpConverter
+ : public OneToNOpConversionPattern<CoIterateOp> {
+ using OneToNOpConversionPattern::OneToNOpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(CoIterateOp op, OpAdaptor adaptor,
+ OneToNPatternRewriter &rewriter) const override {
+ assert(op.getSpaceDim() == 1 && "Not implemented");
+ Location loc = op.getLoc();
+
+ I64BitSet denseBits(0);
+ for (auto [idx, spaceTp] : llvm::enumerate(op.getIterSpaces().getTypes()))
+ if (all_of(cast<IterSpaceType>(spaceTp).getLvlTypes(), isDenseLT))
+ denseBits.set(idx);
+
+ // If there exists a case that only contains dense spaces. I.e., case
+ // bits is a subset of dense bits, or when there is a full empty case (due
+ // to complements), we need a universal pointer to forward the coiteration
+ // loop.
+ bool needUniv =
+ any_of(op.getRegionDefinedSpaces(), [denseBits](I64BitSet caseBits) {
+ // A case for complement.
+ if (caseBits.count() == 0)
+ return true;
+ // An all-dense case.
+ return caseBits.isSubSetOf(denseBits);
+ });
+ assert(!needUniv && "Not implemented");
+ (void)needUniv;
+
+ for (Region ®ion : op.getCaseRegions()) {
+ // Do a one-shot type conversion on all region blocks, since the same
+ // region might be used multiple time.
+ Block *block = ®ion.getBlocks().front();
+ OneToNTypeMapping blockTypeMapping(block->getArgumentTypes());
+ if (failed(typeConverter->convertSignatureArgs(block->getArgumentTypes(),
+ blockTypeMapping)))
+ return rewriter.notifyMatchFailure(
+ op, "failed to convert coiterate region argurment types");
+
+ rewriter.applySignatureConversion(block, blockTypeMapping);
+ }
+
+ SmallVector<SparseIterationSpace> spaces;
+ SmallVector<std::unique_ptr<SparseIterator>> iters;
+ for (auto [spaceTp, spaceVals] : llvm::zip_equal(
+ op.getIterSpaces().getTypes(), adaptor.getIterSpaces())) {
+ // TODO: do we really need tid?
+ spaces.push_back(SparseIterationSpace::fromValues(
+ cast<IterSpaceType>(spaceTp), spaceVals, /*tid=*/0));
+ // Extract the iterator.
+ iters.push_back(spaces.back().extractIterator(rewriter, loc));
+ }
+
+ auto getFilteredIters = [&iters](I64BitSet caseBits) {
+ // Retrives a vector of pointers to the iterators used in the case.
+ SmallVector<SparseIterator *> validIters;
+ for (auto idx : caseBits.bits())
+ validIters.push_back(iters[idx].get());
+ return validIters;
+ };
+
+ // Get a flattened user-provided loop reduction values.
+ SmallVector<Value> userReduc;
+ for (ValueRange r : adaptor.getInitArgs())
+ llvm::append_range(userReduc, r);
+
+ // TODO: we need to sort the cases such that they appears in lexical order.
+ // Although sparsification always generates cases in that order, it might
+ // not be the case for human-written code.
+
+ // Generates a loop sequence, one loop per case.
+ for (auto [r, caseBits] :
+ llvm::zip_equal(op.getCaseRegions(), op.getRegionDefinedSpaces())) {
+ assert(caseBits.count() > 0 && "Complement space not implemented");
+
+ // Retrives a vector of pointers to the iterators used in the case.
+ SmallVector<SparseIterator *> validIters = getFilteredIters(caseBits);
+
+ if (validIters.size() > 1) {
+ auto [loop, loopCrd] =
+ genCoIteration(rewriter, loc, validIters, userReduc,
+ /*uniIdx=*/nullptr, /*userReducFirst=*/true);
+
+ // 1st. find all the cases that is a strict subset of the current case
+ // condition, for which we generate one branch per case inside the loop.
+ // The subcases are never empty, it must contains at least the current
+ // region itself.
+ // TODO: these cases should be sorted.
+ SmallVector<Region *> subCases = op.getSubCasesOf(r.getRegionNumber());
+ assert(!subCases.empty());
+
+ ValueRange res = genCoIterateBranchNest(rewriter, loc, op, loopCrd,
+ iters, subCases, userReduc);
+
+ SmallVector<Value> nextIterYields(res);
+ // 2nd. foward the loop.
+ for (SparseIterator *it : validIters) {
+ Value cmp = rewriter.create<arith::CmpIOp>(
+ loc, arith::CmpIPredicate::eq, it->getCrd(), loopCrd);
+ it->forwardIf(rewriter, loc, cmp);
+ llvm::append_range(nextIterYields, it->getCursor());
+ }
+ rewriter.create<scf::YieldOp>(loc, nextIterYields);
+
+ // Exit the loop, relink the iterator SSA value.
+ rewriter.setInsertionPointAfter(loop);
+ ValueRange iterVals = loop->getResults().drop_front(userReduc.size());
+ for (SparseIterator *it : validIters)
+ iterVals = it->linkNewScope(iterVals);
+ assert(iterVals.empty());
+
+ ValueRange curResult = loop->getResults().take_front(userReduc.size());
+ userReduc.assign(curResult.begin(), curResult.end());
+ } else {
+ // This is a simple iteration loop.
+ assert(caseBits.count() == 1);
+
+ Block *block = &r.getBlocks().front();
+ ValueRange curResult = genLoopWithIterator(
+ rewriter, loc, validIters.front(), userReduc, /*iterFirst=*/false,
+ /*bodyBuilder=*/
+ [block](PatternRewriter &rewriter, Location loc, Region &dstRegion,
+ SparseIterator *it,
+ ValueRange reduc) -> SmallVector<Value> {
+ SmallVector<Value> blockArgs(reduc);
+ blockArgs.push_back(it->deref(rewriter, loc));
+ llvm::append_range(blockArgs, it->getCursor());
+
+ Block *dstBlock = &dstRegion.getBlocks().front();
+ rewriter.inlineBlockBefore(
+ block, dstBlock, rewriter.getInsertionPoint(), blockArgs);
+ auto yield = llvm::cast<sparse_tensor::YieldOp>(dstBlock->back());
+ SmallVector<Value> result(yield.getResults());
+ rewriter.eraseOp(yield);
+ return result;
+ });
+
+ userReduc.assign(curResult.begin(), curResult.end());
+ }
+ }
+
+ rewriter.replaceOp(op, userReduc);
+ return success();
+ }
+};
+
} // namespace
mlir::SparseIterationTypeConverter::SparseIterationTypeConverter() {
@@ -210,5 +498,6 @@ void mlir::populateLowerSparseIterationToSCFPatterns(
IterateOp::getCanonicalizationPatterns(patterns, patterns.getContext());
patterns.add<ExtractIterSpaceConverter, ExtractValOpConverter,
- SparseIterateOpConverter>(converter, patterns.getContext());
+ SparseIterateOpConverter, SparseCoIterateOpConverter>(
+ converter, patterns.getContext());
}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp
index efb3295fb2a4bf..cb5874ff45068e 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp
@@ -524,84 +524,8 @@ std::pair<Operation *, Value> LoopEmitter::emitForLoopOverTensorAtLvl(
std::pair<Operation *, Value> LoopEmitter::emitWhileLoopOverTensorsAtLvls(
OpBuilder &builder, Location loc, ArrayRef<SparseIterator *> spIters,
MutableArrayRef<Value> reduc, bool needsUniv) {
- // NOTE: the slice driven tensor-related reduction variable must
- // appear before normal tensors.
-
- // The set of induction variables for the while loop.
- SmallVector<Value> ivs;
-
- // Construct the while-loop with a parameter for each coordinate.
- for (SparseIterator *it : spIters) {
- ValueRange itVals = it->getCursor();
- ivs.append(itVals.begin(), itVals.end());
- }
-
- // The position where user-supplied reduction variable starts.
- ivs.append(reduc.begin(), reduc.end());
- // Update universal index.
- if (needsUniv)
- ivs.push_back(loopSeqStack.back().first);
-
- // Ensures all operands are valid.
- assert(llvm::all_of(ivs, [](Value v) { return v != nullptr; }));
- TypeRange types = ValueRange(ivs).getTypes();
- auto whileOp = builder.create<scf::WhileOp>(loc, types, ivs);
-
- SmallVector<Location> locs(types.size(), loc);
- Block *before = builder.createBlock(&whileOp.getBefore(), {}, types, locs);
- Block *after = builder.createBlock(&whileOp.getAfter(), {}, types, locs);
-
- // Generates loop conditions.
- builder.setInsertionPointToStart(before);
- ValueRange bArgs = before->getArguments();
- Value whileCond = nullptr; // bool values for loop condition.
-
- for (SparseIterator *it : spIters) {
- auto [cond, remArgs] = it->genWhileCond(builder, loc, bArgs);
- whileCond = !whileCond ? cond : ANDI(whileCond, cond);
- bArgs = remArgs;
- }
- // The remaining block arguments are user-provided reduction values and an
- // optional universal index. Make sure their sizes match.
- assert(bArgs.size() == reduc.size() + needsUniv);
- builder.create<scf::ConditionOp>(loc, whileCond, before->getArguments());
-
- // Generates loop body.
- builder.setInsertionPointToStart(after);
- ValueRange aArgs = after->getArguments();
- // Since some LoopCondKind might need extra checks to filter out invalid
- // iterations, we maintains another array to hold the iteration arguments to
- // yield if the checks fails.
- SmallVector<Value> nextArgs(aArgs.begin(), aArgs.end());
-
- for (SparseIterator *it : spIters) {
- aArgs = it->linkNewScope(aArgs);
- // Dereference the iterator to cache the coordinate.
- it->deref(builder, loc);
- }
-
- // In-place update on reduction variable.
- assert(aArgs.size() == reduc.size() + needsUniv);
- for (unsigned i = 0, e = reduc.size(); i < e; i++)
- reduc[i] = aArgs[i];
-
- Value min;
- // Finds the minimum coordinate
- if (!needsUniv) {
- for (SparseIterator *it : spIters) {
- if (min) {
- Value cmp = CMPI(ult, it->getCrd(), min);
- min = SELECT(cmp, it->getCrd(), min);
- } else {
- min = it->getCrd();
- }
- }
- } else {
- // Otherwise, universal index is the minimal pos.
- min = whileOp.getAfterArguments().back();
- }
-
- return {whileOp, min};
+ return genCoIteration(builder, loc, spIters, reduc,
+ needsUniv ? loopSeqStack.back().first : nullptr);
}
bool LoopEmitter::shouldIteratedByForLoop(ArrayRef<SparseIterator *> spIters) {
@@ -972,6 +896,100 @@ void LoopEmitter::exitCurrentLoop(RewriterBase &rewriter, Location loc,
loopStack.pop_back();
}
+//===----------------------------------------------------------------------===//
+// Loop generation utils
+//===----------------------------------------------------------------------===//
+
+std::pair<Operation *, Value> sparse_tensor::genCoIteration(
+ OpBuilder &builder, Location loc, ArrayRef<SparseIterator *> spIters,
+ MutableArrayRef<Value> reduc, Value uniIdx, bool userReducFirst) {
+ // NOTE: the slice driven tensor-related reduction variable must
+ // appear before normal tensors.
+
+ // The set of induction variables for the while loop.
+ SmallVector<Value> ivs;
+
+ // TODO: remove the flag after full migration. Currently
+ // `sparse_tensor.coiterate` operation (must) put user provided reduction
+ // values at the front of the block list, while direct sparsification to scf
+ // loops put them at the end.
+ if (userReducFirst)
+ ivs.append(reduc.begin(), reduc.end());
+
+ // Construct the while-loop with a parameter for each coordinate.
+ for (SparseIterator *it : spIters) {
+ ValueRange itVals = it->getCursor();
+ ivs.append(itVals.begin(), itVals.end());
+ }
+
+ if (!userReducFirst)
+ ivs.append(reduc.begin(), reduc.end());
+
+ // Update universal index.
+ if (uniIdx)
+ ivs.push_back(uniIdx);
+
+ // Ensures all operands are valid.
+ assert(llvm::all_of(ivs, [](Value v) { return v != nullptr; }));
+ TypeRange types = ValueRange(ivs).getTypes();
+ auto whileOp = builder.create<scf::WhileOp>(loc, types, ivs);
+
+ SmallVector<Location> locs(types.size(), loc);
+ Block *before = builder.createBlock(&whileOp.getBefore(), {}, types, locs);
+ Block *after = builder.createBlock(&whileOp.getAfter(), {}, types, locs);
+
+ // Generates loop conditions.
+ builder.setInsertionPointToStart(before);
+ ValueRange bArgs = before->getArguments();
+ Value whileCond = nullptr; // bool values for loop condition.
+
+ for (SparseIterator *it : spIters) {
+ auto [cond, remArgs] = it->genWhileCond(builder, loc, bArgs);
+ whileCond = !whileCond ? cond : ANDI(whileCond, cond);
+ bArgs = remArgs;
+ }
+ // The remaining block arguments are user-provided reduction values and an
+ // optional universal index. Make sure their sizes match.
+ assert(bArgs.size() == reduc.size() + (uniIdx ? 1 : 0));
+ builder.create<scf::ConditionOp>(loc, whileCond, before->getArguments());
+
+ // Generates loop body.
+ builder.setInsertionPointToStart(after);
+ ValueRange aArgs = after->getArguments();
+ // Since some LoopCondKind might need extra checks to filter out invalid
+ // iterations, we maintains another array to hold the iteration arguments to
+ // yield if the checks fails.
+ SmallVector<Value> nextArgs(aArgs.begin(), aArgs.end());
+
+ for (SparseIterator *it : spIters) {
+ aArgs = it->linkNewScope(aArgs);
+ // Dereference the iterator to cache the coordinate.
+ it->deref(builder, loc);
+ }
+
+ // In-place update on reduction variable.
+ for (unsigned i = 0, e = reduc.size(); i < e; i++)
+ reduc[i] = aArgs[i];
+
+ Value min;
+ // Finds the minimum coordinate
+ if (!uniIdx) {
+ for (SparseIterator *it : spIters) {
+ if (min) {
+ Value cmp = CMPI(ult, it->getCrd(), min);
+ min = SELECT(cmp, it->getCrd(), min);
+ } else {
+ min = it->getCrd();
+ }
+ }
+ } else {
+ // Otherwise, universal index is the minimal pos.
+ min = whileOp.getAfterArguments().back();
+ }
+
+ return {whileOp, min};
+}
+
#undef CMPI
#undef C_IDX
#undef YIELD
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h
index a9eb888c8b6bec..3e61b5f27fcc2a 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h
@@ -436,6 +436,17 @@ class LoopEmitter {
std::vector<std::vector<Value>> spIterVals;
};
+//
+// Utils functions to generate sparse loops.
+//
+
+// Generate a while loop that co-iterates over a set of iterators.
+std::pair<Operation *, Value> genCoIteration(OpBuilder &builder, Location loc,
+ ArrayRef<SparseIterator *> iters,
+ MutableArrayRef<Value> reduc,
+ Value uniIdx,
+ bool userReducFirst = false);
+
} // namespace sparse_tensor
} // namespace mlir
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h
index 91f363db93f1df..642cb1afa156b0 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h
@@ -95,6 +95,8 @@ enum class IterKind : uint8_t {
class SparseIterationSpace {
public:
SparseIterationSpace() = default;
+ SparseIterationSpace(SparseIterationSpace &) = delete;
+ SparseIterationSpace(SparseIterationSpace &&) = default;
// Constructs a N-D iteration space.
SparseIterationSpace(Location loc, OpBuilder &b, Value t, unsigned tid,
diff --git a/mlir/test/Dialect/SparseTensor/sparse_kernels_to_iterator.mlir b/mlir/test/Dialect/SparseTensor/sparse_kernels_to_iterator.mlir
index 2487156a9a2e48..f819458e038582 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_kernels_to_iterator.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_kernels_to_iterator.mlir
@@ -1,7 +1,5 @@
// RUN: mlir-opt %s --sparse-reinterpret-map -sparsification="sparse-emit-strategy=sparse-iterator" --cse | FileCheck %s --check-prefix="ITER"
-
-// TODO: temporarilly disabled since there is no lowering rules from `coiterate` to `scf`.
-// R_U_N: mlir-opt %s --sparse-reinterpret-map -sparsification="sparse-emit-strategy=sparse-iterator" --cse --sparse-space-collapse --lower-sparse-iteration-to-scf --loop-invariant-code-motion | FileCheck %s
+// RUN: mlir-opt %s --sparse-reinterpret-map -sparsification="sparse-emit-strategy=sparse-iterator" --cse --sparse-space-collapse --lower-sparse-iteration-to-scf --loop-invariant-code-motion -cse --canonicalize | FileCheck %s
@@ -79,6 +77,79 @@ func.func @sqsum(%arg0: tensor<?x?x?x?xi32, #COO>) -> tensor<i32> {
// ITER: bufferization.to_tensor
// ITER: return
// ITER: }
+
+// CHECK-LABEL: func.func @add(
+// CHECK-SAME: %[[VAL_0:.*]]: tensor<10xi32, #sparse{{.*}}>,
+// CHECK-SAME: %[[VAL_1:.*]]: tensor<10xi32, #sparse{{.*}}>) -> tensor<10xi32> {
+// CHECK: %[[VAL_2:.*]] = arith.constant 1 : index
+// CHECK: %[[VAL_3:.*]] = arith.constant 0 : index
+// CHECK: %[[VAL_4:.*]] = arith.constant 0 : i32
+// CHECK: %[[VAL_5:.*]] = arith.constant dense<0> : tensor<10xi32>
+// CHECK: %[[VAL_6:.*]] = bufferization.to_memref %[[VAL_5]] : memref<10xi32>
+// CHECK: linalg.fill ins(%[[VAL_4]] : i32) outs(%[[VAL_6]] : memref<10xi32>)
+// CHECK: %[[VAL_7:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<10xi32, #sparse{{.*}}> to memref<?xindex>
+// CHECK: %[[VAL_8:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<10xi32, #sparse{{.*}}> to memref<?xindex>
+// CHECK: %[[VAL_9:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_3]]] : memref<?xindex>
+// CHECK: %[[VAL_10:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_2]]] : memref<?xindex>
+// CHECK: %[[VAL_11:.*]] = sparse_tensor.positions %[[VAL_1]] {level = 0 : index} : tensor<10xi32, #sparse{{.*}}> to memref<?xindex>
+// CHECK: %[[VAL_12:.*]] = sparse_tensor.coordinates %[[VAL_1]] {level = 0 : index} : tensor<10xi32, #sparse{{.*}}> to memref<?xindex>
+// CHECK: %[[VAL_13:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_3]]] : memref<?xindex>
+// CHECK: %[[VAL_14:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_2]]] : memref<?xindex>
+// CHECK: %[[VAL_15:.*]]:2 = scf.while (%[[VAL_16:.*]] = %[[VAL_9]], %[[VAL_17:.*]] = %[[VAL_13]]) : (index, index) -> (index, index) {
+// CHECK: %[[VAL_18:.*]] = arith.cmpi ult, %[[VAL_16]], %[[VAL_10]] : index
+// CHECK: %[[VAL_19:.*]] = arith.cmpi ult, %[[VAL_17]], %[[VAL_14]] : index
+// CHECK: %[[VAL_20:.*]] = arith.andi %[[VAL_18]], %[[VAL_19]] : i1
+// CHECK: scf.condition(%[[VAL_20]]) %[[VAL_16]], %[[VAL_17]] : index, index
+// CHECK: } do {
+// CHECK: ^bb0(%[[VAL_21:.*]]: index, %[[VAL_22:.*]]: index):
+// CHECK: %[[VAL_23:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_21]]] : memref<?xindex>
+// CHECK: %[[VAL_24:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_22]]] : memref<?xindex>
+// CHECK: %[[VAL_25:.*]] = arith.cmpi ult, %[[VAL_24]], %[[VAL_23]] : index
+// CHECK: %[[VAL_26:.*]] = arith.select %[[VAL_25]], %[[VAL_24]], %[[VAL_23]] : index
+// CHECK: %[[VAL_27:.*]] = arith.cmpi eq, %[[VAL_23]], %[[VAL_26]] : index
+// CHECK: %[[VAL_28:.*]] = arith.cmpi eq, %[[VAL_24]], %[[VAL_26]] : index
+// CHECK: %[[VAL_29:.*]] = arith.andi %[[VAL_27]], %[[VAL_28]] : i1
+// CHECK: scf.if %[[VAL_29]] {
+// CHECK: %[[VAL_30:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<10xi32, #sparse{{.*}}> to memref<?xi32>
+// CHECK: %[[VAL_31:.*]] = memref.load %[[VAL_30]]{{\[}}%[[VAL_21]]] : memref<?xi32>
+// CHECK: %[[VAL_32:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<10xi32, #sparse{{.*}}> to memref<?xi32>
+// CHECK: %[[VAL_33:.*]] = memref.load %[[VAL_32]]{{\[}}%[[VAL_22]]] : memref<?xi32>
+// CHECK: %[[VAL_34:.*]] = arith.addi %[[VAL_31]], %[[VAL_33]] : i32
+// CHECK: memref.store %[[VAL_34]], %[[VAL_6]]{{\[}}%[[VAL_26]]] : memref<10xi32>
+// CHECK: } else {
+// CHECK: scf.if %[[VAL_27]] {
+// CHECK: %[[VAL_35:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<10xi32, #sparse{{.*}}> to memref<?xi32>
+// CHECK: %[[VAL_36:.*]] = memref.load %[[VAL_35]]{{\[}}%[[VAL_21]]] : memref<?xi32>
+// CHECK: memref.store %[[VAL_36]], %[[VAL_6]]{{\[}}%[[VAL_26]]] : memref<10xi32>
+// CHECK: } else {
+// CHECK: scf.if %[[VAL_28]] {
+// CHECK: %[[VAL_37:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<10xi32, #sparse{{.*}}> to memref<?xi32>
+// CHECK: %[[VAL_38:.*]] = memref.load %[[VAL_37]]{{\[}}%[[VAL_22]]] : memref<?xi32>
+// CHECK: memref.store %[[VAL_38]], %[[VAL_6]]{{\[}}%[[VAL_26]]] : memref<10xi32>
+// CHECK: }
+// CHECK: }
+// CHECK: }
+// CHECK: %[[VAL_39:.*]] = arith.addi %[[VAL_21]], %[[VAL_2]] : index
+// CHECK: %[[VAL_40:.*]] = arith.select %[[VAL_27]], %[[VAL_39]], %[[VAL_21]] : index
+// CHECK: %[[VAL_41:.*]] = arith.addi %[[VAL_22]], %[[VAL_2]] : index
+// CHECK: %[[VAL_42:.*]] = arith.select %[[VAL_28]], %[[VAL_41]], %[[VAL_22]] : index
+// CHECK: scf.yield %[[VAL_40]], %[[VAL_42]] : index, index
+// CHECK: }
+// CHECK: %[[VAL_43:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<10xi32, #sparse{{.*}}> to memref<?xi32>
+// CHECK: scf.for %[[VAL_44:.*]] = %[[VAL_45:.*]]#0 to %[[VAL_10]] step %[[VAL_2]] {
+// CHECK: %[[VAL_46:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_44]]] : memref<?xindex>
+// CHECK: %[[VAL_47:.*]] = memref.load %[[VAL_43]]{{\[}}%[[VAL_44]]] : memref<?xi32>
+// CHECK: memref.store %[[VAL_47]], %[[VAL_6]]{{\[}}%[[VAL_46]]] : memref<10xi32>
+// CHECK: }
+// CHECK: %[[VAL_48:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<10xi32, #sparse{{.*}}> to memref<?xi32>
+// CHECK: scf.for %[[VAL_49:.*]] = %[[VAL_50:.*]]#1 to %[[VAL_14]] step %[[VAL_2]] {
+// CHECK: %[[VAL_51:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_49]]] : memref<?xindex>
+// CHECK: %[[VAL_52:.*]] = memref.load %[[VAL_48]]{{\[}}%[[VAL_49]]] : memref<?xi32>
+// CHECK: memref.store %[[VAL_52]], %[[VAL_6]]{{\[}}%[[VAL_51]]] : memref<10xi32>
+// CHECK: }
+// CHECK: %[[VAL_53:.*]] = bufferization.to_tensor %[[VAL_6]] : memref<10xi32>
+// CHECK: return %[[VAL_53]] : tensor<10xi32>
+// CHECK: }
func.func @add(%arg0: tensor<10xi32, #VEC>, %arg1: tensor<10xi32, #VEC>) -> tensor<10xi32> {
%cst = arith.constant dense<0> : tensor<10xi32>
%0 = linalg.generic {
diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/iterator-based-sqsum.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/iterator-based-kernel.mlir
similarity index 63%
rename from mlir/test/Integration/Dialect/SparseTensor/CPU/iterator-based-sqsum.mlir
rename to mlir/test/Integration/Dialect/SparseTensor/CPU/iterator-based-kernel.mlir
index 6d03565f8f7b2a..6cca4fa86a162e 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/iterator-based-sqsum.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/iterator-based-kernel.mlir
@@ -35,9 +35,13 @@
explicitVal = 1 : i32
}>
-// An example of vector reductions.
-module {
+#VEC = #sparse_tensor.encoding<{
+ map = (d0) -> (d0 : compressed)
+}>
+
+module {
+ // An example of vector reductions (lowered through sparse_tensor.iterate).
func.func @sqsum(%arg0: tensor<2x3x4x5xi32, #COO>) -> tensor<i32> {
%cst = arith.constant dense<0> : tensor<i32>
%0 = linalg.generic {
@@ -55,7 +59,30 @@ module {
return %0 : tensor<i32>
}
+ // An example of vector addition (lowered through sparse_tensor.coiterate).
+ func.func @vec_add(%arg0: tensor<4xi32, #VEC>, %arg1: tensor<4xi32, #VEC>) -> tensor<4xi32> {
+ %cst = arith.constant dense<0> : tensor<4xi32>
+ %0 = linalg.generic {
+ indexing_maps = [
+ affine_map<(d0) -> (d0)>,
+ affine_map<(d0) -> (d0)>,
+ affine_map<(d0) -> (d0)>
+ ],
+ iterator_types = ["parallel"]
+ }
+ ins(%arg0, %arg1 : tensor<4xi32, #VEC>, tensor<4xi32, #VEC>)
+ outs(%cst : tensor<4xi32>) {
+ ^bb0(%in1: i32, %in2: i32, %out: i32):
+ %2 = arith.addi %in1, %in2 : i32
+ linalg.yield %2 : i32
+ } -> tensor<4xi32>
+ return %0 : tensor<4xi32>
+ }
+
func.func @main() {
+ %c0 = arith.constant 0 : index
+ %i0 = arith.constant 0 : i32
+
%cst = arith.constant sparse<
[
[0, 1, 2, 3],
@@ -66,15 +93,33 @@ module {
[1, 1, 1, 1]
> : tensor<2x3x4x5xi32>
+ %l = arith.constant dense<
+ [0, 1, 2, 3]
+ > : tensor<4xi32>
+ %r = arith.constant dense<
+ [1, 0, 3, 0]
+ > : tensor<4xi32>
+
%input = sparse_tensor.convert %cst : tensor<2x3x4x5xi32> to tensor<2x3x4x5xi32, #COO>
%0 = call @sqsum(%input) : (tensor<2x3x4x5xi32, #COO>) -> tensor<i32>
%v = tensor.extract %0[] : tensor<i32>
+ %lhs = sparse_tensor.convert %l : tensor<4xi32> to tensor<4xi32, #VEC>
+ %rhs = sparse_tensor.convert %r : tensor<4xi32> to tensor<4xi32, #VEC>
+ %add = call @vec_add(%lhs, %rhs) : (tensor<4xi32, #VEC>, tensor<4xi32, #VEC>) -> tensor<4xi32>
+
// CHECK: 4
vector.print %v : i32
+ // CHECK-NEXT: ( 1, 1, 5, 3 )
+ %vec = vector.transfer_read %add[%c0], %i0 : tensor<4xi32>, vector<4xi32>
+ vector.print %vec : vector<4xi32>
bufferization.dealloc_tensor %input : tensor<2x3x4x5xi32, #COO>
bufferization.dealloc_tensor %0 : tensor<i32>
+
+ bufferization.dealloc_tensor %lhs : tensor<4xi32, #VEC>
+ bufferization.dealloc_tensor %rhs : tensor<4xi32, #VEC>
+ bufferization.dealloc_tensor %add : tensor<4xi32>
return
}
}
>From 5d73f2302ef11cf32848f72c79e0d0a611f29a9c Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Thu, 15 Aug 2024 18:10:25 +0000
Subject: [PATCH 2/2] [mlir][sparse] refactoring sparse_tensor.iterate lowering
pattern implementation.
stack-info: PR: https://github.com/llvm/llvm-project/pull/105566, branch: users/PeimingLiu/stack/2
---
.../Transforms/SparseIterationToScf.cpp | 118 ++++++------------
1 file changed, 36 insertions(+), 82 deletions(-)
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
index d6c0da4a9e4573..f7fcabb0220b50 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
@@ -244,88 +244,41 @@ class SparseIterateOpConverter : public OneToNOpConversionPattern<IterateOp> {
std::unique_ptr<SparseIterator> it =
iterSpace.extractIterator(rewriter, loc);
- if (it->iteratableByFor()) {
- auto [lo, hi] = it->genForCond(rewriter, loc);
- Value step = constantIndex(rewriter, loc, 1);
- SmallVector<Value> ivs;
- for (ValueRange inits : adaptor.getInitArgs())
- llvm::append_range(ivs, inits);
- scf::ForOp forOp = rewriter.create<scf::ForOp>(loc, lo, hi, step, ivs);
-
- Block *loopBody = op.getBody();
- OneToNTypeMapping bodyTypeMapping(loopBody->getArgumentTypes());
- if (failed(typeConverter->convertSignatureArgs(
- loopBody->getArgumentTypes(), bodyTypeMapping)))
- return failure();
- rewriter.applySignatureConversion(loopBody, bodyTypeMapping);
-
- rewriter.eraseBlock(forOp.getBody());
- Region &dstRegion = forOp.getRegion();
- rewriter.inlineRegionBefore(op.getRegion(), dstRegion, dstRegion.end());
-
- auto yieldOp =
- llvm::cast<sparse_tensor::YieldOp>(forOp.getBody()->getTerminator());
-
- rewriter.setInsertionPointToEnd(forOp.getBody());
- // replace sparse_tensor.yield with scf.yield.
- rewriter.create<scf::YieldOp>(loc, yieldOp.getResults());
- rewriter.eraseOp(yieldOp);
-
- const OneToNTypeMapping &resultMapping = adaptor.getResultMapping();
- rewriter.replaceOp(op, forOp.getResults(), resultMapping);
- } else {
- SmallVector<Value> ivs;
- // TODO: put iterator at the end of argument list to be consistent with
- // coiterate operation.
- llvm::append_range(ivs, it->getCursor());
- for (ValueRange inits : adaptor.getInitArgs())
- llvm::append_range(ivs, inits);
-
- assert(llvm::all_of(ivs, [](Value v) { return v != nullptr; }));
-
- TypeRange types = ValueRange(ivs).getTypes();
- auto whileOp = rewriter.create<scf::WhileOp>(loc, types, ivs);
- SmallVector<Location> l(types.size(), op.getIterator().getLoc());
-
- // Generates loop conditions.
- Block *before = rewriter.createBlock(&whileOp.getBefore(), {}, types, l);
- rewriter.setInsertionPointToStart(before);
- ValueRange bArgs = before->getArguments();
- auto [whileCond, remArgs] = it->genWhileCond(rewriter, loc, bArgs);
- assert(remArgs.size() == adaptor.getInitArgs().size());
- rewriter.create<scf::ConditionOp>(loc, whileCond, before->getArguments());
-
- // Generates loop body.
- Block *loopBody = op.getBody();
- OneToNTypeMapping bodyTypeMapping(loopBody->getArgumentTypes());
- if (failed(typeConverter->convertSignatureArgs(
- loopBody->getArgumentTypes(), bodyTypeMapping)))
- return failure();
- rewriter.applySignatureConversion(loopBody, bodyTypeMapping);
-
- Region &dstRegion = whileOp.getAfter();
- // TODO: handle uses of coordinate!
- rewriter.inlineRegionBefore(op.getRegion(), dstRegion, dstRegion.end());
- ValueRange aArgs = whileOp.getAfterArguments();
- auto yieldOp = llvm::cast<sparse_tensor::YieldOp>(
- whileOp.getAfterBody()->getTerminator());
-
- rewriter.setInsertionPointToEnd(whileOp.getAfterBody());
+ SmallVector<Value> ivs;
+ for (ValueRange inits : adaptor.getInitArgs())
+ llvm::append_range(ivs, inits);
+
+ // Type conversion on iterate op block.
+ OneToNTypeMapping blockTypeMapping(op.getBody()->getArgumentTypes());
+ if (failed(typeConverter->convertSignatureArgs(
+ op.getBody()->getArgumentTypes(), blockTypeMapping)))
+ return rewriter.notifyMatchFailure(
+ op, "failed to convert iterate region argurment types");
+ rewriter.applySignatureConversion(op.getBody(), blockTypeMapping);
+
+ Block *block = op.getBody();
+ ValueRange ret = genLoopWithIterator(
+ rewriter, loc, it.get(), ivs, /*iterFirst=*/true,
+ [block](PatternRewriter &rewriter, Location loc, Region &loopBody,
+ SparseIterator *it, ValueRange reduc) -> SmallVector<Value> {
+ SmallVector<Value> blockArgs(it->getCursor());
+ // TODO: Also appends coordinates if used.
+ // blockArgs.push_back(it->deref(rewriter, loc));
+ llvm::append_range(blockArgs, reduc);
+
+ Block *dstBlock = &loopBody.getBlocks().front();
+ rewriter.inlineBlockBefore(block, dstBlock, dstBlock->end(),
+ blockArgs);
+ auto yield = llvm::cast<sparse_tensor::YieldOp>(dstBlock->back());
+ // We can not use ValueRange as the operation holding the values will
+ // be destoryed.
+ SmallVector<Value> result(yield.getResults());
+ rewriter.eraseOp(yield);
+ return result;
+ });
- aArgs = it->linkNewScope(aArgs);
- ValueRange nx = it->forward(rewriter, loc);
- SmallVector<Value> yields;
- llvm::append_range(yields, nx);
- llvm::append_range(yields, yieldOp.getResults());
-
- // replace sparse_tensor.yield with scf.yield.
- rewriter.eraseOp(yieldOp);
- rewriter.create<scf::YieldOp>(loc, yields);
- const OneToNTypeMapping &resultMapping = adaptor.getResultMapping();
- rewriter.replaceOp(
- op, whileOp.getResults().drop_front(it->getCursor().size()),
- resultMapping);
- }
+ const OneToNTypeMapping &resultMapping = adaptor.getResultMapping();
+ rewriter.replaceOp(op, ret, resultMapping);
return success();
}
};
@@ -366,9 +319,10 @@ class SparseCoIterateOpConverter
Block *block = ®ion.getBlocks().front();
OneToNTypeMapping blockTypeMapping(block->getArgumentTypes());
if (failed(typeConverter->convertSignatureArgs(block->getArgumentTypes(),
- blockTypeMapping)))
+ blockTypeMapping))) {
return rewriter.notifyMatchFailure(
op, "failed to convert coiterate region argurment types");
+ }
rewriter.applySignatureConversion(block, blockTypeMapping);
}
More information about the Mlir-commits
mailing list