[Mlir-commits] [mlir] 70633a8 - [mlir][sparse] first general insertion implementation with pure codegen
Aart Bik
llvmlistbot at llvm.org
Tue Nov 8 13:10:14 PST 2022
Author: Aart Bik
Date: 2022-11-08T13:10:05-08:00
New Revision: 70633a8d55a543eff892cc3316eaa3605d084637
URL: https://github.com/llvm/llvm-project/commit/70633a8d55a543eff892cc3316eaa3605d084637
DIFF: https://github.com/llvm/llvm-project/commit/70633a8d55a543eff892cc3316eaa3605d084637.diff
LOG: [mlir][sparse] first general insertion implementation with pure codegen
This revision generalizes lowering the sparse_tensor.insert op into actual code that directly operates on the memrefs of a sparse storage scheme. The current insertion strategy does *not* rely on a cursor anymore, with introduces some testing overhead for each insertion (but still proportional to the rank, as before). Over time, we can optimize the code generation, but this version enables us to finish the effort to migrate from library to actual codegen.
Things to do:
(1) carefully deal with (un)ordered and (not)unique
(2) omit overhead when not needed
(3) optimize and specialize
(4) try to avoid the pointer "cleanup" (at HasInserts), and make sure the storage scheme is consistent at every insertion point (so that it can "escape" without concerns).
Reviewed By: Peiming
Differential Revision: https://reviews.llvm.org/D137457
Added:
mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_insert_2d.mlir
Modified:
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
mlir/test/Dialect/SparseTensor/codegen.mlir
mlir/test/Dialect/SparseTensor/scf_1_N_conversion.mlir
mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_insert_1d.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index 944139f38626a..a35f97c5c60fa 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -32,8 +32,8 @@ using namespace mlir::sparse_tensor;
namespace {
static constexpr uint64_t DimSizesIdx = 0;
-static constexpr uint64_t MemSizesIdx = 2;
-static constexpr uint64_t FieldsIdx = 3;
+static constexpr uint64_t MemSizesIdx = 1;
+static constexpr uint64_t FieldsIdx = 2;
//===----------------------------------------------------------------------===//
// Helper methods.
@@ -45,9 +45,9 @@ static UnrealizedConversionCastOp getTuple(Value tensor) {
}
/// Packs the given values as a "tuple" value.
-static Value genTuple(OpBuilder &rewriter, Location loc, Type tp,
+static Value genTuple(OpBuilder &builder, Location loc, Type tp,
ValueRange values) {
- return rewriter.create<UnrealizedConversionCastOp>(loc, TypeRange(tp), values)
+ return builder.create<UnrealizedConversionCastOp>(loc, TypeRange(tp), values)
.getResult(0);
}
@@ -71,9 +71,47 @@ static void flattenOperands(ValueRange operands,
}
}
-/// Gets the dimension size for the given sparse tensor at the given dim.
-/// Returns None if no sparse encoding is attached to the tensor type.
-static Optional<Value> sizeFromTensorAtDim(OpBuilder &rewriter, Location loc,
+/// Adds index conversions where needed.
+static Value toType(OpBuilder &builder, Location loc, Value value, Type tp) {
+ if (value.getType() != tp)
+ return builder.create<arith::IndexCastOp>(loc, tp, value);
+ return value;
+}
+
+/// Generates a load with proper index typing.
+static Value genLoad(OpBuilder &builder, Location loc, Value mem, Value idx) {
+ idx = toType(builder, loc, idx, builder.getIndexType());
+ return builder.create<memref::LoadOp>(loc, mem, idx);
+}
+
+/// Generates a store with proper index typing and (for indices) proper value.
+static void genStore(OpBuilder &builder, Location loc, Value val, Value mem,
+ Value idx) {
+ idx = toType(builder, loc, idx, builder.getIndexType());
+ val = toType(builder, loc, val,
+ mem.getType().cast<ShapedType>().getElementType());
+ builder.create<memref::StoreOp>(loc, val, mem, idx);
+}
+
+/// Creates a straightforward counting for-loop.
+static scf::ForOp createFor(OpBuilder &builder, Location loc, Value upper,
+ SmallVectorImpl<Value> &fields,
+ Value lower = Value()) {
+ Type indexType = builder.getIndexType();
+ if (!lower)
+ lower = constantZero(builder, loc, indexType);
+ Value one = constantOne(builder, loc, indexType);
+ scf::ForOp forOp = builder.create<scf::ForOp>(loc, lower, upper, one, fields);
+ for (unsigned i = 0, e = fields.size(); i < e; i++)
+ fields[i] = forOp.getRegionIterArg(i);
+ builder.setInsertionPointToStart(forOp.getBody());
+ return forOp;
+}
+
+/// Gets the dimension size for the given sparse tensor at the given
+/// original dimension 'dim'. Returns None if no sparse encoding is
+/// attached to the given tensor type.
+static Optional<Value> sizeFromTensorAtDim(OpBuilder &builder, Location loc,
RankedTensorType tensorTp,
Value adaptedValue, unsigned dim) {
auto enc = getSparseTensorEncoding(tensorTp);
@@ -84,40 +122,63 @@ static Optional<Value> sizeFromTensorAtDim(OpBuilder &rewriter, Location loc,
// Note that this is typically already done by DimOp's folding.
auto shape = tensorTp.getShape();
if (!ShapedType::isDynamic(shape[dim]))
- return constantIndex(rewriter, loc, shape[dim]);
+ return constantIndex(builder, loc, shape[dim]);
// Any other query can consult the dimSizes array at field DimSizesIdx,
// accounting for the reordering applied to the sparse storage.
auto tuple = getTuple(adaptedValue);
- Value idx = constantIndex(rewriter, loc, toStoredDim(tensorTp, dim));
- return rewriter
+ Value idx = constantIndex(builder, loc, toStoredDim(tensorTp, dim));
+ return builder
.create<memref::LoadOp>(loc, tuple.getInputs()[DimSizesIdx], idx)
.getResult();
}
+// Gets the dimension size at the given stored dimension 'd', either as a
+// constant for a static size, or otherwise dynamically through memSizes.
+Value sizeAtStoredDim(OpBuilder &builder, Location loc, RankedTensorType rtp,
+ SmallVectorImpl<Value> &fields, unsigned d) {
+ unsigned dim = toOrigDim(rtp, d);
+ auto shape = rtp.getShape();
+ if (!ShapedType::isDynamic(shape[dim]))
+ return constantIndex(builder, loc, shape[dim]);
+ return genLoad(builder, loc, fields[DimSizesIdx],
+ constantIndex(builder, loc, d));
+}
+
/// Translates field index to memSizes index.
static unsigned getMemSizesIndex(unsigned field) {
assert(FieldsIdx <= field);
return field - FieldsIdx;
}
+/// Creates a pushback op for given field and updates the fields array
+/// accordingly. This operation also updates the memSizes contents.
+static void createPushback(OpBuilder &builder, Location loc,
+ SmallVectorImpl<Value> &fields, unsigned field,
+ Value value, Value repeat = Value()) {
+ assert(FieldsIdx <= field && field < fields.size());
+ Type etp = fields[field].getType().cast<ShapedType>().getElementType();
+ fields[field] = builder.create<PushBackOp>(
+ loc, fields[field].getType(), fields[MemSizesIdx], fields[field],
+ toType(builder, loc, value, etp), APInt(64, getMemSizesIndex(field)),
+ repeat);
+}
+
/// Returns field index of sparse tensor type for pointers/indices, when set.
static unsigned getFieldIndex(Type type, unsigned ptrDim, unsigned idxDim) {
assert(getSparseTensorEncoding(type));
RankedTensorType rType = type.cast<RankedTensorType>();
unsigned field = FieldsIdx; // start past header
- unsigned ptr = 0;
- unsigned idx = 0;
for (unsigned r = 0, rank = rType.getShape().size(); r < rank; r++) {
if (isCompressedDim(rType, r)) {
- if (ptr++ == ptrDim)
+ if (r == ptrDim)
return field;
field++;
- if (idx++ == idxDim)
+ if (r == idxDim)
return field;
field++;
} else if (isSingletonDim(rType, r)) {
- if (idx++ == idxDim)
+ if (r == idxDim)
return field;
field++;
} else {
@@ -144,14 +205,13 @@ convertSparseTensorType(Type type, SmallVectorImpl<Type> &fields) {
Type ptrType = ptrWidth ? IntegerType::get(context, ptrWidth) : indexType;
Type eltType = rType.getElementType();
//
- // Sparse tensor storage for rank-dimensional tensor is organized as a
- // single compound type with the following fields. Note that every
+ // Sparse tensor storage scheme for rank-dimensional tensor is organized
+ // as a single compound type with the following fields. Note that every
// memref with ? size actually behaves as a "vector", i.e. the stored
// size is the capacity and the used size resides in the memSizes array.
//
// struct {
// memref<rank x index> dimSizes ; size in each dimension
- // memref<rank x index> dimCursor ; cursor in each dimension
// memref<n x index> memSizes ; sizes of ptrs/inds/values
// ; per-dimension d:
// ; if dense:
@@ -166,8 +226,7 @@ convertSparseTensorType(Type type, SmallVectorImpl<Type> &fields) {
//
unsigned rank = rType.getShape().size();
unsigned lastField = getFieldIndex(type, -1u, -1u);
- // The dimSizes array, dimCursor array, and memSizes array.
- fields.push_back(MemRefType::get({rank}, indexType));
+ // The dimSizes array and memSizes array.
fields.push_back(MemRefType::get({rank}, indexType));
fields.push_back(MemRefType::get({getMemSizesIndex(lastField)}, indexType));
// Per-dimension storage.
@@ -191,6 +250,41 @@ convertSparseTensorType(Type type, SmallVectorImpl<Type> &fields) {
return success();
}
+/// Generates code that allocates a sparse storage scheme for given rank.
+static void allocSchemeForRank(OpBuilder &builder, Location loc,
+ RankedTensorType rtp,
+ SmallVectorImpl<Value> &fields, unsigned field,
+ unsigned r0) {
+ unsigned rank = rtp.getShape().size();
+ Value linear = constantIndex(builder, loc, 1);
+ for (unsigned r = r0; r < rank; r++) {
+ if (isCompressedDim(rtp, r)) {
+ // Append linear x pointers, initialized to zero. Since each compressed
+ // dimension initially already has a single zero entry, this maintains
+ // the desired "linear + 1" length property at all times.
+ unsigned ptrWidth = getSparseTensorEncoding(rtp).getPointerBitWidth();
+ Type indexType = builder.getIndexType();
+ Type ptrType = ptrWidth ? builder.getIntegerType(ptrWidth) : indexType;
+ Value ptrZero = constantZero(builder, loc, ptrType);
+ createPushback(builder, loc, fields, field, ptrZero, linear);
+ return;
+ } else if (isSingletonDim(rtp, r)) {
+ return; // nothing to do
+ } else {
+ // Keep compounding the size, but nothing needs to be initialized
+ // at this level. We will eventually reach a compressed level or
+ // otherwise the values array for the from-here "all-dense" case.
+ assert(isDenseDim(rtp, r));
+ Value size = sizeAtStoredDim(builder, loc, rtp, fields, r);
+ linear = builder.create<arith::MulIOp>(loc, linear, size);
+ }
+ }
+ // Reached values array so prepare for an insertion.
+ Value valZero = constantZero(builder, loc, rtp.getElementType());
+ createPushback(builder, loc, fields, field, valZero, linear);
+ assert(fields.size() == ++field);
+}
+
/// Creates allocation operation.
static Value createAllocation(OpBuilder &builder, Location loc, Type type,
Value sz) {
@@ -203,7 +297,7 @@ static Value createAllocation(OpBuilder &builder, Location loc, Type type,
/// the "vector", while the actual size resides in the sizes array.
///
/// TODO: for efficiency, we will need heuristis to make educated guesses
-/// on the required capacities
+/// on the required capacities (see heuristic variable).
///
static void createAllocFields(OpBuilder &builder, Location loc, Type type,
ValueRange dynSizes,
@@ -213,17 +307,14 @@ static void createAllocFields(OpBuilder &builder, Location loc, Type type,
// Construct the basic types.
unsigned idxWidth = enc.getIndexBitWidth();
unsigned ptrWidth = enc.getPointerBitWidth();
- RankedTensorType rType = type.cast<RankedTensorType>();
+ RankedTensorType rtp = type.cast<RankedTensorType>();
Type indexType = builder.getIndexType();
Type idxType = idxWidth ? builder.getIntegerType(idxWidth) : indexType;
Type ptrType = ptrWidth ? builder.getIntegerType(ptrWidth) : indexType;
- Type eltType = rType.getElementType();
- auto shape = rType.getShape();
+ Type eltType = rtp.getElementType();
+ auto shape = rtp.getShape();
unsigned rank = shape.size();
- bool allDense = true;
- Value one = constantIndex(builder, loc, 1);
- Value linear = one;
- Value heuristic = one; // FIX, see TODO above
+ Value heuristic = constantIndex(builder, loc, 16);
// Build original sizes.
SmallVector<Value, 8> sizes;
for (unsigned r = 0, o = 0; r < rank; r++) {
@@ -232,114 +323,242 @@ static void createAllocFields(OpBuilder &builder, Location loc, Type type,
else
sizes.push_back(constantIndex(builder, loc, shape[r]));
}
- // The dimSizes array, dimCursor array, and memSizes array.
+ // The dimSizes array and memSizes array.
unsigned lastField = getFieldIndex(type, -1u, -1u);
Value dimSizes =
builder.create<memref::AllocOp>(loc, MemRefType::get({rank}, indexType));
- Value dimCursor =
- builder.create<memref::AllocOp>(loc, MemRefType::get({rank}, indexType));
Value memSizes = builder.create<memref::AllocOp>(
loc, MemRefType::get({getMemSizesIndex(lastField)}, indexType));
fields.push_back(dimSizes);
- fields.push_back(dimCursor);
fields.push_back(memSizes);
// Per-dimension storage.
for (unsigned r = 0; r < rank; r++) {
- // Get the original dimension (ro) for the current stored dimension.
- unsigned ro = toOrigDim(rType, r);
- builder.create<memref::StoreOp>(loc, sizes[ro], dimSizes,
- constantIndex(builder, loc, r));
- linear = builder.create<arith::MulIOp>(loc, linear, sizes[ro]);
- // Allocate fields.
- if (isCompressedDim(rType, r)) {
+ if (isCompressedDim(rtp, r)) {
fields.push_back(createAllocation(builder, loc, ptrType, heuristic));
fields.push_back(createAllocation(builder, loc, idxType, heuristic));
- allDense = false;
- } else if (isSingletonDim(rType, r)) {
+ } else if (isSingletonDim(rtp, r)) {
fields.push_back(createAllocation(builder, loc, idxType, heuristic));
- allDense = false;
} else {
- assert(isDenseDim(rType, r)); // no fields
+ assert(isDenseDim(rtp, r)); // no fields
}
}
- // The values array. For all-dense, the full length is required.
- // In all other case, we resort to the heuristical initial value.
- Value valuesSz = allDense ? linear : heuristic;
- fields.push_back(createAllocation(builder, loc, eltType, valuesSz));
- // Set memSizes.
- if (allDense)
- builder.create<memref::StoreOp>(
- loc, valuesSz, memSizes,
- constantIndex(builder, loc, 0)); // TODO: avoid memSizes in this case?
- else
- builder.create<linalg::FillOp>(
- loc, ValueRange{constantZero(builder, loc, indexType)},
- ValueRange{memSizes});
+ // The values array.
+ fields.push_back(createAllocation(builder, loc, eltType, heuristic));
assert(fields.size() == lastField);
+ // Initialize the storage scheme to an empty tensor. Initialized memSizes
+ // to all zeros, sets the dimSizes to known values and gives all pointer
+ // fields an initial zero entry, so that it is easier to maintain the
+ // "linear + 1" length property.
+ builder.create<linalg::FillOp>(
+ loc, ValueRange{constantZero(builder, loc, indexType)},
+ ValueRange{memSizes}); // zero memSizes
+ Value ptrZero = constantZero(builder, loc, ptrType);
+ for (unsigned r = 0, field = FieldsIdx; r < rank; r++) {
+ unsigned ro = toOrigDim(rtp, r);
+ genStore(builder, loc, sizes[ro], dimSizes, constantIndex(builder, loc, r));
+ if (isCompressedDim(rtp, r)) {
+ createPushback(builder, loc, fields, field, ptrZero);
+ field += 2;
+ } else if (isSingletonDim(rtp, r)) {
+ field += 1;
+ }
+ }
+ allocSchemeForRank(builder, loc, rtp, fields, FieldsIdx, /*rank=*/0);
}
-/// Creates a straightforward counting for-loop.
-static scf::ForOp createFor(OpBuilder &builder, Location loc, Value count,
- SmallVectorImpl<Value> &fields) {
+/// Helper method that generates block specific to compressed case:
+///
+/// plo = pointers[d][pos[d-1]]
+/// phi = pointers[d][pos[d-1]+1]
+/// msz = indices[d].size()
+/// if (plo < phi) {
+/// present = indices[d][phi-1] == i[d]
+/// } else { // first insertion
+/// present = false
+/// pointers[d][pos[d-1]] = msz
+/// }
+/// if (present) { // index already present
+/// next = phi-1
+/// } else {
+/// indices[d].push_back(i[d])
+/// pointers[d][pos[d-1]+1] = msz+1
+/// next = msz
+/// <prepare dimension d + 1>
+/// }
+/// pos[d] = next
+static Value genCompressed(OpBuilder &builder, Location loc,
+ RankedTensorType rtp, SmallVectorImpl<Value> &fields,
+ SmallVectorImpl<Value> &indices, Value value,
+ Value pos, unsigned field, unsigned d) {
+ unsigned rank = rtp.getShape().size();
+ SmallVector<Type, 4> types;
Type indexType = builder.getIndexType();
- Value zero = constantZero(builder, loc, indexType);
- Value one = constantOne(builder, loc, indexType);
- scf::ForOp forOp = builder.create<scf::ForOp>(loc, zero, count, one, fields);
+ Type boolType = builder.getIntegerType(1);
+ Value one = constantIndex(builder, loc, 1);
+ Value pp1 = builder.create<arith::AddIOp>(loc, pos, one);
+ Value plo = genLoad(builder, loc, fields[field], pos);
+ Value phi = genLoad(builder, loc, fields[field], pp1);
+ Value psz = constantIndex(builder, loc, getMemSizesIndex(field + 1));
+ Value msz = genLoad(builder, loc, fields[MemSizesIdx], psz);
+ Value phim1 = builder.create<arith::SubIOp>(
+ loc, toType(builder, loc, phi, indexType), one);
+ // Conditional expression.
+ Value lt =
+ builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, plo, phi);
+ types.push_back(boolType);
+ scf::IfOp ifOp1 = builder.create<scf::IfOp>(loc, types, lt, /*else*/ true);
+ types.pop_back();
+ builder.setInsertionPointToStart(&ifOp1.getThenRegion().front());
+ Value crd = genLoad(builder, loc, fields[field + 1], phim1);
+ Value eq = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
+ toType(builder, loc, crd, indexType),
+ indices[d]);
+ builder.create<scf::YieldOp>(loc, eq);
+ builder.setInsertionPointToStart(&ifOp1.getElseRegion().front());
+ if (d > 0)
+ genStore(builder, loc, msz, fields[field], pos);
+ builder.create<scf::YieldOp>(loc, constantI1(builder, loc, false));
+ builder.setInsertionPointAfter(ifOp1);
+ Value p = ifOp1.getResult(0);
+ // If present construct. Note that for a non-unique dimension level, we simply
+ // set the condition to false and rely on CSE/DCE to clean up the IR.
+ //
+ // TODO: generate less temporary IR?
+ //
for (unsigned i = 0, e = fields.size(); i < e; i++)
- fields[i] = forOp.getRegionIterArg(i);
- builder.setInsertionPointToStart(forOp.getBody());
- return forOp;
-}
-
-/// Creates a pushback op for given field and updates the fields array
-/// accordingly.
-static void createPushback(OpBuilder &builder, Location loc,
- SmallVectorImpl<Value> &fields, unsigned field,
- Value value) {
- assert(FieldsIdx <= field && field < fields.size());
- Type etp = fields[field].getType().cast<ShapedType>().getElementType();
- if (value.getType() != etp)
- value = builder.create<arith::IndexCastOp>(loc, etp, value);
- fields[field] = builder.create<PushBackOp>(
- loc, fields[field].getType(), fields[MemSizesIdx], fields[field], value,
- APInt(64, getMemSizesIndex(field)));
+ types.push_back(fields[i].getType());
+ types.push_back(indexType);
+ if (!isUniqueDim(rtp, d))
+ p = constantI1(builder, loc, false);
+ scf::IfOp ifOp2 = builder.create<scf::IfOp>(loc, types, p, /*else*/ true);
+ // If present (fields unaffected, update next to phim1).
+ builder.setInsertionPointToStart(&ifOp2.getThenRegion().front());
+ fields.push_back(phim1);
+ builder.create<scf::YieldOp>(loc, fields);
+ fields.pop_back();
+ // If !present (changes fields, update next).
+ builder.setInsertionPointToStart(&ifOp2.getElseRegion().front());
+ Value mszp1 = builder.create<arith::AddIOp>(loc, msz, one);
+ genStore(builder, loc, mszp1, fields[field], pp1);
+ createPushback(builder, loc, fields, field + 1, indices[d]);
+ // Prepare the next dimension "as needed".
+ if ((d + 1) < rank)
+ allocSchemeForRank(builder, loc, rtp, fields, field + 2, d + 1);
+ fields.push_back(msz);
+ builder.create<scf::YieldOp>(loc, fields);
+ fields.pop_back();
+ // Update fields and return next pos.
+ builder.setInsertionPointAfter(ifOp2);
+ unsigned o = 0;
+ for (unsigned i = 0, e = fields.size(); i < e; i++)
+ fields[i] = ifOp2.getResult(o++);
+ return ifOp2.getResult(o);
}
-/// Generates insertion code.
-//
-// TODO: generalize this for any rank and format currently it is just sparse
-// vectors as a proof of concept that we have everything in place!
-//
+/// Generates code along an insertion path without the need for a "cursor".
+/// This current insertion strategy comes at the expense of some testing
+/// overhead for each insertion. The strategy will be optimized later for
+/// common insertion patterns. The current insertion strategy also assumes
+/// insertions occur in "a reasonable order" that enables building the
+/// storage scheme in an appending/inserting kind of fashion (i.e. no
+/// in-between insertions that need data movement). The implementation
+/// relies on CSE/DCE to clean up all bookkeeping that is not needed.
+///
+/// TODO: better unord/not-unique; also generalize, optimize, specialize!
+///
static void genInsert(OpBuilder &builder, Location loc, RankedTensorType rtp,
SmallVectorImpl<Value> &fields,
SmallVectorImpl<Value> &indices, Value value) {
- unsigned rank = indices.size();
- assert(rtp.getShape().size() == rank);
- if (rank != 1 || !isCompressedDim(rtp, 0) || !isUniqueDim(rtp, 0) ||
- !isOrderedDim(rtp, 0))
- return; // TODO: add codegen
- // push_back memSizes indices-0 index
- // push_back memSizes values value
- createPushback(builder, loc, fields, FieldsIdx + 1, indices[0]);
- createPushback(builder, loc, fields, FieldsIdx + 2, value);
+ unsigned rank = rtp.getShape().size();
+ assert(rank == indices.size());
+ unsigned field = FieldsIdx; // start past header
+ Value pos = constantZero(builder, loc, builder.getIndexType());
+ // Generate code for every dimension.
+ for (unsigned d = 0; d < rank; d++) {
+ if (isCompressedDim(rtp, d)) {
+ // Create:
+ // if (!present) {
+ // indices[d].push_back(i[d])
+ // <update pointers and prepare dimension d + 1>
+ // }
+ // pos[d] = indices.size() - 1
+ // <insert @ pos[d] at next dimension d + 1>
+ pos = genCompressed(builder, loc, rtp, fields, indices, value, pos, field,
+ d);
+ field += 2;
+ } else if (isSingletonDim(rtp, d)) {
+ // Create:
+ // indices[d].push_back(i[d])
+ // pos[d] = pos[d-1]
+ // <insert @ pos[d] at next dimension d + 1>
+ createPushback(builder, loc, fields, field, indices[d]);
+ field += 1;
+ } else {
+ assert(isDenseDim(rtp, d));
+ // Construct the new position as:
+ // pos[d] = size * pos[d-1] + i[d]
+ // <insert @ pos[d] at next dimension d + 1>
+ Value size = sizeAtStoredDim(builder, loc, rtp, fields, d);
+ Value mult = builder.create<arith::MulIOp>(loc, size, pos);
+ pos = builder.create<arith::AddIOp>(loc, mult, indices[d]);
+ }
+ }
+ // Reached the actual value append/insert.
+ if (!isDenseDim(rtp, rank - 1))
+ createPushback(builder, loc, fields, field++, value);
+ else
+ genStore(builder, loc, value, fields[field++], pos);
+ assert(fields.size() == field);
}
/// Generations insertion finalization code.
-//
-// TODO: this too only works for the very simple case
-//
static void genEndInsert(OpBuilder &builder, Location loc, RankedTensorType rtp,
SmallVectorImpl<Value> &fields) {
- if (rtp.getShape().size() != 1 || !isCompressedDim(rtp, 0) ||
- !isUniqueDim(rtp, 0) || !isOrderedDim(rtp, 0))
- return; // TODO: add codegen
- // push_back memSizes pointers-0 0
- // push_back memSizes pointers-0 memSizes[2]
- Value zero = constantIndex(builder, loc, 0);
- Value two = constantIndex(builder, loc, 2);
- Value size = builder.create<memref::LoadOp>(loc, fields[MemSizesIdx], two);
- createPushback(builder, loc, fields, FieldsIdx, zero);
- createPushback(builder, loc, fields, FieldsIdx, size);
+ unsigned rank = rtp.getShape().size();
+ unsigned field = FieldsIdx; // start past header
+ for (unsigned d = 0; d < rank; d++) {
+ if (isCompressedDim(rtp, d)) {
+ // Compressed dimensions need a pointer cleanup for all entries
+ // that were not visited during the insertion pass.
+ //
+ // TODO: avoid cleanup and keep compressed scheme consistent at all times?
+ //
+ if (d > 0) {
+ unsigned ptrWidth = getSparseTensorEncoding(rtp).getPointerBitWidth();
+ Type indexType = builder.getIndexType();
+ Type ptrType = ptrWidth ? builder.getIntegerType(ptrWidth) : indexType;
+ Value mz = constantIndex(builder, loc, getMemSizesIndex(field));
+ Value hi = genLoad(builder, loc, fields[MemSizesIdx], mz);
+ Value zero = constantIndex(builder, loc, 0);
+ Value one = constantIndex(builder, loc, 1);
+ SmallVector<Value, 1> inits;
+ inits.push_back(genLoad(builder, loc, fields[field], zero));
+ scf::ForOp loop = createFor(builder, loc, hi, inits, one);
+ Value i = loop.getInductionVar();
+ Value oldv = loop.getRegionIterArg(0);
+ Value newv = genLoad(builder, loc, fields[field], i);
+ Value ptrZero = constantZero(builder, loc, ptrType);
+ Value cond = builder.create<arith::CmpIOp>(
+ loc, arith::CmpIPredicate::eq, newv, ptrZero);
+ scf::IfOp ifOp = builder.create<scf::IfOp>(loc, TypeRange(ptrType),
+ cond, /*else*/ true);
+ builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
+ genStore(builder, loc, oldv, fields[field], i);
+ builder.create<scf::YieldOp>(loc, oldv);
+ builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
+ builder.create<scf::YieldOp>(loc, newv);
+ builder.setInsertionPointAfter(ifOp);
+ builder.create<scf::YieldOp>(loc, ifOp.getResult(0));
+ builder.setInsertionPointAfter(loop);
+ }
+ field += 2;
+ } else if (isSingletonDim(rtp, d)) {
+ field++;
+ } else {
+ assert(isDenseDim(rtp, d));
+ }
+ }
+ assert(fields.size() == ++field);
}
//===----------------------------------------------------------------------===//
@@ -624,14 +843,14 @@ class SparseCompressConverter : public OpConversionPattern<CompressOp> {
// }
scf::ForOp loop = createFor(rewriter, loc, count, fields);
Value i = loop.getInductionVar();
- Value index = rewriter.create<memref::LoadOp>(loc, added, i);
- Value value = rewriter.create<memref::LoadOp>(loc, values, index);
+ Value index = genLoad(rewriter, loc, added, i);
+ Value value = genLoad(rewriter, loc, values, index);
indices.push_back(index);
+ // TODO: faster for subsequent insertions?
genInsert(rewriter, loc, dstType, fields, indices, value);
- rewriter.create<memref::StoreOp>(loc, constantZero(rewriter, loc, eltType),
- values, index);
- rewriter.create<memref::StoreOp>(loc, constantI1(rewriter, loc, false),
- filled, index);
+ genStore(rewriter, loc, constantZero(rewriter, loc, eltType), values,
+ index);
+ genStore(rewriter, loc, constantI1(rewriter, loc, false), filled, index);
rewriter.create<scf::YieldOp>(loc, fields);
// Deallocate the buffers on exit of the full loop nest.
Operation *parent = getTop(op);
diff --git a/mlir/test/Dialect/SparseTensor/codegen.mlir b/mlir/test/Dialect/SparseTensor/codegen.mlir
index 98009fff845f0..660cd7552db8a 100644
--- a/mlir/test/Dialect/SparseTensor/codegen.mlir
+++ b/mlir/test/Dialect/SparseTensor/codegen.mlir
@@ -48,33 +48,30 @@
// CHECK-LABEL: func @sparse_nop(
// CHECK-SAME: %[[A0:.*0]]: memref<1xindex>,
-// CHECK-SAME: %[[A1:.*1]]: memref<1xindex>,
-// CHECK-SAME: %[[A2:.*2]]: memref<3xindex>,
-// CHECK-SAME: %[[A3:.*3]]: memref<?xi32>,
-// CHECK-SAME: %[[A4:.*4]]: memref<?xi64>,
-// CHECK-SAME: %[[A5:.*5]]: memref<?xf64>)
-// CHECK: return %[[A0]], %[[A1]], %[[A2]], %[[A3]], %[[A4]], %[[A5]] :
-// CHECK-SAME: memref<1xindex>, memref<1xindex>, memref<3xindex>, memref<?xi32>, memref<?xi64>, memref<?xf64>
+// CHECK-SAME: %[[A1:.*1]]: memref<3xindex>,
+// CHECK-SAME: %[[A2:.*2]]: memref<?xi32>,
+// CHECK-SAME: %[[A3:.*3]]: memref<?xi64>,
+// CHECK-SAME: %[[A4:.*4]]: memref<?xf64>)
+// CHECK: return %[[A0]], %[[A1]], %[[A2]], %[[A3]], %[[A4]] :
+// CHECK-SAME: memref<1xindex>, memref<3xindex>, memref<?xi32>, memref<?xi64>, memref<?xf64>
func.func @sparse_nop(%arg0: tensor<?xf64, #SparseVector>) -> tensor<?xf64, #SparseVector> {
return %arg0 : tensor<?xf64, #SparseVector>
}
// CHECK-LABEL: func @sparse_nop_multi_ret(
// CHECK-SAME: %[[A0:.*0]]: memref<1xindex>,
-// CHECK-SAME: %[[A1:.*1]]: memref<1xindex>,
-// CHECK-SAME: %[[A2:.*2]]: memref<3xindex>,
-// CHECK-SAME: %[[A3:.*3]]: memref<?xi32>,
-// CHECK-SAME: %[[A4:.*4]]: memref<?xi64>,
-// CHECK-SAME: %[[A5:.*5]]: memref<?xf64>,
-// CHECK-SAME: %[[A6:.*6]]: memref<1xindex>,
-// CHECK-SAME: %[[A7:.*7]]: memref<1xindex>,
-// CHECK-SAME: %[[A8:.*8]]: memref<3xindex>,
-// CHECK-SAME: %[[A9:.*9]]: memref<?xi32>,
-// CHECK-SAME: %[[A10:.*10]]: memref<?xi64>,
-// CHECK-SAME: %[[A11:.*11]]: memref<?xf64>)
-// CHECK: return %[[A0]], %[[A1]], %[[A2]], %[[A3]], %[[A4]], %[[A5]], %[[A6]], %[[A7]], %[[A8]], %[[A9]], %[[A10]], %[[A11]] :
-// CHECK-SAME: memref<1xindex>, memref<1xindex>, memref<3xindex>, memref<?xi32>, memref<?xi64>, memref<?xf64>,
-// CHECK-SAME: memref<1xindex>, memref<1xindex>, memref<3xindex>, memref<?xi32>, memref<?xi64>, memref<?xf64>
+// CHECK-SAME: %[[A1:.*1]]: memref<3xindex>,
+// CHECK-SAME: %[[A2:.*2]]: memref<?xi32>,
+// CHECK-SAME: %[[A3:.*3]]: memref<?xi64>,
+// CHECK-SAME: %[[A4:.*4]]: memref<?xf64>,
+// CHECK-SAME: %[[A5:.*5]]: memref<1xindex>,
+// CHECK-SAME: %[[A6:.*6]]: memref<3xindex>,
+// CHECK-SAME: %[[A7:.*7]]: memref<?xi32>,
+// CHECK-SAME: %[[A8:.*8]]: memref<?xi64>,
+// CHECK-SAME: %[[A9:.*9]]: memref<?xf64>)
+// CHECK: return %[[A0]], %[[A1]], %[[A2]], %[[A3]], %[[A4]], %[[A5]], %[[A6]], %[[A7]], %[[A8]], %[[A9]] :
+// CHECK-SAME: memref<1xindex>, memref<3xindex>, memref<?xi32>, memref<?xi64>, memref<?xf64>,
+// CHECK-SAME: memref<1xindex>, memref<3xindex>, memref<?xi32>, memref<?xi64>, memref<?xf64>
func.func @sparse_nop_multi_ret(%arg0: tensor<?xf64, #SparseVector>,
%arg1: tensor<?xf64, #SparseVector>) ->
(tensor<?xf64, #SparseVector>, tensor<?xf64, #SparseVector>) {
@@ -83,21 +80,19 @@ func.func @sparse_nop_multi_ret(%arg0: tensor<?xf64, #SparseVector>,
// CHECK-LABEL: func @sparse_nop_call(
// CHECK-SAME: %[[A0:.*0]]: memref<1xindex>,
-// CHECK-SAME: %[[A1:.*1]]: memref<1xindex>,
-// CHECK-SAME: %[[A2:.*2]]: memref<3xindex>,
-// CHECK-SAME: %[[A3:.*3]]: memref<?xi32>,
-// CHECK-SAME: %[[A4:.*4]]: memref<?xi64>,
-// CHECK-SAME: %[[A5:.*5]]: memref<?xf64>,
-// CHECK-SAME: %[[A6:.*6]]: memref<1xindex>,
-// CHECK-SAME: %[[A7:.*7]]: memref<1xindex>,
-// CHECK-SAME: %[[A8:.*8]]: memref<3xindex>,
-// CHECK-SAME: %[[A9:.*9]]: memref<?xi32>,
-// CHECK-SAME: %[[A10:.*10]]: memref<?xi64>,
-// CHECK-SAME: %[[A11:.*11]]: memref<?xf64>)
-// CHECK: %[[T:.*]]:12 = call @sparse_nop_multi_ret(%[[A0]], %[[A1]], %[[A2]], %[[A3]], %[[A4]], %[[A5]], %[[A6]], %[[A7]], %[[A8]], %[[A9]], %[[A10]], %[[A11]])
-// CHECK: return %[[T]]#0, %[[T]]#1, %[[T]]#2, %[[T]]#3, %[[T]]#4, %[[T]]#5, %[[T]]#6, %[[T]]#7, %[[T]]#8, %[[T]]#9, %[[T]]#10, %[[T]]#11
-// CHECK-SAME: memref<1xindex>, memref<1xindex>, memref<3xindex>, memref<?xi32>, memref<?xi64>, memref<?xf64>,
-// CHECK-SAME: memref<1xindex>, memref<1xindex>, memref<3xindex>, memref<?xi32>, memref<?xi64>, memref<?xf64>
+// CHECK-SAME: %[[A1:.*1]]: memref<3xindex>,
+// CHECK-SAME: %[[A2:.*2]]: memref<?xi32>,
+// CHECK-SAME: %[[A3:.*3]]: memref<?xi64>,
+// CHECK-SAME: %[[A4:.*4]]: memref<?xf64>,
+// CHECK-SAME: %[[A5:.*5]]: memref<1xindex>,
+// CHECK-SAME: %[[A6:.*6]]: memref<3xindex>,
+// CHECK-SAME: %[[A7:.*7]]: memref<?xi32>,
+// CHECK-SAME: %[[A8:.*8]]: memref<?xi64>,
+// CHECK-SAME: %[[A9:.*9]]: memref<?xf64>)
+// CHECK: %[[T:.*]]:10 = call @sparse_nop_multi_ret(%[[A0]], %[[A1]], %[[A2]], %[[A3]], %[[A4]], %[[A5]], %[[A6]], %[[A7]], %[[A8]], %[[A9]])
+// CHECK: return %[[T]]#0, %[[T]]#1, %[[T]]#2, %[[T]]#3, %[[T]]#4, %[[T]]#5, %[[T]]#6, %[[T]]#7, %[[T]]#8, %[[T]]#9 :
+// CHECK-SAME: memref<1xindex>, memref<3xindex>, memref<?xi32>, memref<?xi64>, memref<?xf64>,
+// CHECK-SAME: memref<1xindex>, memref<3xindex>, memref<?xi32>, memref<?xi64>, memref<?xf64>
func.func @sparse_nop_call(%arg0: tensor<?xf64, #SparseVector>,
%arg1: tensor<?xf64, #SparseVector>) ->
(tensor<?xf64, #SparseVector>, tensor<?xf64, #SparseVector>) {
@@ -109,12 +104,11 @@ func.func @sparse_nop_call(%arg0: tensor<?xf64, #SparseVector>,
// CHECK-LABEL: func @sparse_nop_cast(
// CHECK-SAME: %[[A0:.*0]]: memref<1xindex>,
-// CHECK-SAME: %[[A1:.*1]]: memref<1xindex>,
-// CHECK-SAME: %[[A2:.*2]]: memref<3xindex>,
-// CHECK-SAME: %[[A3:.*3]]: memref<?xi32>,
-// CHECK-SAME: %[[A4:.*4]]: memref<?xi64>,
-// CHECK-SAME: %[[A5:.*5]]: memref<?xf32>)
-// CHECK: return %[[A0]], %[[A1]], %[[A2]], %[[A3]], %[[A4]], %[[A5]] :
+// CHECK-SAME: %[[A1:.*1]]: memref<3xindex>,
+// CHECK-SAME: %[[A2:.*2]]: memref<?xi32>,
+// CHECK-SAME: %[[A3:.*3]]: memref<?xi64>,
+// CHECK-SAME: %[[A4:.*4]]: memref<?xf32>)
+// CHECK: return %[[A0]], %[[A1]], %[[A2]], %[[A3]], %[[A4]] :
// CHECK-SAME: memref<1xindex>, memref<3xindex>, memref<?xi32>, memref<?xi64>, memref<?xf32>
func.func @sparse_nop_cast(%arg0: tensor<64xf32, #SparseVector>) -> tensor<?xf32, #SparseVector> {
%0 = tensor.cast %arg0 : tensor<64xf32, #SparseVector> to tensor<?xf32, #SparseVector>
@@ -123,11 +117,10 @@ func.func @sparse_nop_cast(%arg0: tensor<64xf32, #SparseVector>) -> tensor<?xf32
// CHECK-LABEL: func @sparse_nop_cast_3d(
// CHECK-SAME: %[[A0:.*0]]: memref<3xindex>,
-// CHECK-SAME: %[[A1:.*1]]: memref<3xindex>,
-// CHECK-SAME: %[[A2:.*2]]: memref<1xindex>,
-// CHECK-SAME: %[[A3:.*3]]: memref<?xf32>)
-// CHECK: return %[[A0]], %[[A1]], %[[A2]], %[[A3]] :
-// CHECK-SAME: memref<3xindex>, memref<3xindex>, memref<1xindex>, memref<?xf32>
+// CHECK-SAME: %[[A1:.*1]]: memref<1xindex>,
+// CHECK-SAME: %[[A2:.*2]]: memref<?xf32>)
+// CHECK: return %[[A0]], %[[A1]], %[[A2]] :
+// CHECK-SAME: memref<3xindex>, memref<1xindex>, memref<?xf32>
func.func @sparse_nop_cast_3d(%arg0: tensor<10x20x30xf32, #Dense3D>) -> tensor<?x?x?xf32, #Dense3D> {
%0 = tensor.cast %arg0 : tensor<10x20x30xf32, #Dense3D> to tensor<?x?x?xf32, #Dense3D>
return %0 : tensor<?x?x?xf32, #Dense3D>
@@ -135,9 +128,8 @@ func.func @sparse_nop_cast_3d(%arg0: tensor<10x20x30xf32, #Dense3D>) -> tensor<?
// CHECK-LABEL: func @sparse_dense_2d(
// CHECK-SAME: %[[A0:.*0]]: memref<2xindex>,
-// CHECK-SAME: %[[A1:.*1]]: memref<2xindex>,
-// CHECK-SAME: %[[A2:.*2]]: memref<1xindex>,
-// CHECK-SAME: %[[A3:.*3]]: memref<?xf64>)
+// CHECK-SAME: %[[A1:.*1]]: memref<1xindex>,
+// CHECK-SAME: %[[A2:.*2]]: memref<?xf64>)
// CHECK: return
func.func @sparse_dense_2d(%arg0: tensor<?x?xf64, #Dense2D>) {
return
@@ -145,11 +137,10 @@ func.func @sparse_dense_2d(%arg0: tensor<?x?xf64, #Dense2D>) {
// CHECK-LABEL: func @sparse_row(
// CHECK-SAME: %[[A0:.*0]]: memref<2xindex>,
-// CHECK-SAME: %[[A1:.*1]]: memref<2xindex>,
-// CHECK-SAME: %[[A2:.*2]]: memref<3xindex>,
-// CHECK-SAME: %[[A3:.*3]]: memref<?xi32>,
-// CHECK-SAME: %[[A4:.*4]]: memref<?xi64>,
-// CHECK-SAME: %[[A5:.*5]]: memref<?xf64>)
+// CHECK-SAME: %[[A1:.*1]]: memref<3xindex>,
+// CHECK-SAME: %[[A2:.*2]]: memref<?xi32>,
+// CHECK-SAME: %[[A3:.*3]]: memref<?xi64>,
+// CHECK-SAME: %[[A4:.*4]]: memref<?xf64>)
// CHECK: return
func.func @sparse_row(%arg0: tensor<?x?xf64, #Row>) {
return
@@ -157,11 +148,10 @@ func.func @sparse_row(%arg0: tensor<?x?xf64, #Row>) {
// CHECK-LABEL: func @sparse_csr(
// CHECK-SAME: %[[A0:.*0]]: memref<2xindex>,
-// CHECK-SAME: %[[A1:.*1]]: memref<2xindex>,
-// CHECK-SAME: %[[A2:.*2]]: memref<3xindex>,
-// CHECK-SAME: %[[A3:.*3]]: memref<?xi32>,
-// CHECK-SAME: %[[A4:.*4]]: memref<?xi64>,
-// CHECK-SAME: %[[A5:.*5]]: memref<?xf64>)
+// CHECK-SAME: %[[A1:.*1]]: memref<3xindex>,
+// CHECK-SAME: %[[A2:.*2]]: memref<?xi32>,
+// CHECK-SAME: %[[A3:.*3]]: memref<?xi64>,
+// CHECK-SAME: %[[A4:.*4]]: memref<?xf64>)
// CHECK: return
func.func @sparse_csr(%arg0: tensor<?x?xf64, #CSR>) {
return
@@ -169,13 +159,12 @@ func.func @sparse_csr(%arg0: tensor<?x?xf64, #CSR>) {
// CHECK-LABEL: func @sparse_dcsr(
// CHECK-SAME: %[[A0:.*0]]: memref<2xindex>,
-// CHECK-SAME: %[[A1:.*1]]: memref<2xindex>,
-// CHECK-SAME: %[[A2:.*2]]: memref<5xindex>,
-// CHECK-SAME: %[[A3:.*3]]: memref<?xi32>,
-// CHECK-SAME: %[[A4:.*4]]: memref<?xi64>,
-// CHECK-SAME: %[[A5:.*5]]: memref<?xi32>,
-// CHECK-SAME: %[[A6:.*6]]: memref<?xi64>,
-// CHECK-SAME: %[[A7:.*7]]: memref<?xf64>)
+// CHECK-SAME: %[[A1:.*1]]: memref<5xindex>,
+// CHECK-SAME: %[[A2:.*2]]: memref<?xi32>,
+// CHECK-SAME: %[[A3:.*3]]: memref<?xi64>,
+// CHECK-SAME: %[[A4:.*4]]: memref<?xi32>,
+// CHECK-SAME: %[[A5:.*5]]: memref<?xi64>,
+// CHECK-SAME: %[[A6:.*6]]: memref<?xf64>)
// CHECK: return
func.func @sparse_dcsr(%arg0: tensor<?x?xf64, #DCSR>) {
return
@@ -187,9 +176,8 @@ func.func @sparse_dcsr(%arg0: tensor<?x?xf64, #DCSR>) {
//
// CHECK-LABEL: func @sparse_dense_3d(
// CHECK-SAME: %[[A0:.*0]]: memref<3xindex>,
-// CHECK-SAME: %[[A1:.*1]]: memref<3xindex>,
-// CHECK-SAME: %[[A2:.*2]]: memref<1xindex>,
-// CHECK-SAME: %[[A3:.*3]]: memref<?xf64>)
+// CHECK-SAME: %[[A1:.*1]]: memref<1xindex>,
+// CHECK-SAME: %[[A2:.*2]]: memref<?xf64>)
// CHECK: %[[C:.*]] = arith.constant 20 : index
// CHECK: return %[[C]] : index
func.func @sparse_dense_3d(%arg0: tensor<10x20x30xf64, #Dense3D>) -> index {
@@ -205,9 +193,8 @@ func.func @sparse_dense_3d(%arg0: tensor<10x20x30xf64, #Dense3D>) -> index {
//
// CHECK-LABEL: func @sparse_dense_3d_dyn(
// CHECK-SAME: %[[A0:.*0]]: memref<3xindex>,
-// CHECK-SAME: %[[A1:.*1]]: memref<3xindex>,
-// CHECK-SAME: %[[A2:.*2]]: memref<1xindex>,
-// CHECK-SAME: %[[A3:.*3]]: memref<?xf64>)
+// CHECK-SAME: %[[A1:.*1]]: memref<1xindex>,
+// CHECK-SAME: %[[A2:.*2]]: memref<?xf64>)
// CHECK: %[[C:.*]] = arith.constant 2 : index
// CHECK: %[[L:.*]] = memref.load %[[A0]][%[[C]]] : memref<3xindex>
// CHECK: return %[[L]] : index
@@ -219,14 +206,13 @@ func.func @sparse_dense_3d_dyn(%arg0: tensor<?x?x?xf64, #Dense3D>) -> index {
// CHECK-LABEL: func @sparse_pointers_dcsr(
// CHECK-SAME: %[[A0:.*0]]: memref<2xindex>,
-// CHECK-SAME: %[[A1:.*1]]: memref<2xindex>,
-// CHECK-SAME: %[[A2:.*2]]: memref<5xindex>,
-// CHECK-SAME: %[[A3:.*3]]: memref<?xi32>,
-// CHECK-SAME: %[[A4:.*4]]: memref<?xi64>,
-// CHECK-SAME: %[[A5:.*5]]: memref<?xi32>,
-// CHECK-SAME: %[[A6:.*6]]: memref<?xi64>,
-// CHECK-SAME: %[[A7:.*7]]: memref<?xf64>)
-// CHECK: return %[[A5]] : memref<?xi32>
+// CHECK-SAME: %[[A1:.*1]]: memref<5xindex>,
+// CHECK-SAME: %[[A2:.*2]]: memref<?xi32>,
+// CHECK-SAME: %[[A3:.*3]]: memref<?xi64>,
+// CHECK-SAME: %[[A4:.*4]]: memref<?xi32>,
+// CHECK-SAME: %[[A5:.*5]]: memref<?xi64>,
+// CHECK-SAME: %[[A6:.*6]]: memref<?xf64>)
+// CHECK: return %[[A4]] : memref<?xi32>
func.func @sparse_pointers_dcsr(%arg0: tensor<?x?xf64, #DCSR>) -> memref<?xi32> {
%0 = sparse_tensor.pointers %arg0 { dimension = 1 : index } : tensor<?x?xf64, #DCSR> to memref<?xi32>
return %0 : memref<?xi32>
@@ -234,14 +220,13 @@ func.func @sparse_pointers_dcsr(%arg0: tensor<?x?xf64, #DCSR>) -> memref<?xi32>
// CHECK-LABEL: func @sparse_indices_dcsr(
// CHECK-SAME: %[[A0:.*0]]: memref<2xindex>,
-// CHECK-SAME: %[[A1:.*1]]: memref<2xindex>,
-// CHECK-SAME: %[[A2:.*2]]: memref<5xindex>,
-// CHECK-SAME: %[[A3:.*3]]: memref<?xi32>,
-// CHECK-SAME: %[[A4:.*4]]: memref<?xi64>,
-// CHECK-SAME: %[[A5:.*5]]: memref<?xi32>,
-// CHECK-SAME: %[[A6:.*6]]: memref<?xi64>,
-// CHECK-SAME: %[[A7:.*7]]: memref<?xf64>)
-// CHECK: return %[[A6]] : memref<?xi64>
+// CHECK-SAME: %[[A1:.*1]]: memref<5xindex>,
+// CHECK-SAME: %[[A2:.*2]]: memref<?xi32>,
+// CHECK-SAME: %[[A3:.*3]]: memref<?xi64>,
+// CHECK-SAME: %[[A4:.*4]]: memref<?xi32>,
+// CHECK-SAME: %[[A5:.*5]]: memref<?xi64>,
+// CHECK-SAME: %[[A6:.*6]]: memref<?xf64>)
+// CHECK: return %[[A5]] : memref<?xi64>
func.func @sparse_indices_dcsr(%arg0: tensor<?x?xf64, #DCSR>) -> memref<?xi64> {
%0 = sparse_tensor.indices %arg0 { dimension = 1 : index } : tensor<?x?xf64, #DCSR> to memref<?xi64>
return %0 : memref<?xi64>
@@ -249,14 +234,13 @@ func.func @sparse_indices_dcsr(%arg0: tensor<?x?xf64, #DCSR>) -> memref<?xi64> {
// CHECK-LABEL: func @sparse_values_dcsr(
// CHECK-SAME: %[[A0:.*0]]: memref<2xindex>,
-// CHECK-SAME: %[[A1:.*1]]: memref<2xindex>,
-// CHECK-SAME: %[[A2:.*2]]: memref<5xindex>,
-// CHECK-SAME: %[[A3:.*3]]: memref<?xi32>,
-// CHECK-SAME: %[[A4:.*4]]: memref<?xi64>,
-// CHECK-SAME: %[[A5:.*5]]: memref<?xi32>,
-// CHECK-SAME: %[[A6:.*6]]: memref<?xi64>,
-// CHECK-SAME: %[[A7:.*7]]: memref<?xf64>)
-// CHECK: return %[[A7]] : memref<?xf64>
+// CHECK-SAME: %[[A1:.*1]]: memref<5xindex>,
+// CHECK-SAME: %[[A2:.*2]]: memref<?xi32>,
+// CHECK-SAME: %[[A3:.*3]]: memref<?xi64>,
+// CHECK-SAME: %[[A4:.*4]]: memref<?xi32>,
+// CHECK-SAME: %[[A5:.*5]]: memref<?xi64>,
+// CHECK-SAME: %[[A6:.*6]]: memref<?xf64>)
+// CHECK: return %[[A6]] : memref<?xf64>
func.func @sparse_values_dcsr(%arg0: tensor<?x?xf64, #DCSR>) -> memref<?xf64> {
%0 = sparse_tensor.values %arg0 : tensor<?x?xf64, #DCSR> to memref<?xf64>
return %0 : memref<?xf64>
@@ -264,13 +248,12 @@ func.func @sparse_values_dcsr(%arg0: tensor<?x?xf64, #DCSR>) -> memref<?xf64> {
// CHECK-LABEL: func @sparse_noe(
// CHECK-SAME: %[[A0:.*0]]: memref<1xindex>,
-// CHECK-SAME: %[[A1:.*1]]: memref<1xindex>,
-// CHECK-SAME: %[[A2:.*2]]: memref<3xindex>,
-// CHECK-SAME: %[[A3:.*3]]: memref<?xi32>,
-// CHECK-SAME: %[[A4:.*4]]: memref<?xi64>,
-// CHECK-SAME: %[[A5:.*5]]: memref<?xf64>)
+// CHECK-SAME: %[[A1:.*1]]: memref<3xindex>,
+// CHECK-SAME: %[[A2:.*2]]: memref<?xi32>,
+// CHECK-SAME: %[[A3:.*3]]: memref<?xi64>,
+// CHECK-SAME: %[[A4:.*4]]: memref<?xf64>)
// CHECK: %[[C2:.*]] = arith.constant 2 : index
-// CHECK: %[[NOE:.*]] = memref.load %[[A2]][%[[C2]]] : memref<3xindex>
+// CHECK: %[[NOE:.*]] = memref.load %[[A1]][%[[C2]]] : memref<3xindex>
// CHECK: return %[[NOE]] : index
func.func @sparse_noe(%arg0: tensor<128xf64, #SparseVector>) -> index {
%0 = sparse_tensor.number_of_entries %arg0 : tensor<128xf64, #SparseVector>
@@ -279,17 +262,15 @@ func.func @sparse_noe(%arg0: tensor<128xf64, #SparseVector>) -> index {
// CHECK-LABEL: func @sparse_dealloc_csr(
// CHECK-SAME: %[[A0:.*0]]: memref<2xindex>,
-// CHECK-SAME: %[[A1:.*1]]: memref<2xindex>,
-// CHECK-SAME: %[[A2:.*2]]: memref<3xindex>,
-// CHECK-SAME: %[[A3:.*3]]: memref<?xi32>,
-// CHECK-SAME: %[[A4:.*4]]: memref<?xi64>,
-// CHECK-SAME: %[[A5:.*5]]: memref<?xf64>)
+// CHECK-SAME: %[[A1:.*1]]: memref<3xindex>,
+// CHECK-SAME: %[[A2:.*2]]: memref<?xi32>,
+// CHECK-SAME: %[[A3:.*3]]: memref<?xi64>,
+// CHECK-SAME: %[[A4:.*4]]: memref<?xf64>)
// CHECK: memref.dealloc %[[A0]] : memref<2xindex>
-// CHECK: memref.dealloc %[[A1]] : memref<2xindex>
-// CHECK: memref.dealloc %[[A2]] : memref<3xindex>
-// CHECK: memref.dealloc %[[A3]] : memref<?xi32>
-// CHECK: memref.dealloc %[[A4]] : memref<?xi64>
-// CHECK: memref.dealloc %[[A5]] : memref<?xf64>
+// CHECK: memref.dealloc %[[A1]] : memref<3xindex>
+// CHECK: memref.dealloc %[[A2]] : memref<?xi32>
+// CHECK: memref.dealloc %[[A3]] : memref<?xi64>
+// CHECK: memref.dealloc %[[A4]] : memref<?xf64>
// CHECK: return
func.func @sparse_dealloc_csr(%arg0: tensor<?x?xf64, #CSR>) {
bufferization.dealloc_tensor %arg0 : tensor<?x?xf64, #CSR>
@@ -298,24 +279,25 @@ func.func @sparse_dealloc_csr(%arg0: tensor<?x?xf64, #CSR>) {
// CHECK-LABEL: func @sparse_alloc_csc(
// CHECK-SAME: %[[A:.*]]: index) ->
-// CHECK-SAME: memref<2xindex>, memref<2xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf64>
+// CHECK-SAME: memref<2xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf64>
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[C10:.*]] = arith.constant 10 : index
// CHECK: %[[T0:.*]] = memref.alloc() : memref<2xindex>
-// CHECK: %[[CC:.*]] = memref.alloc() : memref<2xindex>
// CHECK: %[[T1:.*]] = memref.alloc() : memref<3xindex>
+// CHECK: %[[T2:.*]] = memref.alloc() : memref<16xindex>
+// CHECK: %[[T3:.*]] = memref.cast %[[T2]] : memref<16xindex> to memref<?xindex>
+// CHECK: %[[T4:.*]] = memref.alloc() : memref<16xindex>
+// CHECK: %[[T5:.*]] = memref.cast %[[T4]] : memref<16xindex> to memref<?xindex>
+// CHECK: %[[T6:.*]] = memref.alloc() : memref<16xf64>
+// CHECK: %[[T7:.*]] = memref.cast %[[T6]] : memref<16xf64> to memref<?xf64>
+// CHECK: linalg.fill ins(%[[C0]] : index) outs(%[[T1]] : memref<3xindex>)
// CHECK: memref.store %[[A]], %[[T0]][%[[C0]]] : memref<2xindex>
// CHECK: memref.store %[[C10]], %[[T0]][%[[C1]]] : memref<2xindex>
-// CHECK: %[[T2:.*]] = memref.alloc() : memref<1xindex>
-// CHECK: %[[T3:.*]] = memref.cast %[[T2]] : memref<1xindex> to memref<?xindex>
-// CHECK: %[[T4:.*]] = memref.alloc() : memref<1xindex>
-// CHECK: %[[T5:.*]] = memref.cast %[[T4]] : memref<1xindex> to memref<?xindex>
-// CHECK: %[[T6:.*]] = memref.alloc() : memref<1xf64>
-// CHECK: %[[T7:.*]] = memref.cast %[[T6]] : memref<1xf64> to memref<?xf64>
-// CHECK: linalg.fill ins(%[[C0]] : index) outs(%[[T1]] : memref<3xindex>)
-// CHECK: return %[[T0]], %[[CC]], %[[T1]], %[[T3]], %[[T5]], %[[T7]] :
-// CHECK-SAME: memref<2xindex>, memref<2xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf64>
+// CHECK: %[[P0:.*]] = sparse_tensor.push_back %[[T1]], %[[T3]]
+// CHECK: %[[P1:.*]] = sparse_tensor.push_back %[[T1]], %[[P0]]
+// CHECK: return %[[T0]], %[[T1]], %[[P1]], %[[T5]], %[[T7]] :
+// CHECK-SAME: memref<2xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf64>
func.func @sparse_alloc_csc(%arg0: index) -> tensor<10x?xf64, #CSC> {
%0 = bufferization.alloc_tensor(%arg0) : tensor<10x?xf64, #CSC>
%1 = sparse_tensor.load %0 : tensor<10x?xf64, #CSC>
@@ -323,7 +305,8 @@ func.func @sparse_alloc_csc(%arg0: index) -> tensor<10x?xf64, #CSC> {
}
// CHECK-LABEL: func @sparse_alloc_3d() ->
-// CHECK-SAME: memref<3xindex>, memref<3xindex>, memref<1xindex>, memref<?xf64>
+// CHECK-SAME: memref<3xindex>, memref<1xindex>, memref<?xf64>
+// CHECK-DAG: %[[F0:.*]] = arith.constant 0.000000e+00 : f64
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
@@ -332,16 +315,16 @@ func.func @sparse_alloc_csc(%arg0: index) -> tensor<10x?xf64, #CSC> {
// CHECK-DAG: %[[C30:.*]] = arith.constant 30 : index
// CHECK-DAG: %[[C6000:.*]] = arith.constant 6000 : index
// CHECK: %[[A0:.*]] = memref.alloc() : memref<3xindex>
-// CHECK: %[[CC:.*]] = memref.alloc() : memref<3xindex>
// CHECK: %[[A1:.*]] = memref.alloc() : memref<1xindex>
+// CHECK: %[[AV:.*]] = memref.alloc() : memref<16xf64>
+// CHECK: %[[A2:.*]] = memref.cast %[[AV]] : memref<16xf64> to memref<?xf64>
+// CHECK: linalg.fill ins(%[[C0]] : index) outs(%[[A1]] : memref<1xindex>)
// CHECK: memref.store %[[C30]], %[[A0]][%[[C0]]] : memref<3xindex>
// CHECK: memref.store %[[C10]], %[[A0]][%[[C1]]] : memref<3xindex>
// CHECK: memref.store %[[C20]], %[[A0]][%[[C2]]] : memref<3xindex>
-// CHECK: %[[A:.*]] = memref.alloc() : memref<6000xf64>
-// CHECK: %[[A2:.*]] = memref.cast %[[A]] : memref<6000xf64> to memref<?xf64>
-// CHECK: memref.store %[[C6000]], %[[A1]][%[[C0]]] : memref<1xindex>
-// CHECK: return %[[A0]], %[[CC]], %[[A1]], %[[A2]] :
-// CHECK-SAME: memref<3xindex>, memref<3xindex>, memref<1xindex>, memref<?xf64>
+// CHECK: %[[P:.*]] = sparse_tensor.push_back %[[A1]], %[[A2]], %[[F0]], %[[C6000]]
+// CHECK: return %[[A0]], %[[A1]], %[[P]] :
+// CHECK-SAME: memref<3xindex>, memref<1xindex>, memref<?xf64>
func.func @sparse_alloc_3d() -> tensor<10x20x30xf64, #Dense3D> {
%0 = bufferization.alloc_tensor() : tensor<10x20x30xf64, #Dense3D>
%1 = sparse_tensor.load %0 : tensor<10x20x30xf64, #Dense3D>
@@ -400,37 +383,31 @@ func.func @sparse_expansion3(%arg0: index, %arg1: index) -> memref<?xindex> {
// CHECK-LABEL: func @sparse_compression_1d(
// CHECK-SAME: %[[A0:.*0]]: memref<1xindex>,
-// CHECK-SAME: %[[A1:.*1]]: memref<1xindex>,
-// CHECK-SAME: %[[A2:.*2]]: memref<3xindex>,
+// CHECK-SAME: %[[A1:.*1]]: memref<3xindex>,
+// CHECK-SAME: %[[A2:.*2]]: memref<?xindex>,
// CHECK-SAME: %[[A3:.*3]]: memref<?xindex>,
-// CHECK-SAME: %[[A4:.*4]]: memref<?xindex>,
+// CHECK-SAME: %[[A4:.*4]]: memref<?xf64>,
// CHECK-SAME: %[[A5:.*5]]: memref<?xf64>,
-// CHECK-SAME: %[[A6:.*6]]: memref<?xf64>,
-// CHECK-SAME: %[[A7:.*7]]: memref<?xi1>,
-// CHECK-SAME: %[[A8:.*8]]: memref<?xindex>,
-// CHECK-SAME: %[[A9:.*9]]: index)
+// CHECK-SAME: %[[A6:.*6]]: memref<?xi1>,
+// CHECK-SAME: %[[A7:.*7]]: memref<?xindex>,
+// CHECK-SAME: %[[A8:.*8]]: index)
// CHECK-DAG: %[[B0:.*]] = arith.constant false
// CHECK-DAG: %[[F0:.*]] = arith.constant 0.000000e+00 : f64
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
-// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
-// CHECK: sparse_tensor.sort %[[A9]], %[[A8]] : memref<?xindex>
-// CHECK: %[[R:.*]]:2 = scf.for %[[I:.*]] = %[[C0]] to %[[A9]] step %[[C1]] iter_args(%[[P0:.*]] = %[[A4]], %[[P1:.*]] = %[[A5]]) -> (memref<?xindex>, memref<?xf64>) {
-// CHECK: %[[T1:.*]] = memref.load %[[A8]][%[[I]]] : memref<?xindex>
-// CHECK: %[[T2:.*]] = memref.load %[[A6]][%[[T1]]] : memref<?xf64>
-// CHECK: %[[T3:.*]] = sparse_tensor.push_back %[[A2]], %[[P0]], %[[T1]] {idx = 1 : index} : memref<3xindex>, memref<?xindex>, index
-// CHECK: %[[T4:.*]] = sparse_tensor.push_back %[[A2]], %[[P1]], %[[T2]] {idx = 2 : index} : memref<3xindex>, memref<?xf64>, f64
-// CHECK: memref.store %[[F0]], %[[A6]][%[[T1]]] : memref<?xf64>
-// CHECK: memref.store %[[B0]], %[[A7]][%[[T1]]] : memref<?xi1>
-// CHECK: scf.yield %[[T3]], %[[T4]] : memref<?xindex>, memref<?xf64>
+// CHECK: sparse_tensor.sort %[[A8]], %[[A7]] : memref<?xindex>
+// CHECK: %[[R:.*]]:2 = scf.for %[[I:.*]] = %[[C0]] to %[[A8]] step %[[C1]] iter_args(%[[P0:.*]] = %[[A3]], %[[P1:.*]] = %[[A4]]) -> (memref<?xindex>, memref<?xf64>) {
+// CHECK: %[[INDEX:.*]] = memref.load %[[A7]][%[[I]]] : memref<?xindex>
+// CHECK: %[[VAL:.*]] = memref.load %[[A5]][%[[INDEX]]] : memref<?xf64>
+// CHECK: %[[PV:.*]] = sparse_tensor.push_back %[[A1]], %[[P1]], %[[VAL]] {idx = 2 : index} : memref<3xindex>, memref<?xf64>, f64
+// CHECK: memref.store %[[F0]], %[[A5]][%[[INDEX]]] : memref<?xf64>
+// CHECK: memref.store %[[B0]], %[[A6]][%[[INDEX]]] : memref<?xi1>
+// CHECK: scf.yield %{{.*}}, %[[PV]] : memref<?xindex>, memref<?xf64>
// CHECK: }
-// CHECK: memref.dealloc %[[A6]] : memref<?xf64>
-// CHECK: memref.dealloc %[[A7]] : memref<?xi1>
-// CHECK: memref.dealloc %[[A8]] : memref<?xindex>
-// CHECK: %[[LL:.*]] = memref.load %[[A2]][%[[C2]]] : memref<3xindex>
-// CHECK: %[[P1:.*]] = sparse_tensor.push_back %[[A2]], %[[A3]], %[[C0]] {idx = 0 : index} : memref<3xindex>, memref<?xindex>, index
-// CHECK: %[[P2:.*]] = sparse_tensor.push_back %[[A2]], %[[P1]], %[[LL]] {idx = 0 : index} : memref<3xindex>, memref<?xindex>, index
-// CHECK: return %[[A0]], %[[A1]], %[[A2]], %[[P2]], %[[R]]#0, %[[R]]#1
+// CHECK: memref.dealloc %[[A5]] : memref<?xf64>
+// CHECK: memref.dealloc %[[A6]] : memref<?xi1>
+// CHECK: memref.dealloc %[[A7]] : memref<?xindex>
+// CHECK: return %[[A0]], %[[A1]], %[[A2]], %[[R]]#0, %[[R]]#1
// CHECK-SAME: memref<1xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf64>
func.func @sparse_compression_1d(%tensor: tensor<100xf64, #SV>,
%values: memref<?xf64>,
@@ -445,31 +422,33 @@ func.func @sparse_compression_1d(%tensor: tensor<100xf64, #SV>,
// CHECK-LABEL: func @sparse_compression(
// CHECK-SAME: %[[A0:.*0]]: memref<2xindex>,
-// CHECK-SAME: %[[A1:.*1]]: memref<2xindex>,
-// CHECK-SAME: %[[A2:.*2]]: memref<3xindex>,
-// CHECK-SAME: %[[A3:.*3]]: memref<?xi32>,
-// CHECK-SAME: %[[A4:.*4]]: memref<?xi64>,
+// CHECK-SAME: %[[A1:.*1]]: memref<3xindex>,
+// CHECK-SAME: %[[A2:.*2]]: memref<?xi32>,
+// CHECK-SAME: %[[A3:.*3]]: memref<?xi64>,
+// CHECK-SAME: %[[A4:.*4]]: memref<?xf64>,
// CHECK-SAME: %[[A5:.*5]]: memref<?xf64>,
-// CHECK-SAME: %[[A6:.*6]]: memref<?xf64>,
-// CHECK-SAME: %[[A7:.*7]]: memref<?xi1>,
-// CHECK-SAME: %[[A8:.*8]]: memref<?xindex>,
-// CHECK-SAME: %[[A9:.*9]]: index,
-// CHECK-SAME: %[[A10:.*10]]: index)
+// CHECK-SAME: %[[A6:.*6]]: memref<?xi1>,
+// CHECK-SAME: %[[A7:.*7]]: memref<?xindex>,
+// CHECK-SAME: %[[A8:.*8]]: index,
+// CHECK-SAME: %[[A9:.*9]]: index)
// CHECK-DAG: %[[B0:.*]] = arith.constant false
// CHECK-DAG: %[[F0:.*]] = arith.constant 0.000000e+00 : f64
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
-// CHECK: sparse_tensor.sort %[[A9]], %[[A8]] : memref<?xindex>
-// CHECK-NEXT: scf.for %[[I:.*]] = %[[C0]] to %[[A9]] step %[[C1]] {
-// CHECK-NEXT: %[[INDEX:.*]] = memref.load %[[A8]][%[[I]]] : memref<?xindex>
-// TODO: 2D-insert
-// CHECK-DAG: memref.store %[[F0]], %[[A6]][%[[INDEX]]] : memref<?xf64>
-// CHECK-DAG: memref.store %[[B0]], %[[A7]][%[[INDEX]]] : memref<?xi1>
-// CHECK-NEXT: }
-// CHECK-DAG: memref.dealloc %[[A6]] : memref<?xf64>
-// CHECK-DAG: memref.dealloc %[[A7]] : memref<?xi1>
-// CHECK-DAG: memref.dealloc %[[A8]] : memref<?xindex>
-// CHECK: return
+// CHECK: sparse_tensor.sort %[[A8]], %[[A7]] : memref<?xindex>
+// CHECK: %[[R:.*]]:2 = scf.for %[[I:.*]] = %[[C0]] to %[[A8]] step %[[C1]] iter_args(%[[P0:.*]] = %[[A3]], %[[P1:.*]] = %[[A4]]) -> (memref<?xi64>, memref<?xf64>) {
+// CHECK: %[[INDEX:.*]] = memref.load %[[A7]][%[[I]]] : memref<?xindex>
+// CHECK: %[[VAL:.*]] = memref.load %[[A5]][%[[INDEX]]] : memref<?xf64>
+// CHECK: %[[PV:.*]] = sparse_tensor.push_back %[[A1]], %[[P1]], %[[VAL]] {idx = 2 : index} : memref<3xindex>, memref<?xf64>, f64
+// CHECK: memref.store %[[F0]], %[[A5]][%[[INDEX]]] : memref<?xf64>
+// CHECK: memref.store %[[B0]], %[[A6]][%[[INDEX]]] : memref<?xi1>
+// CHECK: scf.yield %{{.*}}, %[[PV]] : memref<?xi64>, memref<?xf64>
+// CHECK: }
+// CHECK: memref.dealloc %[[A5]] : memref<?xf64>
+// CHECK: memref.dealloc %[[A6]] : memref<?xi1>
+// CHECK: memref.dealloc %[[A7]] : memref<?xindex>
+// CHECK: return %[[A0]], %[[A1]], %[[A2]], %[[R]]#0, %[[R]]#1
+// CHECK-SAME: memref<2xindex>, memref<3xindex>, memref<?xi32>, memref<?xi64>, memref<?xf64>
func.func @sparse_compression(%tensor: tensor<8x8xf64, #CSR>,
%values: memref<?xf64>,
%filled: memref<?xi1>,
@@ -484,31 +463,33 @@ func.func @sparse_compression(%tensor: tensor<8x8xf64, #CSR>,
// CHECK-LABEL: func @sparse_compression_unordered(
// CHECK-SAME: %[[A0:.*0]]: memref<2xindex>,
-// CHECK-SAME: %[[A1:.*1]]: memref<2xindex>,
-// CHECK-SAME: %[[A2:.*2]]: memref<3xindex>,
+// CHECK-SAME: %[[A1:.*1]]: memref<3xindex>,
+// CHECK-SAME: %[[A2:.*2]]: memref<?xindex>,
// CHECK-SAME: %[[A3:.*3]]: memref<?xindex>,
-// CHECK-SAME: %[[A4:.*4]]: memref<?xindex>,
+// CHECK-SAME: %[[A4:.*4]]: memref<?xf64>,
// CHECK-SAME: %[[A5:.*5]]: memref<?xf64>,
-// CHECK-SAME: %[[A6:.*6]]: memref<?xf64>,
-// CHECK-SAME: %[[A7:.*7]]: memref<?xi1>,
-// CHECK-SAME: %[[A8:.*8]]: memref<?xindex>,
-// CHECK-SAME: %[[A9:.*9]]: index,
-// CHECK-SAME: %[[A10:.*10]]: index)
+// CHECK-SAME: %[[A6:.*6]]: memref<?xi1>,
+// CHECK-SAME: %[[A7:.*7]]: memref<?xindex>,
+// CHECK-SAME: %[[A8:.*8]]: index,
+// CHECK-SAME: %[[A9:.*9]]: index)
// CHECK-DAG: %[[B0:.*]] = arith.constant false
// CHECK-DAG: %[[F0:.*]] = arith.constant 0.000000e+00 : f64
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK-NOT: sparse_tensor.sort
-// CHECK-NEXT: scf.for %[[I:.*]] = %[[C0]] to %[[A9]] step %[[C1]] {
-// CHECK-NEXT: %[[INDEX:.*]] = memref.load %[[A8]][%[[I]]] : memref<?xindex>
-// TODO: 2D-insert
-// CHECK-DAG: memref.store %[[F0]], %[[A6]][%[[INDEX]]] : memref<?xf64>
-// CHECK-DAG: memref.store %[[B0]], %[[A7]][%[[INDEX]]] : memref<?xi1>
-// CHECK-NEXT: }
-// CHECK-DAG: memref.dealloc %[[A6]] : memref<?xf64>
-// CHECK-DAG: memref.dealloc %[[A7]] : memref<?xi1>
-// CHECK-DAG: memref.dealloc %[[A8]] : memref<?xindex>
-// CHECK: return
+// CHECK: %[[R:.*]]:2 = scf.for %[[I:.*]] = %[[C0]] to %[[A8]] step %[[C1]] iter_args(%[[P0:.*]] = %[[A3]], %[[P1:.*]] = %[[A4]]) -> (memref<?xindex>, memref<?xf64>) {
+// CHECK: %[[INDEX:.*]] = memref.load %[[A7]][%[[I]]] : memref<?xindex>
+// CHECK: %[[VAL:.*]] = memref.load %[[A5]][%[[INDEX]]] : memref<?xf64>
+// CHECK: %[[PV:.*]] = sparse_tensor.push_back %[[A1]], %[[P1]], %[[VAL]] {idx = 2 : index} : memref<3xindex>, memref<?xf64>, f64
+// CHECK: memref.store %[[F0]], %[[A5]][%[[INDEX]]] : memref<?xf64>
+// CHECK: memref.store %[[B0]], %[[A6]][%[[INDEX]]] : memref<?xi1>
+// CHECK: scf.yield %{{.*}}, %[[PV]] : memref<?xindex>, memref<?xf64>
+// CHECK: }
+// CHECK: memref.dealloc %[[A5]] : memref<?xf64>
+// CHECK: memref.dealloc %[[A6]] : memref<?xi1>
+// CHECK: memref.dealloc %[[A7]] : memref<?xindex>
+// CHECK: return %[[A0]], %[[A1]], %[[A2]], %[[R]]#0, %[[R]]#1
+// CHECK-SAME: memref<2xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf64>
func.func @sparse_compression_unordered(%tensor: tensor<8x8xf64, #UCSR>,
%values: memref<?xf64>,
%filled: memref<?xi1>,
@@ -523,21 +504,14 @@ func.func @sparse_compression_unordered(%tensor: tensor<8x8xf64, #UCSR>,
// CHECK-LABEL: func @sparse_insert(
// CHECK-SAME: %[[A0:.*0]]: memref<1xindex>,
-// CHECK-SAME: %[[A1:.*1]]: memref<1xindex>,
-// CHECK-SAME: %[[A2:.*2]]: memref<3xindex>,
+// CHECK-SAME: %[[A1:.*1]]: memref<3xindex>,
+// CHECK-SAME: %[[A2:.*2]]: memref<?xindex>,
// CHECK-SAME: %[[A3:.*3]]: memref<?xindex>,
-// CHECK-SAME: %[[A4:.*4]]: memref<?xindex>,
-// CHECK-SAME: %[[A5:.*5]]: memref<?xf64>,
-// CHECK-SAME: %[[A6:.*6]]: index,
-// CHECK-SAME: %[[A7:.*7]]: f64)
-// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
-// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
-// CHECK: %[[T1:.*]] = sparse_tensor.push_back %[[A2]], %[[A4]], %[[A6]]
-// CHECK: %[[T2:.*]] = sparse_tensor.push_back %[[A2]], %[[A5]], %[[A7]]
-// CHECK: %[[T3:.*]] = memref.load %[[A2]][%[[C2]]] : memref<3xindex>
-// CHECK: %[[T0:.*]] = sparse_tensor.push_back %[[A2]], %[[A3]], %[[C0]]
-// CHECK: %[[T4:.*]] = sparse_tensor.push_back %[[A2]], %[[T0]], %[[T3]]
-// CHECK: return %[[A0]], %[[A1]], %[[A2]], %[[T4]], %[[T1]], %[[T2]] :
+// CHECK-SAME: %[[A4:.*4]]: memref<?xf64>,
+// CHECK-SAME: %[[A5:.*5]]: index,
+// CHECK-SAME: %[[A6:.*6]]: f64)
+// CHECK: %[[P:.*]] = sparse_tensor.push_back %[[A1]], %[[A4]], %[[A6]] {idx = 2 : index} : memref<3xindex>, memref<?xf64>, f64
+// CHECK: return %[[A0]], %[[A1]], %[[A2]], %{{.*}}, %[[P]] :
// CHECK-SAME: memref<1xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf64>
func.func @sparse_insert(%arg0: tensor<128xf64, #SV>, %arg1: index, %arg2: f64) -> tensor<128xf64, #SV> {
%0 = sparse_tensor.insert %arg2 into %arg0[%arg1] : tensor<128xf64, #SV>
@@ -547,24 +521,15 @@ func.func @sparse_insert(%arg0: tensor<128xf64, #SV>, %arg1: index, %arg2: f64)
// CHECK-LABEL: func @sparse_insert_typed(
// CHECK-SAME: %[[A0:.*0]]: memref<1xindex>,
-// CHECK-SAME: %[[A1:.*1]]: memref<1xindex>,
-// CHECK-SAME: %[[A2:.*2]]: memref<3xindex>,
-// CHECK-SAME: %[[A3:.*3]]: memref<?xi32>,
-// CHECK-SAME: %[[A4:.*4]]: memref<?xi64>,
-// CHECK-SAME: %[[A5:.*5]]: memref<?xf64>,
-// CHECK-SAME: %[[A6:.*6]]: index,
-// CHECK-SAME: %[[A7:.*7]]: f64)
-// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : i32
-// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
-// CHECK: %[[S1:.*]] = arith.index_cast %[[A6]] : index to i64
-// CHECK: %[[T1:.*]] = sparse_tensor.push_back %[[A2]], %[[A4]], %[[S1]]
-// CHECK: %[[T2:.*]] = sparse_tensor.push_back %[[A2]], %[[A5]], %[[A7]]
-// CHECK: %[[T3:.*]] = memref.load %[[A2]][%[[C2]]] : memref<3xindex>
-// CHECK: %[[T0:.*]] = sparse_tensor.push_back %[[A2]], %[[A3]], %[[C0]]
-// CHECK: %[[S2:.*]] = arith.index_cast %[[T3]] : index to i32
-// CHECK: %[[T4:.*]] = sparse_tensor.push_back %[[A2]], %[[T0]], %[[S2]]
-// CHECK: return %[[A0]], %[[A1]], %[[A2]], %[[T4]], %[[T1]], %[[T2]] :
-// CHECK-SAME: memref<1xindex>, memref<1xindex>, memref<3xindex>, memref<?xi32>, memref<?xi64>, memref<?xf64>
+// CHECK-SAME: %[[A1:.*1]]: memref<3xindex>,
+// CHECK-SAME: %[[A2:.*2]]: memref<?xi32>,
+// CHECK-SAME: %[[A3:.*3]]: memref<?xi64>,
+// CHECK-SAME: %[[A4:.*4]]: memref<?xf64>,
+// CHECK-SAME: %[[A5:.*5]]: index,
+// CHECK-SAME: %[[A6:.*6]]: f64)
+// CHECK: %[[P:.*]] = sparse_tensor.push_back %[[A1]], %[[A4]], %[[A6]] {idx = 2 : index} : memref<3xindex>, memref<?xf64>, f64
+// CHECK: return %[[A0]], %[[A1]], %[[A2]], %{{.*}}, %[[P]] :
+// CHECK-SAME: memref<1xindex>, memref<3xindex>, memref<?xi32>, memref<?xi64>, memref<?xf64>
func.func @sparse_insert_typed(%arg0: tensor<128xf64, #SparseVector>, %arg1: index, %arg2: f64) -> tensor<128xf64, #SparseVector> {
%0 = sparse_tensor.insert %arg2 into %arg0[%arg1] : tensor<128xf64, #SparseVector>
%1 = sparse_tensor.load %0 hasInserts : tensor<128xf64, #SparseVector>
@@ -573,13 +538,12 @@ func.func @sparse_insert_typed(%arg0: tensor<128xf64, #SparseVector>, %arg1: ind
// CHECK-LABEL: func.func @sparse_nop_convert(
// CHECK-SAME: %[[A0:.*0]]: memref<1xindex>,
-// CHECK-SAME: %[[A1:.*1]]: memref<1xindex>,
-// CHECK-SAME: %[[A2:.*2]]: memref<3xindex>,
-// CHECK-SAME: %[[A3:.*3]]: memref<?xi32>,
-// CHECK-SAME: %[[A4:.*4]]: memref<?xi64>,
-// CHECK-SAME: %[[A5:.*5]]: memref<?xf32>)
-// CHECK: return %[[A0]], %[[A1]], %[[A2]], %[[A3]], %[[A4]], %[[A5]] :
-// CHECK-SAME: memref<1xindex>, memref<1xindex>, memref<3xindex>, memref<?xi32>, memref<?xi64>, memref<?xf32>
+// CHECK-SAME: %[[A1:.*1]]: memref<3xindex>,
+// CHECK-SAME: %[[A2:.*2]]: memref<?xi32>,
+// CHECK-SAME: %[[A3:.*3]]: memref<?xi64>,
+// CHECK-SAME: %[[A4:.*4]]: memref<?xf32>)
+// CHECK: return %[[A0]], %[[A1]], %[[A2]], %[[A3]], %[[A4]] :
+// CHECK-SAME: memref<1xindex>, memref<3xindex>, memref<?xi32>, memref<?xi64>, memref<?xf32>
func.func @sparse_nop_convert(%arg0: tensor<32xf32, #SparseVector>) -> tensor<?xf32, #SparseVector> {
%0 = sparse_tensor.convert %arg0 : tensor<32xf32, #SparseVector> to tensor<?xf32, #SparseVector>
return %0 : tensor<?xf32, #SparseVector>
diff --git a/mlir/test/Dialect/SparseTensor/scf_1_N_conversion.mlir b/mlir/test/Dialect/SparseTensor/scf_1_N_conversion.mlir
index 207e46b3d45ae..0c3c9bb0d27c0 100644
--- a/mlir/test/Dialect/SparseTensor/scf_1_N_conversion.mlir
+++ b/mlir/test/Dialect/SparseTensor/scf_1_N_conversion.mlir
@@ -3,24 +3,22 @@
#SparseVector = #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ] }>
// CHECK-LABEL: func @for(
// CHECK-SAME: %[[DIM_SIZE:.*0]]: memref<1xindex>,
-// CHECK-SAME: %[[DIM_CURSOR:.*1]]: memref<1xindex>,
-// CHECK-SAME: %[[MEM_SIZE:.*2]]: memref<3xindex>,
-// CHECK-SAME: %[[POINTER:.*3]]: memref<?xindex>,
-// CHECK-SAME: %[[INDICES:.*4]]: memref<?xindex>,
-// CHECK-SAME: %[[VALUE:.*5]]: memref<?xf32>,
-// CHECK-SAME: %[[LB:.*6]]: index,
-// CHECK-SAME: %[[UB:.*7]]: index,
-// CHECK-SAME: %[[STEP:.*8]]: index)
-// CHECK: %[[OUT:.*]]:6 = scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] iter_args(
+// CHECK-SAME: %[[MEM_SIZE:.*1]]: memref<3xindex>,
+// CHECK-SAME: %[[POINTER:.*2]]: memref<?xindex>,
+// CHECK-SAME: %[[INDICES:.*3]]: memref<?xindex>,
+// CHECK-SAME: %[[VALUE:.*4]]: memref<?xf32>,
+// CHECK-SAME: %[[LB:.*5]]: index,
+// CHECK-SAME: %[[UB:.*6]]: index,
+// CHECK-SAME: %[[STEP:.*7]]: index)
+// CHECK: %[[OUT:.*]]:5 = scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] iter_args(
// CHECK-SAME: %[[SIZE:.*]] = %[[DIM_SIZE]],
-// CHECK-SAME: %[[CUR:.*]] = %[[DIM_CURSOR]],
// CHECK-SAME: %[[MEM:.*]] = %[[MEM_SIZE]],
// CHECK-SAME: %[[PTR:.*]] = %[[POINTER]],
// CHECK-SAME: %[[IDX:.*]] = %[[INDICES]],
// CHECK-SAME: %[[VAL:.*]] = %[[VALUE]])
-// CHECK: scf.yield %[[SIZE]], %[[CUR]], %[[MEM]], %[[PTR]], %[[IDX]], %[[VAL]] : memref<1xindex>, memref<1xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf32>
+// CHECK: scf.yield %[[SIZE]], %[[MEM]], %[[PTR]], %[[IDX]], %[[VAL]] : memref<1xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf32>
// CHECK: }
-// CHECK: return %[[OUT]]#0, %[[OUT]]#1, %[[OUT]]#2, %[[OUT]]#3, %[[OUT]]#4, %[[OUT]]#5 : memref<1xindex>, memref<1xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf32>
+// CHECK: return %[[OUT]]#0, %[[OUT]]#1, %[[OUT]]#2, %[[OUT]]#3, %[[OUT]]#4 : memref<1xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf32>
func.func @for(%in: tensor<1024xf32, #SparseVector>,
%lb: index, %ub: index, %step: index) -> tensor<1024xf32, #SparseVector> {
%1 = scf.for %i = %lb to %ub step %step iter_args(%vin = %in)
@@ -33,25 +31,23 @@ func.func @for(%in: tensor<1024xf32, #SparseVector>,
// CHECK-LABEL: func @if(
// CHECK-SAME: %[[DIM_SIZE:.*0]]: memref<1xindex>,
-// CHECK-SAME: %[[DIM_CURSOR:.*1]]: memref<1xindex>,
-// CHECK-SAME: %[[MEM_SIZE:.*2]]: memref<3xindex>,
-// CHECK-SAME: %[[POINTER:.*3]]: memref<?xindex>,
-// CHECK-SAME: %[[INDICES:.*4]]: memref<?xindex>,
-// CHECK-SAME: %[[VALUE:.*5]]: memref<?xf32>,
-// CHECK-SAME: %[[DIM_SIZE_1:.*6]]: memref<1xindex>,
-// CHECK-SAME: %[[DIM_CURSOR_1:.*7]]: memref<1xindex>,
-// CHECK-SAME: %[[MEM_SIZE_1:.*8]]: memref<3xindex>,
-// CHECK-SAME: %[[POINTER_1:.*9]]: memref<?xindex>,
-// CHECK-SAME: %[[INDICES_1:.*10]]: memref<?xindex>,
-// CHECK-SAME: %[[VALUE_1:.*11]]: memref<?xf32>,
-// CHECK-SAME: %[[TMP_arg12:.*12]]: i1) ->
-// CHECK-SAME: (memref<1xindex>, memref<1xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf32>) {
-// CHECK: %[[SV:.*]]:6 = scf.if %[[TMP_arg12]] -> (memref<1xindex>, memref<1xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf32>) {
-// CHECK: scf.yield %[[DIM_SIZE]], %[[DIM_CURSOR]], %[[MEM_SIZE]], %[[POINTER]], %[[INDICES]], %[[VALUE]] : memref<1xindex>, memref<1xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf32>
+// CHECK-SAME: %[[MEM_SIZE:.*1]]: memref<3xindex>,
+// CHECK-SAME: %[[POINTER:.*2]]: memref<?xindex>,
+// CHECK-SAME: %[[INDICES:.*3]]: memref<?xindex>,
+// CHECK-SAME: %[[VALUE:.*4]]: memref<?xf32>,
+// CHECK-SAME: %[[DIM_SIZE_1:.*5]]: memref<1xindex>,
+// CHECK-SAME: %[[MEM_SIZE_1:.*6]]: memref<3xindex>,
+// CHECK-SAME: %[[POINTER_1:.*7]]: memref<?xindex>,
+// CHECK-SAME: %[[INDICES_1:.*8]]: memref<?xindex>,
+// CHECK-SAME: %[[VALUE_1:.*9]]: memref<?xf32>,
+// CHECK-SAME: %[[I1:.*10]]: i1) ->
+// CHECK-SAME: (memref<1xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf32>) {
+// CHECK: %[[SV:.*]]:5 = scf.if %[[I1]] -> (memref<1xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf32>) {
+// CHECK: scf.yield %[[DIM_SIZE]], %[[MEM_SIZE]], %[[POINTER]], %[[INDICES]], %[[VALUE]] : memref<1xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf32>
// CHECK: } else {
-// CHECK: scf.yield %[[DIM_SIZE_1]], %[[DIM_CURSOR_1]], %[[MEM_SIZE_1]], %[[POINTER_1]], %[[INDICES_1]], %[[VALUE_1]] : memref<1xindex>, memref<1xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf32>
+// CHECK: scf.yield %[[DIM_SIZE_1]], %[[MEM_SIZE_1]], %[[POINTER_1]], %[[INDICES_1]], %[[VALUE_1]] : memref<1xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf32>
// CHECK: }
-// CHECK: return %[[SV]]#0, %[[SV]]#1, %[[SV]]#2, %[[SV]]#3, %[[SV]]#4, %[[SV]]#5 : memref<1xindex>, memref<1xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf32>
+// CHECK: return %[[SV]]#0, %[[SV]]#1, %[[SV]]#2, %[[SV]]#3, %[[SV]]#4 : memref<1xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf32>
func.func @if(%t: tensor<1024xf32, #SparseVector>,
%f: tensor<1024xf32, #SparseVector>,
%c: i1) -> tensor<1024xf32, #SparseVector> {
@@ -60,38 +56,35 @@ func.func @if(%t: tensor<1024xf32, #SparseVector>,
} else {
scf.yield %f : tensor<1024xf32, #SparseVector>
}
-
return %1 : tensor<1024xf32, #SparseVector>
}
// CHECK-LABEL: func @while(
// CHECK-SAME: %[[DIM_SIZE:.*0]]: memref<1xindex>,
-// CHECK-SAME: %[[DIM_CURSOR:.*1]]: memref<1xindex>,
-// CHECK-SAME: %[[MEM_SIZE:.*2]]: memref<3xindex>,
-// CHECK-SAME: %[[POINTER:.*3]]: memref<?xindex>,
-// CHECK-SAME: %[[INDICES:.*4]]: memref<?xindex>,
-// CHECK-SAME: %[[VALUE:.*5]]: memref<?xf32>,
-// CHECK-SAME: %[[TMP_arg6:.*6]]: i1) ->
-// CHECK-SAME: (memref<1xindex>, memref<1xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf32>) {
-// CHECK: %[[SV:.*]]:6 = scf.while (
-// CHECK-SAME: %[[TMP_arg7:.*]] = %[[DIM_SIZE]],
-// CHECK-SAME: %[[TMP_arg8:.*]] = %[[DIM_CURSOR]],
-// CHECK-SAME: %[[TMP_arg9:.*]] = %[[MEM_SIZE]],
-// CHECK-SAME: %[[TMP_arg10:.*]] = %[[POINTER]],
-// CHECK-SAME: %[[TMP_arg11:.*]] = %[[INDICES]],
-// CHECK-SAME: %[[TMP_arg12:.*]] = %[[VALUE]])
-// CHECK: scf.condition(%[[TMP_arg6]]) %[[TMP_arg7]], %[[TMP_arg8]], %[[TMP_arg9]], %[[TMP_arg10]], %[[TMP_arg11]], %[[TMP_arg12]] : memref<1xindex>, memref<1xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf32>
+// CHECK-SAME: %[[MEM_SIZE:.*1]]: memref<3xindex>,
+// CHECK-SAME: %[[POINTER:.*2]]: memref<?xindex>,
+// CHECK-SAME: %[[INDICES:.*3]]: memref<?xindex>,
+// CHECK-SAME: %[[VALUE:.*4]]: memref<?xf32>,
+// CHECK-SAME: %[[I1:.*5]]: i1) ->
+// CHECK-SAME: (memref<1xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf32>) {
+// CHECK: %[[SV:.*]]:5 = scf.while (
+// CHECK-SAME: %[[TMP_DIM:.*]] = %[[DIM_SIZE]],
+// CHECK-SAME: %[[TMP_MEM:.*]] = %[[MEM_SIZE]],
+// CHECK-SAME: %[[TMP_PTR:.*]] = %[[POINTER]],
+// CHECK-SAME: %[[TMP_IND:.*]] = %[[INDICES]],
+// CHECK-SAME: %[[TMP_VAL:.*]] = %[[VALUE]])
+// CHECK: scf.condition(%[[I1]]) %[[TMP_DIM]], %[[TMP_MEM]], %[[TMP_PTR]], %[[TMP_IND]], %[[TMP_VAL]] : memref<1xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf32>
// CHECK: } do {
-// CHECK: ^bb0(%[[TMP_arg7]]: memref<1xindex>, %[[TMP_arg8]]: memref<1xindex>, %[[TMP_arg9]]: memref<3xindex>, %[[TMP_arg10]]: memref<?xindex>, %[[TMP_arg11]]: memref<?xindex>, %[[TMP_arg12]]: memref<?xf32>):
-// CHECK: scf.yield %[[TMP_arg7]], %[[TMP_arg8]], %[[TMP_arg9]], %[[TMP_arg10]], %[[TMP_arg11]], %[[TMP_arg12]] : memref<1xindex>, memref<1xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf32>
+// CHECK: ^bb0(%[[TMP_DIM]]: memref<1xindex>, %[[TMP_MEM]]: memref<3xindex>, %[[TMP_PTR]]: memref<?xindex>, %[[TMP_IND]]: memref<?xindex>, %[[TMP_VAL]]: memref<?xf32>):
+// CHECK: scf.yield %[[TMP_DIM]], %[[TMP_MEM]], %[[TMP_PTR]], %[[TMP_IND]], %[[TMP_VAL]] : memref<1xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf32>
// CHECK: }
-// CHECK: return %[[SV]]#0, %[[SV]]#1, %[[SV]]#2, %[[SV]]#3, %[[SV]]#4, %[[SV]]#5 : memref<1xindex>, memref<1xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf32>
+// CHECK: return %[[SV]]#0, %[[SV]]#1, %[[SV]]#2, %[[SV]]#3, %[[SV]]#4 : memref<1xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf32>
func.func @while(%arg0: tensor<1024xf32, #SparseVector>, %c: i1) -> tensor<1024xf32, #SparseVector> {
- %0 = scf.while (%arg4 = %arg0) : (tensor<1024xf32, #SparseVector>) -> tensor<1024xf32, #SparseVector> {
- scf.condition(%c) %arg4 : tensor<1024xf32, #SparseVector>
+ %0 = scf.while (%in = %arg0) : (tensor<1024xf32, #SparseVector>) -> tensor<1024xf32, #SparseVector> {
+ scf.condition(%c) %in : tensor<1024xf32, #SparseVector>
} do {
- ^bb0(%arg7: tensor<1024xf32, #SparseVector>):
- scf.yield %arg7 : tensor<1024xf32, #SparseVector>
+ ^bb0(%arg1: tensor<1024xf32, #SparseVector>):
+ scf.yield %arg1 : tensor<1024xf32, #SparseVector>
}
return %0: tensor<1024xf32, #SparseVector>
}
diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_insert_1d.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_insert_1d.mlir
index 0f28ab4925f74..d8422180b4025 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_insert_1d.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_insert_1d.mlir
@@ -21,18 +21,17 @@ module {
// Dumps pointers, indices, values for verification.
func.func @dump(%argx: tensor<1024xf32, #SparseVector>) {
%c0 = arith.constant 0 : index
- %cu = arith.constant 99 : index
- %fu = arith.constant 99.0 : f32
+ %f0 = arith.constant 0.0 : f32
%p = sparse_tensor.pointers %argx { dimension = 0 : index }
: tensor<1024xf32, #SparseVector> to memref<?xindex>
%i = sparse_tensor.indices %argx { dimension = 0 : index }
: tensor<1024xf32, #SparseVector> to memref<?xindex>
%v = sparse_tensor.values %argx
: tensor<1024xf32, #SparseVector> to memref<?xf32>
- %vp = vector.transfer_read %p[%c0], %cu: memref<?xindex>, vector<8xindex>
- %vi = vector.transfer_read %i[%c0], %cu: memref<?xindex>, vector<8xindex>
- %vv = vector.transfer_read %v[%c0], %fu: memref<?xf32>, vector<8xf32>
- vector.print %vp : vector<8xindex>
+ %vp = vector.transfer_read %p[%c0], %c0: memref<?xindex>, vector<2xindex>
+ %vi = vector.transfer_read %i[%c0], %c0: memref<?xindex>, vector<8xindex>
+ %vv = vector.transfer_read %v[%c0], %f0: memref<?xf32>, vector<8xf32>
+ vector.print %vp : vector<2xindex>
vector.print %vi : vector<8xindex>
vector.print %vv : vector<8xf32>
return
@@ -41,6 +40,8 @@ module {
func.func @entry() {
%f1 = arith.constant 1.0 : f32
%f2 = arith.constant 2.0 : f32
+ %f3 = arith.constant 3.0 : f32
+ %f4 = arith.constant 4.0 : f32
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c3 = arith.constant 3 : index
@@ -51,13 +52,13 @@ module {
%0 = bufferization.alloc_tensor() : tensor<1024xf32, #SparseVector>
%1 = sparse_tensor.insert %f1 into %0[%c0] : tensor<1024xf32, #SparseVector>
%2 = sparse_tensor.insert %f2 into %1[%c1] : tensor<1024xf32, #SparseVector>
- %3 = sparse_tensor.insert %f1 into %2[%c3] : tensor<1024xf32, #SparseVector>
- %4 = sparse_tensor.insert %f2 into %3[%c1023] : tensor<1024xf32, #SparseVector>
+ %3 = sparse_tensor.insert %f3 into %2[%c3] : tensor<1024xf32, #SparseVector>
+ %4 = sparse_tensor.insert %f4 into %3[%c1023] : tensor<1024xf32, #SparseVector>
%5 = sparse_tensor.load %4 hasInserts : tensor<1024xf32, #SparseVector>
- // CHECK: ( 0, 4, 99, 99, 99, 99, 99, 99 )
- // CHECK-NEXT: ( 0, 1, 3, 1023, 99, 99, 99, 99 )
- // CHECK-NEXT: ( 1, 2, 1, 2, 99, 99, 99, 99 )
+ // CHECK: ( 0, 4 )
+ // CHECK-NEXT: ( 0, 1, 3, 1023
+ // CHECK-NEXT: ( 1, 2, 3, 4
call @dump(%5) : (tensor<1024xf32, #SparseVector>) -> ()
// Build another sparse vector in a loop.
@@ -69,7 +70,7 @@ module {
}
%8 = sparse_tensor.load %7 hasInserts : tensor<1024xf32, #SparseVector>
- // CHECK: ( 0, 8, 99, 99, 99, 99, 99, 99 )
+ // CHECK-NEXT: ( 0, 8 )
// CHECK-NEXT: ( 0, 3, 6, 9, 12, 15, 18, 21 )
// CHECK-NEXT: ( 1, 1, 1, 1, 1, 1, 1, 1 )
//
diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_insert_2d.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_insert_2d.mlir
new file mode 100644
index 0000000000000..7ad07969cdf0b
--- /dev/null
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_insert_2d.mlir
@@ -0,0 +1,229 @@
+// RUN: mlir-opt %s --sparse-compiler=enable-runtime-library=false | \
+// RUN: mlir-cpu-runner \
+// RUN: -e entry -entry-point-result=void \
+// RUN: -shared-libs=%mlir_lib_dir/libmlir_c_runner_utils%shlibext | \
+// RUN: FileCheck %s
+
+#Dense = #sparse_tensor.encoding<{
+ dimLevelType = ["dense", "dense"]
+}>
+
+#SortedCOO = #sparse_tensor.encoding<{
+ dimLevelType = [ "compressed-nu", "singleton" ]
+}>
+
+#CSR = #sparse_tensor.encoding<{
+ dimLevelType = [ "dense", "compressed" ]
+}>
+
+#DCSR = #sparse_tensor.encoding<{
+ dimLevelType = [ "compressed", "compressed" ]
+}>
+
+#Row = #sparse_tensor.encoding<{
+ dimLevelType = [ "compressed", "dense" ]
+}>
+
+module {
+
+ func.func @dump_dense(%arg0: tensor<4x3xf64, #Dense>) {
+ %c0 = arith.constant 0 : index
+ %fu = arith.constant 99.0 : f64
+ %v = sparse_tensor.values %arg0 : tensor<4x3xf64, #Dense> to memref<?xf64>
+ %vv = vector.transfer_read %v[%c0], %fu: memref<?xf64>, vector<12xf64>
+ vector.print %vv : vector<12xf64>
+ return
+ }
+
+ func.func @dump_coo(%arg0: tensor<4x3xf64, #SortedCOO>) {
+ %c0 = arith.constant 0 : index
+ %cu = arith.constant -1 : index
+ %fu = arith.constant 99.0 : f64
+ %p0 = sparse_tensor.pointers %arg0 { dimension = 0 : index } : tensor<4x3xf64, #SortedCOO> to memref<?xindex>
+ %i0 = sparse_tensor.indices %arg0 { dimension = 0 : index } : tensor<4x3xf64, #SortedCOO> to memref<?xindex>
+ %i1 = sparse_tensor.indices %arg0 { dimension = 1 : index } : tensor<4x3xf64, #SortedCOO> to memref<?xindex>
+ %v = sparse_tensor.values %arg0 : tensor<4x3xf64, #SortedCOO> to memref<?xf64>
+ %vp0 = vector.transfer_read %p0[%c0], %cu: memref<?xindex>, vector<2xindex>
+ vector.print %vp0 : vector<2xindex>
+ %vi0 = vector.transfer_read %i0[%c0], %cu: memref<?xindex>, vector<4xindex>
+ vector.print %vi0 : vector<4xindex>
+ %vi1 = vector.transfer_read %i1[%c0], %cu: memref<?xindex>, vector<4xindex>
+ vector.print %vi1 : vector<4xindex>
+ %vv = vector.transfer_read %v[%c0], %fu: memref<?xf64>, vector<4xf64>
+ vector.print %vv : vector<4xf64>
+ return
+ }
+
+ func.func @dump_csr(%arg0: tensor<4x3xf64, #CSR>) {
+ %c0 = arith.constant 0 : index
+ %cu = arith.constant -1 : index
+ %fu = arith.constant 99.0 : f64
+ %p1 = sparse_tensor.pointers %arg0 { dimension = 1 : index } : tensor<4x3xf64, #CSR> to memref<?xindex>
+ %i1 = sparse_tensor.indices %arg0 { dimension = 1 : index } : tensor<4x3xf64, #CSR> to memref<?xindex>
+ %v = sparse_tensor.values %arg0 : tensor<4x3xf64, #CSR> to memref<?xf64>
+ %vp1 = vector.transfer_read %p1[%c0], %cu: memref<?xindex>, vector<5xindex>
+ vector.print %vp1 : vector<5xindex>
+ %vi1 = vector.transfer_read %i1[%c0], %cu: memref<?xindex>, vector<4xindex>
+ vector.print %vi1 : vector<4xindex>
+ %vv = vector.transfer_read %v[%c0], %fu: memref<?xf64>, vector<4xf64>
+ vector.print %vv : vector<4xf64>
+ return
+ }
+
+ func.func @dump_dcsr(%arg0: tensor<4x3xf64, #DCSR>) {
+ %c0 = arith.constant 0 : index
+ %cu = arith.constant -1 : index
+ %fu = arith.constant 99.0 : f64
+ %p0 = sparse_tensor.pointers %arg0 { dimension = 0 : index } : tensor<4x3xf64, #DCSR> to memref<?xindex>
+ %i0 = sparse_tensor.indices %arg0 { dimension = 0 : index } : tensor<4x3xf64, #DCSR> to memref<?xindex>
+ %p1 = sparse_tensor.pointers %arg0 { dimension = 1 : index } : tensor<4x3xf64, #DCSR> to memref<?xindex>
+ %i1 = sparse_tensor.indices %arg0 { dimension = 1 : index } : tensor<4x3xf64, #DCSR> to memref<?xindex>
+ %v = sparse_tensor.values %arg0 : tensor<4x3xf64, #DCSR> to memref<?xf64>
+ %vp0 = vector.transfer_read %p0[%c0], %cu: memref<?xindex>, vector<2xindex>
+ vector.print %vp0 : vector<2xindex>
+ %vi0 = vector.transfer_read %i0[%c0], %cu: memref<?xindex>, vector<3xindex>
+ vector.print %vi0 : vector<3xindex>
+ %vp1 = vector.transfer_read %p1[%c0], %cu: memref<?xindex>, vector<4xindex>
+ vector.print %vp1 : vector<4xindex>
+ %vi1 = vector.transfer_read %i1[%c0], %cu: memref<?xindex>, vector<4xindex>
+ vector.print %vi1 : vector<4xindex>
+ %vv = vector.transfer_read %v[%c0], %fu: memref<?xf64>, vector<4xf64>
+ vector.print %vv : vector<4xf64>
+ return
+ }
+
+ func.func @dump_row(%arg0: tensor<4x3xf64, #Row>) {
+ %c0 = arith.constant 0 : index
+ %cu = arith.constant -1 : index
+ %fu = arith.constant 99.0 : f64
+ %p0 = sparse_tensor.pointers %arg0 { dimension = 0 : index } : tensor<4x3xf64, #Row> to memref<?xindex>
+ %i0 = sparse_tensor.indices %arg0 { dimension = 0 : index } : tensor<4x3xf64, #Row> to memref<?xindex>
+ %v = sparse_tensor.values %arg0 : tensor<4x3xf64, #Row> to memref<?xf64>
+ %vp0 = vector.transfer_read %p0[%c0], %cu: memref<?xindex>, vector<2xindex>
+ vector.print %vp0 : vector<2xindex>
+ %vi0 = vector.transfer_read %i0[%c0], %cu: memref<?xindex>, vector<3xindex>
+ vector.print %vi0 : vector<3xindex>
+ %vv = vector.transfer_read %v[%c0], %fu: memref<?xf64>, vector<9xf64>
+ vector.print %vv : vector<9xf64>
+ return
+ }
+
+ //
+ // Main driver. We test the contents of various sparse tensor
+ // schemes when they are still empty and after a few insertions.
+ //
+ func.func @entry() {
+ %c0 = arith.constant 0 : index
+ %c2 = arith.constant 2 : index
+ %c3 = arith.constant 3 : index
+ %f1 = arith.constant 1.0 : f64
+ %f2 = arith.constant 2.0 : f64
+ %f3 = arith.constant 3.0 : f64
+ %f4 = arith.constant 4.0 : f64
+
+ //
+ // Dense case.
+ //
+ // CHECK: ( 1, 0, 0, 0, 0, 0, 0, 0, 2, 3, 0, 4 )
+ //
+ %densea = bufferization.alloc_tensor() : tensor<4x3xf64, #Dense>
+ %dense1 = sparse_tensor.insert %f1 into %densea[%c0, %c0] : tensor<4x3xf64, #Dense>
+ %dense2 = sparse_tensor.insert %f2 into %dense1[%c2, %c2] : tensor<4x3xf64, #Dense>
+ %dense3 = sparse_tensor.insert %f3 into %dense2[%c3, %c0] : tensor<4x3xf64, #Dense>
+ %dense4 = sparse_tensor.insert %f4 into %dense3[%c3, %c2] : tensor<4x3xf64, #Dense>
+ %densem = sparse_tensor.load %dense4 hasInserts : tensor<4x3xf64, #Dense>
+ call @dump_dense(%densem) : (tensor<4x3xf64, #Dense>) -> ()
+
+ //
+ // COO case.
+ //
+ // CHECK-NEXT: ( 0, 4 )
+ // CHECK-NEXT: ( 0, 2, 3, 3 )
+ // CHECK-NEXT: ( 0, 2, 0, 2 )
+ // CHECK-NEXT: ( 1, 2, 3, 4 )
+ //
+ %cooa = bufferization.alloc_tensor() : tensor<4x3xf64, #SortedCOO>
+ %coo1 = sparse_tensor.insert %f1 into %cooa[%c0, %c0] : tensor<4x3xf64, #SortedCOO>
+ %coo2 = sparse_tensor.insert %f2 into %coo1[%c2, %c2] : tensor<4x3xf64, #SortedCOO>
+ %coo3 = sparse_tensor.insert %f3 into %coo2[%c3, %c0] : tensor<4x3xf64, #SortedCOO>
+ %coo4 = sparse_tensor.insert %f4 into %coo3[%c3, %c2] : tensor<4x3xf64, #SortedCOO>
+ %coom = sparse_tensor.load %coo4 hasInserts : tensor<4x3xf64, #SortedCOO>
+ call @dump_coo(%coom) : (tensor<4x3xf64, #SortedCOO>) -> ()
+
+ //
+ // CSR case.
+ //
+ // CHECK-NEXT: ( 0, 1, 1, 2, 4 )
+ // CHECK-NEXT: ( 0, 2, 0, 2 )
+ // CHECK-NEXT: ( 1, 2, 3, 4 )
+ //
+ %csra = bufferization.alloc_tensor() : tensor<4x3xf64, #CSR>
+ %csr1 = sparse_tensor.insert %f1 into %csra[%c0, %c0] : tensor<4x3xf64, #CSR>
+ %csr2 = sparse_tensor.insert %f2 into %csr1[%c2, %c2] : tensor<4x3xf64, #CSR>
+ %csr3 = sparse_tensor.insert %f3 into %csr2[%c3, %c0] : tensor<4x3xf64, #CSR>
+ %csr4 = sparse_tensor.insert %f4 into %csr3[%c3, %c2] : tensor<4x3xf64, #CSR>
+ %csrm = sparse_tensor.load %csr4 hasInserts : tensor<4x3xf64, #CSR>
+ call @dump_csr(%csrm) : (tensor<4x3xf64, #CSR>) -> ()
+
+ //
+ // DCSR case.
+ //
+ // CHECK-NEXT: ( 0, 3 )
+ // CHECK-NEXT: ( 0, 2, 3 )
+ // CHECK-NEXT: ( 0, 1, 2, 4 )
+ // CHECK-NEXT: ( 0, 2, 0, 2 )
+ // CHECK-NEXT: ( 1, 2, 3, 4 )
+ //
+ %dcsra = bufferization.alloc_tensor() : tensor<4x3xf64, #DCSR>
+ %dcsr1 = sparse_tensor.insert %f1 into %dcsra[%c0, %c0] : tensor<4x3xf64, #DCSR>
+ %dcsr2 = sparse_tensor.insert %f2 into %dcsr1[%c2, %c2] : tensor<4x3xf64, #DCSR>
+ %dcsr3 = sparse_tensor.insert %f3 into %dcsr2[%c3, %c0] : tensor<4x3xf64, #DCSR>
+ %dcsr4 = sparse_tensor.insert %f4 into %dcsr3[%c3, %c2] : tensor<4x3xf64, #DCSR>
+ %dcsrm = sparse_tensor.load %dcsr4 hasInserts : tensor<4x3xf64, #DCSR>
+ call @dump_dcsr(%dcsrm) : (tensor<4x3xf64, #DCSR>) -> ()
+
+ //
+ // Row case.
+ //
+ // CHECK-NEXT: ( 0, 3 )
+ // CHECK-NEXT: ( 0, 2, 3 )
+ // CHECK-NEXT: ( 1, 0, 0, 0, 0, 2, 3, 0, 4 )
+ //
+ %rowa = bufferization.alloc_tensor() : tensor<4x3xf64, #Row>
+ %row1 = sparse_tensor.insert %f1 into %rowa[%c0, %c0] : tensor<4x3xf64, #Row>
+ %row2 = sparse_tensor.insert %f2 into %row1[%c2, %c2] : tensor<4x3xf64, #Row>
+ %row3 = sparse_tensor.insert %f3 into %row2[%c3, %c0] : tensor<4x3xf64, #Row>
+ %row4 = sparse_tensor.insert %f4 into %row3[%c3, %c2] : tensor<4x3xf64, #Row>
+ %rowm = sparse_tensor.load %row4 hasInserts : tensor<4x3xf64, #Row>
+ call @dump_row(%rowm) : (tensor<4x3xf64, #Row>) -> ()
+
+ //
+ // NOE sanity check.
+ //
+ // CHECK-NEXT: 12
+ // CHECK-NEXT: 4
+ // CHECK-NEXT: 4
+ // CHECK-NEXT: 4
+ // CHECK-NEXT: 9
+ //
+ %noe1 = sparse_tensor.number_of_entries %densem : tensor<4x3xf64, #Dense>
+ %noe2 = sparse_tensor.number_of_entries %coom : tensor<4x3xf64, #SortedCOO>
+ %noe3 = sparse_tensor.number_of_entries %csrm : tensor<4x3xf64, #CSR>
+ %noe4 = sparse_tensor.number_of_entries %dcsrm : tensor<4x3xf64, #DCSR>
+ %noe5 = sparse_tensor.number_of_entries %rowm : tensor<4x3xf64, #Row>
+ vector.print %noe1 : index
+ vector.print %noe2 : index
+ vector.print %noe3 : index
+ vector.print %noe4 : index
+ vector.print %noe5 : index
+
+ // Release resources.
+ bufferization.dealloc_tensor %densem : tensor<4x3xf64, #Dense>
+ bufferization.dealloc_tensor %coom : tensor<4x3xf64, #SortedCOO>
+ bufferization.dealloc_tensor %csrm : tensor<4x3xf64, #CSR>
+ bufferization.dealloc_tensor %dcsrm : tensor<4x3xf64, #DCSR>
+ bufferization.dealloc_tensor %rowm : tensor<4x3xf64, #Row>
+
+ return
+ }
+}
More information about the Mlir-commits
mailing list