[Mlir-commits] [mlir] 86f91e4 - [mlir][sparse] Cleaning up the dim/lvl distinction in SparseTensorConversion

wren romano llvmlistbot at llvm.org
Mon Dec 5 16:59:50 PST 2022


Author: wren romano
Date: 2022-12-05T16:59:42-08:00
New Revision: 86f91e45a22bbb981ede3439c7241ee92ea522ec

URL: https://github.com/llvm/llvm-project/commit/86f91e45a22bbb981ede3439c7241ee92ea522ec
DIFF: https://github.com/llvm/llvm-project/commit/86f91e45a22bbb981ede3439c7241ee92ea522ec.diff

LOG: [mlir][sparse] Cleaning up the dim/lvl distinction in SparseTensorConversion

This change cleans up the conversion pass re the "dim"-vs-"lvl" and "sizes"-vs-"shape" distinctions of the runtime. A quick synopsis includes:

* Adds new `SparseTensorStorageBase::getDimSize` method, with `sparseDimSize` wrapper in SparseTensorRuntime.h, and `genDimSizeCall` generator in SparseTensorConversion.cpp
* Changes `genLvlSizeCall` to perform no logic, just generate the function call.
* Adds `createOrFold{Dim,Lvl}Call` functions to handle the logic of replacing `gen{Dim,Lvl}SizeCall` with constants whenever possible. The `createOrFoldDimCall` function replaces the old `sizeFromPtrAtDim`.
* Adds `{get,fill}DimSizes` functions for iterating `createOrFoldDimCall` across the whole type. These functions replace the old `sizesFromPtr`.
* Adds `{get,fill}DimShape` functions for lowering a `ShapedType` into constants. These functions replace the old `sizesFromType`.
* Changes the `DimOp` rewrite to do the right thing.
* Changes the `ExpandOp` rewrite to compute the proper expansion size.

Depends On D138365

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D139165

Added: 
    

Modified: 
    mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h
    mlir/include/mlir/ExecutionEngine/SparseTensorRuntime.h
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
    mlir/lib/ExecutionEngine/SparseTensorRuntime.cpp
    mlir/test/Dialect/SparseTensor/conversion.mlir
    mlir/test/Dialect/SparseTensor/convert_sparse2dense.mlir
    mlir/test/Dialect/SparseTensor/sparse_expand.mlir
    mlir/test/Dialect/SparseTensor/sparse_perm_lower.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h b/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h
index a8986a86835dd..c5e310937efe4 100644
--- a/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h
+++ b/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h
@@ -51,6 +51,8 @@ class SparseTensorEnumeratorBase;
 
 // These macros ensure consistent error messages, without risk of incuring
 // an additional method call to do so.
+#define ASSERT_VALID_DIM(d)                                                    \
+  assert(d < getDimRank() && "Dimension index is out of bounds");
 #define ASSERT_VALID_LVL(l)                                                    \
   assert(l < getLvlRank() && "Level index is out of bounds");
 #define ASSERT_COMPRESSED_LVL(l)                                               \
