[Mlir-commits] [mlir] [mlir][sparse] implement sparse_tensor.reorder_coo (PR #68916)
Peiming Liu
llvmlistbot at llvm.org
Thu Oct 12 10:39:15 PDT 2023
https://github.com/PeimingLiu updated https://github.com/llvm/llvm-project/pull/68916
>From 2e25769c88e642d6d214145047d22670efa9759a Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Wed, 11 Oct 2023 23:29:54 +0000
Subject: [PATCH 1/5] [mlir][sparse] implement sparse_tensor.reorder_coo
operation
---
.../mlir/Dialect/SparseTensor/IR/Enums.h | 2 +
.../SparseTensor/IR/SparseTensorOps.td | 4 -
.../ExecutionEngine/SparseTensor/Storage.h | 75 +++++++
.../SparseTensor/IR/SparseTensorDialect.cpp | 19 +-
.../Transforms/SparseTensorCodegen.cpp | 29 +--
.../Transforms/SparseTensorConversion.cpp | 195 ++----------------
.../Transforms/SparseTensorRewriting.cpp | 22 +-
.../Transforms/StageSparseOperations.cpp | 10 +-
.../ExecutionEngine/SparseTensorRuntime.cpp | 6 +
.../CPU/sparse_conversion_sparse2sparse.mlir | 40 +---
10 files changed, 133 insertions(+), 269 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
index 2920ef79f461c6a..ca9555248130f08 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
@@ -151,6 +151,8 @@ enum class Action : uint32_t {
kToCOO = 5,
kToIterator = 6,
kPack = 7,
+ // Sort an unordered COO in place.
+ kSortCOOInPlace = 8,
};
/// This enum defines all the sparse representations supportable by
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index afbabb97eb71fc5..9016634fa3be8dd 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -200,10 +200,6 @@ def SparseTensor_ConvertOp : SparseTensor_Op<"convert",
// Whether the convert can be done by a single step (either a sort or a foreach),
// or it would require a tmp buffer (sort, then foreach).
bool directConvertable();
-
- // Whether the convert is actually a sort coo
- // TODO: The method will be removed when sort_coo operation is introduced.
- bool isSortCOOConvert();
}];
let assemblyFormat = "$source attr-dict `:` type($source) `to` type($dest)";
diff --git a/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h b/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h
index 303a41bc471d5d9..751ee8d1b17dc37 100644
--- a/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h
+++ b/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h
@@ -349,6 +349,8 @@ class SparseTensorStorage final : public SparseTensorStorageBase {
~SparseTensorStorage() final = default;
+ void sortInPlace();
+
/// Partially specialize these getter methods based on template types.
void getPositions(std::vector<P> **out, uint64_t lvl) final {
assert(out && "Received nullptr for out parameter");
@@ -374,6 +376,24 @@ class SparseTensorStorage final : public SparseTensorStorageBase {
/// Partially specialize lexicographical insertions based on template types.
void lexInsert(const uint64_t *lvlCoords, V val) final {
assert(lvlCoords && "Received nullptr for level-coordinates");
+ // TODO: get rid of this! canonicalize all-dense "sparse" array into dense
+ // tensors.
+ bool allDense = true;
+ for (DimLevelType lt : getLvlTypes()) {
+ if (!isDenseDLT(lt)) {
+ allDense = false;
+ break;
+ }
+ }
+ if (allDense) {
+ uint64_t lvlRank = getLvlRank();
+ uint64_t valIdx = 0;
+ // Linearize the address
+ for (size_t lvl = 0; lvl < lvlRank; lvl++)
+ valIdx = valIdx * getLvlSize(lvl) + lvlCoords[lvl];
+ values[valIdx] = val;
+ return;
+ }
// First, wrap up pending insertion path.
uint64_t diffLvl = 0;
uint64_t full = 0;
@@ -956,6 +976,61 @@ SparseTensorStorage<P, C, V> *SparseTensorStorage<P, C, V>::packFromLvlBuffers(
return tensor;
}
+template <typename P, typename C, typename V>
+void SparseTensorStorage<P, C, V>::sortInPlace() {
+ uint64_t nnz = values.size();
+#ifndef NDEBUG
+ for (uint64_t l = 0; l < getLvlRank(); l++)
+ assert(nnz == coordinates[l].size());
+#endif
+
+ // In-place permutation.
+ auto applyPerm = [this](std::vector<uint64_t> &perm) {
+ size_t length = perm.size();
+ size_t lvlRank = getLvlRank();
+ // Cache for the current level coordinates.
+ std::vector<P> lvlCrds(lvlRank);
+ for (size_t i = 0; i < length; i++) {
+ size_t current = i;
+ if (i != perm[current]) {
+ for (size_t l = 0; l < lvlRank; l++)
+ lvlCrds[l] = coordinates[l][i];
+ V val = values[i];
+ // Deals with a permutation cycle.
+ while (i != perm[current]) {
+ size_t next = perm[current];
+ // Swaps the level coordinates and value.
+ for (size_t l = 0; l < lvlRank; l++)
+ coordinates[l][current] = coordinates[l][next];
+ values[current] = values[next];
+ perm[current] = current;
+ current = next;
+ }
+ for (size_t l = 0; l < lvlRank; l++)
+ coordinates[l][current] = lvlCrds[l];
+ values[current] = val;
+ perm[current] = current;
+ }
+ }
+ };
+
+ std::vector<uint64_t> sortedIdx(nnz, 0);
+ for (uint64_t i = 0; i < nnz; i++)
+ sortedIdx[i] = i;
+
+ std::sort(sortedIdx.begin(), sortedIdx.end(),
+ [this](uint64_t lhs, uint64_t rhs) {
+ for (uint64_t l = 0; l < getLvlRank(); l++) {
+ if (coordinates[l][lhs] == coordinates[l][rhs])
+ continue;
+ return coordinates[l][lhs] < coordinates[l][rhs];
+ }
+ assert(false && "duplicate coordinates");
+ });
+
+ applyPerm(sortedIdx);
+}
+
template <typename P, typename C, typename V>
SparseTensorStorage<P, C, V>::SparseTensorStorage(
uint64_t dimRank, const uint64_t *dimSizes, uint64_t lvlRank,
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index ef9d4fea68628b9..61522fb0dcd24b5 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -1060,20 +1060,12 @@ LogicalResult ConvertOp::verify() {
}
OpFoldResult ConvertOp::fold(FoldAdaptor adaptor) {
- Type dstType = getType();
- // Fold trivial dense-to-dense convert and leave trivial sparse-to-sparse
- // convert for codegen to remove. This is because we use trivial
- // sparse-to-sparse convert to tell bufferization that the sparse codegen
- // will expand the tensor buffer into sparse tensor storage.
- if (!getSparseTensorEncoding(dstType) && dstType == getSource().getType())
+ if (getType() == getSource().getType())
return getSource();
return {};
}
bool ConvertOp::directConvertable() {
- if (isSortCOOConvert())
- return false;
-
SparseTensorType srcStt = getSparseTensorType(getSource());
SparseTensorType dstStt = getSparseTensorType(getDest());
@@ -1099,15 +1091,6 @@ bool ConvertOp::directConvertable() {
return false;
}
-bool ConvertOp::isSortCOOConvert() {
- // TODO: we should instead use a different sort_coo operation to handle
- // the conversion between COOs (but with different ordering).
- return isUniqueCOOType(getSource().getType()) &&
- isUniqueCOOType(getDest().getType()) &&
- !getSparseTensorType(getSource()).isAllOrdered() &&
- getSparseTensorType(getDest()).isAllOrdered();
-}
-
LogicalResult ToPositionsOp::verify() {
auto e = getSparseTensorEncoding(getTensor().getType());
if (failed(lvlIsInBounds(getLevel(), getTensor())))
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index 78f5562b392a682..378dd9128839d7f 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -680,31 +680,26 @@ class SparseDimOpConverter : public OpConversionPattern<tensor::DimOp> {
};
// TODO: use a new SortCOO operation here instead of reusing convert op.
-struct SparseSortCOOConverter : public OpConversionPattern<ConvertOp> {
+struct SparseReorderCOOConverter : public OpConversionPattern<ReorderCOOOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
- matchAndRewrite(ConvertOp op, ConvertOpAdaptor adaptor,
+ matchAndRewrite(ReorderCOOOp op, ReorderCOOOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- // Direct conversion should have already been lowered.
- if (!op.isSortCOOConvert())
- return failure();
-
Location loc = op.getLoc();
MLIRContext *ctx = op.getContext();
- SparseTensorType srcStt = getSparseTensorType(op.getSource());
- SparseTensorType dstStt = getSparseTensorType(op.getDest());
+ SparseTensorType srcStt = getSparseTensorType(op.getInputCoo());
+ SparseTensorType dstStt = getSparseTensorType(op.getResultCoo());
- // TODO: This should be verification rules for sort_coo operation.
+ // Should have been verified.
assert(dstStt.isAllOrdered() && !srcStt.isAllOrdered() &&
isUniqueCOOType(srcStt.getRankedTensorType()) &&
isUniqueCOOType(dstStt.getRankedTensorType()));
-
assert(dstStt.hasSameDimToLvl(srcStt));
// We don't need a mutable descriptor here as we perform sorting in-place.
- auto nnz = genValMemSize(rewriter, op.getLoc(), adaptor.getSource());
- auto desc = getDescriptorFromTensorTuple(adaptor.getSource());
+ auto nnz = genValMemSize(rewriter, op.getLoc(), adaptor.getInputCoo());
+ auto desc = getDescriptorFromTensorTuple(adaptor.getInputCoo());
auto crd = desc.getAOSMemRef();
auto val = desc.getValMemRef();
@@ -715,12 +710,11 @@ struct SparseSortCOOConverter : public OpConversionPattern<ConvertOp> {
auto id = AffineMap::getMultiDimIdentityMap(srcStt.getLvlRank(), ctx);
rewriter.create<SortOp>(loc, nnz, crd, ValueRange{val}, id,
- rewriter.getIndexAttr(0),
- SparseTensorSortKind::HybridQuickSort);
+ rewriter.getIndexAttr(0), op.getAlgorithm());
// Since we do in-place sorting, the destinate tensor will have the same set
// of memrefs as the source tensor.
- rewriter.replaceOp(op, adaptor.getSource());
+ rewriter.replaceOp(op, adaptor.getInputCoo());
return success();
}
};
@@ -1147,9 +1141,6 @@ class SparseConvertConverter : public OpConversionPattern<ConvertOp> {
LogicalResult
matchAndRewrite(ConvertOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- if (op.isSortCOOConvert())
- return failure();
-
SparseTensorEncodingAttr encDst = getSparseTensorEncoding(op.getType());
SparseTensorEncodingAttr encSrc =
getSparseTensorEncoding(op.getSource().getType());
@@ -1603,7 +1594,7 @@ void mlir::populateSparseTensorCodegenPatterns(
SparseCastConverter, SparseExtractSliceConverter,
SparseTensorLoadConverter, SparseExpandConverter,
SparseCompressConverter, SparseInsertConverter,
- SparseSortCOOConverter,
+ SparseReorderCOOConverter,
SparseSliceGetterOpConverter<ToSliceOffsetOp,
StorageSpecifierKind::DimOffset>,
SparseSliceGetterOpConverter<ToSliceStrideOp,
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
index d2d7b46ab834e71..53928803b9b880b 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
@@ -540,179 +540,27 @@ class SparseTensorEmptyConverter : public OpConversionPattern<tensor::EmptyOp> {
};
/// Sparse conversion rule for the convert operator.
-class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> {
+class SparseTensorReorderCOOConverter
+ : public OpConversionPattern<ReorderCOOOp> {
public:
using OpConversionPattern::OpConversionPattern;
- SparseTensorConvertConverter(MLIRContext *context,
- SparseTensorConversionOptions o)
- : OpConversionPattern<ConvertOp>(context), options(o) {}
- SparseTensorConvertConverter(TypeConverter &typeConv, MLIRContext *context,
- SparseTensorConversionOptions o)
- : OpConversionPattern<ConvertOp>(typeConv, context), options(o) {}
LogicalResult
- matchAndRewrite(ConvertOp op, OpAdaptor adaptor,
+ matchAndRewrite(ReorderCOOOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
const Location loc = op->getLoc();
- const auto srcTp = getSparseTensorType(op.getSource());
+ const auto srcTp = getSparseTensorType(op.getInputCoo());
const auto dstTp = getSparseTensorType(op);
- if (!srcTp.hasEncoding() && !dstTp.hasEncoding())
- return failure();
- const Dimension dimRank = srcTp.getDimRank();
- const Type elemTp = srcTp.getElementType();
- const Value src = adaptor.getOperands()[0];
- if (srcTp.hasEncoding() && dstTp.hasEncoding()) {
- const auto srcEnc = srcTp.getEncoding();
- const auto dstEnc = dstTp.getEncoding();
- // This is a sparse => sparse conversion, which is handled as follows:
- // t = src->toCOO(); ; src to COO in dst order
- // dst = newSparseTensor(t)
- // Using the coordinate scheme as an intermediate does not always
- // yield the fastest conversion but avoids the need for a full
- // O(N^2) conversion matrix.
- if (dstEnc == srcEnc) {
- rewriter.replaceOp(op, adaptor.getOperands()); // hidden nop cast
- return success();
- }
- NewCallParams params(rewriter, loc);
- SmallVector<Value> dimSizes = getDimSizes(rewriter, loc, srcTp, src);
- bool useDirectConversion;
- switch (options.sparseToSparseStrategy) {
- case SparseToSparseConversionStrategy::kViaCOO:
- useDirectConversion = false;
- break;
- case SparseToSparseConversionStrategy::kDirect:
- useDirectConversion = true;
- assert(canUseDirectConversion(dstEnc.getLvlTypes()) &&
- "Unsupported target for direct sparse-to-sparse conversion");
- break;
- case SparseToSparseConversionStrategy::kAuto:
- useDirectConversion = canUseDirectConversion(dstEnc.getLvlTypes());
- break;
- }
- if (useDirectConversion) {
- rewriter.replaceOp(
- op, params.genBuffers(srcTp.withEncoding(dstEnc), dimSizes)
- .genNewCall(Action::kSparseToSparse, src));
- } else { // use via-COO conversion.
- // Set up encoding with right mix of src and dst so that the two
- // method calls can share most parameters, while still providing
- // the correct sparsity information to either of them.
- const auto mixedEnc =
- dstEnc.withBitWidths(srcEnc.getPosWidth(), srcEnc.getCrdWidth());
- // TODO: This is the only place where `kToCOO` (or `kToIterator`)
- // is called with a non-identity permutation. Is there any clean
- // way to push the permutation over to the `kFromCOO` side instead?
- Value coo = params.genBuffers(srcTp.withEncoding(mixedEnc), dimSizes)
- .genNewCall(Action::kToCOO, src);
- Value dst = params.setTemplateTypes(srcTp.withEncoding(dstEnc))
- .genNewCall(Action::kFromCOO, coo);
- genDelCOOCall(rewriter, loc, elemTp, coo);
- rewriter.replaceOp(op, dst);
- }
- return success();
- }
- if (srcTp.hasEncoding() && !dstTp.hasEncoding()) {
- const auto srcEnc = srcTp.getEncoding();
- // This is sparse => dense conversion, which is handled as follows:
- // dst = new Tensor(0);
- // iter = new SparseTensorIterator(src);
- // while (elem = iter->getNext()) {
- // dst[elem.coords] = elem.value;
- // }
- // delete iter;
- //
- // Fabricate a no-permutation encoding for NewCallParams
- // The position/coordinate types must be those of `src`.
- // The dimLevelTypes aren't actually used by Action::kToIterator.
- const auto dstEnc = SparseTensorEncodingAttr::get(
- op->getContext(),
- SmallVector<DimLevelType>(dimRank, DimLevelType::Dense), AffineMap(),
- AffineMap(), srcEnc.getPosWidth(), srcEnc.getCrdWidth());
- SmallVector<Value> dimSizes = getDimSizes(rewriter, loc, srcTp, src);
- Value iter = NewCallParams(rewriter, loc)
- .genBuffers(dstTp.withEncoding(dstEnc), dimSizes)
- .genNewCall(Action::kToIterator, src);
- const Type iTp = rewriter.getIndexType();
- Value dimCoords = genAlloca(rewriter, loc, dimRank, iTp);
- Value elemPtr = genAllocaScalar(rewriter, loc, elemTp);
- // TODO: Dense buffers should be allocated/deallocated via the callback
- // in BufferizationOptions.
- Value dst = allocDenseTensor(rewriter, loc, dstTp, dimSizes);
- const SmallVector<Value> noArgs;
- const SmallVector<Type> noTypes;
- auto whileOp = rewriter.create<scf::WhileOp>(loc, noTypes, noArgs);
- Block *before = rewriter.createBlock(&whileOp.getBefore(), {}, noTypes);
- rewriter.setInsertionPointToEnd(before);
- Value cond = genGetNextCall(rewriter, loc, iter, dimCoords, elemPtr);
- rewriter.create<scf::ConditionOp>(loc, cond, before->getArguments());
- Block *after = rewriter.createBlock(&whileOp.getAfter(), {}, noTypes);
- rewriter.setInsertionPointToStart(after);
- const auto dcvs = loadAll(rewriter, loc, dimRank, dimCoords);
- insertScalarIntoDenseTensor(rewriter, loc, elemPtr, dst, dcvs);
- rewriter.create<scf::YieldOp>(loc);
- rewriter.setInsertionPointAfter(whileOp);
- genDelIteratorCall(rewriter, loc, elemTp, iter);
- rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(
- op, dstTp.getRankedTensorType(), dst);
- return success();
- }
- assert(!srcTp.hasEncoding() && dstTp.hasEncoding());
- // This is a dense => sparse conversion or a sparse constant in COO =>
- // sparse conversion, which is handled as follows:
- // t = newSparseCOO()
- // ...code to fill the COO tensor t...
- // s = newSparseTensor(t)
- //
- // To fill the COO tensor from a dense tensor:
- // for i1 in dim1
- // ..
- // for ik in dimk
- // val = a[i1,..,ik]
- // if val != 0
- // t->add(val, [i1,..,ik], [p1,..,pk])
- //
- // To fill the COO tensor from a sparse constant in COO format:
- // for i in range(NNZ)
- // val = values[i]
- // [i1,..,ik] = coordinates[i]
- // t->add(val, [i1,..,ik], [p1,..,pk])
- //
- // Note that the dense tensor traversal code is actually implemented
- // using MLIR IR to avoid having to expose too much low-level
- // memref traversal details to the runtime support library.
- // Also note that the code below only generates the "new" ops and
- // the loop-nest per se; whereas the entire body of the innermost
- // loop is generated by genAddElt().
- SmallVector<Value> dimSizes;
- sizesFromSrc(rewriter, dimSizes, loc, src);
+ const Value src = adaptor.getInputCoo();
+
NewCallParams params(rewriter, loc);
- Value coo =
- params.genBuffers(dstTp, dimSizes).genNewCall(Action::kEmptyCOO);
- const Type iTp = rewriter.getIndexType();
- Value dimCoords = genAlloca(rewriter, loc, dimRank, iTp);
- Value dimToLvl = params.getDimToLvl();
- Value elemPtr = genAllocaScalar(rewriter, loc, elemTp);
- genDenseTensorOrSparseConstantIterLoop(
- rewriter, loc, src, dimRank,
- [&](OpBuilder &builder, Location loc, Value val, ValueRange dcvs) {
- assert(dcvs.size() == static_cast<size_t>(dimRank));
- storeAll(builder, loc, dimCoords, dcvs);
- builder.create<memref::StoreOp>(loc, val, elemPtr);
- genAddEltCall(builder, loc, elemTp, coo, elemPtr, dimCoords,
- dimToLvl);
- });
- // Final call to construct sparse tensor storage.
- Value dst = params.genNewCall(Action::kFromCOO, coo);
- genDelCOOCall(rewriter, loc, elemTp, coo);
- rewriter.replaceOp(op, dst);
+ SmallVector<Value> dimSizes = getDimSizes(rewriter, loc, srcTp, src);
+ rewriter.replaceOp(op, params.genBuffers(dstTp, dimSizes)
+ .genNewCall(Action::kSortCOOInPlace, src));
+
return success();
}
-
-private:
- /// Options to control sparse code generation.
- SparseTensorConversionOptions options;
};
/// Sparse conversion rule for the dealloc operator.
@@ -1016,16 +864,15 @@ mlir::SparseTensorTypeToPtrConverter::SparseTensorTypeToPtrConverter() {
void mlir::populateSparseTensorConversionPatterns(
TypeConverter &typeConverter, RewritePatternSet &patterns,
const SparseTensorConversionOptions &options) {
- patterns.add<SparseReturnConverter, SparseTensorToDimSizeConverter,
- SparseCastConverter, SparseTensorNewConverter,
- SparseTensorAllocConverter, SparseTensorEmptyConverter,
- SparseTensorDeallocConverter, SparseTensorToPositionsConverter,
- SparseTensorToCoordinatesConverter,
- SparseTensorToValuesConverter, SparseNumberOfEntriesConverter,
- SparseTensorLoadConverter, SparseTensorInsertConverter,
- SparseTensorExpandConverter, SparseTensorCompressConverter,
- SparseTensorOutConverter, SparseTensorAssembleConverter>(
- typeConverter, patterns.getContext());
- patterns.add<SparseTensorConvertConverter>(typeConverter,
- patterns.getContext(), options);
+ patterns
+ .add<SparseReturnConverter, SparseTensorToDimSizeConverter,
+ SparseCastConverter, SparseTensorNewConverter,
+ SparseTensorAllocConverter, SparseTensorEmptyConverter,
+ SparseTensorDeallocConverter, SparseTensorReorderCOOConverter,
+ SparseTensorToPositionsConverter, SparseTensorToCoordinatesConverter,
+ SparseTensorToValuesConverter, SparseNumberOfEntriesConverter,
+ SparseTensorLoadConverter, SparseTensorInsertConverter,
+ SparseTensorExpandConverter, SparseTensorCompressConverter,
+ SparseTensorOutConverter, SparseTensorAssembleConverter>(
+ typeConverter, patterns.getContext());
}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index 592852f87ba1e04..f16d08b86a1a14f 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -971,7 +971,10 @@ struct ConcatenateRewriter : public OpRewritePattern<ConcatenateOp> {
Value tmpCoo = dst;
Type dstCooTp = getCOOType(dstRTT, true);
// TODO: this should be a sort_coo operation.
- dst = rewriter.create<ConvertOp>(loc, dstCooTp, tmpCoo).getResult();
+ dst = rewriter
+ .create<ReorderCOOOp>(loc, dstCooTp, tmpCoo,
+ SparseTensorSortKind::HybridQuickSort)
+ .getResult();
dst = rewriter.create<ConvertOp>(loc, dstRTT, dst).getResult();
rewriter.create<DeallocTensorOp>(loc, tmpCoo);
}
@@ -1028,11 +1031,8 @@ struct DirectConvertRewriter : public OpRewritePattern<ConvertOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(ConvertOp op,
PatternRewriter &rewriter) const override {
- if (!op.directConvertable() && !op.isSortCOOConvert())
- return op.emitError("ConvertOp not in conanical form.");
-
- if (op.isSortCOOConvert())
- return failure();
+ if (!op.directConvertable())
+ return op.emitError("ConvertOp not staged.");
// TODO: Maybe we want a different operation for this too.
auto encDst = getSparseTensorEncoding(op.getType());
@@ -1338,12 +1338,8 @@ void mlir::populatePostSparsificationRewriting(RewritePatternSet &patterns,
TensorReshapeRewriter>(patterns.getContext());
if (enableForeach)
patterns.add<ForeachRewriter>(patterns.getContext());
-
- if (!enableRT) {
+ if (enableConvert)
+ patterns.add<DirectConvertRewriter>(patterns.getContext());
+ if (!enableRT)
patterns.add<NewRewriter, OutRewriter>(patterns.getContext());
- // TODO: Move this to a common path for both lib/codegen when libgen support
- // lowering sort_coo.
- if (enableConvert)
- patterns.add<DirectConvertRewriter>(patterns.getContext());
- }
}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/StageSparseOperations.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/StageSparseOperations.cpp
index 4ab4b05a7a420c7..acdb888abe2b028 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/StageSparseOperations.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/StageSparseOperations.cpp
@@ -22,8 +22,7 @@ struct StageUnorderedConvert : public OpRewritePattern<ConvertOp> {
PatternRewriter &rewriter) const override {
// TODO: Implement it as an Interface, this can be reused from other
// operations too (e.g., concatenate, reshape, etc).
-
- if (op.directConvertable() || op.isSortCOOConvert())
+ if (op.directConvertable())
return failure();
Location loc = op.getLoc();
@@ -40,13 +39,16 @@ struct StageUnorderedConvert : public OpRewritePattern<ConvertOp> {
Type srcCOOTp = getCOOFromTypeWithOrdering(
dstStt.getRankedTensorType(), dstStt.getDimToLvl(), /*ordered=*/false);
- Value srcCOO = rewriter.create<ConvertOp>(loc, srcCOOTp, op.getSource());
+ Value srcCOO = op.getSource();
+ if (srcCOO.getType() != srcCOOTp)
+ srcCOO = rewriter.create<ConvertOp>(loc, srcCOOTp, op.getSource());
// -> sort
Type dstCOOTp = getCOOFromTypeWithOrdering(
dstStt.getRankedTensorType(), dstStt.getDimToLvl(), /*ordered=*/true);
// TODO: this should be a sort_coo operation.
- Value dstCOO = rewriter.create<ConvertOp>(loc, dstCOOTp, srcCOO);
+ Value dstCOO = rewriter.create<ReorderCOOOp>(
+ loc, dstCOOTp, srcCOO, SparseTensorSortKind::HybridQuickSort);
// -> dest.
if (dstCOO.getType() == op.getType()) {
diff --git a/mlir/lib/ExecutionEngine/SparseTensorRuntime.cpp b/mlir/lib/ExecutionEngine/SparseTensorRuntime.cpp
index bc6d4ad2c740189..83ceecaf5a30ee1 100644
--- a/mlir/lib/ExecutionEngine/SparseTensorRuntime.cpp
+++ b/mlir/lib/ExecutionEngine/SparseTensorRuntime.cpp
@@ -213,6 +213,12 @@ extern "C" {
dimRank, dimSizes, lvlRank, lvlSizes, lvlTypes, dim2lvl, lvl2dim, \
dimRank, buffers); \
} \
+ case Action::kSortCOOInPlace: { \
+ assert(ptr && "Received nullptr for SparseTensorStorage object"); \
+ auto &tensor = *static_cast<SparseTensorStorage<P, C, V> *>(ptr); \
+ tensor.sortInPlace(); \
+ return ptr; \
+ } \
} \
MLIR_SPARSETENSOR_FATAL("unknown action: %d\n", \
static_cast<uint32_t>(action)); \
diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_conversion_sparse2sparse.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_conversion_sparse2sparse.mlir
index 1c74b6827d980ef..7d7dd5fd2d42d98 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_conversion_sparse2sparse.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_conversion_sparse2sparse.mlir
@@ -49,13 +49,7 @@
}>
#SingletonTensor1 = #sparse_tensor.encoding<{
- map = (d0, d1, d2) -> (d0 : dense, d1 : compressed, d2 : singleton)
-
-}>
-
-// This also checks the compressed->dense conversion (when there are zeros).
-#SingletonTensor2 = #sparse_tensor.encoding<{
- map = (d0, d1, d2) -> (d0 : dense, d1 : dense, d2 : singleton)
+ map = (d0, d1, d2) -> (d0 : dense, d1 : compressed(nonunique), d2 : singleton)
}>
@@ -97,44 +91,34 @@ module {
// Convert dense tensor directly to various sparse tensors.
//
%s1 = sparse_tensor.convert %src : tensor<2x3x4xf64> to tensor<2x3x4xf64, #Tensor1>
- %s2 = sparse_tensor.convert %src : tensor<2x3x4xf64> to tensor<2x3x4xf64, #Tensor2>
%s3 = sparse_tensor.convert %src : tensor<2x3x4xf64> to tensor<2x3x4xf64, #Tensor3>
//
// Convert sparse tensor directly to another sparse format.
//
%t13 = sparse_tensor.convert %s1 : tensor<2x3x4xf64, #Tensor1> to tensor<2x3x4xf64, #Tensor3>
- %t21 = sparse_tensor.convert %s2 : tensor<2x3x4xf64, #Tensor2> to tensor<2x3x4xf64, #Tensor1>
- %t23 = sparse_tensor.convert %s2 : tensor<2x3x4xf64, #Tensor2> to tensor<2x3x4xf64, #Tensor3>
%t31 = sparse_tensor.convert %s3 : tensor<2x3x4xf64, #Tensor3> to tensor<2x3x4xf64, #Tensor1>
//
// Convert sparse tensor back to dense.
//
%d13 = sparse_tensor.convert %t13 : tensor<2x3x4xf64, #Tensor3> to tensor<2x3x4xf64>
- %d21 = sparse_tensor.convert %t21 : tensor<2x3x4xf64, #Tensor1> to tensor<2x3x4xf64>
- %d23 = sparse_tensor.convert %t23 : tensor<2x3x4xf64, #Tensor3> to tensor<2x3x4xf64>
%d31 = sparse_tensor.convert %t31 : tensor<2x3x4xf64, #Tensor1> to tensor<2x3x4xf64>
//
// Check round-trip equality. And release dense tensors.
//
- // CHECK-COUNT-5: ( ( ( 1, 2, 3, 4 ), ( 5, 6, 7, 8 ), ( 9, 10, 11, 12 ) ), ( ( 13, 14, 15, 16 ), ( 17, 18, 19, 20 ), ( 21, 22, 23, 24 ) ) )
+ // CHECK-COUNT-3: ( ( ( 1, 2, 3, 4 ), ( 5, 6, 7, 8 ), ( 9, 10, 11, 12 ) ), ( ( 13, 14, 15, 16 ), ( 17, 18, 19, 20 ), ( 21, 22, 23, 24 ) ) )
call @dump(%src) : (tensor<2x3x4xf64>) -> ()
call @dump(%d13) : (tensor<2x3x4xf64>) -> ()
- call @dump(%d21) : (tensor<2x3x4xf64>) -> ()
- call @dump(%d23) : (tensor<2x3x4xf64>) -> ()
call @dump(%d31) : (tensor<2x3x4xf64>) -> ()
//
// Release sparse tensors.
//
bufferization.dealloc_tensor %t13 : tensor<2x3x4xf64, #Tensor3>
- bufferization.dealloc_tensor %t21 : tensor<2x3x4xf64, #Tensor1>
- bufferization.dealloc_tensor %t23 : tensor<2x3x4xf64, #Tensor3>
bufferization.dealloc_tensor %t31 : tensor<2x3x4xf64, #Tensor1>
bufferization.dealloc_tensor %s1 : tensor<2x3x4xf64, #Tensor1>
- bufferization.dealloc_tensor %s2 : tensor<2x3x4xf64, #Tensor2>
bufferization.dealloc_tensor %s3 : tensor<2x3x4xf64, #Tensor3>
return
@@ -160,52 +144,34 @@ module {
// Convert dense tensor directly to various sparse tensors.
//
%s1 = sparse_tensor.convert %src : tensor<2x3x4xf64> to tensor<2x3x4xf64, #SingletonTensor1>
- %s2 = sparse_tensor.convert %src : tensor<2x3x4xf64> to tensor<2x3x4xf64, #SingletonTensor2>
%s3 = sparse_tensor.convert %src : tensor<2x3x4xf64> to tensor<2x3x4xf64, #SingletonTensor3>
//
// Convert sparse tensor directly to another sparse format.
//
- %t12 = sparse_tensor.convert %s1 : tensor<2x3x4xf64, #SingletonTensor1> to tensor<2x3x4xf64, #SingletonTensor2>
%t13 = sparse_tensor.convert %s1 : tensor<2x3x4xf64, #SingletonTensor1> to tensor<2x3x4xf64, #SingletonTensor3>
- %t21 = sparse_tensor.convert %s2 : tensor<2x3x4xf64, #SingletonTensor2> to tensor<2x3x4xf64, #SingletonTensor1>
- %t23 = sparse_tensor.convert %s2 : tensor<2x3x4xf64, #SingletonTensor2> to tensor<2x3x4xf64, #SingletonTensor3>
%t31 = sparse_tensor.convert %s3 : tensor<2x3x4xf64, #SingletonTensor3> to tensor<2x3x4xf64, #SingletonTensor1>
- %t32 = sparse_tensor.convert %s3 : tensor<2x3x4xf64, #SingletonTensor3> to tensor<2x3x4xf64, #SingletonTensor2>
//
// Convert sparse tensor back to dense.
//
- %d12 = sparse_tensor.convert %t12 : tensor<2x3x4xf64, #SingletonTensor2> to tensor<2x3x4xf64>
%d13 = sparse_tensor.convert %t13 : tensor<2x3x4xf64, #SingletonTensor3> to tensor<2x3x4xf64>
- %d21 = sparse_tensor.convert %t21 : tensor<2x3x4xf64, #SingletonTensor1> to tensor<2x3x4xf64>
- %d23 = sparse_tensor.convert %t23 : tensor<2x3x4xf64, #SingletonTensor3> to tensor<2x3x4xf64>
%d31 = sparse_tensor.convert %t31 : tensor<2x3x4xf64, #SingletonTensor1> to tensor<2x3x4xf64>
- %d32 = sparse_tensor.convert %t32 : tensor<2x3x4xf64, #SingletonTensor2> to tensor<2x3x4xf64>
//
// Check round-trip equality. And release dense tensors.
//
- // CHECK-COUNT-7: ( ( ( 1, 0, 0, 0 ), ( 0, 6, 0, 0 ), ( 0, 0, 11, 0 ) ), ( ( 0, 14, 0, 0 ), ( 0, 0, 0, 20 ), ( 21, 0, 0, 0 ) ) )
+ // CHECK-COUNT-3: ( ( ( 1, 0, 0, 0 ), ( 0, 6, 0, 0 ), ( 0, 0, 11, 0 ) ), ( ( 0, 14, 0, 0 ), ( 0, 0, 0, 20 ), ( 21, 0, 0, 0 ) ) )
call @dump(%src) : (tensor<2x3x4xf64>) -> ()
- call @dump(%d12) : (tensor<2x3x4xf64>) -> ()
call @dump(%d13) : (tensor<2x3x4xf64>) -> ()
- call @dump(%d21) : (tensor<2x3x4xf64>) -> ()
- call @dump(%d23) : (tensor<2x3x4xf64>) -> ()
call @dump(%d31) : (tensor<2x3x4xf64>) -> ()
- call @dump(%d32) : (tensor<2x3x4xf64>) -> ()
//
// Release sparse tensors.
//
- bufferization.dealloc_tensor %t12 : tensor<2x3x4xf64, #SingletonTensor2>
bufferization.dealloc_tensor %t13 : tensor<2x3x4xf64, #SingletonTensor3>
- bufferization.dealloc_tensor %t21 : tensor<2x3x4xf64, #SingletonTensor1>
- bufferization.dealloc_tensor %t23 : tensor<2x3x4xf64, #SingletonTensor3>
bufferization.dealloc_tensor %t31 : tensor<2x3x4xf64, #SingletonTensor1>
- bufferization.dealloc_tensor %t32 : tensor<2x3x4xf64, #SingletonTensor2>
bufferization.dealloc_tensor %s1 : tensor<2x3x4xf64, #SingletonTensor1>
- bufferization.dealloc_tensor %s2 : tensor<2x3x4xf64, #SingletonTensor2>
bufferization.dealloc_tensor %s3 : tensor<2x3x4xf64, #SingletonTensor3>
return
>From 58799d60e03d6acb9b4ad0592a779ef7655d59c6 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Thu, 12 Oct 2023 17:10:25 +0000
Subject: [PATCH 2/5] pass check tests
---
.../SparseTensor/convert_dense2sparse.mlir | 327 ++---------------
.../SparseTensor/convert_sparse2dense.mlir | 341 +++---------------
.../SparseTensor/convert_sparse2sparse.mlir | 177 ++-------
.../Dialect/SparseTensor/sparse_concat.mlir | 173 +--------
4 files changed, 115 insertions(+), 903 deletions(-)
diff --git a/mlir/test/Dialect/SparseTensor/convert_dense2sparse.mlir b/mlir/test/Dialect/SparseTensor/convert_dense2sparse.mlir
index 1a69c80f7ecadfd..4dba16df39f5c65 100644
--- a/mlir/test/Dialect/SparseTensor/convert_dense2sparse.mlir
+++ b/mlir/test/Dialect/SparseTensor/convert_dense2sparse.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s --sparse-tensor-conversion --canonicalize --cse | FileCheck %s
+// RUN: mlir-opt %s --stage-sparse-ops --post-sparsification-rewrite="enable-foreach=false" --canonicalize --cse | FileCheck %s
#SparseVector = #sparse_tensor.encoding<{
map = (d0) -> (d0 : compressed)
@@ -16,187 +16,45 @@
map = (d0, d1, d2) -> (d2 : dense, d0 : compressed, d1 : compressed)
}>
-// CHECK-LABEL: func.func @sparse_convert_1d(
-// CHECK-SAME: %[[VAL_0:.*]]: tensor<?xi32>) -> !llvm.ptr<i8> {
-// CHECK-DAG: %[[VAL_1:.*]] = arith.constant 2 : i32
-// CHECK-DAG: %[[VAL_2:.*]] = arith.constant 4 : i32
-// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 6 : i32
-// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 0 : i32
-// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 1 : index
-// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 8 : i8
-// CHECK-DAG: %[[VAL_7:.*]] = arith.constant 0 : index
-// CHECK: %[[VAL_8:.*]] = tensor.dim %[[VAL_0]], %[[VAL_7]] : tensor<?xi32>
-// CHECK: %[[VAL_9:.*]] = memref.alloca() : memref<1xi8>
-// CHECK: %[[VAL_10:.*]] = memref.cast %[[VAL_9]] : memref<1xi8> to memref<?xi8>
-// CHECK: memref.store %[[VAL_6]], %[[VAL_9]]{{\[}}%[[VAL_7]]] : memref<1xi8>
-// CHECK: %[[VAL_11:.*]] = memref.alloca() : memref<1xindex>
-// CHECK: %[[VAL_12:.*]] = memref.cast %[[VAL_11]] : memref<1xindex> to memref<?xindex>
-// CHECK: memref.store %[[VAL_8]], %[[VAL_11]]{{\[}}%[[VAL_7]]] : memref<1xindex>
-// CHECK: %[[VAL_13:.*]] = memref.alloca() : memref<1xindex>
-// CHECK: %[[VAL_14:.*]] = memref.cast %[[VAL_13]] : memref<1xindex> to memref<?xindex>
-// CHECK: memref.store %[[VAL_7]], %[[VAL_13]]{{\[}}%[[VAL_7]]] : memref<1xindex>
-// CHECK: %[[VAL_15:.*]] = llvm.mlir.zero : !llvm.ptr<i8>
-// CHECK: %[[VAL_16:.*]] = call @newSparseTensor(%[[VAL_12]], %[[VAL_12]], %[[VAL_10]], %[[VAL_14]], %[[VAL_14]], %[[VAL_4]], %[[VAL_4]], %[[VAL_3]], %[[VAL_2]], %[[VAL_15]]) : (memref<?xindex>, memref<?xindex>, memref<?xi8>, memref<?xindex>, memref<?xindex>, i32, i32, i32, i32, !llvm.ptr<i8>) -> !llvm.ptr<i8>
-// CHECK: %[[VAL_17:.*]] = memref.alloca() : memref<1xindex>
-// CHECK: %[[VAL_18:.*]] = memref.cast %[[VAL_17]] : memref<1xindex> to memref<?xindex>
-// CHECK: %[[VAL_19:.*]] = memref.alloca() : memref<i32>
-// CHECK: scf.for %[[VAL_20:.*]] = %[[VAL_7]] to %[[VAL_8]] step %[[VAL_5]] {
-// CHECK: %[[VAL_21:.*]] = tensor.extract %[[VAL_0]]{{\[}}%[[VAL_20]]] : tensor<?xi32>
-// CHECK: %[[VAL_22:.*]] = arith.cmpi ne, %[[VAL_21]], %[[VAL_4]] : i32
-// CHECK: scf.if %[[VAL_22]] {
-// CHECK: memref.store %[[VAL_20]], %[[VAL_17]]{{\[}}%[[VAL_7]]] : memref<1xindex>
-// CHECK: memref.store %[[VAL_21]], %[[VAL_19]][] : memref<i32>
-// CHECK: %[[VAL_23:.*]] = func.call @addEltI32(%[[VAL_16]], %[[VAL_19]], %[[VAL_18]], %[[VAL_14]]) : (!llvm.ptr<i8>, memref<i32>, memref<?xindex>, memref<?xindex>) -> !llvm.ptr<i8>
-// CHECK: }
-// CHECK: }
-// CHECK: %[[VAL_24:.*]] = call @newSparseTensor(%[[VAL_12]], %[[VAL_12]], %[[VAL_10]], %[[VAL_14]], %[[VAL_14]], %[[VAL_4]], %[[VAL_4]], %[[VAL_3]], %[[VAL_1]], %[[VAL_16]]) : (memref<?xindex>, memref<?xindex>, memref<?xi8>, memref<?xindex>, memref<?xindex>, i32, i32, i32, i32, !llvm.ptr<i8>) -> !llvm.ptr<i8>
-// CHECK: call @delSparseTensorCOOI32(%[[VAL_16]]) : (!llvm.ptr<i8>) -> ()
-// CHECK: return %[[VAL_24]] : !llvm.ptr<i8>
-// CHECK: }
+// CHECK-LABEL: func.func @sparse_convert_1d
+// CHECK: sparse_tensor.foreach
+// CHECK: scf.if
+// CHECK: sparse_tensor.insert
+// CHECK-NOT: sparse_tensor.reorder_coo
+// CHECK: sparse_tensor.load
func.func @sparse_convert_1d(%arg0: tensor<?xi32>) -> tensor<?xi32, #SparseVector> {
%0 = sparse_tensor.convert %arg0 : tensor<?xi32> to tensor<?xi32, #SparseVector>
return %0 : tensor<?xi32, #SparseVector>
}
-// CHECK-LABEL: func.func @sparse_convert_complex(
-// CHECK-SAME: %[[VAL_0:.*]]: tensor<100xcomplex<f64>>) -> !llvm.ptr<i8> {
-// CHECK-DAG: %[[VAL_1:.*]] = arith.constant 2 : i32
-// CHECK-DAG: %[[VAL_2:.*]] = complex.constant [0.000000e+00, 0.000000e+00] : complex<f64>
-// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 4 : i32
-// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 9 : i32
-// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 0 : i32
-// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 0 : index
-// CHECK-DAG: %[[VAL_7:.*]] = arith.constant 100 : index
-// CHECK-DAG: %[[VAL_8:.*]] = arith.constant 8 : i8
-// CHECK-DAG: %[[VAL_9:.*]] = arith.constant 1 : index
-// CHECK: %[[VAL_10:.*]] = memref.alloca() : memref<1xi8>
-// CHECK: %[[VAL_11:.*]] = memref.cast %[[VAL_10]] : memref<1xi8> to memref<?xi8>
-// CHECK: memref.store %[[VAL_8]], %[[VAL_10]]{{\[}}%[[VAL_6]]] : memref<1xi8>
-// CHECK: %[[VAL_12:.*]] = memref.alloca() : memref<1xindex>
-// CHECK: %[[VAL_13:.*]] = memref.cast %[[VAL_12]] : memref<1xindex> to memref<?xindex>
-// CHECK: memref.store %[[VAL_7]], %[[VAL_12]]{{\[}}%[[VAL_6]]] : memref<1xindex>
-// CHECK: %[[VAL_14:.*]] = memref.alloca() : memref<1xindex>
-// CHECK: %[[VAL_15:.*]] = memref.cast %[[VAL_14]] : memref<1xindex> to memref<?xindex>
-// CHECK: memref.store %[[VAL_6]], %[[VAL_14]]{{\[}}%[[VAL_6]]] : memref<1xindex>
-// CHECK: %[[VAL_16:.*]] = llvm.mlir.zero : !llvm.ptr<i8>
-// CHECK: %[[VAL_17:.*]] = call @newSparseTensor(%[[VAL_13]], %[[VAL_13]], %[[VAL_11]], %[[VAL_15]], %[[VAL_15]], %[[VAL_5]], %[[VAL_5]], %[[VAL_4]], %[[VAL_3]], %[[VAL_16]]) : (memref<?xindex>, memref<?xindex>, memref<?xi8>, memref<?xindex>, memref<?xindex>, i32, i32, i32, i32, !llvm.ptr<i8>) -> !llvm.ptr<i8>
-// CHECK: %[[VAL_18:.*]] = memref.alloca() : memref<1xindex>
-// CHECK: %[[VAL_19:.*]] = memref.cast %[[VAL_18]] : memref<1xindex> to memref<?xindex>
-// CHECK: %[[VAL_20:.*]] = memref.alloca() : memref<complex<f64>>
-// CHECK: scf.for %[[VAL_21:.*]] = %[[VAL_6]] to %[[VAL_7]] step %[[VAL_9]] {
-// CHECK: %[[VAL_22:.*]] = tensor.extract %[[VAL_0]]{{\[}}%[[VAL_21]]] : tensor<100xcomplex<f64>>
-// CHECK: %[[VAL_23:.*]] = complex.neq %[[VAL_22]], %[[VAL_2]] : complex<f64>
-// CHECK: scf.if %[[VAL_23]] {
-// CHECK: memref.store %[[VAL_21]], %[[VAL_18]]{{\[}}%[[VAL_6]]] : memref<1xindex>
-// CHECK: memref.store %[[VAL_22]], %[[VAL_20]][] : memref<complex<f64>>
-// CHECK: %[[VAL_24:.*]] = func.call @addEltC64(%[[VAL_17]], %[[VAL_20]], %[[VAL_19]], %[[VAL_15]]) : (!llvm.ptr<i8>, memref<complex<f64>>, memref<?xindex>, memref<?xindex>) -> !llvm.ptr<i8>
-// CHECK: }
-// CHECK: }
-// CHECK: %[[VAL_25:.*]] = call @newSparseTensor(%[[VAL_13]], %[[VAL_13]], %[[VAL_11]], %[[VAL_15]], %[[VAL_15]], %[[VAL_5]], %[[VAL_5]], %[[VAL_4]], %[[VAL_1]], %[[VAL_17]]) : (memref<?xindex>, memref<?xindex>, memref<?xi8>, memref<?xindex>, memref<?xindex>, i32, i32, i32, i32, !llvm.ptr<i8>) -> !llvm.ptr<i8>
-// CHECK: call @delSparseTensorCOOC64(%[[VAL_17]]) : (!llvm.ptr<i8>) -> ()
-// CHECK: return %[[VAL_25]] : !llvm.ptr<i8>
-// CHECK: }
+// CHECK-LABEL: func.func @sparse_convert_complex
+// CHECK: sparse_tensor.foreach
+// CHECK: scf.if
+// CHECK: sparse_tensor.insert
+// CHECK-NOT: sparse_tensor.reorder_coo
+// CHECK: sparse_tensor.load
func.func @sparse_convert_complex(%arg0: tensor<100xcomplex<f64>>) -> tensor<100xcomplex<f64>, #SparseVector> {
%0 = sparse_tensor.convert %arg0 : tensor<100xcomplex<f64>> to tensor<100xcomplex<f64>, #SparseVector>
return %0 : tensor<100xcomplex<f64>, #SparseVector>
}
-// CHECK-LABEL: func.func @sparse_convert_2d(
-// CHECK-SAME: %[[VAL_0:.*]]: tensor<2x4xf64>) -> !llvm.ptr<i8> {
-// CHECK-DAG: %[[VAL_1:.*]] = arith.constant 2 : i32
-// CHECK-DAG: %[[VAL_2:.*]] = arith.constant 0.000000e+00 : f64
-// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 4 : i32
-// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 1 : i32
-// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 0 : i32
-// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 0 : index
-// CHECK-DAG: %[[VAL_7:.*]] = arith.constant 2 : index
-// CHECK-DAG: %[[VAL_8:.*]] = arith.constant 1 : index
-// CHECK-DAG: %[[VAL_9:.*]] = arith.constant 4 : index
-// CHECK-DAG: %[[VAL_10:.*]] = arith.constant 4 : i8
-// CHECK-DAG: %[[VAL_11:.*]] = arith.constant 8 : i8
-// CHECK: %[[VAL_12:.*]] = memref.alloca() : memref<2xi8>
-// CHECK: %[[VAL_13:.*]] = memref.cast %[[VAL_12]] : memref<2xi8> to memref<?xi8>
-// CHECK: memref.store %[[VAL_10]], %[[VAL_12]]{{\[}}%[[VAL_6]]] : memref<2xi8>
-// CHECK: memref.store %[[VAL_11]], %[[VAL_12]]{{\[}}%[[VAL_8]]] : memref<2xi8>
-// CHECK: %[[VAL_14:.*]] = memref.alloca() : memref<2xindex>
-// CHECK: %[[VAL_15:.*]] = memref.cast %[[VAL_14]] : memref<2xindex> to memref<?xindex>
-// CHECK: memref.store %[[VAL_7]], %[[VAL_14]]{{\[}}%[[VAL_6]]] : memref<2xindex>
-// CHECK: memref.store %[[VAL_9]], %[[VAL_14]]{{\[}}%[[VAL_8]]] : memref<2xindex>
-// CHECK: %[[VAL_16:.*]] = memref.alloca() : memref<2xindex>
-// CHECK: %[[VAL_17:.*]] = memref.cast %[[VAL_16]] : memref<2xindex> to memref<?xindex>
-// CHECK: memref.store %[[VAL_6]], %[[VAL_16]]{{\[}}%[[VAL_6]]] : memref<2xindex>
-// CHECK: memref.store %[[VAL_8]], %[[VAL_16]]{{\[}}%[[VAL_8]]] : memref<2xindex>
-// CHECK: %[[VAL_18:.*]] = llvm.mlir.zero : !llvm.ptr<i8>
-// CHECK: %[[VAL_19:.*]] = call @newSparseTensor(%[[VAL_15]], %[[VAL_15]], %[[VAL_13]], %[[VAL_17]], %[[VAL_17]], %[[VAL_5]], %[[VAL_5]], %[[VAL_4]], %[[VAL_3]], %[[VAL_18]]) : (memref<?xindex>, memref<?xindex>, memref<?xi8>, memref<?xindex>, memref<?xindex>, i32, i32, i32, i32, !llvm.ptr<i8>) -> !llvm.ptr<i8>
-// CHECK: %[[VAL_20:.*]] = memref.alloca() : memref<2xindex>
-// CHECK: %[[VAL_21:.*]] = memref.cast %[[VAL_20]] : memref<2xindex> to memref<?xindex>
-// CHECK: %[[VAL_22:.*]] = memref.alloca() : memref<f64>
-// CHECK: scf.for %[[VAL_23:.*]] = %[[VAL_6]] to %[[VAL_7]] step %[[VAL_8]] {
-// CHECK: scf.for %[[VAL_24:.*]] = %[[VAL_6]] to %[[VAL_9]] step %[[VAL_8]] {
-// CHECK: %[[VAL_25:.*]] = tensor.extract %[[VAL_0]]{{\[}}%[[VAL_23]], %[[VAL_24]]] : tensor<2x4xf64>
-// CHECK: %[[VAL_26:.*]] = arith.cmpf une, %[[VAL_25]], %[[VAL_2]] : f64
-// CHECK: scf.if %[[VAL_26]] {
-// CHECK: memref.store %[[VAL_23]], %[[VAL_20]]{{\[}}%[[VAL_6]]] : memref<2xindex>
-// CHECK: memref.store %[[VAL_24]], %[[VAL_20]]{{\[}}%[[VAL_8]]] : memref<2xindex>
-// CHECK: memref.store %[[VAL_25]], %[[VAL_22]][] : memref<f64>
-// CHECK: %[[VAL_27:.*]] = func.call @addEltF64(%[[VAL_19]], %[[VAL_22]], %[[VAL_21]], %[[VAL_17]]) : (!llvm.ptr<i8>, memref<f64>, memref<?xindex>, memref<?xindex>) -> !llvm.ptr<i8>
-// CHECK: }
-// CHECK: }
-// CHECK: }
-// CHECK: %[[VAL_28:.*]] = call @newSparseTensor(%[[VAL_15]], %[[VAL_15]], %[[VAL_13]], %[[VAL_17]], %[[VAL_17]], %[[VAL_5]], %[[VAL_5]], %[[VAL_4]], %[[VAL_1]], %[[VAL_19]]) : (memref<?xindex>, memref<?xindex>, memref<?xi8>, memref<?xindex>, memref<?xindex>, i32, i32, i32, i32, !llvm.ptr<i8>) -> !llvm.ptr<i8>
-// CHECK: call @delSparseTensorCOOF64(%[[VAL_19]]) : (!llvm.ptr<i8>) -> ()
-// CHECK: return %[[VAL_28]] : !llvm.ptr<i8>
-// CHECK: }
+// CHECK-LABEL: func.func @sparse_convert_2d
+// CHECK: sparse_tensor.foreach
+// CHECK: scf.if
+// CHECK: sparse_tensor.insert
+// CHECK-NOT: sparse_tensor.reorder_coo
+// CHECK: sparse_tensor.load
func.func @sparse_convert_2d(%arg0: tensor<2x4xf64>) -> tensor<2x4xf64, #CSR> {
%0 = sparse_tensor.convert %arg0 : tensor<2x4xf64> to tensor<2x4xf64, #CSR>
return %0 : tensor<2x4xf64, #CSR>
}
-// CHECK-LABEL: func.func @sparse_constant() -> !llvm.ptr<i8> {
-// CHECK-DAG: %[[VAL_0:.*]] = arith.constant dense<[1.000000e+00, 5.000000e+00]> : tensor<2xf32>
-// CHECK-DAG: %[[VAL_1:.*]] = arith.constant dense<{{\[\[}}0, 0], [1, 6]]> : tensor<2x2xi64>
-// CHECK-DAG: %[[VAL_2:.*]] = arith.constant 4 : i32
-// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 2 : i32
-// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 0 : i32
-// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 0 : index
-// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 8 : index
-// CHECK-DAG: %[[VAL_7:.*]] = arith.constant 1 : index
-// CHECK-DAG: %[[VAL_8:.*]] = arith.constant 7 : index
-// CHECK-DAG: %[[VAL_9:.*]] = arith.constant 4 : i8
-// CHECK-DAG: %[[VAL_10:.*]] = arith.constant 8 : i8
-// CHECK-DAG: %[[VAL_11:.*]] = arith.constant 2 : index
-// CHECK: %[[VAL_12:.*]] = memref.alloca() : memref<2xi8>
-// CHECK: %[[VAL_13:.*]] = memref.cast %[[VAL_12]] : memref<2xi8> to memref<?xi8>
-// CHECK: memref.store %[[VAL_9]], %[[VAL_12]]{{\[}}%[[VAL_5]]] : memref<2xi8>
-// CHECK: memref.store %[[VAL_10]], %[[VAL_12]]{{\[}}%[[VAL_7]]] : memref<2xi8>
-// CHECK: %[[VAL_14:.*]] = memref.alloca() : memref<2xindex>
-// CHECK: %[[VAL_15:.*]] = memref.cast %[[VAL_14]] : memref<2xindex> to memref<?xindex>
-// CHECK: memref.store %[[VAL_6]], %[[VAL_14]]{{\[}}%[[VAL_5]]] : memref<2xindex>
-// CHECK: memref.store %[[VAL_8]], %[[VAL_14]]{{\[}}%[[VAL_7]]] : memref<2xindex>
-// CHECK: %[[VAL_16:.*]] = memref.alloca() : memref<2xindex>
-// CHECK: %[[VAL_17:.*]] = memref.cast %[[VAL_16]] : memref<2xindex> to memref<?xindex>
-// CHECK: memref.store %[[VAL_5]], %[[VAL_16]]{{\[}}%[[VAL_5]]] : memref<2xindex>
-// CHECK: memref.store %[[VAL_7]], %[[VAL_16]]{{\[}}%[[VAL_7]]] : memref<2xindex>
-// CHECK: %[[VAL_18:.*]] = llvm.mlir.zero : !llvm.ptr<i8>
-// CHECK: %[[VAL_19:.*]] = call @newSparseTensor(%[[VAL_15]], %[[VAL_15]], %[[VAL_13]], %[[VAL_17]], %[[VAL_17]], %[[VAL_4]], %[[VAL_4]], %[[VAL_3]], %[[VAL_2]], %[[VAL_18]]) : (memref<?xindex>, memref<?xindex>, memref<?xi8>, memref<?xindex>, memref<?xindex>, i32, i32, i32, i32, !llvm.ptr<i8>) -> !llvm.ptr<i8>
-// CHECK: %[[VAL_20:.*]] = memref.alloca() : memref<2xindex>
-// CHECK: %[[VAL_21:.*]] = memref.cast %[[VAL_20]] : memref<2xindex> to memref<?xindex>
-// CHECK: %[[VAL_22:.*]] = memref.alloca() : memref<f32>
-// CHECK: scf.for %[[VAL_23:.*]] = %[[VAL_5]] to %[[VAL_11]] step %[[VAL_7]] {
-// CHECK: %[[VAL_24:.*]] = tensor.extract %[[VAL_1]]{{\[}}%[[VAL_23]], %[[VAL_5]]] : tensor<2x2xi64>
-// CHECK: %[[VAL_25:.*]] = arith.index_cast %[[VAL_24]] : i64 to index
-// CHECK: %[[VAL_26:.*]] = tensor.extract %[[VAL_1]]{{\[}}%[[VAL_23]], %[[VAL_7]]] : tensor<2x2xi64>
-// CHECK: %[[VAL_27:.*]] = arith.index_cast %[[VAL_26]] : i64 to index
-// CHECK: %[[VAL_28:.*]] = tensor.extract %[[VAL_0]]{{\[}}%[[VAL_23]]] : tensor<2xf32>
-// CHECK: memref.store %[[VAL_25]], %[[VAL_20]]{{\[}}%[[VAL_5]]] : memref<2xindex>
-// CHECK: memref.store %[[VAL_27]], %[[VAL_20]]{{\[}}%[[VAL_7]]] : memref<2xindex>
-// CHECK: memref.store %[[VAL_28]], %[[VAL_22]][] : memref<f32>
-// CHECK: %[[VAL_29:.*]] = func.call @addEltF32(%[[VAL_19]], %[[VAL_22]], %[[VAL_21]], %[[VAL_17]]) : (!llvm.ptr<i8>, memref<f32>, memref<?xindex>, memref<?xindex>) -> !llvm.ptr<i8>
-// CHECK: }
-// CHECK: %[[VAL_30:.*]] = call @newSparseTensor(%[[VAL_15]], %[[VAL_15]], %[[VAL_13]], %[[VAL_17]], %[[VAL_17]], %[[VAL_4]], %[[VAL_4]], %[[VAL_3]], %[[VAL_3]], %[[VAL_19]]) : (memref<?xindex>, memref<?xindex>, memref<?xi8>, memref<?xindex>, memref<?xindex>, i32, i32, i32, i32, !llvm.ptr<i8>) -> !llvm.ptr<i8>
-// CHECK: call @delSparseTensorCOOF32(%[[VAL_19]]) : (!llvm.ptr<i8>) -> ()
-// CHECK: return %[[VAL_30]] : !llvm.ptr<i8>
-// CHECK: }
+// CHECK-LABEL: func.func @sparse_constant
+// CHECK: sparse_tensor.foreach
+// CHECK-NOT: scf.if
+// CHECK: sparse_tensor.insert
+// CHECK-NOT: sparse_tensor.reorder_coo
+// CHECK: sparse_tensor.load
func.func @sparse_constant() -> tensor<8x7xf32, #CSR>{
// Initialize a tensor.
%0 = arith.constant sparse<[[0, 0], [1, 6]], [1.0, 5.0]> : tensor<8x7xf32>
@@ -205,59 +63,12 @@ func.func @sparse_constant() -> tensor<8x7xf32, #CSR>{
return %1 : tensor<8x7xf32, #CSR>
}
-// CHECK-LABEL: func.func @sparse_constant_csc() -> !llvm.ptr<i8> {
-// CHECK-DAG: %[[VAL_0:.*]] = arith.constant dense<[1.000000e+00, 5.000000e+00]> : tensor<2xf32>
-// CHECK-DAG: %[[VAL_1:.*]] = arith.constant dense<{{\[\[}}0, 0], [1, 6]]> : tensor<2x2xi64>
-// CHECK-DAG: %[[VAL_2:.*]] = arith.constant 4 : i32
-// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 2 : i32
-// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 0 : i32
-// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 0 : index
-// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 8 : index
-// CHECK-DAG: %[[VAL_7:.*]] = arith.constant 1 : index
-// CHECK-DAG: %[[VAL_8:.*]] = arith.constant 7 : index
-// CHECK-DAG: %[[VAL_9:.*]] = arith.constant 4 : i8
-// CHECK-DAG: %[[VAL_10:.*]] = arith.constant 8 : i8
-// CHECK-DAG: %[[VAL_11:.*]] = arith.constant 2 : index
-// CHECK: %[[VAL_12:.*]] = memref.alloca() : memref<2xi8>
-// CHECK: %[[VAL_13:.*]] = memref.cast %[[VAL_12]] : memref<2xi8> to memref<?xi8>
-// CHECK: memref.store %[[VAL_9]], %[[VAL_12]]{{\[}}%[[VAL_5]]] : memref<2xi8>
-// CHECK: memref.store %[[VAL_10]], %[[VAL_12]]{{\[}}%[[VAL_7]]] : memref<2xi8>
-// CHECK: %[[VAL_14:.*]] = memref.alloca() : memref<2xindex>
-// CHECK: %[[VAL_15:.*]] = memref.cast %[[VAL_14]] : memref<2xindex> to memref<?xindex>
-// CHECK: memref.store %[[VAL_6]], %[[VAL_14]]{{\[}}%[[VAL_5]]] : memref<2xindex>
-// CHECK: memref.store %[[VAL_8]], %[[VAL_14]]{{\[}}%[[VAL_7]]] : memref<2xindex>
-// CHECK: %[[VAL_16:.*]] = memref.alloca() : memref<2xindex>
-// CHECK: %[[VAL_17:.*]] = memref.cast %[[VAL_16]] : memref<2xindex> to memref<?xindex>
-// CHECK: memref.store %[[VAL_7]], %[[VAL_16]]{{\[}}%[[VAL_5]]] : memref<2xindex>
-// CHECK: memref.store %[[VAL_5]], %[[VAL_16]]{{\[}}%[[VAL_7]]] : memref<2xindex>
-// CHECK: %[[VAL_18:.*]] = memref.alloca() : memref<2xindex>
-// CHECK: %[[VAL_19:.*]] = memref.cast %[[VAL_18]] : memref<2xindex> to memref<?xindex>
-// CHECK: memref.store %[[VAL_7]], %[[VAL_18]]{{\[}}%[[VAL_5]]] : memref<2xindex>
-// CHECK: memref.store %[[VAL_5]], %[[VAL_18]]{{\[}}%[[VAL_7]]] : memref<2xindex>
-// CHECK: %[[VAL_20:.*]] = memref.alloca() : memref<2xindex>
-// CHECK: %[[VAL_21:.*]] = memref.cast %[[VAL_20]] : memref<2xindex> to memref<?xindex>
-// CHECK: memref.store %[[VAL_8]], %[[VAL_20]]{{\[}}%[[VAL_5]]] : memref<2xindex>
-// CHECK: memref.store %[[VAL_6]], %[[VAL_20]]{{\[}}%[[VAL_7]]] : memref<2xindex>
-// CHECK: %[[VAL_22:.*]] = llvm.mlir.zero : !llvm.ptr<i8>
-// CHECK: %[[VAL_23:.*]] = call @newSparseTensor(%[[VAL_15]], %[[VAL_21]], %[[VAL_13]], %[[VAL_17]], %[[VAL_19]], %[[VAL_4]], %[[VAL_4]], %[[VAL_3]], %[[VAL_2]], %[[VAL_22]]) : (memref<?xindex>, memref<?xindex>, memref<?xi8>, memref<?xindex>, memref<?xindex>, i32, i32, i32, i32, !llvm.ptr<i8>) -> !llvm.ptr<i8>
-// CHECK: %[[VAL_24:.*]] = memref.alloca() : memref<2xindex>
-// CHECK: %[[VAL_25:.*]] = memref.cast %[[VAL_24]] : memref<2xindex> to memref<?xindex>
-// CHECK: %[[VAL_26:.*]] = memref.alloca() : memref<f32>
-// CHECK: scf.for %[[VAL_27:.*]] = %[[VAL_5]] to %[[VAL_11]] step %[[VAL_7]] {
-// CHECK: %[[VAL_28:.*]] = tensor.extract %[[VAL_1]]{{\[}}%[[VAL_27]], %[[VAL_5]]] : tensor<2x2xi64>
-// CHECK: %[[VAL_29:.*]] = arith.index_cast %[[VAL_28]] : i64 to index
-// CHECK: %[[VAL_30:.*]] = tensor.extract %[[VAL_1]]{{\[}}%[[VAL_27]], %[[VAL_7]]] : tensor<2x2xi64>
-// CHECK: %[[VAL_31:.*]] = arith.index_cast %[[VAL_30]] : i64 to index
-// CHECK: %[[VAL_32:.*]] = tensor.extract %[[VAL_0]]{{\[}}%[[VAL_27]]] : tensor<2xf32>
-// CHECK: memref.store %[[VAL_29]], %[[VAL_24]]{{\[}}%[[VAL_5]]] : memref<2xindex>
-// CHECK: memref.store %[[VAL_31]], %[[VAL_24]]{{\[}}%[[VAL_7]]] : memref<2xindex>
-// CHECK: memref.store %[[VAL_32]], %[[VAL_26]][] : memref<f32>
-// CHECK: %[[VAL_33:.*]] = func.call @addEltF32(%[[VAL_23]], %[[VAL_26]], %[[VAL_25]], %[[VAL_17]]) : (!llvm.ptr<i8>, memref<f32>, memref<?xindex>, memref<?xindex>) -> !llvm.ptr<i8>
-// CHECK: }
-// CHECK: %[[VAL_34:.*]] = call @newSparseTensor(%[[VAL_15]], %[[VAL_21]], %[[VAL_13]], %[[VAL_17]], %[[VAL_19]], %[[VAL_4]], %[[VAL_4]], %[[VAL_3]], %[[VAL_3]], %[[VAL_23]]) : (memref<?xindex>, memref<?xindex>, memref<?xi8>, memref<?xindex>, memref<?xindex>, i32, i32, i32, i32, !llvm.ptr<i8>) -> !llvm.ptr<i8>
-// CHECK: call @delSparseTensorCOOF32(%[[VAL_23]]) : (!llvm.ptr<i8>) -> ()
-// CHECK: return %[[VAL_34]] : !llvm.ptr<i8>
-// CHECK: }
+// CHECK-LABEL: func.func @sparse_constant_csc
+// CHECK: sparse_tensor.foreach
+// CHECK-NOT: scf.if
+// CHECK: sparse_tensor.insert
+// CHECK-NOT: sparse_tensor.reorder_coo
+// CHECK: sparse_tensor.load
func.func @sparse_constant_csc() -> tensor<8x7xf32, #CSC>{
// Initialize a tensor.
%0 = arith.constant sparse<[[0, 0], [1, 6]], [1.0, 5.0]> : tensor<8x7xf32>
@@ -266,73 +77,15 @@ func.func @sparse_constant_csc() -> tensor<8x7xf32, #CSC>{
return %1 : tensor<8x7xf32, #CSC>
}
-// CHECK-LABEL: func.func @sparse_convert_3d(
-// CHECK-SAME: %[[VAL_0:.*]]: tensor<?x?x?xf64>) -> !llvm.ptr<i8> {
-// CHECK-DAG: %[[VAL_1:.*]] = arith.constant 2 : i32
-// CHECK-DAG: %[[VAL_2:.*]] = arith.constant 0.000000e+00 : f64
-// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 4 : i32
-// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 1 : i32
-// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 0 : i32
-// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 8 : i8
-// CHECK-DAG: %[[VAL_7:.*]] = arith.constant 4 : i8
-// CHECK-DAG: %[[VAL_8:.*]] = arith.constant 2 : index
-// CHECK-DAG: %[[VAL_9:.*]] = arith.constant 1 : index
-// CHECK-DAG: %[[VAL_10:.*]] = arith.constant 0 : index
-// CHECK: %[[VAL_11:.*]] = tensor.dim %[[VAL_0]], %[[VAL_10]] : tensor<?x?x?xf64>
-// CHECK: %[[VAL_12:.*]] = tensor.dim %[[VAL_0]], %[[VAL_9]] : tensor<?x?x?xf64>
-// CHECK: %[[VAL_13:.*]] = tensor.dim %[[VAL_0]], %[[VAL_8]] : tensor<?x?x?xf64>
-// CHECK: %[[VAL_14:.*]] = memref.alloca() : memref<3xi8>
-// CHECK: %[[VAL_15:.*]] = memref.cast %[[VAL_14]] : memref<3xi8> to memref<?xi8>
-// CHECK: memref.store %[[VAL_7]], %[[VAL_14]]{{\[}}%[[VAL_10]]] : memref<3xi8>
-// CHECK: memref.store %[[VAL_6]], %[[VAL_14]]{{\[}}%[[VAL_9]]] : memref<3xi8>
-// CHECK: memref.store %[[VAL_6]], %[[VAL_14]]{{\[}}%[[VAL_8]]] : memref<3xi8>
-// CHECK: %[[VAL_16:.*]] = memref.alloca() : memref<3xindex>
-// CHECK: %[[VAL_17:.*]] = memref.cast %[[VAL_16]] : memref<3xindex> to memref<?xindex>
-// CHECK: memref.store %[[VAL_11]], %[[VAL_16]]{{\[}}%[[VAL_10]]] : memref<3xindex>
-// CHECK: memref.store %[[VAL_12]], %[[VAL_16]]{{\[}}%[[VAL_9]]] : memref<3xindex>
-// CHECK: memref.store %[[VAL_13]], %[[VAL_16]]{{\[}}%[[VAL_8]]] : memref<3xindex>
-// CHECK: %[[VAL_18:.*]] = memref.load %[[VAL_16]]{{\[}}%[[VAL_8]]] : memref<3xindex>
-// CHECK: %[[VAL_19:.*]] = memref.load %[[VAL_16]]{{\[}}%[[VAL_10]]] : memref<3xindex>
-// CHECK: %[[VAL_20:.*]] = memref.load %[[VAL_16]]{{\[}}%[[VAL_9]]] : memref<3xindex>
-// CHECK: %[[VAL_21:.*]] = memref.alloca() : memref<3xindex>
-// CHECK: %[[VAL_22:.*]] = memref.cast %[[VAL_21]] : memref<3xindex> to memref<?xindex>
-// CHECK: memref.store %[[VAL_9]], %[[VAL_21]]{{\[}}%[[VAL_10]]] : memref<3xindex>
-// CHECK: memref.store %[[VAL_8]], %[[VAL_21]]{{\[}}%[[VAL_9]]] : memref<3xindex>
-// CHECK: memref.store %[[VAL_10]], %[[VAL_21]]{{\[}}%[[VAL_8]]] : memref<3xindex>
-// CHECK: %[[VAL_23:.*]] = memref.alloca() : memref<3xindex>
-// CHECK: %[[VAL_24:.*]] = memref.cast %[[VAL_23]] : memref<3xindex> to memref<?xindex>
-// CHECK: memref.store %[[VAL_8]], %[[VAL_23]]{{\[}}%[[VAL_10]]] : memref<3xindex>
-// CHECK: memref.store %[[VAL_10]], %[[VAL_23]]{{\[}}%[[VAL_9]]] : memref<3xindex>
-// CHECK: memref.store %[[VAL_9]], %[[VAL_23]]{{\[}}%[[VAL_8]]] : memref<3xindex>
-// CHECK: %[[VAL_25:.*]] = memref.alloca() : memref<3xindex>
-// CHECK: %[[VAL_26:.*]] = memref.cast %[[VAL_25]] : memref<3xindex> to memref<?xindex>
-// CHECK: memref.store %[[VAL_18]], %[[VAL_25]]{{\[}}%[[VAL_10]]] : memref<3xindex>
-// CHECK: memref.store %[[VAL_19]], %[[VAL_25]]{{\[}}%[[VAL_9]]] : memref<3xindex>
-// CHECK: memref.store %[[VAL_20]], %[[VAL_25]]{{\[}}%[[VAL_8]]] : memref<3xindex>
-// CHECK: %[[VAL_27:.*]] = llvm.mlir.zero : !llvm.ptr<i8>
-// CHECK: %[[VAL_28:.*]] = call @newSparseTensor(%[[VAL_17]], %[[VAL_26]], %[[VAL_15]], %[[VAL_22]], %[[VAL_24]], %[[VAL_5]], %[[VAL_5]], %[[VAL_4]], %[[VAL_3]], %[[VAL_27]]) : (memref<?xindex>, memref<?xindex>, memref<?xi8>, memref<?xindex>, memref<?xindex>, i32, i32, i32, i32, !llvm.ptr<i8>) -> !llvm.ptr<i8>
-// CHECK: %[[VAL_29:.*]] = memref.alloca() : memref<3xindex>
-// CHECK: %[[VAL_30:.*]] = memref.cast %[[VAL_29]] : memref<3xindex> to memref<?xindex>
-// CHECK: %[[VAL_31:.*]] = memref.alloca() : memref<f64>
-// CHECK: scf.for %[[VAL_32:.*]] = %[[VAL_10]] to %[[VAL_11]] step %[[VAL_9]] {
-// CHECK: scf.for %[[VAL_33:.*]] = %[[VAL_10]] to %[[VAL_12]] step %[[VAL_9]] {
-// CHECK: scf.for %[[VAL_34:.*]] = %[[VAL_10]] to %[[VAL_13]] step %[[VAL_9]] {
-// CHECK: %[[VAL_35:.*]] = tensor.extract %[[VAL_0]]{{\[}}%[[VAL_32]], %[[VAL_33]], %[[VAL_34]]] : tensor<?x?x?xf64>
-// CHECK: %[[VAL_36:.*]] = arith.cmpf une, %[[VAL_35]], %[[VAL_2]] : f64
-// CHECK: scf.if %[[VAL_36]] {
-// CHECK: memref.store %[[VAL_32]], %[[VAL_29]]{{\[}}%[[VAL_10]]] : memref<3xindex>
-// CHECK: memref.store %[[VAL_33]], %[[VAL_29]]{{\[}}%[[VAL_9]]] : memref<3xindex>
-// CHECK: memref.store %[[VAL_34]], %[[VAL_29]]{{\[}}%[[VAL_8]]] : memref<3xindex>
-// CHECK: memref.store %[[VAL_35]], %[[VAL_31]][] : memref<f64>
-// CHECK: %[[VAL_37:.*]] = func.call @addEltF64(%[[VAL_28]], %[[VAL_31]], %[[VAL_30]], %[[VAL_22]]) : (!llvm.ptr<i8>, memref<f64>, memref<?xindex>, memref<?xindex>) -> !llvm.ptr<i8>
-// CHECK: }
-// CHECK: }
-// CHECK: }
-// CHECK: }
-// CHECK: %[[VAL_38:.*]] = call @newSparseTensor(%[[VAL_17]], %[[VAL_26]], %[[VAL_15]], %[[VAL_22]], %[[VAL_24]], %[[VAL_5]], %[[VAL_5]], %[[VAL_4]], %[[VAL_1]], %[[VAL_28]]) : (memref<?xindex>, memref<?xindex>, memref<?xi8>, memref<?xindex>, memref<?xindex>, i32, i32, i32, i32, !llvm.ptr<i8>) -> !llvm.ptr<i8>
-// CHECK: call @delSparseTensorCOOF64(%[[VAL_28]]) : (!llvm.ptr<i8>) -> ()
-// CHECK: return %[[VAL_38]] : !llvm.ptr<i8>
-// CHECK: }
+// CHECK-LABEL: func.func @sparse_convert_3d
+// CHECK: sparse_tensor.foreach
+// CHECK: scf.if
+// CHECK: sparse_tensor.insert
+// CHECK: sparse_tensor.load
+// CHECK: sparse_tensor.reorder_coo
+// CHECK: sparse_tensor.foreach
+// CHECK: sparse_tensor.insert
+// CHECK: sparse_tensor.load
func.func @sparse_convert_3d(%arg0: tensor<?x?x?xf64>) -> tensor<?x?x?xf64, #SparseTensor> {
%0 = sparse_tensor.convert %arg0 : tensor<?x?x?xf64> to tensor<?x?x?xf64, #SparseTensor>
return %0 : tensor<?x?x?xf64, #SparseTensor>
diff --git a/mlir/test/Dialect/SparseTensor/convert_sparse2dense.mlir b/mlir/test/Dialect/SparseTensor/convert_sparse2dense.mlir
index ffc0f57a231103e..c22f051a0d5854d 100644
--- a/mlir/test/Dialect/SparseTensor/convert_sparse2dense.mlir
+++ b/mlir/test/Dialect/SparseTensor/convert_sparse2dense.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s --sparse-tensor-conversion --canonicalize --cse | FileCheck %s
+// RUN: mlir-opt %s --stage-sparse-ops --post-sparsification-rewrite="enable-foreach=false" --canonicalize --cse | FileCheck %s
#SparseVector = #sparse_tensor.encoding<{
map = (d0) -> (d0 : compressed)
@@ -12,326 +12,85 @@
map = (d0, d1, d2) -> (d2 : dense, d0 : compressed, d1 : compressed)
}>
-// CHECK-LABEL: func.func @sparse_convert_1d(
-// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr<i8>) -> tensor<13xi32> {
-// CHECK-DAG: %[[VAL_1:.*]] = arith.constant 6 : i32
-// CHECK-DAG: %[[VAL_2:.*]] = arith.constant 0 : i32
-// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 0 : index
-// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 13 : index
-// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 4 : i8
-// CHECK: %[[VAL_6:.*]] = memref.alloca() : memref<1xi8>
-// CHECK: %[[VAL_7:.*]] = memref.cast %[[VAL_6]] : memref<1xi8> to memref<?xi8>
-// CHECK: memref.store %[[VAL_5]], %[[VAL_6]]{{\[}}%[[VAL_3]]] : memref<1xi8>
-// CHECK: %[[VAL_8:.*]] = memref.alloca() : memref<1xindex>
-// CHECK: %[[VAL_9:.*]] = memref.cast %[[VAL_8]] : memref<1xindex> to memref<?xindex>
-// CHECK: memref.store %[[VAL_4]], %[[VAL_8]]{{\[}}%[[VAL_3]]] : memref<1xindex>
-// CHECK: %[[VAL_10:.*]] = memref.alloca() : memref<1xindex>
-// CHECK: %[[VAL_11:.*]] = memref.cast %[[VAL_10]] : memref<1xindex> to memref<?xindex>
-// CHECK: memref.store %[[VAL_3]], %[[VAL_10]]{{\[}}%[[VAL_3]]] : memref<1xindex>
-// CHECK: %[[VAL_12:.*]] = call @newSparseTensor(%[[VAL_9]], %[[VAL_9]], %[[VAL_7]], %[[VAL_11]], %[[VAL_11]], %[[VAL_2]], %[[VAL_2]], %[[VAL_1]], %[[VAL_1]], %[[VAL_0]]) : (memref<?xindex>, memref<?xindex>, memref<?xi8>, memref<?xindex>, memref<?xindex>, i32, i32, i32, i32, !llvm.ptr<i8>) -> !llvm.ptr<i8>
-// CHECK: %[[VAL_13:.*]] = memref.alloca() : memref<1xindex>
-// CHECK: %[[VAL_14:.*]] = memref.cast %[[VAL_13]] : memref<1xindex> to memref<?xindex>
-// CHECK: %[[VAL_15:.*]] = memref.alloca() : memref<i32>
-// CHECK: %[[VAL_16:.*]] = memref.alloc() : memref<13xi32>
-// CHECK: linalg.fill ins(%[[VAL_2]] : i32) outs(%[[VAL_16]] : memref<13xi32>)
-// CHECK: scf.while : () -> () {
-// CHECK: %[[VAL_17:.*]] = func.call @getNextI32(%[[VAL_12]], %[[VAL_14]], %[[VAL_15]]) : (!llvm.ptr<i8>, memref<?xindex>, memref<i32>) -> i1
-// CHECK: scf.condition(%[[VAL_17]])
-// CHECK: } do {
-// CHECK: %[[VAL_18:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_3]]] : memref<1xindex>
-// CHECK: %[[VAL_19:.*]] = memref.load %[[VAL_15]][] : memref<i32>
-// CHECK: memref.store %[[VAL_19]], %[[VAL_16]]{{\[}}%[[VAL_18]]] : memref<13xi32>
-// CHECK: scf.yield
-// CHECK: }
-// CHECK: call @delSparseTensorIteratorI32(%[[VAL_12]]) : (!llvm.ptr<i8>) -> ()
-// CHECK: %[[VAL_20:.*]] = bufferization.to_tensor %[[VAL_16]] : memref<13xi32>
-// CHECK: return %[[VAL_20]] : tensor<13xi32>
-// CHECK: }
+// CHECK-LABEL: func.func @sparse_convert_1d
+// CHECK-NOT: sparse_tensor.reorder_coo
+// CHECK: memref.alloc
+// CHECK: linalg.fill
+// CHECK: sparse_tensor.foreach
+// CHECK: memref.store
+// CHECK: bufferization.to_tensor
func.func @sparse_convert_1d(%arg0: tensor<13xi32, #SparseVector>) -> tensor<13xi32> {
%0 = sparse_tensor.convert %arg0 : tensor<13xi32, #SparseVector> to tensor<13xi32>
return %0 : tensor<13xi32>
}
-// CHECK-LABEL: func.func @sparse_convert_1d_dyn(
-// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr<i8>) -> tensor<?xi32> {
-// CHECK-DAG: %[[VAL_1:.*]] = arith.constant 6 : i32
-// CHECK-DAG: %[[VAL_2:.*]] = arith.constant 0 : i32
-// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 4 : i8
-// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 0 : index
-// CHECK: %[[VAL_5:.*]] = call @sparseDimSize(%[[VAL_0]], %[[VAL_4]]) : (!llvm.ptr<i8>, index) -> index
-// CHECK: %[[VAL_6:.*]] = memref.alloca() : memref<1xi8>
-// CHECK: %[[VAL_7:.*]] = memref.cast %[[VAL_6]] : memref<1xi8> to memref<?xi8>
-// CHECK: memref.store %[[VAL_3]], %[[VAL_6]]{{\[}}%[[VAL_4]]] : memref<1xi8>
-// CHECK: %[[VAL_8:.*]] = memref.alloca() : memref<1xindex>
-// CHECK: %[[VAL_9:.*]] = memref.cast %[[VAL_8]] : memref<1xindex> to memref<?xindex>
-// CHECK: memref.store %[[VAL_5]], %[[VAL_8]]{{\[}}%[[VAL_4]]] : memref<1xindex>
-// CHECK: %[[VAL_10:.*]] = memref.alloca() : memref<1xindex>
-// CHECK: %[[VAL_11:.*]] = memref.cast %[[VAL_10]] : memref<1xindex> to memref<?xindex>
-// CHECK: memref.store %[[VAL_4]], %[[VAL_10]]{{\[}}%[[VAL_4]]] : memref<1xindex>
-// CHECK: %[[VAL_12:.*]] = call @newSparseTensor(%[[VAL_9]], %[[VAL_9]], %[[VAL_7]], %[[VAL_11]], %[[VAL_11]], %[[VAL_2]], %[[VAL_2]], %[[VAL_1]], %[[VAL_1]], %[[VAL_0]]) : (memref<?xindex>, memref<?xindex>, memref<?xi8>, memref<?xindex>, memref<?xindex>, i32, i32, i32, i32, !llvm.ptr<i8>) -> !llvm.ptr<i8>
-// CHECK: %[[VAL_13:.*]] = memref.alloca() : memref<1xindex>
-// CHECK: %[[VAL_14:.*]] = memref.cast %[[VAL_13]] : memref<1xindex> to memref<?xindex>
-// CHECK: %[[VAL_15:.*]] = memref.alloca() : memref<i32>
-// CHECK: %[[VAL_16:.*]] = memref.alloc(%[[VAL_5]]) : memref<?xi32>
-// CHECK: linalg.fill ins(%[[VAL_2]] : i32) outs(%[[VAL_16]] : memref<?xi32>)
-// CHECK: scf.while : () -> () {
-// CHECK: %[[VAL_17:.*]] = func.call @getNextI32(%[[VAL_12]], %[[VAL_14]], %[[VAL_15]]) : (!llvm.ptr<i8>, memref<?xindex>, memref<i32>) -> i1
-// CHECK: scf.condition(%[[VAL_17]])
-// CHECK: } do {
-// CHECK: %[[VAL_18:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_4]]] : memref<1xindex>
-// CHECK: %[[VAL_19:.*]] = memref.load %[[VAL_15]][] : memref<i32>
-// CHECK: memref.store %[[VAL_19]], %[[VAL_16]]{{\[}}%[[VAL_18]]] : memref<?xi32>
-// CHECK: scf.yield
-// CHECK: }
-// CHECK: call @delSparseTensorIteratorI32(%[[VAL_12]]) : (!llvm.ptr<i8>) -> ()
-// CHECK: %[[VAL_20:.*]] = bufferization.to_tensor %[[VAL_16]] : memref<?xi32>
-// CHECK: return %[[VAL_20]] : tensor<?xi32>
-// CHECK: }
+// CHECK-LABEL: func.func @sparse_convert_1d_dyn
+// CHECK-NOT: sparse_tensor.reorder_coo
+// CHECK: memref.alloc
+// CHECK: linalg.fill
+// CHECK: sparse_tensor.foreach
+// CHECK: memref.store
+// CHECK: bufferization.to_tensor
func.func @sparse_convert_1d_dyn(%arg0: tensor<?xi32, #SparseVector>) -> tensor<?xi32> {
%0 = sparse_tensor.convert %arg0 : tensor<?xi32, #SparseVector> to tensor<?xi32>
return %0 : tensor<?xi32>
}
-// CHECK-LABEL: func.func @sparse_convert_2d(
-// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr<i8>) -> tensor<2x4xf64> {
-// CHECK-DAG: %[[VAL_1:.*]] = arith.constant 0.000000e+00 : f64
-// CHECK-DAG: %[[VAL_2:.*]] = arith.constant 6 : i32
-// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 1 : i32
-// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 0 : i32
-// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 1 : index
-// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 0 : index
-// CHECK-DAG: %[[VAL_7:.*]] = arith.constant 2 : index
-// CHECK-DAG: %[[VAL_8:.*]] = arith.constant 4 : index
-// CHECK-DAG: %[[VAL_9:.*]] = arith.constant 4 : i8
-// CHECK: %[[VAL_10:.*]] = memref.alloca() : memref<2xi8>
-// CHECK: %[[VAL_11:.*]] = memref.cast %[[VAL_10]] : memref<2xi8> to memref<?xi8>
-// CHECK: memref.store %[[VAL_9]], %[[VAL_10]]{{\[}}%[[VAL_6]]] : memref<2xi8>
-// CHECK: memref.store %[[VAL_9]], %[[VAL_10]]{{\[}}%[[VAL_5]]] : memref<2xi8>
-// CHECK: %[[VAL_12:.*]] = memref.alloca() : memref<2xindex>
-// CHECK: %[[VAL_13:.*]] = memref.cast %[[VAL_12]] : memref<2xindex> to memref<?xindex>
-// CHECK: memref.store %[[VAL_7]], %[[VAL_12]]{{\[}}%[[VAL_6]]] : memref<2xindex>
-// CHECK: memref.store %[[VAL_8]], %[[VAL_12]]{{\[}}%[[VAL_5]]] : memref<2xindex>
-// CHECK: %[[VAL_14:.*]] = memref.alloca() : memref<2xindex>
-// CHECK: %[[VAL_15:.*]] = memref.cast %[[VAL_14]] : memref<2xindex> to memref<?xindex>
-// CHECK: memref.store %[[VAL_6]], %[[VAL_14]]{{\[}}%[[VAL_6]]] : memref<2xindex>
-// CHECK: memref.store %[[VAL_5]], %[[VAL_14]]{{\[}}%[[VAL_5]]] : memref<2xindex>
-// CHECK: %[[VAL_16:.*]] = call @newSparseTensor(%[[VAL_13]], %[[VAL_13]], %[[VAL_11]], %[[VAL_15]], %[[VAL_15]], %[[VAL_4]], %[[VAL_4]], %[[VAL_3]], %[[VAL_2]], %[[VAL_0]]) : (memref<?xindex>, memref<?xindex>, memref<?xi8>, memref<?xindex>, memref<?xindex>, i32, i32, i32, i32, !llvm.ptr<i8>) -> !llvm.ptr<i8>
-// CHECK: %[[VAL_17:.*]] = memref.alloca() : memref<2xindex>
-// CHECK: %[[VAL_18:.*]] = memref.cast %[[VAL_17]] : memref<2xindex> to memref<?xindex>
-// CHECK: %[[VAL_19:.*]] = memref.alloca() : memref<f64>
-// CHECK: %[[VAL_20:.*]] = memref.alloc() : memref<2x4xf64>
-// CHECK: linalg.fill ins(%[[VAL_1]] : f64) outs(%[[VAL_20]] : memref<2x4xf64>)
-// CHECK: scf.while : () -> () {
-// CHECK: %[[VAL_21:.*]] = func.call @getNextF64(%[[VAL_16]], %[[VAL_18]], %[[VAL_19]]) : (!llvm.ptr<i8>, memref<?xindex>, memref<f64>) -> i1
-// CHECK: scf.condition(%[[VAL_21]])
-// CHECK: } do {
-// CHECK: %[[VAL_22:.*]] = memref.load %[[VAL_17]]{{\[}}%[[VAL_6]]] : memref<2xindex>
-// CHECK: %[[VAL_23:.*]] = memref.load %[[VAL_17]]{{\[}}%[[VAL_5]]] : memref<2xindex>
-// CHECK: %[[VAL_24:.*]] = memref.load %[[VAL_19]][] : memref<f64>
-// CHECK: memref.store %[[VAL_24]], %[[VAL_20]]{{\[}}%[[VAL_22]], %[[VAL_23]]] : memref<2x4xf64>
-// CHECK: scf.yield
-// CHECK: }
-// CHECK: call @delSparseTensorIteratorF64(%[[VAL_16]]) : (!llvm.ptr<i8>) -> ()
-// CHECK: %[[VAL_25:.*]] = bufferization.to_tensor %[[VAL_20]] : memref<2x4xf64>
-// CHECK: return %[[VAL_25]] : tensor<2x4xf64>
-// CHECK: }
+// CHECK-LABEL: func.func @sparse_convert_2d
+// CHECK-NOT: sparse_tensor.reorder_coo
+// CHECK: memref.alloc
+// CHECK: linalg.fill
+// CHECK: sparse_tensor.foreach
+// CHECK: memref.store
+// CHECK: bufferization.to_tensor
func.func @sparse_convert_2d(%arg0: tensor<2x4xf64, #SparseMatrix>) -> tensor<2x4xf64> {
%0 = sparse_tensor.convert %arg0 : tensor<2x4xf64, #SparseMatrix> to tensor<2x4xf64>
return %0 : tensor<2x4xf64>
}
-// CHECK-LABEL: func.func @sparse_convert_2d_dyn0(
-// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr<i8>) -> tensor<?x4xf64> {
-// CHECK-DAG: %[[VAL_1:.*]] = arith.constant 0.000000e+00 : f64
-// CHECK-DAG: %[[VAL_2:.*]] = arith.constant 6 : i32
-// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 1 : i32
-// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 0 : i32
-// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 1 : index
-// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 4 : i8
-// CHECK-DAG: %[[VAL_7:.*]] = arith.constant 4 : index
-// CHECK-DAG: %[[VAL_8:.*]] = arith.constant 0 : index
-// CHECK: %[[VAL_9:.*]] = call @sparseDimSize(%[[VAL_0]], %[[VAL_8]]) : (!llvm.ptr<i8>, index) -> index
-// CHECK: %[[VAL_10:.*]] = memref.alloca() : memref<2xi8>
-// CHECK: %[[VAL_11:.*]] = memref.cast %[[VAL_10]] : memref<2xi8> to memref<?xi8>
-// CHECK: memref.store %[[VAL_6]], %[[VAL_10]]{{\[}}%[[VAL_8]]] : memref<2xi8>
-// CHECK: memref.store %[[VAL_6]], %[[VAL_10]]{{\[}}%[[VAL_5]]] : memref<2xi8>
-// CHECK: %[[VAL_12:.*]] = memref.alloca() : memref<2xindex>
-// CHECK: %[[VAL_13:.*]] = memref.cast %[[VAL_12]] : memref<2xindex> to memref<?xindex>
-// CHECK: memref.store %[[VAL_9]], %[[VAL_12]]{{\[}}%[[VAL_8]]] : memref<2xindex>
-// CHECK: memref.store %[[VAL_7]], %[[VAL_12]]{{\[}}%[[VAL_5]]] : memref<2xindex>
-// CHECK: %[[VAL_14:.*]] = memref.alloca() : memref<2xindex>
-// CHECK: %[[VAL_15:.*]] = memref.cast %[[VAL_14]] : memref<2xindex> to memref<?xindex>
-// CHECK: memref.store %[[VAL_8]], %[[VAL_14]]{{\[}}%[[VAL_8]]] : memref<2xindex>
-// CHECK: memref.store %[[VAL_5]], %[[VAL_14]]{{\[}}%[[VAL_5]]] : memref<2xindex>
-// CHECK: %[[VAL_16:.*]] = call @newSparseTensor(%[[VAL_13]], %[[VAL_13]], %[[VAL_11]], %[[VAL_15]], %[[VAL_15]], %[[VAL_4]], %[[VAL_4]], %[[VAL_3]], %[[VAL_2]], %[[VAL_0]]) : (memref<?xindex>, memref<?xindex>, memref<?xi8>, memref<?xindex>, memref<?xindex>, i32, i32, i32, i32, !llvm.ptr<i8>) -> !llvm.ptr<i8>
-// CHECK: %[[VAL_17:.*]] = memref.alloca() : memref<2xindex>
-// CHECK: %[[VAL_18:.*]] = memref.cast %[[VAL_17]] : memref<2xindex> to memref<?xindex>
-// CHECK: %[[VAL_19:.*]] = memref.alloca() : memref<f64>
-// CHECK: %[[VAL_20:.*]] = memref.alloc(%[[VAL_9]]) : memref<?x4xf64>
-// CHECK: linalg.fill ins(%[[VAL_1]] : f64) outs(%[[VAL_20]] : memref<?x4xf64>)
-// CHECK: scf.while : () -> () {
-// CHECK: %[[VAL_21:.*]] = func.call @getNextF64(%[[VAL_16]], %[[VAL_18]], %[[VAL_19]]) : (!llvm.ptr<i8>, memref<?xindex>, memref<f64>) -> i1
-// CHECK: scf.condition(%[[VAL_21]])
-// CHECK: } do {
-// CHECK: %[[VAL_22:.*]] = memref.load %[[VAL_17]]{{\[}}%[[VAL_8]]] : memref<2xindex>
-// CHECK: %[[VAL_23:.*]] = memref.load %[[VAL_17]]{{\[}}%[[VAL_5]]] : memref<2xindex>
-// CHECK: %[[VAL_24:.*]] = memref.load %[[VAL_19]][] : memref<f64>
-// CHECK: memref.store %[[VAL_24]], %[[VAL_20]]{{\[}}%[[VAL_22]], %[[VAL_23]]] : memref<?x4xf64>
-// CHECK: scf.yield
-// CHECK: }
-// CHECK: call @delSparseTensorIteratorF64(%[[VAL_16]]) : (!llvm.ptr<i8>) -> ()
-// CHECK: %[[VAL_25:.*]] = bufferization.to_tensor %[[VAL_20]] : memref<?x4xf64>
-// CHECK: return %[[VAL_25]] : tensor<?x4xf64>
-// CHECK: }
+// CHECK-LABEL: func.func @sparse_convert_2d_dyn
+// CHECK-NOT: sparse_tensor.reorder_coo
+// CHECK: memref.alloc
+// CHECK: linalg.fill
+// CHECK: sparse_tensor.foreach
+// CHECK: memref.store
+// CHECK: bufferization.to_tensor
func.func @sparse_convert_2d_dyn0(%arg0: tensor<?x4xf64, #SparseMatrix>) -> tensor<?x4xf64> {
%0 = sparse_tensor.convert %arg0 : tensor<?x4xf64, #SparseMatrix> to tensor<?x4xf64>
return %0 : tensor<?x4xf64>
}
-// CHECK-LABEL: func.func @sparse_convert_2d_dyn1(
-// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr<i8>) -> tensor<2x?xf64> {
-// CHECK-DAG: %[[VAL_1:.*]] = arith.constant 0.000000e+00 : f64
-// CHECK-DAG: %[[VAL_2:.*]] = arith.constant 6 : i32
-// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 1 : i32
-// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 0 : i32
-// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 0 : index
-// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 4 : i8
-// CHECK-DAG: %[[VAL_7:.*]] = arith.constant 2 : index
-// CHECK-DAG: %[[VAL_8:.*]] = arith.constant 1 : index
-// CHECK: %[[VAL_9:.*]] = call @sparseDimSize(%[[VAL_0]], %[[VAL_8]]) : (!llvm.ptr<i8>, index) -> index
-// CHECK: %[[VAL_10:.*]] = memref.alloca() : memref<2xi8>
-// CHECK: %[[VAL_11:.*]] = memref.cast %[[VAL_10]] : memref<2xi8> to memref<?xi8>
-// CHECK: memref.store %[[VAL_6]], %[[VAL_10]]{{\[}}%[[VAL_5]]] : memref<2xi8>
-// CHECK: memref.store %[[VAL_6]], %[[VAL_10]]{{\[}}%[[VAL_8]]] : memref<2xi8>
-// CHECK: %[[VAL_12:.*]] = memref.alloca() : memref<2xindex>
-// CHECK: %[[VAL_13:.*]] = memref.cast %[[VAL_12]] : memref<2xindex> to memref<?xindex>
-// CHECK: memref.store %[[VAL_7]], %[[VAL_12]]{{\[}}%[[VAL_5]]] : memref<2xindex>
-// CHECK: memref.store %[[VAL_9]], %[[VAL_12]]{{\[}}%[[VAL_8]]] : memref<2xindex>
-// CHECK: %[[VAL_14:.*]] = memref.alloca() : memref<2xindex>
-// CHECK: %[[VAL_15:.*]] = memref.cast %[[VAL_14]] : memref<2xindex> to memref<?xindex>
-// CHECK: memref.store %[[VAL_5]], %[[VAL_14]]{{\[}}%[[VAL_5]]] : memref<2xindex>
-// CHECK: memref.store %[[VAL_8]], %[[VAL_14]]{{\[}}%[[VAL_8]]] : memref<2xindex>
-// CHECK: %[[VAL_16:.*]] = call @newSparseTensor(%[[VAL_13]], %[[VAL_13]], %[[VAL_11]], %[[VAL_15]], %[[VAL_15]], %[[VAL_4]], %[[VAL_4]], %[[VAL_3]], %[[VAL_2]], %[[VAL_0]]) : (memref<?xindex>, memref<?xindex>, memref<?xi8>, memref<?xindex>, memref<?xindex>, i32, i32, i32, i32, !llvm.ptr<i8>) -> !llvm.ptr<i8>
-// CHECK: %[[VAL_17:.*]] = memref.alloca() : memref<2xindex>
-// CHECK: %[[VAL_18:.*]] = memref.cast %[[VAL_17]] : memref<2xindex> to memref<?xindex>
-// CHECK: %[[VAL_19:.*]] = memref.alloca() : memref<f64>
-// CHECK: %[[VAL_20:.*]] = memref.alloc(%[[VAL_9]]) : memref<2x?xf64>
-// CHECK: linalg.fill ins(%[[VAL_1]] : f64) outs(%[[VAL_20]] : memref<2x?xf64>)
-// CHECK: scf.while : () -> () {
-// CHECK: %[[VAL_21:.*]] = func.call @getNextF64(%[[VAL_16]], %[[VAL_18]], %[[VAL_19]]) : (!llvm.ptr<i8>, memref<?xindex>, memref<f64>) -> i1
-// CHECK: scf.condition(%[[VAL_21]])
-// CHECK: } do {
-// CHECK: %[[VAL_22:.*]] = memref.load %[[VAL_17]]{{\[}}%[[VAL_5]]] : memref<2xindex>
-// CHECK: %[[VAL_23:.*]] = memref.load %[[VAL_17]]{{\[}}%[[VAL_8]]] : memref<2xindex>
-// CHECK: %[[VAL_24:.*]] = memref.load %[[VAL_19]][] : memref<f64>
-// CHECK: memref.store %[[VAL_24]], %[[VAL_20]]{{\[}}%[[VAL_22]], %[[VAL_23]]] : memref<2x?xf64>
-// CHECK: scf.yield
-// CHECK: }
-// CHECK: call @delSparseTensorIteratorF64(%[[VAL_16]]) : (!llvm.ptr<i8>) -> ()
-// CHECK: %[[VAL_25:.*]] = bufferization.to_tensor %[[VAL_20]] : memref<2x?xf64>
-// CHECK: return %[[VAL_25]] : tensor<2x?xf64>
-// CHECK: }
+// CHECK-LABEL: func.func @sparse_convert_2d_dyn1
+// CHECK-NOT: sparse_tensor.reorder_coo
+// CHECK: memref.alloc
+// CHECK: linalg.fill
+// CHECK: sparse_tensor.foreach
+// CHECK: memref.store
+// CHECK: bufferization.to_tensor
func.func @sparse_convert_2d_dyn1(%arg0: tensor<2x?xf64, #SparseMatrix>) -> tensor<2x?xf64> {
%0 = sparse_tensor.convert %arg0 : tensor<2x?xf64, #SparseMatrix> to tensor<2x?xf64>
return %0 : tensor<2x?xf64>
}
-// CHECK-LABEL: func.func @sparse_convert_2d_dyn2(
-// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr<i8>) -> tensor<?x?xf64> {
-// CHECK-DAG: %[[VAL_1:.*]] = arith.constant 0.000000e+00 : f64
-// CHECK-DAG: %[[VAL_2:.*]] = arith.constant 6 : i32
-// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 1 : i32
-// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 0 : i32
-// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 4 : i8
-// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 1 : index
-// CHECK-DAG: %[[VAL_7:.*]] = arith.constant 0 : index
-// CHECK: %[[VAL_8:.*]] = call @sparseDimSize(%[[VAL_0]], %[[VAL_7]]) : (!llvm.ptr<i8>, index) -> index
-// CHECK: %[[VAL_9:.*]] = call @sparseDimSize(%[[VAL_0]], %[[VAL_6]]) : (!llvm.ptr<i8>, index) -> index
-// CHECK: %[[VAL_10:.*]] = memref.alloca() : memref<2xi8>
-// CHECK: %[[VAL_11:.*]] = memref.cast %[[VAL_10]] : memref<2xi8> to memref<?xi8>
-// CHECK: memref.store %[[VAL_5]], %[[VAL_10]]{{\[}}%[[VAL_7]]] : memref<2xi8>
-// CHECK: memref.store %[[VAL_5]], %[[VAL_10]]{{\[}}%[[VAL_6]]] : memref<2xi8>
-// CHECK: %[[VAL_12:.*]] = memref.alloca() : memref<2xindex>
-// CHECK: %[[VAL_13:.*]] = memref.cast %[[VAL_12]] : memref<2xindex> to memref<?xindex>
-// CHECK: memref.store %[[VAL_8]], %[[VAL_12]]{{\[}}%[[VAL_7]]] : memref<2xindex>
-// CHECK: memref.store %[[VAL_9]], %[[VAL_12]]{{\[}}%[[VAL_6]]] : memref<2xindex>
-// CHECK: %[[VAL_14:.*]] = memref.alloca() : memref<2xindex>
-// CHECK: %[[VAL_15:.*]] = memref.cast %[[VAL_14]] : memref<2xindex> to memref<?xindex>
-// CHECK: memref.store %[[VAL_7]], %[[VAL_14]]{{\[}}%[[VAL_7]]] : memref<2xindex>
-// CHECK: memref.store %[[VAL_6]], %[[VAL_14]]{{\[}}%[[VAL_6]]] : memref<2xindex>
-// CHECK: %[[VAL_16:.*]] = call @newSparseTensor(%[[VAL_13]], %[[VAL_13]], %[[VAL_11]], %[[VAL_15]], %[[VAL_15]], %[[VAL_4]], %[[VAL_4]], %[[VAL_3]], %[[VAL_2]], %[[VAL_0]]) : (memref<?xindex>, memref<?xindex>, memref<?xi8>, memref<?xindex>, memref<?xindex>, i32, i32, i32, i32, !llvm.ptr<i8>) -> !llvm.ptr<i8>
-// CHECK: %[[VAL_17:.*]] = memref.alloca() : memref<2xindex>
-// CHECK: %[[VAL_18:.*]] = memref.cast %[[VAL_17]] : memref<2xindex> to memref<?xindex>
-// CHECK: %[[VAL_19:.*]] = memref.alloca() : memref<f64>
-// CHECK: %[[VAL_20:.*]] = memref.alloc(%[[VAL_8]], %[[VAL_9]]) : memref<?x?xf64>
-// CHECK: linalg.fill ins(%[[VAL_1]] : f64) outs(%[[VAL_20]] : memref<?x?xf64>)
-// CHECK: scf.while : () -> () {
-// CHECK: %[[VAL_21:.*]] = func.call @getNextF64(%[[VAL_16]], %[[VAL_18]], %[[VAL_19]]) : (!llvm.ptr<i8>, memref<?xindex>, memref<f64>) -> i1
-// CHECK: scf.condition(%[[VAL_21]])
-// CHECK: } do {
-// CHECK: %[[VAL_22:.*]] = memref.load %[[VAL_17]]{{\[}}%[[VAL_7]]] : memref<2xindex>
-// CHECK: %[[VAL_23:.*]] = memref.load %[[VAL_17]]{{\[}}%[[VAL_6]]] : memref<2xindex>
-// CHECK: %[[VAL_24:.*]] = memref.load %[[VAL_19]][] : memref<f64>
-// CHECK: memref.store %[[VAL_24]], %[[VAL_20]]{{\[}}%[[VAL_22]], %[[VAL_23]]] : memref<?x?xf64>
-// CHECK: scf.yield
-// CHECK: }
-// CHECK: call @delSparseTensorIteratorF64(%[[VAL_16]]) : (!llvm.ptr<i8>) -> ()
-// CHECK: %[[VAL_25:.*]] = bufferization.to_tensor %[[VAL_20]] : memref<?x?xf64>
-// CHECK: return %[[VAL_25]] : tensor<?x?xf64>
-// CHECK: }
+// CHECK-LABEL: func.func @sparse_convert_2d_dyn2
+// CHECK-NOT: sparse_tensor.reorder_coo
+// CHECK: memref.alloc
+// CHECK: linalg.fill
+// CHECK: sparse_tensor.foreach
+// CHECK: memref.store
+// CHECK: bufferization.to_tensor
func.func @sparse_convert_2d_dyn2(%arg0: tensor<?x?xf64, #SparseMatrix>) -> tensor<?x?xf64> {
%0 = sparse_tensor.convert %arg0 : tensor<?x?xf64, #SparseMatrix> to tensor<?x?xf64>
return %0 : tensor<?x?xf64>
}
-// CHECK-LABEL: func.func @sparse_convert_3d(
-// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr<i8>) -> tensor<2x3x4xf64> {
-// CHECK-DAG: %[[VAL_1:.*]] = arith.constant 0.000000e+00 : f64
-// CHECK-DAG: %[[VAL_2:.*]] = arith.constant 6 : i32
-// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 1 : i32
-// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 0 : i32
-// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 1 : index
-// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 0 : index
-// CHECK-DAG: %[[VAL_7:.*]] = arith.constant 2 : index
-// CHECK-DAG: %[[VAL_8:.*]] = arith.constant 3 : index
-// CHECK-DAG: %[[VAL_9:.*]] = arith.constant 4 : index
-// CHECK-DAG: %[[VAL_10:.*]] = arith.constant 4 : i8
-// CHECK: %[[VAL_11:.*]] = memref.alloca() : memref<3xi8>
-// CHECK: %[[VAL_12:.*]] = memref.cast %[[VAL_11]] : memref<3xi8> to memref<?xi8>
-// CHECK: memref.store %[[VAL_10]], %[[VAL_11]]{{\[}}%[[VAL_6]]] : memref<3xi8>
-// CHECK: memref.store %[[VAL_10]], %[[VAL_11]]{{\[}}%[[VAL_5]]] : memref<3xi8>
-// CHECK: memref.store %[[VAL_10]], %[[VAL_11]]{{\[}}%[[VAL_7]]] : memref<3xi8>
-// CHECK: %[[VAL_13:.*]] = memref.alloca() : memref<3xindex>
-// CHECK: %[[VAL_14:.*]] = memref.cast %[[VAL_13]] : memref<3xindex> to memref<?xindex>
-// CHECK: memref.store %[[VAL_7]], %[[VAL_13]]{{\[}}%[[VAL_6]]] : memref<3xindex>
-// CHECK: memref.store %[[VAL_8]], %[[VAL_13]]{{\[}}%[[VAL_5]]] : memref<3xindex>
-// CHECK: memref.store %[[VAL_9]], %[[VAL_13]]{{\[}}%[[VAL_7]]] : memref<3xindex>
-// CHECK: %[[VAL_15:.*]] = memref.alloca() : memref<3xindex>
-// CHECK: %[[VAL_16:.*]] = memref.cast %[[VAL_15]] : memref<3xindex> to memref<?xindex>
-// CHECK: memref.store %[[VAL_6]], %[[VAL_15]]{{\[}}%[[VAL_6]]] : memref<3xindex>
-// CHECK: memref.store %[[VAL_5]], %[[VAL_15]]{{\[}}%[[VAL_5]]] : memref<3xindex>
-// CHECK: memref.store %[[VAL_7]], %[[VAL_15]]{{\[}}%[[VAL_7]]] : memref<3xindex>
-// CHECK: %[[VAL_17:.*]] = call @newSparseTensor(%[[VAL_14]], %[[VAL_14]], %[[VAL_12]], %[[VAL_16]], %[[VAL_16]], %[[VAL_4]], %[[VAL_4]], %[[VAL_3]], %[[VAL_2]], %[[VAL_0]]) : (memref<?xindex>, memref<?xindex>, memref<?xi8>, memref<?xindex>, memref<?xindex>, i32, i32, i32, i32, !llvm.ptr<i8>) -> !llvm.ptr<i8>
-// CHECK: %[[VAL_18:.*]] = memref.alloca() : memref<3xindex>
-// CHECK: %[[VAL_19:.*]] = memref.cast %[[VAL_18]] : memref<3xindex> to memref<?xindex>
-// CHECK: %[[VAL_20:.*]] = memref.alloca() : memref<f64>
-// CHECK: %[[VAL_21:.*]] = memref.alloc() : memref<2x3x4xf64>
-// CHECK: linalg.fill ins(%[[VAL_1]] : f64) outs(%[[VAL_21]] : memref<2x3x4xf64>)
-// CHECK: scf.while : () -> () {
-// CHECK: %[[VAL_22:.*]] = func.call @getNextF64(%[[VAL_17]], %[[VAL_19]], %[[VAL_20]]) : (!llvm.ptr<i8>, memref<?xindex>, memref<f64>) -> i1
-// CHECK: scf.condition(%[[VAL_22]])
-// CHECK: } do {
-// CHECK: %[[VAL_23:.*]] = memref.load %[[VAL_18]]{{\[}}%[[VAL_6]]] : memref<3xindex>
-// CHECK: %[[VAL_24:.*]] = memref.load %[[VAL_18]]{{\[}}%[[VAL_5]]] : memref<3xindex>
-// CHECK: %[[VAL_25:.*]] = memref.load %[[VAL_18]]{{\[}}%[[VAL_7]]] : memref<3xindex>
-// CHECK: %[[VAL_26:.*]] = memref.load %[[VAL_20]][] : memref<f64>
-// CHECK: memref.store %[[VAL_26]], %[[VAL_21]]{{\[}}%[[VAL_23]], %[[VAL_24]], %[[VAL_25]]] : memref<2x3x4xf64>
-// CHECK: scf.yield
-// CHECK: }
-// CHECK: call @delSparseTensorIteratorF64(%[[VAL_17]]) : (!llvm.ptr<i8>) -> ()
-// CHECK: %[[VAL_27:.*]] = bufferization.to_tensor %[[VAL_21]] : memref<2x3x4xf64>
-// CHECK: return %[[VAL_27]] : tensor<2x3x4xf64>
-// CHECK: }
+// CHECK-LABEL: func.func @sparse_convert_3d
+// CHECK-NOT: sparse_tensor.reorder_coo
+// CHECK: memref.alloc
+// CHECK: linalg.fill
+// CHECK: sparse_tensor.foreach
+// CHECK: memref.store
+// CHECK: bufferization.to_tensor
func.func @sparse_convert_3d(%arg0: tensor<2x3x4xf64, #SparseTensor>) -> tensor<2x3x4xf64> {
%0 = sparse_tensor.convert %arg0 : tensor<2x3x4xf64, #SparseTensor> to tensor<2x3x4xf64>
return %0 : tensor<2x3x4xf64>
diff --git a/mlir/test/Dialect/SparseTensor/convert_sparse2sparse.mlir b/mlir/test/Dialect/SparseTensor/convert_sparse2sparse.mlir
index e8e69dc861015ce..658e8aa40022eb2 100644
--- a/mlir/test/Dialect/SparseTensor/convert_sparse2sparse.mlir
+++ b/mlir/test/Dialect/SparseTensor/convert_sparse2sparse.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s --sparse-tensor-conversion --canonicalize --cse | FileCheck %s
+// RUN: mlir-opt %s --stage-sparse-ops --post-sparsification-rewrite="enable-foreach=false" --canonicalize --cse | FileCheck %s
#SparseVector64 = #sparse_tensor.encoding<{
map = (d0) -> (d0 : compressed),
@@ -33,185 +33,56 @@
map = (d0 : #sparse_tensor<slice(2, 2, 1)>, d1 : #sparse_tensor<slice(12, 13, 1)>) -> (d0 : compressed(nonunique), d1 : singleton)
}>
-// CHECK-LABEL: func.func @sparse_nop_convert(
-// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr<i8>) -> !llvm.ptr<i8> {
-// CHECK: return %[[VAL_0]] : !llvm.ptr<i8>
-// CHECK: }
+// CHECK-LABEL: func.func @sparse_nop_convert
+// CHECK-NEXT: return
func.func @sparse_nop_convert(%arg0: tensor<64xf32, #SparseVector>) -> tensor<64xf32, #SparseVector> {
%0 = sparse_tensor.convert %arg0 : tensor<64xf32, #SparseVector> to tensor<64xf32, #SparseVector>
return %0 : tensor<64xf32, #SparseVector>
}
-// CHECK-LABEL: func.func @sparse_hidden_nop_cast(
-// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr<i8>) -> !llvm.ptr<i8> {
-// CHECK: return %[[VAL_0]] : !llvm.ptr<i8>
-// CHECK: }
+// CHECK-LABEL: func.func @sparse_hidden_nop_cast
+// TODO: The following convert should be a cast instead.
+// CHECK: sparse_tensor.convert
+// CHECK: return
func.func @sparse_hidden_nop_cast(%arg0: tensor<32xf32, #SparseVector>) -> tensor<?xf32, #SparseVector> {
%0 = sparse_tensor.convert %arg0 : tensor<32xf32, #SparseVector> to tensor<?xf32, #SparseVector>
return %0 : tensor<?xf32, #SparseVector>
}
// CHECK-LABEL: func.func @sparse_convert_1d_ss(
-// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr<i8>) -> !llvm.ptr<i8> {
-// CHECK-DAG: %[[VAL_1:.*]] = arith.constant 3 : i32
-// CHECK-DAG: %[[VAL_2:.*]] = arith.constant 2 : i32
-// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 8 : i8
-// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 0 : index
-// CHECK: %[[VAL_5:.*]] = call @sparseDimSize(%[[VAL_0]], %[[VAL_4]]) : (!llvm.ptr<i8>, index) -> index
-// CHECK: %[[VAL_6:.*]] = memref.alloca() : memref<1xi8>
-// CHECK: %[[VAL_7:.*]] = memref.cast %[[VAL_6]] : memref<1xi8> to memref<?xi8>
-// CHECK: memref.store %[[VAL_3]], %[[VAL_6]]{{\[}}%[[VAL_4]]] : memref<1xi8>
-// CHECK: %[[VAL_8:.*]] = memref.alloca() : memref<1xindex>
-// CHECK: %[[VAL_9:.*]] = memref.cast %[[VAL_8]] : memref<1xindex> to memref<?xindex>
-// CHECK: memref.store %[[VAL_5]], %[[VAL_8]]{{\[}}%[[VAL_4]]] : memref<1xindex>
-// CHECK: %[[VAL_10:.*]] = memref.alloca() : memref<1xindex>
-// CHECK: %[[VAL_11:.*]] = memref.cast %[[VAL_10]] : memref<1xindex> to memref<?xindex>
-// CHECK: memref.store %[[VAL_4]], %[[VAL_10]]{{\[}}%[[VAL_4]]] : memref<1xindex>
-// CHECK: %[[VAL_12:.*]] = call @newSparseTensor(%[[VAL_9]], %[[VAL_9]], %[[VAL_7]], %[[VAL_11]], %[[VAL_11]], %[[VAL_2]], %[[VAL_2]], %[[VAL_2]], %[[VAL_1]], %[[VAL_0]]) : (memref<?xindex>, memref<?xindex>, memref<?xi8>, memref<?xindex>, memref<?xindex>, i32, i32, i32, i32, !llvm.ptr<i8>) -> !llvm.ptr<i8>
-// CHECK: return %[[VAL_12]] : !llvm.ptr<i8>
-// CHECK: }
+// TODO: libgen path need to support efficient format conversion (e.g., 32 bit pos -> 64 bit pos).
+// Maybe we should use a different operator as well to be clear.
func.func @sparse_convert_1d_ss(%arg0: tensor<?xf32, #SparseVector64>) -> tensor<?xf32, #SparseVector32> {
%0 = sparse_tensor.convert %arg0 : tensor<?xf32, #SparseVector64> to tensor<?xf32, #SparseVector32>
return %0 : tensor<?xf32, #SparseVector32>
}
// CHECK-LABEL: func.func @sparse_convert(
-// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr<i8>) -> !llvm.ptr<i8> {
-// CHECK-DAG: %[[VAL_1:.*]] = arith.constant 3 : i32
-// CHECK-DAG: %[[VAL_2:.*]] = arith.constant 2 : i32
-// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 8 : i8
-// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 0 : index
-// CHECK: %[[VAL_5:.*]] = call @sparseDimSize(%[[VAL_0]], %[[VAL_4]]) : (!llvm.ptr<i8>, index) -> index
-// CHECK: %[[VAL_6:.*]] = memref.alloca() : memref<1xi8>
-// CHECK: %[[VAL_7:.*]] = memref.cast %[[VAL_6]] : memref<1xi8> to memref<?xi8>
-// CHECK: memref.store %[[VAL_3]], %[[VAL_6]]{{\[}}%[[VAL_4]]] : memref<1xi8>
-// CHECK: %[[VAL_8:.*]] = memref.alloca() : memref<1xindex>
-// CHECK: %[[VAL_9:.*]] = memref.cast %[[VAL_8]] : memref<1xindex> to memref<?xindex>
-// CHECK: memref.store %[[VAL_5]], %[[VAL_8]]{{\[}}%[[VAL_4]]] : memref<1xindex>
-// CHECK: %[[VAL_10:.*]] = memref.alloca() : memref<1xindex>
-// CHECK: %[[VAL_11:.*]] = memref.cast %[[VAL_10]] : memref<1xindex> to memref<?xindex>
-// CHECK: memref.store %[[VAL_4]], %[[VAL_10]]{{\[}}%[[VAL_4]]] : memref<1xindex>
-// CHECK: %[[VAL_12:.*]] = call @newSparseTensor(%[[VAL_9]], %[[VAL_9]], %[[VAL_7]], %[[VAL_11]], %[[VAL_11]], %[[VAL_2]], %[[VAL_2]], %[[VAL_2]], %[[VAL_1]], %[[VAL_0]]) : (memref<?xindex>, memref<?xindex>, memref<?xi8>, memref<?xindex>, memref<?xindex>, i32, i32, i32, i32, !llvm.ptr<i8>) -> !llvm.ptr<i8>
-// CHECK: return %[[VAL_12]] : !llvm.ptr<i8>
-// CHECK: }
+// TODO: libgen path need to support efficient format conversion (e.g., 32 bit pos -> 64 bit pos).
+// Maybe we should use a different operator as well to be clear.
func.func @sparse_convert(%arg0: tensor<?xf32, #SparseVector64>) -> tensor<?xf32, #SparseVector32> {
%0 = sparse_tensor.convert %arg0 : tensor<?xf32, #SparseVector64> to tensor<?xf32, #SparseVector32>
return %0 : tensor<?xf32, #SparseVector32>
}
-#SparseSingleton64 = #sparse_tensor.encoding<{
- map = (d0) -> (d0 : singleton),
- posWidth = 64,
- crdWidth = 64
-}>
-
-#SparseSingleton32 = #sparse_tensor.encoding<{
- map = (d0) -> (d0 : singleton),
- posWidth = 32,
- crdWidth = 32
-}>
-
-//
-// CHECK-LABEL: func.func @sparse_convert_singleton(
-// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr<i8>) -> !llvm.ptr<i8> {
-// CHECK-DAG: %[[VAL_1:.*]] = arith.constant 3 : i32
-// CHECK-DAG: %[[VAL_2:.*]] = arith.constant 2 : i32
-// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 16 : i8
-// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 0 : index
-// CHECK: %[[VAL_5:.*]] = call @sparseDimSize(%[[VAL_0]], %[[VAL_4]]) : (!llvm.ptr<i8>, index) -> index
-// CHECK: %[[VAL_6:.*]] = memref.alloca() : memref<1xi8>
-// CHECK: %[[VAL_7:.*]] = memref.cast %[[VAL_6]] : memref<1xi8> to memref<?xi8>
-// CHECK: memref.store %[[VAL_3]], %[[VAL_6]]{{\[}}%[[VAL_4]]] : memref<1xi8>
-// CHECK: %[[VAL_8:.*]] = memref.alloca() : memref<1xindex>
-// CHECK: %[[VAL_9:.*]] = memref.cast %[[VAL_8]] : memref<1xindex> to memref<?xindex>
-// CHECK: memref.store %[[VAL_5]], %[[VAL_8]]{{\[}}%[[VAL_4]]] : memref<1xindex>
-// CHECK: %[[VAL_10:.*]] = memref.alloca() : memref<1xindex>
-// CHECK: %[[VAL_11:.*]] = memref.cast %[[VAL_10]] : memref<1xindex> to memref<?xindex>
-// CHECK: memref.store %[[VAL_4]], %[[VAL_10]]{{\[}}%[[VAL_4]]] : memref<1xindex>
-// CHECK: %[[VAL_12:.*]] = call @newSparseTensor(%[[VAL_9]], %[[VAL_9]], %[[VAL_7]], %[[VAL_11]], %[[VAL_11]], %[[VAL_2]], %[[VAL_2]], %[[VAL_2]], %[[VAL_1]], %[[VAL_0]]) : (memref<?xindex>, memref<?xindex>, memref<?xi8>, memref<?xindex>, memref<?xindex>, i32, i32, i32, i32, !llvm.ptr<i8>) -> !llvm.ptr<i8>
-// CHECK: return %[[VAL_12]] : !llvm.ptr<i8>
-// CHECK: }
-func.func @sparse_convert_singleton(%arg0: tensor<?xf32, #SparseSingleton64>) -> tensor<?xf32, #SparseSingleton32> {
- %0 = sparse_tensor.convert %arg0 : tensor<?xf32, #SparseSingleton64> to tensor<?xf32, #SparseSingleton32>
- return %0 : tensor<?xf32, #SparseSingleton32>
-}
-
-// CHECK-LABEL: func.func @sparse_convert_permuted(
-// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr<i8>) -> !llvm.ptr<i8> {
-// CHECK-DAG: %[[VAL_1:.*]] = arith.constant 5 : i32
-// CHECK-DAG: %[[VAL_2:.*]] = arith.constant 2 : i32
-// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 0 : i32
-// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 8 : i8
-// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 2 : index
-// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 1 : index
-// CHECK-DAG: %[[VAL_7:.*]] = arith.constant 0 : index
-// CHECK: %[[VAL_8:.*]] = call @sparseDimSize(%[[VAL_0]], %[[VAL_7]]) : (!llvm.ptr<i8>, index) -> index
-// CHECK: %[[VAL_9:.*]] = call @sparseDimSize(%[[VAL_0]], %[[VAL_6]]) : (!llvm.ptr<i8>, index) -> index
-// CHECK: %[[VAL_10:.*]] = call @sparseDimSize(%[[VAL_0]], %[[VAL_5]]) : (!llvm.ptr<i8>, index) -> index
-// CHECK: %[[VAL_11:.*]] = memref.alloca() : memref<3xi8>
-// CHECK: %[[VAL_12:.*]] = memref.cast %[[VAL_11]] : memref<3xi8> to memref<?xi8>
-// CHECK: memref.store %[[VAL_4]], %[[VAL_11]]{{\[}}%[[VAL_7]]] : memref<3xi8>
-// CHECK: memref.store %[[VAL_4]], %[[VAL_11]]{{\[}}%[[VAL_6]]] : memref<3xi8>
-// CHECK: memref.store %[[VAL_4]], %[[VAL_11]]{{\[}}%[[VAL_5]]] : memref<3xi8>
-// CHECK: %[[VAL_13:.*]] = memref.alloca() : memref<3xindex>
-// CHECK: %[[VAL_14:.*]] = memref.cast %[[VAL_13]] : memref<3xindex> to memref<?xindex>
-// CHECK: memref.store %[[VAL_8]], %[[VAL_13]]{{\[}}%[[VAL_7]]] : memref<3xindex>
-// CHECK: memref.store %[[VAL_9]], %[[VAL_13]]{{\[}}%[[VAL_6]]] : memref<3xindex>
-// CHECK: memref.store %[[VAL_10]], %[[VAL_13]]{{\[}}%[[VAL_5]]] : memref<3xindex>
-// CHECK: %[[VAL_15:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_5]]] : memref<3xindex>
-// CHECK: %[[VAL_16:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_7]]] : memref<3xindex>
-// CHECK: %[[VAL_17:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_6]]] : memref<3xindex>
-// CHECK: %[[VAL_18:.*]] = memref.alloca() : memref<3xindex>
-// CHECK: %[[VAL_19:.*]] = memref.cast %[[VAL_18]] : memref<3xindex> to memref<?xindex>
-// CHECK: memref.store %[[VAL_6]], %[[VAL_18]]{{\[}}%[[VAL_7]]] : memref<3xindex>
-// CHECK: memref.store %[[VAL_5]], %[[VAL_18]]{{\[}}%[[VAL_6]]] : memref<3xindex>
-// CHECK: memref.store %[[VAL_7]], %[[VAL_18]]{{\[}}%[[VAL_5]]] : memref<3xindex>
-// CHECK: %[[VAL_20:.*]] = memref.alloca() : memref<3xindex>
-// CHECK: %[[VAL_21:.*]] = memref.cast %[[VAL_20]] : memref<3xindex> to memref<?xindex>
-// CHECK: memref.store %[[VAL_5]], %[[VAL_20]]{{\[}}%[[VAL_7]]] : memref<3xindex>
-// CHECK: memref.store %[[VAL_7]], %[[VAL_20]]{{\[}}%[[VAL_6]]] : memref<3xindex>
-// CHECK: memref.store %[[VAL_6]], %[[VAL_20]]{{\[}}%[[VAL_5]]] : memref<3xindex>
-// CHECK: %[[VAL_22:.*]] = memref.alloca() : memref<3xindex>
-// CHECK: %[[VAL_23:.*]] = memref.cast %[[VAL_22]] : memref<3xindex> to memref<?xindex>
-// CHECK: memref.store %[[VAL_15]], %[[VAL_22]]{{\[}}%[[VAL_7]]] : memref<3xindex>
-// CHECK: memref.store %[[VAL_16]], %[[VAL_22]]{{\[}}%[[VAL_6]]] : memref<3xindex>
-// CHECK: memref.store %[[VAL_17]], %[[VAL_22]]{{\[}}%[[VAL_5]]] : memref<3xindex>
-// CHECK: %[[VAL_24:.*]] = call @newSparseTensor(%[[VAL_14]], %[[VAL_23]], %[[VAL_12]], %[[VAL_19]], %[[VAL_21]], %[[VAL_3]], %[[VAL_3]], %[[VAL_2]], %[[VAL_1]], %[[VAL_0]]) : (memref<?xindex>, memref<?xindex>, memref<?xi8>, memref<?xindex>, memref<?xindex>, i32, i32, i32, i32, !llvm.ptr<i8>) -> !llvm.ptr<i8>
-// CHECK: %[[VAL_25:.*]] = call @newSparseTensor(%[[VAL_14]], %[[VAL_23]], %[[VAL_12]], %[[VAL_19]], %[[VAL_21]], %[[VAL_3]], %[[VAL_3]], %[[VAL_2]], %[[VAL_2]], %[[VAL_24]]) : (memref<?xindex>, memref<?xindex>, memref<?xi8>, memref<?xindex>, memref<?xindex>, i32, i32, i32, i32, !llvm.ptr<i8>) -> !llvm.ptr<i8>
-// CHECK: call @delSparseTensorCOOF32(%[[VAL_24]]) : (!llvm.ptr<i8>) -> ()
-// CHECK: return %[[VAL_25]] : !llvm.ptr<i8>
-// CHECK: }
+// CHECK-LABEL: func.func @sparse_convert_permuted
+// CHECK: sparse_tensor.foreach
+// CHECK: sparse_tensor.insert
+// CHECK: sparse_tensor.load
+// CHECK: sparse_tensor.reorder_coo
+// CHECK: sparse_tensor.foreach
+// CHECK: sparse_tensor.insert
+// CHECK: sparse_tensor.load
func.func @sparse_convert_permuted(%arg0: tensor<?x?x?xf32, #SortedCOO3D>) -> tensor<?x?x?xf32, #TsssPermuted> {
%0 = sparse_tensor.convert %arg0 : tensor<?x?x?xf32, #SortedCOO3D> to tensor<?x?x?xf32, #TsssPermuted>
return %0 : tensor<?x?x?xf32, #TsssPermuted>
}
-// CHECK-LABEL: func.func @sparse_convert_slice(
-// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr<i8>) -> !llvm.ptr<i8> {
-// CHECK-DAG: %[[VAL_1:.*]] = arith.constant 3 : i32
-// CHECK-DAG: %[[VAL_2:.*]] = arith.constant 6 : i32
-// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 0 : i32
-// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 1 : index
-// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 0 : index
-// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 2 : index
-// CHECK-DAG: %[[VAL_7:.*]] = arith.constant 13 : index
-// CHECK-DAG: %[[VAL_8:.*]] = arith.constant 9 : i8
-// CHECK-DAG: %[[VAL_9:.*]] = arith.constant 16 : i8
-// CHECK: %[[VAL_10:.*]] = memref.alloca() : memref<2xi8>
-// CHECK: %[[VAL_11:.*]] = memref.cast %[[VAL_10]] : memref<2xi8> to memref<?xi8>
-// CHECK: memref.store %[[VAL_8]], %[[VAL_10]]{{\[}}%[[VAL_5]]] : memref<2xi8>
-// CHECK: memref.store %[[VAL_9]], %[[VAL_10]]{{\[}}%[[VAL_4]]] : memref<2xi8>
-// CHECK: %[[VAL_12:.*]] = memref.alloca() : memref<2xindex>
-// CHECK: %[[VAL_13:.*]] = memref.cast %[[VAL_12]] : memref<2xindex> to memref<?xindex>
-// CHECK: memref.store %[[VAL_6]], %[[VAL_12]]{{\[}}%[[VAL_5]]] : memref<2xindex>
-// CHECK: memref.store %[[VAL_7]], %[[VAL_12]]{{\[}}%[[VAL_4]]] : memref<2xindex>
-// CHECK: %[[VAL_14:.*]] = memref.alloca() : memref<2xindex>
-// CHECK: %[[VAL_15:.*]] = memref.cast %[[VAL_14]] : memref<2xindex> to memref<?xindex>
-// CHECK: memref.store %[[VAL_5]], %[[VAL_14]]{{\[}}%[[VAL_5]]] : memref<2xindex>
-// CHECK: memref.store %[[VAL_4]], %[[VAL_14]]{{\[}}%[[VAL_4]]] : memref<2xindex>
-// CHECK: %[[VAL_16:.*]] = call @newSparseTensor(%[[VAL_13]], %[[VAL_13]], %[[VAL_11]], %[[VAL_15]], %[[VAL_15]], %[[VAL_3]], %[[VAL_3]], %[[VAL_2]], %[[VAL_1]], %[[VAL_0]]) : (memref<?xindex>, memref<?xindex>, memref<?xi8>, memref<?xindex>, memref<?xindex>, i32, i32, i32, i32, !llvm.ptr<i8>) -> !llvm.ptr<i8>
-// CHECK: return %[[VAL_16]] : !llvm.ptr<i8>
-// CHECK: }
+// CHECK-LABEL: func.func @sparse_convert_slice
+// CHECK: sparse_tensor.foreach
+// CHECK: sparse_tensor.insert
+// CHECK: sparse_tensor.load
+// CHECK-NOT: sparse_tensor.reorder_coo
func.func @sparse_convert_slice(%arg0: tensor<2x13xi32, #COOSlice>) -> (tensor<2x13xi32, #SortedCOO2D>) {
%0 = sparse_tensor.convert %arg0 : tensor<2x13xi32, #COOSlice> to tensor<2x13xi32, #SortedCOO2D>
return %0 : tensor<2x13xi32, #SortedCOO2D>
diff --git a/mlir/test/Dialect/SparseTensor/sparse_concat.mlir b/mlir/test/Dialect/SparseTensor/sparse_concat.mlir
index 2fb4529e5695e58..bdfab54dc6daeb5 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_concat.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_concat.mlir
@@ -5,8 +5,7 @@
#DCSR = #sparse_tensor.encoding<{map = (d0, d1) -> (d0 : compressed, d1 : compressed)}>
-#DENSE = #sparse_tensor.encoding<{map = (d0, d1) -> (d0 : dense, d1 : dense)}>
-#DENSE_P = #sparse_tensor.encoding<{map = (d0, d1) -> (d1 : dense, d0 : dense)}>
+
// CHECK-LABEL: @concat_sparse_sparse(
// CHECK-SAME: %[[TMP_arg0:.*]]: tensor<2x4xf64, #sparse_tensor
// CHECK-SAME: %[[TMP_arg1:.*]]: tensor<3x4xf64, #sparse_tensor
@@ -258,173 +257,3 @@ func.func @concat_sparse_sparse_dense(%arg0: tensor<2x4xf64, #DCSR>,
tensor<4x4xf64, #DCSR> to tensor<?x?xf64>
return %0 : tensor<?x?xf64>
}
-
-// CHECK-LABEL: @concat_sparse_sparse_annotated_dense(
-// CHECK-SAME: %[[TMP_arg0:.*]]: tensor<2x4xf64, #sparse_tensor
-// CHECK-SAME: %[[TMP_arg1:.*]]: tensor<3x4xf64, #sparse_tensor
-// CHECK-SAME: %[[TMP_arg2:.*]]: tensor<4x4xf64, #sparse_tensor
-// CHECK-DAG: %[[TMP_c0:.*]] = arith.constant 0 : index
-// CHECK-DAG: %[[TMP_c1:.*]] = arith.constant 1 : index
-// CHECK-DAG: %[[TMP_c5:.*]] = arith.constant 5 : index
-// CHECK-DAG: %[[TMP_c2:.*]] = arith.constant 2 : index
-// CHECK-DAG: %[[TMP_c9:.*]] = arith.constant 9 : index
-// CHECK-DAG: %[[TMP_c4:.*]] = arith.constant 4 : index
-// CHECK: %[[TMP_0:.*]] = bufferization.alloc_tensor(%[[TMP_c9]], %[[TMP_c4]]) : tensor<?x?xf64, #sparse_tensor
-// CHECK: %[[VAL_0:.*]] = sparse_tensor.values %[[TMP_0]] : tensor<?x?xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xf64>
-// CHECK: %[[DIM_0:.*]] = memref.alloca() : memref<2xindex>
-// CHECK: memref.store %[[TMP_c9]], %[[DIM_0]][%[[TMP_c0]]] : memref<2xindex>
-// CHECK: memref.store %[[TMP_c4]], %[[DIM_0]][%[[TMP_c1]]] : memref<2xindex>
-// CHECK: %[[VAL_1:.*]] = memref.reshape %[[VAL_0]](%[[DIM_0]]) : (memref<?xf64>, memref<2xindex>) -> memref<?x?xf64>
-// CHECK: %[[TMP_1:.*]] = sparse_tensor.positions %[[TMP_arg0]] {level = 0 : index} : tensor<2x4xf64, #sparse_tensor
-// CHECK: %[[TMP_2:.*]] = sparse_tensor.coordinates %[[TMP_arg0]] {level = 0 : index} : tensor<2x4xf64, #sparse_tensor
-// CHECK: %[[TMP_3:.*]] = sparse_tensor.positions %[[TMP_arg0]] {level = 1 : index} : tensor<2x4xf64, #sparse_tensor
-// CHECK: %[[TMP_4:.*]] = sparse_tensor.coordinates %[[TMP_arg0]] {level = 1 : index} : tensor<2x4xf64, #sparse_tensor
-// CHECK: %[[TMP_5:.*]] = sparse_tensor.values %[[TMP_arg0]] : tensor<2x4xf64, #sparse_tensor
-// CHECK: %[[TMP_6:.*]] = memref.load %[[TMP_1]][%[[TMP_c0]]] : memref<?xindex>
-// CHECK: %[[TMP_7:.*]] = memref.load %[[TMP_1]][%[[TMP_c1]]] : memref<?xindex>
-// CHECK: scf.for %[[TMP_arg3:.*]] = %[[TMP_6]] to %[[TMP_7]] step %[[TMP_c1]]
-// CHECK: %[[TMP_23:.*]] = memref.load %[[TMP_2]][%[[TMP_arg3]]] : memref<?xindex>
-// CHECK-DAG: %[[TMP_25:.*]] = memref.load %[[TMP_3]][%[[TMP_arg3]]] : memref<?xindex>
-// CHECK-DAG: %[[TMP_24:.*]] = arith.addi %[[TMP_arg3]], %[[TMP_c1]] : index
-// CHECK: %[[TMP_26:.*]] = memref.load %[[TMP_3]][%[[TMP_24]]] : memref<?xindex>
-// CHECK: scf.for %[[TMP_arg4:.*]] = %[[TMP_25]] to %[[TMP_26]] step %[[TMP_c1]]
-// CHECK: %[[TMP_27:.*]] = memref.load %[[TMP_4]][%[[TMP_arg4]]] : memref<?xindex>
-// CHECK: %[[TMP_28:.*]] = memref.load %[[TMP_5]][%[[TMP_arg4]]] : memref<?xf64>
-// CHECK: memref.store %[[TMP_28]], %[[VAL_1]][%[[TMP_23]], %[[TMP_27]]] : memref<?x?xf64>
-// CHECK: }
-// CHECK: }
-// CHECK: %[[TMP_8:.*]] = sparse_tensor.positions %[[TMP_arg1]] {level = 0 : index} : tensor<3x4xf64, #sparse_tensor
-// CHECK: %[[TMP_9:.*]] = sparse_tensor.coordinates %[[TMP_arg1]] {level = 0 : index} : tensor<3x4xf64, #sparse_tensor
-// CHECK: %[[TMP_10:.*]] = sparse_tensor.positions %[[TMP_arg1]] {level = 1 : index} : tensor<3x4xf64, #sparse_tensor
-// CHECK: %[[TMP_11:.*]] = sparse_tensor.coordinates %[[TMP_arg1]] {level = 1 : index} : tensor<3x4xf64, #sparse_tensor
-// CHECK: %[[TMP_12:.*]] = sparse_tensor.values %[[TMP_arg1]] : tensor<3x4xf64, #sparse_tensor
-// CHECK: %[[TMP_13:.*]] = memref.load %[[TMP_8]][%[[TMP_c0]]] : memref<?xindex>
-// CHECK: %[[TMP_14:.*]] = memref.load %[[TMP_8]][%[[TMP_c1]]] : memref<?xindex>
-// CHECK: scf.for %[[TMP_arg3:.*]] = %[[TMP_13]] to %[[TMP_14]] step %[[TMP_c1]]
-// CHECK: %[[TMP_23:.*]] = memref.load %[[TMP_9]][%[[TMP_arg3]]] : memref<?xindex>
-// CHECK-DAG: %[[TMP_25:.*]] = memref.load %[[TMP_10]][%[[TMP_arg3]]] : memref<?xindex>
-// CHECK-DAG: %[[TMP_24:.*]] = arith.addi %[[TMP_arg3]], %[[TMP_c1]] : index
-// CHECK: %[[TMP_26:.*]] = memref.load %[[TMP_10]][%[[TMP_24]]] : memref<?xindex>
-// CHECK: scf.for %[[TMP_arg4:.*]] = %[[TMP_25]] to %[[TMP_26]] step %[[TMP_c1]]
-// CHECK: %[[TMP_27:.*]] = memref.load %[[TMP_11]][%[[TMP_arg4]]] : memref<?xindex>
-// CHECK: %[[TMP_28:.*]] = memref.load %[[TMP_12]][%[[TMP_arg4]]] : memref<?xf64>
-// CHECK: %[[TMP_29:.*]] = arith.addi %[[TMP_23]], %[[TMP_c2]] : index
-// CHECK: memref.store %[[TMP_28]], %[[VAL_1]][%[[TMP_29]], %[[TMP_27]]] : memref<?x?xf64>
-// CHECK: }
-// CHECK: }
-// CHECK: %[[TMP_15:.*]] = sparse_tensor.positions %[[TMP_arg2]] {level = 0 : index} : tensor<4x4xf64, #sparse_tensor
-// CHECK: %[[TMP_16:.*]] = sparse_tensor.coordinates %[[TMP_arg2]] {level = 0 : index} : tensor<4x4xf64, #sparse_tensor
-// CHECK: %[[TMP_17:.*]] = sparse_tensor.positions %[[TMP_arg2]] {level = 1 : index} : tensor<4x4xf64, #sparse_tensor
-// CHECK: %[[TMP_18:.*]] = sparse_tensor.coordinates %[[TMP_arg2]] {level = 1 : index} : tensor<4x4xf64, #sparse_tensor
-// CHECK: %[[TMP_19:.*]] = sparse_tensor.values %[[TMP_arg2]] : tensor<4x4xf64, #sparse_tensor
-// CHECK: %[[TMP_20:.*]] = memref.load %[[TMP_15]][%[[TMP_c0]]] : memref<?xindex>
-// CHECK: %[[TMP_21:.*]] = memref.load %[[TMP_15]][%[[TMP_c1]]] : memref<?xindex>
-// CHECK: scf.for %[[TMP_arg3:.*]] = %[[TMP_20]] to %[[TMP_21]] step %[[TMP_c1]]
-// CHECK: %[[TMP_23:.*]] = memref.load %[[TMP_16]][%[[TMP_arg3]]] : memref<?xindex>
-// CHECK: %[[TMP_25:.*]] = memref.load %[[TMP_17]][%[[TMP_arg3]]] : memref<?xindex>
-// CHECK: %[[TMP_24:.*]] = arith.addi %[[TMP_arg3]], %[[TMP_c1]] : index
-// CHECK: %[[TMP_26:.*]] = memref.load %[[TMP_17]][%[[TMP_24]]] : memref<?xindex>
-// CHECK: scf.for %[[TMP_arg4:.*]] = %[[TMP_25]] to %[[TMP_26]] step %[[TMP_c1]]
-// CHECK: %[[TMP_27:.*]] = memref.load %[[TMP_18]][%[[TMP_arg4]]] : memref<?xindex>
-// CHECK: %[[TMP_28:.*]] = memref.load %[[TMP_19]][%[[TMP_arg4]]] : memref<?xf64>
-// CHECK: %[[TMP_29:.*]] = arith.addi %[[TMP_23]], %[[TMP_c5]] : index
-// CHECK: memref.store %[[TMP_28]], %[[VAL_1]][%[[TMP_29]], %[[TMP_27]]] : memref<?x?xf64>
-// CHECK: }
-// CHECK: }
-// CHECK: %[[R:.*]] = sparse_tensor.convert %[[TMP_0]]
-// CHECK: return %[[R]] : tensor<?x?xf64, #sparse_tensor.encoding<{{{.*}}}>>
-func.func @concat_sparse_sparse_annotated_dense(%arg0: tensor<2x4xf64, #DCSR>,
- %arg1: tensor<3x4xf64, #DCSR>,
- %arg2: tensor<4x4xf64, #DCSR>)
- -> tensor<?x?xf64, #DENSE> {
- %0 = sparse_tensor.concatenate %arg0, %arg1, %arg2 {dimension = 0 : index}
- : tensor<2x4xf64, #DCSR>,
- tensor<3x4xf64, #DCSR>,
- tensor<4x4xf64, #DCSR> to tensor<?x?xf64, #DENSE>
- return %0 : tensor<?x?xf64, #DENSE>
-}
-
-// CHECK-LABEL: @concat_sparse_sparse_annotated_dense_permute(
-// CHECK-SAME: %[[TMP_arg0:.*]]: tensor<2x4xf64, #sparse_tensor
-// CHECK-SAME: %[[TMP_arg1:.*]]: tensor<3x4xf64, #sparse_tensor
-// CHECK-SAME: %[[TMP_arg2:.*]]: tensor<4x4xf64, #sparse_tensor
-// CHECK-DAG: %[[TMP_c0:.*]] = arith.constant 0 : index
-// CHECK-DAG: %[[TMP_c1:.*]] = arith.constant 1 : index
-// CHECK-DAG: %[[TMP_c5:.*]] = arith.constant 5 : index
-// CHECK-DAG: %[[TMP_c2:.*]] = arith.constant 2 : index
-// CHECK-DAG: %[[TMP_c9:.*]] = arith.constant 9 : index
-// CHECK-DAG: %[[TMP_c4:.*]] = arith.constant 4 : index
-// CHECK: %[[TMP_0:.*]] = bufferization.alloc_tensor(%[[TMP_c9]], %[[TMP_c4]]) : tensor<?x?xf64, #sparse_tensor
-// CHECK: %[[VAL_0:.*]] = sparse_tensor.values %[[TMP_0]] : tensor<?x?xf64, #sparse_tensor
-// CHECK: %[[DIM_0:.*]] = memref.alloca() : memref<2xindex>
-// CHECK: memref.store %[[TMP_c4]], %[[DIM_0]][%[[TMP_c0]]] : memref<2xindex>
-// CHECK: memref.store %[[TMP_c9]], %[[DIM_0]][%[[TMP_c1]]] : memref<2xindex>
-// CHECK: %[[VAL_1:.*]] = memref.reshape %[[VAL_0]](%[[DIM_0]]) : (memref<?xf64>, memref<2xindex>) -> memref<?x?xf64>
-// CHECK: %[[TMP_1:.*]] = sparse_tensor.positions %[[TMP_arg0]] {level = 0 : index} : tensor<2x4xf64, #sparse_tensor
-// CHECK: %[[TMP_2:.*]] = sparse_tensor.coordinates %[[TMP_arg0]] {level = 0 : index} : tensor<2x4xf64, #sparse_tensor
-// CHECK: %[[TMP_3:.*]] = sparse_tensor.positions %[[TMP_arg0]] {level = 1 : index} : tensor<2x4xf64, #sparse_tensor
-// CHECK: %[[TMP_4:.*]] = sparse_tensor.coordinates %[[TMP_arg0]] {level = 1 : index} : tensor<2x4xf64, #sparse_tensor
-// CHECK: %[[TMP_5:.*]] = sparse_tensor.values %[[TMP_arg0]] : tensor<2x4xf64, #sparse_tensor
-// CHECK: %[[TMP_6:.*]] = memref.load %[[TMP_1]][%[[TMP_c0]]] : memref<?xindex>
-// CHECK: %[[TMP_7:.*]] = memref.load %[[TMP_1]][%[[TMP_c1]]] : memref<?xindex>
-// CHECK: scf.for %[[TMP_arg3:.*]] = %[[TMP_6]] to %[[TMP_7]] step %[[TMP_c1]]
-// CHECK: %[[TMP_23:.*]] = memref.load %[[TMP_2]][%[[TMP_arg3]]] : memref<?xindex>
-// CHECK-DAG: %[[TMP_25:.*]] = memref.load %[[TMP_3]][%[[TMP_arg3]]] : memref<?xindex>
-// CHECK-DAG: %[[TMP_24:.*]] = arith.addi %[[TMP_arg3]], %[[TMP_c1]] : index
-// CHECK: %[[TMP_26:.*]] = memref.load %[[TMP_3]][%[[TMP_24]]] : memref<?xindex>
-// CHECK: scf.for %[[TMP_arg4:.*]] = %[[TMP_25]] to %[[TMP_26]] step %[[TMP_c1]]
-// CHECK: %[[TMP_27:.*]] = memref.load %[[TMP_4]][%[[TMP_arg4]]] : memref<?xindex>
-// CHECK: %[[TMP_28:.*]] = memref.load %[[TMP_5]][%[[TMP_arg4]]] : memref<?xf64>
-// CHECK: memref.store %[[TMP_28]], %[[VAL_1]][%[[TMP_27]], %[[TMP_23]]] : memref<?x?xf64>
-// CHECK: }
-// CHECK: }
-// CHECK: %[[TMP_8:.*]] = sparse_tensor.positions %[[TMP_arg1]] {level = 0 : index} : tensor<3x4xf64, #sparse_tensor
-// CHECK: %[[TMP_9:.*]] = sparse_tensor.coordinates %[[TMP_arg1]] {level = 0 : index} : tensor<3x4xf64, #sparse_tensor
-// CHECK: %[[TMP_10:.*]] = sparse_tensor.positions %[[TMP_arg1]] {level = 1 : index} : tensor<3x4xf64, #sparse_tensor
-// CHECK: %[[TMP_11:.*]] = sparse_tensor.coordinates %[[TMP_arg1]] {level = 1 : index} : tensor<3x4xf64, #sparse_tensor
-// CHECK: %[[TMP_12:.*]] = sparse_tensor.values %[[TMP_arg1]] : tensor<3x4xf64, #sparse_tensor
-// CHECK: %[[TMP_13:.*]] = memref.load %[[TMP_8]][%[[TMP_c0]]] : memref<?xindex>
-// CHECK: %[[TMP_14:.*]] = memref.load %[[TMP_8]][%[[TMP_c1]]] : memref<?xindex>
-// CHECK: scf.for %[[TMP_arg3:.*]] = %[[TMP_13]] to %[[TMP_14]] step %[[TMP_c1]]
-// CHECK: %[[TMP_23:.*]] = memref.load %[[TMP_9]][%[[TMP_arg3]]] : memref<?xindex>
-// CHECK-DAG: %[[TMP_25:.*]] = memref.load %[[TMP_10]][%[[TMP_arg3]]] : memref<?xindex>
-// CHECK-DAG: %[[TMP_24:.*]] = arith.addi %[[TMP_arg3]], %[[TMP_c1]] : index
-// CHECK: %[[TMP_26:.*]] = memref.load %[[TMP_10]][%[[TMP_24]]] : memref<?xindex>
-// CHECK: scf.for %[[TMP_arg4:.*]] = %[[TMP_25]] to %[[TMP_26]] step %[[TMP_c1]]
-// CHECK: %[[TMP_27:.*]] = memref.load %[[TMP_11]][%[[TMP_arg4]]] : memref<?xindex>
-// CHECK: %[[TMP_28:.*]] = memref.load %[[TMP_12]][%[[TMP_arg4]]] : memref<?xf64>
-// CHECK: %[[TMP_29:.*]] = arith.addi %[[TMP_23]], %[[TMP_c2]] : index
-// CHECK: memref.store %[[TMP_28]], %[[VAL_1]][%[[TMP_27]], %[[TMP_29]]] : memref<?x?xf64>
-// CHECK: }
-// CHECK: }
-// CHECK: %[[TMP_15:.*]] = sparse_tensor.positions %[[TMP_arg2]] {level = 0 : index} : tensor<4x4xf64, #sparse_tensor
-// CHECK: %[[TMP_16:.*]] = sparse_tensor.coordinates %[[TMP_arg2]] {level = 0 : index} : tensor<4x4xf64, #sparse_tensor
-// CHECK: %[[TMP_17:.*]] = sparse_tensor.positions %[[TMP_arg2]] {level = 1 : index} : tensor<4x4xf64, #sparse_tensor
-// CHECK: %[[TMP_18:.*]] = sparse_tensor.coordinates %[[TMP_arg2]] {level = 1 : index} : tensor<4x4xf64, #sparse_tensor
-// CHECK: %[[TMP_19:.*]] = sparse_tensor.values %[[TMP_arg2]] : tensor<4x4xf64, #sparse_tensor
-// CHECK: %[[TMP_20:.*]] = memref.load %[[TMP_15]][%[[TMP_c0]]] : memref<?xindex>
-// CHECK: %[[TMP_21:.*]] = memref.load %[[TMP_15]][%[[TMP_c1]]] : memref<?xindex>
-// CHECK: scf.for %[[TMP_arg3:.*]] = %[[TMP_20]] to %[[TMP_21]] step %[[TMP_c1]]
-// CHECK: %[[TMP_23:.*]] = memref.load %[[TMP_16]][%[[TMP_arg3]]] : memref<?xindex>
-// CHECK: %[[TMP_25:.*]] = memref.load %[[TMP_17]][%[[TMP_arg3]]] : memref<?xindex>
-// CHECK: %[[TMP_24:.*]] = arith.addi %[[TMP_arg3]], %[[TMP_c1]] : index
-// CHECK: %[[TMP_26:.*]] = memref.load %[[TMP_17]][%[[TMP_24]]] : memref<?xindex>
-// CHECK: scf.for %[[TMP_arg4:.*]] = %[[TMP_25]] to %[[TMP_26]] step %[[TMP_c1]]
-// CHECK: %[[TMP_27:.*]] = memref.load %[[TMP_18]][%[[TMP_arg4]]] : memref<?xindex>
-// CHECK: %[[TMP_28:.*]] = memref.load %[[TMP_19]][%[[TMP_arg4]]] : memref<?xf64>
-// CHECK: %[[TMP_29:.*]] = arith.addi %[[TMP_23]], %[[TMP_c5]] : index
-// CHECK: memref.store %[[TMP_28]], %[[VAL_1]][%[[TMP_27]], %[[TMP_29]]] : memref<?x?xf64>
-// CHECK: }
-// CHECK: }
-// CHECK: %[[R:.*]] = sparse_tensor.convert %[[TMP_0]]
-// CHECK: return %[[R]] : tensor<?x?xf64, #sparse_tensor.encoding<{{{.*}}}>>
-func.func @concat_sparse_sparse_annotated_dense_permute(%arg0: tensor<2x4xf64, #DCSR>,
- %arg1: tensor<3x4xf64, #DCSR>,
- %arg2: tensor<4x4xf64, #DCSR>)
- -> tensor<?x?xf64, #DENSE_P> {
- %0 = sparse_tensor.concatenate %arg0, %arg1, %arg2 {dimension = 0 : index}
- : tensor<2x4xf64, #DCSR>,
- tensor<3x4xf64, #DCSR>,
- tensor<4x4xf64, #DCSR> to tensor<?x?xf64, #DENSE_P>
- return %0 : tensor<?x?xf64, #DENSE_P>
-}
>From 5f8cbbbb9673b665d7c2cb23842b2132ca1b2d9b Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Thu, 12 Oct 2023 17:20:40 +0000
Subject: [PATCH 3/5] remove deadcode
---
.../Transforms/SparseTensorConversion.cpp | 70 -------------------
1 file changed, 70 deletions(-)
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
index 53928803b9b880b..5d9ee6906749ec1 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
@@ -299,76 +299,6 @@ static void genDelCOOCall(OpBuilder &builder, Location loc, Type elemTp,
createFuncCall(builder, loc, name, {}, coo, EmitCInterface::Off);
}
-/// Generates a call to release/delete a `SparseTensorIterator`.
-static void genDelIteratorCall(OpBuilder &builder, Location loc, Type elemTp,
- Value iter) {
- SmallString<26> name{"delSparseTensorIterator",
- primaryTypeFunctionSuffix(elemTp)};
- createFuncCall(builder, loc, name, {}, iter, EmitCInterface::Off);
-}
-
-/// Generates a call that adds one element to a coordinate scheme.
-/// In particular, this generates code like the following:
-/// val = a[i1,..,ik];
-/// if val != 0
-/// t->add(&val, [i1,..,ik], [p1,..,pk]);
-static void genAddEltCall(OpBuilder &builder, Location loc, Type eltType,
- Value lvlCOO, Value valPtr, Value dimCoords,
- Value dimToLvl) {
- SmallString<9> name{"addElt", primaryTypeFunctionSuffix(eltType)};
- SmallVector<Value, 4> params{lvlCOO, valPtr, dimCoords, dimToLvl};
- Type pTp = getOpaquePointerType(builder);
- createFuncCall(builder, loc, name, pTp, params, EmitCInterface::On);
-}
-
-/// Generates a call to `iter->getNext()`. If there is a next element,
-/// then it is copied into the out-parameters `coords` and `elemPtr`,
-/// and the return value is true. If there isn't a next element, then
-/// the return value is false.
-///
-/// The `coords` argument uses the same coordinate-space as the `iter`
-/// (which can be either dim- or lvl-coords, depending on context).
-static Value genGetNextCall(OpBuilder &builder, Location loc, Value iter,
- Value coords, Value elemPtr) {
- Type elemTp = cast<ShapedType>(elemPtr.getType()).getElementType();
- SmallString<10> name{"getNext", primaryTypeFunctionSuffix(elemTp)};
- SmallVector<Value, 3> params{iter, coords, elemPtr};
- Type i1 = builder.getI1Type();
- return createFuncCall(builder, loc, name, i1, params, EmitCInterface::On)
- .getResult(0);
-}
-
-/// Loads the value stored in `elemPtr`, and stores it at the coordinates
-/// `cvs` into a dense tensor created by `allocDenseTensor`.
-static void insertScalarIntoDenseTensor(OpBuilder &builder, Location loc,
- Value elemPtr, Value tensor,
- ValueRange cvs) {
- Value elemV = builder.create<memref::LoadOp>(loc, elemPtr);
- builder.create<memref::StoreOp>(loc, elemV, tensor, cvs);
-}
-
-/// Determine if the runtime library supports direct conversion to the
-/// given target `dimTypes`.
-static bool canUseDirectConversion(ArrayRef<DimLevelType> dimTypes) {
- bool alreadyCompressed = false;
- for (const auto dlt : dimTypes) {
- if (isCompressedDLT(dlt)) {
- if (alreadyCompressed)
- return false; // Multiple compressed dimensions not yet supported.
- alreadyCompressed = true;
- } else if (isDenseDLT(dlt)) {
- if (alreadyCompressed)
- return false; // Dense after Compressed not yet supported.
- } else if (isSingletonDLT(dlt)) {
- // Direct conversion doesn't have any particular problems with
- // singleton after compressed.
- } else { // TODO: investigate
- return false;
- }
- }
- return true;
-}
-
//===----------------------------------------------------------------------===//
// Conversion rules.
//===----------------------------------------------------------------------===//
>From 061951596dd966e6d6c6788a190c8aa85968f2f2 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Thu, 12 Oct 2023 17:36:00 +0000
Subject: [PATCH 4/5] address comments
---
mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h | 9 ++-------
1 file changed, 2 insertions(+), 7 deletions(-)
diff --git a/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h b/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h
index 751ee8d1b17dc37..12dd65cb08bb1c0 100644
--- a/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h
+++ b/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h
@@ -378,13 +378,8 @@ class SparseTensorStorage final : public SparseTensorStorageBase {
assert(lvlCoords && "Received nullptr for level-coordinates");
// TODO: get rid of this! canonicalize all-dense "sparse" array into dense
// tensors.
- bool allDense = true;
- for (DimLevelType lt : getLvlTypes()) {
- if (!isDenseDLT(lt)) {
- allDense = false;
- break;
- }
- }
+ bool allDense = std::all_of(getLvlTypes().begin(), getLvlTypes().end(),
+ [](DimLevelType lt) { return isDenseDLT(lt); });
if (allDense) {
uint64_t lvlRank = getLvlRank();
uint64_t valIdx = 0;
>From 9921b8e8f422d75f49d9c7ed19dd5d2a2116a6f5 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Thu, 12 Oct 2023 17:39:00 +0000
Subject: [PATCH 5/5] remove outdated comments
---
.../Dialect/SparseTensor/Transforms/StageSparseOperations.cpp | 1 -
1 file changed, 1 deletion(-)
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/StageSparseOperations.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/StageSparseOperations.cpp
index acdb888abe2b028..4c163ea6e067ba6 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/StageSparseOperations.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/StageSparseOperations.cpp
@@ -46,7 +46,6 @@ struct StageUnorderedConvert : public OpRewritePattern<ConvertOp> {
// -> sort
Type dstCOOTp = getCOOFromTypeWithOrdering(
dstStt.getRankedTensorType(), dstStt.getDimToLvl(), /*ordered=*/true);
- // TODO: this should be a sort_coo operation.
Value dstCOO = rewriter.create<ReorderCOOOp>(
loc, dstCOOTp, srcCOO, SparseTensorSortKind::HybridQuickSort);
More information about the Mlir-commits
mailing list