[Mlir-commits] [mlir] [mlir][sparse] implement sparse_tensor.lvl operation. (PR #69993)

Peiming Liu llvmlistbot at llvm.org
Tue Oct 24 10:16:55 PDT 2023


https://github.com/PeimingLiu updated https://github.com/llvm/llvm-project/pull/69993

>From 47d6d705645431bbbafc3fe331bf54afef8df253 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Tue, 24 Oct 2023 01:01:04 +0000
Subject: [PATCH 1/3] [mlir][sparse] implement sparse_tensor.lvl operation.

---
 .../SparseTensor/IR/SparseTensorDialect.cpp   |  8 +++-
 .../SparseTensor/Transforms/LoopEmitter.cpp   | 16 ++++---
 .../Transforms/SparseTensorCodegen.cpp        | 46 ++++---------------
 .../Transforms/SparseTensorConversion.cpp     | 20 ++++----
 .../Transforms/SparseTensorRewriting.cpp      | 40 +++++++++++++++-
 mlir/test/Dialect/SparseTensor/codegen.mlir   |  2 +-
 .../test/Dialect/SparseTensor/conversion.mlir |  8 ++--
 mlir/test/Dialect/SparseTensor/sparse_2d.mlir |  4 +-
 mlir/test/Dialect/SparseTensor/sparse_3d.mlir |  4 +-
 .../Dialect/SparseTensor/sparse_expand.mlir   |  8 ++--
 .../Dialect/SparseTensor/sparse_foreach.mlir  |  4 +-
 .../Dialect/SparseTensor/sparse_index.mlir    | 23 +++++-----
 .../Dialect/SparseTensor/sparse_perm.mlir     | 16 +++----
 .../SparseTensor/sparse_perm_lower.mlir       | 22 ++++-----
 .../Dialect/SparseTensor/sparse_reshape.mlir  |  4 +-
 15 files changed, 123 insertions(+), 102 deletions(-)

diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 5455b4fab0c08de..17e6ef53fe596e0 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -915,7 +915,7 @@ Level mlir::sparse_tensor::toStoredDim(SparseTensorEncodingAttr enc,
 // properly handle non-permutations.
 Dimension mlir::sparse_tensor::toOrigDim(RankedTensorType type, Level l) {
   const auto enc = getSparseTensorEncoding(type);
-  assert(l < enc.getLvlRank());
+  assert(!enc || l < enc.getLvlRank());
   return toOrigDim(enc, l);
 }
 
@@ -1208,6 +1208,12 @@ LogicalResult CrdTranslateOp::fold(FoldAdaptor adaptor,
   return success();
 }
 
+void LvlOp::build(OpBuilder &builder, OperationState &state, Value source,
+                  int64_t index) {
+  Value val = builder.create<arith::ConstantIndexOp>(state.location, index);
+  return build(builder, state, source, val);
+}
+
 LogicalResult LvlOp::verify() {
   if (std::optional<uint64_t> lvl = getConstantLvlIndex()) {
     auto stt = getSparseTensorType(getSource());
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
index de7ad3c91d793a0..bb3c6fb56f692d9 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
@@ -428,12 +428,13 @@ void LoopEmitter::initializeLoopEmit(
     const auto enc = getSparseTensorEncoding(rtp);
     const Level cooStart = enc ? getCOOStart(enc) : lvlRank;
 
-    SmallVector<Value> dimSz;
-    for (Dimension d = 0; d < stt.getDimRank(); d++)
-      dimSz.push_back(linalg::createOrFoldDimOp(builder, loc, tensor, d));
-
-    ValueRange lvlSzs =
-        enc.translateCrds(builder, loc, dimSz, CrdTransDirectionKind::dim2lvl);
+    SmallVector<Value> lvlSzs;
+    for (Level l = 0; l < stt.getLvlRank(); l++) {
+      if (stt.hasEncoding())
+        lvlSzs.push_back(builder.create<LvlOp>(loc, tensor, l));
+      else
+        lvlSzs.push_back(builder.create<tensor::DimOp>(loc, tensor, l));
+    }
 
     // Scan all levels of current tensor.
     for (Level l = 0; l < lvlRank; l++) {
@@ -489,7 +490,8 @@ void LoopEmitter::initializeLoopEmit(
       valBuffer[t] = denseVal;
     } else {
       // Annotated sparse tensors.
-      // We also need the value buffer for all-dense annotated "sparse" tensors.
+      // We also need the value buffer for all-dense annotated "sparse"
+      // tensors.
       valBuffer[t] = genToValues(builder, loc, tensor);
     }
     // NOTE: we can also prepare for 0 lvl here in advance, this will hoist
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index 67a78832ef30422..ecc452a5ba6c1cf 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -97,32 +97,6 @@ static scf::ForOp createFor(OpBuilder &builder, Location loc, Value upper,
   return forOp;
 }
 
-/// Gets the dimension size for the given sparse tensor at the given
-/// original dimension 'dim'.
-static Value sizeFromTensorAtDim(OpBuilder &builder, Location loc,
-                                 SparseTensorDescriptor desc, Dimension dim) {
-  const SparseTensorType stt(desc.getRankedTensorType());
-  // Access into static dimension can query original type directly.
-  // Note that this is typically already done by DimOp's folding.
-  if (auto sz = stt.getStaticDimSize(dim))
-    return constantIndex(builder, loc, *sz);
-
-  // Any other query can consult the dimSizes array at field DimSizesIdx,
-  // accounting for the reordering applied to the sparse storage.
-  // FIXME: `toStoredDim` is deprecated.
-  const Level lvl = toStoredDim(stt, dim);
-  return desc.getLvlSize(builder, loc, lvl);
-}
-
-// Gets the dimension size at the given stored level 'lvl', either as a
-// constant for a static size, or otherwise dynamically through memSizes.
-static Value sizeFromTensorAtLvl(OpBuilder &builder, Location loc,
-                                 SparseTensorDescriptor desc, Level lvl) {
-  // FIXME: `toOrigDim` is deprecated.
-  return sizeFromTensorAtDim(builder, loc, desc,
-                             toOrigDim(desc.getRankedTensorType(), lvl));
-}
-
 static void createPushback(OpBuilder &builder, Location loc,
                            MutSparseTensorDescriptor desc,
                            SparseTensorFieldKind kind, std::optional<Level> lvl,
@@ -164,7 +138,7 @@ static void allocSchemeForRank(OpBuilder &builder, Location loc,
     // at this level. We will eventually reach a compressed level or
     // otherwise the values array for the from-here "all-dense" case.
     assert(isDenseDLT(dlt));
-    Value size = sizeFromTensorAtLvl(builder, loc, desc, l);
+    Value size = desc.getLvlSize(builder, loc, l);
     linear = builder.create<arith::MulIOp>(loc, linear, size);
   }
   // Reached values array so prepare for an insertion.
@@ -448,7 +422,7 @@ class SparseInsertGenerator
         // Construct the new position as:
         //   positions[l] = size * positions[l-1] + coords[l]
         //   <insert @ positions[l] at next level l + 1>
-        Value size = sizeFromTensorAtLvl(builder, loc, desc, l);
+        Value size = desc.getLvlSize(builder, loc, l);
         Value mult = builder.create<arith::MulIOp>(loc, size, parentPos);
         parentPos = builder.create<arith::AddIOp>(loc, mult, coords[l]);
       }
@@ -659,18 +633,18 @@ class SparseCallConverter : public OpConversionPattern<func::CallOp> {
 };
 
 /// Sparse codegen rule for dimension accesses.
-class SparseDimOpConverter : public OpConversionPattern<tensor::DimOp> {
+class SparseLvlOpConverter : public OpConversionPattern<LvlOp> {
 public:
   using OpConversionPattern::OpConversionPattern;
   LogicalResult
-  matchAndRewrite(tensor::DimOp op, OpAdaptor adaptor,
+  matchAndRewrite(LvlOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    std::optional<int64_t> dim = op.getConstantIndex();
-    if (!dim || !getSparseTensorEncoding(adaptor.getSource().getType()))
+    std::optional<int64_t> lvl = op.getConstantLvlIndex();
+    if (!lvl || !getSparseTensorEncoding(adaptor.getSource().getType()))
       return failure();
 
     auto desc = getDescriptorFromTensorTuple(adaptor.getSource());
-    auto sz = sizeFromTensorAtDim(rewriter, op.getLoc(), desc, *dim);
+    auto sz = desc.getLvlSize(rewriter, op.getLoc(), *lvl);
 
     rewriter.replaceOp(op, sz);
     return success();
@@ -925,9 +899,7 @@ class SparseExpandConverter : public OpConversionPattern<ExpandOp> {
     // Determine the size for access expansion (always the innermost stored
     // level size, translated back to original dimension). Note that we
     // recursively rewrite the new DimOp on the **original** tensor.
-    // FIXME: `toOrigDim` is deprecated.
-    const Dimension innerDim = toOrigDim(srcType, srcType.getLvlRank() - 1);
-    const auto sz = sizeFromTensorAtDim(rewriter, loc, desc, innerDim);
+    const auto sz = desc.getLvlSize(rewriter, loc, srcType.getLvlRank() - 1);
     // Generate a memref for `sz` elements of type `t`.
     const auto genAlloc = [&](Type t) {
       const auto memTp = MemRefType::get({ShapedType::kDynamic}, t);
@@ -1588,7 +1560,7 @@ void mlir::populateSparseTensorCodegenPatterns(
     TypeConverter &typeConverter, RewritePatternSet &patterns,
     bool createSparseDeallocs, bool enableBufferInitialization) {
   patterns.add<SparseAssembleOpConverter, SparseDisassembleOpConverter,
-               SparseReturnConverter, SparseCallConverter, SparseDimOpConverter,
+               SparseReturnConverter, SparseCallConverter, SparseLvlOpConverter,
                SparseCastConverter, SparseExtractSliceConverter,
                SparseTensorLoadConverter, SparseExpandConverter,
                SparseCompressConverter, SparseInsertConverter,
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
index a1f725333530d90..b5937698968eac5 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
@@ -294,25 +294,27 @@ class SparseReturnConverter : public OpConversionPattern<func::ReturnOp> {
 };
 
 /// Sparse conversion rule for accessing dimension-sizes.
-class SparseTensorToDimSizeConverter
-    : public OpConversionPattern<tensor::DimOp> {
+class SparseTensorLvlOpConverter : public OpConversionPattern<LvlOp> {
 public:
   using OpConversionPattern::OpConversionPattern;
   LogicalResult
-  matchAndRewrite(tensor::DimOp op, OpAdaptor adaptor,
+  matchAndRewrite(LvlOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     const auto stt = getSparseTensorType(op.getSource());
     // Only rewrite sparse DimOp.
     if (!stt.hasEncoding())
       return failure();
+
     // Only rewrite DimOp with constant index.
-    std::optional<int64_t> dim = op.getConstantIndex();
-    if (!dim)
+    std::optional<int64_t> lvl = op.getConstantLvlIndex();
+
+    if (!lvl)
       return failure();
-    // Generate the call.
+
+    // By now, if the level size is constant, the operation should have already
+    // been folded by LvlOp's folder, so we generate the call unconditionally.
     Value src = adaptor.getOperands()[0];
-    rewriter.replaceOp(
-        op, createOrFoldDimCall(rewriter, op->getLoc(), stt, src, *dim));
+    rewriter.replaceOp(op, genLvlSizeCall(rewriter, op.getLoc(), src, *lvl));
     return success();
   }
 };
@@ -763,7 +765,7 @@ mlir::SparseTensorTypeToPtrConverter::SparseTensorTypeToPtrConverter() {
 void mlir::populateSparseTensorConversionPatterns(TypeConverter &typeConverter,
                                                   RewritePatternSet &patterns) {
   patterns
-      .add<SparseReturnConverter, SparseTensorToDimSizeConverter,
+      .add<SparseReturnConverter, SparseTensorLvlOpConverter,
            SparseCastConverter, SparseTensorNewConverter,
            SparseTensorAllocConverter, SparseTensorEmptyConverter,
            SparseTensorDeallocConverter, SparseTensorReorderCOOConverter,
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index 5ce81f932faaf3f..6ef4e09c8fe6de4 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -888,6 +888,43 @@ struct CrdTranslateRewriter : public OpRewritePattern<CrdTranslateOp> {
   }
 };
 
+struct SparseTensorDimOpRewriter : public OpRewritePattern<tensor::DimOp> {
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(tensor::DimOp op,
+                                PatternRewriter &rewriter) const override {
+    std::optional<int64_t> dim = op.getConstantIndex();
+    auto stt = getSparseTensorType(op.getSource());
+    if (!dim || !stt.hasEncoding())
+      return failure();
+
+    if (stt.isPermutation()) {
+      rewriter.replaceOpWithNewOp<LvlOp>(op, op.getSource(),
+                                         toStoredDim(stt, *dim));
+      return success();
+    }
+
+    // Non-permutation dim2lvl/lvl2dim maps.
+    Location loc = op.getLoc();
+    SmallVector<Value> maxLvlCrds;
+    for (Level l = 0; l < stt.getLvlRank(); l++) {
+      Value lvlSz = rewriter.create<LvlOp>(loc, op.getSource(), l);
+      Value maxLvlCrd = rewriter.create<arith::SubIOp>(
+          loc, lvlSz, constantOne(rewriter, loc, rewriter.getIndexType()));
+      maxLvlCrds.push_back(maxLvlCrd);
+    }
+
+    AffineExpr lvl2DimExp = stt.getLvlToDim().getResult(*dim);
+    Value maxDimCrd = rewriter.create<affine::AffineApplyOp>(
+        op.getLoc(), AffineMap::get(stt.getLvlRank(), 0, lvl2DimExp),
+        maxLvlCrds);
+
+    Value dimSz = rewriter.create<arith::AddIOp>(
+        loc, maxDimCrd, constantOne(rewriter, loc, rewriter.getIndexType()));
+    rewriter.replaceOp(op, dimSz);
+    return success();
+  }
+};
+
 struct ConcatenateRewriter : public OpRewritePattern<ConcatenateOp> {
   using OpRewritePattern::OpRewritePattern;
   LogicalResult matchAndRewrite(ConcatenateOp op,
@@ -1270,7 +1307,8 @@ void mlir::populatePostSparsificationRewriting(RewritePatternSet &patterns,
                ReshapeRewriter<tensor::CollapseShapeOp>,
                Sparse2SparseReshapeRewriter<tensor::ExpandShapeOp>,
                Sparse2SparseReshapeRewriter<tensor::CollapseShapeOp>,
-               TensorReshapeRewriter>(patterns.getContext());
+               SparseTensorDimOpRewriter, TensorReshapeRewriter>(
+      patterns.getContext());
   if (enableForeach)
     patterns.add<ForeachRewriter>(patterns.getContext());
   if (enableConvert)
diff --git a/mlir/test/Dialect/SparseTensor/codegen.mlir b/mlir/test/Dialect/SparseTensor/codegen.mlir
index 84904227a636327..8993333d6e5333d 100644
--- a/mlir/test/Dialect/SparseTensor/codegen.mlir
+++ b/mlir/test/Dialect/SparseTensor/codegen.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s --sparse-tensor-codegen  --canonicalize -cse | FileCheck %s
+// RUN: mlir-opt %s --post-sparsification-rewrite --sparse-tensor-codegen  --canonicalize -cse | FileCheck %s
 
 #SV = #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed) }>
 
diff --git a/mlir/test/Dialect/SparseTensor/conversion.mlir b/mlir/test/Dialect/SparseTensor/conversion.mlir
index 2ff4887dae7b8c9..1d7599b3a4edb87 100644
--- a/mlir/test/Dialect/SparseTensor/conversion.mlir
+++ b/mlir/test/Dialect/SparseTensor/conversion.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s --sparse-tensor-conversion --canonicalize --cse | FileCheck %s
+// RUN: mlir-opt %s --post-sparsification-rewrite --sparse-tensor-conversion --canonicalize --cse | FileCheck %s
 
 #SparseVector = #sparse_tensor.encoding<{
   map = (d0) -> (d0 : compressed)
@@ -38,7 +38,7 @@ func.func @sparse_nop(%arg0: tensor<?xf64, #SparseVector>) -> tensor<?xf64, #Spa
 // CHECK-LABEL: func @sparse_dim1d(
 //  CHECK-SAME: %[[A:.*]]: !llvm.ptr<i8>)
 //       CHECK: %[[C:.*]] = arith.constant 0 : index
-//       CHECK: %[[D:.*]] = call @sparseDimSize(%[[A]], %[[C]])
+//       CHECK: %[[D:.*]] = call @sparseLvlSize(%[[A]], %[[C]])
 //       CHECK: return %[[D]] : index
 func.func @sparse_dim1d(%arg0: tensor<?xf64, #SparseVector>) -> index {
   %c = arith.constant 0 : index
@@ -51,8 +51,8 @@ func.func @sparse_dim1d(%arg0: tensor<?xf64, #SparseVector>) -> index {
 // dimension 1 is stored as level 2).
 // CHECK-LABEL: func @sparse_dim3d(
 //  CHECK-SAME: %[[A:.*]]: !llvm.ptr<i8>)
-//       CHECK: %[[C:.*]] = arith.constant 1 : index
-//       CHECK: %[[D:.*]] = call @sparseDimSize(%[[A]], %[[C]])
+//       CHECK: %[[C:.*]] = arith.constant 2 : index
+//       CHECK: %[[D:.*]] = call @sparseLvlSize(%[[A]], %[[C]])
 //       CHECK: return %[[D]] : index
 func.func @sparse_dim3d(%arg0: tensor<?x?x?xf64, #SparseTensor>) -> index {
   %c = arith.constant 1 : index
diff --git a/mlir/test/Dialect/SparseTensor/sparse_2d.mlir b/mlir/test/Dialect/SparseTensor/sparse_2d.mlir
index 83c0ef390ae7a03..076e9201a1c053e 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_2d.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_2d.mlir
@@ -1511,7 +1511,7 @@ func.func @sum_reduction(%arga: tensor<10x20xf32, #Tds>, %argx: tensor<f32>) ->
 // CHECK-DAG:       %[[VAL_5:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 1 : index} : tensor<?x?xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
 // CHECK-DAG:       %[[VAL_6:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 1 : index} : tensor<?x?xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
 // CHECK-DAG:       %[[VAL_7:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<?x?xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xf64>
-// CHECK-DAG:       %[[VAL_8:.*]] = tensor.dim %[[VAL_0]], %[[VAL_3]] : tensor<?x?xf64, #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK-DAG:       %[[VAL_8:.*]] = sparse_tensor.lvl %[[VAL_0]], %[[VAL_3]] : tensor<?x?xf64, #sparse_tensor.encoding<{{{.*}}}>>
 // CHECK-DAG:       %[[VAL_11:.*]] = bufferization.to_memref %[[VAL_1]] : memref<?x?xf64>
 // CHECK:           linalg.fill ins(%{{.*}} : f64) outs(%[[VAL_11]] : memref<?x?xf64>)
 // CHECK:           scf.for %[[VAL_12:.*]] = %[[VAL_3]] to %[[VAL_8]] step %[[VAL_4]] {
@@ -1641,7 +1641,7 @@ func.func @sampled_dense_dense(%args: tensor<?x?xf32, #Tss>,
 // CHECK-DAG:       %[[VAL_19:.*]] = sparse_tensor.values %[[VAL_2]] : tensor<?x?xf32, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xf32>
 // CHECK-DAG:       %[[VAL_20:.*]] = bufferization.to_memref %[[VAL_3]] : memref<?xf32>
 // CHECK-DAG:       %[[VAL_21:.*]] = bufferization.to_memref %[[VAL_4]] : memref<f32>
-// CHECK-DAG:       %[[VAL_22:.*]] = tensor.dim %[[VAL_2]], %[[VAL_6]] : tensor<?x?xf32,
+// CHECK-DAG:       %[[VAL_22:.*]] = sparse_tensor.lvl %[[VAL_2]], %[[VAL_6]] : tensor<?x?xf32,
 // CHECK-DAG:       %[[VAL_24:.*]] = bufferization.to_memref %[[VAL_5]] : memref<?xf32>
 // CHECK:           %[[VAL_25:.*]] = memref.load %[[VAL_21]][] : memref<f32>
 // CHECK:           %[[VAL_26:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_6]]] : memref<?xindex>
diff --git a/mlir/test/Dialect/SparseTensor/sparse_3d.mlir b/mlir/test/Dialect/SparseTensor/sparse_3d.mlir
index 7b5d7420ff4c736..c245e612be37f12 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_3d.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_3d.mlir
@@ -1126,10 +1126,10 @@ func.func @mul_sss(%arga: tensor<32x16x8xf32, #Tsss>, %argb: tensor<32x16x8xf32>
 // CHECK-DAG:       %[[VAL_7:.*]] = sparse_tensor.positions %[[VAL_1]] {level = 2 : index} : tensor<?x?x?xf32, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
 // CHECK-DAG:       %[[VAL_8:.*]] = sparse_tensor.coordinates %[[VAL_1]] {level = 2 : index} : tensor<?x?x?xf32, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
 // CHECK-DAG:       %[[VAL_9:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<?x?x?xf32, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xf32>
-// CHECK-DAG:       %[[VAL_10:.*]] = tensor.dim %[[VAL_1]], %[[VAL_6]] : tensor<?x?x?xf32, #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK-DAG:       %[[VAL_10:.*]] = sparse_tensor.lvl %[[VAL_1]], %[[VAL_6]] : tensor<?x?x?xf32, #sparse_tensor.encoding<{{{.*}}}>>
 // CHECK-DAG:       %[[VAL_11:.*]] = bufferization.to_memref %[[VAL_2]] : memref<?x?xf32>
 // CHECK-DAG:       %[[VAL_12:.*]] = bufferization.to_memref %[[VAL_3]] : memref<?x?xf32>
-// CHECK-DAG:       %[[VAL_13:.*]] = tensor.dim %[[VAL_1]], %[[VAL_5]] : tensor<?x?x?xf32, #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK-DAG:       %[[VAL_13:.*]] = sparse_tensor.lvl %[[VAL_1]], %[[VAL_5]] : tensor<?x?x?xf32, #sparse_tensor.encoding<{{{.*}}}>>
 // CHECK-DAG:       %[[VAL_14:.*]] = tensor.dim %[[VAL_2]], %[[VAL_6]] : tensor<?x?xf32>
 // CHECK-DAG:       %[[VAL_16:.*]] = bufferization.to_memref %[[VAL_0]] : memref<?x?xf32>
 // CHECK:           scf.for %[[VAL_17:.*]] = %[[VAL_5]] to %[[VAL_13]] step %[[VAL_6]] {
diff --git a/mlir/test/Dialect/SparseTensor/sparse_expand.mlir b/mlir/test/Dialect/SparseTensor/sparse_expand.mlir
index 9d8db10aa423022..3ee6e84a2382a9e 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_expand.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_expand.mlir
@@ -4,7 +4,8 @@
 // RUN:   FileCheck %s --check-prefix=CHECK-SPARSE
 // RUN: mlir-opt %s --linalg-generalize-named-ops \
 // RUN:             --linalg-fuse-elementwise-ops \
-// RUN:             --sparsification --sparse-tensor-conversion --cse | \
+// RUN:             --sparsification --post-sparsification-rewrite \
+// RUN:             --sparse-tensor-conversion --cse | \
 // RUN:   FileCheck %s --check-prefix=CHECK-CONVERT
 
 #CSR = #sparse_tensor.encoding<{
@@ -45,8 +46,9 @@
 //
 // CHECK-CONVERT-LABEL: func @kernel(
 // CHECK-CONVERT-SAME: %[[A:.*]]: !llvm.ptr<i8>) -> !llvm.ptr<i8>
-// CHECK-CONVERT: %[[C0:.*]] = arith.constant 0 : index
-// CHECK-CONVERT: %[[N:.*]] = call @sparseDimSize(%[[A]], %[[C0]])
+// CHECK-CONVERT-DAG: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-CONVERT-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-CONVERT: %[[N:.*]] = call @sparseLvlSize(%[[A]], %[[C1]])
 // CHECK-CONVERT: %[[V:.*]] = call @newSparseTensor
 // CHECK-CONVERT: %[[S:.*]] = call @sparseLvlSize(%[[V]], %[[C0]])
 // CHECK-CONVERT: %[[A:.*]] = memref.alloc(%[[S]]) : memref<?xf64>
diff --git a/mlir/test/Dialect/SparseTensor/sparse_foreach.mlir b/mlir/test/Dialect/SparseTensor/sparse_foreach.mlir
index 8c1836b1c2ef8f1..bbce42c100641ab 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_foreach.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_foreach.mlir
@@ -43,12 +43,12 @@ func.func @sparse_foreach_constant() -> () {
 // CHECK-DAG:       %[[VAL_2:.*]] = arith.constant 1 : index
 // CHECK-DAG:       %[[VAL_3:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<?x?xf64,
 // CHECK-DAG:       %[[VAL_4:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<?x?xf64,
-// CHECK-DAG:       %[[VAL_5:.*]] = tensor.dim %[[VAL_0]], %[[VAL_1]] : tensor<?x?xf64,
+// CHECK-DAG:       %[[VAL_5:.*]] = sparse_tensor.lvl %[[VAL_0]], %[[VAL_1]] : tensor<?x?xf64,
 // CHECK-DAG:       %[[VAL_6:.*]] = sparse_tensor.slice.offset %[[VAL_0]] at 0 : tensor<?x?xf64,
 // CHECK-DAG:       %[[VAL_7:.*]] = sparse_tensor.slice.stride %[[VAL_0]] at 0 : tensor<?x?xf64,
 // CHECK-DAG:       %[[VAL_8:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 1 : index} : tensor<?x?xf64,
 // CHECK-DAG:       %[[VAL_9:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 1 : index} : tensor<?x?xf64,
-// CHECK-DAG:       %[[VAL_10:.*]] = tensor.dim %[[VAL_0]], %[[VAL_2]] : tensor<?x?xf64,
+// CHECK-DAG:       %[[VAL_10:.*]] = sparse_tensor.lvl %[[VAL_0]], %[[VAL_2]] : tensor<?x?xf64,
 // CHECK-DAG:       %[[VAL_11:.*]] = sparse_tensor.slice.offset %[[VAL_0]] at 1 : tensor<?x?xf64,
 // CHECK-DAG:       %[[VAL_12:.*]] = sparse_tensor.slice.stride %[[VAL_0]] at 1 : tensor<?x?xf64,
 // CHECK-DAG:       %[[VAL_13:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<?x?xf64,
diff --git a/mlir/test/Dialect/SparseTensor/sparse_index.mlir b/mlir/test/Dialect/SparseTensor/sparse_index.mlir
index 674225b9e935910..34e04c03529036f 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_index.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_index.mlir
@@ -21,13 +21,13 @@
 // CHECK-SAME:      %[[VAL_0:.*]]: tensor<?x?xi64, #sparse_tensor.encoding
 // CHECK-DAG:       %[[VAL_1:.*]] = arith.constant 0 : index
 // CHECK-DAG:       %[[VAL_2:.*]] = arith.constant 1 : index
-// CHECK-DAG:       %[[VAL_3:.*]] = tensor.dim %[[VAL_0]], %[[VAL_1]] : tensor<?x?xi64, #sparse_tensor.encoding
-// CHECK-DAG:       %[[VAL_4:.*]] = tensor.dim %[[VAL_0]], %[[VAL_1]] : tensor<?x?xi64, #sparse_tensor.encoding
+// CHECK-DAG:       %[[VAL_3:.*]] = sparse_tensor.lvl %[[VAL_0]], %[[VAL_1]] : tensor<?x?xi64, #sparse_tensor.encoding
+// CHECK-DAG:       %[[VAL_4:.*]] = sparse_tensor.lvl %[[VAL_0]], %[[VAL_1]] : tensor<?x?xi64, #sparse_tensor.encoding
 // CHECK-DAG:       %[[VAL_5:.*]] = tensor.empty(%[[VAL_3]], %[[VAL_4]]) : tensor<?x?xi64, #sparse_tensor.encoding
 // CHECK-DAG:       %[[VAL_6:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<?x?xi64, #sparse_tensor.encoding
-// CHECK-DAG:       %[[VAL_7:.*]] = tensor.dim %[[VAL_0]], %[[VAL_1]] : tensor<?x?xi64, #sparse_tensor.encoding
-// CHECK-DAG:       %[[VAL_8:.*]] = tensor.dim %[[VAL_0]], %[[VAL_2]] : tensor<?x?xi64, #sparse_tensor.encoding
-// CHECK-DAG:       %[[VAL_24:.*]] = tensor.dim %[[VAL_5]], %[[VAL_2]] : tensor<?x?xi64, #sparse_tensor.encoding
+// CHECK-DAG:       %[[VAL_7:.*]] = sparse_tensor.lvl %[[VAL_0]], %[[VAL_1]] : tensor<?x?xi64, #sparse_tensor.encoding
+// CHECK-DAG:       %[[VAL_8:.*]] = sparse_tensor.lvl %[[VAL_0]], %[[VAL_2]] : tensor<?x?xi64, #sparse_tensor.encoding
+// CHECK-DAG:       %[[VAL_24:.*]] = sparse_tensor.lvl %[[VAL_5]], %[[VAL_2]] : tensor<?x?xi64, #sparse_tensor.encoding
 // CHECK-DAG:       %[[VAL_9:.*]] = sparse_tensor.values %[[VAL_5]] : tensor<?x?xi64, #sparse_tensor.encoding
 // CHECK:           scf.for %[[VAL_10:.*]] = %[[VAL_1]] to %[[VAL_7]] step %[[VAL_2]] {
 // CHECK:             scf.for %[[VAL_11:.*]] = %[[VAL_1]] to %[[VAL_8]] step %[[VAL_2]] {
@@ -50,8 +50,8 @@ func.func @dense_index(%arga: tensor<?x?xi64, #DenseMatrix>)
                       -> tensor<?x?xi64, #DenseMatrix> {
   %c0 = arith.constant 0 : index
   %c1 = arith.constant 0 : index
-  %0 = tensor.dim %arga, %c0 : tensor<?x?xi64, #DenseMatrix>
-  %1 = tensor.dim %arga, %c1 : tensor<?x?xi64, #DenseMatrix>
+  %0 = sparse_tensor.lvl %arga, %c0 : tensor<?x?xi64, #DenseMatrix>
+  %1 = sparse_tensor.lvl %arga, %c1 : tensor<?x?xi64, #DenseMatrix>
   %init = tensor.empty(%0, %1) : tensor<?x?xi64, #DenseMatrix>
   %r = linalg.generic #trait
       ins(%arga: tensor<?x?xi64, #DenseMatrix>)
@@ -73,8 +73,8 @@ func.func @dense_index(%arga: tensor<?x?xi64, #DenseMatrix>)
 // CHECK-SAME:      %[[VAL_0:.*]]: tensor<?x?xi64, #sparse_tensor.encoding
 // CHECK-DAG:       %[[VAL_1:.*]] = arith.constant 0 : index
 // CHECK-DAG:       %[[VAL_2:.*]] = arith.constant 1 : index
-// CHECK-DAG:       %[[VAL_3:.*]] = tensor.dim %[[VAL_0]], %[[VAL_1]] : tensor<?x?xi64, #sparse_tensor.encoding
-// CHECK-DAG:       %[[VAL_4:.*]] = tensor.dim %[[VAL_0]], %[[VAL_1]] : tensor<?x?xi64, #sparse_tensor.encoding
+// CHECK-DAG:       %[[VAL_3:.*]] = sparse_tensor.lvl %[[VAL_0]], %[[VAL_1]] : tensor<?x?xi64, #sparse_tensor.encoding
+// CHECK-DAG:       %[[VAL_4:.*]] = sparse_tensor.lvl %[[VAL_0]], %[[VAL_1]] : tensor<?x?xi64, #sparse_tensor.encoding
 // CHECK-DAG:       %[[VAL_5:.*]] = tensor.empty(%[[VAL_3]], %[[VAL_4]]) : tensor<?x?xi64, #sparse_tensor.encoding
 // CHECK-DAG:       %[[VAL_6:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<?x?xi64, #sparse_tensor.encoding
 // CHECK-DAG:       %[[VAL_7:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<?x?xi64, #sparse_tensor.encoding
@@ -107,8 +107,8 @@ func.func @sparse_index(%arga: tensor<?x?xi64, #SparseMatrix>)
                        -> tensor<?x?xi64, #SparseMatrix> {
   %c0 = arith.constant 0 : index
   %c1 = arith.constant 0 : index
-  %0 = tensor.dim %arga, %c0 : tensor<?x?xi64, #SparseMatrix>
-  %1 = tensor.dim %arga, %c1 : tensor<?x?xi64, #SparseMatrix>
+  %0 = sparse_tensor.lvl %arga, %c0 : tensor<?x?xi64, #SparseMatrix>
+  %1 = sparse_tensor.lvl %arga, %c1 : tensor<?x?xi64, #SparseMatrix>
   %init = tensor.empty(%0, %1) : tensor<?x?xi64, #SparseMatrix>
   %r = linalg.generic #trait
       ins(%arga: tensor<?x?xi64, #SparseMatrix>)
@@ -124,4 +124,3 @@ func.func @sparse_index(%arga: tensor<?x?xi64, #SparseMatrix>)
   } -> tensor<?x?xi64, #SparseMatrix>
   return %r : tensor<?x?xi64, #SparseMatrix>
 }
-
diff --git a/mlir/test/Dialect/SparseTensor/sparse_perm.mlir b/mlir/test/Dialect/SparseTensor/sparse_perm.mlir
index 26712f2c7b001b3..186030aa2ca0873 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_perm.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_perm.mlir
@@ -59,17 +59,17 @@ func.func @sparse_static_dims(%arga: tensor<10x20x30xf32, #X>,
 // CHECK-DAG:       %[[VAL_3:.*]] = arith.constant 0 : index
 // CHECK-DAG:       %[[VAL_4:.*]] = arith.constant 1 : index
 // CHECK-DAG:       %[[VAL_5:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<?x?x?xf32, #sparse_tensor.encoding<{{{.*}}}>>
-// CHECK-DAG:       %[[VAL_6:.*]] = tensor.dim %[[VAL_0]], %[[VAL_2]] : tensor<?x?x?xf32, #sparse_tensor.encoding<{{{.*}}}>>
-// CHECK-DAG:       %[[VAL_7:.*]] = tensor.dim %[[VAL_0]], %[[VAL_3]] : tensor<?x?x?xf32, #sparse_tensor.encoding<{{{.*}}}>>
-// CHECK-DAG:       %[[VAL_8:.*]] = tensor.dim %[[VAL_0]], %[[VAL_4]] : tensor<?x?x?xf32, #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK-DAG:       %[[VAL_6:.*]] = sparse_tensor.lvl %[[VAL_0]], %[[VAL_2]] : tensor<?x?x?xf32, #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK-DAG:       %[[VAL_7:.*]] = sparse_tensor.lvl %[[VAL_0]], %[[VAL_3]] : tensor<?x?x?xf32, #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK-DAG:       %[[VAL_8:.*]] = sparse_tensor.lvl %[[VAL_0]], %[[VAL_4]] : tensor<?x?x?xf32, #sparse_tensor.encoding<{{{.*}}}>>
 // CHECK-DAG:       %[[VAL_10:.*]] = bufferization.to_memref %[[VAL_1]] : memref<?x?x?xf32>
 // CHECK:           linalg.fill ins(%[[ZERO]] : f32) outs(%[[VAL_10]] : memref<?x?x?xf32>)
-// CHECK:           scf.for %[[VAL_11:.*]] = %[[VAL_3]] to %[[VAL_6]] step %[[VAL_4]] {
-// CHECK:             scf.for %[[VAL_12:.*]] = %[[VAL_3]] to %[[VAL_7]] step %[[VAL_4]] {
-// CHECK:               %[[VAL_13:.*]] = arith.muli %[[VAL_7]], %[[VAL_11]] : index
+// CHECK:           scf.for %[[VAL_11:.*]] = %[[VAL_3]] to %[[VAL_7]] step %[[VAL_4]] {
+// CHECK:             scf.for %[[VAL_12:.*]] = %[[VAL_3]] to %[[VAL_8]] step %[[VAL_4]] {
+// CHECK:               %[[VAL_13:.*]] = arith.muli %[[VAL_8]], %[[VAL_11]] : index
 // CHECK:               %[[VAL_14:.*]] = arith.addi %[[VAL_13]], %[[VAL_12]] : index
-// CHECK:               scf.for %[[VAL_15:.*]] = %[[VAL_3]] to %[[VAL_8]] step %[[VAL_4]] {
-// CHECK:                 %[[VAL_16:.*]] = arith.muli %[[VAL_8]], %[[VAL_14]] : index
+// CHECK:               scf.for %[[VAL_15:.*]] = %[[VAL_3]] to %[[VAL_6]] step %[[VAL_4]] {
+// CHECK:                 %[[VAL_16:.*]] = arith.muli %[[VAL_6]], %[[VAL_14]] : index
 // CHECK:                 %[[VAL_17:.*]] = arith.addi %[[VAL_16]], %[[VAL_15]] : index
 // CHECK:                 %[[VAL_18:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_17]]] : memref<?xf32>
 // CHECK:                 memref.store %[[VAL_18]], %[[VAL_10]]{{\[}}%[[VAL_15]], %[[VAL_11]], %[[VAL_12]]] : memref<?x?x?xf32>
diff --git a/mlir/test/Dialect/SparseTensor/sparse_perm_lower.mlir b/mlir/test/Dialect/SparseTensor/sparse_perm_lower.mlir
index fd8aaf28697e42d..42726d998ac7a72 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_perm_lower.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_perm_lower.mlir
@@ -21,9 +21,9 @@
 // CHECK-HIR-DAG:       %[[VAL_2:.*]] = arith.constant 1 : index
 // CHECK-HIR-DAG:       %[[VAL_3:.*]] = arith.constant 0 : index
 // CHECK-HIR-DAG:       %[[VAL_4:.*]] = arith.constant 2 : index
-// CHECK-HIR-DAG:       %[[VAL_5:.*]] = tensor.dim %[[VAL_0]], %[[VAL_4]] : tensor<?x?x?xf32, #sparse_tensor.encoding<{{{.*}}}>>
-// CHECK-HIR-DAG:       %[[VAL_6:.*]] = tensor.dim %[[VAL_0]], %[[VAL_3]] : tensor<?x?x?xf32, #sparse_tensor.encoding<{{{.*}}}>>
-// CHECK-HIR-DAG:       %[[VAL_7:.*]] = tensor.dim %[[VAL_0]], %[[VAL_2]] : tensor<?x?x?xf32, #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK-HIR-DAG:       %[[VAL_5:.*]] = sparse_tensor.lvl %[[VAL_0]], %[[VAL_3]] : tensor<?x?x?xf32, #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK-HIR-DAG:       %[[VAL_6:.*]] = sparse_tensor.lvl %[[VAL_0]], %[[VAL_2]] : tensor<?x?x?xf32, #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK-HIR-DAG:       %[[VAL_7:.*]] = sparse_tensor.lvl %[[VAL_0]], %[[VAL_4]] : tensor<?x?x?xf32, #sparse_tensor.encoding<{{{.*}}}>>
 // CHECK-HIR-DAG:       %[[VAL_8:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<?x?x?xf32, #sparse_tensor.encoding<{{{.*}}}>>
 // CHECK-HIR-DAG:       %[[VAL_10:.*]] = bufferization.to_memref %[[VAL_1]] : memref<f32>
 // CHECK-HIR:           %[[VAL_11:.*]] = tensor.extract %[[VAL_1]][] : tensor<f32>
@@ -53,18 +53,18 @@
 // CHECK-MIR-DAG:       %[[I0:.*]] = arith.constant 0 : index
 // CHECK-MIR-DAG:       %[[I1:.*]] = arith.constant 1 : index
 // CHECK-MIR-DAG:       %[[I2:.*]] = arith.constant 2 : index
-// CHECK-MIR-DAG:       %[[DimSize0:.*]] = call @sparseDimSize(%[[ARGA]], %[[I0]])
-// CHECK-MIR-DAG:       %[[DimSize1:.*]] = call @sparseDimSize(%[[ARGA]], %[[I1]])
-// CHECK-MIR-DAG:       %[[DimSize2:.*]] = call @sparseDimSize(%[[ARGA]], %[[I2]])
+// CHECK-MIR-DAG:       %[[DimSize0:.*]] = call @sparseLvlSize(%[[ARGA]], %[[I0]])
+// CHECK-MIR-DAG:       %[[DimSize1:.*]] = call @sparseLvlSize(%[[ARGA]], %[[I1]])
+// CHECK-MIR-DAG:       %[[DimSize2:.*]] = call @sparseLvlSize(%[[ARGA]], %[[I2]])
 // CHECK-MIR-DAG:       %[[VAL_8:.*]] = call @sparseValuesF32(%[[ARGA]]) : (!llvm.ptr<i8>) -> memref<?xf32>
 // CHECK-MIR-DAG:       %[[VAL_10:.*]] = bufferization.to_memref %[[ARGX]] : memref<f32>
 // CHECK-MIR:           %[[VAL_11:.*]] = tensor.extract %[[ARGX]][] : tensor<f32>
-// CHECK-MIR:           %[[VAL_12:.*]] = scf.for %[[D2:.*]] = %[[I0]] to %[[DimSize2]] step %[[I1]] iter_args(%[[VAL_14:.*]] = %[[VAL_11]]) -> (f32) {
-// CHECK-MIR:             %[[VAL_15:.*]] = scf.for %[[D0:.*]] = %[[I0]] to %[[DimSize0]] step %[[I1]] iter_args(%[[VAL_17:.*]] = %[[VAL_14]]) -> (f32) {
-// CHECK-MIR:               %[[VAL_18:.*]] = arith.muli %[[DimSize0]], %[[D2]] : index
+// CHECK-MIR:           %[[VAL_12:.*]] = scf.for %[[D2:.*]] = %[[I0]] to %[[DimSize0]] step %[[I1]] iter_args(%[[VAL_14:.*]] = %[[VAL_11]]) -> (f32) {
+// CHECK-MIR:             %[[VAL_15:.*]] = scf.for %[[D0:.*]] = %[[I0]] to %[[DimSize1]] step %[[I1]] iter_args(%[[VAL_17:.*]] = %[[VAL_14]]) -> (f32) {
+// CHECK-MIR:               %[[VAL_18:.*]] = arith.muli %[[DimSize1]], %[[D2]] : index
 // CHECK-MIR:               %[[VAL_19:.*]] = arith.addi %[[VAL_18]], %[[D0]] : index
-// CHECK-MIR:               %[[VAL_20:.*]] = scf.for %[[D1:.*]] = %[[I0]] to %[[DimSize1]] step %[[I1]] iter_args(%[[VAL_22:.*]] = %[[VAL_17]]) -> (f32) {
-// CHECK-MIR:                 %[[VAL_23:.*]] = arith.muli %[[DimSize1]], %[[VAL_19]] : index
+// CHECK-MIR:               %[[VAL_20:.*]] = scf.for %[[D1:.*]] = %[[I0]] to %[[DimSize2]] step %[[I1]] iter_args(%[[VAL_22:.*]] = %[[VAL_17]]) -> (f32) {
+// CHECK-MIR:                 %[[VAL_23:.*]] = arith.muli %[[DimSize2]], %[[VAL_19]] : index
 // CHECK-MIR:                 %[[VAL_24:.*]] = arith.addi %[[VAL_23]], %[[D1]] : index
 // CHECK-MIR:                 %[[VAL_25:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_24]]] : memref<?xf32>
 // CHECK-MIR:                 %[[VAL_26:.*]] = arith.addf %[[VAL_22]], %[[VAL_25]] : f32
diff --git a/mlir/test/Dialect/SparseTensor/sparse_reshape.mlir b/mlir/test/Dialect/SparseTensor/sparse_reshape.mlir
index 543fcaededf4965..4f105f3e19b3e75 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_reshape.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_reshape.mlir
@@ -103,7 +103,7 @@ func.func @sparse_collapse(%arg0: tensor<10x10xf64, #SparseMatrix>) -> tensor<10
 // CHECK-DAG:     %[[C10:.*]] = arith.constant 10 : index
 // CHECK-DAG:     %[[C0:.*]] = arith.constant 0 : index
 // CHECK-DAG:     %[[C1:.*]] = arith.constant 1 : index
-// CHECK:         %[[SD:.*]] = tensor.dim %[[S]], %[[C0]]
+// CHECK:         %[[SD:.*]] = sparse_tensor.lvl %[[S]], %[[C0]]
 // CHECK:         %[[DD0:.*]] = arith.divui %[[SD]], %[[C10]] : index
 // CHECK:         %[[B:.*]] = bufferization.alloc_tensor(%[[DD0]])
 // CHECK:         %[[P0:.*]] = sparse_tensor.positions %[[S]] {level = 0 : index}
@@ -146,7 +146,7 @@ func.func @dynamic_sparse_expand(%arg0: tensor<?xf64, #SparseVector>) -> tensor<
 // CHECK-DAG:     %[[C10:.*]] = arith.constant 10 : index
 // CHECK-DAG:     %[[C0:.*]] = arith.constant 0 : index
 // CHECK-DAG:     %[[C1:.*]] = arith.constant 1 : index
-// CHECK:         %[[SD1:.*]] = tensor.dim %[[S]], %[[C1]]
+// CHECK:         %[[SD1:.*]] = sparse_tensor.lvl %[[S]], %[[C1]]
 // CHECK:         %[[DD0:.*]] = arith.muli %[[SD1]], %[[C10]] : index
 // CHECK:         %[[B:.*]] = bufferization.alloc_tensor(%[[DD0]])
 // CHECK:         %[[P0:.*]] = sparse_tensor.positions %[[S]] {level = 0 : index}

>From 74d69b9c9a4f46fc575f5791da68f8e7f6be6a41 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Tue, 24 Oct 2023 16:40:20 +0000
Subject: [PATCH 2/3] add some comments.

---
 .../SparseTensor/Transforms/SparseTensorRewriting.cpp        | 5 +++++
 1 file changed, 5 insertions(+)

diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index 6ef4e09c8fe6de4..6c25cba91e27bb8 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -904,6 +904,11 @@ struct SparseTensorDimOpRewriter : public OpRewritePattern<tensor::DimOp> {
     }
 
     // Non-permutation dim2lvl/lvl2dim maps.
+    // Computes as following:
+    // affine.apply #map (l0 - 1, l1 - 1, ...) + 1
+    // Note that it is not the most efficient way (but a more general one) for
+    // the lvl to dim translation, e.g., for BSR, the dimension size for can be
+    // computed simply by lvl_size * block_size.
     Location loc = op.getLoc();
     SmallVector<Value> maxLvlCrds;
     for (Level l = 0; l < stt.getLvlRank(); l++) {

>From dbe65786a56a8a5487d1d925476c0097f9339582 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Tue, 24 Oct 2023 17:16:40 +0000
Subject: [PATCH 3/3] address comments.

---
 .../Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp   | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index ecc452a5ba6c1cf..2452de8b1ec00f1 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -896,9 +896,9 @@ class SparseExpandConverter : public OpConversionPattern<ExpandOp> {
     Type idxType = rewriter.getIndexType();
     // All initialization should be done on entry of the loop nest.
     rewriter.setInsertionPointAfter(op.getTensor().getDefiningOp());
+
     // Determine the size for access expansion (always the innermost stored
-    // level size, translated back to original dimension). Note that we
-    // recursively rewrite the new DimOp on the **original** tensor.
+    // level size).
     const auto sz = desc.getLvlSize(rewriter, loc, srcType.getLvlRank() - 1);
     // Generate a memref for `sz` elements of type `t`.
     const auto genAlloc = [&](Type t) {



More information about the Mlir-commits mailing list