[Mlir-commits] [mlir] 86f91e4 - [mlir][sparse] Cleaning up the dim/lvl distinction in SparseTensorConversion
wren romano
llvmlistbot at llvm.org
Mon Dec 5 16:59:50 PST 2022
Author: wren romano
Date: 2022-12-05T16:59:42-08:00
New Revision: 86f91e45a22bbb981ede3439c7241ee92ea522ec
URL: https://github.com/llvm/llvm-project/commit/86f91e45a22bbb981ede3439c7241ee92ea522ec
DIFF: https://github.com/llvm/llvm-project/commit/86f91e45a22bbb981ede3439c7241ee92ea522ec.diff
LOG: [mlir][sparse] Cleaning up the dim/lvl distinction in SparseTensorConversion
This change cleans up the conversion pass re the "dim"-vs-"lvl" and "sizes"-vs-"shape" distinctions of the runtime. A quick synopsis includes:
* Adds new `SparseTensorStorageBase::getDimSize` method, with `sparseDimSize` wrapper in SparseTensorRuntime.h, and `genDimSizeCall` generator in SparseTensorConversion.cpp
* Changes `genLvlSizeCall` to perform no logic, just generate the function call.
* Adds `createOrFold{Dim,Lvl}Call` functions to handle the logic of replacing `gen{Dim,Lvl}SizeCall` with constants whenever possible. The `createOrFoldDimCall` function replaces the old `sizeFromPtrAtDim`.
* Adds `{get,fill}DimSizes` functions for iterating `createOrFoldDimCall` across the whole type. These functions replace the old `sizesFromPtr`.
* Adds `{get,fill}DimShape` functions for lowering a `ShapedType` into constants. These functions replace the old `sizesFromType`.
* Changes the `DimOp` rewrite to do the right thing.
* Changes the `ExpandOp` rewrite to compute the proper expansion size.
Depends On D138365
Reviewed By: aartbik
Differential Revision: https://reviews.llvm.org/D139165
Added:
Modified:
mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h
mlir/include/mlir/ExecutionEngine/SparseTensorRuntime.h
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
mlir/lib/ExecutionEngine/SparseTensorRuntime.cpp
mlir/test/Dialect/SparseTensor/conversion.mlir
mlir/test/Dialect/SparseTensor/convert_sparse2dense.mlir
mlir/test/Dialect/SparseTensor/sparse_expand.mlir
mlir/test/Dialect/SparseTensor/sparse_perm_lower.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h b/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h
index a8986a86835dd..c5e310937efe4 100644
--- a/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h
+++ b/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h
@@ -51,6 +51,8 @@ class SparseTensorEnumeratorBase;
// These macros ensure consistent error messages, without risk of incuring
// an additional method call to do so.
+#define ASSERT_VALID_DIM(d) \
+ assert(d < getDimRank() && "Dimension index is out of bounds");
#define ASSERT_VALID_LVL(l) \
assert(l < getLvlRank() && "Level index is out of bounds");
#define ASSERT_COMPRESSED_LVL(l) \
@@ -153,6 +155,12 @@ class SparseTensorStorageBase {
/// Gets the tensor-dimension sizes array.
const std::vector<uint64_t> &getDimSizes() const { return dimSizes; }
+ /// Safely looks up the size of the given tensor-dimension.
+ uint64_t getDimSize(uint64_t d) const {
+ ASSERT_VALID_DIM(d);
+ return dimSizes[d];
+ }
+
/// Gets the storage-level sizes array.
const std::vector<uint64_t> &getLvlSizes() const { return lvlSizes; }
@@ -694,6 +702,7 @@ class SparseTensorStorage final : public SparseTensorStorageBase {
#undef ASSERT_COMPRESSED_OR_SINGLETON_LVL
#undef ASSERT_COMPRESSED_LVL
#undef ASSERT_VALID_LVL
+#undef ASSERT_VALID_DIM
//===----------------------------------------------------------------------===//
/// A (higher-order) function object for enumerating the elements of some
diff --git a/mlir/include/mlir/ExecutionEngine/SparseTensorRuntime.h b/mlir/include/mlir/ExecutionEngine/SparseTensorRuntime.h
index 558799528c4c4..953cbe22804b5 100644
--- a/mlir/include/mlir/ExecutionEngine/SparseTensorRuntime.h
+++ b/mlir/include/mlir/ExecutionEngine/SparseTensorRuntime.h
@@ -137,6 +137,9 @@ MLIR_SPARSETENSOR_FOREVERY_V(DECL_EXPINSERT)
/// Tensor-storage method to get the size of the given level.
MLIR_CRUNNERUTILS_EXPORT index_type sparseLvlSize(void *tensor, index_type l);
+/// Tensor-storage method to get the size of the given dimension.
+MLIR_CRUNNERUTILS_EXPORT index_type sparseDimSize(void *tensor, index_type d);
+
/// Tensor-storage method to finalize lexicographic insertions.
MLIR_CRUNNERUTILS_EXPORT void endInsert(void *tensor);
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
index 5017d0e635520..eb2b567a22219 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
@@ -57,62 +57,111 @@ static func::CallOp replaceOpWithFuncCall(RewriterBase &rewriter, Operation *op,
operands);
}
-/// Generates call to lookup a level-size.
-static Value genLvlSizeCall(OpBuilder &builder, Location loc,
- SparseTensorEncodingAttr &enc, Value src,
+/// Generates call to lookup a level-size. N.B., this only generates
+/// the raw function call, and therefore (intentionally) does not perform
+/// any dim<->lvl conversion or other logic.
+static Value genLvlSizeCall(OpBuilder &builder, Location loc, Value tensor,
uint64_t lvl) {
- // Generate the call.
StringRef name = "sparseLvlSize";
- SmallVector<Value, 2> params{ // just two
- src, constantIndex(builder, loc, toStoredDim(enc, lvl))};
+ SmallVector<Value, 2> params{tensor, constantIndex(builder, loc, lvl)};
Type iTp = builder.getIndexType();
return createFuncCall(builder, loc, name, iTp, params, EmitCInterface::Off)
.getResult(0);
}
-/// Compute the size from type (for static sizes) or from an already-converted
-/// opaque pointer source (for dynamic sizes) at the given dimension.
-//
-// FIXME: Need to rename this function to match `genLvlSizeCall` and hence
-// match the naming convention used in the runtime library. However, it's
-// not entirely clear that all callsites of this function properly make the
-// "level"-vs-"dimension" distinction; so need to audit each callsite to
-// ensure this still does what they mean (possibly by having two separate
-// functions, one for levels and one for dimensions). That also means
-// renaming `sizesFromPtr`, `sizesFromType`, etc, to make clear whether
-// they mean to be referring to level-sizes vs dimension-sizes.
-static Value sizeFromPtrAtDim(OpBuilder &builder, Location loc,
- SparseTensorEncodingAttr &enc, ShapedType stp,
- Value src, unsigned i) {
- auto shape = stp.getShape();
- if (shape[i] == ShapedType::kDynamic)
- return genLvlSizeCall(builder, loc, enc, src, i);
- return constantIndex(builder, loc, shape[i]);
+/// Generates call to lookup a dimension-size. N.B., this only generates
+/// the raw function call, and therefore (intentionally) does not perform
+/// any dim<->lvl conversion or other logic.
+static Value genDimSizeCall(OpBuilder &builder, Location loc, Value tensor,
+ uint64_t dim) {
+ StringRef name = "sparseDimSize";
+ SmallVector<Value, 2> params{tensor, constantIndex(builder, loc, dim)};
+ Type iTp = builder.getIndexType();
+ return createFuncCall(builder, loc, name, iTp, params, EmitCInterface::Off)
+ .getResult(0);
+}
+
+/// Looks up a level-size by returning a statically-computed constant
+/// (when possible), or by calling `genLvlSizeCall` (when dynamic).
+static Value createOrFoldLvlCall(OpBuilder &builder, Location loc,
+ SparseTensorEncodingAttr &enc, ShapedType stp,
+ Value tensor, unsigned lvl) {
+ // Only sparse tensors have "levels" to query.
+ assert(enc);
+ auto dimOrder = enc.getDimOrdering();
+ // TODO: The following implementation only handles permutations;
+ // we'll need to generalize this to handle arbitrary AffineExpr.
+ //
+ // There's no need to assert `isPermutation` here: because
+ // `getDimPosition` checks that the expr isa `AffineDimExpr`,
+ // which is all we care about (for supporting permutations).
+ unsigned dim = dimOrder ? dimOrder.getDimPosition(lvl) : lvl;
+ auto s = stp.getShape()[dim];
+ if (s != ShapedType::kDynamic)
+ return constantIndex(builder, loc, s);
+ // If we cannot statically compute the size from the shape, then we
+ // must dynamically query it. (In principle we could also dynamically
+ // compute it, but since we already did so to construct the `tensor`
+ // in the first place, we might as well query rather than recompute.)
+ return genLvlSizeCall(builder, loc, tensor, lvl);
+}
+
+/// Looks up a dimension-size by returning a constant from the shape
+/// (for static sizes), or by calling `genDimSizeCall` (for dynamic sizes
+/// of sparse tensors) or `linalg::createOrFoldDimOp` (for dynamic sizes
+/// of dense tensors).
+static Value createOrFoldDimCall(OpBuilder &builder, Location loc,
+ SparseTensorEncodingAttr &enc, ShapedType stp,
+ Value tensor, unsigned dim) {
+ auto s = stp.getShape()[dim];
+ if (s != ShapedType::kDynamic)
+ return constantIndex(builder, loc, s);
+ if (enc)
+ return genDimSizeCall(builder, loc, tensor, dim);
+ return linalg::createOrFoldDimOp(builder, loc, tensor, dim);
+}
+
+/// Populates the array with the dimension-sizes of the given tensor.
+static void fillDimSizes(OpBuilder &builder, Location loc,
+ SparseTensorEncodingAttr &enc, ShapedType stp,
+ Value tensor, SmallVectorImpl<Value> &out) {
+ unsigned dimRank = stp.getRank();
+ out.reserve(dimRank);
+ for (unsigned d = 0; d < dimRank; d++)
+ out.push_back(createOrFoldDimCall(builder, loc, enc, stp, tensor, d));
}
-/// Populates given sizes array from type (for static sizes) and from
-/// an already-converted opaque pointer source (for dynamic sizes).
-static void sizesFromPtr(OpBuilder &builder, SmallVectorImpl<Value> &sizes,
- Location loc, SparseTensorEncodingAttr &enc,
- ShapedType stp, Value src) {
- unsigned rank = stp.getRank();
- sizes.reserve(rank);
- for (unsigned i = 0; i < rank; i++)
- sizes.push_back(sizeFromPtrAtDim(builder, loc, enc, stp, src, i));
+/// Returns an array with the dimension-sizes of the given tensor.
+static SmallVector<Value> getDimSizes(OpBuilder &builder, Location loc,
+ SparseTensorEncodingAttr &enc,
+ ShapedType stp, Value tensor) {
+ SmallVector<Value> out;
+ fillDimSizes(builder, loc, enc, stp, tensor, out);
+ return out;
}
-/// Populates given sizes array from type.
-static void sizesFromType(OpBuilder &builder, SmallVectorImpl<Value> &sizes,
- Location loc, ShapedType stp) {
+/// Populates the array with the dimension-shape of the given `ShapedType`,
+/// where dynamic sizes are represented by zero.
+static void fillDimShape(OpBuilder &builder, Location loc, ShapedType stp,
+ SmallVectorImpl<Value> &out) {
auto shape = stp.getShape();
- unsigned rank = stp.getRank();
- sizes.reserve(rank);
- for (unsigned i = 0; i < rank; i++) {
- uint64_t s = shape[i] == ShapedType::kDynamic ? 0 : shape[i];
- sizes.push_back(constantIndex(builder, loc, s));
+ unsigned dimRank = stp.getRank();
+ out.reserve(dimRank);
+ for (unsigned d = 0; d < dimRank; d++) {
+ auto s = shape[d] == ShapedType::kDynamic ? 0 : shape[d];
+ out.push_back(constantIndex(builder, loc, s));
}
}
+/// Returns an array with the dimension-shape of the given `ShapedType`,
+/// where dynamic sizes are represented by zero.
+static SmallVector<Value> getDimShape(OpBuilder &builder, Location loc,
+ ShapedType stp) {
+ SmallVector<Value> out;
+ fillDimShape(builder, loc, stp, out);
+ return out;
+}
+
/// Populates the given sizes array for concatenation from type (for static
/// sizes) and from an already-converted opaque pointer source (for dynamic
/// sizes).
@@ -128,7 +177,7 @@ static void concatSizesFromInputs(OpBuilder &builder,
// compute the size of the concatenation dimension if necessary.
if (srcEnc)
// Reuses sizes from an arbitrary input tensor is fine.
- sizesFromPtr(builder, sizes, loc, srcEnc, srcTp, srcs[0]);
+ fillDimSizes(builder, loc, srcEnc, srcTp, srcs[0], sizes);
else
sizesFromSrc(builder, sizes, loc, srcs[0]);
@@ -142,8 +191,7 @@ static void concatSizesFromInputs(OpBuilder &builder,
auto srcTp = srcs[i].getType().cast<ShapedType>();
auto encSrc = getSparseTensorEncoding(srcTp);
Value srcSz =
- encSrc ? sizeFromPtrAtDim(builder, loc, encSrc, srcTp, srcs[i], dim)
- : linalg::createOrFoldDimOp(builder, loc, srcs[i], dim);
+ createOrFoldDimCall(builder, loc, encSrc, srcTp, srcs[i], dim);
// Sum up all the sizes.
sizes[dim] = builder.create<arith::AddIOp>(loc, sizes[dim], srcSz);
}
@@ -489,9 +537,6 @@ genSparse2SparseReshape(ReshapeOp op, typename ReshapeOp::Adaptor adaptor,
auto encDst = getSparseTensorEncoding(dstTp);
if (!encDst || !encSrc)
return failure();
-
- unsigned srcRank = srcTp.getRank();
- unsigned dstRank = dstTp.getRank();
Type elemTp = srcTp.getElementType();
assert(elemTp == dstTp.getElementType() &&
"reshape should not change element type");
@@ -499,26 +544,26 @@ genSparse2SparseReshape(ReshapeOp op, typename ReshapeOp::Adaptor adaptor,
auto noPerm = SparseTensorEncodingAttr::get(
op->getContext(), encSrc.getDimLevelType(), AffineMap(), AffineMap(),
encSrc.getPointerBitWidth(), encSrc.getIndexBitWidth());
- SmallVector<Value> srcSizes;
- sizesFromPtr(rewriter, srcSizes, loc, encSrc, srcTp, adaptor.getSrc());
+ SmallVector<Value> srcDimSizes =
+ getDimSizes(rewriter, loc, encSrc, srcTp, adaptor.getSrc());
NewCallParams params(rewriter, loc);
- Value iter = params.genBuffers(noPerm, srcSizes, srcTp)
+ Value iter = params.genBuffers(noPerm, srcDimSizes, srcTp)
.genNewCall(Action::kToIterator, adaptor.getSrc());
// Start a new COO for the destination tensor.
- SmallVector<Value> dstSizes;
- if (dstTp.hasStaticShape()) {
- sizesFromType(rewriter, dstSizes, loc, dstTp);
- } else {
- ArrayRef<int64_t> dstShape = dstTp.getShape();
- genReshapeDstShape(loc, rewriter, dstSizes, srcSizes, dstShape,
- op.getReassociationIndices());
- }
- Value coo =
- params.genBuffers(encDst, dstSizes, dstTp).genNewCall(Action::kEmptyCOO);
+ SmallVector<Value> dstDimSizes;
+ if (dstTp.hasStaticShape())
+ // Static "shapes" are in fact "sizes".
+ fillDimShape(rewriter, loc, dstTp, dstDimSizes);
+ else
+ genReshapeDstShape(loc, rewriter, dstDimSizes, srcDimSizes,
+ dstTp.getShape(), op.getReassociationIndices());
+ Value coo = params.genBuffers(encDst, dstDimSizes, dstTp)
+ .genNewCall(Action::kEmptyCOO);
Value dstPerm = params.getDim2LvlMap();
// Construct a while loop over the iterator.
- Value srcIdx = genAlloca(rewriter, loc, srcRank, rewriter.getIndexType());
- Value dstIdx = genAlloca(rewriter, loc, dstRank, rewriter.getIndexType());
+ Type iTp = rewriter.getIndexType();
+ Value srcIdx = genAlloca(rewriter, loc, srcTp.getRank(), iTp);
+ Value dstIdx = genAlloca(rewriter, loc, dstTp.getRank(), iTp);
Value elemPtr = genAllocaScalar(rewriter, loc, elemTp);
SmallVector<Value> noArgs;
SmallVector<Type> noTypes;
@@ -532,7 +577,7 @@ genSparse2SparseReshape(ReshapeOp op, typename ReshapeOp::Adaptor adaptor,
Block *after = rewriter.createBlock(&whileOp.getAfter(), {}, noTypes);
rewriter.setInsertionPointToStart(after);
translateIndices(loc, rewriter, op.getReassociationIndices(), dstTp, srcTp,
- dstIdx, srcIdx, dstSizes, srcSizes);
+ dstIdx, srcIdx, dstDimSizes, srcDimSizes);
genAddEltCall(rewriter, loc, elemTp, coo, elemPtr, dstIdx, dstPerm);
rewriter.create<scf::YieldOp>(loc);
// Final call to construct sparse tensor storage and free temporary resources.
@@ -566,10 +611,9 @@ static void genSparseCOOIterationLoop(
auto noPerm = SparseTensorEncodingAttr::get(
rewriter.getContext(), enc.getDimLevelType(), AffineMap(), AffineMap(),
enc.getPointerBitWidth(), enc.getIndexBitWidth());
- SmallVector<Value> sizes;
- sizesFromPtr(rewriter, sizes, loc, noPerm, tensorTp, t);
+ SmallVector<Value> dimSizes = getDimSizes(rewriter, loc, noPerm, tensorTp, t);
Value iter = NewCallParams(rewriter, loc)
- .genBuffers(noPerm, sizes, tensorTp)
+ .genBuffers(noPerm, dimSizes, tensorTp)
.genNewCall(Action::kToIterator, t);
// Construct a while loop over the iterator.
@@ -664,7 +708,7 @@ class SparseReturnConverter : public OpConversionPattern<func::ReturnOp> {
}
};
-/// Sparse conversion rule for dimension accesses.
+/// Sparse conversion rule for accessing dimension-sizes.
class SparseTensorToDimSizeConverter
: public OpConversionPattern<tensor::DimOp> {
public:
@@ -672,18 +716,19 @@ class SparseTensorToDimSizeConverter
LogicalResult
matchAndRewrite(tensor::DimOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- // Only rewrite annotated DimOp with constant index.
- auto enc = getSparseTensorEncoding(op.getSource().getType());
+ auto stp = op.getSource().getType().cast<ShapedType>();
+ // Only rewrite sparse DimOp.
+ auto enc = getSparseTensorEncoding(stp);
if (!enc)
return failure();
- Optional<int64_t> index = op.getConstantIndex();
- if (!index)
+ // Only rewrite DimOp with constant index.
+ Optional<int64_t> dim = op.getConstantIndex();
+ if (!dim)
return failure();
// Generate the call.
Value src = adaptor.getOperands()[0];
- int64_t idx = *index;
- rewriter.replaceOp(op,
- genLvlSizeCall(rewriter, op->getLoc(), enc, src, idx));
+ rewriter.replaceOp(
+ op, createOrFoldDimCall(rewriter, op->getLoc(), enc, stp, src, *dim));
return success();
}
};
@@ -734,8 +779,7 @@ class SparseTensorNewConverter : public OpConversionPattern<NewOp> {
const unsigned lvlRank = enc.getDimLevelType().size();
// Construct the dimShape.
const auto dimShape = stp.getShape();
- SmallVector<Value> dimShapeValues;
- sizesFromType(rewriter, dimShapeValues, loc, stp);
+ SmallVector<Value> dimShapeValues = getDimShape(rewriter, loc, stp);
Value dimShapeBuffer = genBuffer(rewriter, loc, dimShapeValues);
// Allocate `SparseTensorReader` and perform all initial setup that
// does not depend on lvlSizes (nor dim2lvl, lvl2dim, etc).
@@ -890,10 +934,10 @@ class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> {
rewriter.replaceOp(op, adaptor.getOperands()); // hidden nop cast
return success();
}
- SmallVector<Value> sizes;
NewCallParams params(rewriter, loc);
ShapedType stp = srcType.cast<ShapedType>();
- sizesFromPtr(rewriter, sizes, loc, encSrc, stp, src);
+ SmallVector<Value> dimSizes =
+ getDimSizes(rewriter, loc, encSrc, stp, src);
bool useDirectConversion;
switch (options.sparseToSparseStrategy) {
case SparseToSparseConversionStrategy::kViaCOO:
@@ -909,7 +953,7 @@ class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> {
break;
}
if (useDirectConversion) {
- rewriter.replaceOp(op, params.genBuffers(encDst, sizes, stp)
+ rewriter.replaceOp(op, params.genBuffers(encDst, dimSizes, stp)
.genNewCall(Action::kSparseToSparse, src));
} else { // use via-COO conversion.
// Set up encoding with right mix of src and dst so that the two
@@ -922,8 +966,8 @@ class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> {
// TODO: This is the only place where `kToCOO` (or `kToIterator`)
// is called with a non-identity permutation. Is there any clean
// way to push the permutation over to the `kFromCOO` side instead?
- Value coo =
- params.genBuffers(enc, sizes, stp).genNewCall(Action::kToCOO, src);
+ Value coo = params.genBuffers(enc, dimSizes, stp)
+ .genNewCall(Action::kToCOO, src);
Value dst = params.setTemplateTypes(encDst, stp)
.genNewCall(Action::kFromCOO, coo);
genDelCOOCall(rewriter, loc, stp.getElementType(), coo);
@@ -950,17 +994,17 @@ class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> {
op->getContext(),
SmallVector<DimLevelType>(rank, DimLevelType::Dense), AffineMap(),
AffineMap(), encSrc.getPointerBitWidth(), encSrc.getIndexBitWidth());
- SmallVector<Value> sizes;
- sizesFromPtr(rewriter, sizes, loc, encSrc, srcTensorTp, src);
+ SmallVector<Value> dimSizes =
+ getDimSizes(rewriter, loc, encSrc, srcTensorTp, src);
Value iter = NewCallParams(rewriter, loc)
- .genBuffers(encDst, sizes, dstTensorTp)
+ .genBuffers(encDst, dimSizes, dstTensorTp)
.genNewCall(Action::kToIterator, src);
Value ind = genAlloca(rewriter, loc, rank, rewriter.getIndexType());
Value elemPtr = genAllocaScalar(rewriter, loc, elemTp);
Block *insertionBlock = rewriter.getInsertionBlock();
// TODO: Dense buffers should be allocated/deallocated via the callback
// in BufferizationOptions.
- Value dst = allocDenseTensor(rewriter, loc, dstTensorTp, sizes);
+ Value dst = allocDenseTensor(rewriter, loc, dstTensorTp, dimSizes);
SmallVector<Value> noArgs;
SmallVector<Type> noTypes;
auto whileOp = rewriter.create<scf::WhileOp>(loc, noTypes, noArgs);
@@ -1196,12 +1240,12 @@ class SparseTensorExpandConverter : public OpConversionPattern<ExpandOp> {
Type idxType = rewriter.getIndexType();
// All initialization should be done on entry of the loop nest.
rewriter.setInsertionPointAfter(op.getTensor().getDefiningOp());
- // Determine the size for access expansion (always the innermost stored
- // dimension size, translated back to original dimension).
- auto enc = getSparseTensorEncoding(srcType);
- unsigned innerDim = toOrigDim(srcType, srcType.getRank() - 1);
- auto sz = sizeFromPtrAtDim(rewriter, loc, enc, srcType, adaptor.getTensor(),
- innerDim);
+ // Get the cardinality of valid coordinates for the innermost level.
+ auto srcEnc = getSparseTensorEncoding(srcType);
+ unsigned lvlRank =
+ srcEnc ? srcEnc.getDimLevelType().size() : srcType.getRank();
+ Value sz = createOrFoldLvlCall(rewriter, loc, srcEnc, srcType,
+ adaptor.getTensor(), lvlRank - 1);
// Allocate temporary buffers for values, filled-switch, and indices.
// We do not use stack buffers for this, since the expanded size may
// be rather large (as it envelops a single expanded dense dimension).
@@ -1377,10 +1421,8 @@ class SparseTensorConcatConverter : public OpConversionPattern<ConcatenateOp> {
}
// Accumulate offset.
// TODO: avoid calling sparseDimSize multiple times by caching the result!
- Value curDim = encSrc ? sizeFromPtrAtDim(rewriter, loc, encSrc, srcTp,
- adaptedOp, concatDim)
- : linalg::createOrFoldDimOp(rewriter, loc,
- adaptedOp, concatDim);
+ Value curDim = createOrFoldDimCall(rewriter, loc, encSrc, srcTp,
+ adaptedOp, concatDim);
offset = rewriter.create<arith::AddIOp>(loc, offset, curDim);
}
@@ -1410,13 +1452,13 @@ class SparseTensorOutConverter : public OpConversionPattern<OutOp> {
// Convert to default permuted COO.
Value src = adaptor.getOperands()[0];
auto encSrc = getSparseTensorEncoding(srcType);
- SmallVector<Value> sizes;
- sizesFromPtr(rewriter, sizes, loc, encSrc, srcType, src);
+ SmallVector<Value> dimSizes =
+ getDimSizes(rewriter, loc, encSrc, srcType, src);
auto enc = SparseTensorEncodingAttr::get(
op->getContext(), encSrc.getDimLevelType(), AffineMap(), AffineMap(),
encSrc.getPointerBitWidth(), encSrc.getIndexBitWidth());
Value coo = NewCallParams(rewriter, loc)
- .genBuffers(enc, sizes, srcType)
+ .genBuffers(enc, dimSizes, srcType)
.genNewCall(Action::kToCOO, src);
// Then output the tensor to external file with indices in the externally
// visible lexicographic index order. A sort is required if the source was
diff --git a/mlir/lib/ExecutionEngine/SparseTensorRuntime.cpp b/mlir/lib/ExecutionEngine/SparseTensorRuntime.cpp
index dbb08d5e2f56a..c9c404ec2ddc7 100644
--- a/mlir/lib/ExecutionEngine/SparseTensorRuntime.cpp
+++ b/mlir/lib/ExecutionEngine/SparseTensorRuntime.cpp
@@ -779,8 +779,12 @@ MLIR_SPARSETENSOR_FOREVERY_V(IMPL_OUTNEXT)
//
//===----------------------------------------------------------------------===//
-index_type sparseLvlSize(void *tensor, index_type x) {
- return static_cast<SparseTensorStorageBase *>(tensor)->getLvlSize(x);
+index_type sparseLvlSize(void *tensor, index_type l) {
+ return static_cast<SparseTensorStorageBase *>(tensor)->getLvlSize(l);
+}
+
+index_type sparseDimSize(void *tensor, index_type d) {
+ return static_cast<SparseTensorStorageBase *>(tensor)->getDimSize(d);
}
void endInsert(void *tensor) {
diff --git a/mlir/test/Dialect/SparseTensor/conversion.mlir b/mlir/test/Dialect/SparseTensor/conversion.mlir
index 2264066353169..dc4efae50c01b 100644
--- a/mlir/test/Dialect/SparseTensor/conversion.mlir
+++ b/mlir/test/Dialect/SparseTensor/conversion.mlir
@@ -40,7 +40,7 @@ func.func @sparse_nop(%arg0: tensor<?xf64, #SparseVector>) -> tensor<?xf64, #Spa
// CHECK-LABEL: func @sparse_dim1d(
// CHECK-SAME: %[[A:.*]]: !llvm.ptr<i8>)
// CHECK: %[[C:.*]] = arith.constant 0 : index
-// CHECK: %[[D:.*]] = call @sparseLvlSize(%[[A]], %[[C]])
+// CHECK: %[[D:.*]] = call @sparseDimSize(%[[A]], %[[C]])
// CHECK: return %[[D]] : index
func.func @sparse_dim1d(%arg0: tensor<?xf64, #SparseVector>) -> index {
%c = arith.constant 0 : index
@@ -48,28 +48,28 @@ func.func @sparse_dim1d(%arg0: tensor<?xf64, #SparseVector>) -> index {
return %0 : index
}
+// Querying the size of dimension 1 should do so; i.e., it should
+// not be permuted into a query for the size of level 2 (even though
+// dimension 1 is stored as level 2).
// CHECK-LABEL: func @sparse_dim3d(
// CHECK-SAME: %[[A:.*]]: !llvm.ptr<i8>)
-// CHECK: %[[C:.*]] = arith.constant 2 : index
-// CHECK: %[[D:.*]] = call @sparseLvlSize(%[[A]], %[[C]])
+// CHECK: %[[C:.*]] = arith.constant 1 : index
+// CHECK: %[[D:.*]] = call @sparseDimSize(%[[A]], %[[C]])
// CHECK: return %[[D]] : index
func.func @sparse_dim3d(%arg0: tensor<?x?x?xf64, #SparseTensor>) -> index {
- // Querying for dimension 1 in the tensor type needs to be
- // permuted into querying for dimension 2 in the stored sparse
- // tensor scheme, since the latter honors the dimOrdering.
%c = arith.constant 1 : index
%0 = tensor.dim %arg0, %c : tensor<?x?x?xf64, #SparseTensor>
return %0 : index
}
+// Querying the size of a static dimension should be folded into a
+// constant (and we should be sure to get the size of dimension 1,
+// not dimension 2 nor level 1).
// CHECK-LABEL: func @sparse_dim3d_const(
// CHECK-SAME: %[[A:.*]]: !llvm.ptr<i8>)
// CHECK: %[[C:.*]] = arith.constant 20 : index
// CHECK: return %[[C]] : index
func.func @sparse_dim3d_const(%arg0: tensor<10x20x30xf64, #SparseTensor>) -> index {
- // Querying for dimension 1 in the tensor type can be directly
- // folded into the right value (even though it corresponds
- // to dimension 2 in the stored sparse tensor scheme).
%c = arith.constant 1 : index
%0 = tensor.dim %arg0, %c : tensor<10x20x30xf64, #SparseTensor>
return %0 : index
@@ -361,7 +361,7 @@ func.func @sparse_expansion2() -> memref<?xindex> {
// CHECK-LABEL: func @sparse_expansion3(
// CHECK: %[[C1:.*]] = arith.constant 1 : index
// CHECK: %[[N:.*]] = call @newSparseTensor
-// CHECK: %[[S:.*]] = call @sparseLvlSize(%[[N]], %c1) : (!llvm.ptr<i8>, index) -> index
+// CHECK: %[[S:.*]] = call @sparseLvlSize(%[[N]], %[[C1]])
// CHECK: %[[A:.*]] = memref.alloc(%[[S]]) : memref<?xf64>
// CHECK: %[[B:.*]] = memref.alloc(%[[S]]) : memref<?xi1>
// CHECK: %[[C:.*]] = memref.alloc(%[[S]]) : memref<?xindex>
diff --git a/mlir/test/Dialect/SparseTensor/convert_sparse2dense.mlir b/mlir/test/Dialect/SparseTensor/convert_sparse2dense.mlir
index 2c5de95d775ef..b847a277859fb 100644
--- a/mlir/test/Dialect/SparseTensor/convert_sparse2dense.mlir
+++ b/mlir/test/Dialect/SparseTensor/convert_sparse2dense.mlir
@@ -70,7 +70,7 @@ func.func @sparse_convert_1d(%arg0: tensor<13xi32, #SparseVector>) -> tensor<13x
// CHECK-DAG: memref.store %[[DenseDLT]], %[[LvlTypes]][%[[I0]]] : memref<1xi8>
// CHECK-DAG: %[[DimSizes:.*]] = memref.alloca() : memref<1xindex>
// CHECK-DAG: %[[DimSizesP:.*]] = memref.cast %[[DimSizes]] : memref<1xindex> to memref<?xindex>
-// CHECK-DAG: %[[SizeI0:.*]] = call @sparseLvlSize(%[[Arg]], %[[I0]]) : (!llvm.ptr<i8>, index) -> index
+// CHECK-DAG: %[[SizeI0:.*]] = call @sparseDimSize(%[[Arg]], %[[I0]]) : (!llvm.ptr<i8>, index) -> index
// CHECK-DAG: memref.store %[[SizeI0]], %[[DimSizes]][%[[I0]]] : memref<1xindex>
// CHECK-DAG: %[[LvlSizes:.*]] = memref.alloca() : memref<1xindex>
// CHECK-DAG: %[[LvlSizesP:.*]] = memref.cast %[[LvlSizes]] : memref<1xindex> to memref<?xindex>
@@ -175,7 +175,7 @@ func.func @sparse_convert_2d(%arg0: tensor<2x4xf64, #SparseMatrix>) -> tensor<2x
// CHECK-DAG: memref.store %[[DenseDLT]], %[[LvlTypes]][%[[I1]]] : memref<2xi8>
// CHECK-DAG: %[[DimSizes:.*]] = memref.alloca() : memref<2xindex>
// CHECK-DAG: %[[DimSizesP:.*]] = memref.cast %[[DimSizes]] : memref<2xindex> to memref<?xindex>
-// CHECK-DAG: %[[SizeI0:.*]] = call @sparseLvlSize(%[[Arg]], %[[I0]]) : (!llvm.ptr<i8>, index) -> index
+// CHECK-DAG: %[[SizeI0:.*]] = call @sparseDimSize(%[[Arg]], %[[I0]]) : (!llvm.ptr<i8>, index) -> index
// CHECK-DAG: memref.store %[[SizeI0]], %[[DimSizes]][%[[I0]]] : memref<2xindex>
// CHECK-DAG: memref.store %[[I4]], %[[DimSizes]][%[[I1]]] : memref<2xindex>
// CHECK-DAG: %[[LvlSizes:.*]] = memref.alloca() : memref<2xindex>
@@ -223,7 +223,7 @@ func.func @sparse_convert_2d_dyn0(%arg0: tensor<?x4xf64, #SparseMatrix>) -> tens
// CHECK-DAG: memref.store %[[DenseDLT]], %[[LvlTypes]][%[[I1]]] : memref<2xi8>
// CHECK-DAG: %[[DimSizes:.*]] = memref.alloca() : memref<2xindex>
// CHECK-DAG: %[[DimSizesP:.*]] = memref.cast %[[DimSizes]] : memref<2xindex> to memref<?xindex>
-// CHECK-DAG: %[[SizeI1:.*]] = call @sparseLvlSize(%[[Arg]], %[[I1]]) : (!llvm.ptr<i8>, index) -> index
+// CHECK-DAG: %[[SizeI1:.*]] = call @sparseDimSize(%[[Arg]], %[[I1]]) : (!llvm.ptr<i8>, index) -> index
// CHECK-DAG: memref.store %[[I2]], %[[DimSizes]][%[[I0]]] : memref<2xindex>
// CHECK-DAG: memref.store %[[SizeI1]], %[[DimSizes]][%[[I1]]] : memref<2xindex>
// CHECK-DAG: %[[LvlSizes:.*]] = memref.alloca() : memref<2xindex>
@@ -270,8 +270,8 @@ func.func @sparse_convert_2d_dyn1(%arg0: tensor<2x?xf64, #SparseMatrix>) -> tens
// CHECK-DAG: memref.store %[[DenseDLT]], %[[LvlTypes]][%[[I1]]] : memref<2xi8>
// CHECK-DAG: %[[DimSizes:.*]] = memref.alloca() : memref<2xindex>
// CHECK-DAG: %[[DimSizesP:.*]] = memref.cast %[[DimSizes]] : memref<2xindex> to memref<?xindex>
-// CHECK-DAG: %[[SizeI0:.*]] = call @sparseLvlSize(%[[Arg]], %[[I0]]) : (!llvm.ptr<i8>, index) -> index
-// CHECK-DAG: %[[SizeI1:.*]] = call @sparseLvlSize(%[[Arg]], %[[I1]]) : (!llvm.ptr<i8>, index) -> index
+// CHECK-DAG: %[[SizeI0:.*]] = call @sparseDimSize(%[[Arg]], %[[I0]]) : (!llvm.ptr<i8>, index) -> index
+// CHECK-DAG: %[[SizeI1:.*]] = call @sparseDimSize(%[[Arg]], %[[I1]]) : (!llvm.ptr<i8>, index) -> index
// CHECK-DAG: memref.store %[[SizeI0]], %[[DimSizes]][%[[I0]]] : memref<2xindex>
// CHECK-DAG: memref.store %[[SizeI1]], %[[DimSizes]][%[[I1]]] : memref<2xindex>
// CHECK-DAG: %[[LvlSizes:.*]] = memref.alloca() : memref<2xindex>
diff --git a/mlir/test/Dialect/SparseTensor/sparse_expand.mlir b/mlir/test/Dialect/SparseTensor/sparse_expand.mlir
index 946a828274b24..785033494bf2b 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_expand.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_expand.mlir
@@ -46,10 +46,11 @@
// CHECK-SPARSE: return %[[RET]]
//
// CHECK-CONVERT-LABEL: func @kernel(
-// CHECK-CONVERT: %[[C:.*]] = arith.constant 0 : index
-// CHECK-CONVERT: %{{.*}} = call @sparseLvlSize
-// CHECK-CONVERT: %[[N:.*]] = call @newSparseTensor
-// CHECK-CONVERT: %[[S:.*]] = call @sparseLvlSize(%[[N]], %[[C]])
+// CHECK-CONVERT-SAME: %[[A:.*]]: !llvm.ptr<i8>) -> !llvm.ptr<i8>
+// CHECK-CONVERT: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-CONVERT: %[[N:.*]] = call @sparseDimSize(%[[A]], %[[C0]])
+// CHECK-CONVERT: %[[V:.*]] = call @newSparseTensor
+// CHECK-CONVERT: %[[S:.*]] = call @sparseLvlSize(%[[V]], %[[C0]])
// CHECK-CONVERT: %[[A:.*]] = memref.alloc(%[[S]]) : memref<?xf64>
// CHECK-CONVERT: %[[B:.*]] = memref.alloc(%[[S]]) : memref<?xi1>
// CHECK-CONVERT: %[[C:.*]] = memref.alloc(%[[S]]) : memref<?xindex>
diff --git a/mlir/test/Dialect/SparseTensor/sparse_perm_lower.mlir b/mlir/test/Dialect/SparseTensor/sparse_perm_lower.mlir
index a8da710f0bd0c..56d3168a7634b 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_perm_lower.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_perm_lower.mlir
@@ -49,24 +49,24 @@
// CHECK-HIR: }
//
// CHECK-MIR-LABEL: func @sparse_dynamic_dims(
-// CHECK-MIR-SAME: %[[VAL_0:.*]]: !llvm.ptr<i8>,
-// CHECK-MIR-SAME: %[[VAL_1:.*]]: tensor<f32>) -> tensor<f32> {
-// CHECK-MIR-DAG: %[[VAL_2:.*]] = arith.constant 2 : index
-// CHECK-MIR-DAG: %[[VAL_3:.*]] = arith.constant 1 : index
-// CHECK-MIR-DAG: %[[VAL_4:.*]] = arith.constant 0 : index
-// CHECK-MIR-DAG: %[[VAL_5:.*]] = call @sparseLvlSize(%[[VAL_0]], %[[VAL_4]]) : (!llvm.ptr<i8>, index) -> index
-// CHECK-MIR-DAG: %[[VAL_6:.*]] = call @sparseLvlSize(%[[VAL_0]], %[[VAL_3]]) : (!llvm.ptr<i8>, index) -> index
-// CHECK-MIR-DAG: %[[VAL_7:.*]] = call @sparseLvlSize(%[[VAL_0]], %[[VAL_2]]) : (!llvm.ptr<i8>, index) -> index
-// CHECK-MIR-DAG: %[[VAL_8:.*]] = call @sparseValuesF32(%[[VAL_0]]) : (!llvm.ptr<i8>) -> memref<?xf32>
-// CHECK-MIR-DAG: %[[VAL_10:.*]] = bufferization.to_memref %[[VAL_1]] : memref<f32>
-// CHECK-MIR: %[[VAL_11:.*]] = tensor.extract %[[VAL_1]][] : tensor<f32>
-// CHECK-MIR: %[[VAL_12:.*]] = scf.for %[[VAL_13:.*]] = %[[VAL_4]] to %[[VAL_5]] step %[[VAL_3]] iter_args(%[[VAL_14:.*]] = %[[VAL_11]]) -> (f32) {
-// CHECK-MIR: %[[VAL_15:.*]] = scf.for %[[VAL_16:.*]] = %[[VAL_4]] to %[[VAL_6]] step %[[VAL_3]] iter_args(%[[VAL_17:.*]] = %[[VAL_14]]) -> (f32) {
-// CHECK-MIR: %[[VAL_18:.*]] = arith.muli %[[VAL_6]], %[[VAL_13]] : index
-// CHECK-MIR: %[[VAL_19:.*]] = arith.addi %[[VAL_18]], %[[VAL_16]] : index
-// CHECK-MIR: %[[VAL_20:.*]] = scf.for %[[VAL_21:.*]] = %[[VAL_4]] to %[[VAL_7]] step %[[VAL_3]] iter_args(%[[VAL_22:.*]] = %[[VAL_17]]) -> (f32) {
-// CHECK-MIR: %[[VAL_23:.*]] = arith.muli %[[VAL_7]], %[[VAL_19]] : index
-// CHECK-MIR: %[[VAL_24:.*]] = arith.addi %[[VAL_23]], %[[VAL_21]] : index
+// CHECK-MIR-SAME: %[[ARGA:.*]]: !llvm.ptr<i8>,
+// CHECK-MIR-SAME: %[[ARGX:.*]]: tensor<f32>) -> tensor<f32> {
+// CHECK-MIR-DAG: %[[I0:.*]] = arith.constant 0 : index
+// CHECK-MIR-DAG: %[[I1:.*]] = arith.constant 1 : index
+// CHECK-MIR-DAG: %[[I2:.*]] = arith.constant 2 : index
+// CHECK-MIR-DAG: %[[DimSize0:.*]] = call @sparseDimSize(%[[ARGA]], %[[I0]])
+// CHECK-MIR-DAG: %[[DimSize1:.*]] = call @sparseDimSize(%[[ARGA]], %[[I1]])
+// CHECK-MIR-DAG: %[[DimSize2:.*]] = call @sparseDimSize(%[[ARGA]], %[[I2]])
+// CHECK-MIR-DAG: %[[VAL_8:.*]] = call @sparseValuesF32(%[[ARGA]]) : (!llvm.ptr<i8>) -> memref<?xf32>
+// CHECK-MIR-DAG: %[[VAL_10:.*]] = bufferization.to_memref %[[ARGX]] : memref<f32>
+// CHECK-MIR: %[[VAL_11:.*]] = tensor.extract %[[ARGX]][] : tensor<f32>
+// CHECK-MIR: %[[VAL_12:.*]] = scf.for %[[D2:.*]] = %[[I0]] to %[[DimSize2]] step %[[I1]] iter_args(%[[VAL_14:.*]] = %[[VAL_11]]) -> (f32) {
+// CHECK-MIR: %[[VAL_15:.*]] = scf.for %[[D0:.*]] = %[[I0]] to %[[DimSize0]] step %[[I1]] iter_args(%[[VAL_17:.*]] = %[[VAL_14]]) -> (f32) {
+// CHECK-MIR: %[[VAL_18:.*]] = arith.muli %[[DimSize0]], %[[D2]] : index
+// CHECK-MIR: %[[VAL_19:.*]] = arith.addi %[[VAL_18]], %[[D0]] : index
+// CHECK-MIR: %[[VAL_20:.*]] = scf.for %[[D1:.*]] = %[[I0]] to %[[DimSize1]] step %[[I1]] iter_args(%[[VAL_22:.*]] = %[[VAL_17]]) -> (f32) {
+// CHECK-MIR: %[[VAL_23:.*]] = arith.muli %[[DimSize1]], %[[VAL_19]] : index
+// CHECK-MIR: %[[VAL_24:.*]] = arith.addi %[[VAL_23]], %[[D1]] : index
// CHECK-MIR: %[[VAL_25:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_24]]] : memref<?xf32>
// CHECK-MIR: %[[VAL_26:.*]] = arith.addf %[[VAL_22]], %[[VAL_25]] : f32
// CHECK-MIR: scf.yield %[[VAL_26]] : f32
More information about the Mlir-commits
mailing list