[Mlir-commits] [mlir] [mlir][sparse] code cleanup, remove FIXMEs (PR #73575)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Nov 27 14:25:53 PST 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-sparse
Author: Peiming Liu (PeimingLiu)
<details>
<summary>Changes</summary>
---
Full diff: https://github.com/llvm/llvm-project/pull/73575.diff
6 Files Affected:
- (modified) mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h (+9-7)
- (modified) mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp (+11-25)
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp (-26)
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h (-7)
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp (+2-4)
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp (+3-6)
``````````diff
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
index eb7c50ae2efdf8c..f102f0270154264 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
@@ -163,13 +163,15 @@ bool isBlockSparsity(AffineMap dimToLvl);
// Reordering.
//
-/// [deprecated] Convenience method to translate the given level to the
-/// corresponding dimension. Requires: `0 <= l < lvlRank`.
-Dimension toOrigDim(SparseTensorEncodingAttr enc, Level l);
-
-/// [deprecated] Convenience method to translate the given dimension to
-/// the corresponding level. Requires: `0 <= d < dimRank`.
-Level toStoredDim(SparseTensorEncodingAttr enc, Dimension d);
+/// Convenience method to translate the given level to the corresponding
+/// dimension.
+/// Requires: `enc` has a permuted dim2lvl map and `0 <= l < lvlRank`.
+Dimension toDim(SparseTensorEncodingAttr enc, Level l);
+
+/// Convenience method to translate the given dimension to the corresponding
+/// level.
+/// Requires: `enc` has a permuted dim2lvl map and `0 <= d < dimRank`.
+Level toLvl(SparseTensorEncodingAttr enc, Dimension d);
} // namespace sparse_tensor
} // namespace mlir
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 791aeebee5a328d..28e07e1669e7919 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -375,14 +375,12 @@ SparseTensorEncodingAttr::getStaticDimSliceStride(Dimension dim) const {
std::optional<uint64_t>
SparseTensorEncodingAttr::getStaticLvlSliceOffset(Level lvl) const {
- // FIXME: `toOrigDim` is deprecated.
- return getStaticDimSliceOffset(toOrigDim(*this, lvl));
+ return getStaticDimSliceOffset(toDim(*this, lvl));
}
std::optional<uint64_t>
SparseTensorEncodingAttr::getStaticLvlSliceStride(Level lvl) const {
- // FIXME: `toOrigDim` is deprecated.
- return getStaticDimSliceStride(toOrigDim(*this, lvl));
+ return getStaticDimSliceStride(toDim(*this, lvl));
}
SmallVector<int64_t>
@@ -399,9 +397,8 @@ SparseTensorEncodingAttr::tranlateShape(ArrayRef<int64_t> srcShape,
if (isPermutation()) {
for (unsigned r = 0; r < rank; r++) {
// FIXME: `toOrigDim` and `toStoredDim` are deprecated.
- unsigned trans = dir == CrdTransDirectionKind::dim2lvl
- ? toOrigDim(*this, r)
- : toStoredDim(*this, r);
+ unsigned trans = dir == CrdTransDirectionKind::dim2lvl ? toDim(*this, r)
+ : toLvl(*this, r);
ret.push_back(srcShape[trans]);
}
return ret;
@@ -925,31 +922,20 @@ RankedTensorType sparse_tensor::getCOOFromType(RankedTensorType src,
ordered);
}
-// TODO: Remove this definition once all use-sites have been fixed to
-// properly handle non-permutations.
-Dimension mlir::sparse_tensor::toOrigDim(SparseTensorEncodingAttr enc,
- Level l) {
+Dimension mlir::sparse_tensor::toDim(SparseTensorEncodingAttr enc, Level l) {
if (enc) {
- if (const auto dimToLvl = enc.getDimToLvl()) {
- assert(enc.isPermutation());
+ assert(enc.isPermutation() && "Non permutation map");
+ if (const auto dimToLvl = enc.getDimToLvl())
return dimToLvl.getDimPosition(l);
- }
}
return l;
}
-// TODO: Remove this definition once all use-sites have been fixed to
-// properly handle non-permutations.
-Level mlir::sparse_tensor::toStoredDim(SparseTensorEncodingAttr enc,
- Dimension d) {
+Level mlir::sparse_tensor::toLvl(SparseTensorEncodingAttr enc, Dimension d) {
if (enc) {
- if (const auto dimToLvl = enc.getDimToLvl()) {
- assert(enc.isPermutation());
- auto maybePos =
- dimToLvl.getResultPosition(getAffineDimExpr(d, enc.getContext()));
- assert(maybePos.has_value());
- return *maybePos;
- }
+ assert(enc.isPermutation() && "");
+ if (const auto lvlToDim = enc.getLvlToDim())
+ return lvlToDim.getDimPosition(d);
}
return d;
}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
index 1200b999f9a90ff..33d449aac5a3550 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
@@ -546,32 +546,6 @@ void sparse_tensor::storeAll(OpBuilder &builder, Location loc, Value mem,
}
}
-Value sparse_tensor::reshapeValuesToLevels(OpBuilder &builder, Location loc,
- SparseTensorEncodingAttr enc,
- ValueRange dimSizes,
- Value valuesBuffer,
- Value lvlCoords) {
- // Reuse the `lvlCoords` buffer to store the level-sizes.
- const Level lvlRank = enc.getLvlRank();
- SmallVector<Value> lvlSizes;
- lvlSizes.reserve(lvlRank);
- for (Level l = 0; l < lvlRank; l++)
- // FIXME: `toOrigDim` is deprecated.
- lvlSizes.push_back(dimSizes[toOrigDim(enc, l)]);
- storeAll(builder, loc, lvlCoords, lvlSizes);
- // The memref ReshapeOp requires the sizes buffer to have a static
- // shape.
- const auto iTp = builder.getIndexType();
- const SmallVector<Size, 1> lvlSizesShape{static_cast<Size>(lvlRank)};
- const auto lvlSizesTp = MemRefType::get(lvlSizesShape, iTp);
- lvlCoords = builder.create<memref::CastOp>(loc, lvlSizesTp, lvlCoords);
- // Finally, create the ReshapeOp.
- const SmallVector<Size> resShape(lvlRank, ShapedType::kDynamic);
- const Type elemTp = getMemRefType(valuesBuffer).getElementType();
- const auto resTp = MemRefType::get(resShape, elemTp);
- return builder.create<memref::ReshapeOp>(loc, resTp, valuesBuffer, lvlCoords);
-}
-
TypedValue<BaseMemRefType>
sparse_tensor::genToMemref(OpBuilder &builder, Location loc, Value tensor) {
auto tTp = llvm::cast<TensorType>(tensor.getType());
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
index cb0acdd2be9f7b0..0ce33427281f598 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
@@ -277,13 +277,6 @@ SmallVector<Value> loadAll(OpBuilder &builder, Location loc, size_t size,
void storeAll(OpBuilder &builder, Location loc, Value mem, ValueRange vs,
size_t offsetIdx = 0, Value offsetVal = Value());
-/// Reshapes the linear values buffer for an annotated all dense sparse tensor
-/// to match the shape of the corresponding dense tensor to support direct
-/// access of the buffer through `lvlCoords`.
-Value reshapeValuesToLevels(OpBuilder &builder, Location loc,
- SparseTensorEncodingAttr enc, ValueRange dimSizes,
- Value valuesBuffer, Value lvlCoords);
-
// Generates code to cast a tensor to a memref.
TypedValue<BaseMemRefType> genToMemref(OpBuilder &builder, Location loc,
Value tensor);
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
index f8bcc0fe12a1093..413a835ff14d314 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
@@ -68,15 +68,13 @@ static constexpr unsigned kSliceIterWidth = 3;
static Value genSliceOffset(OpBuilder &builder, Location loc, Value tensor,
Level lvl) {
auto enc = getSparseTensorEncoding(tensor.getType());
- // FIXME: `toOrigDim` is deprecated
- return createOrFoldSliceOffsetOp(builder, loc, tensor, toOrigDim(enc, lvl));
+ return createOrFoldSliceOffsetOp(builder, loc, tensor, toDim(enc, lvl));
}
static Value genSliceStride(OpBuilder &builder, Location loc, Value tensor,
Level lvl) {
auto enc = getSparseTensorEncoding(tensor.getType());
- // FIXME: `toOrigDim` is deprecated
- return createOrFoldSliceStrideOp(builder, loc, tensor, toOrigDim(enc, lvl));
+ return createOrFoldSliceStrideOp(builder, loc, tensor, toDim(enc, lvl));
}
/// Converts a coordinate relative to the slice to the coordinate relative
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index 5374ab55c5c0d9f..103908b2cf5bd89 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -661,8 +661,7 @@ struct TensorReshapeRewriter : public OpRewritePattern<tensor::ReshapeOp> {
SmallVector<Value> srcDcvs;
srcDcvs.reserve(srcRank);
for (Dimension d = 0; d < srcRank; d++) {
- // FIXME: `toStoredDim` is deprecated
- Level lvl = toStoredDim(encSrc, d);
+ Level lvl = toLvl(encSrc, d);
srcDcvs.push_back(srcLcvs[lvl]);
}
@@ -766,8 +765,7 @@ struct Sparse2SparseReshapeRewriter : public OpRewritePattern<ReshapeOp> {
SmallVector<Value> srcDcvs;
srcDcvs.reserve(dimRank);
for (Dimension d = 0; d < dimRank; d++) {
- // FIXME: `toStoredDim` is deprecated
- Level lvl = toStoredDim(encSrc, d);
+ Level lvl = toLvl(encSrc, d);
srcDcvs.push_back(srcLcvs[lvl]);
}
SmallVector<Value> dstDcvs;
@@ -872,9 +870,8 @@ struct SparseTensorDimOpRewriter : public OpRewritePattern<tensor::DimOp> {
return failure();
if (stt.isPermutation()) {
- // FIXME: `toStoredDim` is deprecated
rewriter.replaceOpWithNewOp<LvlOp>(op, op.getSource(),
- toStoredDim(stt.getEncoding(), *dim));
+ toLvl(stt.getEncoding(), *dim));
return success();
}
``````````
</details>
https://github.com/llvm/llvm-project/pull/73575
More information about the Mlir-commits
mailing list