[Mlir-commits] [mlir] [mlir][sparse] code cleanup, remove FIXMEs (PR #73575)
Peiming Liu
llvmlistbot at llvm.org
Mon Nov 27 14:38:49 PST 2023
https://github.com/PeimingLiu updated https://github.com/llvm/llvm-project/pull/73575
>From c56970af6ebb9f7dcc242166ba3591c26e07f988 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Mon, 27 Nov 2023 22:22:47 +0000
Subject: [PATCH 1/4] [mlir][sparse] code cleanup, remove FIXMEs
---
.../Dialect/SparseTensor/IR/SparseTensor.h | 16 +++++----
.../SparseTensor/IR/SparseTensorDialect.cpp | 36 ++++++-------------
.../SparseTensor/Transforms/CodegenUtils.cpp | 26 --------------
.../SparseTensor/Transforms/CodegenUtils.h | 7 ----
.../SparseTensor/Transforms/LoopEmitter.cpp | 6 ++--
.../Transforms/SparseTensorRewriting.cpp | 9 ++---
6 files changed, 25 insertions(+), 75 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
index eb7c50ae2efdf8c..f102f0270154264 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
@@ -163,13 +163,15 @@ bool isBlockSparsity(AffineMap dimToLvl);
// Reordering.
//
-/// [deprecated] Convenience method to translate the given level to the
-/// corresponding dimension. Requires: `0 <= l < lvlRank`.
-Dimension toOrigDim(SparseTensorEncodingAttr enc, Level l);
-
-/// [deprecated] Convenience method to translate the given dimension to
-/// the corresponding level. Requires: `0 <= d < dimRank`.
-Level toStoredDim(SparseTensorEncodingAttr enc, Dimension d);
+/// Convenience method to translate the given level to the corresponding
+/// dimension.
+/// Requires: `enc` has a permuted dim2lvl map and `0 <= l < lvlRank`.
+Dimension toDim(SparseTensorEncodingAttr enc, Level l);
+
+/// Convenience method to translate the given dimension to the corresponding
+/// level.
+/// Requires: `enc` has a permuted dim2lvl map and `0 <= d < dimRank`.
+Level toLvl(SparseTensorEncodingAttr enc, Dimension d);
} // namespace sparse_tensor
} // namespace mlir
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 791aeebee5a328d..28e07e1669e7919 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -375,14 +375,12 @@ SparseTensorEncodingAttr::getStaticDimSliceStride(Dimension dim) const {
std::optional<uint64_t>
SparseTensorEncodingAttr::getStaticLvlSliceOffset(Level lvl) const {
- // FIXME: `toOrigDim` is deprecated.
- return getStaticDimSliceOffset(toOrigDim(*this, lvl));
+ return getStaticDimSliceOffset(toDim(*this, lvl));
}
std::optional<uint64_t>
SparseTensorEncodingAttr::getStaticLvlSliceStride(Level lvl) const {
- // FIXME: `toOrigDim` is deprecated.
- return getStaticDimSliceStride(toOrigDim(*this, lvl));
+ return getStaticDimSliceStride(toDim(*this, lvl));
}
SmallVector<int64_t>
@@ -399,9 +397,8 @@ SparseTensorEncodingAttr::tranlateShape(ArrayRef<int64_t> srcShape,
if (isPermutation()) {
for (unsigned r = 0; r < rank; r++) {
// FIXME: `toOrigDim` and `toStoredDim` are deprecated.
- unsigned trans = dir == CrdTransDirectionKind::dim2lvl
- ? toOrigDim(*this, r)
- : toStoredDim(*this, r);
+ unsigned trans = dir == CrdTransDirectionKind::dim2lvl ? toDim(*this, r)
+ : toLvl(*this, r);
ret.push_back(srcShape[trans]);
}
return ret;
@@ -925,31 +922,20 @@ RankedTensorType sparse_tensor::getCOOFromType(RankedTensorType src,
ordered);
}
-// TODO: Remove this definition once all use-sites have been fixed to
-// properly handle non-permutations.
-Dimension mlir::sparse_tensor::toOrigDim(SparseTensorEncodingAttr enc,
- Level l) {
+Dimension mlir::sparse_tensor::toDim(SparseTensorEncodingAttr enc, Level l) {
if (enc) {
- if (const auto dimToLvl = enc.getDimToLvl()) {
- assert(enc.isPermutation());
+ assert(enc.isPermutation() && "Non permutation map");
+ if (const auto dimToLvl = enc.getDimToLvl())
return dimToLvl.getDimPosition(l);
- }
}
return l;
}
-// TODO: Remove this definition once all use-sites have been fixed to
-// properly handle non-permutations.
-Level mlir::sparse_tensor::toStoredDim(SparseTensorEncodingAttr enc,
- Dimension d) {
+Level mlir::sparse_tensor::toLvl(SparseTensorEncodingAttr enc, Dimension d) {
if (enc) {
- if (const auto dimToLvl = enc.getDimToLvl()) {
- assert(enc.isPermutation());
- auto maybePos =
- dimToLvl.getResultPosition(getAffineDimExpr(d, enc.getContext()));
- assert(maybePos.has_value());
- return *maybePos;
- }
+ assert(enc.isPermutation() && "");
+ if (const auto lvlToDim = enc.getLvlToDim())
+ return lvlToDim.getDimPosition(d);
}
return d;
}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
index 1200b999f9a90ff..33d449aac5a3550 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
@@ -546,32 +546,6 @@ void sparse_tensor::storeAll(OpBuilder &builder, Location loc, Value mem,
}
}
-Value sparse_tensor::reshapeValuesToLevels(OpBuilder &builder, Location loc,
- SparseTensorEncodingAttr enc,
- ValueRange dimSizes,
- Value valuesBuffer,
- Value lvlCoords) {
- // Reuse the `lvlCoords` buffer to store the level-sizes.
- const Level lvlRank = enc.getLvlRank();
- SmallVector<Value> lvlSizes;
- lvlSizes.reserve(lvlRank);
- for (Level l = 0; l < lvlRank; l++)
- // FIXME: `toOrigDim` is deprecated.
- lvlSizes.push_back(dimSizes[toOrigDim(enc, l)]);
- storeAll(builder, loc, lvlCoords, lvlSizes);
- // The memref ReshapeOp requires the sizes buffer to have a static
- // shape.
- const auto iTp = builder.getIndexType();
- const SmallVector<Size, 1> lvlSizesShape{static_cast<Size>(lvlRank)};
- const auto lvlSizesTp = MemRefType::get(lvlSizesShape, iTp);
- lvlCoords = builder.create<memref::CastOp>(loc, lvlSizesTp, lvlCoords);
- // Finally, create the ReshapeOp.
- const SmallVector<Size> resShape(lvlRank, ShapedType::kDynamic);
- const Type elemTp = getMemRefType(valuesBuffer).getElementType();
- const auto resTp = MemRefType::get(resShape, elemTp);
- return builder.create<memref::ReshapeOp>(loc, resTp, valuesBuffer, lvlCoords);
-}
-
TypedValue<BaseMemRefType>
sparse_tensor::genToMemref(OpBuilder &builder, Location loc, Value tensor) {
auto tTp = llvm::cast<TensorType>(tensor.getType());
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
index cb0acdd2be9f7b0..0ce33427281f598 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
@@ -277,13 +277,6 @@ SmallVector<Value> loadAll(OpBuilder &builder, Location loc, size_t size,
void storeAll(OpBuilder &builder, Location loc, Value mem, ValueRange vs,
size_t offsetIdx = 0, Value offsetVal = Value());
-/// Reshapes the linear values buffer for an annotated all dense sparse tensor
-/// to match the shape of the corresponding dense tensor to support direct
-/// access of the buffer through `lvlCoords`.
-Value reshapeValuesToLevels(OpBuilder &builder, Location loc,
- SparseTensorEncodingAttr enc, ValueRange dimSizes,
- Value valuesBuffer, Value lvlCoords);
-
// Generates code to cast a tensor to a memref.
TypedValue<BaseMemRefType> genToMemref(OpBuilder &builder, Location loc,
Value tensor);
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
index f8bcc0fe12a1093..413a835ff14d314 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
@@ -68,15 +68,13 @@ static constexpr unsigned kSliceIterWidth = 3;
static Value genSliceOffset(OpBuilder &builder, Location loc, Value tensor,
Level lvl) {
auto enc = getSparseTensorEncoding(tensor.getType());
- // FIXME: `toOrigDim` is deprecated
- return createOrFoldSliceOffsetOp(builder, loc, tensor, toOrigDim(enc, lvl));
+ return createOrFoldSliceOffsetOp(builder, loc, tensor, toDim(enc, lvl));
}
static Value genSliceStride(OpBuilder &builder, Location loc, Value tensor,
Level lvl) {
auto enc = getSparseTensorEncoding(tensor.getType());
- // FIXME: `toOrigDim` is deprecated
- return createOrFoldSliceStrideOp(builder, loc, tensor, toOrigDim(enc, lvl));
+ return createOrFoldSliceStrideOp(builder, loc, tensor, toDim(enc, lvl));
}
/// Converts a coordinate relative to the slice to the coordinate relative
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index 5374ab55c5c0d9f..103908b2cf5bd89 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -661,8 +661,7 @@ struct TensorReshapeRewriter : public OpRewritePattern<tensor::ReshapeOp> {
SmallVector<Value> srcDcvs;
srcDcvs.reserve(srcRank);
for (Dimension d = 0; d < srcRank; d++) {
- // FIXME: `toStoredDim` is deprecated
- Level lvl = toStoredDim(encSrc, d);
+ Level lvl = toLvl(encSrc, d);
srcDcvs.push_back(srcLcvs[lvl]);
}
@@ -766,8 +765,7 @@ struct Sparse2SparseReshapeRewriter : public OpRewritePattern<ReshapeOp> {
SmallVector<Value> srcDcvs;
srcDcvs.reserve(dimRank);
for (Dimension d = 0; d < dimRank; d++) {
- // FIXME: `toStoredDim` is deprecated
- Level lvl = toStoredDim(encSrc, d);
+ Level lvl = toLvl(encSrc, d);
srcDcvs.push_back(srcLcvs[lvl]);
}
SmallVector<Value> dstDcvs;
@@ -872,9 +870,8 @@ struct SparseTensorDimOpRewriter : public OpRewritePattern<tensor::DimOp> {
return failure();
if (stt.isPermutation()) {
- // FIXME: `toStoredDim` is deprecated
rewriter.replaceOpWithNewOp<LvlOp>(op, op.getSource(),
- toStoredDim(stt.getEncoding(), *dim));
+ toLvl(stt.getEncoding(), *dim));
return success();
}
>From 80ca0bb13c336c2ac76bd2d3a7236d005a05ffb5 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Mon, 27 Nov 2023 22:33:43 +0000
Subject: [PATCH 2/4] remove unused variables
---
.../SparseTensor/Transforms/IterationGraphSorter.h | 1 -
.../SparseTensor/Transforms/SparseReinterpretMap.cpp | 8 ++++----
2 files changed, 4 insertions(+), 5 deletions(-)
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/IterationGraphSorter.h b/mlir/lib/Dialect/SparseTensor/Transforms/IterationGraphSorter.h
index 613a8609ac0973a..52ee11702930096 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/IterationGraphSorter.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/IterationGraphSorter.h
@@ -26,7 +26,6 @@ enum class SortMask : unsigned {
// The individual mask bits.
kIncludeDenseOutput = 0x1, // b001
kIncludeDenseInput = 0x2, // b010
- kIncludeUndef = 0x4, // b100
// The subsets of mask bits.
kIncludeAll = 0x7, // b111
kIncludeDense = 0x3, // b011
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
index 268bd8fbe27387f..c94ef8b96287766 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
@@ -422,10 +422,10 @@ struct GenericOpScheduler : public OpRewritePattern<linalg::GenericOp> {
// computation. Must be ordered from more strict to less strict.
// Ideally (though might not be guaranteed), the earlier a constraint mask
// can be satisfied, the faster the generated kernel will be.
- const auto allMasks = {
- SortMask::kIncludeAll, SortMask::kIncludeDense,
- SortMask::kIncludeDenseInput, SortMask::kIncludeDenseOutput,
- SortMask::kIncludeUndef, SortMask::kSparseOnly};
+ const auto allMasks = {SortMask::kIncludeAll, SortMask::kIncludeDense,
+ SortMask::kIncludeDenseInput,
+ SortMask::kIncludeDenseOutput,
+ SortMask::kSparseOnly};
for (const SortMask mask : allMasks) {
order = scheduler.sort(mask);
if (order) {
>From c345ebd7dd5644a5d7c6cb5b95c9eb66b5001055 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Mon, 27 Nov 2023 22:37:33 +0000
Subject: [PATCH 3/4] small fix
---
mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 28e07e1669e7919..df37dfab9a2ef5d 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -924,7 +924,7 @@ RankedTensorType sparse_tensor::getCOOFromType(RankedTensorType src,
Dimension mlir::sparse_tensor::toDim(SparseTensorEncodingAttr enc, Level l) {
if (enc) {
- assert(enc.isPermutation() && "Non permutation map");
+ assert(enc.isPermutation() && "Non permutation map not supported");
if (const auto dimToLvl = enc.getDimToLvl())
return dimToLvl.getDimPosition(l);
}
@@ -933,7 +933,7 @@ Dimension mlir::sparse_tensor::toDim(SparseTensorEncodingAttr enc, Level l) {
Level mlir::sparse_tensor::toLvl(SparseTensorEncodingAttr enc, Dimension d) {
if (enc) {
- assert(enc.isPermutation() && "");
+ assert(enc.isPermutation() && "Non permutation map not supported");
if (const auto lvlToDim = enc.getLvlToDim())
return lvlToDim.getDimPosition(d);
}
>From efc542a30384ccd608d30306014cb12ad59f1c36 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Mon, 27 Nov 2023 22:38:35 +0000
Subject: [PATCH 4/4] small fix
---
mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp | 1 -
1 file changed, 1 deletion(-)
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index df37dfab9a2ef5d..fc897e7935510a6 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -396,7 +396,6 @@ SparseTensorEncodingAttr::tranlateShape(ArrayRef<int64_t> srcShape,
if (isPermutation()) {
for (unsigned r = 0; r < rank; r++) {
- // FIXME: `toOrigDim` and `toStoredDim` are deprecated.
unsigned trans = dir == CrdTransDirectionKind::dim2lvl ? toDim(*this, r)
: toLvl(*this, r);
ret.push_back(srcShape[trans]);
More information about the Mlir-commits
mailing list