[Mlir-commits] [mlir] [mlir][sparse] support CSR/BSR conversion (PR #69800)
Peiming Liu
llvmlistbot at llvm.org
Fri Oct 20 17:06:43 PDT 2023
https://github.com/PeimingLiu created https://github.com/llvm/llvm-project/pull/69800
None
>From 8dd2c0aa81c6587f934f4e8040fd864323ee6116 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Sat, 21 Oct 2023 00:06:03 +0000
Subject: [PATCH] [mlir][sparse] support CSR/BSR conversion
---
.../SparseTensor/IR/SparseTensorDialect.cpp | 5 -----
.../SparseTensor/Transforms/LoopEmitter.cpp | 18 +++++++++++-------
.../Transforms/SparseTensorRewriting.cpp | 4 ++--
3 files changed, 13 insertions(+), 14 deletions(-)
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index fe2eaebbfa45d5a..56214c2b41c387b 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -847,11 +847,6 @@ RankedTensorType sparse_tensor::getCOOFromTypeWithOrdering(RankedTensorType rtt,
AffineMap lvlPerm,
bool ordered) {
const SparseTensorType src(rtt);
- // TODO: This assertion is to match the behavior from before we merged
- // dimOrdering and higherOrdering into dimToLvl. However, there's no
- // in-principle reason to require this. (wrengr has a commit in the
- // wings to fix this.)
- assert(src.isPermutation());
const Level lvlRank = src.getLvlRank();
SmallVector<DimLevelType> lvlTypes;
lvlTypes.reserve(lvlRank);
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
index 2d4f40eceba3b0d..de7ad3c91d793a0 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
@@ -422,10 +422,19 @@ void LoopEmitter::initializeLoopEmit(
// FIXME: the definition of `lvlRank` looks more like a dim-rank;
// but the variable is used as a level everywhere below, which
// suggests there may be some dim/lvl confusion going on here.
- const Level lvlRank = rtp.getRank();
+ auto stt = getSparseTensorType(tensor);
+ const Level lvlRank = stt.getLvlRank();
const auto shape = rtp.getShape();
const auto enc = getSparseTensorEncoding(rtp);
const Level cooStart = enc ? getCOOStart(enc) : lvlRank;
+
+ SmallVector<Value> dimSz;
+ for (Dimension d = 0; d < stt.getDimRank(); d++)
+ dimSz.push_back(linalg::createOrFoldDimOp(builder, loc, tensor, d));
+
+ ValueRange lvlSzs =
+ enc.translateCrds(builder, loc, dimSz, CrdTransDirectionKind::dim2lvl);
+
// Scan all levels of current tensor.
for (Level l = 0; l < lvlRank; l++) {
// This should be called only once at beginning.
@@ -447,13 +456,8 @@ void LoopEmitter::initializeLoopEmit(
assert(isDenseDLT(lvlTp));
}
- // FIXME: `toOrigDim` is deprecated. For now this relies on the
- // 1:1 mapping between levels and dimensions, since nowhere else
- // in the code supports non-permutations yet either.
- Value lvlSz = mlir::linalg::createOrFoldDimOp(builder, loc, tensor,
- toOrigDim(enc, l));
// Find upper bound in current dimension.
- highs[t][l] = lvlSizes[t][l] = lvlSz;
+ highs[t][l] = lvlSizes[t][l] = lvlSzs[l];
if (isSparseSlices[t]) {
sliceOffsets[t][l] = genSliceOffset(builder, loc, tensors[t], l);
sliceStrides[t][l] = genSliceStride(builder, loc, tensors[t], l);
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index c168f10dadbbfb5..5ce81f932faaf3f 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -1069,7 +1069,6 @@ struct ForeachRewriter : public OpRewritePattern<ForeachOp> {
Value input = op.getTensor();
SmallVector<Value> reduc = op.getInitArgs();
const auto stt = getSparseTensorType(input);
- const Dimension dimRank = stt.getDimRank();
const Level lvlRank = stt.getLvlRank();
// Special-case: for each over a sparse constant uses its own rewriting
@@ -1103,6 +1102,7 @@ struct ForeachRewriter : public OpRewritePattern<ForeachOp> {
SmallVector<Value> lcvs = loopEmitter.getLoopIVs();
if (op.getOrder()) {
// FIXME: There is some dim/lvl confusion here since `dimRank != lvlRank`
+ const Dimension dimRank = stt.getDimRank();
SmallVector<Value> dcvs = lcvs; // keep a copy
for (Dimension d = 0; d < dimRank; d++) {
auto l = op.getOrder()->getDimPosition(d);
@@ -1143,7 +1143,7 @@ struct ForeachRewriter : public OpRewritePattern<ForeachOp> {
args);
}
- for (Dimension d = 0; d < dimRank; d++) {
+ for (Level l = 0; l < lvlRank; l++) {
// Link the reduction chain. Note that loop emitter update the reducValue
// in place.
loopEmitter.exitCurrentLoop(rewriter, loc, reducValue);
More information about the Mlir-commits
mailing list