@@ -153,6 +155,12 @@ class SparseTensorStorageBase {
   /// Gets the tensor-dimension sizes array.
   const std::vector<uint64_t> &getDimSizes() const { return dimSizes; }
 
+  /// Safely looks up the size of the given tensor-dimension.
+  uint64_t getDimSize(uint64_t d) const {
+    ASSERT_VALID_DIM(d);
+    return dimSizes[d];
+  }
+
   /// Gets the storage-level sizes array.
   const std::vector<uint64_t> &getLvlSizes() const { return lvlSizes; }
 
@@ -694,6 +702,7 @@ class SparseTensorStorage final : public SparseTensorStorageBase {
 #undef ASSERT_COMPRESSED_OR_SINGLETON_LVL
 #undef ASSERT_COMPRESSED_LVL
 #undef ASSERT_VALID_LVL
+#undef ASSERT_VALID_DIM
 
 //===----------------------------------------------------------------------===//
 /// A (higher-order) function object for enumerating the elements of some

diff  --git a/mlir/include/mlir/ExecutionEngine/SparseTensorRuntime.h b/mlir/include/mlir/ExecutionEngine/SparseTensorRuntime.h
index 558799528c4c4..953cbe22804b5 100644
--- a/mlir/include/mlir/ExecutionEngine/SparseTensorRuntime.h
+++ b/mlir/include/mlir/ExecutionEngine/SparseTensorRuntime.h
@@ -137,6 +137,9 @@ MLIR_SPARSETENSOR_FOREVERY_V(DECL_EXPINSERT)
 /// Tensor-storage method to get the size of the given level.
 MLIR_CRUNNERUTILS_EXPORT index_type sparseLvlSize(void *tensor, index_type l);
 
+/// Tensor-storage method to get the size of the given dimension.
+MLIR_CRUNNERUTILS_EXPORT index_type sparseDimSize(void *tensor, index_type d);
+
 /// Tensor-storage method to finalize lexicographic insertions.
 MLIR_CRUNNERUTILS_EXPORT void endInsert(void *tensor);
 

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
index 5017d0e635520..eb2b567a22219 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
@@ -57,62 +57,111 @@ static func::CallOp replaceOpWithFuncCall(RewriterBase &rewriter, Operation *op,
                                                    operands);
 }
 
-/// Generates call to lookup a level-size.
-static Value genLvlSizeCall(OpBuilder &builder, Location loc,
-                            SparseTensorEncodingAttr &enc, Value src,
+/// Generates call to lookup a level-size.  N.B., this only generates
+/// the raw function call, and therefore (intentionally) does not perform
+/// any dim<->lvl conversion or other logic.
+static Value genLvlSizeCall(OpBuilder &builder, Location loc, Value tensor,
                             uint64_t lvl) {
-  // Generate the call.
   StringRef name = "sparseLvlSize";
-  SmallVector<Value, 2> params{  // just two
-      src, constantIndex(builder, loc, toStoredDim(enc, lvl))};
+  SmallVector<Value, 2> params{tensor, constantIndex(builder, loc, lvl)};
   Type iTp = builder.getIndexType();
   return createFuncCall(builder, loc, name, iTp, params, EmitCInterface::Off)
       .getResult(0);
 }
 
-/// Compute the size from type (for static sizes) or from an already-converted
-/// opaque pointer source (for dynamic sizes) at the given dimension.
-//
-// FIXME: Need to rename this function to match `genLvlSizeCall` and hence
-// match the naming convention used in the runtime library.  However, it's
-// not entirely clear that all callsites of this function properly make the
-// "level"-vs-"dimension" distinction; so need to audit each callsite to
-// ensure this still does what they mean (possibly by having two separate
-// functions, one for levels and one for dimensions).  That also means
-// renaming `sizesFromPtr`, `sizesFromType`, etc, to make clear whether
-// they mean to be referring to level-sizes vs dimension-sizes.
-static Value sizeFromPtrAtDim(OpBuilder &builder, Location loc,
-                              SparseTensorEncodingAttr &enc, ShapedType stp,
-                              Value src, unsigned i) {
-  auto shape = stp.getShape();
-  if (shape[i] == ShapedType::kDynamic)
-    return genLvlSizeCall(builder, loc, enc, src, i);
-  return constantIndex(builder, loc, shape[i]);
+/// Generates call to lookup a dimension-size.  N.B., this only generates
+/// the raw function call, and therefore (intentionally) does not perform
+/// any dim<->lvl conversion or other logic.
+static Value genDimSizeCall(OpBuilder &builder, Location loc, Value tensor,
+                            uint64_t dim) {
+  StringRef name = "sparseDimSize";
+  SmallVector<Value, 2> params{tensor, constantIndex(builder, loc, dim)};
+  Type iTp = builder.getIndexType();
+  return createFuncCall(builder, loc, name, iTp, params, EmitCInterface::Off)
+      .getResult(0);
+}
+
+/// Looks up a level-size by returning a statically-computed constant
+/// (when possible), or by calling `genLvlSizeCall` (when dynamic).
+static Value createOrFoldLvlCall(OpBuilder &builder, Location loc,
+                                 SparseTensorEncodingAttr &enc, ShapedType stp,
+                                 Value tensor, unsigned lvl) {
+  // Only sparse tensors have "levels" to query.
+  assert(enc);
+  auto dimOrder = enc.getDimOrdering();
+  // TODO: The following implementation only handles permutations;
+  // we'll need to generalize this to handle arbitrary AffineExpr.
+  //
+  // There's no need to assert `isPermutation` here: because
+  // `getDimPosition` checks that the expr isa `AffineDimExpr`,
+  // which is all we care about (for supporting permutations).
+  unsigned dim = dimOrder ? dimOrder.getDimPosition(lvl) : lvl;
+  auto s = stp.getShape()[dim];
+  if (s != ShapedType::kDynamic)
+    return constantIndex(builder, loc, s);
+  // If we cannot statically compute the size from the shape, then we
+  // must dynamically query it.  (In principle we could also dynamically
+  // compute it, but since we already did so to construct the `tensor`
+  // in the first place, we might as well query rather than recompute.)
+  return genLvlSizeCall(builder, loc, tensor, lvl);
+}
+
+/// Looks up a dimension-size by returning a constant from the shape
+/// (for static sizes), or by calling `genDimSizeCall` (for dynamic sizes
+/// of sparse tensors) or `linalg::createOrFoldDimOp` (for dynamic sizes
+/// of dense tensors).
+static Value createOrFoldDimCall(OpBuilder &builder, Location loc,
+                                 SparseTensorEncodingAttr &enc, ShapedType stp,
+                                 Value tensor, unsigned dim) {
+  auto s = stp.getShape()[dim];
+  if (s != ShapedType::kDynamic)
+    return constantIndex(builder, loc, s);
+  if (enc)
+    return genDimSizeCall(builder, loc, tensor, dim);
+  return linalg::createOrFoldDimOp(builder, loc, tensor, dim);
+}
+
+/// Populates the array with the dimension-sizes of the given tensor.
+static void fillDimSizes(OpBuilder &builder, Location loc,
+                         SparseTensorEncodingAttr &enc, ShapedType stp,
+                         Value tensor, SmallVectorImpl<Value> &out) {
+  unsigned dimRank = stp.getRank();
+  out.reserve(dimRank);
+  for (unsigned d = 0; d < dimRank; d++)
+    out.push_back(createOrFoldDimCall(builder, loc, enc, stp, tensor, d));
 }
 
-/// Populates given sizes array from type (for static sizes) and from
-/// an already-converted opaque pointer source (for dynamic sizes).
-static void sizesFromPtr(OpBuilder &builder, SmallVectorImpl<Value> &sizes,
-                         Location loc, SparseTensorEncodingAttr &enc,
-                         ShapedType stp, Value src) {
-  unsigned rank = stp.getRank();
-  sizes.reserve(rank);
-  for (unsigned i = 0; i < rank; i++)
-    sizes.push_back(sizeFromPtrAtDim(builder, loc, enc, stp, src, i));
+/// Returns an array with the dimension-sizes of the given tensor.
+static SmallVector<Value> getDimSizes(OpBuilder &builder, Location loc,
+                                      SparseTensorEncodingAttr &enc,
+                                      ShapedType stp, Value tensor) {
+  SmallVector<Value> out;
+  fillDimSizes(builder, loc, enc, stp, tensor, out);
+  return out;
 }
 
-/// Populates given sizes array from type.
-static void sizesFromType(OpBuilder &builder, SmallVectorImpl<Value> &sizes,
-                          Location loc, ShapedType stp) {
+/// Populates the array with the dimension-shape of the given `ShapedType`,
+/// where dynamic sizes are represented by zero.
+static void fillDimShape(OpBuilder &builder, Location loc, ShapedType stp,
+                         SmallVectorImpl<Value> &out) {
   auto shape = stp.getShape();
-  unsigned rank = stp.getRank();
-  sizes.reserve(rank);
-  for (unsigned i = 0; i < rank; i++) {
-    uint64_t s = shape[i] == ShapedType::kDynamic ? 0 : shape[i];
-    sizes.push_back(constantIndex(builder, loc, s));
+  unsigned dimRank = stp.getRank();
+  out.reserve(dimRank);
+  for (unsigned d = 0; d < dimRank; d++) {
+    auto s = shape[d] == ShapedType::kDynamic ? 0 : shape[d];
+    out.push_back(constantIndex(builder, loc, s));
   }
 }
 
+/// Returns an array with the dimension-shape of the given `ShapedType`,
+/// where dynamic sizes are represented by zero.
+static SmallVector<Value> getDimShape(OpBuilder &builder, Location loc,
+                                      ShapedType stp) {
+  SmallVector<Value> out;
+  fillDimShape(builder, loc, stp, out);
+  return out;
+}
+
 /// Populates the given sizes array for concatenation from type (for static
 /// sizes) and from an already-converted opaque pointer source (for dynamic
 /// sizes).
@@ -128,7 +177,7 @@ static void concatSizesFromInputs(OpBuilder &builder,
   // compute the size of the concatenation dimension if necessary.
   if (srcEnc)
     // Reuses sizes from an arbitrary input tensor is fine.
-    sizesFromPtr(builder, sizes, loc, srcEnc, srcTp, srcs[0]);
+    fillDimSizes(builder, loc, srcEnc, srcTp, srcs[0], sizes);
   else
     sizesFromSrc(builder, sizes, loc, srcs[0]);
 
@@ -142,8 +191,7 @@ static void concatSizesFromInputs(OpBuilder &builder,
       auto srcTp = srcs[i].getType().cast<ShapedType>();
       auto encSrc = getSparseTensorEncoding(srcTp);
       Value srcSz =
-          encSrc ? sizeFromPtrAtDim(builder, loc, encSrc, srcTp, srcs[i], dim)
-                 : linalg::createOrFoldDimOp(builder, loc, srcs[i], dim);
+          createOrFoldDimCall(builder, loc, encSrc, srcTp, srcs[i], dim);
       // Sum up all the sizes.
       sizes[dim] = builder.create<arith::AddIOp>(loc, sizes[dim], srcSz);
     }
@@ -489,9 +537,6 @@ genSparse2SparseReshape(ReshapeOp op, typename ReshapeOp::Adaptor adaptor,
   auto encDst = getSparseTensorEncoding(dstTp);
   if (!encDst || !encSrc)
     return failure();
-
-  unsigned srcRank = srcTp.getRank();
-  unsigned dstRank = dstTp.getRank();
   Type elemTp = srcTp.getElementType();
   assert(elemTp == dstTp.getElementType() &&
          "reshape should not change element type");
@@ -499,26 +544,26 @@ genSparse2SparseReshape(ReshapeOp op, typename ReshapeOp::Adaptor adaptor,
   auto noPerm = SparseTensorEncodingAttr::get(
       op->getContext(), encSrc.getDimLevelType(), AffineMap(), AffineMap(),
       encSrc.getPointerBitWidth(), encSrc.getIndexBitWidth());
-  SmallVector<Value> srcSizes;
-  sizesFromPtr(rewriter, srcSizes, loc, encSrc, srcTp, adaptor.getSrc());
+  SmallVector<Value> srcDimSizes =
+      getDimSizes(rewriter, loc, encSrc, srcTp, adaptor.getSrc());
   NewCallParams params(rewriter, loc);
-  Value iter = params.genBuffers(noPerm, srcSizes, srcTp)
+  Value iter = params.genBuffers(noPerm, srcDimSizes, srcTp)
                    .genNewCall(Action::kToIterator, adaptor.getSrc());
   // Start a new COO for the destination tensor.
-  SmallVector<Value> dstSizes;
-  if (dstTp.hasStaticShape()) {
-    sizesFromType(rewriter, dstSizes, loc, dstTp);
-  } else {
-    ArrayRef<int64_t> dstShape = dstTp.getShape();
-    genReshapeDstShape(loc, rewriter, dstSizes, srcSizes, dstShape,
-                       op.getReassociationIndices());
-  }
-  Value coo =
-      params.genBuffers(encDst, dstSizes, dstTp).genNewCall(Action::kEmptyCOO);
+  SmallVector<Value> dstDimSizes;
+  if (dstTp.hasStaticShape())
+    // Static "shapes" are in fact "sizes".
+    fillDimShape(rewriter, loc, dstTp, dstDimSizes);
+  else
+    genReshapeDstShape(loc, rewriter, dstDimSizes, srcDimSizes,
+                       dstTp.getShape(), op.getReassociationIndices());
+  Value coo = params.genBuffers(encDst, dstDimSizes, dstTp)
+                  .genNewCall(Action::kEmptyCOO);
   Value dstPerm = params.getDim2LvlMap();
   // Construct a while loop over the iterator.
-  Value srcIdx = genAlloca(rewriter, loc, srcRank, rewriter.getIndexType());
-  Value dstIdx = genAlloca(rewriter, loc, dstRank, rewriter.getIndexType());
+  Type iTp = rewriter.getIndexType();
+  Value srcIdx = genAlloca(rewriter, loc, srcTp.getRank(), iTp);
+  Value dstIdx = genAlloca(rewriter, loc, dstTp.getRank(), iTp);
   Value elemPtr = genAllocaScalar(rewriter, loc, elemTp);
   SmallVector<Value> noArgs;
   SmallVector<Type> noTypes;
@@ -532,7 +577,7 @@ genSparse2SparseReshape(ReshapeOp op, typename ReshapeOp::Adaptor adaptor,
   Block *after = rewriter.createBlock(&whileOp.getAfter(), {}, noTypes);
   rewriter.setInsertionPointToStart(after);
   translateIndices(loc, rewriter, op.getReassociationIndices(), dstTp, srcTp,
-                   dstIdx, srcIdx, dstSizes, srcSizes);
+                   dstIdx, srcIdx, dstDimSizes, srcDimSizes);
   genAddEltCall(rewriter, loc, elemTp, coo, elemPtr, dstIdx, dstPerm);
   rewriter.create<scf::YieldOp>(loc);
   // Final call to construct sparse tensor storage and free temporary resources.
@@ -566,10 +611,9 @@ static void genSparseCOOIterationLoop(
   auto noPerm = SparseTensorEncodingAttr::get(
       rewriter.getContext(), enc.getDimLevelType(), AffineMap(), AffineMap(),
       enc.getPointerBitWidth(), enc.getIndexBitWidth());
-  SmallVector<Value> sizes;
-  sizesFromPtr(rewriter, sizes, loc, noPerm, tensorTp, t);
+  SmallVector<Value> dimSizes = getDimSizes(rewriter, loc, noPerm, tensorTp, t);
   Value iter = NewCallParams(rewriter, loc)
-                   .genBuffers(noPerm, sizes, tensorTp)
+                   .genBuffers(noPerm, dimSizes, tensorTp)
                    .genNewCall(Action::kToIterator, t);
 
   // Construct a while loop over the iterator.
@@ -664,7 +708,7 @@ class SparseReturnConverter : public OpConversionPattern<func::ReturnOp> {
   }
 };
 
-/// Sparse conversion rule for dimension accesses.
+/// Sparse conversion rule for accessing dimension-sizes.
 class SparseTensorToDimSizeConverter
     : public OpConversionPattern<tensor::DimOp> {
 public:
@@ -672,18 +716,19 @@ class SparseTensorToDimSizeConverter
   LogicalResult
   matchAndRewrite(tensor::DimOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    // Only rewrite annotated DimOp with constant index.
-    auto enc = getSparseTensorEncoding(op.getSource().getType());
+    auto stp = op.getSource().getType().cast<ShapedType>();
+    // Only rewrite sparse DimOp.
+    auto enc = getSparseTensorEncoding(stp);
     if (!enc)
       return failure();
-    Optional<int64_t> index = op.getConstantIndex();
-    if (!index)
+    // Only rewrite DimOp with constant index.
+    Optional<int64_t> dim = op.getConstantIndex();
+    if (!dim)
       return failure();
     // Generate the call.
     Value src = adaptor.getOperands()[0];
-    int64_t idx = *index;
-    rewriter.replaceOp(op,
-                       genLvlSizeCall(rewriter, op->getLoc(), enc, src, idx));
+    rewriter.replaceOp(
+        op, createOrFoldDimCall(rewriter, op->getLoc(), enc, stp, src, *dim));
     return success();
   }
 };
@@ -734,8 +779,7 @@ class SparseTensorNewConverter : public OpConversionPattern<NewOp> {
     const unsigned lvlRank = enc.getDimLevelType().size();
     // Construct the dimShape.
     const auto dimShape = stp.getShape();
-    SmallVector<Value> dimShapeValues;
-    sizesFromType(rewriter, dimShapeValues, loc, stp);
+    SmallVector<Value> dimShapeValues = getDimShape(rewriter, loc, stp);
     Value dimShapeBuffer = genBuffer(rewriter, loc, dimShapeValues);
     // Allocate `SparseTensorReader` and perform all initial setup that
     // does not depend on lvlSizes (nor dim2lvl, lvl2dim, etc).
@@ -890,10 +934,10 @@ class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> {
         rewriter.replaceOp(op, adaptor.getOperands()); // hidden nop cast
         return success();
       }
-      SmallVector<Value> sizes;
       NewCallParams params(rewriter, loc);
       ShapedType stp = srcType.cast<ShapedType>();
-      sizesFromPtr(rewriter, sizes, loc, encSrc, stp, src);
+      SmallVector<Value> dimSizes =
+          getDimSizes(rewriter, loc, encSrc, stp, src);
       bool useDirectConversion;
       switch (options.sparseToSparseStrategy) {
       case SparseToSparseConversionStrategy::kViaCOO:
@@ -909,7 +953,7 @@ class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> {
         break;
       }
       if (useDirectConversion) {
-        rewriter.replaceOp(op, params.genBuffers(encDst, sizes, stp)
+        rewriter.replaceOp(op, params.genBuffers(encDst, dimSizes, stp)
                                    .genNewCall(Action::kSparseToSparse, src));
       } else { // use via-COO conversion.
         // Set up encoding with right mix of src and dst so that the two
@@ -922,8 +966,8 @@ class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> {
         // TODO: This is the only place where `kToCOO` (or `kToIterator`)
         // is called with a non-identity permutation.  Is there any clean
         // way to push the permutation over to the `kFromCOO` side instead?
-        Value coo =
-            params.genBuffers(enc, sizes, stp).genNewCall(Action::kToCOO, src);
+        Value coo = params.genBuffers(enc, dimSizes, stp)
+                        .genNewCall(Action::kToCOO, src);
         Value dst = params.setTemplateTypes(encDst, stp)
                         .genNewCall(Action::kFromCOO, coo);
         genDelCOOCall(rewriter, loc, stp.getElementType(), coo);
@@ -950,17 +994,17 @@ class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> {
           op->getContext(),
           SmallVector<DimLevelType>(rank, DimLevelType::Dense), AffineMap(),
           AffineMap(), encSrc.getPointerBitWidth(), encSrc.getIndexBitWidth());
-      SmallVector<Value> sizes;
-      sizesFromPtr(rewriter, sizes, loc, encSrc, srcTensorTp, src);
+      SmallVector<Value> dimSizes =
+          getDimSizes(rewriter, loc, encSrc, srcTensorTp, src);
       Value iter = NewCallParams(rewriter, loc)
-                       .genBuffers(encDst, sizes, dstTensorTp)
+                       .genBuffers(encDst, dimSizes, dstTensorTp)
                        .genNewCall(Action::kToIterator, src);
       Value ind = genAlloca(rewriter, loc, rank, rewriter.getIndexType());
       Value elemPtr = genAllocaScalar(rewriter, loc, elemTp);
       Block *insertionBlock = rewriter.getInsertionBlock();
       // TODO: Dense buffers should be allocated/deallocated via the callback
       // in BufferizationOptions.
-      Value dst = allocDenseTensor(rewriter, loc, dstTensorTp, sizes);
+      Value dst = allocDenseTensor(rewriter, loc, dstTensorTp, dimSizes);
       SmallVector<Value> noArgs;
       SmallVector<Type> noTypes;
       auto whileOp = rewriter.create<scf::WhileOp>(loc, noTypes, noArgs);
@@ -1196,12 +1240,12 @@ class SparseTensorExpandConverter : public OpConversionPattern<ExpandOp> {
     Type idxType = rewriter.getIndexType();
     // All initialization should be done on entry of the loop nest.
     rewriter.setInsertionPointAfter(op.getTensor().getDefiningOp());
-    // Determine the size for access expansion (always the innermost stored
-    // dimension size, translated back to original dimension).
-    auto enc = getSparseTensorEncoding(srcType);
-    unsigned innerDim = toOrigDim(srcType, srcType.getRank() - 1);
-    auto sz = sizeFromPtrAtDim(rewriter, loc, enc, srcType, adaptor.getTensor(),
-                               innerDim);
+    // Get the cardinality of valid coordinates for the innermost level.
+    auto srcEnc = getSparseTensorEncoding(srcType);
+    unsigned lvlRank =
+        srcEnc ? srcEnc.getDimLevelType().size() : srcType.getRank();
+    Value sz = createOrFoldLvlCall(rewriter, loc, srcEnc, srcType,
+                                   adaptor.getTensor(), lvlRank - 1);
     // Allocate temporary buffers for values, filled-switch, and indices.
     // We do not use stack buffers for this, since the expanded size may
     // be rather large (as it envelops a single expanded dense dimension).
@@ -1377,10 +1421,8 @@ class SparseTensorConcatConverter : public OpConversionPattern<ConcatenateOp> {
       }
       // Accumulate offset.
       // TODO: avoid calling sparseDimSize multiple times by caching the result!
-      Value curDim = encSrc ? sizeFromPtrAtDim(rewriter, loc, encSrc, srcTp,
-                                               adaptedOp, concatDim)
-                            : linalg::createOrFoldDimOp(rewriter, loc,
-                                                        adaptedOp, concatDim);
+      Value curDim = createOrFoldDimCall(rewriter, loc, encSrc, srcTp,
+                                         adaptedOp, concatDim);
 
       offset = rewriter.create<arith::AddIOp>(loc, offset, curDim);
     }
@@ -1410,13 +1452,13 @@ class SparseTensorOutConverter : public OpConversionPattern<OutOp> {
     // Convert to default permuted COO.
     Value src = adaptor.getOperands()[0];
     auto encSrc = getSparseTensorEncoding(srcType);
-    SmallVector<Value> sizes;
-    sizesFromPtr(rewriter, sizes, loc, encSrc, srcType, src);
+    SmallVector<Value> dimSizes =
+        getDimSizes(rewriter, loc, encSrc, srcType, src);
     auto enc = SparseTensorEncodingAttr::get(
         op->getContext(), encSrc.getDimLevelType(), AffineMap(), AffineMap(),
         encSrc.getPointerBitWidth(), encSrc.getIndexBitWidth());
     Value coo = NewCallParams(rewriter, loc)
-                    .genBuffers(enc, sizes, srcType)
+                    .genBuffers(enc, dimSizes, srcType)
                     .genNewCall(Action::kToCOO, src);
     // Then output the tensor to external file with indices in the externally
     // visible lexicographic index order. A sort is required if the source was

diff  --git a/mlir/lib/ExecutionEngine/SparseTensorRuntime.cpp b/mlir/lib/ExecutionEngine/SparseTensorRuntime.cpp
index dbb08d5e2f56a..c9c404ec2ddc7 100644
--- a/mlir/lib/ExecutionEngine/SparseTensorRuntime.cpp
+++ b/mlir/lib/ExecutionEngine/SparseTensorRuntime.cpp
@@ -779,8 +779,12 @@ MLIR_SPARSETENSOR_FOREVERY_V(IMPL_OUTNEXT)
 //
 //===----------------------------------------------------------------------===//
 
-index_type sparseLvlSize(void *tensor, index_type x) {
-  return static_cast<SparseTensorStorageBase *>(tensor)->getLvlSize(x);
+index_type sparseLvlSize(void *tensor, index_type l) {
+  return static_cast<SparseTensorStorageBase *>(tensor)->getLvlSize(l);
+}
+
+index_type sparseDimSize(void *tensor, index_type d) {
+  return static_cast<SparseTensorStorageBase *>(tensor)->getDimSize(d);
 }
 
 void endInsert(void *tensor) {

diff  --git a/mlir/test/Dialect/SparseTensor/conversion.mlir b/mlir/test/Dialect/SparseTensor/conversion.mlir
index 2264066353169..dc4efae50c01b 100644
--- a/mlir/test/Dialect/SparseTensor/conversion.mlir
+++ b/mlir/test/Dialect/SparseTensor/conversion.mlir
@@ -40,7 +40,7 @@ func.func @sparse_nop(%arg0: tensor<?xf64, #SparseVector>) -> tensor<?xf64, #Spa
 // CHECK-LABEL: func @sparse_dim1d(
 //  CHECK-SAME: %[[A:.*]]: !llvm.ptr<i8>)
 //       CHECK: %[[C:.*]] = arith.constant 0 : index
-//       CHECK: %[[D:.*]] = call @sparseLvlSize(%[[A]], %[[C]])
+//       CHECK: %[[D:.*]] = call @sparseDimSize(%[[A]], %[[C]])
 //       CHECK: return %[[D]] : index
 func.func @sparse_dim1d(%arg0: tensor<?xf64, #SparseVector>) -> index {
   %c = arith.constant 0 : index
@@ -48,28 +48,28 @@ func.func @sparse_dim1d(%arg0: tensor<?xf64, #SparseVector>) -> index {
   return %0 : index
 }
 
+// Querying the size of dimension 1 should do so; i.e., it should
+// not be permuted into a query for the size of level 2 (even though
+// dimension 1 is stored as level 2).
 // CHECK-LABEL: func @sparse_dim3d(
 //  CHECK-SAME: %[[A:.*]]: !llvm.ptr<i8>)
-//       CHECK: %[[C:.*]] = arith.constant 2 : index
-//       CHECK: %[[D:.*]] = call @sparseLvlSize(%[[A]], %[[C]])
+//       CHECK: %[[C:.*]] = arith.constant 1 : index
+//       CHECK: %[[D:.*]] = call @sparseDimSize(%[[A]], %[[C]])
 //       CHECK: return %[[D]] : index
 func.func @sparse_dim3d(%arg0: tensor<?x?x?xf64, #SparseTensor>) -> index {
-  // Querying for dimension 1 in the tensor type needs to be
-  // permuted into querying for dimension 2 in the stored sparse
-  // tensor scheme, since the latter honors the dimOrdering.
   %c = arith.constant 1 : index
   %0 = tensor.dim %arg0, %c : tensor<?x?x?xf64, #SparseTensor>
   return %0 : index
 }
 
+// Querying the size of a static dimension should be folded into a
+// constant (and we should be sure to get the size of dimension 1,
+// not dimension 2 nor level 1).
 // CHECK-LABEL: func @sparse_dim3d_const(
 //  CHECK-SAME: %[[A:.*]]: !llvm.ptr<i8>)
 //       CHECK: %[[C:.*]] = arith.constant 20 : index
 //       CHECK: return %[[C]] : index
 func.func @sparse_dim3d_const(%arg0: tensor<10x20x30xf64, #SparseTensor>) -> index {
-  // Querying for dimension 1 in the tensor type can be directly
-  // folded into the right value (even though it corresponds
-  // to dimension 2 in the stored sparse tensor scheme).
   %c = arith.constant 1 : index
   %0 = tensor.dim %arg0, %c : tensor<10x20x30xf64, #SparseTensor>
   return %0 : index
@@ -361,7 +361,7 @@ func.func @sparse_expansion2() -> memref<?xindex> {
 // CHECK-LABEL: func @sparse_expansion3(
 //       CHECK: %[[C1:.*]] = arith.constant 1 : index
 //       CHECK: %[[N:.*]] = call @newSparseTensor
-//       CHECK: %[[S:.*]] = call @sparseLvlSize(%[[N]], %c1) : (!llvm.ptr<i8>, index) -> index
+//       CHECK: %[[S:.*]] = call @sparseLvlSize(%[[N]], %[[C1]])
 //       CHECK: %[[A:.*]] = memref.alloc(%[[S]]) : memref<?xf64>
 //       CHECK: %[[B:.*]] = memref.alloc(%[[S]]) : memref<?xi1>
 //       CHECK: %[[C:.*]] = memref.alloc(%[[S]]) : memref<?xindex>

diff  --git a/mlir/test/Dialect/SparseTensor/convert_sparse2dense.mlir b/mlir/test/Dialect/SparseTensor/convert_sparse2dense.mlir
index 2c5de95d775ef..b847a277859fb 100644
--- a/mlir/test/Dialect/SparseTensor/convert_sparse2dense.mlir
+++ b/mlir/test/Dialect/SparseTensor/convert_sparse2dense.mlir
@@ -70,7 +70,7 @@ func.func @sparse_convert_1d(%arg0: tensor<13xi32, #SparseVector>) -> tensor<13x
 //   CHECK-DAG: memref.store %[[DenseDLT]], %[[LvlTypes]][%[[I0]]] : memref<1xi8>
 //   CHECK-DAG: %[[DimSizes:.*]] = memref.alloca() : memref<1xindex>
 //   CHECK-DAG: %[[DimSizesP:.*]] = memref.cast %[[DimSizes]] : memref<1xindex> to memref<?xindex>
-//   CHECK-DAG: %[[SizeI0:.*]] = call @sparseLvlSize(%[[Arg]], %[[I0]]) : (!llvm.ptr<i8>, index) -> index
+//   CHECK-DAG: %[[SizeI0:.*]] = call @sparseDimSize(%[[Arg]], %[[I0]]) : (!llvm.ptr<i8>, index) -> index
 //   CHECK-DAG: memref.store %[[SizeI0]], %[[DimSizes]][%[[I0]]] : memref<1xindex>
 //   CHECK-DAG: %[[LvlSizes:.*]] = memref.alloca() : memref<1xindex>
 //   CHECK-DAG: %[[LvlSizesP:.*]] = memref.cast %[[LvlSizes]] : memref<1xindex> to memref<?xindex>
@@ -175,7 +175,7 @@ func.func @sparse_convert_2d(%arg0: tensor<2x4xf64, #SparseMatrix>) -> tensor<2x
 //   CHECK-DAG: memref.store %[[DenseDLT]], %[[LvlTypes]][%[[I1]]] : memref<2xi8>
 //   CHECK-DAG: %[[DimSizes:.*]] = memref.alloca() : memref<2xindex>
 //   CHECK-DAG: %[[DimSizesP:.*]] = memref.cast %[[DimSizes]] : memref<2xindex> to memref<?xindex>
-//   CHECK-DAG: %[[SizeI0:.*]] = call @sparseLvlSize(%[[Arg]], %[[I0]]) : (!llvm.ptr<i8>, index) -> index
+//   CHECK-DAG: %[[SizeI0:.*]] = call @sparseDimSize(%[[Arg]], %[[I0]]) : (!llvm.ptr<i8>, index) -> index
 //   CHECK-DAG: memref.store %[[SizeI0]], %[[DimSizes]][%[[I0]]] : memref<2xindex>
 //   CHECK-DAG: memref.store %[[I4]], %[[DimSizes]][%[[I1]]] : memref<2xindex>
 //   CHECK-DAG: %[[LvlSizes:.*]] = memref.alloca() : memref<2xindex>
@@ -223,7 +223,7 @@ func.func @sparse_convert_2d_dyn0(%arg0: tensor<?x4xf64, #SparseMatrix>) -> tens
 //   CHECK-DAG: memref.store %[[DenseDLT]], %[[LvlTypes]][%[[I1]]] : memref<2xi8>
 //   CHECK-DAG: %[[DimSizes:.*]] = memref.alloca() : memref<2xindex>
 //   CHECK-DAG: %[[DimSizesP:.*]] = memref.cast %[[DimSizes]] : memref<2xindex> to memref<?xindex>
-//   CHECK-DAG: %[[SizeI1:.*]] = call @sparseLvlSize(%[[Arg]], %[[I1]]) : (!llvm.ptr<i8>, index) -> index
+//   CHECK-DAG: %[[SizeI1:.*]] = call @sparseDimSize(%[[Arg]], %[[I1]]) : (!llvm.ptr<i8>, index) -> index
 //   CHECK-DAG: memref.store %[[I2]], %[[DimSizes]][%[[I0]]] : memref<2xindex>
 //   CHECK-DAG: memref.store %[[SizeI1]], %[[DimSizes]][%[[I1]]] : memref<2xindex>
 //   CHECK-DAG: %[[LvlSizes:.*]] = memref.alloca() : memref<2xindex>
@@ -270,8 +270,8 @@ func.func @sparse_convert_2d_dyn1(%arg0: tensor<2x?xf64, #SparseMatrix>) -> tens
 //   CHECK-DAG: memref.store %[[DenseDLT]], %[[LvlTypes]][%[[I1]]] : memref<2xi8>
 //   CHECK-DAG: %[[DimSizes:.*]] = memref.alloca() : memref<2xindex>
 //   CHECK-DAG: %[[DimSizesP:.*]] = memref.cast %[[DimSizes]] : memref<2xindex> to memref<?xindex>
-//   CHECK-DAG: %[[SizeI0:.*]] = call @sparseLvlSize(%[[Arg]], %[[I0]]) : (!llvm.ptr<i8>, index) -> index
-//   CHECK-DAG: %[[SizeI1:.*]] = call @sparseLvlSize(%[[Arg]], %[[I1]]) : (!llvm.ptr<i8>, index) -> index
+//   CHECK-DAG: %[[SizeI0:.*]] = call @sparseDimSize(%[[Arg]], %[[I0]]) : (!llvm.ptr<i8>, index) -> index
+//   CHECK-DAG: %[[SizeI1:.*]] = call @sparseDimSize(%[[Arg]], %[[I1]]) : (!llvm.ptr<i8>, index) -> index
 //   CHECK-DAG: memref.store %[[SizeI0]], %[[DimSizes]][%[[I0]]] : memref<2xindex>
 //   CHECK-DAG: memref.store %[[SizeI1]], %[[DimSizes]][%[[I1]]] : memref<2xindex>
 //   CHECK-DAG: %[[LvlSizes:.*]] = memref.alloca() : memref<2xindex>

diff  --git a/mlir/test/Dialect/SparseTensor/sparse_expand.mlir b/mlir/test/Dialect/SparseTensor/sparse_expand.mlir
index 946a828274b24..785033494bf2b 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_expand.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_expand.mlir
@@ -46,10 +46,11 @@
 // CHECK-SPARSE: return %[[RET]]
 //
 // CHECK-CONVERT-LABEL: func @kernel(
-// CHECK-CONVERT: %[[C:.*]] = arith.constant 0 : index
-// CHECK-CONVERT: %{{.*}} = call @sparseLvlSize
-// CHECK-CONVERT: %[[N:.*]] = call @newSparseTensor
-// CHECK-CONVERT: %[[S:.*]] = call @sparseLvlSize(%[[N]], %[[C]])
+// CHECK-CONVERT-SAME: %[[A:.*]]: !llvm.ptr<i8>) -> !llvm.ptr<i8>
+// CHECK-CONVERT: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-CONVERT: %[[N:.*]] = call @sparseDimSize(%[[A]], %[[C0]])
+// CHECK-CONVERT: %[[V:.*]] = call @newSparseTensor
+// CHECK-CONVERT: %[[S:.*]] = call @sparseLvlSize(%[[V]], %[[C0]])
 // CHECK-CONVERT: %[[A:.*]] = memref.alloc(%[[S]]) : memref<?xf64>
 // CHECK-CONVERT: %[[B:.*]] = memref.alloc(%[[S]]) : memref<?xi1>
 // CHECK-CONVERT: %[[C:.*]] = memref.alloc(%[[S]]) : memref<?xindex>

diff  --git a/mlir/test/Dialect/SparseTensor/sparse_perm_lower.mlir b/mlir/test/Dialect/SparseTensor/sparse_perm_lower.mlir
index a8da710f0bd0c..56d3168a7634b 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_perm_lower.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_perm_lower.mlir
@@ -49,24 +49,24 @@
 // CHECK-HIR:         }
 //
 // CHECK-MIR-LABEL:   func @sparse_dynamic_dims(
-// CHECK-MIR-SAME:      %[[VAL_0:.*]]: !llvm.ptr<i8>,
-// CHECK-MIR-SAME:      %[[VAL_1:.*]]: tensor<f32>) -> tensor<f32> {
-// CHECK-MIR-DAG:       %[[VAL_2:.*]] = arith.constant 2 : index
-// CHECK-MIR-DAG:       %[[VAL_3:.*]] = arith.constant 1 : index
-// CHECK-MIR-DAG:       %[[VAL_4:.*]] = arith.constant 0 : index
-// CHECK-MIR-DAG:       %[[VAL_5:.*]] = call @sparseLvlSize(%[[VAL_0]], %[[VAL_4]]) : (!llvm.ptr<i8>, index) -> index
-// CHECK-MIR-DAG:       %[[VAL_6:.*]] = call @sparseLvlSize(%[[VAL_0]], %[[VAL_3]]) : (!llvm.ptr<i8>, index) -> index
-// CHECK-MIR-DAG:       %[[VAL_7:.*]] = call @sparseLvlSize(%[[VAL_0]], %[[VAL_2]]) : (!llvm.ptr<i8>, index) -> index
-// CHECK-MIR-DAG:       %[[VAL_8:.*]] = call @sparseValuesF32(%[[VAL_0]]) : (!llvm.ptr<i8>) -> memref<?xf32>
-// CHECK-MIR-DAG:       %[[VAL_10:.*]] = bufferization.to_memref %[[VAL_1]] : memref<f32>
-// CHECK-MIR:           %[[VAL_11:.*]] = tensor.extract %[[VAL_1]][] : tensor<f32>
-// CHECK-MIR:           %[[VAL_12:.*]] = scf.for %[[VAL_13:.*]] = %[[VAL_4]] to %[[VAL_5]] step %[[VAL_3]] iter_args(%[[VAL_14:.*]] = %[[VAL_11]]) -> (f32) {
-// CHECK-MIR:             %[[VAL_15:.*]] = scf.for %[[VAL_16:.*]] = %[[VAL_4]] to %[[VAL_6]] step %[[VAL_3]] iter_args(%[[VAL_17:.*]] = %[[VAL_14]]) -> (f32) {
-// CHECK-MIR:               %[[VAL_18:.*]] = arith.muli %[[VAL_6]], %[[VAL_13]] : index
-// CHECK-MIR:               %[[VAL_19:.*]] = arith.addi %[[VAL_18]], %[[VAL_16]] : index
-// CHECK-MIR:               %[[VAL_20:.*]] = scf.for %[[VAL_21:.*]] = %[[VAL_4]] to %[[VAL_7]] step %[[VAL_3]] iter_args(%[[VAL_22:.*]] = %[[VAL_17]]) -> (f32) {
-// CHECK-MIR:                 %[[VAL_23:.*]] = arith.muli %[[VAL_7]], %[[VAL_19]] : index
-// CHECK-MIR:                 %[[VAL_24:.*]] = arith.addi %[[VAL_23]], %[[VAL_21]] : index
+// CHECK-MIR-SAME:      %[[ARGA:.*]]: !llvm.ptr<i8>,
+// CHECK-MIR-SAME:      %[[ARGX:.*]]: tensor<f32>) -> tensor<f32> {
+// CHECK-MIR-DAG:       %[[I0:.*]] = arith.constant 0 : index
+// CHECK-MIR-DAG:       %[[I1:.*]] = arith.constant 1 : index
+// CHECK-MIR-DAG:       %[[I2:.*]] = arith.constant 2 : index
+// CHECK-MIR-DAG:       %[[DimSize0:.*]] = call @sparseDimSize(%[[ARGA]], %[[I0]])
+// CHECK-MIR-DAG:       %[[DimSize1:.*]] = call @sparseDimSize(%[[ARGA]], %[[I1]])
+// CHECK-MIR-DAG:       %[[DimSize2:.*]] = call @sparseDimSize(%[[ARGA]], %[[I2]])
+// CHECK-MIR-DAG:       %[[VAL_8:.*]] = call @sparseValuesF32(%[[ARGA]]) : (!llvm.ptr<i8>) -> memref<?xf32>
+// CHECK-MIR-DAG:       %[[VAL_10:.*]] = bufferization.to_memref %[[ARGX]] : memref<f32>
+// CHECK-MIR:           %[[VAL_11:.*]] = tensor.extract %[[ARGX]][] : tensor<f32>
+// CHECK-MIR:           %[[VAL_12:.*]] = scf.for %[[D2:.*]] = %[[I0]] to %[[DimSize2]] step %[[I1]] iter_args(%[[VAL_14:.*]] = %[[VAL_11]]) -> (f32) {
+// CHECK-MIR:             %[[VAL_15:.*]] = scf.for %[[D0:.*]] = %[[I0]] to %[[DimSize0]] step %[[I1]] iter_args(%[[VAL_17:.*]] = %[[VAL_14]]) -> (f32) {
+// CHECK-MIR:               %[[VAL_18:.*]] = arith.muli %[[DimSize0]], %[[D2]] : index
+// CHECK-MIR:               %[[VAL_19:.*]] = arith.addi %[[VAL_18]], %[[D0]] : index
+// CHECK-MIR:               %[[VAL_20:.*]] = scf.for %[[D1:.*]] = %[[I0]] to %[[DimSize1]] step %[[I1]] iter_args(%[[VAL_22:.*]] = %[[VAL_17]]) -> (f32) {
+// CHECK-MIR:                 %[[VAL_23:.*]] = arith.muli %[[DimSize1]], %[[VAL_19]] : index
+// CHECK-MIR:                 %[[VAL_24:.*]] = arith.addi %[[VAL_23]], %[[D1]] : index
 // CHECK-MIR:                 %[[VAL_25:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_24]]] : memref<?xf32>
 // CHECK-MIR:                 %[[VAL_26:.*]] = arith.addf %[[VAL_22]], %[[VAL_25]] : f32
 // CHECK-MIR:                 scf.yield %[[VAL_26]] : f32


        


More information about the Mlir-commits mailing list