[Mlir-commits] [mlir] [mlir][sparse_tensor] Migrate `SparseIterationToScf.cpp` to dialect conversion (PR #121054)
Matthias Springer
llvmlistbot at llvm.org
Tue Dec 24 06:07:59 PST 2024
https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/121054
>From 8547cba1d26e2699874f4688cbf5ef340a09d71b Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Fri, 20 Dec 2024 18:13:00 +0100
Subject: [PATCH] sparse iteration
---
.../Transforms/SparseIterationToScf.cpp | 123 +++++++++++-------
.../Transforms/SparseTensorPasses.cpp | 11 +-
2 files changed, 81 insertions(+), 53 deletions(-)
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
index e8a40b1e033dd5b..9e9fea76416b9ca 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
@@ -7,11 +7,17 @@
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
-#include "mlir/Transforms/OneToNTypeConversion.h"
+#include "mlir/Transforms/DialectConversion.h"
using namespace mlir;
using namespace mlir::sparse_tensor;
+/// Assert that the given value range contains a single value and return it.
+static Value getSingleValue(ValueRange values) {
+ assert(values.size() == 1 && "expected single value");
+ return values.front();
+}
+
static void convertLevelType(SparseTensorEncodingAttr enc, Level lvl,
SmallVectorImpl<Type> &fields) {
// Position and coordinate buffer in the sparse structure.
@@ -54,14 +60,17 @@ static ValueRange
genCoIterateBranchNest(PatternRewriter &rewriter, Location loc, CoIterateOp op,
Value loopCrd,
ArrayRef<std::unique_ptr<SparseIterator>> iters,
- ArrayRef<Region *> subCases, ArrayRef<Value> userReduc) {
- if (subCases.empty())
+ ArrayRef<Block *> newBlocks, ArrayRef<Block *> oldBlocks,
+ ArrayRef<Value> userReduc) {
+ if (newBlocks.empty())
return userReduc;
// The current branch that we are handling.
- Region *b = subCases.front();
+ Block *newBlock = newBlocks.front();
+ Block *oldBlock = oldBlocks.front();
Value casePred = constantI1(rewriter, loc, true);
- I64BitSet caseBits = op.getRegionDefinedSpace(b->getRegionNumber());
+ I64BitSet caseBits =
+ op.getRegionDefinedSpace(newBlock->getParent()->getRegionNumber());
for (unsigned i : caseBits.bits()) {
SparseIterator *it = iters[i].get();
Value pred = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
@@ -80,16 +89,20 @@ genCoIterateBranchNest(PatternRewriter &rewriter, Location loc, CoIterateOp op,
for (unsigned idx : caseBits.bits())
llvm::append_range(blockArgs, iters[idx]->getCursor());
+ // Map the old block arguments, because the dialect conversion driver does
+ // not immediately perform SSA value replacements. This function is still
+ // seeing the old uses.
IRMapping mapping;
- for (auto [from, to] :
- llvm::zip_equal(b->front().getArguments(), blockArgs)) {
+ for (auto [from, to] : llvm::zip_equal(oldBlock->getArguments(), blockArgs)) {
mapping.map(from, to);
}
// Clone the region, we can not erase the region now because the same region
// might be a subcase for multiple lattice point.
- rewriter.cloneRegionBefore(*b, ifOp.getThenRegion(),
+ rewriter.cloneRegionBefore(*newBlock->getParent(), ifOp.getThenRegion(),
ifOp.getThenRegion().begin(), mapping);
+ // Remove the block arguments, they were already replaced via `mapping`.
+ ifOp.getThenRegion().front().eraseArguments(0, blockArgs.size());
// replace sparse_tensor::YieldOp -> scf::YieldOp
auto spY = cast<sparse_tensor::YieldOp>(&ifOp.getThenRegion().front().back());
@@ -101,7 +114,8 @@ genCoIterateBranchNest(PatternRewriter &rewriter, Location loc, CoIterateOp op,
// Generates remaining case recursively.
rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front());
ValueRange res = genCoIterateBranchNest(rewriter, loc, op, loopCrd, iters,
- subCases.drop_front(), userReduc);
+ newBlocks.drop_front(),
+ oldBlocks.drop_front(), userReduc);
if (!res.empty())
rewriter.create<scf::YieldOp>(loc, res);
@@ -119,15 +133,13 @@ static ValueRange genLoopWithIterator(
if (it->iteratableByFor()) {
auto [lo, hi] = it->genForCond(rewriter, loc);
Value step = constantIndex(rewriter, loc, 1);
- scf::ForOp forOp = rewriter.create<scf::ForOp>(loc, lo, hi, step, reduc);
+ scf::ForOp forOp = rewriter.create<scf::ForOp>(
+ loc, lo, hi, step, reduc,
+ [&](OpBuilder &b, Location loc, Value iv, ValueRange iterArgs) {
+ // Empty builder function to ensure that no terminator is created.
+ });
{
OpBuilder::InsertionGuard guard(rewriter);
- // Erase the implicit yield operation created by ForOp when there is no
- // yielding values.
- if (!forOp.getBody()->empty())
- rewriter.eraseOp(&forOp.getBody()->front());
- assert(forOp.getBody()->empty());
-
it->linkNewScope(forOp.getInductionVar());
rewriter.setInsertionPointToStart(forOp.getBody());
SmallVector<Value> ret = bodyBuilder(rewriter, loc, forOp.getBodyRegion(),
@@ -178,46 +190,47 @@ namespace {
/// Sparse codegen rule for number of entries operator.
class ExtractIterSpaceConverter
- : public OneToNOpConversionPattern<ExtractIterSpaceOp> {
+ : public OpConversionPattern<ExtractIterSpaceOp> {
public:
- using OneToNOpConversionPattern::OneToNOpConversionPattern;
+ using OpConversionPattern::OpConversionPattern;
LogicalResult
- matchAndRewrite(ExtractIterSpaceOp op, OpAdaptor adaptor,
- OneToNPatternRewriter &rewriter) const override {
+ matchAndRewrite(ExtractIterSpaceOp op, OneToNOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc();
- const OneToNTypeMapping &resultMapping = adaptor.getResultMapping();
// Construct the iteration space.
- SparseIterationSpace space(loc, rewriter, op.getTensor(), 0,
+ SparseIterationSpace space(loc, rewriter,
+ getSingleValue(adaptor.getTensor()), 0,
op.getLvlRange(), adaptor.getParentIter());
SmallVector<Value> result = space.toValues();
- rewriter.replaceOp(op, result, resultMapping);
+ rewriter.replaceOpWithMultiple(op, {result});
return success();
}
};
/// Sparse codegen rule for number of entries operator.
-class ExtractValOpConverter : public OneToNOpConversionPattern<ExtractValOp> {
+class ExtractValOpConverter : public OpConversionPattern<ExtractValOp> {
public:
- using OneToNOpConversionPattern::OneToNOpConversionPattern;
+ using OpConversionPattern::OpConversionPattern;
LogicalResult
- matchAndRewrite(ExtractValOp op, OpAdaptor adaptor,
- OneToNPatternRewriter &rewriter) const override {
+ matchAndRewrite(ExtractValOp op, OneToNOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value pos = adaptor.getIterator().back();
- Value valBuf = rewriter.create<ToValuesOp>(loc, op.getTensor());
+ Value valBuf =
+ rewriter.create<ToValuesOp>(loc, getSingleValue(adaptor.getTensor()));
rewriter.replaceOpWithNewOp<memref::LoadOp>(op, valBuf, pos);
return success();
}
};
-class SparseIterateOpConverter : public OneToNOpConversionPattern<IterateOp> {
+class SparseIterateOpConverter : public OpConversionPattern<IterateOp> {
public:
- using OneToNOpConversionPattern::OneToNOpConversionPattern;
+ using OpConversionPattern::OpConversionPattern;
LogicalResult
- matchAndRewrite(IterateOp op, OpAdaptor adaptor,
- OneToNPatternRewriter &rewriter) const override {
+ matchAndRewrite(IterateOp op, OneToNOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
if (!op.getCrdUsedLvls().empty())
return rewriter.notifyMatchFailure(
op, "non-empty coordinates list not implemented.");
@@ -235,14 +248,15 @@ class SparseIterateOpConverter : public OneToNOpConversionPattern<IterateOp> {
llvm::append_range(ivs, inits);
// Type conversion on iterate op block.
- OneToNTypeMapping blockTypeMapping(op.getBody()->getArgumentTypes());
+ unsigned numOrigArgs = op.getBody()->getArgumentTypes().size();
+ TypeConverter::SignatureConversion signatureConversion(numOrigArgs);
if (failed(typeConverter->convertSignatureArgs(
- op.getBody()->getArgumentTypes(), blockTypeMapping)))
+ op.getBody()->getArgumentTypes(), signatureConversion)))
return rewriter.notifyMatchFailure(
op, "failed to convert iterate region argurment types");
- rewriter.applySignatureConversion(op.getBody(), blockTypeMapping);
- Block *block = op.getBody();
+ Block *block = rewriter.applySignatureConversion(
+ op.getBody(), signatureConversion, getTypeConverter());
ValueRange ret = genLoopWithIterator(
rewriter, loc, it.get(), ivs,
[block](PatternRewriter &rewriter, Location loc, Region &loopBody,
@@ -263,19 +277,17 @@ class SparseIterateOpConverter : public OneToNOpConversionPattern<IterateOp> {
return result;
});
- const OneToNTypeMapping &resultMapping = adaptor.getResultMapping();
- rewriter.replaceOp(op, ret, resultMapping);
+ rewriter.replaceOp(op, ret);
return success();
}
};
-class SparseCoIterateOpConverter
- : public OneToNOpConversionPattern<CoIterateOp> {
- using OneToNOpConversionPattern::OneToNOpConversionPattern;
+class SparseCoIterateOpConverter : public OpConversionPattern<CoIterateOp> {
+ using OpConversionPattern::OpConversionPattern;
LogicalResult
- matchAndRewrite(CoIterateOp op, OpAdaptor adaptor,
- OneToNPatternRewriter &rewriter) const override {
+ matchAndRewrite(CoIterateOp op, OneToNOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
assert(op.getSpaceDim() == 1 && "Not implemented");
Location loc = op.getLoc();
@@ -299,18 +311,23 @@ class SparseCoIterateOpConverter
assert(!needUniv && "Not implemented");
(void)needUniv;
+ SmallVector<Block *> newBlocks;
+ DenseMap<Block *, Block *> newToOldBlockMap;
for (Region ®ion : op.getCaseRegions()) {
// Do a one-shot type conversion on all region blocks, since the same
// region might be used multiple time.
Block *block = ®ion.getBlocks().front();
- OneToNTypeMapping blockTypeMapping(block->getArgumentTypes());
+ TypeConverter::SignatureConversion blockTypeMapping(
+ block->getArgumentTypes().size());
if (failed(typeConverter->convertSignatureArgs(block->getArgumentTypes(),
blockTypeMapping))) {
return rewriter.notifyMatchFailure(
op, "failed to convert coiterate region argurment types");
}
- rewriter.applySignatureConversion(block, blockTypeMapping);
+ newBlocks.push_back(rewriter.applySignatureConversion(
+ block, blockTypeMapping, getTypeConverter()));
+ newToOldBlockMap[newBlocks.back()] = block;
}
SmallVector<SparseIterationSpace> spaces;
@@ -343,7 +360,7 @@ class SparseCoIterateOpConverter
// Generates a loop sequence, one loop per case.
for (auto [r, caseBits] :
- llvm::zip_equal(op.getCaseRegions(), op.getRegionDefinedSpaces())) {
+ llvm::zip_equal(newBlocks, op.getRegionDefinedSpaces())) {
assert(caseBits.count() > 0 && "Complement space not implemented");
// Retrives a vector of pointers to the iterators used in the case.
@@ -359,11 +376,17 @@ class SparseCoIterateOpConverter
// The subcases are never empty, it must contains at least the current
// region itself.
// TODO: these cases should be sorted.
- SmallVector<Region *> subCases = op.getSubCasesOf(r.getRegionNumber());
+ SmallVector<Region *> subCases =
+ op.getSubCasesOf(r->getParent()->getRegionNumber());
+ SmallVector<Block *> newBlocks, oldBlocks;
+ for (Region *r : subCases) {
+ newBlocks.push_back(&r->front());
+ oldBlocks.push_back(newToOldBlockMap[newBlocks.back()]);
+ }
assert(!subCases.empty());
- ValueRange res = genCoIterateBranchNest(rewriter, loc, op, loopCrd,
- iters, subCases, userReduc);
+ ValueRange res = genCoIterateBranchNest(
+ rewriter, loc, op, loopCrd, iters, newBlocks, oldBlocks, userReduc);
SmallVector<Value> nextIterYields(res);
// 2nd. foward the loop.
@@ -388,7 +411,7 @@ class SparseCoIterateOpConverter
// This is a simple iteration loop.
assert(caseBits.count() == 1);
- Block *block = &r.getBlocks().front();
+ Block *block = r;
ValueRange curResult = genLoopWithIterator(
rewriter, loc, validIters.front(), userReduc,
/*bodyBuilder=*/
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
index 1cac949b68c79dc..153b9b170e5d340 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
@@ -172,11 +172,16 @@ struct LowerSparseIterationToSCFPass
ConversionTarget target(*ctx);
// The actual conversion.
- target.addIllegalOp<ExtractIterSpaceOp, IterateOp>();
+ target.addLegalDialect<arith::ArithDialect, linalg::LinalgDialect,
+ memref::MemRefDialect, scf::SCFDialect,
+ sparse_tensor::SparseTensorDialect>();
+ target.addIllegalOp<CoIterateOp, ExtractIterSpaceOp, ExtractValOp,
+ IterateOp>();
+ target.addLegalOp<UnrealizedConversionCastOp>();
populateLowerSparseIterationToSCFPatterns(converter, patterns);
- if (failed(applyPartialOneToNConversion(getOperation(), converter,
- std::move(patterns))))
+ if (failed(applyPartialConversion(getOperation(), target,
+ std::move(patterns))))
signalPassFailure();
}
};
More information about the Mlir-commits
mailing list