[Mlir-commits] [mlir] [mlir][sparse] partially support lowering sparse coiteration loops to scf.while/for. (PR #105565)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Aug 21 11:30:34 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-sparse
Author: Peiming Liu (PeimingLiu)
<details>
<summary>Changes</summary>
Stacked PRs:
* #<!-- -->105567
* #<!-- -->105566
* __->__#<!-- -->105565
--- --- ---
### [mlir][sparse] partially support lowering sparse coiteration loops to scf.while/for.
---
Patch is 35.78 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/105565.diff
9 Files Affected:
- (modified) mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h (+15-7)
- (modified) mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td (+4)
- (modified) mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp (+10)
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp (+290-1)
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp (+96-78)
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h (+11)
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h (+2)
- (modified) mlir/test/Dialect/SparseTensor/sparse_kernels_to_iterator.mlir (+74-3)
- (renamed) mlir/test/Integration/Dialect/SparseTensor/CPU/iterator-based-kernel.mlir (+47-2)
``````````diff
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
index 388efd1c454b1e..915a0cd8d92973 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
@@ -96,24 +96,32 @@ class I64BitSet {
return *this;
}
+ bool isSubSetOf(const I64BitSet p) const {
+ I64BitSet tmp = *this;
+ tmp |= p;
+ return tmp == p;
+ }
+
// Needed by `llvm::const_set_bits_iterator_impl`.
int find_first() const { return min(); }
int find_next(unsigned prev) const {
- if (prev >= max())
+ if (prev >= max() - 1)
return -1;
- uint64_t b = storage >> (prev + 1);
- if (b == 0)
- return -1;
+ uint64_t b = storage >> (prev + 1ULL);
+ assert(b != 0);
- return llvm::countr_zero(b) + prev + 1;
+ return llvm::countr_zero(b) + prev + 1ULL;
}
bool operator[](unsigned i) const {
assert(i < 64);
- return (storage & (1 << i)) != 0;
+ return (storage & (static_cast<int64_t>(1) << i)) != 0;
+ }
+ unsigned min() const {
+ unsigned m = llvm::countr_zero(storage);
+ return m == 64 ? -1 : m;
}
- unsigned min() const { return llvm::countr_zero(storage); }
unsigned max() const { return 64 - llvm::countl_zero(storage); }
unsigned count() const { return llvm::popcount(storage); }
bool empty() const { return storage == 0; }
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index 2803223354d5ee..20512f972e67cd 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -1787,6 +1787,10 @@ def SparseTensor_CoIterateOp : SparseTensor_Op<"coiterate",
.take_back(getRegionDefinedSpace(regionIdx).count());
}
ValueRange getYieldedValues(unsigned regionIdx);
+
+ // Returns a vector of regions that are the `sub-cases` of the given case region.
+ // E.g., `case %it1, _, %it3` is a subcase of `case %it1, %it2, %it3`.
+ SmallVector<Region *> getSubCasesOf(unsigned regionIdx);
}];
let hasVerifier = 1;
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index a143189c301a43..16856b958d4f13 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -2745,6 +2745,16 @@ LogicalResult CoIterateOp::verifyRegions() {
return success();
}
+SmallVector<Region *> CoIterateOp::getSubCasesOf(unsigned regionIdx) {
+ SmallVector<Region *> ret;
+ I64BitSet caseBit = getRegionDefinedSpace(regionIdx);
+ for (Region &r : getCaseRegions())
+ if (getRegionDefinedSpace(r.getRegionNumber()).isSubSetOf(caseBit))
+ ret.push_back(&r);
+
+ return ret;
+}
+
//===----------------------------------------------------------------------===//
// Sparse Tensor Dialect Setups.
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
index b1451dee738ac3..d6c0da4a9e4573 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
@@ -1,5 +1,6 @@
#include "Utils/CodegenUtils.h"
+#include "Utils/LoopEmitter.h"
#include "Utils/SparseTensorIterator.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
@@ -49,6 +50,144 @@ convertIteratorType(IteratorType itTp, SmallVectorImpl<Type> &fields) {
return success();
}
+static ValueRange
+genCoIterateBranchNest(PatternRewriter &rewriter, Location loc, CoIterateOp op,
+ Value loopCrd,
+ ArrayRef<std::unique_ptr<SparseIterator>> iters,
+ ArrayRef<Region *> subCases, ArrayRef<Value> userReduc) {
+ if (subCases.empty())
+ return userReduc;
+
+ // The current branch that we are handling.
+ Region *b = subCases.front();
+ Value casePred = constantI1(rewriter, loc, true);
+ I64BitSet caseBits = op.getRegionDefinedSpace(b->getRegionNumber());
+ for (unsigned i : caseBits.bits()) {
+ SparseIterator *it = iters[i].get();
+ Value pred = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
+ it->getCrd(), loopCrd);
+ casePred = rewriter.create<arith::AndIOp>(loc, casePred, pred);
+ }
+ scf::IfOp ifOp = rewriter.create<scf::IfOp>(
+ loc, ValueRange(userReduc).getTypes(), casePred, /*else=*/true);
+ rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front());
+
+ // Erase the empty block.
+ rewriter.eraseBlock(&ifOp.getThenRegion().front());
+ // Set up block arguments: user-provided values -> loop coord -> iterators.
+ SmallVector<Value> blockArgs(userReduc);
+ blockArgs.push_back(loopCrd);
+ for (unsigned idx : caseBits.bits())
+ llvm::append_range(blockArgs, iters[idx]->getCursor());
+
+ IRMapping mapping;
+ for (auto [from, to] :
+ llvm::zip_equal(b->front().getArguments(), blockArgs)) {
+ mapping.map(from, to);
+ }
+
+ // Clone the region, we can not erase the region now because the same region
+ // might be a subcase for multiple lattice point.
+ rewriter.cloneRegionBefore(*b, ifOp.getThenRegion(),
+ ifOp.getThenRegion().begin(), mapping);
+
+ // replace sparse_tensor::YieldOp -> scf::YieldOp
+ auto spY = cast<sparse_tensor::YieldOp>(&ifOp.getThenRegion().front().back());
+ ValueRange yields = spY.getResults();
+ rewriter.eraseOp(spY);
+ rewriter.setInsertionPointToEnd(&ifOp.getThenRegion().front());
+ rewriter.create<scf::YieldOp>(loc, yields);
+
+ // Generates remaining case recursively.
+ rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front());
+ ValueRange res = genCoIterateBranchNest(rewriter, loc, op, loopCrd, iters,
+ subCases.drop_front(), userReduc);
+ if (!res.empty())
+ rewriter.create<scf::YieldOp>(loc, res);
+
+ rewriter.setInsertionPointAfter(ifOp);
+ return ifOp.getResults();
+}
+
+static ValueRange genLoopWithIterator(
+ PatternRewriter &rewriter, Location loc, SparseIterator *it,
+ ValueRange reduc, bool iterFirst,
+ function_ref<SmallVector<Value>(PatternRewriter &rewriter, Location loc,
+ Region &loopBody, SparseIterator *it,
+ ValueRange reduc)>
+ bodyBuilder) {
+ if (it->iteratableByFor()) {
+ auto [lo, hi] = it->genForCond(rewriter, loc);
+ Value step = constantIndex(rewriter, loc, 1);
+ scf::ForOp forOp = rewriter.create<scf::ForOp>(loc, lo, hi, step, reduc);
+ {
+ OpBuilder::InsertionGuard guard(rewriter);
+ // Erase the implicit yield operation created by ForOp when there is no
+ // yielding values.
+ if (!forOp.getBody()->empty())
+ rewriter.eraseOp(&forOp.getBody()->front());
+ assert(forOp.getBody()->empty());
+
+ it->linkNewScope(forOp.getInductionVar());
+ rewriter.setInsertionPointToStart(forOp.getBody());
+ SmallVector<Value> ret = bodyBuilder(rewriter, loc, forOp.getBodyRegion(),
+ it, forOp.getRegionIterArgs());
+
+ rewriter.setInsertionPointToEnd(forOp.getBody());
+ rewriter.create<scf::YieldOp>(loc, ret);
+ }
+ return forOp.getResults();
+ }
+ SmallVector<Value> ivs;
+ // TODO: always put iterator SSA values at the end of argument list to be
+ // consistent with coiterate operation.
+ if (!iterFirst)
+ llvm::append_range(ivs, it->getCursor());
+ // Appends the user-provided values.
+ llvm::append_range(ivs, reduc);
+ if (iterFirst)
+ llvm::append_range(ivs, it->getCursor());
+
+ TypeRange types = ValueRange(ivs).getTypes();
+ auto whileOp = rewriter.create<scf::WhileOp>(loc, types, ivs);
+ {
+ OpBuilder::InsertionGuard guard(rewriter);
+ // Generates loop conditions.
+ SmallVector<Location> l(types.size(), loc);
+ Block *before = rewriter.createBlock(&whileOp.getBefore(), {}, types, l);
+ rewriter.setInsertionPointToStart(before);
+ ValueRange bArgs = before->getArguments();
+ auto [whileCond, remArgs] = it->genWhileCond(rewriter, loc, bArgs);
+ rewriter.create<scf::ConditionOp>(loc, whileCond, before->getArguments());
+
+ // Delegates loop body generation.
+ Region &dstRegion = whileOp.getAfter();
+ Block *after = rewriter.createBlock(&dstRegion, {}, types, l);
+ ValueRange aArgs = whileOp.getAfterArguments();
+ if (iterFirst) {
+ aArgs = it->linkNewScope(aArgs);
+ } else {
+ aArgs = aArgs.take_front(reduc.size());
+ it->linkNewScope(aArgs.drop_front(reduc.size()));
+ }
+
+ rewriter.setInsertionPointToStart(after);
+ SmallVector<Value> ret = bodyBuilder(rewriter, loc, dstRegion, it, aArgs);
+ rewriter.setInsertionPointToEnd(after);
+
+ // Forward loops
+ SmallVector<Value> yields;
+ ValueRange nx = it->forward(rewriter, loc);
+ if (iterFirst)
+ llvm::append_range(yields, nx);
+ llvm::append_range(yields, ret);
+ if (!iterFirst)
+ llvm::append_range(yields, nx);
+ rewriter.create<scf::YieldOp>(loc, yields);
+ }
+ return whileOp.getResults().drop_front(it->getCursor().size());
+}
+
namespace {
/// Sparse codegen rule for number of entries operator.
@@ -136,6 +275,8 @@ class SparseIterateOpConverter : public OneToNOpConversionPattern<IterateOp> {
rewriter.replaceOp(op, forOp.getResults(), resultMapping);
} else {
SmallVector<Value> ivs;
+ // TODO: put iterator at the end of argument list to be consistent with
+ // coiterate operation.
llvm::append_range(ivs, it->getCursor());
for (ValueRange inits : adaptor.getInitArgs())
llvm::append_range(ivs, inits);
@@ -189,6 +330,153 @@ class SparseIterateOpConverter : public OneToNOpConversionPattern<IterateOp> {
}
};
+class SparseCoIterateOpConverter
+ : public OneToNOpConversionPattern<CoIterateOp> {
+ using OneToNOpConversionPattern::OneToNOpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(CoIterateOp op, OpAdaptor adaptor,
+ OneToNPatternRewriter &rewriter) const override {
+ assert(op.getSpaceDim() == 1 && "Not implemented");
+ Location loc = op.getLoc();
+
+ I64BitSet denseBits(0);
+ for (auto [idx, spaceTp] : llvm::enumerate(op.getIterSpaces().getTypes()))
+ if (all_of(cast<IterSpaceType>(spaceTp).getLvlTypes(), isDenseLT))
+ denseBits.set(idx);
+
+ // If there exists a case that only contains dense spaces. I.e., case
+ // bits is a subset of dense bits, or when there is a full empty case (due
+ // to complements), we need a universal pointer to forward the coiteration
+ // loop.
+ bool needUniv =
+ any_of(op.getRegionDefinedSpaces(), [denseBits](I64BitSet caseBits) {
+ // A case for complement.
+ if (caseBits.count() == 0)
+ return true;
+ // An all-dense case.
+ return caseBits.isSubSetOf(denseBits);
+ });
+ assert(!needUniv && "Not implemented");
+ (void)needUniv;
+
+ for (Region ®ion : op.getCaseRegions()) {
+ // Do a one-shot type conversion on all region blocks, since the same
+ // region might be used multiple time.
+ Block *block = ®ion.getBlocks().front();
+ OneToNTypeMapping blockTypeMapping(block->getArgumentTypes());
+ if (failed(typeConverter->convertSignatureArgs(block->getArgumentTypes(),
+ blockTypeMapping)))
+ return rewriter.notifyMatchFailure(
+ op, "failed to convert coiterate region argurment types");
+
+ rewriter.applySignatureConversion(block, blockTypeMapping);
+ }
+
+ SmallVector<SparseIterationSpace> spaces;
+ SmallVector<std::unique_ptr<SparseIterator>> iters;
+ for (auto [spaceTp, spaceVals] : llvm::zip_equal(
+ op.getIterSpaces().getTypes(), adaptor.getIterSpaces())) {
+ // TODO: do we really need tid?
+ spaces.push_back(SparseIterationSpace::fromValues(
+ cast<IterSpaceType>(spaceTp), spaceVals, /*tid=*/0));
+ // Extract the iterator.
+ iters.push_back(spaces.back().extractIterator(rewriter, loc));
+ }
+
+ auto getFilteredIters = [&iters](I64BitSet caseBits) {
+ // Retrives a vector of pointers to the iterators used in the case.
+ SmallVector<SparseIterator *> validIters;
+ for (auto idx : caseBits.bits())
+ validIters.push_back(iters[idx].get());
+ return validIters;
+ };
+
+ // Get a flattened user-provided loop reduction values.
+ SmallVector<Value> userReduc;
+ for (ValueRange r : adaptor.getInitArgs())
+ llvm::append_range(userReduc, r);
+
+ // TODO: we need to sort the cases such that they appears in lexical order.
+ // Although sparsification always generates cases in that order, it might
+ // not be the case for human-written code.
+
+ // Generates a loop sequence, one loop per case.
+ for (auto [r, caseBits] :
+ llvm::zip_equal(op.getCaseRegions(), op.getRegionDefinedSpaces())) {
+ assert(caseBits.count() > 0 && "Complement space not implemented");
+
+ // Retrives a vector of pointers to the iterators used in the case.
+ SmallVector<SparseIterator *> validIters = getFilteredIters(caseBits);
+
+ if (validIters.size() > 1) {
+ auto [loop, loopCrd] =
+ genCoIteration(rewriter, loc, validIters, userReduc,
+ /*uniIdx=*/nullptr, /*userReducFirst=*/true);
+
+ // 1st. find all the cases that is a strict subset of the current case
+ // condition, for which we generate one branch per case inside the loop.
+ // The subcases are never empty, it must contains at least the current
+ // region itself.
+ // TODO: these cases should be sorted.
+ SmallVector<Region *> subCases = op.getSubCasesOf(r.getRegionNumber());
+ assert(!subCases.empty());
+
+ ValueRange res = genCoIterateBranchNest(rewriter, loc, op, loopCrd,
+ iters, subCases, userReduc);
+
+ SmallVector<Value> nextIterYields(res);
+ // 2nd. foward the loop.
+ for (SparseIterator *it : validIters) {
+ Value cmp = rewriter.create<arith::CmpIOp>(
+ loc, arith::CmpIPredicate::eq, it->getCrd(), loopCrd);
+ it->forwardIf(rewriter, loc, cmp);
+ llvm::append_range(nextIterYields, it->getCursor());
+ }
+ rewriter.create<scf::YieldOp>(loc, nextIterYields);
+
+ // Exit the loop, relink the iterator SSA value.
+ rewriter.setInsertionPointAfter(loop);
+ ValueRange iterVals = loop->getResults().drop_front(userReduc.size());
+ for (SparseIterator *it : validIters)
+ iterVals = it->linkNewScope(iterVals);
+ assert(iterVals.empty());
+
+ ValueRange curResult = loop->getResults().take_front(userReduc.size());
+ userReduc.assign(curResult.begin(), curResult.end());
+ } else {
+ // This is a simple iteration loop.
+ assert(caseBits.count() == 1);
+
+ Block *block = &r.getBlocks().front();
+ ValueRange curResult = genLoopWithIterator(
+ rewriter, loc, validIters.front(), userReduc, /*iterFirst=*/false,
+ /*bodyBuilder=*/
+ [block](PatternRewriter &rewriter, Location loc, Region &dstRegion,
+ SparseIterator *it,
+ ValueRange reduc) -> SmallVector<Value> {
+ SmallVector<Value> blockArgs(reduc);
+ blockArgs.push_back(it->deref(rewriter, loc));
+ llvm::append_range(blockArgs, it->getCursor());
+
+ Block *dstBlock = &dstRegion.getBlocks().front();
+ rewriter.inlineBlockBefore(
+ block, dstBlock, rewriter.getInsertionPoint(), blockArgs);
+ auto yield = llvm::cast<sparse_tensor::YieldOp>(dstBlock->back());
+ SmallVector<Value> result(yield.getResults());
+ rewriter.eraseOp(yield);
+ return result;
+ });
+
+ userReduc.assign(curResult.begin(), curResult.end());
+ }
+ }
+
+ rewriter.replaceOp(op, userReduc);
+ return success();
+ }
+};
+
} // namespace
mlir::SparseIterationTypeConverter::SparseIterationTypeConverter() {
@@ -210,5 +498,6 @@ void mlir::populateLowerSparseIterationToSCFPatterns(
IterateOp::getCanonicalizationPatterns(patterns, patterns.getContext());
patterns.add<ExtractIterSpaceConverter, ExtractValOpConverter,
- SparseIterateOpConverter>(converter, patterns.getContext());
+ SparseIterateOpConverter, SparseCoIterateOpConverter>(
+ converter, patterns.getContext());
}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp
index efb3295fb2a4bf..cb5874ff45068e 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp
@@ -524,84 +524,8 @@ std::pair<Operation *, Value> LoopEmitter::emitForLoopOverTensorAtLvl(
std::pair<Operation *, Value> LoopEmitter::emitWhileLoopOverTensorsAtLvls(
OpBuilder &builder, Location loc, ArrayRef<SparseIterator *> spIters,
MutableArrayRef<Value> reduc, bool needsUniv) {
- // NOTE: the slice driven tensor-related reduction variable must
- // appear before normal tensors.
-
- // The set of induction variables for the while loop.
- SmallVector<Value> ivs;
-
- // Construct the while-loop with a parameter for each coordinate.
- for (SparseIterator *it : spIters) {
- ValueRange itVals = it->getCursor();
- ivs.append(itVals.begin(), itVals.end());
- }
-
- // The position where user-supplied reduction variable starts.
- ivs.append(reduc.begin(), reduc.end());
- // Update universal index.
- if (needsUniv)
- ivs.push_back(loopSeqStack.back().first);
-
- // Ensures all operands are valid.
- assert(llvm::all_of(ivs, [](Value v) { return v != nullptr; }));
- TypeRange types = ValueRange(ivs).getTypes();
- auto whileOp = builder.create<scf::WhileOp>(loc, types, ivs);
-
- SmallVector<Location> locs(types.size(), loc);
- Block *before = builder.createBlock(&whileOp.getBefore(), {}, types, locs);
- Block *after = builder.createBlock(&whileOp.getAfter(), {}, types, locs);
-
- // Generates loop conditions.
- builder.setInsertionPointToStart(before);
- ValueRange bArgs = before->getArguments();
- Value whileCond = nullptr; // bool values for loop condition.
-
- for (SparseIterator *it : spIters) {
- auto [cond, remArgs] = it->genWhileCond(builder, loc, bArgs);
- whileCond = !whileCond ? cond : ANDI(whileCond, cond);
- bArgs = remArgs;
- }
- // The remaining block arguments are user-provided reduction values and an
- // optional universal index. Make sure their sizes match.
- assert(bArgs.size() == reduc.size() + needsUniv);
- builder.create<scf::ConditionOp>(loc, whileCond, before->getArguments());
-
- // Generates loop body.
- builder.setInsertionPointToStart(after);
- ValueRange aArgs = after->getArguments();
- // Since some LoopCondKind might need extra checks to filter out invalid
- // iterations, we maintains another array to hold the iteration arguments to
- // yield if the checks fails.
- SmallVector<Value> nextArgs(aArgs.begin(), aArgs.end());
-
- for (SparseIterator *it : spIters) {
- aArgs = it->linkNewScope(aArgs);
- // Dereference the iterator to cache the coordinate.
- it->deref(builder, loc);
- }
-
- // In-place update on reduction variable.
- assert(aArgs.size() == reduc.size() + needsUniv);
- for (unsigned i = 0, e = reduc.size(); i < e; i++)
- reduc[i] = aArgs[i];
-
- Value min;
- // Finds the minimum coordinate
- if (!needsUniv) {
- for (SparseIterator *it : spIters) {
- if (min) {
- Value cmp = CMPI(ult, it->getCrd(), min);
- min = SELECT(cmp, it->getCrd(), min);
- } else {
- min = it->getCrd();
- }
- }
- } else {
- // Otherwise, universal index is the minimal pos.
- min = whileOp.getAfterAr...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/105565
More information about the Mlir-commits
mailing list