[Mlir-commits] [mlir] Reapply "[mlir][sparse] implement lowering rules for IterateOp." (PR #95836)
Peiming Liu
llvmlistbot at llvm.org
Mon Jun 17 13:10:46 PDT 2024
https://github.com/PeimingLiu created https://github.com/llvm/llvm-project/pull/95836
None
>From b94709bf864d239a22874733edf6e8d109b61a66 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Mon, 17 Jun 2024 20:04:28 +0000
Subject: [PATCH 1/2] Reapply "[mlir][sparse] implement lowering rules for
IterateOp." (#95826)
This reverts commit 996905d8152def16ca2fa1322367e493ac6eef5e.
---
.../Transforms/SparseIterationToScf.cpp | 121 +++++++++++++++++-
.../Transforms/Utils/SparseTensorIterator.cpp | 40 ++++++
.../Transforms/Utils/SparseTensorIterator.h | 26 +++-
.../SparseTensor/sparse_iteration_to_scf.mlir | 54 ++++++--
4 files changed, 224 insertions(+), 17 deletions(-)
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
index 62887c75c872b..f57be49f21b8c 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
@@ -34,6 +34,20 @@ convertIterSpaceType(IterSpaceType itSp, SmallVectorImpl<Type> &fields) {
return success();
}
+static std::optional<LogicalResult>
+convertIteratorType(IteratorType itTp, SmallVectorImpl<Type> &fields) {
+ // The actually Iterator Values (that are updated every iteration).
+ auto idxTp = IndexType::get(itTp.getContext());
+ // TODO: handle batch dimension.
+ assert(itTp.getEncoding().getBatchLvlRank() == 0);
+ if (!itTp.isUnique()) {
+ // Segment high for non-unique iterator.
+ fields.push_back(idxTp);
+ }
+ fields.push_back(idxTp);
+ return success();
+}
+
namespace {
/// Sparse codegen rule for number of entries operator.
@@ -57,10 +71,114 @@ class ExtractIterSpaceConverter
}
};
+class SparseIterateOpConverter : public OneToNOpConversionPattern<IterateOp> {
+public:
+ using OneToNOpConversionPattern::OneToNOpConversionPattern;
+ LogicalResult
+ matchAndRewrite(IterateOp op, OpAdaptor adaptor,
+ OneToNPatternRewriter &rewriter) const override {
+ if (!op.getCrdUsedLvls().empty())
+ return rewriter.notifyMatchFailure(
+ op, "non-empty coordinates list not implemented.");
+
+ Location loc = op.getLoc();
+
+ auto iterSpace = SparseIterationSpace::fromValues(
+ op.getIterSpace().getType(), adaptor.getIterSpace(), 0);
+
+ std::unique_ptr<SparseIterator> it =
+ iterSpace.extractIterator(rewriter, loc);
+
+ if (it->iteratableByFor()) {
+ auto [lo, hi] = it->genForCond(rewriter, loc);
+ Value step = constantIndex(rewriter, loc, 1);
+ SmallVector<Value> ivs;
+ for (ValueRange inits : adaptor.getInitArgs())
+ llvm::append_range(ivs, inits);
+ scf::ForOp forOp = rewriter.create<scf::ForOp>(loc, lo, hi, step, ivs);
+
+ Block *loopBody = op.getBody();
+ OneToNTypeMapping bodyTypeMapping(loopBody->getArgumentTypes());
+ if (failed(typeConverter->convertSignatureArgs(
+ loopBody->getArgumentTypes(), bodyTypeMapping)))
+ return failure();
+ rewriter.applySignatureConversion(loopBody, bodyTypeMapping);
+
+ forOp.getBody()->erase();
+ Region &dstRegion = forOp.getRegion();
+ rewriter.inlineRegionBefore(op.getRegion(), dstRegion, dstRegion.end());
+
+ auto yieldOp =
+ llvm::cast<sparse_tensor::YieldOp>(forOp.getBody()->getTerminator());
+
+ rewriter.setInsertionPointToEnd(forOp.getBody());
+ // replace sparse_tensor.yield with scf.yield.
+ rewriter.create<scf::YieldOp>(loc, yieldOp.getResults());
+ yieldOp.erase();
+
+ const OneToNTypeMapping &resultMapping = adaptor.getResultMapping();
+ rewriter.replaceOp(op, forOp.getResults(), resultMapping);
+ } else {
+ SmallVector<Value> ivs;
+ llvm::append_range(ivs, it->getCursor());
+ for (ValueRange inits : adaptor.getInitArgs())
+ llvm::append_range(ivs, inits);
+
+ assert(llvm::all_of(ivs, [](Value v) { return v != nullptr; }));
+
+ TypeRange types = ValueRange(ivs).getTypes();
+ auto whileOp = rewriter.create<scf::WhileOp>(loc, types, ivs);
+ SmallVector<Location> l(types.size(), op.getIterator().getLoc());
+
+ // Generates loop conditions.
+ Block *before = rewriter.createBlock(&whileOp.getBefore(), {}, types, l);
+ rewriter.setInsertionPointToStart(before);
+ ValueRange bArgs = before->getArguments();
+ auto [whileCond, remArgs] = it->genWhileCond(rewriter, loc, bArgs);
+ assert(remArgs.size() == adaptor.getInitArgs().size());
+ rewriter.create<scf::ConditionOp>(loc, whileCond, before->getArguments());
+
+ // Generates loop body.
+ Block *loopBody = op.getBody();
+ OneToNTypeMapping bodyTypeMapping(loopBody->getArgumentTypes());
+ if (failed(typeConverter->convertSignatureArgs(
+ loopBody->getArgumentTypes(), bodyTypeMapping)))
+ return failure();
+ rewriter.applySignatureConversion(loopBody, bodyTypeMapping);
+
+ Region &dstRegion = whileOp.getAfter();
+ // TODO: handle uses of coordinate!
+ rewriter.inlineRegionBefore(op.getRegion(), dstRegion, dstRegion.end());
+ ValueRange aArgs = whileOp.getAfterArguments();
+ auto yieldOp = llvm::cast<sparse_tensor::YieldOp>(
+ whileOp.getAfterBody()->getTerminator());
+
+ rewriter.setInsertionPointToEnd(whileOp.getAfterBody());
+
+ aArgs = it->linkNewScope(aArgs);
+ ValueRange nx = it->forward(rewriter, loc);
+ SmallVector<Value> yields;
+ llvm::append_range(yields, nx);
+ llvm::append_range(yields, yieldOp.getResults());
+
+ // replace sparse_tensor.yield with scf.yield.
+ yieldOp->erase();
+ rewriter.create<scf::YieldOp>(loc, yields);
+
+ const OneToNTypeMapping &resultMapping = adaptor.getResultMapping();
+ rewriter.replaceOp(
+ op, whileOp.getResults().drop_front(it->getCursor().size()),
+ resultMapping);
+ }
+ return success();
+ }
+};
+
} // namespace
mlir::SparseIterationTypeConverter::SparseIterationTypeConverter() {
addConversion([](Type type) { return type; });
+ addConversion(convertIteratorType);
addConversion(convertIterSpaceType);
addSourceMaterialization([](OpBuilder &builder, IterSpaceType spTp,
@@ -74,5 +192,6 @@ mlir::SparseIterationTypeConverter::SparseIterationTypeConverter() {
void mlir::populateLowerSparseIterationToSCFPatterns(
TypeConverter &converter, RewritePatternSet &patterns) {
- patterns.add<ExtractIterSpaceConverter>(converter, patterns.getContext());
+ patterns.add<ExtractIterSpaceConverter, SparseIterateOpConverter>(
+ converter, patterns.getContext());
}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp
index be8e15d6ae6f4..ef95fcc84bd90 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp
@@ -331,6 +331,13 @@ class TrivialIterator : public ConcreteIterator {
TrivialIterator(const SparseTensorLevel &stl)
: ConcreteIterator(stl, IterKind::kTrivial, /*itValCnt=*/1) {}
+ TrivialIterator(OpBuilder &b, Location l, const SparseTensorLevel &stl,
+ Value posLo, Value posHi)
+ : ConcreteIterator(stl, IterKind::kTrivial, /*itValCnt=*/1), posLo(posLo),
+ posHi(posHi) {
+ seek(posLo);
+ }
+
std::string getDebugInterfacePrefix() const override {
return std::string("trivial<") + stl.toString() + ">";
}
@@ -420,6 +427,14 @@ class DedupIterator : public ConcreteIterator {
: ConcreteIterator(stl, IterKind::kDedup, /*itValCnt=*/2) {
assert(!stl.isUnique());
}
+
+ DedupIterator(OpBuilder &b, Location l, const SparseTensorLevel &stl,
+ Value posLo, Value posHi)
+ : ConcreteIterator(stl, IterKind::kDedup, /*itValCnt=*/2), posHi(posHi) {
+ assert(!stl.isUnique());
+ seek({posLo, genSegmentHigh(b, l, posLo)});
+ }
+
// For LLVM-style RTTI.
static bool classof(const SparseIterator *from) {
return from->kind == IterKind::kDedup;
@@ -1532,6 +1547,11 @@ SparseIterationSpace mlir::sparse_tensor::SparseIterationSpace::fromValues(
return space;
}
+std::unique_ptr<SparseIterator>
+SparseIterationSpace::extractIterator(OpBuilder &b, Location l) const {
+ return makeSimpleIterator(b, l, *this);
+}
+
//===----------------------------------------------------------------------===//
// SparseIterator factory functions.
//===----------------------------------------------------------------------===//
@@ -1590,6 +1610,26 @@ sparse_tensor::makeSynLevelAndIterator(Value sz, unsigned tid, unsigned lvl,
return std::make_pair(std::move(stl), std::move(it));
}
+std::unique_ptr<SparseIterator>
+sparse_tensor::makeSimpleIterator(OpBuilder &b, Location l,
+ const SparseIterationSpace &iterSpace) {
+ // assert(iterSpace.getSpaceDim() == 1);
+ std::unique_ptr<SparseIterator> ret;
+ if (!iterSpace.isUnique()) {
+ // We always dedupliate the non-unique level, but we should optimize it away
+ // if possible.
+ ret = std::make_unique<DedupIterator>(b, l, iterSpace.getLastLvl(),
+ iterSpace.getBoundLo(),
+ iterSpace.getBoundHi());
+ } else {
+ ret = std::make_unique<TrivialIterator>(b, l, iterSpace.getLastLvl(),
+ iterSpace.getBoundLo(),
+ iterSpace.getBoundHi());
+ }
+ ret->setSparseEmitStrategy(SparseEmitStrategy::kFunctional);
+ return ret;
+}
+
std::unique_ptr<SparseIterator>
sparse_tensor::makeSimpleIterator(const SparseTensorLevel &stl,
SparseEmitStrategy strategy) {
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h
index 17636af2b2f9d..91f363db93f1d 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h
@@ -132,6 +132,10 @@ class SparseIterationSpace {
Value getBoundLo() const { return bound.first; }
Value getBoundHi() const { return bound.second; }
+ // Extract an iterator to iterate over the sparse iteration space.
+ std::unique_ptr<SparseIterator> extractIterator(OpBuilder &b,
+ Location l) const;
+
private:
SmallVector<std::unique_ptr<SparseTensorLevel>> lvls;
std::pair<Value, Value> bound;
@@ -192,6 +196,13 @@ class SparseIterator {
crd = nullptr;
}
+ // Reconstructs a iteration space directly from the provided ValueRange.
+ static std::unique_ptr<SparseIterator>
+ fromValues(IteratorType dstTp, ValueRange values, unsigned tid);
+
+ // The inverse operation of `fromValues`.
+ SmallVector<Value> toValues() const { llvm_unreachable("Not implemented"); }
+
//
// Iterator properties.
//
@@ -345,12 +356,21 @@ std::unique_ptr<SparseTensorLevel> makeSparseTensorLevel(OpBuilder &b,
unsigned tid,
Level lvl);
-/// Helper function to create a TensorLevel object from given `tensor`.
+/// Helper function to create a TensorLevel object from given ValueRange.
std::unique_ptr<SparseTensorLevel> makeSparseTensorLevel(LevelType lt, Value sz,
ValueRange buffers,
unsigned tid, Level l);
-/// Helper function to create a simple SparseIterator object that iterates
-/// over the SparseTensorLevel.
+
+/// Helper function to create a simple SparseIterator object that iterate
+/// over the entire iteration space.
+std::unique_ptr<SparseIterator>
+makeSimpleIterator(OpBuilder &b, Location l,
+ const SparseIterationSpace &iterSpace);
+
+/// Helper function to create a simple SparseIterator object that iterate
+/// over the sparse tensor level.
+/// TODO: switch to `SparseIterationSpace` (which support N-D iterator) when
+/// feature complete.
std::unique_ptr<SparseIterator> makeSimpleIterator(
const SparseTensorLevel &stl,
SparseEmitStrategy strategy = SparseEmitStrategy::kFunctional);
diff --git a/mlir/test/Dialect/SparseTensor/sparse_iteration_to_scf.mlir b/mlir/test/Dialect/SparseTensor/sparse_iteration_to_scf.mlir
index 5fcd661bb69b2..77a0e89dc7c81 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_iteration_to_scf.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_iteration_to_scf.mlir
@@ -1,4 +1,5 @@
// RUN: mlir-opt %s --lower-sparse-iteration-to-scf | FileCheck %s
+// RUN: mlir-opt %s --sparse-space-collapse --lower-sparse-iteration-to-scf | FileCheck %s --check-prefix COLLAPSED
#COO = #sparse_tensor.encoding<{
map = (i, j) -> (
@@ -7,17 +8,44 @@
)
}>
-// CHECK-LABEL: func.func @sparse_1D_space(
-// CHECK-SAME: %[[VAL_0:.*]]: tensor<?x?xf32, #sparse{{[0-9]*}}>) -> !sparse_tensor.iter_space<#sparse{{[0-9]*}}, lvls = 0> {
-// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
-// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
-// CHECK-DAG: %[[LVL_SIZE:.*]] = sparse_tensor.lvl %[[VAL_0]], %[[C0]] : tensor<?x?xf32, #sparse{{[0-9]*}}>
-// CHECK: %[[POS_MEM:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<?x?xf32, #sparse{{[0-9]*}}> to memref<?xindex>
-// CHECK: %[[CRD_MEM:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<?x?xf32, #sparse{{[0-9]*}}> to memref<?xindex>
-// CHECK: %[[POS_LO:.*]] = memref.load %[[POS_MEM]]{{\[}}%[[C0]]] : memref<?xindex>
-// CHECK: %[[POS_HI:.*]] = memref.load %[[POS_MEM]]{{\[}}%[[C1]]] : memref<?xindex>
-// CHECK: %[[ITER_SPACE:.*]] = builtin.unrealized_conversion_cast %[[POS_MEM]], %[[CRD_MEM]], %[[LVL_SIZE]], %[[POS_LO]], %[[POS_HI]]
-func.func @sparse_1D_space(%sp : tensor<?x?xf32, #COO>) -> !sparse_tensor.iter_space<#COO, lvls = 0> {
- %l1 = sparse_tensor.extract_iteration_space %sp lvls = 0 : tensor<?x?xf32, #COO> -> !sparse_tensor.iter_space<#COO, lvls = 0>
- return %l1 : !sparse_tensor.iter_space<#COO, lvls = 0>
+// CHECK-LABEL: @sparse_iteration_to_scf
+// // deduplication
+// CHECK: scf.while {{.*}} {
+// CHECK: } do {
+// CHECK: }
+// CHECK: scf.while {{.*}} {
+// CHECK: } do {
+// // actual computation
+// CHECK: scf.for {{.*}} {
+// CHECK: arith.addi
+// CHECK: }
+// // deduplication
+// CHECK: scf.while {{.*}} {
+// CHECK: } do {
+// CHECK: }
+// CHECK: scf.yield
+// CHECK: }
+// CHECK: return
+
+// COLLAPSED-LABEL: @sparse_iteration_to_scf
+// COLLAPSED: %[[RET:.*]] = scf.for {{.*}} {
+// COLLAPSED: %[[VAL:.*]] = arith.addi
+// COLLAPSED: scf.yield %[[VAL]] : index
+// COLLAPSED: }
+// COLLAPSED: return %[[RET]] : index
+func.func @sparse_iteration_to_scf(%sp : tensor<4x8xf32, #COO>) -> index {
+ %i = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %l1 = sparse_tensor.extract_iteration_space %sp lvls = 0
+ : tensor<4x8xf32, #COO> -> !sparse_tensor.iter_space<#COO, lvls = 0>
+ %r1 = sparse_tensor.iterate %it1 in %l1 iter_args(%outer = %i): !sparse_tensor.iter_space<#COO, lvls = 0 to 1> -> index {
+ %l2 = sparse_tensor.extract_iteration_space %sp at %it1 lvls = 1
+ : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 0 to 1> -> !sparse_tensor.iter_space<#COO, lvls = 1>
+ %r2 = sparse_tensor.iterate %it2 in %l2 iter_args(%inner = %outer): !sparse_tensor.iter_space<#COO, lvls = 1 to 2> -> index {
+ %k = arith.addi %inner, %c1 : index
+ sparse_tensor.yield %k : index
+ }
+ sparse_tensor.yield %r2 : index
+ }
+ return %r1 : index
}
>From f1baa43c9df1962676401fb2f8853df243e4f87f Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Mon, 17 Jun 2024 20:10:08 +0000
Subject: [PATCH 2/2] fix UAF
---
.../SparseTensor/Transforms/SparseIterationToScf.cpp | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
index f57be49f21b8c..4224925147c84 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
@@ -104,7 +104,7 @@ class SparseIterateOpConverter : public OneToNOpConversionPattern<IterateOp> {
return failure();
rewriter.applySignatureConversion(loopBody, bodyTypeMapping);
- forOp.getBody()->erase();
+ rewriter.eraseBlock(forOp.getBody());
Region &dstRegion = forOp.getRegion();
rewriter.inlineRegionBefore(op.getRegion(), dstRegion, dstRegion.end());
@@ -114,7 +114,7 @@ class SparseIterateOpConverter : public OneToNOpConversionPattern<IterateOp> {
rewriter.setInsertionPointToEnd(forOp.getBody());
// replace sparse_tensor.yield with scf.yield.
rewriter.create<scf::YieldOp>(loc, yieldOp.getResults());
- yieldOp.erase();
+ rewriter.eraseOp(yieldOp);
const OneToNTypeMapping &resultMapping = adaptor.getResultMapping();
rewriter.replaceOp(op, forOp.getResults(), resultMapping);
@@ -162,7 +162,7 @@ class SparseIterateOpConverter : public OneToNOpConversionPattern<IterateOp> {
llvm::append_range(yields, yieldOp.getResults());
// replace sparse_tensor.yield with scf.yield.
- yieldOp->erase();
+ rewriter.eraseOp(yieldOp);
rewriter.create<scf::YieldOp>(loc, yields);
const OneToNTypeMapping &resultMapping = adaptor.getResultMapping();
More information about the Mlir-commits
mailing list