[Mlir-commits] [mlir] [mlir][sparse] implement sparse_tensor.lvl operation. (PR #69993)
Peiming Liu
llvmlistbot at llvm.org
Tue Oct 24 10:16:55 PDT 2023
https://github.com/PeimingLiu updated https://github.com/llvm/llvm-project/pull/69993
>From 47d6d705645431bbbafc3fe331bf54afef8df253 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Tue, 24 Oct 2023 01:01:04 +0000
Subject: [PATCH 1/3] [mlir][sparse] implement sparse_tensor.lvl operation.
---
.../SparseTensor/IR/SparseTensorDialect.cpp | 8 +++-
.../SparseTensor/Transforms/LoopEmitter.cpp | 16 ++++---
.../Transforms/SparseTensorCodegen.cpp | 46 ++++---------------
.../Transforms/SparseTensorConversion.cpp | 20 ++++----
.../Transforms/SparseTensorRewriting.cpp | 40 +++++++++++++++-
mlir/test/Dialect/SparseTensor/codegen.mlir | 2 +-
.../test/Dialect/SparseTensor/conversion.mlir | 8 ++--
mlir/test/Dialect/SparseTensor/sparse_2d.mlir | 4 +-
mlir/test/Dialect/SparseTensor/sparse_3d.mlir | 4 +-
.../Dialect/SparseTensor/sparse_expand.mlir | 8 ++--
.../Dialect/SparseTensor/sparse_foreach.mlir | 4 +-
.../Dialect/SparseTensor/sparse_index.mlir | 23 +++++-----
.../Dialect/SparseTensor/sparse_perm.mlir | 16 +++----
.../SparseTensor/sparse_perm_lower.mlir | 22 ++++-----
.../Dialect/SparseTensor/sparse_reshape.mlir | 4 +-
15 files changed, 123 insertions(+), 102 deletions(-)
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 5455b4fab0c08de..17e6ef53fe596e0 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -915,7 +915,7 @@ Level mlir::sparse_tensor::toStoredDim(SparseTensorEncodingAttr enc,
// properly handle non-permutations.
Dimension mlir::sparse_tensor::toOrigDim(RankedTensorType type, Level l) {
const auto enc = getSparseTensorEncoding(type);
- assert(l < enc.getLvlRank());
+ assert(!enc || l < enc.getLvlRank());
return toOrigDim(enc, l);
}
@@ -1208,6 +1208,12 @@ LogicalResult CrdTranslateOp::fold(FoldAdaptor adaptor,
return success();
}
+void LvlOp::build(OpBuilder &builder, OperationState &state, Value source,
+ int64_t index) {
+ Value val = builder.create<arith::ConstantIndexOp>(state.location, index);
+ return build(builder, state, source, val);
+}
+
LogicalResult LvlOp::verify() {
if (std::optional<uint64_t> lvl = getConstantLvlIndex()) {
auto stt = getSparseTensorType(getSource());
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
index de7ad3c91d793a0..bb3c6fb56f692d9 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
@@ -428,12 +428,13 @@ void LoopEmitter::initializeLoopEmit(
const auto enc = getSparseTensorEncoding(rtp);
const Level cooStart = enc ? getCOOStart(enc) : lvlRank;
- SmallVector<Value> dimSz;
- for (Dimension d = 0; d < stt.getDimRank(); d++)
- dimSz.push_back(linalg::createOrFoldDimOp(builder, loc, tensor, d));
-
- ValueRange lvlSzs =
- enc.translateCrds(builder, loc, dimSz, CrdTransDirectionKind::dim2lvl);
+ SmallVector<Value> lvlSzs;
+ for (Level l = 0; l < stt.getLvlRank(); l++) {
+ if (stt.hasEncoding())
+ lvlSzs.push_back(builder.create<LvlOp>(loc, tensor, l));
+ else
+ lvlSzs.push_back(builder.create<tensor::DimOp>(loc, tensor, l));
+ }
// Scan all levels of current tensor.
for (Level l = 0; l < lvlRank; l++) {
@@ -489,7 +490,8 @@ void LoopEmitter::initializeLoopEmit(
valBuffer[t] = denseVal;
} else {
// Annotated sparse tensors.
- // We also need the value buffer for all-dense annotated "sparse" tensors.
+ // We also need the value buffer for all-dense annotated "sparse"
+ // tensors.
valBuffer[t] = genToValues(builder, loc, tensor);
}
// NOTE: we can also prepare for 0 lvl here in advance, this will hoist
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index 67a78832ef30422..ecc452a5ba6c1cf 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -97,32 +97,6 @@ static scf::ForOp createFor(OpBuilder &builder, Location loc, Value upper,
return forOp;
}
-/// Gets the dimension size for the given sparse tensor at the given
-/// original dimension 'dim'.
-static Value sizeFromTensorAtDim(OpBuilder &builder, Location loc,
- SparseTensorDescriptor desc, Dimension dim) {
- const SparseTensorType stt(desc.getRankedTensorType());
- // Access into static dimension can query original type directly.
- // Note that this is typically already done by DimOp's folding.
- if (auto sz = stt.getStaticDimSize(dim))
- return constantIndex(builder, loc, *sz);
-
- // Any other query can consult the dimSizes array at field DimSizesIdx,
- // accounting for the reordering applied to the sparse storage.
- // FIXME: `toStoredDim` is deprecated.
- const Level lvl = toStoredDim(stt, dim);
- return desc.getLvlSize(builder, loc, lvl);
-}
-
-// Gets the dimension size at the given stored level 'lvl', either as a
-// constant for a static size, or otherwise dynamically through memSizes.
-static Value sizeFromTensorAtLvl(OpBuilder &builder, Location loc,
- SparseTensorDescriptor desc, Level lvl) {
- // FIXME: `toOrigDim` is deprecated.
- return sizeFromTensorAtDim(builder, loc, desc,
- toOrigDim(desc.getRankedTensorType(), lvl));
-}
-
static void createPushback(OpBuilder &builder, Location loc,
MutSparseTensorDescriptor desc,
SparseTensorFieldKind kind, std::optional<Level> lvl,
@@ -164,7 +138,7 @@ static void allocSchemeForRank(OpBuilder &builder, Location loc,
// at this level. We will eventually reach a compressed level or
// otherwise the values array for the from-here "all-dense" case.
assert(isDenseDLT(dlt));
- Value size = sizeFromTensorAtLvl(builder, loc, desc, l);
+ Value size = desc.getLvlSize(builder, loc, l);
linear = builder.create<arith::MulIOp>(loc, linear, size);
}
// Reached values array so prepare for an insertion.
@@ -448,7 +422,7 @@ class SparseInsertGenerator
// Construct the new position as:
// positions[l] = size * positions[l-1] + coords[l]
// <insert @ positions[l] at next level l + 1>
- Value size = sizeFromTensorAtLvl(builder, loc, desc, l);
+ Value size = desc.getLvlSize(builder, loc, l);
Value mult = builder.create<arith::MulIOp>(loc, size, parentPos);
parentPos = builder.create<arith::AddIOp>(loc, mult, coords[l]);
}
@@ -659,18 +633,18 @@ class SparseCallConverter : public OpConversionPattern<func::CallOp> {
};
/// Sparse codegen rule for dimension accesses.
-class SparseDimOpConverter : public OpConversionPattern<tensor::DimOp> {
+class SparseLvlOpConverter : public OpConversionPattern<LvlOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
- matchAndRewrite(tensor::DimOp op, OpAdaptor adaptor,
+ matchAndRewrite(LvlOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- std::optional<int64_t> dim = op.getConstantIndex();
- if (!dim || !getSparseTensorEncoding(adaptor.getSource().getType()))
+ std::optional<int64_t> lvl = op.getConstantLvlIndex();
+ if (!lvl || !getSparseTensorEncoding(adaptor.getSource().getType()))
return failure();
auto desc = getDescriptorFromTensorTuple(adaptor.getSource());
- auto sz = sizeFromTensorAtDim(rewriter, op.getLoc(), desc, *dim);
+ auto sz = desc.getLvlSize(rewriter, op.getLoc(), *lvl);
rewriter.replaceOp(op, sz);
return success();
@@ -925,9 +899,7 @@ class SparseExpandConverter : public OpConversionPattern<ExpandOp> {
// Determine the size for access expansion (always the innermost stored
// level size, translated back to original dimension). Note that we
// recursively rewrite the new DimOp on the **original** tensor.
- // FIXME: `toOrigDim` is deprecated.
- const Dimension innerDim = toOrigDim(srcType, srcType.getLvlRank() - 1);
- const auto sz = sizeFromTensorAtDim(rewriter, loc, desc, innerDim);
+ const auto sz = desc.getLvlSize(rewriter, loc, srcType.getLvlRank() - 1);
// Generate a memref for `sz` elements of type `t`.
const auto genAlloc = [&](Type t) {
const auto memTp = MemRefType::get({ShapedType::kDynamic}, t);
@@ -1588,7 +1560,7 @@ void mlir::populateSparseTensorCodegenPatterns(
TypeConverter &typeConverter, RewritePatternSet &patterns,
bool createSparseDeallocs, bool enableBufferInitialization) {
patterns.add<SparseAssembleOpConverter, SparseDisassembleOpConverter,
- SparseReturnConverter, SparseCallConverter, SparseDimOpConverter,
+ SparseReturnConverter, SparseCallConverter, SparseLvlOpConverter,
SparseCastConverter, SparseExtractSliceConverter,
SparseTensorLoadConverter, SparseExpandConverter,
SparseCompressConverter, SparseInsertConverter,
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
index a1f725333530d90..b5937698968eac5 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
@@ -294,25 +294,27 @@ class SparseReturnConverter : public OpConversionPattern<func::ReturnOp> {
};
/// Sparse conversion rule for accessing dimension-sizes.
-class SparseTensorToDimSizeConverter
- : public OpConversionPattern<tensor::DimOp> {
+class SparseTensorLvlOpConverter : public OpConversionPattern<LvlOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
- matchAndRewrite(tensor::DimOp op, OpAdaptor adaptor,
+ matchAndRewrite(LvlOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
const auto stt = getSparseTensorType(op.getSource());
// Only rewrite sparse DimOp.
if (!stt.hasEncoding())
return failure();
+
// Only rewrite DimOp with constant index.
- std::optional<int64_t> dim = op.getConstantIndex();
- if (!dim)
+ std::optional<int64_t> lvl = op.getConstantLvlIndex();
+
+ if (!lvl)
return failure();
- // Generate the call.
+
+ // By now, if the level size is constant, the operation should have already
+ // been folded by LvlOp's folder, so we generate the call unconditionally.
Value src = adaptor.getOperands()[0];
- rewriter.replaceOp(
- op, createOrFoldDimCall(rewriter, op->getLoc(), stt, src, *dim));
+ rewriter.replaceOp(op, genLvlSizeCall(rewriter, op.getLoc(), src, *lvl));
return success();
}
};
@@ -763,7 +765,7 @@ mlir::SparseTensorTypeToPtrConverter::SparseTensorTypeToPtrConverter() {
void mlir::populateSparseTensorConversionPatterns(TypeConverter &typeConverter,
RewritePatternSet &patterns) {
patterns
- .add<SparseReturnConverter, SparseTensorToDimSizeConverter,
+ .add<SparseReturnConverter, SparseTensorLvlOpConverter,
SparseCastConverter, SparseTensorNewConverter,
SparseTensorAllocConverter, SparseTensorEmptyConverter,
SparseTensorDeallocConverter, SparseTensorReorderCOOConverter,
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index 5ce81f932faaf3f..6ef4e09c8fe6de4 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -888,6 +888,43 @@ struct CrdTranslateRewriter : public OpRewritePattern<CrdTranslateOp> {
}
};
+struct SparseTensorDimOpRewriter : public OpRewritePattern<tensor::DimOp> {
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(tensor::DimOp op,
+ PatternRewriter &rewriter) const override {
+ std::optional<int64_t> dim = op.getConstantIndex();
+ auto stt = getSparseTensorType(op.getSource());
+ if (!dim || !stt.hasEncoding())
+ return failure();
+
+ if (stt.isPermutation()) {
+ rewriter.replaceOpWithNewOp<LvlOp>(op, op.getSource(),
+ toStoredDim(stt, *dim));
+ return success();
+ }
+
+ // Non-permutation dim2lvl/lvl2dim maps.
+ Location loc = op.getLoc();
+ SmallVector<Value> maxLvlCrds;
+ for (Level l = 0; l < stt.getLvlRank(); l++) {
+ Value lvlSz = rewriter.create<LvlOp>(loc, op.getSource(), l);
+ Value maxLvlCrd = rewriter.create<arith::SubIOp>(
+ loc, lvlSz, constantOne(rewriter, loc, rewriter.getIndexType()));
+ maxLvlCrds.push_back(maxLvlCrd);
+ }
+
+ AffineExpr lvl2DimExp = stt.getLvlToDim().getResult(*dim);
+ Value maxDimCrd = rewriter.create<affine::AffineApplyOp>(
+ op.getLoc(), AffineMap::get(stt.getLvlRank(), 0, lvl2DimExp),
+ maxLvlCrds);
+
+ Value dimSz = rewriter.create<arith::AddIOp>(
+ loc, maxDimCrd, constantOne(rewriter, loc, rewriter.getIndexType()));
+ rewriter.replaceOp(op, dimSz);
+ return success();
+ }
+};
+
struct ConcatenateRewriter : public OpRewritePattern<ConcatenateOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(ConcatenateOp op,
@@ -1270,7 +1307,8 @@ void mlir::populatePostSparsificationRewriting(RewritePatternSet &patterns,
ReshapeRewriter<tensor::CollapseShapeOp>,
Sparse2SparseReshapeRewriter<tensor::ExpandShapeOp>,
Sparse2SparseReshapeRewriter<tensor::CollapseShapeOp>,
- TensorReshapeRewriter>(patterns.getContext());
+ SparseTensorDimOpRewriter, TensorReshapeRewriter>(
+ patterns.getContext());
if (enableForeach)
patterns.add<ForeachRewriter>(patterns.getContext());
if (enableConvert)
diff --git a/mlir/test/Dialect/SparseTensor/codegen.mlir b/mlir/test/Dialect/SparseTensor/codegen.mlir
index 84904227a636327..8993333d6e5333d 100644
--- a/mlir/test/Dialect/SparseTensor/codegen.mlir
+++ b/mlir/test/Dialect/SparseTensor/codegen.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s --sparse-tensor-codegen --canonicalize -cse | FileCheck %s
+// RUN: mlir-opt %s --post-sparsification-rewrite --sparse-tensor-codegen --canonicalize -cse | FileCheck %s
#SV = #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed) }>
diff --git a/mlir/test/Dialect/SparseTensor/conversion.mlir b/mlir/test/Dialect/SparseTensor/conversion.mlir
index 2ff4887dae7b8c9..1d7599b3a4edb87 100644
--- a/mlir/test/Dialect/SparseTensor/conversion.mlir
+++ b/mlir/test/Dialect/SparseTensor/conversion.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s --sparse-tensor-conversion --canonicalize --cse | FileCheck %s
+// RUN: mlir-opt %s --post-sparsification-rewrite --sparse-tensor-conversion --canonicalize --cse | FileCheck %s
#SparseVector = #sparse_tensor.encoding<{
map = (d0) -> (d0 : compressed)
@@ -38,7 +38,7 @@ func.func @sparse_nop(%arg0: tensor<?xf64, #SparseVector>) -> tensor<?xf64, #Spa
// CHECK-LABEL: func @sparse_dim1d(
// CHECK-SAME: %[[A:.*]]: !llvm.ptr<i8>)
// CHECK: %[[C:.*]] = arith.constant 0 : index
-// CHECK: %[[D:.*]] = call @sparseDimSize(%[[A]], %[[C]])
+// CHECK: %[[D:.*]] = call @sparseLvlSize(%[[A]], %[[C]])
// CHECK: return %[[D]] : index
func.func @sparse_dim1d(%arg0: tensor<?xf64, #SparseVector>) -> index {
%c = arith.constant 0 : index
@@ -51,8 +51,8 @@ func.func @sparse_dim1d(%arg0: tensor<?xf64, #SparseVector>) -> index {
// dimension 1 is stored as level 2).
// CHECK-LABEL: func @sparse_dim3d(
// CHECK-SAME: %[[A:.*]]: !llvm.ptr<i8>)
-// CHECK: %[[C:.*]] = arith.constant 1 : index
-// CHECK: %[[D:.*]] = call @sparseDimSize(%[[A]], %[[C]])
+// CHECK: %[[C:.*]] = arith.constant 2 : index
+// CHECK: %[[D:.*]] = call @sparseLvlSize(%[[A]], %[[C]])
// CHECK: return %[[D]] : index
func.func @sparse_dim3d(%arg0: tensor<?x?x?xf64, #SparseTensor>) -> index {
%c = arith.constant 1 : index
diff --git a/mlir/test/Dialect/SparseTensor/sparse_2d.mlir b/mlir/test/Dialect/SparseTensor/sparse_2d.mlir
index 83c0ef390ae7a03..076e9201a1c053e 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_2d.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_2d.mlir
@@ -1511,7 +1511,7 @@ func.func @sum_reduction(%arga: tensor<10x20xf32, #Tds>, %argx: tensor<f32>) ->
// CHECK-DAG: %[[VAL_5:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 1 : index} : tensor<?x?xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
// CHECK-DAG: %[[VAL_6:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 1 : index} : tensor<?x?xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
// CHECK-DAG: %[[VAL_7:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<?x?xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xf64>
-// CHECK-DAG: %[[VAL_8:.*]] = tensor.dim %[[VAL_0]], %[[VAL_3]] : tensor<?x?xf64, #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK-DAG: %[[VAL_8:.*]] = sparse_tensor.lvl %[[VAL_0]], %[[VAL_3]] : tensor<?x?xf64, #sparse_tensor.encoding<{{{.*}}}>>
// CHECK-DAG: %[[VAL_11:.*]] = bufferization.to_memref %[[VAL_1]] : memref<?x?xf64>
// CHECK: linalg.fill ins(%{{.*}} : f64) outs(%[[VAL_11]] : memref<?x?xf64>)
// CHECK: scf.for %[[VAL_12:.*]] = %[[VAL_3]] to %[[VAL_8]] step %[[VAL_4]] {
@@ -1641,7 +1641,7 @@ func.func @sampled_dense_dense(%args: tensor<?x?xf32, #Tss>,
// CHECK-DAG: %[[VAL_19:.*]] = sparse_tensor.values %[[VAL_2]] : tensor<?x?xf32, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xf32>
// CHECK-DAG: %[[VAL_20:.*]] = bufferization.to_memref %[[VAL_3]] : memref<?xf32>
// CHECK-DAG: %[[VAL_21:.*]] = bufferization.to_memref %[[VAL_4]] : memref<f32>
-// CHECK-DAG: %[[VAL_22:.*]] = tensor.dim %[[VAL_2]], %[[VAL_6]] : tensor<?x?xf32,
+// CHECK-DAG: %[[VAL_22:.*]] = sparse_tensor.lvl %[[VAL_2]], %[[VAL_6]] : tensor<?x?xf32,
// CHECK-DAG: %[[VAL_24:.*]] = bufferization.to_memref %[[VAL_5]] : memref<?xf32>
// CHECK: %[[VAL_25:.*]] = memref.load %[[VAL_21]][] : memref<f32>
// CHECK: %[[VAL_26:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_6]]] : memref<?xindex>
diff --git a/mlir/test/Dialect/SparseTensor/sparse_3d.mlir b/mlir/test/Dialect/SparseTensor/sparse_3d.mlir
index 7b5d7420ff4c736..c245e612be37f12 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_3d.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_3d.mlir
@@ -1126,10 +1126,10 @@ func.func @mul_sss(%arga: tensor<32x16x8xf32, #Tsss>, %argb: tensor<32x16x8xf32>
// CHECK-DAG: %[[VAL_7:.*]] = sparse_tensor.positions %[[VAL_1]] {level = 2 : index} : tensor<?x?x?xf32, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
// CHECK-DAG: %[[VAL_8:.*]] = sparse_tensor.coordinates %[[VAL_1]] {level = 2 : index} : tensor<?x?x?xf32, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
// CHECK-DAG: %[[VAL_9:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<?x?x?xf32, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xf32>
-// CHECK-DAG: %[[VAL_10:.*]] = tensor.dim %[[VAL_1]], %[[VAL_6]] : tensor<?x?x?xf32, #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK-DAG: %[[VAL_10:.*]] = sparse_tensor.lvl %[[VAL_1]], %[[VAL_6]] : tensor<?x?x?xf32, #sparse_tensor.encoding<{{{.*}}}>>
// CHECK-DAG: %[[VAL_11:.*]] = bufferization.to_memref %[[VAL_2]] : memref<?x?xf32>
// CHECK-DAG: %[[VAL_12:.*]] = bufferization.to_memref %[[VAL_3]] : memref<?x?xf32>
-// CHECK-DAG: %[[VAL_13:.*]] = tensor.dim %[[VAL_1]], %[[VAL_5]] : tensor<?x?x?xf32, #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK-DAG: %[[VAL_13:.*]] = sparse_tensor.lvl %[[VAL_1]], %[[VAL_5]] : tensor<?x?x?xf32, #sparse_tensor.encoding<{{{.*}}}>>
// CHECK-DAG: %[[VAL_14:.*]] = tensor.dim %[[VAL_2]], %[[VAL_6]] : tensor<?x?xf32>
// CHECK-DAG: %[[VAL_16:.*]] = bufferization.to_memref %[[VAL_0]] : memref<?x?xf32>
// CHECK: scf.for %[[VAL_17:.*]] = %[[VAL_5]] to %[[VAL_13]] step %[[VAL_6]] {
diff --git a/mlir/test/Dialect/SparseTensor/sparse_expand.mlir b/mlir/test/Dialect/SparseTensor/sparse_expand.mlir
index 9d8db10aa423022..3ee6e84a2382a9e 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_expand.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_expand.mlir
@@ -4,7 +4,8 @@
// RUN: FileCheck %s --check-prefix=CHECK-SPARSE
// RUN: mlir-opt %s --linalg-generalize-named-ops \
// RUN: --linalg-fuse-elementwise-ops \
-// RUN: --sparsification --sparse-tensor-conversion --cse | \
+// RUN: --sparsification --post-sparsification-rewrite \
+// RUN: --sparse-tensor-conversion --cse | \
// RUN: FileCheck %s --check-prefix=CHECK-CONVERT
#CSR = #sparse_tensor.encoding<{
@@ -45,8 +46,9 @@
//
// CHECK-CONVERT-LABEL: func @kernel(
// CHECK-CONVERT-SAME: %[[A:.*]]: !llvm.ptr<i8>) -> !llvm.ptr<i8>
-// CHECK-CONVERT: %[[C0:.*]] = arith.constant 0 : index
-// CHECK-CONVERT: %[[N:.*]] = call @sparseDimSize(%[[A]], %[[C0]])
+// CHECK-CONVERT-DAG: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-CONVERT-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-CONVERT: %[[N:.*]] = call @sparseLvlSize(%[[A]], %[[C1]])
// CHECK-CONVERT: %[[V:.*]] = call @newSparseTensor
// CHECK-CONVERT: %[[S:.*]] = call @sparseLvlSize(%[[V]], %[[C0]])
// CHECK-CONVERT: %[[A:.*]] = memref.alloc(%[[S]]) : memref<?xf64>
diff --git a/mlir/test/Dialect/SparseTensor/sparse_foreach.mlir b/mlir/test/Dialect/SparseTensor/sparse_foreach.mlir
index 8c1836b1c2ef8f1..bbce42c100641ab 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_foreach.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_foreach.mlir
@@ -43,12 +43,12 @@ func.func @sparse_foreach_constant() -> () {
// CHECK-DAG: %[[VAL_2:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[VAL_3:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<?x?xf64,
// CHECK-DAG: %[[VAL_4:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<?x?xf64,
-// CHECK-DAG: %[[VAL_5:.*]] = tensor.dim %[[VAL_0]], %[[VAL_1]] : tensor<?x?xf64,
+// CHECK-DAG: %[[VAL_5:.*]] = sparse_tensor.lvl %[[VAL_0]], %[[VAL_1]] : tensor<?x?xf64,
// CHECK-DAG: %[[VAL_6:.*]] = sparse_tensor.slice.offset %[[VAL_0]] at 0 : tensor<?x?xf64,
// CHECK-DAG: %[[VAL_7:.*]] = sparse_tensor.slice.stride %[[VAL_0]] at 0 : tensor<?x?xf64,
// CHECK-DAG: %[[VAL_8:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 1 : index} : tensor<?x?xf64,
// CHECK-DAG: %[[VAL_9:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 1 : index} : tensor<?x?xf64,
-// CHECK-DAG: %[[VAL_10:.*]] = tensor.dim %[[VAL_0]], %[[VAL_2]] : tensor<?x?xf64,
+// CHECK-DAG: %[[VAL_10:.*]] = sparse_tensor.lvl %[[VAL_0]], %[[VAL_2]] : tensor<?x?xf64,
// CHECK-DAG: %[[VAL_11:.*]] = sparse_tensor.slice.offset %[[VAL_0]] at 1 : tensor<?x?xf64,
// CHECK-DAG: %[[VAL_12:.*]] = sparse_tensor.slice.stride %[[VAL_0]] at 1 : tensor<?x?xf64,
// CHECK-DAG: %[[VAL_13:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<?x?xf64,
diff --git a/mlir/test/Dialect/SparseTensor/sparse_index.mlir b/mlir/test/Dialect/SparseTensor/sparse_index.mlir
index 674225b9e935910..34e04c03529036f 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_index.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_index.mlir
@@ -21,13 +21,13 @@
// CHECK-SAME: %[[VAL_0:.*]]: tensor<?x?xi64, #sparse_tensor.encoding
// CHECK-DAG: %[[VAL_1:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[VAL_2:.*]] = arith.constant 1 : index
-// CHECK-DAG: %[[VAL_3:.*]] = tensor.dim %[[VAL_0]], %[[VAL_1]] : tensor<?x?xi64, #sparse_tensor.encoding
-// CHECK-DAG: %[[VAL_4:.*]] = tensor.dim %[[VAL_0]], %[[VAL_1]] : tensor<?x?xi64, #sparse_tensor.encoding
+// CHECK-DAG: %[[VAL_3:.*]] = sparse_tensor.lvl %[[VAL_0]], %[[VAL_1]] : tensor<?x?xi64, #sparse_tensor.encoding
+// CHECK-DAG: %[[VAL_4:.*]] = sparse_tensor.lvl %[[VAL_0]], %[[VAL_1]] : tensor<?x?xi64, #sparse_tensor.encoding
// CHECK-DAG: %[[VAL_5:.*]] = tensor.empty(%[[VAL_3]], %[[VAL_4]]) : tensor<?x?xi64, #sparse_tensor.encoding
// CHECK-DAG: %[[VAL_6:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<?x?xi64, #sparse_tensor.encoding
-// CHECK-DAG: %[[VAL_7:.*]] = tensor.dim %[[VAL_0]], %[[VAL_1]] : tensor<?x?xi64, #sparse_tensor.encoding
-// CHECK-DAG: %[[VAL_8:.*]] = tensor.dim %[[VAL_0]], %[[VAL_2]] : tensor<?x?xi64, #sparse_tensor.encoding
-// CHECK-DAG: %[[VAL_24:.*]] = tensor.dim %[[VAL_5]], %[[VAL_2]] : tensor<?x?xi64, #sparse_tensor.encoding
+// CHECK-DAG: %[[VAL_7:.*]] = sparse_tensor.lvl %[[VAL_0]], %[[VAL_1]] : tensor<?x?xi64, #sparse_tensor.encoding
+// CHECK-DAG: %[[VAL_8:.*]] = sparse_tensor.lvl %[[VAL_0]], %[[VAL_2]] : tensor<?x?xi64, #sparse_tensor.encoding
+// CHECK-DAG: %[[VAL_24:.*]] = sparse_tensor.lvl %[[VAL_5]], %[[VAL_2]] : tensor<?x?xi64, #sparse_tensor.encoding
// CHECK-DAG: %[[VAL_9:.*]] = sparse_tensor.values %[[VAL_5]] : tensor<?x?xi64, #sparse_tensor.encoding
// CHECK: scf.for %[[VAL_10:.*]] = %[[VAL_1]] to %[[VAL_7]] step %[[VAL_2]] {
// CHECK: scf.for %[[VAL_11:.*]] = %[[VAL_1]] to %[[VAL_8]] step %[[VAL_2]] {
@@ -50,8 +50,8 @@ func.func @dense_index(%arga: tensor<?x?xi64, #DenseMatrix>)
-> tensor<?x?xi64, #DenseMatrix> {
%c0 = arith.constant 0 : index
%c1 = arith.constant 0 : index
- %0 = tensor.dim %arga, %c0 : tensor<?x?xi64, #DenseMatrix>
- %1 = tensor.dim %arga, %c1 : tensor<?x?xi64, #DenseMatrix>
+ %0 = sparse_tensor.lvl %arga, %c0 : tensor<?x?xi64, #DenseMatrix>
+ %1 = sparse_tensor.lvl %arga, %c1 : tensor<?x?xi64, #DenseMatrix>
%init = tensor.empty(%0, %1) : tensor<?x?xi64, #DenseMatrix>
%r = linalg.generic #trait
ins(%arga: tensor<?x?xi64, #DenseMatrix>)
@@ -73,8 +73,8 @@ func.func @dense_index(%arga: tensor<?x?xi64, #DenseMatrix>)
// CHECK-SAME: %[[VAL_0:.*]]: tensor<?x?xi64, #sparse_tensor.encoding
// CHECK-DAG: %[[VAL_1:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[VAL_2:.*]] = arith.constant 1 : index
-// CHECK-DAG: %[[VAL_3:.*]] = tensor.dim %[[VAL_0]], %[[VAL_1]] : tensor<?x?xi64, #sparse_tensor.encoding
-// CHECK-DAG: %[[VAL_4:.*]] = tensor.dim %[[VAL_0]], %[[VAL_1]] : tensor<?x?xi64, #sparse_tensor.encoding
+// CHECK-DAG: %[[VAL_3:.*]] = sparse_tensor.lvl %[[VAL_0]], %[[VAL_1]] : tensor<?x?xi64, #sparse_tensor.encoding
+// CHECK-DAG: %[[VAL_4:.*]] = sparse_tensor.lvl %[[VAL_0]], %[[VAL_1]] : tensor<?x?xi64, #sparse_tensor.encoding
// CHECK-DAG: %[[VAL_5:.*]] = tensor.empty(%[[VAL_3]], %[[VAL_4]]) : tensor<?x?xi64, #sparse_tensor.encoding
// CHECK-DAG: %[[VAL_6:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<?x?xi64, #sparse_tensor.encoding
// CHECK-DAG: %[[VAL_7:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<?x?xi64, #sparse_tensor.encoding
@@ -107,8 +107,8 @@ func.func @sparse_index(%arga: tensor<?x?xi64, #SparseMatrix>)
-> tensor<?x?xi64, #SparseMatrix> {
%c0 = arith.constant 0 : index
%c1 = arith.constant 0 : index
- %0 = tensor.dim %arga, %c0 : tensor<?x?xi64, #SparseMatrix>
- %1 = tensor.dim %arga, %c1 : tensor<?x?xi64, #SparseMatrix>
+ %0 = sparse_tensor.lvl %arga, %c0 : tensor<?x?xi64, #SparseMatrix>
+ %1 = sparse_tensor.lvl %arga, %c1 : tensor<?x?xi64, #SparseMatrix>
%init = tensor.empty(%0, %1) : tensor<?x?xi64, #SparseMatrix>
%r = linalg.generic #trait
ins(%arga: tensor<?x?xi64, #SparseMatrix>)
@@ -124,4 +124,3 @@ func.func @sparse_index(%arga: tensor<?x?xi64, #SparseMatrix>)
} -> tensor<?x?xi64, #SparseMatrix>
return %r : tensor<?x?xi64, #SparseMatrix>
}
-
diff --git a/mlir/test/Dialect/SparseTensor/sparse_perm.mlir b/mlir/test/Dialect/SparseTensor/sparse_perm.mlir
index 26712f2c7b001b3..186030aa2ca0873 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_perm.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_perm.mlir
@@ -59,17 +59,17 @@ func.func @sparse_static_dims(%arga: tensor<10x20x30xf32, #X>,
// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[VAL_5:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<?x?x?xf32, #sparse_tensor.encoding<{{{.*}}}>>
-// CHECK-DAG: %[[VAL_6:.*]] = tensor.dim %[[VAL_0]], %[[VAL_2]] : tensor<?x?x?xf32, #sparse_tensor.encoding<{{{.*}}}>>
-// CHECK-DAG: %[[VAL_7:.*]] = tensor.dim %[[VAL_0]], %[[VAL_3]] : tensor<?x?x?xf32, #sparse_tensor.encoding<{{{.*}}}>>
-// CHECK-DAG: %[[VAL_8:.*]] = tensor.dim %[[VAL_0]], %[[VAL_4]] : tensor<?x?x?xf32, #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK-DAG: %[[VAL_6:.*]] = sparse_tensor.lvl %[[VAL_0]], %[[VAL_2]] : tensor<?x?x?xf32, #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK-DAG: %[[VAL_7:.*]] = sparse_tensor.lvl %[[VAL_0]], %[[VAL_3]] : tensor<?x?x?xf32, #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK-DAG: %[[VAL_8:.*]] = sparse_tensor.lvl %[[VAL_0]], %[[VAL_4]] : tensor<?x?x?xf32, #sparse_tensor.encoding<{{{.*}}}>>
// CHECK-DAG: %[[VAL_10:.*]] = bufferization.to_memref %[[VAL_1]] : memref<?x?x?xf32>
// CHECK: linalg.fill ins(%[[ZERO]] : f32) outs(%[[VAL_10]] : memref<?x?x?xf32>)
-// CHECK: scf.for %[[VAL_11:.*]] = %[[VAL_3]] to %[[VAL_6]] step %[[VAL_4]] {
-// CHECK: scf.for %[[VAL_12:.*]] = %[[VAL_3]] to %[[VAL_7]] step %[[VAL_4]] {
-// CHECK: %[[VAL_13:.*]] = arith.muli %[[VAL_7]], %[[VAL_11]] : index
+// CHECK: scf.for %[[VAL_11:.*]] = %[[VAL_3]] to %[[VAL_7]] step %[[VAL_4]] {
+// CHECK: scf.for %[[VAL_12:.*]] = %[[VAL_3]] to %[[VAL_8]] step %[[VAL_4]] {
+// CHECK: %[[VAL_13:.*]] = arith.muli %[[VAL_8]], %[[VAL_11]] : index
// CHECK: %[[VAL_14:.*]] = arith.addi %[[VAL_13]], %[[VAL_12]] : index
-// CHECK: scf.for %[[VAL_15:.*]] = %[[VAL_3]] to %[[VAL_8]] step %[[VAL_4]] {
-// CHECK: %[[VAL_16:.*]] = arith.muli %[[VAL_8]], %[[VAL_14]] : index
+// CHECK: scf.for %[[VAL_15:.*]] = %[[VAL_3]] to %[[VAL_6]] step %[[VAL_4]] {
+// CHECK: %[[VAL_16:.*]] = arith.muli %[[VAL_6]], %[[VAL_14]] : index
// CHECK: %[[VAL_17:.*]] = arith.addi %[[VAL_16]], %[[VAL_15]] : index
// CHECK: %[[VAL_18:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_17]]] : memref<?xf32>
// CHECK: memref.store %[[VAL_18]], %[[VAL_10]]{{\[}}%[[VAL_15]], %[[VAL_11]], %[[VAL_12]]] : memref<?x?x?xf32>
diff --git a/mlir/test/Dialect/SparseTensor/sparse_perm_lower.mlir b/mlir/test/Dialect/SparseTensor/sparse_perm_lower.mlir
index fd8aaf28697e42d..42726d998ac7a72 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_perm_lower.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_perm_lower.mlir
@@ -21,9 +21,9 @@
// CHECK-HIR-DAG: %[[VAL_2:.*]] = arith.constant 1 : index
// CHECK-HIR-DAG: %[[VAL_3:.*]] = arith.constant 0 : index
// CHECK-HIR-DAG: %[[VAL_4:.*]] = arith.constant 2 : index
-// CHECK-HIR-DAG: %[[VAL_5:.*]] = tensor.dim %[[VAL_0]], %[[VAL_4]] : tensor<?x?x?xf32, #sparse_tensor.encoding<{{{.*}}}>>
-// CHECK-HIR-DAG: %[[VAL_6:.*]] = tensor.dim %[[VAL_0]], %[[VAL_3]] : tensor<?x?x?xf32, #sparse_tensor.encoding<{{{.*}}}>>
-// CHECK-HIR-DAG: %[[VAL_7:.*]] = tensor.dim %[[VAL_0]], %[[VAL_2]] : tensor<?x?x?xf32, #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK-HIR-DAG: %[[VAL_5:.*]] = sparse_tensor.lvl %[[VAL_0]], %[[VAL_3]] : tensor<?x?x?xf32, #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK-HIR-DAG: %[[VAL_6:.*]] = sparse_tensor.lvl %[[VAL_0]], %[[VAL_2]] : tensor<?x?x?xf32, #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK-HIR-DAG: %[[VAL_7:.*]] = sparse_tensor.lvl %[[VAL_0]], %[[VAL_4]] : tensor<?x?x?xf32, #sparse_tensor.encoding<{{{.*}}}>>
// CHECK-HIR-DAG: %[[VAL_8:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<?x?x?xf32, #sparse_tensor.encoding<{{{.*}}}>>
// CHECK-HIR-DAG: %[[VAL_10:.*]] = bufferization.to_memref %[[VAL_1]] : memref<f32>
// CHECK-HIR: %[[VAL_11:.*]] = tensor.extract %[[VAL_1]][] : tensor<f32>
@@ -53,18 +53,18 @@
// CHECK-MIR-DAG: %[[I0:.*]] = arith.constant 0 : index
// CHECK-MIR-DAG: %[[I1:.*]] = arith.constant 1 : index
// CHECK-MIR-DAG: %[[I2:.*]] = arith.constant 2 : index
-// CHECK-MIR-DAG: %[[DimSize0:.*]] = call @sparseDimSize(%[[ARGA]], %[[I0]])
-// CHECK-MIR-DAG: %[[DimSize1:.*]] = call @sparseDimSize(%[[ARGA]], %[[I1]])
-// CHECK-MIR-DAG: %[[DimSize2:.*]] = call @sparseDimSize(%[[ARGA]], %[[I2]])
+// CHECK-MIR-DAG: %[[DimSize0:.*]] = call @sparseLvlSize(%[[ARGA]], %[[I0]])
+// CHECK-MIR-DAG: %[[DimSize1:.*]] = call @sparseLvlSize(%[[ARGA]], %[[I1]])
+// CHECK-MIR-DAG: %[[DimSize2:.*]] = call @sparseLvlSize(%[[ARGA]], %[[I2]])
// CHECK-MIR-DAG: %[[VAL_8:.*]] = call @sparseValuesF32(%[[ARGA]]) : (!llvm.ptr<i8>) -> memref<?xf32>
// CHECK-MIR-DAG: %[[VAL_10:.*]] = bufferization.to_memref %[[ARGX]] : memref<f32>
// CHECK-MIR: %[[VAL_11:.*]] = tensor.extract %[[ARGX]][] : tensor<f32>
-// CHECK-MIR: %[[VAL_12:.*]] = scf.for %[[D2:.*]] = %[[I0]] to %[[DimSize2]] step %[[I1]] iter_args(%[[VAL_14:.*]] = %[[VAL_11]]) -> (f32) {
-// CHECK-MIR: %[[VAL_15:.*]] = scf.for %[[D0:.*]] = %[[I0]] to %[[DimSize0]] step %[[I1]] iter_args(%[[VAL_17:.*]] = %[[VAL_14]]) -> (f32) {
-// CHECK-MIR: %[[VAL_18:.*]] = arith.muli %[[DimSize0]], %[[D2]] : index
+// CHECK-MIR: %[[VAL_12:.*]] = scf.for %[[D2:.*]] = %[[I0]] to %[[DimSize0]] step %[[I1]] iter_args(%[[VAL_14:.*]] = %[[VAL_11]]) -> (f32) {
+// CHECK-MIR: %[[VAL_15:.*]] = scf.for %[[D0:.*]] = %[[I0]] to %[[DimSize1]] step %[[I1]] iter_args(%[[VAL_17:.*]] = %[[VAL_14]]) -> (f32) {
+// CHECK-MIR: %[[VAL_18:.*]] = arith.muli %[[DimSize1]], %[[D2]] : index
// CHECK-MIR: %[[VAL_19:.*]] = arith.addi %[[VAL_18]], %[[D0]] : index
-// CHECK-MIR: %[[VAL_20:.*]] = scf.for %[[D1:.*]] = %[[I0]] to %[[DimSize1]] step %[[I1]] iter_args(%[[VAL_22:.*]] = %[[VAL_17]]) -> (f32) {
-// CHECK-MIR: %[[VAL_23:.*]] = arith.muli %[[DimSize1]], %[[VAL_19]] : index
+// CHECK-MIR: %[[VAL_20:.*]] = scf.for %[[D1:.*]] = %[[I0]] to %[[DimSize2]] step %[[I1]] iter_args(%[[VAL_22:.*]] = %[[VAL_17]]) -> (f32) {
+// CHECK-MIR: %[[VAL_23:.*]] = arith.muli %[[DimSize2]], %[[VAL_19]] : index
// CHECK-MIR: %[[VAL_24:.*]] = arith.addi %[[VAL_23]], %[[D1]] : index
// CHECK-MIR: %[[VAL_25:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_24]]] : memref<?xf32>
// CHECK-MIR: %[[VAL_26:.*]] = arith.addf %[[VAL_22]], %[[VAL_25]] : f32
diff --git a/mlir/test/Dialect/SparseTensor/sparse_reshape.mlir b/mlir/test/Dialect/SparseTensor/sparse_reshape.mlir
index 543fcaededf4965..4f105f3e19b3e75 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_reshape.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_reshape.mlir
@@ -103,7 +103,7 @@ func.func @sparse_collapse(%arg0: tensor<10x10xf64, #SparseMatrix>) -> tensor<10
// CHECK-DAG: %[[C10:.*]] = arith.constant 10 : index
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
-// CHECK: %[[SD:.*]] = tensor.dim %[[S]], %[[C0]]
+// CHECK: %[[SD:.*]] = sparse_tensor.lvl %[[S]], %[[C0]]
// CHECK: %[[DD0:.*]] = arith.divui %[[SD]], %[[C10]] : index
// CHECK: %[[B:.*]] = bufferization.alloc_tensor(%[[DD0]])
// CHECK: %[[P0:.*]] = sparse_tensor.positions %[[S]] {level = 0 : index}
@@ -146,7 +146,7 @@ func.func @dynamic_sparse_expand(%arg0: tensor<?xf64, #SparseVector>) -> tensor<
// CHECK-DAG: %[[C10:.*]] = arith.constant 10 : index
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
-// CHECK: %[[SD1:.*]] = tensor.dim %[[S]], %[[C1]]
+// CHECK: %[[SD1:.*]] = sparse_tensor.lvl %[[S]], %[[C1]]
// CHECK: %[[DD0:.*]] = arith.muli %[[SD1]], %[[C10]] : index
// CHECK: %[[B:.*]] = bufferization.alloc_tensor(%[[DD0]])
// CHECK: %[[P0:.*]] = sparse_tensor.positions %[[S]] {level = 0 : index}
>From 74d69b9c9a4f46fc575f5791da68f8e7f6be6a41 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Tue, 24 Oct 2023 16:40:20 +0000
Subject: [PATCH 2/3] add some comments.
---
.../SparseTensor/Transforms/SparseTensorRewriting.cpp | 5 +++++
1 file changed, 5 insertions(+)
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index 6ef4e09c8fe6de4..6c25cba91e27bb8 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -904,6 +904,11 @@ struct SparseTensorDimOpRewriter : public OpRewritePattern<tensor::DimOp> {
}
// Non-permutation dim2lvl/lvl2dim maps.
+ // Computes as following:
+ // affine.apply #map (l0 - 1, l1 - 1, ...) + 1
+ // Note that it is not the most efficient way (but a more general one) for
+ // the lvl to dim translation, e.g., for BSR, the dimension size for can be
+ // computed simply by lvl_size * block_size.
Location loc = op.getLoc();
SmallVector<Value> maxLvlCrds;
for (Level l = 0; l < stt.getLvlRank(); l++) {
>From dbe65786a56a8a5487d1d925476c0097f9339582 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Tue, 24 Oct 2023 17:16:40 +0000
Subject: [PATCH 3/3] address comments.
---
.../Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index ecc452a5ba6c1cf..2452de8b1ec00f1 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -896,9 +896,9 @@ class SparseExpandConverter : public OpConversionPattern<ExpandOp> {
Type idxType = rewriter.getIndexType();
// All initialization should be done on entry of the loop nest.
rewriter.setInsertionPointAfter(op.getTensor().getDefiningOp());
+
// Determine the size for access expansion (always the innermost stored
- // level size, translated back to original dimension). Note that we
- // recursively rewrite the new DimOp on the **original** tensor.
+ // level size).
const auto sz = desc.getLvlSize(rewriter, loc, srcType.getLvlRank() - 1);
// Generate a memref for `sz` elements of type `t`.
const auto genAlloc = [&](Type t) {
More information about the Mlir-commits
mailing list