[Mlir-commits] [mlir] e9fa1fd - [mlir][sparse] support CSR/BSR conversion (#69800)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Oct 20 17:24:38 PDT 2023
Author: Peiming Liu
Date: 2023-10-20T17:24:34-07:00
New Revision: e9fa1fdec9b1e4fcf1320502b40954ed2f558c06
URL: https://github.com/llvm/llvm-project/commit/e9fa1fdec9b1e4fcf1320502b40954ed2f558c06
DIFF: https://github.com/llvm/llvm-project/commit/e9fa1fdec9b1e4fcf1320502b40954ed2f558c06.diff
LOG: [mlir][sparse] support CSR/BSR conversion (#69800)
Added:
Modified:
mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index fe2eaebbfa45d5a..56214c2b41c387b 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -847,11 +847,6 @@ RankedTensorType sparse_tensor::getCOOFromTypeWithOrdering(RankedTensorType rtt,
AffineMap lvlPerm,
bool ordered) {
const SparseTensorType src(rtt);
- // TODO: This assertion is to match the behavior from before we merged
- // dimOrdering and higherOrdering into dimToLvl. However, there's no
- // in-principle reason to require this. (wrengr has a commit in the
- // wings to fix this.)
- assert(src.isPermutation());
const Level lvlRank = src.getLvlRank();
SmallVector<DimLevelType> lvlTypes;
lvlTypes.reserve(lvlRank);
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
index 2d4f40eceba3b0d..de7ad3c91d793a0 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
@@ -422,10 +422,19 @@ void LoopEmitter::initializeLoopEmit(
// FIXME: the definition of `lvlRank` looks more like a dim-rank;
// but the variable is used as a level everywhere below, which
// suggests there may be some dim/lvl confusion going on here.
- const Level lvlRank = rtp.getRank();
+ auto stt = getSparseTensorType(tensor);
+ const Level lvlRank = stt.getLvlRank();
const auto shape = rtp.getShape();
const auto enc = getSparseTensorEncoding(rtp);
const Level cooStart = enc ? getCOOStart(enc) : lvlRank;
+
+ SmallVector<Value> dimSz;
+ for (Dimension d = 0; d < stt.getDimRank(); d++)
+ dimSz.push_back(linalg::createOrFoldDimOp(builder, loc, tensor, d));
+
+ ValueRange lvlSzs =
+ enc.translateCrds(builder, loc, dimSz, CrdTransDirectionKind::dim2lvl);
+
// Scan all levels of current tensor.
for (Level l = 0; l < lvlRank; l++) {
// This should be called only once at beginning.
@@ -447,13 +456,8 @@ void LoopEmitter::initializeLoopEmit(
assert(isDenseDLT(lvlTp));
}
- // FIXME: `toOrigDim` is deprecated. For now this relies on the
- // 1:1 mapping between levels and dimensions, since nowhere else
- // in the code supports non-permutations yet either.
- Value lvlSz = mlir::linalg::createOrFoldDimOp(builder, loc, tensor,
- toOrigDim(enc, l));
// Find upper bound in current dimension.
- highs[t][l] = lvlSizes[t][l] = lvlSz;
+ highs[t][l] = lvlSizes[t][l] = lvlSzs[l];
if (isSparseSlices[t]) {
sliceOffsets[t][l] = genSliceOffset(builder, loc, tensors[t], l);
sliceStrides[t][l] = genSliceStride(builder, loc, tensors[t], l);
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index c168f10dadbbfb5..5ce81f932faaf3f 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -1069,7 +1069,6 @@ struct ForeachRewriter : public OpRewritePattern<ForeachOp> {
Value input = op.getTensor();
SmallVector<Value> reduc = op.getInitArgs();
const auto stt = getSparseTensorType(input);
- const Dimension dimRank = stt.getDimRank();
const Level lvlRank = stt.getLvlRank();
// Special-case: for each over a sparse constant uses its own rewriting
@@ -1103,6 +1102,7 @@ struct ForeachRewriter : public OpRewritePattern<ForeachOp> {
SmallVector<Value> lcvs = loopEmitter.getLoopIVs();
if (op.getOrder()) {
// FIXME: There is some dim/lvl confusion here since `dimRank != lvlRank`
+ const Dimension dimRank = stt.getDimRank();
SmallVector<Value> dcvs = lcvs; // keep a copy
for (Dimension d = 0; d < dimRank; d++) {
auto l = op.getOrder()->getDimPosition(d);
@@ -1143,7 +1143,7 @@ struct ForeachRewriter : public OpRewritePattern<ForeachOp> {
args);
}
- for (Dimension d = 0; d < dimRank; d++) {
+ for (Level l = 0; l < lvlRank; l++) {
// Link the reduction chain. Note that loop emitter update the reducValue
// in place.
loopEmitter.exitCurrentLoop(rewriter, loc, reducValue);
More information about the Mlir-commits
mailing list