[Mlir-commits] [mlir] 988733c - [mlir][sparse] use sparse_tensor::StorageSpecifier to store dim/memSizes
Peiming Liu
llvmlistbot at llvm.org
Thu Dec 22 16:47:42 PST 2022
Author: Peiming Liu
Date: 2022-12-23T00:47:36Z
New Revision: 988733c60037c61ca49233c356c0f928a5ac14bb
URL: https://github.com/llvm/llvm-project/commit/988733c60037c61ca49233c356c0f928a5ac14bb
DIFF: https://github.com/llvm/llvm-project/commit/988733c60037c61ca49233c356c0f928a5ac14bb.diff
LOG: [mlir][sparse] use sparse_tensor::StorageSpecifier to store dim/memSizes
Reviewed By: aartbik
Differential Revision: https://reviews.llvm.org/D140130
Added:
Modified:
mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
mlir/lib/Dialect/SparseTensor/Transforms/SparseStorageSpecifierToLLVM.cpp
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.cpp
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h
mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir
mlir/test/Dialect/SparseTensor/codegen.mlir
mlir/test/Dialect/SparseTensor/codegen_buffer_initialization.mlir
mlir/test/Dialect/SparseTensor/invalid.mlir
mlir/test/Dialect/SparseTensor/roundtrip.mlir
mlir/test/Dialect/SparseTensor/scf_1_N_conversion.mlir
mlir/test/Dialect/SparseTensor/sparse_matmul_codegen.mlir
mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_push_back.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index 13c6e033c13a..771e97fb006f 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -343,18 +343,20 @@ def SparseTensor_PushBackOp : SparseTensor_Op<"push_back",
"inBuffer", "value",
"$_self.cast<ShapedType>().getElementType()">,
AllTypesMatch<["inBuffer", "outBuffer"]>]>,
- Arguments<(ins StridedMemRefRankOf<[Index], [1]>:$bufferSizes,
+ Arguments<(ins Index:$curSize,
StridedMemRefRankOf<[AnyType], [1]>:$inBuffer,
- AnyType:$value, IndexAttr:$idx, Optional<Index>:$n,
+ AnyType:$value, Optional<Index>:$n,
UnitAttr:$inbounds)>,
- Results<(outs StridedMemRefRankOf<[AnyType], [1]>:$outBuffer)> {
+ Results<(outs StridedMemRefRankOf<[AnyType], [1]>:$outBuffer,
+ Index:$newSize)> {
string summary = "Pushes a value to the back of a given buffer";
string description = [{
- Push `value` to the end of the given sparse tensor storage buffer
- `inBuffer` and update the size of the buffer in `bufferSizes[idx]`. The
- capacity of the buffer is recorded in the memref type of `inBuffer `. If the
- current buffer is full, then `inBuffer.realloc` is called before pushing the
- data to the buffer. This is similar to std::vector push_back.
+ Pushes `value` to the end of the given sparse tensor storage buffer
+ `inBuffer` as indicated by the value of `curSize` and returns the
+ new size of the buffer in `newSize` (`newSize = curSize + n`).
+ The capacity of the buffer is recorded in the memref type of `inBuffer`.
+ If the current buffer is full, then `inBuffer.realloc` is called before
+ pushing the data to the buffer. This is similar to std::vector push_back.
The optional input `n` specifies the number of times to repeately push
the value to the back of the tensor. When `n` is a compile-time constant,
@@ -376,29 +378,28 @@ def SparseTensor_PushBackOp : SparseTensor_Op<"push_back",
Example:
```mlir
- %r = sparse_tensor.push_back %bufferSizes, %buffer, %val
- {idx = 0 : index} : memref<?xindex>, memref<?xf64>, f64
+ %buf, %newSize = sparse_tensor.push_back %curSize, %buffer, %val
+ : index, memref<?xf64>, f64
```
```mlir
- %r = sparse_tensor.push_back inbounds %bufferSizes, %buffer, %val
- {idx = 0 : index} : memref<?xindex>, memref<?xf64>, f64
+ %buf, %newSize = sparse_tensor.push_back inbounds %curSize, %buffer, %val
+ : xindex, memref<?xf64>, f64
```
```mlir
- %r = sparse_tensor.push_back inbounds %bufferSizes, %buffer, %val, %n
- {idx = 0 : index} : memref<?xindex>, memref<?xf64>, f64
+ %buf, %newSize = sparse_tensor.push_back inbounds %curSize, %buffer, %val, %n
+ : xindex, memref<?xf64>, f64
```
}];
- let assemblyFormat = "(`inbounds` $inbounds^)? $bufferSizes `,` $inBuffer"
+ let assemblyFormat = "(`inbounds` $inbounds^)? $curSize `,` $inBuffer"
" `,` $value (`,` $n^ )? attr-dict `:`"
- " type($bufferSizes) `,` type($inBuffer) `,`"
- " type($value) (`,` type($n)^ )?";
+ " type($curSize) `,` type($inBuffer) `,`"
+ " type($value) (`,` type($n)^ )?";
let builders = [
- //Build an op without input `n`.
- OpBuilder<(ins "Type":$outBuffer, "Value":$bufferSizes, "Value":$inBuffer,
- "Value":$value, "APInt":$idx)>
+ // Build an op (reusing type from curSize and inBuffer) without input `n`
+ OpBuilder<(ins "Value":$curSize, "Value":$inBuffer, "Value":$value)>
];
let hasVerifier = 1;
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 1e9aab8deb17..f28abee047fd 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -694,10 +694,8 @@ LogicalResult InsertOp::verify() {
}
void PushBackOp::build(OpBuilder &builder, OperationState &result,
- Type outBuffer, Value bufferSizes, Value inBuffer,
- Value value, APInt idx) {
- build(builder, result, outBuffer, bufferSizes, inBuffer, value,
- std::move(idx), Value());
+ Value curSize, Value inBuffer, Value value) {
+ build(builder, result, curSize, inBuffer, value, Value());
}
LogicalResult PushBackOp::verify() {
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
index 9fd74f7d3001..e3ab5ce1d040 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
@@ -90,116 +90,6 @@ static Value genIndexAndValueForDense(OpBuilder &builder, Location loc,
return val;
}
-void sparse_tensor::foreachFieldInSparseTensor(
- const SparseTensorEncodingAttr enc,
- llvm::function_ref<bool(unsigned, SparseTensorFieldKind, unsigned,
- DimLevelType)>
- callback) {
- assert(enc);
-
-#define RETURN_ON_FALSE(idx, kind, dim, dlt) \
- if (!(callback(idx, kind, dim, dlt))) \
- return;
-
- RETURN_ON_FALSE(dimSizesIdx, SparseTensorFieldKind::DimSizes, -1u,
- DimLevelType::Undef);
- RETURN_ON_FALSE(memSizesIdx, SparseTensorFieldKind::MemSizes, -1u,
- DimLevelType::Undef);
-
- static_assert(dataFieldIdx == memSizesIdx + 1);
- unsigned fieldIdx = dataFieldIdx;
- // Per-dimension storage.
- for (unsigned r = 0, rank = enc.getDimLevelType().size(); r < rank; r++) {
- // Dimension level types apply in order to the reordered dimension.
- // As a result, the compound type can be constructed directly in the given
- // order.
- auto dlt = getDimLevelType(enc, r);
- if (isCompressedDLT(dlt)) {
- RETURN_ON_FALSE(fieldIdx++, SparseTensorFieldKind::PtrMemRef, r, dlt);
- RETURN_ON_FALSE(fieldIdx++, SparseTensorFieldKind::IdxMemRef, r, dlt);
- } else if (isSingletonDLT(dlt)) {
- RETURN_ON_FALSE(fieldIdx++, SparseTensorFieldKind::IdxMemRef, r, dlt);
- } else {
- assert(isDenseDLT(dlt)); // no fields
- }
- }
- // The values array.
- RETURN_ON_FALSE(fieldIdx++, SparseTensorFieldKind::ValMemRef, -1u,
- DimLevelType::Undef);
-
-#undef RETURN_ON_FALSE
-}
-
-void sparse_tensor::foreachFieldAndTypeInSparseTensor(
- RankedTensorType rType,
- llvm::function_ref<bool(Type, unsigned, SparseTensorFieldKind, unsigned,
- DimLevelType)>
- callback) {
- auto enc = getSparseTensorEncoding(rType);
- assert(enc);
- // Construct the basic types.
- Type indexType = IndexType::get(enc.getContext());
- Type idxType = enc.getIndexType();
- Type ptrType = enc.getPointerType();
- Type eltType = rType.getElementType();
- unsigned rank = rType.getShape().size();
- // memref<rank x index> dimSizes
- Type dimSizeType = MemRefType::get({rank}, indexType);
- // memref<n x index> memSizes
- Type memSizeType =
- MemRefType::get({getNumDataFieldsFromEncoding(enc)}, indexType);
- // memref<? x ptr> pointers
- Type ptrMemType = MemRefType::get({ShapedType::kDynamic}, ptrType);
- // memref<? x idx> indices
- Type idxMemType = MemRefType::get({ShapedType::kDynamic}, idxType);
- // memref<? x eltType> values
- Type valMemType = MemRefType::get({ShapedType::kDynamic}, eltType);
-
- foreachFieldInSparseTensor(
- enc,
- [dimSizeType, memSizeType, ptrMemType, idxMemType, valMemType,
- callback](unsigned fieldIdx, SparseTensorFieldKind fieldKind,
- unsigned dim, DimLevelType dlt) -> bool {
- switch (fieldKind) {
- case SparseTensorFieldKind::DimSizes:
- return callback(dimSizeType, fieldIdx, fieldKind, dim, dlt);
- case SparseTensorFieldKind::MemSizes:
- return callback(memSizeType, fieldIdx, fieldKind, dim, dlt);
- case SparseTensorFieldKind::PtrMemRef:
- return callback(ptrMemType, fieldIdx, fieldKind, dim, dlt);
- case SparseTensorFieldKind::IdxMemRef:
- return callback(idxMemType, fieldIdx, fieldKind, dim, dlt);
- case SparseTensorFieldKind::ValMemRef:
- return callback(valMemType, fieldIdx, fieldKind, dim, dlt);
- };
- llvm_unreachable("unrecognized field kind");
- });
-}
-
-unsigned sparse_tensor::getNumFieldsFromEncoding(SparseTensorEncodingAttr enc) {
- unsigned numFields = 0;
- foreachFieldInSparseTensor(enc,
- [&numFields](unsigned, SparseTensorFieldKind,
- unsigned, DimLevelType) -> bool {
- numFields++;
- return true;
- });
- return numFields;
-}
-
-unsigned
-sparse_tensor::getNumDataFieldsFromEncoding(SparseTensorEncodingAttr enc) {
- unsigned numFields = 0; // one value memref
- foreachFieldInSparseTensor(enc,
- [&numFields](unsigned fidx, SparseTensorFieldKind,
- unsigned, DimLevelType) -> bool {
- if (fidx >= dataFieldIdx)
- numFields++;
- return true;
- });
- assert(numFields == getNumFieldsFromEncoding(enc) - dataFieldIdx);
- return numFields;
-}
//===----------------------------------------------------------------------===//
// Sparse tensor loop emitter class implementations
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
index 6e4ea83e1ba5..46f02141bb82 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
@@ -313,220 +313,6 @@ inline bool isZeroRankedTensorOrScalar(Type type) {
return !rtp || rtp.getRank() == 0;
}
-//===----------------------------------------------------------------------===//
-// SparseTensorDescriptor and helpers, manage the sparse tensor memory layout
-// scheme.
-//
-// Sparse tensor storage scheme for rank-dimensional tensor is organized
-// as a single compound type with the following fields. Note that every
-// memref with ? size actually behaves as a "vector", i.e. the stored
-// size is the capacity and the used size resides in the memSizes array.
-//
-// struct {
-// memref<rank x index> dimSizes ; size in each dimension
-// memref<n x index> memSizes ; sizes of ptrs/inds/values
-// ; per-dimension d:
-// ; if dense:
-// <nothing>
-// ; if compresed:
-// memref<? x ptr> pointers-d ; pointers for sparse dim d
-// memref<? x idx> indices-d ; indices for sparse dim d
-// ; if singleton:
-// memref<? x idx> indices-d ; indices for singleton dim d
-// memref<? x eltType> values ; values
-// };
-//
-//===----------------------------------------------------------------------===//
-enum class SparseTensorFieldKind {
- DimSizes,
- MemSizes,
- PtrMemRef,
- IdxMemRef,
- ValMemRef
-};
-
-constexpr uint64_t dimSizesIdx = 0;
-constexpr uint64_t memSizesIdx = dimSizesIdx + 1;
-constexpr uint64_t dataFieldIdx = memSizesIdx + 1;
-
-/// For each field that will be allocated for the given sparse tensor encoding,
-/// calls the callback with the corresponding field index, field kind, dimension
-/// (for sparse tensor level memrefs) and dimlevelType.
-/// The field index always starts with zero and increments by one between two
-/// callback invocations.
-/// Ideally, all other methods should rely on this function to query a sparse
-/// tensor fields instead of relying on ad-hoc index computation.
-void foreachFieldInSparseTensor(
- SparseTensorEncodingAttr,
- llvm::function_ref<bool(unsigned /*fieldIdx*/,
- SparseTensorFieldKind /*fieldKind*/,
- unsigned /*dim (if applicable)*/,
- DimLevelType /*DLT (if applicable)*/)>);
-
-/// Same as above, except that it also builds the Type for the corresponding
-/// field.
-void foreachFieldAndTypeInSparseTensor(
- RankedTensorType,
- llvm::function_ref<bool(Type /*fieldType*/, unsigned /*fieldIdx*/,
- SparseTensorFieldKind /*fieldKind*/,
- unsigned /*dim (if applicable)*/,
- DimLevelType /*DLT (if applicable)*/)>);
-
-/// Gets the total number of fields for the given sparse tensor encoding.
-unsigned getNumFieldsFromEncoding(SparseTensorEncodingAttr enc);
-
-/// Gets the total number of data fields (index arrays, pointer arrays, and a
-/// value array) for the given sparse tensor encoding.
-unsigned getNumDataFieldsFromEncoding(SparseTensorEncodingAttr enc);
-
-/// Get the index of the field in memSizes (only valid for data fields).
-inline unsigned getFieldMemSizesIndex(unsigned fid) {
- assert(fid >= dataFieldIdx);
- return fid - dataFieldIdx;
-}
-
-template <bool>
-struct SparseTensorValueArrayRef;
-
-// Uses ValueRange for immuatable descriptors; uses SmallVectorImpl<Value> &
-// for mutable descriptors.
-template <>
-struct SparseTensorValueArrayRef<false> {
- using ValueArray = ValueRange;
-};
-
-// Using SmallVector for mutable descriptor allows users to reuse it as a tmp
-// buffers to append value for some special cases, though users should be
-// responsible to restore the buffer to legal states after their use. It is
-// probably not a clean way, but it is the most efficient way to avoid copying
-// the fields into another SmallVector. If a more clear way is wanted, we
-// should change it to MutableArrayRef instead.
-template <>
-struct SparseTensorValueArrayRef<true> {
- using ValueArray = SmallVectorImpl<Value> &;
-};
-
-/// A helper class around an array of values that corresponding to a sparse
-/// tensor, provides a set of meaningful APIs to query and update a particular
-/// field in a consistent way.
-/// Users should not make assumption on how a sparse tensor is laid out but
-/// instead relies on this class to access the right value for the right field.
-template <bool mut>
-class SparseTensorDescriptorImpl {
-private:
- using Storage = typename SparseTensorValueArrayRef<mut>::ValueArray;
-
-public:
- SparseTensorDescriptorImpl(Type tp, Storage fields)
- : rType(tp.cast<RankedTensorType>()), fields(fields) {
- assert(getSparseTensorEncoding(tp) &&
- getNumFieldsFromEncoding(getSparseTensorEncoding(tp)) ==
- fields.size());
- // We should make sure the class is trivially copyable (and should be small
- // enough) such that we can pass it by value.
- static_assert(
- std::is_trivially_copyable_v<SparseTensorDescriptorImpl<mut>>);
- }
-
- // Implicit (and cheap) type conversion from MutSparseTensorDescriptor to
- // SparseTensorDescriptor.
- template <typename T = SparseTensorDescriptorImpl<true>>
- /*implicit*/ SparseTensorDescriptorImpl(std::enable_if_t<!mut, T> &mDesc)
- : rType(mDesc.getTensorType()), fields(mDesc.getFields()) {}
-
- ///
- /// Getters: get the field index for required field.
- ///
-
- unsigned getPtrMemRefIndex(unsigned ptrDim) const {
- return getFieldIndex(ptrDim, SparseTensorFieldKind::PtrMemRef);
- }
-
- unsigned getIdxMemRefIndex(unsigned idxDim) const {
- return getFieldIndex(idxDim, SparseTensorFieldKind::IdxMemRef);
- }
-
- unsigned getValMemRefIndex() const { return fields.size() - 1; }
-
- unsigned getPtrMemSizesIndex(unsigned dim) const {
- return getPtrMemRefIndex(dim) - dataFieldIdx;
- }
-
- unsigned getIdxMemSizesIndex(unsigned dim) const {
- return getIdxMemRefIndex(dim) - dataFieldIdx;
- }
-
- unsigned getValMemSizesIndex() const {
- return getValMemRefIndex() - dataFieldIdx;
- }
-
- unsigned getNumFields() const { return fields.size(); }
-
- ///
- /// Getters: get the value for required field.
- ///
-
- Value getDimSizesMemRef() const { return fields[dimSizesIdx]; }
- Value getMemSizesMemRef() const { return fields[memSizesIdx]; }
-
- Value getPtrMemRef(unsigned ptrDim) const {
- return fields[getPtrMemRefIndex(ptrDim)];
- }
-
- Value getIdxMemRef(unsigned idxDim) const {
- return fields[getIdxMemRefIndex(idxDim)];
- }
-
- Value getValMemRef() const { return fields[getValMemRefIndex()]; }
-
- Value getField(unsigned fid) const {
- assert(fid < fields.size());
- return fields[fid];
- }
-
- ///
- /// Setters: update the value for required field (only enabled for
- /// MutSparseTensorDescriptor).
- ///
-
- template <typename T = Value>
- void setField(unsigned fid, std::enable_if_t<mut, T> v) {
- assert(fid < fields.size());
- fields[fid] = v;
- }
-
- RankedTensorType getTensorType() const { return rType; }
- Storage getFields() const { return fields; }
-
- Type getElementType(unsigned fidx) const {
- return fields[fidx].getType().template cast<MemRefType>().getElementType();
- }
-
-private:
- unsigned getFieldIndex(unsigned dim, SparseTensorFieldKind kind) const {
- unsigned fieldIdx = -1u;
- foreachFieldInSparseTensor(
- getSparseTensorEncoding(rType),
- [dim, kind, &fieldIdx](unsigned fIdx, SparseTensorFieldKind fKind,
- unsigned fDim, DimLevelType dlt) -> bool {
- if (fDim == dim && kind == fKind) {
- fieldIdx = fIdx;
- // Returns false to break the iteration.
- return false;
- }
- return true;
- });
- assert(fieldIdx != -1u);
- return fieldIdx;
- }
-
- RankedTensorType rType;
- Storage fields;
-};
-
-using SparseTensorDescriptor = SparseTensorDescriptorImpl<false>;
-using MutSparseTensorDescriptor = SparseTensorDescriptorImpl<true>;
-
//===----------------------------------------------------------------------===//
// SparseTensorLoopEmiter class, manages sparse tensors and helps to
// generate loop structure to (co)-iterate sparse tensors.
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
index ded1e653fb5f..fc9476cd2b65 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
@@ -331,7 +331,7 @@ static void createBinarySearchFunc(OpBuilder &builder, ModuleOp module,
Location loc = func.getLoc();
ValueRange args = entryBlock->getArguments();
Value p = args[hiIdx];
- SmallVector<Type, 2> types(2, p.getType()); // only two
+ SmallVector<Type, 2> types(2, p.getType()); // only two
scf::WhileOp whileOp = builder.create<scf::WhileOp>(
loc, types, SmallVector<Value, 2>{args[loIdx], args[hiIdx]});
@@ -490,7 +490,7 @@ static void createPartitionFunc(OpBuilder &builder, ModuleOp module,
Value i = lo;
Value j = builder.create<arith::SubIOp>(loc, hi, c1);
- SmallVector<Value, 3> operands{i, j, p}; // exactly three
+ SmallVector<Value, 3> operands{i, j, p}; // exactly three
SmallVector<Type, 3> types{i.getType(), j.getType(), p.getType()};
scf::WhileOp whileOp = builder.create<scf::WhileOp>(loc, types, operands);
@@ -770,9 +770,7 @@ struct PushBackRewriter : OpRewritePattern<PushBackOp> {
Value c0 = constantIndex(rewriter, loc, 0);
Value buffer = op.getInBuffer();
Value capacity = rewriter.create<memref::DimOp>(loc, buffer, c0);
- Value idx = constantIndex(rewriter, loc, op.getIdx().getZExtValue());
- Value bufferSizes = op.getBufferSizes();
- Value size = rewriter.create<memref::LoadOp>(loc, bufferSizes, idx);
+ Value size = op.getCurSize();
Value value = op.getValue();
Value n = op.getN() ? op.getN() : constantIndex(rewriter, loc, 1);
@@ -852,8 +850,7 @@ struct PushBackRewriter : OpRewritePattern<PushBackOp> {
}
// Update the buffer size.
- rewriter.create<memref::StoreOp>(loc, newSize, bufferSizes, idx);
- rewriter.replaceOp(op, buffer);
+ rewriter.replaceOp(op, {buffer, newSize});
return success();
}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseStorageSpecifierToLLVM.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseStorageSpecifierToLLVM.cpp
index d6a9007baad0..5843aef0b7a7 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseStorageSpecifierToLLVM.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseStorageSpecifierToLLVM.cpp
@@ -118,7 +118,7 @@ class SpecifierGetterSetterOpConverter : public OpConversionPattern<SourceOp> {
op.getDim().value().getZExtValue());
} else {
auto enc = op.getSpecifier().getType().getEncoding();
- builder::StorageLayout layout(enc);
+ StorageLayout layout(enc);
Optional<unsigned> dim = std::nullopt;
if (op.getDim())
dim = op.getDim().value().getZExtValue();
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index 6b406d8241bc..710d6cec31dd 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -16,6 +16,7 @@
//===----------------------------------------------------------------------===//
#include "CodegenUtils.h"
+#include "SparseTensorStorageLayout.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
@@ -40,38 +41,6 @@ static constexpr const char kInsertFuncNamePrefix[] = "_insert_";
// Helper methods.
//===----------------------------------------------------------------------===//
-/// Returns the "tuple" value of the adapted tensor.
-static UnrealizedConversionCastOp getTuple(Value tensor) {
- return llvm::cast<UnrealizedConversionCastOp>(tensor.getDefiningOp());
-}
-
-static SparseTensorDescriptor getDescriptorFromTensorTuple(Value tensor) {
- auto tuple = getTuple(tensor);
- return SparseTensorDescriptor(tuple.getResultTypes()[0], tuple.getInputs());
-}
-
-static MutSparseTensorDescriptor
-getMutDescriptorFromTensorTuple(Value tensor, SmallVectorImpl<Value> &fields) {
- auto tuple = getTuple(tensor);
- fields.assign(tuple.getInputs().begin(), tuple.getInputs().end());
- return MutSparseTensorDescriptor(tuple.getResultTypes()[0], fields);
-}
-
-/// Packs the given values as a "tuple" value.
-static Value genTuple(OpBuilder &builder, Location loc, Type tp,
- ValueRange values) {
- return builder.create<UnrealizedConversionCastOp>(loc, TypeRange(tp), values)
- .getResult(0);
-}
-
-static Value genTuple(OpBuilder &builder, Location loc,
- SparseTensorDescriptor desc) {
- return builder
- .create<UnrealizedConversionCastOp>(loc, desc.getTensorType(),
- desc.getFields())
- .getResult(0);
-}
-
/// Flatten a list of operands that may contain sparse tensors.
static void flattenOperands(ValueRange operands,
SmallVectorImpl<Value> &flattened) {
@@ -146,9 +115,7 @@ static std::optional<Value> sizeFromTensorAtDim(OpBuilder &builder,
// Any other query can consult the dimSizes array at field DimSizesIdx,
// accounting for the reordering applied to the sparse storage.
- Value idx = constantIndex(builder, loc, toStoredDim(rtp, dim));
- return builder.create<memref::LoadOp>(loc, desc.getDimSizesMemRef(), idx)
- .getResult();
+ return desc.getDimSize(builder, loc, toStoredDim(rtp, dim));
}
// Gets the dimension size at the given stored dimension 'd', either as a
@@ -161,40 +128,24 @@ Value sizeAtStoredDim(OpBuilder &builder, Location loc,
if (!ShapedType::isDynamic(shape[dim]))
return constantIndex(builder, loc, shape[dim]);
- return genLoad(builder, loc, desc.getDimSizesMemRef(),
- constantIndex(builder, loc, d));
+ return desc.getDimSize(builder, loc, d);
}
static void createPushback(OpBuilder &builder, Location loc,
- MutSparseTensorDescriptor desc, unsigned fidx,
+ MutSparseTensorDescriptor desc,
+ SparseTensorFieldKind kind, Optional<unsigned> dim,
Value value, Value repeat = Value()) {
- Type etp = desc.getElementType(fidx);
- Value field = desc.getField(fidx);
- Value newField = builder.create<PushBackOp>(
- loc, field.getType(), desc.getMemSizesMemRef(), field,
- toType(builder, loc, value, etp), APInt(64, getFieldMemSizesIndex(fidx)),
- repeat);
- desc.setField(fidx, newField);
-}
+ Type etp = desc.getMemRefElementType(kind, dim);
+ Value field = desc.getMemRefField(kind, dim);
+ StorageSpecifierKind specFieldKind = toSpecifierKind(kind);
-/// Maps a sparse tensor type to the appropriate compounded buffers.
-static std::optional<LogicalResult>
-convertSparseTensorType(Type type, SmallVectorImpl<Type> &fields) {
- auto enc = getSparseTensorEncoding(type);
- if (!enc)
- return std::nullopt;
+ auto pushBackOp = builder.create<PushBackOp>(
+ loc, desc.getSpecifierField(builder, loc, specFieldKind, dim), field,
+ toType(builder, loc, value, etp), repeat);
- RankedTensorType rType = type.cast<RankedTensorType>();
- foreachFieldAndTypeInSparseTensor(
- rType,
- [&fields](Type fieldType, unsigned fieldIdx,
- SparseTensorFieldKind /*fieldKind*/, unsigned /*dim*/,
- DimLevelType /*dlt*/) -> bool {
- assert(fieldIdx == fields.size());
- fields.push_back(fieldType);
- return true;
- });
- return success();
+ desc.setMemRefField(kind, dim, pushBackOp.getOutBuffer());
+ desc.setSpecifierField(builder, loc, specFieldKind, dim,
+ pushBackOp.getNewSize());
}
/// Generates code that allocates a sparse storage scheme for given rank.
@@ -210,8 +161,8 @@ static void allocSchemeForRank(OpBuilder &builder, Location loc,
// the desired "linear + 1" length property at all times.
Type ptrType = getSparseTensorEncoding(rtp).getPointerType();
Value ptrZero = constantZero(builder, loc, ptrType);
- createPushback(builder, loc, desc, desc.getPtrMemRefIndex(r), ptrZero,
- linear);
+ createPushback(builder, loc, desc, SparseTensorFieldKind::PtrMemRef, r,
+ ptrZero, linear);
return;
}
if (isSingletonDim(rtp, r)) {
@@ -226,7 +177,8 @@ static void allocSchemeForRank(OpBuilder &builder, Location loc,
}
// Reached values array so prepare for an insertion.
Value valZero = constantZero(builder, loc, rtp.getElementType());
- createPushback(builder, loc, desc, desc.getValMemRefIndex(), valZero, linear);
+ createPushback(builder, loc, desc, SparseTensorFieldKind::ValMemRef,
+ std::nullopt, valZero, linear);
}
/// Creates allocation operation.
@@ -257,22 +209,20 @@ static void createAllocFields(OpBuilder &builder, Location loc, Type type,
foreachFieldAndTypeInSparseTensor(
rtp,
- [&builder, &fields, loc, heuristic,
+ [&builder, &fields, rtp, loc, heuristic,
enableInit](Type fType, unsigned fIdx, SparseTensorFieldKind fKind,
unsigned /*dim*/, DimLevelType /*dlt*/) -> bool {
assert(fields.size() == fIdx);
- auto memRefTp = fType.cast<MemRefType>();
Value field;
switch (fKind) {
- case SparseTensorFieldKind::DimSizes:
- case SparseTensorFieldKind::MemSizes:
- field = builder.create<memref::AllocOp>(loc, memRefTp);
+ case SparseTensorFieldKind::StorageSpec:
+ field = SparseTensorSpecifier::getInitValue(builder, loc, rtp);
break;
case SparseTensorFieldKind::PtrMemRef:
case SparseTensorFieldKind::IdxMemRef:
case SparseTensorFieldKind::ValMemRef:
- field =
- createAllocation(builder, loc, memRefTp, heuristic, enableInit);
+ field = createAllocation(builder, loc, fType.cast<MemRefType>(),
+ heuristic, enableInit);
break;
}
assert(field);
@@ -297,21 +247,18 @@ static void createAllocFields(OpBuilder &builder, Location loc, Type type,
// to all zeros, sets the dimSizes to known values and gives all pointer
// fields an initial zero entry, so that it is easier to maintain the
// "linear + 1" length property.
- builder.create<linalg::FillOp>(
- loc, constantZero(builder, loc, builder.getIndexType()),
- desc.getMemSizesMemRef()); // zero memSizes
-
Value ptrZero =
constantZero(builder, loc, getSparseTensorEncoding(rtp).getPointerType());
for (unsigned r = 0; r < rank; r++) {
unsigned ro = toOrigDim(rtp, r);
// Fills dim sizes array.
- genStore(builder, loc, sizes[ro], desc.getDimSizesMemRef(),
- constantIndex(builder, loc, r));
+ desc.setDimSize(builder, loc, r, sizes[ro]);
// Pushes a leading zero to pointers memref.
- if (isCompressedDim(rtp, r))
- createPushback(builder, loc, desc, desc.getPtrMemRefIndex(r), ptrZero);
+ if (isCompressedDim(rtp, r)) {
+ createPushback(builder, loc, desc, SparseTensorFieldKind::PtrMemRef, r,
+ ptrZero);
+ }
}
allocSchemeForRank(builder, loc, desc, /*rank=*/0);
}
@@ -349,10 +296,11 @@ static Value genCompressed(OpBuilder &builder, Location loc,
unsigned ptrIndex = desc.getPtrMemRefIndex(d);
Value one = constantIndex(builder, loc, 1);
Value pp1 = builder.create<arith::AddIOp>(loc, pos, one);
- Value plo = genLoad(builder, loc, desc.getField(ptrIndex), pos);
- Value phi = genLoad(builder, loc, desc.getField(ptrIndex), pp1);
- Value psz = constantIndex(builder, loc, getFieldMemSizesIndex(idxIndex));
- Value msz = genLoad(builder, loc, desc.getMemSizesMemRef(), psz);
+ Value plo = genLoad(builder, loc, desc.getMemRefField(ptrIndex), pos);
+ Value phi = genLoad(builder, loc, desc.getMemRefField(ptrIndex), pp1);
+ Value msz = desc.getIdxMemSize(builder, loc, d);
+ // Value msz = desc.getMemSize(builder, loc, getFieldMemSizesIndex(idxIndex));
+
Value phim1 = builder.create<arith::SubIOp>(
loc, toType(builder, loc, phi, indexType), one);
// Conditional expression.
@@ -362,14 +310,14 @@ static Value genCompressed(OpBuilder &builder, Location loc,
scf::IfOp ifOp1 = builder.create<scf::IfOp>(loc, types, lt, /*else*/ true);
types.pop_back();
builder.setInsertionPointToStart(&ifOp1.getThenRegion().front());
- Value crd = genLoad(builder, loc, desc.getField(idxIndex), phim1);
+ Value crd = genLoad(builder, loc, desc.getMemRefField(idxIndex), phim1);
Value eq = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
toType(builder, loc, crd, indexType),
indices[d]);
builder.create<scf::YieldOp>(loc, eq);
builder.setInsertionPointToStart(&ifOp1.getElseRegion().front());
if (d > 0)
- genStore(builder, loc, msz, desc.getField(ptrIndex), pos);
+ genStore(builder, loc, msz, desc.getMemRefField(ptrIndex), pos);
builder.create<scf::YieldOp>(loc, constantI1(builder, loc, false));
builder.setInsertionPointAfter(ifOp1);
Value p = ifOp1.getResult(0);
@@ -396,8 +344,9 @@ static Value genCompressed(OpBuilder &builder, Location loc,
// If !present (changes fields, update next).
builder.setInsertionPointToStart(&ifOp2.getElseRegion().front());
Value mszp1 = builder.create<arith::AddIOp>(loc, msz, one);
- genStore(builder, loc, mszp1, desc.getField(ptrIndex), pp1);
- createPushback(builder, loc, desc, idxIndex, indices[d]);
+ genStore(builder, loc, mszp1, desc.getMemRefField(ptrIndex), pp1);
+ createPushback(builder, loc, desc, SparseTensorFieldKind::IdxMemRef, d,
+ indices[d]);
// Prepare the next dimension "as needed".
if ((d + 1) < rank)
allocSchemeForRank(builder, loc, desc, d + 1);
@@ -459,7 +408,8 @@ static void genInsertBody(OpBuilder &builder, ModuleOp module,
// indices[d].push_back(i[d])
// pos[d] = pos[d-1]
// <insert @ pos[d] at next dimension d + 1>
- createPushback(builder, loc, desc, desc.getIdxMemRefIndex(d), indices[d]);
+ createPushback(builder, loc, desc, SparseTensorFieldKind::IdxMemRef, d,
+ indices[d]);
} else {
assert(isDenseDim(rtp, d));
// Construct the new position as:
@@ -472,7 +422,8 @@ static void genInsertBody(OpBuilder &builder, ModuleOp module,
}
// Reached the actual value append/insert.
if (!isDenseDim(rtp, rank - 1))
- createPushback(builder, loc, desc, desc.getValMemRefIndex(), value);
+ createPushback(builder, loc, desc, SparseTensorFieldKind::ValMemRef,
+ std::nullopt, value);
else
genStore(builder, loc, value, desc.getValMemRef(), pos);
builder.create<func::ReturnOp>(loc, fields);
@@ -565,8 +516,7 @@ static void genEndInsert(OpBuilder &builder, Location loc,
if (d > 0) {
Type ptrType = getSparseTensorEncoding(rtp).getPointerType();
Value ptrMemRef = desc.getPtrMemRef(d);
- Value mz = constantIndex(builder, loc, desc.getPtrMemSizesIndex(d));
- Value hi = genLoad(builder, loc, desc.getMemSizesMemRef(), mz);
+ Value hi = desc.getPtrMemSize(builder, loc, d);
Value zero = constantIndex(builder, loc, 0);
Value one = constantIndex(builder, loc, 1);
// Vector of only one, but needed by createFor's prototype.
@@ -723,6 +673,7 @@ class SparseTensorAllocConverter
bool enableInit)
: OpConversionPattern(typeConverter, context),
enableBufferInitialization(enableInit) {}
+
LogicalResult
matchAndRewrite(bufferization::AllocTensorOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
@@ -761,8 +712,8 @@ class SparseTensorDeallocConverter
// Replace the sparse tensor deallocation with field deallocations.
Location loc = op.getLoc();
- auto tuple = getTuple(adaptor.getTensor());
- for (auto input : tuple.getInputs())
+ auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
+ for (auto input : desc.getMemRefFields())
// Deallocate every buffer used to store the sparse tensor handler.
rewriter.create<memref::DeallocOp>(loc, input);
@@ -1018,36 +969,13 @@ class SparseNumberOfEntriesConverter
ConversionPatternRewriter &rewriter) const override {
// Query memSizes for the actually stored values size.
auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
- Value field =
- constantIndex(rewriter, op.getLoc(), desc.getValMemSizesIndex());
- rewriter.replaceOpWithNewOp<memref::LoadOp>(op, desc.getMemSizesMemRef(),
- field);
+ rewriter.replaceOp(op, desc.getValMemSize(rewriter, op.getLoc()));
return success();
}
};
} // namespace
-//===----------------------------------------------------------------------===//
-// Sparse tensor type conversion into an actual buffer.
-//===----------------------------------------------------------------------===//
-
-mlir::SparseTensorTypeToBufferConverter::SparseTensorTypeToBufferConverter() {
- addConversion([](Type type) { return type; });
- addConversion(convertSparseTensorType);
-
- // Required by scf.for 1:N type conversion.
- addSourceMaterialization([](OpBuilder &builder, RankedTensorType tp,
- ValueRange inputs,
- Location loc) -> std::optional<Value> {
- if (!getSparseTensorEncoding(tp))
- // Not a sparse tensor.
- return std::nullopt;
- // Sparse compiler knows how to cancel out these casts.
- return genTuple(builder, loc, tp, inputs);
- });
-}
-
//===----------------------------------------------------------------------===//
// Public method for populating conversion rules.
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.cpp
index 4dcd0345ed9d..fd0126bd555a 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.cpp
@@ -22,15 +22,51 @@ static Value createIndexCast(OpBuilder &builder, Location loc, Value value,
return value;
}
-static IntegerAttr fromOptionalInt(MLIRContext *ctx, Optional<unsigned> dim) {
+static IntegerAttr fromOptionalInt(MLIRContext *ctx,
+ std::optional<unsigned> dim) {
if (!dim)
return nullptr;
return IntegerAttr::get(IndexType::get(ctx), dim.value());
}
-unsigned
-builder::StorageLayout::getMemRefFieldIndex(SparseTensorFieldKind kind,
- Optional<unsigned> dim) const {
+static std::optional<LogicalResult>
+convertSparseTensorType(RankedTensorType rtp, SmallVectorImpl<Type> &fields) {
+ auto enc = getSparseTensorEncoding(rtp);
+ if (!enc)
+ return std::nullopt;
+
+ foreachFieldAndTypeInSparseTensor(
+ rtp,
+ [&fields](Type fieldType, unsigned fieldIdx,
+ SparseTensorFieldKind /*fieldKind*/, unsigned /*dim*/,
+ DimLevelType /*dlt*/) -> bool {
+ assert(fieldIdx == fields.size());
+ fields.push_back(fieldType);
+ return true;
+ });
+ return success();
+}
+
+SparseTensorTypeToBufferConverter::SparseTensorTypeToBufferConverter() {
+ addConversion([](Type type) { return type; });
+ addConversion([&](RankedTensorType rtp, SmallVectorImpl<Type> &fields) {
+ return convertSparseTensorType(rtp, fields);
+ });
+
+ // Required by scf.for 1:N type conversion.
+ addSourceMaterialization([](OpBuilder &builder, RankedTensorType tp,
+ ValueRange inputs,
+ Location loc) -> std::optional<Value> {
+ if (!getSparseTensorEncoding(tp))
+ // Not a sparse tensor.
+ return std::nullopt;
+ // Sparse compiler knows how to cancel out these casts.
+ return genTuple(builder, loc, tp, inputs);
+ });
+}
+
+unsigned StorageLayout::getMemRefFieldIndex(SparseTensorFieldKind kind,
+ std::optional<unsigned> dim) const {
unsigned fieldIdx = -1u;
foreachFieldInSparseTensor(
enc,
@@ -48,22 +84,20 @@ builder::StorageLayout::getMemRefFieldIndex(SparseTensorFieldKind kind,
return fieldIdx;
}
-unsigned
-builder::StorageLayout::getMemRefFieldIndex(StorageSpecifierKind kind,
- Optional<unsigned> dim) const {
+unsigned StorageLayout::getMemRefFieldIndex(StorageSpecifierKind kind,
+ std::optional<unsigned> dim) const {
return getMemRefFieldIndex(toFieldKind(kind), dim);
}
-Value builder::SparseTensorSpecifier::getInitValue(OpBuilder &builder,
- Location loc,
- RankedTensorType rtp) {
+Value SparseTensorSpecifier::getInitValue(OpBuilder &builder, Location loc,
+ RankedTensorType rtp) {
return builder.create<StorageSpecifierInitOp>(
loc, StorageSpecifierType::get(getSparseTensorEncoding(rtp)));
}
-Value builder::SparseTensorSpecifier::getSpecifierField(
- OpBuilder &builder, Location loc, StorageSpecifierKind kind,
- Optional<unsigned> dim) {
+Value SparseTensorSpecifier::getSpecifierField(OpBuilder &builder, Location loc,
+ StorageSpecifierKind kind,
+ std::optional<unsigned> dim) {
return createIndexCast(builder, loc,
builder.create<GetStorageSpecifierOp>(
loc, getFieldType(kind, dim), specifier, kind,
@@ -71,9 +105,10 @@ Value builder::SparseTensorSpecifier::getSpecifierField(
builder.getIndexType());
}
-void builder::SparseTensorSpecifier::setSpecifierField(
- OpBuilder &builder, Location loc, Value v, StorageSpecifierKind kind,
- Optional<unsigned> dim) {
+void SparseTensorSpecifier::setSpecifierField(OpBuilder &builder, Location loc,
+ Value v,
+ StorageSpecifierKind kind,
+ std::optional<unsigned> dim) {
specifier = builder.create<SetStorageSpecifierOp>(
loc, specifier, kind, fromOptionalInt(specifier.getContext(), dim),
createIndexCast(builder, loc, v, getFieldType(kind, dim)));
@@ -81,7 +116,7 @@ void builder::SparseTensorSpecifier::setSpecifierField(
constexpr uint64_t kDataFieldStartingIdx = 0;
-void sparse_tensor::builder::foreachFieldInSparseTensor(
+void sparse_tensor::foreachFieldInSparseTensor(
const SparseTensorEncodingAttr enc,
llvm::function_ref<bool(unsigned, SparseTensorFieldKind, unsigned,
DimLevelType)>
@@ -120,7 +155,7 @@ void sparse_tensor::builder::foreachFieldInSparseTensor(
#undef RETURN_ON_FALSE
}
-void sparse_tensor::builder::foreachFieldAndTypeInSparseTensor(
+void sparse_tensor::foreachFieldAndTypeInSparseTensor(
RankedTensorType rType,
llvm::function_ref<bool(Type, unsigned, SparseTensorFieldKind, unsigned,
DimLevelType)>
@@ -159,8 +194,7 @@ void sparse_tensor::builder::foreachFieldAndTypeInSparseTensor(
});
}
-unsigned
-sparse_tensor::builder::getNumFieldsFromEncoding(SparseTensorEncodingAttr enc) {
+unsigned sparse_tensor::getNumFieldsFromEncoding(SparseTensorEncodingAttr enc) {
unsigned numFields = 0;
foreachFieldInSparseTensor(enc,
[&numFields](unsigned, SparseTensorFieldKind,
@@ -171,8 +205,8 @@ sparse_tensor::builder::getNumFieldsFromEncoding(SparseTensorEncodingAttr enc) {
return numFields;
}
-unsigned sparse_tensor::builder::getNumDataFieldsFromEncoding(
- SparseTensorEncodingAttr enc) {
+unsigned
+sparse_tensor::getNumDataFieldsFromEncoding(SparseTensorEncodingAttr enc) {
unsigned numFields = 0; // one value memref
foreachFieldInSparseTensor(enc,
[&numFields](unsigned fidx, SparseTensorFieldKind,
@@ -183,6 +217,6 @@ unsigned sparse_tensor::builder::getNumDataFieldsFromEncoding(
});
numFields -= 1; // the last field is MetaData field
assert(numFields ==
- builder::getNumFieldsFromEncoding(enc) - kDataFieldStartingIdx - 1);
+ getNumFieldsFromEncoding(enc) - kDataFieldStartingIdx - 1);
return numFields;
}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h
index 9b4e2352b8f3..d94aa1f098b6 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h
@@ -22,8 +22,6 @@
namespace mlir {
namespace sparse_tensor {
-// FIXME: this is a tmp namespace
-namespace builder {
//===----------------------------------------------------------------------===//
// SparseTensorDescriptor and helpers, manage the sparse tensor memory layout
// scheme.
@@ -171,7 +169,7 @@ class SparseTensorDescriptorImpl {
SparseTensorDescriptorImpl(Type tp, ValueArrayRef fields)
: rType(tp.cast<RankedTensorType>()), fields(fields) {
assert(getSparseTensorEncoding(tp) &&
- builder::getNumFieldsFromEncoding(getSparseTensorEncoding(tp)) ==
+ getNumFieldsFromEncoding(getSparseTensorEncoding(tp)) ==
fields.size());
// We should make sure the class is trivially copyable (and should be small
// enough) such that we can pass it by value.
@@ -355,7 +353,6 @@ getMutDescriptorFromTensorTuple(Value tensor, SmallVectorImpl<Value> &fields) {
return MutSparseTensorDescriptor(tuple.getResultTypes()[0], fields);
}
-} // namespace builder
} // namespace sparse_tensor
} // namespace mlir
#endif // MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_SPARSETENSORBUILDER_H_
diff --git a/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir b/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir
index 04c850f5ea83..7e10ae17a1c4 100644
--- a/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir
+++ b/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir
@@ -1,15 +1,14 @@
// RUN: mlir-opt %s -split-input-file --sparse-buffer-rewrite --canonicalize --cse | FileCheck %s
// CHECK-LABEL: func @sparse_push_back(
-// CHECK-SAME: %[[A:.*]]: memref<?xindex>,
+// CHECK-SAME: %[[A:.*]]: index,
// CHECK-SAME: %[[B:.*]]: memref<?xf64>,
-// CHECK-SAME: %[[C:.*]]: f64) -> memref<?xf64> {
+// CHECK-SAME: %[[C:.*]]: f64) -> (memref<?xf64>, index) {
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[P1:.*]] = memref.dim %[[B]], %[[C0]]
-// CHECK: %[[S1:.*]] = memref.load %[[A]]{{\[}}%[[C2]]]
-// CHECK: %[[S2:.*]] = arith.addi %[[S1]], %[[C1]] : index
+// CHECK: %[[S2:.*]] = arith.addi %[[A]], %[[C1]] : index
// CHECK: %[[T:.*]] = arith.cmpi ugt, %[[S2]], %[[P1]]
// CHECK: %[[M:.*]] = scf.if %[[T]] -> (memref<?xf64>) {
// CHECK: %[[P2:.*]] = arith.muli %[[P1]], %[[C2]]
@@ -18,25 +17,23 @@
// CHECK: } else {
// CHECK: scf.yield %[[B]] : memref<?xf64>
// CHECK: }
-// CHECK: memref.store %[[C]], %[[M]]{{\[}}%[[S1]]]
-// CHECK: memref.store %[[S2]], %[[A]]{{\[}}%[[C2]]]
-// CHECK: return %[[M]] : memref<?xf64>
-func.func @sparse_push_back(%arg0: memref<?xindex>, %arg1: memref<?xf64>, %arg2: f64) -> memref<?xf64> {
- %0 = sparse_tensor.push_back %arg0, %arg1, %arg2 {idx = 2 : index} : memref<?xindex>, memref<?xf64>, f64
- return %0 : memref<?xf64>
+// CHECK: memref.store %[[C]], %[[M]]{{\[}}%[[A]]]
+// CHECK: return %[[M]], %[[S2]]
+func.func @sparse_push_back(%arg0: index, %arg1: memref<?xf64>, %arg2: f64) -> (memref<?xf64>, index) {
+ %0:2 = sparse_tensor.push_back %arg0, %arg1, %arg2 : index, memref<?xf64>, f64
+ return %0#0, %0#1 : memref<?xf64>, index
}
// -----
// CHECK-LABEL: func @sparse_push_back_n(
-// CHECK-SAME: %[[A:.*]]: memref<?xindex>,
+// CHECK-SAME: %[[S1:.*]]: index,
// CHECK-SAME: %[[B:.*]]: memref<?xf64>,
// CHECK-SAME: %[[C:.*]]: f64,
-// CHECK-SAME: %[[D:.*]]: index) -> memref<?xf64> {
+// CHECK-SAME: %[[D:.*]]: index) -> (memref<?xf64>, index) {
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
// CHECK: %[[P1:.*]] = memref.dim %[[B]], %[[C0]]
-// CHECK: %[[S1:.*]] = memref.load %[[A]]{{\[}}%[[C2]]]
// CHECK: %[[S2:.*]] = arith.addi %[[S1]], %[[D]] : index
// CHECK: %[[T:.*]] = arith.cmpi ugt, %[[S2]], %[[P1]]
// CHECK: %[[M:.*]] = scf.if %[[T]] -> (memref<?xf64>) {
@@ -55,29 +52,25 @@ func.func @sparse_push_back(%arg0: memref<?xindex>, %arg1: memref<?xf64>, %arg2:
// CHECK: }
// CHECK: %[[S:.*]] = memref.subview %[[M]]{{\[}}%[[S1]]] {{\[}}%[[D]]] [1]
// CHECK: linalg.fill ins(%[[C]] : f64) outs(%[[S]]
-// CHECK: memref.store %[[S2]], %[[A]]{{\[}}%[[C2]]]
-// CHECK: return %[[M]] : memref<?xf64>
-func.func @sparse_push_back_n(%arg0: memref<?xindex>, %arg1: memref<?xf64>, %arg2: f64, %arg3: index) -> memref<?xf64> {
- %0 = sparse_tensor.push_back %arg0, %arg1, %arg2, %arg3 {idx = 2 : index} : memref<?xindex>, memref<?xf64>, f64, index
- return %0 : memref<?xf64>
+// CHECK: return %[[M]], %[[S2]] : memref<?xf64>, index
+func.func @sparse_push_back_n(%arg0: index, %arg1: memref<?xf64>, %arg2: f64, %arg3: index) -> (memref<?xf64>, index) {
+ %0:2 = sparse_tensor.push_back %arg0, %arg1, %arg2, %arg3 : index, memref<?xf64>, f64, index
+ return %0#0, %0#1 : memref<?xf64>, index
}
// -----
// CHECK-LABEL: func @sparse_push_back_inbound(
-// CHECK-SAME: %[[A:.*]]: memref<?xindex>,
+// CHECK-SAME: %[[S1:.*]]: index,
// CHECK-SAME: %[[B:.*]]: memref<?xf64>,
-// CHECK-SAME: %[[C:.*]]: f64) -> memref<?xf64> {
+// CHECK-SAME: %[[C:.*]]: f64) -> (memref<?xf64>, index) {
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
-// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
-// CHECK: %[[S1:.*]] = memref.load %[[A]]{{\[}}%[[C2]]]
// CHECK: %[[S2:.*]] = arith.addi %[[S1]], %[[C1]]
// CHECK: memref.store %[[C]], %[[B]]{{\[}}%[[S1]]]
-// CHECK: memref.store %[[S2]], %[[A]]{{\[}}%[[C2]]]
-// CHECK: return %[[B]] : memref<?xf64>
-func.func @sparse_push_back_inbound(%arg0: memref<?xindex>, %arg1: memref<?xf64>, %arg2: f64) -> memref<?xf64> {
- %0 = sparse_tensor.push_back inbounds %arg0, %arg1, %arg2 {idx = 2 : index} : memref<?xindex>, memref<?xf64>, f64
- return %0 : memref<?xf64>
+// CHECK: return %[[B]], %[[S2]] : memref<?xf64>, index
+func.func @sparse_push_back_inbound(%arg0: index, %arg1: memref<?xf64>, %arg2: f64) -> (memref<?xf64>, index) {
+ %0:2 = sparse_tensor.push_back inbounds %arg0, %arg1, %arg2 : index, memref<?xf64>, f64
+ return %0#0, %0#1 : memref<?xf64>, index
}
// -----
diff --git a/mlir/test/Dialect/SparseTensor/codegen.mlir b/mlir/test/Dialect/SparseTensor/codegen.mlir
index 0e3eb03bda78..b60935c2913c 100644
--- a/mlir/test/Dialect/SparseTensor/codegen.mlir
+++ b/mlir/test/Dialect/SparseTensor/codegen.mlir
@@ -47,31 +47,28 @@
}>
// CHECK-LABEL: func @sparse_nop(
-// CHECK-SAME: %[[A0:.*0]]: memref<1xindex>,
-// CHECK-SAME: %[[A1:.*1]]: memref<3xindex>,
-// CHECK-SAME: %[[A2:.*2]]: memref<?xi32>,
-// CHECK-SAME: %[[A3:.*3]]: memref<?xi64>,
-// CHECK-SAME: %[[A4:.*4]]: memref<?xf64>)
-// CHECK: return %[[A0]], %[[A1]], %[[A2]], %[[A3]], %[[A4]] :
-// CHECK-SAME: memref<1xindex>, memref<3xindex>, memref<?xi32>, memref<?xi64>, memref<?xf64>
+// CHECK-SAME: %[[A0:.*]]: memref<?xi32>,
+// CHECK-SAME: %[[A1:.*]]: memref<?xi64>,
+// CHECK-SAME: %[[A2:.*]]: memref<?xf64>,
+// CHECK-SAME: %[[A3:.*]]: !sparse_tensor.storage_specifier
+// CHECK: return %[[A0]], %[[A1]], %[[A2]], %[[A3]] :
+// CHECK-SAME: memref<?xi32>, memref<?xi64>, memref<?xf64>, !sparse_tensor.storage_specifier
func.func @sparse_nop(%arg0: tensor<?xf64, #SparseVector>) -> tensor<?xf64, #SparseVector> {
return %arg0 : tensor<?xf64, #SparseVector>
}
// CHECK-LABEL: func @sparse_nop_multi_ret(
-// CHECK-SAME: %[[A0:.*0]]: memref<1xindex>,
-// CHECK-SAME: %[[A1:.*1]]: memref<3xindex>,
-// CHECK-SAME: %[[A2:.*2]]: memref<?xi32>,
-// CHECK-SAME: %[[A3:.*3]]: memref<?xi64>,
-// CHECK-SAME: %[[A4:.*4]]: memref<?xf64>,
-// CHECK-SAME: %[[A5:.*5]]: memref<1xindex>,
-// CHECK-SAME: %[[A6:.*6]]: memref<3xindex>,
-// CHECK-SAME: %[[A7:.*7]]: memref<?xi32>,
-// CHECK-SAME: %[[A8:.*8]]: memref<?xi64>,
-// CHECK-SAME: %[[A9:.*9]]: memref<?xf64>)
-// CHECK: return %[[A0]], %[[A1]], %[[A2]], %[[A3]], %[[A4]], %[[A5]], %[[A6]], %[[A7]], %[[A8]], %[[A9]] :
-// CHECK-SAME: memref<1xindex>, memref<3xindex>, memref<?xi32>, memref<?xi64>, memref<?xf64>,
-// CHECK-SAME: memref<1xindex>, memref<3xindex>, memref<?xi32>, memref<?xi64>, memref<?xf64>
+// CHECK-SAME: %[[A0:.*0]]: memref<?xi32>,
+// CHECK-SAME: %[[A1:.*1]]: memref<?xi64>,
+// CHECK-SAME: %[[A2:.*2]]: memref<?xf64>,
+// CHECK-SAME: %[[A3:.*3]]: !sparse_tensor.storage_specifier
+// CHECK-SAME: %[[A4:.*4]]: memref<?xi32>,
+// CHECK-SAME: %[[A5:.*5]]: memref<?xi64>,
+// CHECK-SAME: %[[A6:.*6]]: memref<?xf64>,
+// CHECK-SAME: %[[A7:.*7]]: !sparse_tensor.storage_specifier
+// CHECK: return %[[A0]], %[[A1]], %[[A2]], %[[A3]], %[[A4]], %[[A5]], %[[A6]], %[[A7]] :
+// CHECK-SAME: memref<?xi32>, memref<?xi64>, memref<?xf64>, !sparse_tensor.storage_specifier
+// CHECK-SAME: memref<?xi32>, memref<?xi64>, memref<?xf64>, !sparse_tensor.storage_specifier
func.func @sparse_nop_multi_ret(%arg0: tensor<?xf64, #SparseVector>,
%arg1: tensor<?xf64, #SparseVector>) ->
(tensor<?xf64, #SparseVector>, tensor<?xf64, #SparseVector>) {
@@ -79,20 +76,18 @@ func.func @sparse_nop_multi_ret(%arg0: tensor<?xf64, #SparseVector>,
}
// CHECK-LABEL: func @sparse_nop_call(
-// CHECK-SAME: %[[A0:.*0]]: memref<1xindex>,
-// CHECK-SAME: %[[A1:.*1]]: memref<3xindex>,
-// CHECK-SAME: %[[A2:.*2]]: memref<?xi32>,
-// CHECK-SAME: %[[A3:.*3]]: memref<?xi64>,
-// CHECK-SAME: %[[A4:.*4]]: memref<?xf64>,
-// CHECK-SAME: %[[A5:.*5]]: memref<1xindex>,
-// CHECK-SAME: %[[A6:.*6]]: memref<3xindex>,
-// CHECK-SAME: %[[A7:.*7]]: memref<?xi32>,
-// CHECK-SAME: %[[A8:.*8]]: memref<?xi64>,
-// CHECK-SAME: %[[A9:.*9]]: memref<?xf64>)
-// CHECK: %[[T:.*]]:10 = call @sparse_nop_multi_ret(%[[A0]], %[[A1]], %[[A2]], %[[A3]], %[[A4]], %[[A5]], %[[A6]], %[[A7]], %[[A8]], %[[A9]])
-// CHECK: return %[[T]]#0, %[[T]]#1, %[[T]]#2, %[[T]]#3, %[[T]]#4, %[[T]]#5, %[[T]]#6, %[[T]]#7, %[[T]]#8, %[[T]]#9 :
-// CHECK-SAME: memref<1xindex>, memref<3xindex>, memref<?xi32>, memref<?xi64>, memref<?xf64>,
-// CHECK-SAME: memref<1xindex>, memref<3xindex>, memref<?xi32>, memref<?xi64>, memref<?xf64>
+// CHECK-SAME: %[[A0:.*0]]: memref<?xi32>,
+// CHECK-SAME: %[[A1:.*1]]: memref<?xi64>,
+// CHECK-SAME: %[[A2:.*2]]: memref<?xf64>,
+// CHECK-SAME: %[[A3:.*3]]: !sparse_tensor.storage_specifier
+// CHECK-SAME: %[[A4:.*4]]: memref<?xi32>,
+// CHECK-SAME: %[[A5:.*5]]: memref<?xi64>,
+// CHECK-SAME: %[[A6:.*6]]: memref<?xf64>,
+// CHECK-SAME: %[[A7:.*7]]: !sparse_tensor.storage_specifier
+// CHECK: %[[T:.*]]:8 = call @sparse_nop_multi_ret(%[[A0]], %[[A1]], %[[A2]], %[[A3]], %[[A4]], %[[A5]], %[[A6]], %[[A7]])
+// CHECK: return %[[T]]#0, %[[T]]#1, %[[T]]#2, %[[T]]#3, %[[T]]#4, %[[T]]#5, %[[T]]#6, %[[T]]#7 :
+// CHECK-SAME: memref<?xi32>, memref<?xi64>, memref<?xf64>, !sparse_tensor.storage_specifier
+// CHECK-SAME: memref<?xi32>, memref<?xi64>, memref<?xf64>, !sparse_tensor.storage_specifier
func.func @sparse_nop_call(%arg0: tensor<?xf64, #SparseVector>,
%arg1: tensor<?xf64, #SparseVector>) ->
(tensor<?xf64, #SparseVector>, tensor<?xf64, #SparseVector>) {
@@ -103,68 +98,61 @@ func.func @sparse_nop_call(%arg0: tensor<?xf64, #SparseVector>,
}
// CHECK-LABEL: func @sparse_nop_cast(
-// CHECK-SAME: %[[A0:.*0]]: memref<1xindex>,
-// CHECK-SAME: %[[A1:.*1]]: memref<3xindex>,
-// CHECK-SAME: %[[A2:.*2]]: memref<?xi32>,
-// CHECK-SAME: %[[A3:.*3]]: memref<?xi64>,
-// CHECK-SAME: %[[A4:.*4]]: memref<?xf32>)
-// CHECK: return %[[A0]], %[[A1]], %[[A2]], %[[A3]], %[[A4]] :
-// CHECK-SAME: memref<1xindex>, memref<3xindex>, memref<?xi32>, memref<?xi64>, memref<?xf32>
+// CHECK-SAME: %[[A0:.*]]: memref<?xi32>,
+// CHECK-SAME: %[[A1:.*]]: memref<?xi64>,
+// CHECK-SAME: %[[A2:.*]]: memref<?xf32>,
+// CHECK-SAME: %[[A3:.*]]: !sparse_tensor.storage_specifier
+// CHECK: return %[[A0]], %[[A1]], %[[A2]], %[[A3]] :
func.func @sparse_nop_cast(%arg0: tensor<64xf32, #SparseVector>) -> tensor<?xf32, #SparseVector> {
%0 = tensor.cast %arg0 : tensor<64xf32, #SparseVector> to tensor<?xf32, #SparseVector>
return %0 : tensor<?xf32, #SparseVector>
}
// CHECK-LABEL: func @sparse_nop_cast_3d(
-// CHECK-SAME: %[[A0:.*0]]: memref<3xindex>,
-// CHECK-SAME: %[[A1:.*1]]: memref<1xindex>,
-// CHECK-SAME: %[[A2:.*2]]: memref<?xf32>)
-// CHECK: return %[[A0]], %[[A1]], %[[A2]] :
-// CHECK-SAME: memref<3xindex>, memref<1xindex>, memref<?xf32>
+// CHECK-SAME: %[[A0:.*]]: memref<?xf32>,
+// CHECK-SAME: %[[A1:.*]]: !sparse_tensor.storage_specifier
+// CHECK: return %[[A0]], %[[A1]] :
+// CHECK-SAME: memref<?xf32>, !sparse_tensor.storage_specifier
func.func @sparse_nop_cast_3d(%arg0: tensor<10x20x30xf32, #Dense3D>) -> tensor<?x?x?xf32, #Dense3D> {
%0 = tensor.cast %arg0 : tensor<10x20x30xf32, #Dense3D> to tensor<?x?x?xf32, #Dense3D>
return %0 : tensor<?x?x?xf32, #Dense3D>
}
// CHECK-LABEL: func @sparse_dense_2d(
-// CHECK-SAME: %[[A0:.*0]]: memref<2xindex>,
-// CHECK-SAME: %[[A1:.*1]]: memref<1xindex>,
-// CHECK-SAME: %[[A2:.*2]]: memref<?xf64>)
+// CHECK-SAME: %[[A0:.*]]: memref<?xf64>,
+// CHECK-SAME: %[[A1:.*]]: !sparse_tensor.storage_specifier
// CHECK: return
func.func @sparse_dense_2d(%arg0: tensor<?x?xf64, #Dense2D>) {
return
}
// CHECK-LABEL: func @sparse_row(
-// CHECK-SAME: %[[A0:.*0]]: memref<2xindex>,
-// CHECK-SAME: %[[A1:.*1]]: memref<3xindex>,
-// CHECK-SAME: %[[A2:.*2]]: memref<?xi32>,
-// CHECK-SAME: %[[A3:.*3]]: memref<?xi64>,
-// CHECK-SAME: %[[A4:.*4]]: memref<?xf64>)
+// CHECK-SAME: %[[A0:.*]]: memref<?xi32>,
+// CHECK-SAME: %[[A1:.*]]: memref<?xi64>,
+// CHECK-SAME: %[[A2:.*]]: memref<?xf64>,
+// CHECK-SAME: %[[A3:.*]]: !sparse_tensor.storage_specifier
// CHECK: return
func.func @sparse_row(%arg0: tensor<?x?xf64, #Row>) {
return
}
// CHECK-LABEL: func @sparse_csr(
-// CHECK-SAME: %[[A0:.*0]]: memref<2xindex>,
-// CHECK-SAME: %[[A1:.*1]]: memref<3xindex>,
-// CHECK-SAME: %[[A2:.*2]]: memref<?xi32>,
-// CHECK-SAME: %[[A3:.*3]]: memref<?xi64>,
-// CHECK-SAME: %[[A4:.*4]]: memref<?xf64>)
+// CHECK-SAME: %[[A0:.*]]: memref<?xi32>,
+// CHECK-SAME: %[[A1:.*]]: memref<?xi64>,
+// CHECK-SAME: %[[A2:.*]]: memref<?xf64>,
+// CHECK-SAME: %[[A3:.*]]: !sparse_tensor.storage_specifier
// CHECK: return
func.func @sparse_csr(%arg0: tensor<?x?xf64, #CSR>) {
return
}
// CHECK-LABEL: func @sparse_dcsr(
-// CHECK-SAME: %[[A0:.*0]]: memref<2xindex>,
-// CHECK-SAME: %[[A1:.*1]]: memref<5xindex>,
+// CHECK-SAME: %[[A0:.*0]]: memref<?xi32>,
+// CHECK-SAME: %[[A1:.*1]]: memref<?xi64>,
// CHECK-SAME: %[[A2:.*2]]: memref<?xi32>,
// CHECK-SAME: %[[A3:.*3]]: memref<?xi64>,
-// CHECK-SAME: %[[A4:.*4]]: memref<?xi32>,
-// CHECK-SAME: %[[A5:.*5]]: memref<?xi64>,
-// CHECK-SAME: %[[A6:.*6]]: memref<?xf64>)
+// CHECK-SAME: %[[A4:.*4]]: memref<?xf64>,
+// CHECK-SAME: %[[A5:.*5]]: !sparse_tensor.storage_specifier
// CHECK: return
func.func @sparse_dcsr(%arg0: tensor<?x?xf64, #DCSR>) {
return
@@ -175,9 +163,8 @@ func.func @sparse_dcsr(%arg0: tensor<?x?xf64, #DCSR>) {
// fold using the original static dimension sizes.
//
// CHECK-LABEL: func @sparse_dense_3d(
-// CHECK-SAME: %[[A0:.*0]]: memref<3xindex>,
-// CHECK-SAME: %[[A1:.*1]]: memref<1xindex>,
-// CHECK-SAME: %[[A2:.*2]]: memref<?xf64>)
+// CHECK-SAME: %[[A0:.*]]: memref<?xf64>,
+// CHECK-SAME: %[[A1:.*]]: !sparse_tensor.storage_specifier
// CHECK: %[[C:.*]] = arith.constant 20 : index
// CHECK: return %[[C]] : index
func.func @sparse_dense_3d(%arg0: tensor<10x20x30xf64, #Dense3D>) -> index {
@@ -192,12 +179,11 @@ func.func @sparse_dense_3d(%arg0: tensor<10x20x30xf64, #Dense3D>) -> index {
// since the latter honors the dimOrdering.
//
// CHECK-LABEL: func @sparse_dense_3d_dyn(
-// CHECK-SAME: %[[A0:.*0]]: memref<3xindex>,
-// CHECK-SAME: %[[A1:.*1]]: memref<1xindex>,
-// CHECK-SAME: %[[A2:.*2]]: memref<?xf64>)
-// CHECK: %[[C:.*]] = arith.constant 2 : index
-// CHECK: %[[L:.*]] = memref.load %[[A0]][%[[C]]] : memref<3xindex>
-// CHECK: return %[[L]] : index
+// CHECK-SAME: %[[A0:.*]]: memref<?xf64>,
+// CHECK-SAME: %[[A1:.*]]: !sparse_tensor.storage_specifier
+// CHECK: %[[A2:.*]] = sparse_tensor.storage_specifier.get %[[A1]] dim_sz at 2
+// CHECK: %[[A3:.*]] = arith.index_cast %[[A2]] : i64 to index
+// CHECK: return %[[A3]] : index
func.func @sparse_dense_3d_dyn(%arg0: tensor<?x?x?xf64, #Dense3D>) -> index {
%c = arith.constant 1 : index
%0 = tensor.dim %arg0, %c : tensor<?x?x?xf64, #Dense3D>
@@ -205,55 +191,51 @@ func.func @sparse_dense_3d_dyn(%arg0: tensor<?x?x?xf64, #Dense3D>) -> index {
}
// CHECK-LABEL: func @sparse_pointers_dcsr(
-// CHECK-SAME: %[[A0:.*0]]: memref<2xindex>,
-// CHECK-SAME: %[[A1:.*1]]: memref<5xindex>,
+// CHECK-SAME: %[[A0:.*0]]: memref<?xi32>,
+// CHECK-SAME: %[[A1:.*1]]: memref<?xi64>,
// CHECK-SAME: %[[A2:.*2]]: memref<?xi32>,
// CHECK-SAME: %[[A3:.*3]]: memref<?xi64>,
-// CHECK-SAME: %[[A4:.*4]]: memref<?xi32>,
-// CHECK-SAME: %[[A5:.*5]]: memref<?xi64>,
-// CHECK-SAME: %[[A6:.*6]]: memref<?xf64>)
-// CHECK: return %[[A4]] : memref<?xi32>
+// CHECK-SAME: %[[A4:.*4]]: memref<?xf64>,
+// CHECK-SAME: %[[A5:.*5]]: !sparse_tensor.storage_specifier
+// CHECK: return %[[A2]] : memref<?xi32>
func.func @sparse_pointers_dcsr(%arg0: tensor<?x?xf64, #DCSR>) -> memref<?xi32> {
%0 = sparse_tensor.pointers %arg0 { dimension = 1 : index } : tensor<?x?xf64, #DCSR> to memref<?xi32>
return %0 : memref<?xi32>
}
// CHECK-LABEL: func @sparse_indices_dcsr(
-// CHECK-SAME: %[[A0:.*0]]: memref<2xindex>,
-// CHECK-SAME: %[[A1:.*1]]: memref<5xindex>,
+// CHECK-SAME: %[[A0:.*0]]: memref<?xi32>,
+// CHECK-SAME: %[[A1:.*1]]: memref<?xi64>,
// CHECK-SAME: %[[A2:.*2]]: memref<?xi32>,
// CHECK-SAME: %[[A3:.*3]]: memref<?xi64>,
-// CHECK-SAME: %[[A4:.*4]]: memref<?xi32>,
-// CHECK-SAME: %[[A5:.*5]]: memref<?xi64>,
-// CHECK-SAME: %[[A6:.*6]]: memref<?xf64>)
-// CHECK: return %[[A5]] : memref<?xi64>
+// CHECK-SAME: %[[A4:.*4]]: memref<?xf64>,
+// CHECK-SAME: %[[A5:.*5]]: !sparse_tensor.storage_specifier
+// CHECK: return %[[A3]] : memref<?xi64>
func.func @sparse_indices_dcsr(%arg0: tensor<?x?xf64, #DCSR>) -> memref<?xi64> {
%0 = sparse_tensor.indices %arg0 { dimension = 1 : index } : tensor<?x?xf64, #DCSR> to memref<?xi64>
return %0 : memref<?xi64>
}
// CHECK-LABEL: func @sparse_values_dcsr(
-// CHECK-SAME: %[[A0:.*0]]: memref<2xindex>,
-// CHECK-SAME: %[[A1:.*1]]: memref<5xindex>,
+// CHECK-SAME: %[[A0:.*0]]: memref<?xi32>,
+// CHECK-SAME: %[[A1:.*1]]: memref<?xi64>,
// CHECK-SAME: %[[A2:.*2]]: memref<?xi32>,
// CHECK-SAME: %[[A3:.*3]]: memref<?xi64>,
-// CHECK-SAME: %[[A4:.*4]]: memref<?xi32>,
-// CHECK-SAME: %[[A5:.*5]]: memref<?xi64>,
-// CHECK-SAME: %[[A6:.*6]]: memref<?xf64>)
-// CHECK: return %[[A6]] : memref<?xf64>
+// CHECK-SAME: %[[A4:.*4]]: memref<?xf64>,
+// CHECK-SAME: %[[A5:.*5]]: !sparse_tensor.storage_specifier
+// CHECK: return %[[A4]] : memref<?xf64>
func.func @sparse_values_dcsr(%arg0: tensor<?x?xf64, #DCSR>) -> memref<?xf64> {
%0 = sparse_tensor.values %arg0 : tensor<?x?xf64, #DCSR> to memref<?xf64>
return %0 : memref<?xf64>
}
// CHECK-LABEL: func @sparse_noe(
-// CHECK-SAME: %[[A0:.*0]]: memref<1xindex>,
-// CHECK-SAME: %[[A1:.*1]]: memref<3xindex>,
-// CHECK-SAME: %[[A2:.*2]]: memref<?xi32>,
-// CHECK-SAME: %[[A3:.*3]]: memref<?xi64>,
-// CHECK-SAME: %[[A4:.*4]]: memref<?xf64>)
-// CHECK: %[[C2:.*]] = arith.constant 2 : index
-// CHECK: %[[NOE:.*]] = memref.load %[[A1]][%[[C2]]] : memref<3xindex>
+// CHECK-SAME: %[[A0:.*]]: memref<?xi32>,
+// CHECK-SAME: %[[A1:.*]]: memref<?xi64>,
+// CHECK-SAME: %[[A2:.*]]: memref<?xf64>,
+// CHECK-SAME: %[[A3:.*]]: !sparse_tensor.storage_specifier
+// CHECK: %[[A4:.*]] = sparse_tensor.storage_specifier.get %[[A3]] val_mem_sz
+// CHECK: %[[NOE:.*]] = arith.index_cast %[[A4]] : i64 to index
// CHECK: return %[[NOE]] : index
func.func @sparse_noe(%arg0: tensor<128xf64, #SparseVector>) -> index {
%0 = sparse_tensor.number_of_entries %arg0 : tensor<128xf64, #SparseVector>
@@ -261,70 +243,66 @@ func.func @sparse_noe(%arg0: tensor<128xf64, #SparseVector>) -> index {
}
// CHECK-LABEL: func @sparse_dealloc_csr(
-// CHECK-SAME: %[[A0:.*0]]: memref<2xindex>,
-// CHECK-SAME: %[[A1:.*1]]: memref<3xindex>,
-// CHECK-SAME: %[[A2:.*2]]: memref<?xi32>,
-// CHECK-SAME: %[[A3:.*3]]: memref<?xi64>,
-// CHECK-SAME: %[[A4:.*4]]: memref<?xf64>)
-// CHECK: memref.dealloc %[[A0]] : memref<2xindex>
-// CHECK: memref.dealloc %[[A1]] : memref<3xindex>
-// CHECK: memref.dealloc %[[A2]] : memref<?xi32>
-// CHECK: memref.dealloc %[[A3]] : memref<?xi64>
-// CHECK: memref.dealloc %[[A4]] : memref<?xf64>
+// CHECK-SAME: %[[A0:.*]]: memref<?xi32>,
+// CHECK-SAME: %[[A1:.*]]: memref<?xi64>,
+// CHECK-SAME: %[[A2:.*]]: memref<?xf64>,
+// CHECK-SAME: %[[A3:.*]]: !sparse_tensor.storage_specifier
+// CHECK: memref.dealloc %[[A0]] : memref<?xi32>
+// CHECK: memref.dealloc %[[A1]] : memref<?xi64>
+// CHECK: memref.dealloc %[[A2]] : memref<?xf64>
// CHECK: return
func.func @sparse_dealloc_csr(%arg0: tensor<?x?xf64, #CSR>) {
bufferization.dealloc_tensor %arg0 : tensor<?x?xf64, #CSR>
return
}
-// CHECK-LABEL: func @sparse_alloc_csc(
-// CHECK-SAME: %[[A:.*]]: index) ->
-// CHECK-SAME: memref<2xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf64>
-// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
-// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
-// CHECK-DAG: %[[C10:.*]] = arith.constant 10 : index
-// CHECK: %[[T0:.*]] = memref.alloc() : memref<2xindex>
-// CHECK: %[[T1:.*]] = memref.alloc() : memref<3xindex>
-// CHECK: %[[T2:.*]] = memref.alloc() : memref<16xindex>
-// CHECK: %[[T3:.*]] = memref.cast %[[T2]] : memref<16xindex> to memref<?xindex>
-// CHECK: %[[T4:.*]] = memref.alloc() : memref<16xindex>
-// CHECK: %[[T5:.*]] = memref.cast %[[T4]] : memref<16xindex> to memref<?xindex>
-// CHECK: %[[T6:.*]] = memref.alloc() : memref<16xf64>
-// CHECK: %[[T7:.*]] = memref.cast %[[T6]] : memref<16xf64> to memref<?xf64>
-// CHECK: linalg.fill ins(%[[C0]] : index) outs(%[[T1]] : memref<3xindex>)
-// CHECK: memref.store %[[A]], %[[T0]][%[[C0]]] : memref<2xindex>
-// CHECK: memref.store %[[C10]], %[[T0]][%[[C1]]] : memref<2xindex>
-// CHECK: %[[P0:.*]] = sparse_tensor.push_back %[[T1]], %[[T3]]
-// CHECK: %[[P1:.*]] = sparse_tensor.push_back %[[T1]], %[[P0]]
-// CHECK: return %[[T0]], %[[T1]], %[[P1]], %[[T5]], %[[T7]] :
-// CHECK-SAME: memref<2xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf64>
+// CHECK-LABEL: func.func @sparse_alloc_csc(
+// CHECK-SAME: %[[A0:.*]]: index) -> (memref<?xindex>, memref<?xindex>, memref<?xf64>, !sparse_tensor.storage_specifier
+// CHECK: %[[A1:.*]] = arith.constant 10 : i64
+// CHECK: %[[A2:.*]] = arith.constant 0 : index
+// CHECK: %[[A3:.*]] = memref.alloc() : memref<16xindex>
+// CHECK: %[[A4:.*]] = memref.cast %[[A3]] : memref<16xindex> to memref<?xindex>
+// CHECK: %[[A5:.*]] = memref.alloc() : memref<16xindex>
+// CHECK: %[[A6:.*]] = memref.cast %[[A5]] : memref<16xindex> to memref<?xindex>
+// CHECK: %[[A7:.*]] = memref.alloc() : memref<16xf64>
+// CHECK: %[[A8:.*]] = memref.cast %[[A7]] : memref<16xf64> to memref<?xf64>
+// CHECK: %[[A9:.*]] = sparse_tensor.storage_specifier.init : !sparse_tensor.storage_specifier
+// CHECK: %[[A10:.*]] = arith.index_cast %[[A0]] : index to i64
+// CHECK: %[[A11:.*]] = sparse_tensor.storage_specifier.set %[[A9]] dim_sz at 0 with %[[A10]] : i64, !sparse_tensor.storage_specifier
+// CHECK: %[[A12:.*]] = sparse_tensor.storage_specifier.set %[[A11]] dim_sz at 1 with %[[A1]] : i64, !sparse_tensor.storage_specifier
+// CHECK: %[[A13:.*]] = sparse_tensor.storage_specifier.get %[[A12]] ptr_mem_sz at 1 : !sparse_tensor.storage_specifier
+// CHECK: %[[A14:.*]] = arith.index_cast %[[A13]] : i64 to index
+// CHECK: %[[A15:.*]], %[[A16:.*]] = sparse_tensor.push_back %[[A14]], %[[A4]], %[[A2]] : index, memref<?xindex>, index
+// CHECK: %[[A17:.*]] = arith.index_cast %[[A16]] : index to i64
+// CHECK: %[[A18:.*]] = sparse_tensor.storage_specifier.set %[[A12]] ptr_mem_sz at 1 with %[[A17]] : i64, !sparse_tensor.storage_specifier
+// CHECK: %[[A23:.*]], %[[A24:.*]] = sparse_tensor.push_back %[[A16]], %[[A15]], %[[A2]], %[[A0]] : index, memref<?xindex>, index, index
+// CHECK: %[[A25:.*]] = arith.index_cast %[[A24]] : index to i64
+// CHECK: %[[A26:.*]] = sparse_tensor.storage_specifier.set %[[A18]] ptr_mem_sz at 1 with %[[A25]] : i64, !sparse_tensor.storage_specifier
+// CHECK: return %[[A23]], %[[A6]], %[[A8]], %[[A26]] : memref<?xindex>, memref<?xindex>, memref<?xf64>, !sparse_tensor.storage_specifier
func.func @sparse_alloc_csc(%arg0: index) -> tensor<10x?xf64, #CSC> {
%0 = bufferization.alloc_tensor(%arg0) : tensor<10x?xf64, #CSC>
%1 = sparse_tensor.load %0 : tensor<10x?xf64, #CSC>
return %1 : tensor<10x?xf64, #CSC>
}
-// CHECK-LABEL: func @sparse_alloc_3d() ->
-// CHECK-SAME: memref<3xindex>, memref<1xindex>, memref<?xf64>
-// CHECK-DAG: %[[F0:.*]] = arith.constant 0.000000e+00 : f64
-// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
-// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
-// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
-// CHECK-DAG: %[[C10:.*]] = arith.constant 10 : index
-// CHECK-DAG: %[[C20:.*]] = arith.constant 20 : index
-// CHECK-DAG: %[[C30:.*]] = arith.constant 30 : index
-// CHECK-DAG: %[[C6000:.*]] = arith.constant 6000 : index
-// CHECK: %[[A0:.*]] = memref.alloc() : memref<3xindex>
-// CHECK: %[[A1:.*]] = memref.alloc() : memref<1xindex>
-// CHECK: %[[AV:.*]] = memref.alloc() : memref<16xf64>
-// CHECK: %[[A2:.*]] = memref.cast %[[AV]] : memref<16xf64> to memref<?xf64>
-// CHECK: linalg.fill ins(%[[C0]] : index) outs(%[[A1]] : memref<1xindex>)
-// CHECK: memref.store %[[C30]], %[[A0]][%[[C0]]] : memref<3xindex>
-// CHECK: memref.store %[[C10]], %[[A0]][%[[C1]]] : memref<3xindex>
-// CHECK: memref.store %[[C20]], %[[A0]][%[[C2]]] : memref<3xindex>
-// CHECK: %[[P:.*]] = sparse_tensor.push_back %[[A1]], %[[A2]], %[[F0]], %[[C6000]]
-// CHECK: return %[[A0]], %[[A1]], %[[P]] :
-// CHECK-SAME: memref<3xindex>, memref<1xindex>, memref<?xf64>
+// CHECK-LABEL: func.func @sparse_alloc_3d() -> (memref<?xf64>, !sparse_tensor.storage_specifier
+// CHECK: %[[A0:.*]] = arith.constant 6000 : index
+// CHECK: %[[A1:.*]] = arith.constant 20 : i64
+// CHECK: %[[A2:.*]] = arith.constant 10 : i64
+// CHECK: %[[A3:.*]] = arith.constant 30 : i64
+// CHECK: %[[A4:.*]] = arith.constant 0.000000e+00 : f64
+// CHECK: %[[A5:.*]] = memref.alloc() : memref<16xf64>
+// CHECK: %[[A6:.*]] = memref.cast %[[A5]] : memref<16xf64> to memref<?xf64>
+// CHECK: %[[A7:.*]] = sparse_tensor.storage_specifier.init : !sparse_tensor.storage_specifier
+// CHECK: %[[A8:.*]] = sparse_tensor.storage_specifier.set %[[A7]] dim_sz at 0 with %[[A3]] : i64, !sparse_tensor.storage_specifier
+// CHECK: %[[A9:.*]] = sparse_tensor.storage_specifier.set %[[A8]] dim_sz at 1 with %[[A2]] : i64, !sparse_tensor.storage_specifier
+// CHECK: %[[A10:.*]] = sparse_tensor.storage_specifier.set %[[A9]] dim_sz at 2 with %[[A1]] : i64, !sparse_tensor.storage_specifier
+// CHECK: %[[A11:.*]] = sparse_tensor.storage_specifier.get %[[A10]] val_mem_sz : !sparse_tensor.storage_specifier
+// CHECK: %[[A12:.*]] = arith.index_cast %[[A11]] : i64 to index
+// CHECK: %[[A13:.*]], %[[A14:.*]] = sparse_tensor.push_back %[[A12]], %[[A6]], %[[A4]], %[[A0]] : index, memref<?xf64>, f64, index
+// CHECK: %[[A15:.*]] = arith.index_cast %[[A14]] : index to i64
+// CHECK: %[[A16:.*]] = sparse_tensor.storage_specifier.set %[[A10]] val_mem_sz with %[[A15]] : i64, !sparse_tensor.storage_specifier
+// CHECK: return %[[A13]], %[[A16]] : memref<?xf64>, !sparse_tensor.storage_specifier
func.func @sparse_alloc_3d() -> tensor<10x20x30xf64, #Dense3D> {
%0 = bufferization.alloc_tensor() : tensor<10x20x30xf64, #Dense3D>
%1 = sparse_tensor.load %0 : tensor<10x20x30xf64, #Dense3D>
@@ -364,13 +342,9 @@ func.func @sparse_expansion2() -> memref<?xindex> {
// CHECK-LABEL: func.func @sparse_expansion3(
// CHECK-SAME: %[[D0:.*]]: index,
// CHECK-SAME: %{{.*}}: index) -> memref<?xindex> {
-// CHECK: %[[C1:.*]] = arith.constant 1 : index
-// CHECK: %[[S0:.*]] = memref.alloc() : memref<2xindex>
-// CHECK: memref.store %[[D0]], %[[S0]]{{\[}}%[[C1]]] : memref<2xindex>
-// CHECK: %[[D1:.*]] = memref.load %[[S0]]{{\[}}%[[C1]]] : memref<2xindex>
-// CHECK: %[[V:.*]] = memref.alloc(%[[D1]]) : memref<?xf64>
-// CHECK: %[[B:.*]] = memref.alloc(%[[D1]]) : memref<?xi1>
-// CHECK: %[[D:.*]] = memref.alloc(%[[D1]]) : memref<?xindex>
+// CHECK: %[[V:.*]] = memref.alloc(%[[D0]]) : memref<?xf64>
+// CHECK: %[[B:.*]] = memref.alloc(%[[D0]]) : memref<?xi1>
+// CHECK: %[[D:.*]] = memref.alloc(%[[D0]]) : memref<?xindex>
// CHECK: linalg.fill ins(%{{.*}} : f64) outs(%[[V]] : memref<?xf64>)
// CHECK: linalg.fill ins(%{{.*}} : i1) outs(%[[B]] : memref<?xi1>)
// CHECK: return %[[D]] : memref<?xindex>
@@ -382,45 +356,39 @@ func.func @sparse_expansion3(%arg0: index, %arg1: index) -> memref<?xindex> {
}
// CHECK-LABEL: func.func private @_insert_C_100_f64_0_0(
-// CHECK-SAME: %[[A0:.*0]]: memref<1xindex>,
-// CHECK-SAME: %[[A1:.*1]]: memref<3xindex>,
-// CHECK-SAME: %[[A2:.*2]]: memref<?xindex>,
-// CHECK-SAME: %[[A3:.*3]]: memref<?xindex>,
-// CHECK-SAME: %[[A4:.*4]]: memref<?xf64>,
-// CHECK-SAME: %[[A5:.*5]]: index,
-// CHECK-SAME: %[[A6:.*6]]: f64)
-// CHECK: %[[PV:.*]] = sparse_tensor.push_back %[[A1]], %[[A4]], %[[A6]] {idx = 2 : index} : memref<3xindex>, memref<?xf64>, f64
-// CHECK: return %[[A0]], %[[A1]], %[[A2]], %{{.*}}, %[[PV]]
+// CHECK-SAME: %[[A1:.*0]]: memref<?xindex>,
+// CHECK-SAME: %[[A2:.*1]]: memref<?xindex>,
+// CHECK-SAME: %[[A3:.*2]]: memref<?xf64>,
+// CHECK-SAME: %[[A4:.*3]]: !sparse_tensor.storage_specifier
+// CHECK-SAME: %[[A5:.*4]]: index,
+// CHECK-SAME: %[[A6:.*5]]: f64)
//
-// CHECK-LABEL: func @sparse_compression_1d(
-// CHECK-SAME: %[[A0:.*0]]: memref<1xindex>,
-// CHECK-SAME: %[[A1:.*1]]: memref<3xindex>,
-// CHECK-SAME: %[[A2:.*2]]: memref<?xindex>,
-// CHECK-SAME: %[[A3:.*3]]: memref<?xindex>,
+// CHECK-LABEL: func.func @sparse_compression_1d(
+// CHECK-SAME: %[[A0:.*0]]: memref<?xindex>,
+// CHECK-SAME: %[[A1:.*1]]: memref<?xindex>,
+// CHECK-SAME: %[[A2:.*2]]: memref<?xf64>,
+// CHECK-SAME: %[[A3:.*3]]: !sparse_tensor.storage_specifier
// CHECK-SAME: %[[A4:.*4]]: memref<?xf64>,
-// CHECK-SAME: %[[A5:.*5]]: memref<?xf64>,
-// CHECK-SAME: %[[A6:.*6]]: memref<?xi1>,
-// CHECK-SAME: %[[A7:.*7]]: memref<?xindex>,
-// CHECK-SAME: %[[A8:.*8]]: index)
-// CHECK-DAG: %[[B0:.*]] = arith.constant false
-// CHECK-DAG: %[[F0:.*]] = arith.constant 0.000000e+00 : f64
-// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
-// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
-// CHECK: sparse_tensor.sort %[[A8]], %[[A7]] : memref<?xindex>
-// CHECK: %[[R:.*]]:5 = scf.for %[[I:.*]] = %[[C0]] to %[[A8]] step %[[C1]]
-// CHECK-SAME: iter_args(%[[P0:.*]] = %[[A0]], %[[P1:.*]] = %[[A1]], %[[P2:.*]] = %[[A2]], %[[P3:.*]] = %[[A3]], %[[P4:.*]] = %[[A4]]) -> (memref<1xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf64>) {
-// CHECK: %[[INDEX:.*]] = memref.load %[[A7]][%[[I]]] : memref<?xindex>
-// CHECK: %[[VAL:.*]] = memref.load %[[A5]][%[[INDEX]]] : memref<?xf64>
-// CHECK: %[[C:.*]]:5 = func.call @_insert_C_100_f64_0_0(%[[P0]], %[[P1]], %[[P2]], %[[P3]], %[[P4]], %[[INDEX]], %[[VAL]])
-// CHECK: memref.store %[[F0]], %[[A5]][%[[INDEX]]] : memref<?xf64>
-// CHECK: memref.store %[[B0]], %[[A6]][%[[INDEX]]] : memref<?xi1>
-// CHECK: scf.yield %[[C]]#0, %[[C]]#1, %[[C]]#2, %[[C]]#3, %[[C]]#4 : memref<1xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf64>
+// CHECK-SAME: %[[A5:.*5]]: memref<?xi1>,
+// CHECK-SAME: %[[A6:.*6]]: memref<?xindex>,
+// CHECK-SAME: %[[A7:.*7]]: index) -> (memref<?xindex>, memref<?xindex>, memref<?xf64>, !sparse_tensor.storage_specifier
+// CHECK-DAG: %[[A8:.*]] = arith.constant false
+// CHECK-DAG: %[[A9:.*]] = arith.constant 0.000000e+00 : f64
+// CHECK-DAG: %[[A10:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[A11:.*]] = arith.constant 0 : index
+// CHECK: sparse_tensor.sort %[[A7]], %[[A6]] : memref<?xindex>
+// CHECK: %[[A12:.*]]:4 = scf.for %[[A13:.*]] = %[[A11]] to %[[A7]] step %[[A10]] iter_args(%[[A14:.*]] = %[[A0]], %[[A15:.*]] = %[[A1]], %[[A16:.*]] = %[[A2]], %[[A17:.*]] = %[[A3]])
+// CHECK: %[[A18:.*]] = memref.load %[[A6]]{{\[}}%[[A13]]] : memref<?xindex>
+// CHECK: %[[A19:.*]] = memref.load %[[A4]]{{\[}}%[[A18]]] : memref<?xf64>
+// CHECK: %[[A20:.*]]:4 = func.call @_insert_C_100_f64_0_0(%[[A14]], %[[A15]], %[[A16]], %[[A17]], %[[A18]], %[[A19]])
+// CHECK: memref.store %[[A9]], %[[A4]]{{\[}}%[[A18]]] : memref<?xf64>
+// CHECK: memref.store %[[A8]], %[[A5]]{{\[}}%[[A18]]] : memref<?xi1>
+// CHECK: scf.yield %[[A20]]#0, %[[A20]]#1, %[[A20]]#2, %[[A20]]#3
// CHECK: }
-// CHECK: memref.dealloc %[[A5]] : memref<?xf64>
-// CHECK: memref.dealloc %[[A6]] : memref<?xi1>
-// CHECK: memref.dealloc %[[A7]] : memref<?xindex>
-// CHECK: return %[[R]]#0, %[[R]]#1, %[[R]]#2, %[[R]]#3, %[[R]]#4
-// CHECK-SAME: memref<1xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf64>
+// CHECK: memref.dealloc %[[A4]] : memref<?xf64>
+// CHECK: memref.dealloc %[[A5]] : memref<?xi1>
+// CHECK: memref.dealloc %[[A6]] : memref<?xindex>
+// CHECK: return %[[A21:.*]]#0, %[[A21]]#1, %[[A21]]#2, %[[A21]]#3 : memref<?xindex>, memref<?xindex>, memref<?xf64>, !sparse_tensor.storage_specifier
func.func @sparse_compression_1d(%tensor: tensor<100xf64, #SV>,
%values: memref<?xf64>,
%filled: memref<?xi1>,
@@ -433,47 +401,54 @@ func.func @sparse_compression_1d(%tensor: tensor<100xf64, #SV>,
}
// CHECK-LABEL: func.func private @_insert_D_C_8_8_f64_64_32(
-// CHECK-SAME: %[[A0:.*0]]: memref<2xindex>,
-// CHECK-SAME: %[[A1:.*1]]: memref<3xindex>,
-// CHECK-SAME: %[[A2:.*2]]: memref<?xi32>,
-// CHECK-SAME: %[[A3:.*3]]: memref<?xi64>,
-// CHECK-SAME: %[[A4:.*4]]: memref<?xf64>,
-// CHECK-SAME: %[[A5:.*5]]: index,
-// CHECK-SAME: %[[A6:.*6]]: index,
-// CHECK-SAME: %[[A7:.*7]]: f64)
-// CHECK: %[[PV:.*]] = sparse_tensor.push_back %[[A1]], %[[A4]], %[[A7]] {idx = 2 : index} : memref<3xindex>, memref<?xf64>, f64
-// CHECK: return %[[A0]], %[[A1]], %[[A2]], %{{.*}}, %[[PV]]
+// CHECK-SAME: %[[A1:.*0]]: memref<?xi32>,
+// CHECK-SAME: %[[A2:.*1]]: memref<?xi64>,
+// CHECK-SAME: %[[A3:.*2]]: memref<?xf64>,
+// CHECK-SAME: %[[A4:.*3]]: !sparse_tensor.storage_specifier
+// CHECK-SAME: %[[A5:.*4]]: index,
+// CHECK-SAME: %[[A6:.*5]]: index,
+// CHECK-SAME: %[[A7:.*6]]: f64)
//
-// CHECK-LABEL: func @sparse_compression(
-// CHECK-SAME: %[[A0:.*0]]: memref<2xindex>,
-// CHECK-SAME: %[[A1:.*1]]: memref<3xindex>,
-// CHECK-SAME: %[[A2:.*2]]: memref<?xi32>,
-// CHECK-SAME: %[[A3:.*3]]: memref<?xi64>,
-// CHECK-SAME: %[[A4:.*4]]: memref<?xf64>,
-// CHECK-SAME: %[[A5:.*5]]: memref<?xf64>,
-// CHECK-SAME: %[[A6:.*6]]: memref<?xi1>,
-// CHECK-SAME: %[[A7:.*7]]: memref<?xindex>,
-// CHECK-SAME: %[[A8:.*8]]: index,
-// CHECK-SAME: %[[A9:.*9]]: index)
-// CHECK-DAG: %[[B0:.*]] = arith.constant false
-// CHECK-DAG: %[[F0:.*]] = arith.constant 0.000000e+00 : f64
-// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
-// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
-// CHECK: sparse_tensor.sort %[[A8]], %[[A7]] : memref<?xindex>
-// CHECK: %[[R:.*]]:5 = scf.for %[[I:.*]] = %[[C0]] to %[[A8]] step %[[C1]]
-// CHECK-SAME: iter_args(%[[P0:.*]] = %[[A0]], %[[P1:.*]] = %[[A1]], %[[P2:.*]] = %[[A2]], %[[P3:.*]] = %[[A3]], %[[P4:.*]] = %[[A4]]) -> (memref<2xindex>, memref<3xindex>, memref<?xi32>, memref<?xi64>, memref<?xf64>) {
-// CHECK: %[[INDEX:.*]] = memref.load %[[A7]][%[[I]]] : memref<?xindex>
-// CHECK: %[[VAL:.*]] = memref.load %[[A5]][%[[INDEX]]] : memref<?xf64>
-// CHECK: %[[C:.*]]:5 = func.call @_insert_D_C_8_8_f64_64_32(%[[P0]], %[[P1]], %[[P2]], %[[P3]], %[[P4]], %[[A9]], %[[INDEX]], %[[VAL]])
-// CHECK: memref.store %[[F0]], %[[A5]][%[[INDEX]]] : memref<?xf64>
-// CHECK: memref.store %[[B0]], %[[A6]][%[[INDEX]]] : memref<?xi1>
-// CHECK: scf.yield %[[C]]#0, %[[C]]#1, %[[C]]#2, %[[C]]#3, %[[C]]#4 : memref<2xindex>, memref<3xindex>, memref<?xi32>, memref<?xi64>, memref<?xf64>
-// CHECK: }
-// CHECK: memref.dealloc %[[A5]] : memref<?xf64>
-// CHECK: memref.dealloc %[[A6]] : memref<?xi1>
-// CHECK: memref.dealloc %[[A7]] : memref<?xindex>
-// CHECK: return %[[R]]#0, %[[R]]#1, %[[R]]#2, %[[R]]#3, %[[R]]#4
-// CHECK-SAME: memref<2xindex>, memref<3xindex>, memref<?xi32>, memref<?xi64>, memref<?xf64>
+// CHECK-LABEL: func.func @sparse_compression(
+// CHECK-SAME: %[[A0:.*0]]: memref<?xi32>,
+// CHECK-SAME: %[[A1:.*1]]: memref<?xi64>,
+// CHECK-SAME: %[[A2:.*2]]: memref<?xf64>,
+// CHECK-SAME: %[[A3:.*3]]: !sparse_tensor.storage_specifier
+// CHECK-SAME: %[[A4:.*4]]: memref<?xf64>,
+// CHECK-SAME: %[[A5:.*5]]: memref<?xi1>,
+// CHECK-SAME: %[[A6:.*6]]: memref<?xindex>,
+// CHECK-SAME: %[[A7:.*7]]: index,
+// CHECK-SAME: %[[A8:.*8]]: index) -> (memref<?xi32>, memref<?xi64>, memref<?xf64>, !sparse_tensor.storage_specifier
+// CHECK: %[[A9:.*]] = arith.constant 0 : i32
+// CHECK: %[[A10:.*]] = arith.constant false
+// CHECK: %[[A11:.*]] = arith.constant 0.000000e+00 : f64
+// CHECK: %[[A12:.*]] = arith.constant 1 : index
+// CHECK: %[[A13:.*]] = arith.constant 0 : index
+// CHECK: sparse_tensor.sort %[[A7]], %[[A6]] : memref<?xindex>
+// CHECK: %[[A14:.*]]:4 = scf.for %[[A15:.*]] = %[[A13]] to %[[A7]] step %[[A12]] iter_args(%[[A16:.*]] = %[[A0]], %[[A17:.*]] = %[[A1]], %[[A18:.*]] = %[[A2]], %[[A19:.*]] = %[[A3]]) -> (memref<?xi32>, memref<?xi64>, memref<?xf64>, !sparse_tensor.storage_specifier
+// CHECK: %[[A20:.*]] = memref.load %[[A6]]{{\[}}%[[A15]]] : memref<?xindex>
+// CHECK: %[[A21:.*]] = memref.load %[[A4]]{{\[}}%[[A20]]] : memref<?xf64>
+// CHECK: %[[A22:.*]]:4 = func.call @_insert_D_C_8_8_f64_64_32(%[[A16]], %[[A17]], %[[A18]], %[[A19]], %[[A8]], %[[A20]], %[[A21]]) : (memref<?xi32>, memref<?xi64>, memref<?xf64>, !sparse_tensor.storage_specifier
+// CHECK: memref.store %[[A11]], %[[A4]]{{\[}}%[[A20]]] : memref<?xf64>
+// CHECK: memref.store %[[A10]], %[[A5]]{{\[}}%[[A20]]] : memref<?xi1>
+// CHECK: scf.yield %[[A22]]#0, %[[A22]]#1, %[[A22]]#2, %[[A22]]#3 : memref<?xi32>, memref<?xi64>, memref<?xf64>, !sparse_tensor.storage_specifier
+// CHECK: }
+// CHECK: memref.dealloc %[[A4]] : memref<?xf64>
+// CHECK: memref.dealloc %[[A5]] : memref<?xi1>
+// CHECK: memref.dealloc %[[A6]] : memref<?xindex>
+// CHECK: %[[A23:.*]] = sparse_tensor.storage_specifier.get %[[A24:.*]]#3 ptr_mem_sz at 1 : !sparse_tensor.storage_specifier
+// CHECK: %[[A25:.*]] = arith.index_cast %[[A23]] : i64 to index
+// CHECK: %[[A26:.*]] = memref.load %[[A24]]#0{{\[}}%[[A13]]] : memref<?xi32>
+// CHECK: %[[A27:.*]] = scf.for %[[A28:.*]] = %[[A12]] to %[[A25]] step %[[A12]] iter_args(%[[A29:.*]] = %[[A26]]) -> (i32) {
+// CHECK: %[[A30:.*]] = memref.load %[[A24]]#0{{\[}}%[[A28]]] : memref<?xi32>
+// CHECK: %[[A31:.*]] = arith.cmpi eq, %[[A30]], %[[A9]] : i32
+// CHECK: %[[A32:.*]] = arith.select %[[A31]], %[[A29]], %[[A30]] : i32
+// CHECK: scf.if %[[A31]] {
+// CHECK: memref.store %[[A29]], %[[A24]]#0{{\[}}%[[A28]]] : memref<?xi32>
+// CHECK: }
+// CHECK: scf.yield %[[A32]] : i32
+// CHECK: }
+// CHECK: return %[[A24]]#0, %[[A24]]#1, %[[A24]]#2, %[[A24]]#3 : memref<?xi32>, memref<?xi64>, memref<?xf64>, !sparse_tensor.storage_specifier
func.func @sparse_compression(%tensor: tensor<8x8xf64, #CSR>,
%values: memref<?xf64>,
%filled: memref<?xi1>,
@@ -487,47 +462,52 @@ func.func @sparse_compression(%tensor: tensor<8x8xf64, #CSR>,
}
// CHECK-LABEL: func.func private @_insert_D_C_8_8_f64_0_0(
-// CHECK-SAME: %[[A0:.*0]]: memref<2xindex>,
-// CHECK-SAME: %[[A1:.*1]]: memref<3xindex>,
-// CHECK-SAME: %[[A2:.*2]]: memref<?xindex>,
-// CHECK-SAME: %[[A3:.*3]]: memref<?xindex>,
-// CHECK-SAME: %[[A4:.*4]]: memref<?xf64>,
-// CHECK-SAME: %[[A5:.*5]]: index,
-// CHECK-SAME: %[[A6:.*6]]: index,
-// CHECK-SAME: %[[A7:.*7]]: f64)
-// CHECK: %[[PV:.*]] = sparse_tensor.push_back %[[A1]], %[[A4]], %[[A7]] {idx = 2 : index} : memref<3xindex>, memref<?xf64>, f64
-// CHECK: return %[[A0]], %[[A1]], %[[A2]], %{{.*}}, %[[PV]]
+// CHECK-SAME: %[[A1:.*0]]: memref<?xindex>,
+// CHECK-SAME: %[[A2:.*1]]: memref<?xindex>,
+// CHECK-SAME: %[[A3:.*2]]: memref<?xf64>,
+// CHECK-SAME: %[[A4:.*3]]: !sparse_tensor.storage_specifier
+// CHECK-SAME: %[[A5:.*4]]: index,
+// CHECK-SAME: %[[A6:.*5]]: index,
+// CHECK-SAME: %[[A7:.*6]]: f64)
//
-// CHECK-LABEL: func @sparse_compression_unordered(
-// CHECK-SAME: %[[A0:.*0]]: memref<2xindex>,
-// CHECK-SAME: %[[A1:.*1]]: memref<3xindex>,
-// CHECK-SAME: %[[A2:.*2]]: memref<?xindex>,
-// CHECK-SAME: %[[A3:.*3]]: memref<?xindex>,
-// CHECK-SAME: %[[A4:.*4]]: memref<?xf64>,
-// CHECK-SAME: %[[A5:.*5]]: memref<?xf64>,
-// CHECK-SAME: %[[A6:.*6]]: memref<?xi1>,
-// CHECK-SAME: %[[A7:.*7]]: memref<?xindex>,
-// CHECK-SAME: %[[A8:.*8]]: index,
-// CHECK-SAME: %[[A9:.*9]]: index)
-// CHECK-DAG: %[[B0:.*]] = arith.constant false
-// CHECK-DAG: %[[F0:.*]] = arith.constant 0.000000e+00 : f64
-// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
-// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
-// CHECK-NOT: sparse_tensor.sort
-// CHECK: %[[R:.*]]:5 = scf.for %[[I:.*]] = %[[C0]] to %[[A8]] step %[[C1]]
-// CHECK-SAME: iter_args(%[[P0:.*]] = %[[A0]], %[[P1:.*]] = %[[A1]], %[[P2:.*]] = %[[A2]], %[[P3:.*]] = %[[A3]], %[[P4:.*]] = %[[A4]]) -> (memref<2xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf64>) {
-// CHECK: %[[INDEX:.*]] = memref.load %[[A7]][%[[I]]] : memref<?xindex>
-// CHECK: %[[VAL:.*]] = memref.load %[[A5]][%[[INDEX]]] : memref<?xf64>
-// CHECK: %[[C:.*]]:5 = func.call @_insert_D_C_8_8_f64_0_0(%[[P0]], %[[P1]], %[[P2]], %[[P3]], %[[P4]], %[[A9]], %[[INDEX]], %[[VAL]])
-// CHECK: memref.store %[[F0]], %[[A5]][%[[INDEX]]] : memref<?xf64>
-// CHECK: memref.store %[[B0]], %[[A6]][%[[INDEX]]] : memref<?xi1>
-// CHECK: scf.yield %[[C]]#0, %[[C]]#1, %[[C]]#2, %[[C]]#3, %[[C]]#4 : memref<2xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf64>
-// CHECK: }
-// CHECK: memref.dealloc %[[A5]] : memref<?xf64>
-// CHECK: memref.dealloc %[[A6]] : memref<?xi1>
-// CHECK: memref.dealloc %[[A7]] : memref<?xindex>
-// CHECK: return %[[R]]#0, %[[R]]#1, %[[R]]#2, %[[R]]#3, %[[R]]#4
-// CHECK-SAME: memref<2xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf64>
+// CHECK-LABEL: func.func @sparse_compression_unordered(
+// CHECK-SAME: %[[A0:.*0]]: memref<?xindex>,
+// CHECK-SAME: %[[A1:.*1]]: memref<?xindex>,
+// CHECK-SAME: %[[A2:.*2]]: memref<?xf64>,
+// CHECK-SAME: %[[A3:.*3]]: !sparse_tensor.storage_specifier
+// CHECK-SAME: %[[A4:.*4]]: memref<?xf64>,
+// CHECK-SAME: %[[A5:.*5]]: memref<?xi1>,
+// CHECK-SAME: %[[A6:.*6]]: memref<?xindex>,
+// CHECK-SAME: %[[A7:.*7]]: index,
+// CHECK-SAME: %[[A8:.*8]]: index) -> (memref<?xindex>, memref<?xindex>, memref<?xf64>, !sparse_tensor.storage_specifier
+// CHECK: %[[A9:.*]] = arith.constant false
+// CHECK: %[[A10:.*]] = arith.constant 0.000000e+00 : f64
+// CHECK: %[[A11:.*]] = arith.constant 0 : index
+// CHECK: %[[A12:.*]] = arith.constant 1 : index
+// CHECK: %[[A13:.*]]:4 = scf.for %[[A14:.*]] = %[[A11]] to %[[A7]] step %[[A12]] iter_args(%[[A15:.*]] = %[[A0]], %[[A16:.*]] = %[[A1]], %[[A17:.*]] = %[[A2]], %[[A18:.*]] = %[[A3]]) -> (memref<?xindex>, memref<?xindex>, memref<?xf64>, !sparse_tensor.storage_specifier
+// CHECK: %[[A19:.*]] = memref.load %[[A6]]{{\[}}%[[A14]]] : memref<?xindex>
+// CHECK: %[[A20:.*]] = memref.load %[[A4]]{{\[}}%[[A19]]] : memref<?xf64>
+// CHECK: %[[A21:.*]]:4 = func.call @_insert_D_C_8_8_f64_0_0(%[[A15]], %[[A16]], %[[A17]], %[[A18]], %[[A8]], %[[A19]], %[[A20]]) : (memref<?xindex>, memref<?xindex>, memref<?xf64>, !sparse_tensor.storage_specifier
+// CHECK: memref.store %[[A10]], %[[A4]]{{\[}}%[[A19]]] : memref<?xf64>
+// CHECK: memref.store %[[A9]], %[[A5]]{{\[}}%[[A19]]] : memref<?xi1>
+// CHECK: scf.yield %[[A21]]#0, %[[A21]]#1, %[[A21]]#2, %[[A21]]#3 : memref<?xindex>, memref<?xindex>, memref<?xf64>, !sparse_tensor.storage_specifier
+// CHECK: }
+// CHECK: memref.dealloc %[[A4]] : memref<?xf64>
+// CHECK: memref.dealloc %[[A5]] : memref<?xi1>
+// CHECK: memref.dealloc %[[A6]] : memref<?xindex>
+// CHECK: %[[A22:.*]] = sparse_tensor.storage_specifier.get %[[A23:.*]]#3 ptr_mem_sz at 1 : !sparse_tensor.storage_specifier
+// CHECK: %[[A24:.*]] = arith.index_cast %[[A22]] : i64 to index
+// CHECK: %[[A25:.*]] = memref.load %[[A23]]#0{{\[}}%[[A11]]] : memref<?xindex>
+// CHECK: %[[A26:.*]] = scf.for %[[A27:.*]] = %[[A12]] to %[[A24]] step %[[A12]] iter_args(%[[A28:.*]] = %[[A25]]) -> (index) {
+// CHECK: %[[A29:.*]] = memref.load %[[A23]]#0{{\[}}%[[A27]]] : memref<?xindex>
+// CHECK: %[[A30:.*]] = arith.cmpi eq, %[[A29]], %[[A11]] : index
+// CHECK: %[[A31:.*]] = arith.select %[[A30]], %[[A28]], %[[A29]] : index
+// CHECK: scf.if %[[A30]] {
+// CHECK: memref.store %[[A28]], %[[A23]]#0{{\[}}%[[A27]]] : memref<?xindex>
+// CHECK: }
+// CHECK: scf.yield %[[A31]] : index
+// CHECK: }
+// CHECK: return %[[A23]]#0, %[[A23]]#1, %[[A23]]#2, %[[A23]]#3 : memref<?xindex>, memref<?xindex>, memref<?xf64>, !sparse_tensor.storage_specifier
func.func @sparse_compression_unordered(%tensor: tensor<8x8xf64, #UCSR>,
%values: memref<?xf64>,
%filled: memref<?xi1>,
@@ -541,26 +521,22 @@ func.func @sparse_compression_unordered(%tensor: tensor<8x8xf64, #UCSR>,
}
// CHECK-LABEL: func.func private @_insert_C_128_f64_0_0(
-// CHECK-SAME: %[[A0:.*0]]: memref<1xindex>,
-// CHECK-SAME: %[[A1:.*1]]: memref<3xindex>,
-// CHECK-SAME: %[[A2:.*2]]: memref<?xindex>,
-// CHECK-SAME: %[[A3:.*3]]: memref<?xindex>,
-// CHECK-SAME: %[[A4:.*4]]: memref<?xf64>,
-// CHECK-SAME: %[[A5:.*5]]: index,
-// CHECK-SAME: %[[A6:.*6]]: f64)
-// CHECK: %[[P:.*]] = sparse_tensor.push_back %[[A1]], %[[A4]], %[[A6]] {idx = 2 : index} : memref<3xindex>, memref<?xf64>, f64
-// CHECK: return %[[A0]], %[[A1]], %[[A2]], %{{.*}}, %[[P]] :
-// CHECK: func @sparse_insert(
-// CHECK-SAME: %[[A0:.*0]]: memref<1xindex>,
-// CHECK-SAME: %[[A1:.*1]]: memref<3xindex>,
-// CHECK-SAME: %[[A2:.*2]]: memref<?xindex>,
-// CHECK-SAME: %[[A3:.*3]]: memref<?xindex>,
-// CHECK-SAME: %[[A4:.*4]]: memref<?xf64>,
-// CHECK-SAME: %[[A5:.*5]]: index,
-// CHECK-SAME: %[[A6:.*6]]: f64)
-// CHECK: %[[R:.*]]:5 = call @_insert_C_128_f64_0_0(%[[A0]], %[[A1]], %[[A2]], %[[A3]], %[[A4]], %[[A5]], %[[A6]])
-// CHECK: return %[[R]]#0, %[[R]]#1, %[[R]]#2, %[[R]]#3, %[[R]]#4
-// CHECK-SAME: memref<1xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf64>
+// CHECK-SAME: %[[A1:.*0]]: memref<?xindex>,
+// CHECK-SAME: %[[A2:.*1]]: memref<?xindex>,
+// CHECK-SAME: %[[A3:.*2]]: memref<?xf64>,
+// CHECK-SAME: %[[A4:.*3]]: !sparse_tensor.storage_specifier
+// CHECK-SAME: %[[A5:.*4]]: index,
+// CHECK-SAME: %[[A6:.*5]]: f64)
+//
+// CHECK-LABEL: func @sparse_insert(
+// CHECK-SAME: %[[A1:.*0]]: memref<?xindex>,
+// CHECK-SAME: %[[A2:.*1]]: memref<?xindex>,
+// CHECK-SAME: %[[A3:.*2]]: memref<?xf64>,
+// CHECK-SAME: %[[A4:.*3]]: !sparse_tensor.storage_specifier
+// CHECK-SAME: %[[A5:.*4]]: index,
+// CHECK-SAME: %[[A6:.*5]]: f64)
+// CHECK: %[[R:.*]]:4 = call @_insert_C_128_f64_0_0(%[[A1]], %[[A2]], %[[A3]], %[[A4]], %[[A5]], %[[A6]])
+// CHECK: return %[[R]]#0, %[[R]]#1, %[[R]]#2, %[[R]]#3
func.func @sparse_insert(%arg0: tensor<128xf64, #SV>, %arg1: index, %arg2: f64) -> tensor<128xf64, #SV> {
%0 = sparse_tensor.insert %arg2 into %arg0[%arg1] : tensor<128xf64, #SV>
%1 = sparse_tensor.load %0 hasInserts : tensor<128xf64, #SV>
@@ -568,26 +544,22 @@ func.func @sparse_insert(%arg0: tensor<128xf64, #SV>, %arg1: index, %arg2: f64)
}
// CHECK-LABEL: func.func private @_insert_C_128_f64_64_32(
-// CHECK-SAME: %[[A0:.*0]]: memref<1xindex>,
-// CHECK-SAME: %[[A1:.*1]]: memref<3xindex>,
-// CHECK-SAME: %[[A2:.*2]]: memref<?xi32>,
-// CHECK-SAME: %[[A3:.*3]]: memref<?xi64>,
-// CHECK-SAME: %[[A4:.*4]]: memref<?xf64>,
-// CHECK-SAME: %[[A5:.*5]]: index,
-// CHECK-SAME: %[[A6:.*6]]: f64)
-// CHECK: %[[P:.*]] = sparse_tensor.push_back %[[A1]], %[[A4]], %[[A6]] {idx = 2 : index} : memref<3xindex>, memref<?xf64>, f64
-// CHECK: return %[[A0]], %[[A1]], %[[A2]], %{{.*}}, %[[P]] :
-// CHECK: func @sparse_insert_typed(
-// CHECK-SAME: %[[A0:.*0]]: memref<1xindex>,
-// CHECK-SAME: %[[A1:.*1]]: memref<3xindex>,
-// CHECK-SAME: %[[A2:.*2]]: memref<?xi32>,
-// CHECK-SAME: %[[A3:.*3]]: memref<?xi64>,
-// CHECK-SAME: %[[A4:.*4]]: memref<?xf64>,
-// CHECK-SAME: %[[A5:.*5]]: index,
-// CHECK-SAME: %[[A6:.*6]]: f64)
-// CHECK: %[[R:.*]]:5 = call @_insert_C_128_f64_64_32(%[[A0]], %[[A1]], %[[A2]], %[[A3]], %[[A4]], %[[A5]], %[[A6]])
-// CHECK: return %[[R]]#0, %[[R]]#1, %[[R]]#2, %[[R]]#3, %[[R]]#4
-// CHECK-SAME: memref<1xindex>, memref<3xindex>, memref<?xi32>, memref<?xi64>, memref<?xf64>
+// CHECK-SAME: %[[A1:.*]]: memref<?xi32>,
+// CHECK-SAME: %[[A2:.*]]: memref<?xi64>,
+// CHECK-SAME: %[[A3:.*]]: memref<?xf64>,
+// CHECK-SAME: %[[A4:.*]]: !sparse_tensor.storage_specifier
+// CHECK-SAME: %[[A5:.*]]: index,
+// CHECK-SAME: %[[A6:.*]]: f64)
+//
+// CHECK-LABEL: func @sparse_insert_typed(
+// CHECK-SAME: %[[A1:.*]]: memref<?xi32>,
+// CHECK-SAME: %[[A2:.*]]: memref<?xi64>,
+// CHECK-SAME: %[[A3:.*]]: memref<?xf64>,
+// CHECK-SAME: %[[A4:.*]]: !sparse_tensor.storage_specifier
+// CHECK-SAME: %[[A5:.*]]: index,
+// CHECK-SAME: %[[A6:.*]]: f64)
+// CHECK: %[[R:.*]]:4 = call @_insert_C_128_f64_64_32(%[[A1]], %[[A2]], %[[A3]], %[[A4]], %[[A5]], %[[A6]])
+// CHECK: return %[[R]]#0, %[[R]]#1, %[[R]]#2, %[[R]]#3
func.func @sparse_insert_typed(%arg0: tensor<128xf64, #SparseVector>, %arg1: index, %arg2: f64) -> tensor<128xf64, #SparseVector> {
%0 = sparse_tensor.insert %arg2 into %arg0[%arg1] : tensor<128xf64, #SparseVector>
%1 = sparse_tensor.load %0 hasInserts : tensor<128xf64, #SparseVector>
@@ -595,14 +567,13 @@ func.func @sparse_insert_typed(%arg0: tensor<128xf64, #SparseVector>, %arg1: ind
}
// CHECK-LABEL: func.func @sparse_nop_convert(
-// CHECK-SAME: %[[A0:.*0]]: memref<1xindex>,
-// CHECK-SAME: %[[A1:.*1]]: memref<3xindex>,
-// CHECK-SAME: %[[A2:.*2]]: memref<?xi32>,
-// CHECK-SAME: %[[A3:.*3]]: memref<?xi64>,
-// CHECK-SAME: %[[A4:.*4]]: memref<?xf32>)
-// CHECK: return %[[A0]], %[[A1]], %[[A2]], %[[A3]], %[[A4]] :
-// CHECK-SAME: memref<1xindex>, memref<3xindex>, memref<?xi32>, memref<?xi64>, memref<?xf32>
+// CHECK-SAME: %[[A1:.*]]: memref<?xi32>,
+// CHECK-SAME: %[[A2:.*]]: memref<?xi64>,
+// CHECK-SAME: %[[A3:.*]]: memref<?xf32>,
+// CHECK-SAME: %[[A4:.*]]: !sparse_tensor.storage_specifier
+// CHECK: return %[[A1]], %[[A2]], %[[A3]], %[[A4]] :
+// CHECK-SAME: memref<?xi32>, memref<?xi64>, memref<?xf32>, !sparse_tensor.storage_specifier
func.func @sparse_nop_convert(%arg0: tensor<32xf32, #SparseVector>) -> tensor<?xf32, #SparseVector> {
%0 = sparse_tensor.convert %arg0 : tensor<32xf32, #SparseVector> to tensor<?xf32, #SparseVector>
return %0 : tensor<?xf32, #SparseVector>
-}
\ No newline at end of file
+}
diff --git a/mlir/test/Dialect/SparseTensor/codegen_buffer_initialization.mlir b/mlir/test/Dialect/SparseTensor/codegen_buffer_initialization.mlir
index 4599b2359936..33bbe6a71ad0 100644
--- a/mlir/test/Dialect/SparseTensor/codegen_buffer_initialization.mlir
+++ b/mlir/test/Dialect/SparseTensor/codegen_buffer_initialization.mlir
@@ -2,28 +2,32 @@
#SV = #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ] }>
-// CHECK-LABEL: func @sparse_alloc_sparse_vector(
-// CHECK-SAME: %[[A:.*]]: index) ->
-// CHECK-SAME: memref<1xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf64>
-// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
-// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
-// CHECK-DAG: %[[F0:.*]] = arith.constant 0.000000e+00 : f64
-// CHECK: %[[T0:.*]] = memref.alloc() : memref<1xindex>
-// CHECK: %[[T1:.*]] = memref.alloc() : memref<3xindex>
-// CHECK: %[[T2:.*]] = memref.alloc() : memref<16xindex>
-// CHECK: %[[T3:.*]] = memref.cast %[[T2]] : memref<16xindex> to memref<?xindex>
-// CHECK: linalg.fill ins(%[[C0]] : index) outs(%[[T2]] : memref<16xindex>)
-// CHECK: %[[T4:.*]] = memref.alloc() : memref<16xindex>
-// CHECK: %[[T5:.*]] = memref.cast %[[T4]] : memref<16xindex> to memref<?xindex>
-// CHECK: linalg.fill ins(%[[C0]] : index) outs(%[[T4]] : memref<16xindex>)
-// CHECK: %[[T6:.*]] = memref.alloc() : memref<16xf64>
-// CHECK: %[[T7:.*]] = memref.cast %[[T6]] : memref<16xf64> to memref<?xf64>
-// CHECK: linalg.fill ins(%[[F0]] : f64) outs(%[[T6]] : memref<16xf64>)
-// CHECK: linalg.fill ins(%[[C0]] : index) outs(%[[T1]] : memref<3xindex>)
-// CHECK: memref.store %[[A]], %[[T0]][%[[C0]]] : memref<1xindex>
-// CHECK: %[[P0:.*]] = sparse_tensor.push_back %[[T1]], %[[T3]]
-// CHECK: %[[P1:.*]] = sparse_tensor.push_back %[[T1]], %[[P0]]
-// CHECK: return %[[T0]], %[[T1]], %[[P1]], %[[T5]], %[[T7]] :
+// CHECK-LABEL: func.func @sparse_alloc_sparse_vector(
+// CHECK-SAME: %[[VAL_0:.*]]: index) -> (memref<?xindex>, memref<?xindex>, memref<?xf64>, !sparse_tensor.storage_specifier
+// CHECK: %[[VAL_1:.*]] = arith.constant 1 : index
+// CHECK: %[[VAL_2:.*]] = arith.constant 0.000000e+00 : f64
+// CHECK: %[[VAL_3:.*]] = arith.constant 0 : index
+// CHECK: %[[VAL_4:.*]] = memref.alloc() : memref<16xindex>
+// CHECK: %[[VAL_5:.*]] = memref.cast %[[VAL_4]] : memref<16xindex> to memref<?xindex>
+// CHECK: linalg.fill ins(%[[VAL_3]] : index) outs(%[[VAL_4]] : memref<16xindex>)
+// CHECK: %[[VAL_6:.*]] = memref.alloc() : memref<16xindex>
+// CHECK: %[[VAL_7:.*]] = memref.cast %[[VAL_6]] : memref<16xindex> to memref<?xindex>
+// CHECK: linalg.fill ins(%[[VAL_3]] : index) outs(%[[VAL_6]] : memref<16xindex>)
+// CHECK: %[[VAL_8:.*]] = memref.alloc() : memref<16xf64>
+// CHECK: %[[VAL_9:.*]] = memref.cast %[[VAL_8]] : memref<16xf64> to memref<?xf64>
+// CHECK: linalg.fill ins(%[[VAL_2]] : f64) outs(%[[VAL_8]] : memref<16xf64>)
+// CHECK: %[[VAL_10:.*]] = sparse_tensor.storage_specifier.init : !sparse_tensor.storage_specifier
+// CHECK: %[[VAL_11:.*]] = arith.index_cast %[[VAL_0]] : index to i64
+// CHECK: %[[VAL_12:.*]] = sparse_tensor.storage_specifier.set %[[VAL_10]] dim_sz at 0 with %[[VAL_11]] : i64, !sparse_tensor.storage_specifier
+// CHECK: %[[VAL_13:.*]] = sparse_tensor.storage_specifier.get %[[VAL_12]] ptr_mem_sz at 0 : !sparse_tensor.storage_specifier
+// CHECK: %[[VAL_14:.*]] = arith.index_cast %[[VAL_13]] : i64 to index
+// CHECK: %[[VAL_15:.*]], %[[VAL_16:.*]] = sparse_tensor.push_back %[[VAL_14]], %[[VAL_5]], %[[VAL_3]] : index, memref<?xindex>, index
+// CHECK: %[[VAL_17:.*]] = arith.index_cast %[[VAL_16]] : index to i64
+// CHECK: %[[VAL_18:.*]] = sparse_tensor.storage_specifier.set %[[VAL_12]] ptr_mem_sz at 0 with %[[VAL_17]] : i64, !sparse_tensor.storage_specifier
+// CHECK: %[[VAL_19:.*]], %[[VAL_20:.*]] = sparse_tensor.push_back %[[VAL_16]], %[[VAL_15]], %[[VAL_3]], %[[VAL_1]] : index, memref<?xindex>, index, index
+// CHECK: %[[VAL_21:.*]] = arith.index_cast %[[VAL_20]] : index to i64
+// CHECK: %[[VAL_22:.*]] = sparse_tensor.storage_specifier.set %[[VAL_18]] ptr_mem_sz at 0 with %[[VAL_21]] : i64, !sparse_tensor.storage_specifier
+// CHECK: return %[[VAL_19]], %[[VAL_7]], %[[VAL_9]], %[[VAL_22]] : memref<?xindex>, memref<?xindex>, memref<?xf64>, !sparse_tensor.storage_specifier
func.func @sparse_alloc_sparse_vector(%arg0: index) -> tensor<?xf64, #SV> {
%0 = bufferization.alloc_tensor(%arg0) : tensor<?xf64, #SV>
%1 = sparse_tensor.load %0 : tensor<?xf64, #SV>
diff --git a/mlir/test/Dialect/SparseTensor/invalid.mlir b/mlir/test/Dialect/SparseTensor/invalid.mlir
index e94be7e8782a..4482cf2f86e1 100644
--- a/mlir/test/Dialect/SparseTensor/invalid.mlir
+++ b/mlir/test/Dialect/SparseTensor/invalid.mlir
@@ -192,19 +192,19 @@ func.func @sparse_wrong_arity_insert(%arg0: tensor<128x64xf64, #CSR>, %arg1: ind
// -----
-func.func @sparse_push_back(%arg0: memref<?xindex>, %arg1: memref<?xf64>, %arg2: f32) -> memref<?xf64> {
+func.func @sparse_push_back(%arg0: index, %arg1: memref<?xf64>, %arg2: f32) -> (memref<?xf64>, index) {
// expected-error at +1 {{'sparse_tensor.push_back' op failed to verify that value type matches element type of inBuffer}}
- %0 = sparse_tensor.push_back %arg0, %arg1, %arg2 {idx = 2 : index} : memref<?xindex>, memref<?xf64>, f32
- return %0 : memref<?xf64>
+ %0:2 = sparse_tensor.push_back %arg0, %arg1, %arg2 : index, memref<?xf64>, f32
+ return %0#0, %0#1 : memref<?xf64>, index
}
// -----
-func.func @sparse_push_back_n(%arg0: memref<?xindex>, %arg1: memref<?xf32>, %arg2: f32) -> memref<?xf32> {
+func.func @sparse_push_back_n(%arg0: index, %arg1: memref<?xf32>, %arg2: f32) -> (memref<?xf32>, index) {
%c0 = arith.constant 0: index
// expected-error at +1 {{'sparse_tensor.push_back' op n must be not less than 1}}
- %0 = sparse_tensor.push_back %arg0, %arg1, %arg2, %c0 {idx = 2 : index} : memref<?xindex>, memref<?xf32>, f32, index
- return %0 : memref<?xf32>
+ %0:2 = sparse_tensor.push_back %arg0, %arg1, %arg2, %c0 : index, memref<?xf32>, f32, index
+ return %0#0, %0#1 : memref<?xf32>, index
}
// -----
diff --git a/mlir/test/Dialect/SparseTensor/roundtrip.mlir b/mlir/test/Dialect/SparseTensor/roundtrip.mlir
index 48b4509d6ac1..67fefa282a4a 100644
--- a/mlir/test/Dialect/SparseTensor/roundtrip.mlir
+++ b/mlir/test/Dialect/SparseTensor/roundtrip.mlir
@@ -201,41 +201,41 @@ func.func @sparse_insert(%arg0: tensor<128xf64, #SparseVector>, %arg1: index, %a
// -----
// CHECK-LABEL: func @sparse_push_back(
-// CHECK-SAME: %[[A:.*]]: memref<?xindex>,
+// CHECK-SAME: %[[A:.*]]: index,
// CHECK-SAME: %[[B:.*]]: memref<?xf64>,
-// CHECK-SAME: %[[C:.*]]: f64) -> memref<?xf64> {
-// CHECK: %[[D:.*]] = sparse_tensor.push_back %[[A]], %[[B]], %[[C]] {idx = 2 : index} : memref<?xindex>, memref<?xf64>, f64
+// CHECK-SAME: %[[C:.*]]: f64) -> (memref<?xf64>, index) {
+// CHECK: %[[D:.*]] = sparse_tensor.push_back %[[A]], %[[B]], %[[C]] : index, memref<?xf64>, f64
// CHECK: return %[[D]]
-func.func @sparse_push_back(%arg0: memref<?xindex>, %arg1: memref<?xf64>, %arg2: f64) -> memref<?xf64> {
- %0 = sparse_tensor.push_back %arg0, %arg1, %arg2 {idx = 2 : index} : memref<?xindex>, memref<?xf64>, f64
- return %0 : memref<?xf64>
+func.func @sparse_push_back(%arg0: index, %arg1: memref<?xf64>, %arg2: f64) -> (memref<?xf64>, index) {
+ %0:2 = sparse_tensor.push_back %arg0, %arg1, %arg2 : index, memref<?xf64>, f64
+ return %0#0, %0#1 : memref<?xf64>, index
}
// -----
// CHECK-LABEL: func @sparse_push_back_inbound(
-// CHECK-SAME: %[[A:.*]]: memref<?xindex>,
+// CHECK-SAME: %[[A:.*]]: index,
// CHECK-SAME: %[[B:.*]]: memref<?xf64>,
-// CHECK-SAME: %[[C:.*]]: f64) -> memref<?xf64> {
-// CHECK: %[[D:.*]] = sparse_tensor.push_back inbounds %[[A]], %[[B]], %[[C]] {idx = 2 : index} : memref<?xindex>, memref<?xf64>, f64
+// CHECK-SAME: %[[C:.*]]: f64) -> (memref<?xf64>, index) {
+// CHECK: %[[D:.*]] = sparse_tensor.push_back inbounds %[[A]], %[[B]], %[[C]] : index, memref<?xf64>, f64
// CHECK: return %[[D]]
-func.func @sparse_push_back_inbound(%arg0: memref<?xindex>, %arg1: memref<?xf64>, %arg2: f64) -> memref<?xf64> {
- %0 = sparse_tensor.push_back inbounds %arg0, %arg1, %arg2 {idx = 2 : index} : memref<?xindex>, memref<?xf64>, f64
- return %0 : memref<?xf64>
+func.func @sparse_push_back_inbound(%arg0: index, %arg1: memref<?xf64>, %arg2: f64) -> (memref<?xf64>, index) {
+ %0:2 = sparse_tensor.push_back inbounds %arg0, %arg1, %arg2 : index, memref<?xf64>, f64
+ return %0#0, %0#1 : memref<?xf64>, index
}
// -----
// CHECK-LABEL: func @sparse_push_back_n(
-// CHECK-SAME: %[[A:.*]]: memref<?xindex>,
+// CHECK-SAME: %[[A:.*]]: index,
// CHECK-SAME: %[[B:.*]]: memref<?xf64>,
// CHECK-SAME: %[[C:.*]]: f64,
-// CHECK-SAME: %[[D:.*]]: index) -> memref<?xf64> {
-// CHECK: %[[E:.*]] = sparse_tensor.push_back %[[A]], %[[B]], %[[C]], %[[D]] {idx = 2 : index} : memref<?xindex>, memref<?xf64>, f64, index
+// CHECK-SAME: %[[D:.*]]: index) -> (memref<?xf64>, index) {
+// CHECK: %[[E:.*]] = sparse_tensor.push_back %[[A]], %[[B]], %[[C]], %[[D]] : index, memref<?xf64>, f64, index
// CHECK: return %[[E]]
-func.func @sparse_push_back_n(%arg0: memref<?xindex>, %arg1: memref<?xf64>, %arg2: f64, %arg3: index) -> memref<?xf64> {
- %0 = sparse_tensor.push_back %arg0, %arg1, %arg2, %arg3 {idx = 2 : index} : memref<?xindex>, memref<?xf64>, f64, index
- return %0 : memref<?xf64>
+func.func @sparse_push_back_n(%arg0: index, %arg1: memref<?xf64>, %arg2: f64, %arg3: index) -> (memref<?xf64>, index) {
+ %0:2 = sparse_tensor.push_back %arg0, %arg1, %arg2, %arg3 : index, memref<?xf64>, f64, index
+ return %0#0, %0#1 : memref<?xf64>, index
}
// -----
diff --git a/mlir/test/Dialect/SparseTensor/scf_1_N_conversion.mlir b/mlir/test/Dialect/SparseTensor/scf_1_N_conversion.mlir
index 0c3c9bb0d27c..6922201e2bbc 100644
--- a/mlir/test/Dialect/SparseTensor/scf_1_N_conversion.mlir
+++ b/mlir/test/Dialect/SparseTensor/scf_1_N_conversion.mlir
@@ -1,24 +1,23 @@
// RUN: mlir-opt %s -sparse-tensor-codegen -cse | FileCheck %s
#SparseVector = #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ] }>
-// CHECK-LABEL: func @for(
-// CHECK-SAME: %[[DIM_SIZE:.*0]]: memref<1xindex>,
-// CHECK-SAME: %[[MEM_SIZE:.*1]]: memref<3xindex>,
-// CHECK-SAME: %[[POINTER:.*2]]: memref<?xindex>,
-// CHECK-SAME: %[[INDICES:.*3]]: memref<?xindex>,
-// CHECK-SAME: %[[VALUE:.*4]]: memref<?xf32>,
-// CHECK-SAME: %[[LB:.*5]]: index,
-// CHECK-SAME: %[[UB:.*6]]: index,
-// CHECK-SAME: %[[STEP:.*7]]: index)
-// CHECK: %[[OUT:.*]]:5 = scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] iter_args(
-// CHECK-SAME: %[[SIZE:.*]] = %[[DIM_SIZE]],
-// CHECK-SAME: %[[MEM:.*]] = %[[MEM_SIZE]],
-// CHECK-SAME: %[[PTR:.*]] = %[[POINTER]],
-// CHECK-SAME: %[[IDX:.*]] = %[[INDICES]],
-// CHECK-SAME: %[[VAL:.*]] = %[[VALUE]])
-// CHECK: scf.yield %[[SIZE]], %[[MEM]], %[[PTR]], %[[IDX]], %[[VAL]] : memref<1xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf32>
-// CHECK: }
-// CHECK: return %[[OUT]]#0, %[[OUT]]#1, %[[OUT]]#2, %[[OUT]]#3, %[[OUT]]#4 : memref<1xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf32>
+
+// CHECK-LABEL: func.func @for(
+// CHECK-SAME: %[[VAL_1:.*0]]: memref<?xindex>,
+// CHECK-SAME: %[[VAL_2:.*1]]: memref<?xindex>,
+// CHECK-SAME: %[[VAL_3:.*2]]: memref<?xf32>,
+// CHECK-SAME: %[[VAL_4:.*3]]: !sparse_tensor.storage_specifier
+// CHECK-SAME: %[[VAL_5:.*4]]: index,
+// CHECK-SAME: %[[VAL_6:.*5]]: index,
+// CHECK-SAME: %[[VAL_7:.*6]]: index) -> (memref<?xindex>, memref<?xindex>, memref<?xf32>, !sparse_tensor.storage_specifier
+// CHECK: %[[VAL_8:.*]]:4 = scf.for %[[VAL_9:.*]] = %[[VAL_5]] to %[[VAL_6]] step %[[VAL_7]] iter_args(
+// CHECK-SAME: %[[VAL_11:.*]] = %[[VAL_1]],
+// CHECK-SAME: %[[VAL_12:.*]] = %[[VAL_2]],
+// CHECK-SAME: %[[VAL_13:.*]] = %[[VAL_3]],
+// CHECK-SAME: %[[VAL_14:.*]] = %[[VAL_4]])
+// CHECK: scf.yield %[[VAL_11]], %[[VAL_12]], %[[VAL_13]], %[[VAL_14]] :
+// CHECK: }
+// CHECK: return %[[VAL_8]]#0, %[[VAL_8]]#1, %[[VAL_8]]#2, %[[VAL_8]]#3
func.func @for(%in: tensor<1024xf32, #SparseVector>,
%lb: index, %ub: index, %step: index) -> tensor<1024xf32, #SparseVector> {
%1 = scf.for %i = %lb to %ub step %step iter_args(%vin = %in)
@@ -28,26 +27,23 @@ func.func @for(%in: tensor<1024xf32, #SparseVector>,
return %1 : tensor<1024xf32, #SparseVector>
}
-
-// CHECK-LABEL: func @if(
-// CHECK-SAME: %[[DIM_SIZE:.*0]]: memref<1xindex>,
-// CHECK-SAME: %[[MEM_SIZE:.*1]]: memref<3xindex>,
-// CHECK-SAME: %[[POINTER:.*2]]: memref<?xindex>,
-// CHECK-SAME: %[[INDICES:.*3]]: memref<?xindex>,
-// CHECK-SAME: %[[VALUE:.*4]]: memref<?xf32>,
-// CHECK-SAME: %[[DIM_SIZE_1:.*5]]: memref<1xindex>,
-// CHECK-SAME: %[[MEM_SIZE_1:.*6]]: memref<3xindex>,
-// CHECK-SAME: %[[POINTER_1:.*7]]: memref<?xindex>,
-// CHECK-SAME: %[[INDICES_1:.*8]]: memref<?xindex>,
-// CHECK-SAME: %[[VALUE_1:.*9]]: memref<?xf32>,
-// CHECK-SAME: %[[I1:.*10]]: i1) ->
-// CHECK-SAME: (memref<1xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf32>) {
-// CHECK: %[[SV:.*]]:5 = scf.if %[[I1]] -> (memref<1xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf32>) {
-// CHECK: scf.yield %[[DIM_SIZE]], %[[MEM_SIZE]], %[[POINTER]], %[[INDICES]], %[[VALUE]] : memref<1xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf32>
-// CHECK: } else {
-// CHECK: scf.yield %[[DIM_SIZE_1]], %[[MEM_SIZE_1]], %[[POINTER_1]], %[[INDICES_1]], %[[VALUE_1]] : memref<1xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf32>
-// CHECK: }
-// CHECK: return %[[SV]]#0, %[[SV]]#1, %[[SV]]#2, %[[SV]]#3, %[[SV]]#4 : memref<1xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf32>
+// CHECK-LABEL: func.func @if(
+// CHECK-SAME: %[[VAL_1:.*0]]: memref<?xindex>,
+// CHECK-SAME: %[[VAL_2:.*1]]: memref<?xindex>,
+// CHECK-SAME: %[[VAL_3:.*2]]: memref<?xf32>,
+// CHECK-SAME: %[[VAL_4:.*3]]: !sparse_tensor.storage_specifier
+// CHECK-SAME: %[[VAL_6:.*4]]: memref<?xindex>,
+// CHECK-SAME: %[[VAL_7:.*5]]: memref<?xindex>,
+// CHECK-SAME: %[[VAL_8:.*6]]: memref<?xf32>,
+// CHECK-SAME: %[[VAL_9:.*7]]: !sparse_tensor.storage_specifier
+// CHECK-SAME: %[[VAL_10:.*]]: i1)
+// CHECK: %[[VAL_11:.*]]:4 = scf.if %[[VAL_10]]
+// CHECK: scf.yield %[[VAL_1]], %[[VAL_2]], %[[VAL_3]], %[[VAL_4]]
+// CHECK: } else {
+// CHECK: scf.yield %[[VAL_6]], %[[VAL_7]], %[[VAL_8]], %[[VAL_9]]
+// CHECK: }
+// CHECK: return %[[VAL_11]]#0, %[[VAL_11]]#1, %[[VAL_11]]#2, %[[VAL_11]]#3 :
+// CHECK-SAME: memref<?xindex>, memref<?xindex>, memref<?xf32>, !sparse_tensor.storage_specifier
func.func @if(%t: tensor<1024xf32, #SparseVector>,
%f: tensor<1024xf32, #SparseVector>,
%c: i1) -> tensor<1024xf32, #SparseVector> {
@@ -59,26 +55,28 @@ func.func @if(%t: tensor<1024xf32, #SparseVector>,
return %1 : tensor<1024xf32, #SparseVector>
}
-// CHECK-LABEL: func @while(
-// CHECK-SAME: %[[DIM_SIZE:.*0]]: memref<1xindex>,
-// CHECK-SAME: %[[MEM_SIZE:.*1]]: memref<3xindex>,
-// CHECK-SAME: %[[POINTER:.*2]]: memref<?xindex>,
-// CHECK-SAME: %[[INDICES:.*3]]: memref<?xindex>,
-// CHECK-SAME: %[[VALUE:.*4]]: memref<?xf32>,
-// CHECK-SAME: %[[I1:.*5]]: i1) ->
-// CHECK-SAME: (memref<1xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf32>) {
-// CHECK: %[[SV:.*]]:5 = scf.while (
-// CHECK-SAME: %[[TMP_DIM:.*]] = %[[DIM_SIZE]],
-// CHECK-SAME: %[[TMP_MEM:.*]] = %[[MEM_SIZE]],
-// CHECK-SAME: %[[TMP_PTR:.*]] = %[[POINTER]],
-// CHECK-SAME: %[[TMP_IND:.*]] = %[[INDICES]],
-// CHECK-SAME: %[[TMP_VAL:.*]] = %[[VALUE]])
-// CHECK: scf.condition(%[[I1]]) %[[TMP_DIM]], %[[TMP_MEM]], %[[TMP_PTR]], %[[TMP_IND]], %[[TMP_VAL]] : memref<1xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf32>
-// CHECK: } do {
-// CHECK: ^bb0(%[[TMP_DIM]]: memref<1xindex>, %[[TMP_MEM]]: memref<3xindex>, %[[TMP_PTR]]: memref<?xindex>, %[[TMP_IND]]: memref<?xindex>, %[[TMP_VAL]]: memref<?xf32>):
-// CHECK: scf.yield %[[TMP_DIM]], %[[TMP_MEM]], %[[TMP_PTR]], %[[TMP_IND]], %[[TMP_VAL]] : memref<1xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf32>
-// CHECK: }
-// CHECK: return %[[SV]]#0, %[[SV]]#1, %[[SV]]#2, %[[SV]]#3, %[[SV]]#4 : memref<1xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf32>
+
+// CHECK-LABEL: func.func @while(
+// CHECK-SAME: %[[VAL_1:.*0]]: memref<?xindex>,
+// CHECK-SAME: %[[VAL_2:.*1]]: memref<?xindex>,
+// CHECK-SAME: %[[VAL_3:.*2]]: memref<?xf32>,
+// CHECK-SAME: %[[VAL_4:.*3]]: !sparse_tensor.storage_specifier
+// CHECK-SAME: %[[VAL_5:.*4]]: i1)
+// CHECK: %[[VAL_6:.*]]:4 = scf.while (
+// CHECK-SAME: %[[VAL_8:.*]] = %[[VAL_1]],
+// CHECK-SAME: %[[VAL_9:.*]] = %[[VAL_2]],
+// CHECK-SAME: %[[VAL_10:.*]] = %[[VAL_3]],
+// CHECK-SAME: %[[VAL_11:.*]] = %[[VAL_4]])
+// CHECK: scf.condition(%[[VAL_5]]) %[[VAL_8]], %[[VAL_9]], %[[VAL_10]], %[[VAL_11]]
+// CHECK: } do {
+// CHECK: ^bb0(%[[VAL_13:.*5]]: memref<?xindex>,
+// CHECK-SAME: %[[VAL_14:.*6]]: memref<?xindex>,
+// CHECK-SAME: %[[VAL_15:.*7]]: memref<?xf32>,
+// CHECK-SAME: %[[VAL_16:.*8]]: !sparse_tensor.storage_specifier
+// CHECK: scf.yield %[[VAL_13]], %[[VAL_14]], %[[VAL_15]], %[[VAL_16]]
+// CHECK: }
+// CHECK: return %[[VAL_6]]#0, %[[VAL_6]]#1, %[[VAL_6]]#2, %[[VAL_6]]#3 :
+// CHECK-SAME: memref<?xindex>, memref<?xindex>, memref<?xf32>, !sparse_tensor.storage_specifier
func.func @while(%arg0: tensor<1024xf32, #SparseVector>, %c: i1) -> tensor<1024xf32, #SparseVector> {
%0 = scf.while (%in = %arg0) : (tensor<1024xf32, #SparseVector>) -> tensor<1024xf32, #SparseVector> {
scf.condition(%c) %in : tensor<1024xf32, #SparseVector>
diff --git a/mlir/test/Dialect/SparseTensor/sparse_matmul_codegen.mlir b/mlir/test/Dialect/SparseTensor/sparse_matmul_codegen.mlir
index 3808a5b20074..90fa3a550996 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_matmul_codegen.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_matmul_codegen.mlir
@@ -13,134 +13,145 @@
// Computes C = A x B with all matrices sparse (SpMSpM) in CSR.
//
// CHECK-LABEL: func.func private @_insert_D_C_4_4_f64_0_0(
-// CHECK-SAME: %[[VAL_0:.*]]: memref<2xindex>,
-// CHECK-SAME: %[[VAL_1:.*]]: memref<3xindex>,
-// CHECK-SAME: %[[VAL_2:[^ ]+]]: memref<?xindex>,
-// CHECK-SAME: %[[VAL_3:.*]]: memref<?xindex>,
-// CHECK-SAME: %[[VAL_4:.*]]: memref<?xf64>,
-// CHECK-SAME: %[[VAL_5:[^ ]+]]: index,
-// CHECK-SAME: %[[VAL_6:.*]]: index,
-// CHECK-SAME: %[[VAL_7:.*]]: f64) -> (memref<2xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf64>) {
-// CHECK-DAG: %[[VAL_8:.*]] = arith.constant false
-// CHECK-DAG: %[[VAL_9:.*]] = arith.constant 1 : index
-// CHECK: %[[VAL_10:.*]] = arith.addi %[[VAL_5]], %[[VAL_9]] : index
-// CHECK: %[[VAL_11:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_5]]] : memref<?xindex>
-// CHECK: %[[VAL_12:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_10]]] : memref<?xindex>
-// CHECK: %[[VAL_13:.*]] = memref.load %[[VAL_1]]{{\[}}%[[VAL_9]]] : memref<3xindex>
-// CHECK: %[[VAL_14:.*]] = arith.subi %[[VAL_12]], %[[VAL_9]] : index
-// CHECK: %[[VAL_15:.*]] = arith.cmpi ult, %[[VAL_11]], %[[VAL_12]] : index
+// CHECK-SAME: %[[VAL_0:.*0]]: memref<?xindex>,
+// CHECK-SAME: %[[VAL_1:.*1]]: memref<?xindex>,
+// CHECK-SAME: %[[VAL_2:.*2]]: memref<?xf64>,
+// CHECK-SAME: %[[VAL_3:.*3]]: !sparse_tensor.storage_specifier
+// CHECK-SAME: %[[VAL_4:.*4]]: index,
+// CHECK-SAME: %[[VAL_5:.*5]]: index,
+// CHECK-SAME: %[[VAL_6:.*6]]: f64) -> (memref<?xindex>, memref<?xindex>, memref<?xf64>, !sparse_tensor.storage_specifier
+// CHECK: %[[VAL_7:.*]] = arith.constant false
+// CHECK: %[[VAL_8:.*]] = arith.constant 1 : index
+// CHECK: %[[VAL_9:.*]] = arith.addi %[[VAL_4]], %[[VAL_8]] : index
+// CHECK: %[[VAL_10:.*]] = memref.load %[[VAL_0]]{{\[}}%[[VAL_4]]] : memref<?xindex>
+// CHECK: %[[VAL_11:.*]] = memref.load %[[VAL_0]]{{\[}}%[[VAL_9]]] : memref<?xindex>
+// CHECK: %[[VAL_12:.*]] = sparse_tensor.storage_specifier.get %[[VAL_3]] idx_mem_sz at 1 : !sparse_tensor.storage_specifier
+// CHECK: %[[VAL_13:.*]] = arith.index_cast %[[VAL_12]] : i64 to index
+// CHECK: %[[VAL_14:.*]] = arith.subi %[[VAL_11]], %[[VAL_8]] : index
+// CHECK: %[[VAL_15:.*]] = arith.cmpi ult, %[[VAL_10]], %[[VAL_11]] : index
// CHECK: %[[VAL_16:.*]] = scf.if %[[VAL_15]] -> (i1) {
-// CHECK: %[[VAL_17:.*]] = memref.load %[[VAL_3]]{{\[}}%[[VAL_14]]] : memref<?xindex>
-// CHECK: %[[VAL_18:.*]] = arith.cmpi eq, %[[VAL_17]], %[[VAL_6]] : index
+// CHECK: %[[VAL_17:.*]] = memref.load %[[VAL_1]]{{\[}}%[[VAL_14]]] : memref<?xindex>
+// CHECK: %[[VAL_18:.*]] = arith.cmpi eq, %[[VAL_17]], %[[VAL_5]] : index
// CHECK: scf.yield %[[VAL_18]] : i1
// CHECK: } else {
-// CHECK: memref.store %[[VAL_13]], %[[VAL_2]]{{\[}}%[[VAL_5]]] : memref<?xindex>
-// CHECK: scf.yield %[[VAL_8]] : i1
+// CHECK: memref.store %[[VAL_13]], %[[VAL_0]]{{\[}}%[[VAL_4]]] : memref<?xindex>
+// CHECK: scf.yield %[[VAL_7]] : i1
// CHECK: }
-// CHECK: %[[VAL_19:.*]] = scf.if %[[VAL_20:.*]] -> (memref<?xindex>) {
-// CHECK: scf.yield %[[VAL_3]] : memref<?xindex>
+// CHECK: %[[VAL_19:.*]]:2 = scf.if %[[VAL_20:.*]] -> (memref<?xindex>, !sparse_tensor.storage_specifier
+// CHECK: scf.yield %[[VAL_1]], %[[VAL_3]] : memref<?xindex>, !sparse_tensor.storage_specifier
// CHECK: } else {
-// CHECK: %[[VAL_21:.*]] = arith.addi %[[VAL_13]], %[[VAL_9]] : index
-// CHECK: memref.store %[[VAL_21]], %[[VAL_2]]{{\[}}%[[VAL_10]]] : memref<?xindex>
-// CHECK: %[[VAL_22:.*]] = sparse_tensor.push_back %[[VAL_1]], %[[VAL_3]], %[[VAL_6]] {idx = 1 : index} : memref<3xindex>, memref<?xindex>, index
-// CHECK: scf.yield %[[VAL_22]] : memref<?xindex>
+// CHECK: %[[VAL_21:.*]] = arith.addi %[[VAL_13]], %[[VAL_8]] : index
+// CHECK: memref.store %[[VAL_21]], %[[VAL_0]]{{\[}}%[[VAL_9]]] : memref<?xindex>
+// CHECK: %[[VAL_22:.*]], %[[VAL_23:.*]] = sparse_tensor.push_back %[[VAL_13]], %[[VAL_1]], %[[VAL_5]] : index, memref<?xindex>, index
+// CHECK: %[[VAL_24:.*]] = arith.index_cast %[[VAL_23]] : index to i64
+// CHECK: %[[VAL_25:.*]] = sparse_tensor.storage_specifier.set %[[VAL_3]] idx_mem_sz at 1 with %[[VAL_24]] : i64, !sparse_tensor.storage_specifier
+// CHECK: scf.yield %[[VAL_22]], %[[VAL_25]] : memref<?xindex>, !sparse_tensor.storage_specifier
// CHECK: }
-// CHECK: %[[VAL_23:.*]] = sparse_tensor.push_back %[[VAL_1]], %[[VAL_4]], %[[VAL_7]] {idx = 2 : index} : memref<3xindex>, memref<?xf64>, f64
-// CHECK: return %[[VAL_0]], %[[VAL_1]], %[[VAL_2]], %[[VAL_24:.*]], %[[VAL_23]] : memref<2xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf64>
+// CHECK: %[[VAL_26:.*]] = sparse_tensor.storage_specifier.get %[[VAL_27:.*]]#1 val_mem_sz : !sparse_tensor.storage_specifier
+// CHECK: %[[VAL_28:.*]] = arith.index_cast %[[VAL_26]] : i64 to index
+// CHECK: %[[VAL_29:.*]], %[[VAL_30:.*]] = sparse_tensor.push_back %[[VAL_28]], %[[VAL_2]], %[[VAL_6]] : index, memref<?xf64>, f64
+// CHECK: %[[VAL_31:.*]] = arith.index_cast %[[VAL_30]] : index to i64
+// CHECK: %[[VAL_32:.*]] = sparse_tensor.storage_specifier.set %[[VAL_27]]#1 val_mem_sz with %[[VAL_31]] : i64, !sparse_tensor.storage_specifier
+// CHECK: return %[[VAL_0]], %[[VAL_27]]#0, %[[VAL_29]], %[[VAL_32]] : memref<?xindex>, memref<?xindex>, memref<?xf64>, !sparse_tensor.storage_specifier
// CHECK: }
// CHECK-LABEL: func.func @matmul(
-// CHECK-SAME: %[[VAL_0:.*0]]: memref<2xindex>,
-// CHECK-SAME: %[[VAL_1:.*1]]: memref<3xindex>,
-// CHECK-SAME: %[[VAL_2:.*2]]: memref<?xindex>,
-// CHECK-SAME: %[[VAL_3:.*3]]: memref<?xindex>,
-// CHECK-SAME: %[[VAL_4:.*4]]: memref<?xf64>,
-// CHECK-SAME: %[[VAL_5:.*5]]: memref<2xindex>,
-// CHECK-SAME: %[[VAL_6:.*6]]: memref<3xindex>,
-// CHECK-SAME: %[[VAL_7:.*7]]: memref<?xindex>,
-// CHECK-SAME: %[[VAL_8:.*8]]: memref<?xindex>,
-// CHECK-SAME: %[[VAL_9:.*9]]: memref<?xf64>) -> (memref<2xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf64>) {
-// CHECK-DAG: %[[VAL_10:.*]] = arith.constant 4 : index
-// CHECK-DAG: %[[VAL_11:.*]] = arith.constant 0.000000e+00 : f64
-// CHECK-DAG: %[[VAL_12:.*]] = arith.constant 0 : index
-// CHECK-DAG: %[[VAL_13:.*]] = arith.constant 1 : index
-// CHECK-DAG: %[[VAL_14:.*]] = arith.constant false
-// CHECK-DAG: %[[VAL_15:.*]] = arith.constant true
-// CHECK: %[[VAL_16:.*]] = memref.alloc() : memref<2xindex>
-// CHECK: %[[VAL_17:.*]] = memref.alloc() : memref<3xindex>
-// CHECK: %[[VAL_18:.*]] = memref.alloc() : memref<16xindex>
-// CHECK: %[[VAL_19:.*]] = memref.cast %[[VAL_18]] : memref<16xindex> to memref<?xindex>
-// CHECK: %[[VAL_20:.*]] = memref.alloc() : memref<16xindex>
-// CHECK: %[[VAL_21:.*]] = memref.cast %[[VAL_20]] : memref<16xindex> to memref<?xindex>
-// CHECK: %[[VAL_22:.*]] = memref.alloc() : memref<16xf64>
-// CHECK: %[[VAL_23:.*]] = memref.cast %[[VAL_22]] : memref<16xf64> to memref<?xf64>
-// CHECK: linalg.fill ins(%[[VAL_12]] : index) outs(%[[VAL_17]] : memref<3xindex>)
-// CHECK: memref.store %[[VAL_10]], %[[VAL_16]]{{\[}}%[[VAL_12]]] : memref<2xindex>
-// CHECK: memref.store %[[VAL_10]], %[[VAL_16]]{{\[}}%[[VAL_13]]] : memref<2xindex>
-// CHECK: %[[VAL_24:.*]] = sparse_tensor.push_back %[[VAL_17]], %[[VAL_19]], %[[VAL_12]] {idx = 0 : index} : memref<3xindex>, memref<?xindex>, index
-// CHECK: %[[VAL_25:.*]] = sparse_tensor.push_back %[[VAL_17]], %[[VAL_24]], %[[VAL_12]], %[[VAL_10]] {idx = 0 : index} : memref<3xindex>, memref<?xindex>, index, index
-// CHECK: %[[VAL_26:.*]] = memref.alloc() : memref<4xf64>
-// CHECK: %[[VAL_27:.*]] = memref.alloc() : memref<4xi1>
-// CHECK: %[[VAL_28:.*]] = memref.alloc() : memref<4xindex>
-// CHECK: %[[VAL_29:.*]] = memref.cast %[[VAL_28]] : memref<4xindex> to memref<?xindex>
-// CHECK: linalg.fill ins(%[[VAL_11]] : f64) outs(%[[VAL_26]] : memref<4xf64>)
-// CHECK: linalg.fill ins(%[[VAL_14]] : i1) outs(%[[VAL_27]] : memref<4xi1>)
-// CHECK: %[[VAL_30:.*]]:5 = scf.for %[[VAL_31:.*]] = %[[VAL_12]] to %[[VAL_10]] step %[[VAL_13]] iter_args(%[[VAL_32:.*]] = %[[VAL_16]], %[[VAL_33:.*]] = %[[VAL_17]], %[[VAL_34:.*]] = %[[VAL_25]], %[[VAL_35:.*]] = %[[VAL_21]], %[[VAL_36:.*]] = %[[VAL_23]]) -> (memref<2xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf64>) {
-// CHECK: %[[VAL_37:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_31]]] : memref<?xindex>
-// CHECK: %[[VAL_38:.*]] = arith.addi %[[VAL_31]], %[[VAL_13]] : index
-// CHECK: %[[VAL_39:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_38]]] : memref<?xindex>
-// CHECK: %[[VAL_40:.*]] = scf.for %[[VAL_41:.*]] = %[[VAL_37]] to %[[VAL_39]] step %[[VAL_13]] iter_args(%[[VAL_42:.*]] = %[[VAL_12]]) -> (index) {
-// CHECK: %[[VAL_43:.*]] = memref.load %[[VAL_3]]{{\[}}%[[VAL_41]]] : memref<?xindex>
-// CHECK: %[[VAL_44:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_41]]] : memref<?xf64>
-// CHECK: %[[VAL_45:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_43]]] : memref<?xindex>
-// CHECK: %[[VAL_46:.*]] = arith.addi %[[VAL_43]], %[[VAL_13]] : index
-// CHECK: %[[VAL_47:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_46]]] : memref<?xindex>
-// CHECK: %[[VAL_48:.*]] = scf.for %[[VAL_49:.*]] = %[[VAL_45]] to %[[VAL_47]] step %[[VAL_13]] iter_args(%[[VAL_50:.*]] = %[[VAL_42]]) -> (index) {
-// CHECK: %[[VAL_51:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_49]]] : memref<?xindex>
-// CHECK: %[[VAL_52:.*]] = memref.load %[[VAL_26]]{{\[}}%[[VAL_51]]] : memref<4xf64>
-// CHECK: %[[VAL_53:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_49]]] : memref<?xf64>
-// CHECK: %[[VAL_54:.*]] = arith.mulf %[[VAL_44]], %[[VAL_53]] : f64
-// CHECK: %[[VAL_55:.*]] = arith.addf %[[VAL_52]], %[[VAL_54]] : f64
-// CHECK: %[[VAL_56:.*]] = memref.load %[[VAL_27]]{{\[}}%[[VAL_51]]] : memref<4xi1>
-// CHECK: %[[VAL_57:.*]] = arith.cmpi eq, %[[VAL_56]], %[[VAL_14]] : i1
-// CHECK: %[[VAL_58:.*]] = scf.if %[[VAL_57]] -> (index) {
-// CHECK: memref.store %[[VAL_15]], %[[VAL_27]]{{\[}}%[[VAL_51]]] : memref<4xi1>
-// CHECK: memref.store %[[VAL_51]], %[[VAL_28]]{{\[}}%[[VAL_50]]] : memref<4xindex>
-// CHECK: %[[VAL_59:.*]] = arith.addi %[[VAL_50]], %[[VAL_13]] : index
-// CHECK: scf.yield %[[VAL_59]] : index
+// CHECK-SAME: %[[VAL_0:.*0]]: memref<?xindex>,
+// CHECK-SAME: %[[VAL_1:.*1]]: memref<?xindex>,
+// CHECK-SAME: %[[VAL_2:.*2]]: memref<?xf64>,
+// CHECK-SAME: %[[VAL_3:.*3]]: !sparse_tensor.storage_specifier
+// CHECK-SAME: %[[VAL_4:.*4]]: memref<?xindex>,
+// CHECK-SAME: %[[VAL_5:.*5]]: memref<?xindex>,
+// CHECK-SAME: %[[VAL_6:.*6]]: memref<?xf64>,
+// CHECK-SAME: %[[VAL_7:.*7]]: !sparse_tensor.storage_specifier
+// CHECK: %[[VAL_8:.*]] = arith.constant 4 : index
+// CHECK: %[[VAL_9:.*]] = arith.constant 4 : i64
+// CHECK: %[[VAL_10:.*]] = arith.constant 0.000000e+00 : f64
+// CHECK: %[[VAL_11:.*]] = arith.constant 0 : index
+// CHECK: %[[VAL_12:.*]] = arith.constant 1 : index
+// CHECK: %[[VAL_13:.*]] = arith.constant false
+// CHECK: %[[VAL_14:.*]] = arith.constant true
+// CHECK: %[[VAL_15:.*]] = memref.alloc() : memref<16xindex>
+// CHECK: %[[VAL_16:.*]] = memref.cast %[[VAL_15]] : memref<16xindex> to memref<?xindex>
+// CHECK: %[[VAL_17:.*]] = memref.alloc() : memref<16xindex>
+// CHECK: %[[VAL_18:.*]] = memref.cast %[[VAL_17]] : memref<16xindex> to memref<?xindex>
+// CHECK: %[[VAL_19:.*]] = memref.alloc() : memref<16xf64>
+// CHECK: %[[VAL_20:.*]] = memref.cast %[[VAL_19]] : memref<16xf64> to memref<?xf64>
+// CHECK: %[[VAL_21:.*]] = sparse_tensor.storage_specifier.init : !sparse_tensor.storage_specifier
+// CHECK: %[[VAL_22:.*]] = sparse_tensor.storage_specifier.set %[[VAL_21]] dim_sz at 0 with %[[VAL_9]] : i64, !sparse_tensor.storage_specifier
+// CHECK: %[[VAL_23:.*]] = sparse_tensor.storage_specifier.set %[[VAL_22]] dim_sz at 1 with %[[VAL_9]] : i64, !sparse_tensor.storage_specifier
+// CHECK: %[[VAL_24:.*]] = sparse_tensor.storage_specifier.get %[[VAL_23]] ptr_mem_sz at 1 : !sparse_tensor.storage_specifier
+// CHECK: %[[VAL_25:.*]] = arith.index_cast %[[VAL_24]] : i64 to index
+// CHECK: %[[VAL_26:.*]], %[[VAL_27:.*]] = sparse_tensor.push_back %[[VAL_25]], %[[VAL_16]], %[[VAL_11]] : index, memref<?xindex>, index
+// CHECK: %[[VAL_28:.*]] = arith.index_cast %[[VAL_27]] : index to i64
+// CHECK: %[[VAL_29:.*]] = sparse_tensor.storage_specifier.set %[[VAL_23]] ptr_mem_sz at 1 with %[[VAL_28]] : i64, !sparse_tensor.storage_specifier
+// CHECK: %[[VAL_32:.*]], %[[VAL_33:.*]] = sparse_tensor.push_back %[[VAL_27]], %[[VAL_26]], %[[VAL_11]], %[[VAL_8]] : index, memref<?xindex>, index, index
+// CHECK: %[[VAL_34:.*]] = arith.index_cast %[[VAL_33]] : index to i64
+// CHECK: %[[VAL_35:.*]] = sparse_tensor.storage_specifier.set %[[VAL_29]] ptr_mem_sz at 1 with %[[VAL_34]] : i64, !sparse_tensor.storage_specifier
+// CHECK: %[[VAL_36:.*]] = memref.alloc() : memref<4xf64>
+// CHECK: %[[VAL_37:.*]] = memref.alloc() : memref<4xi1>
+// CHECK: %[[VAL_38:.*]] = memref.alloc() : memref<4xindex>
+// CHECK: %[[VAL_39:.*]] = memref.cast %[[VAL_38]] : memref<4xindex> to memref<?xindex>
+// CHECK: linalg.fill ins(%[[VAL_10]] : f64) outs(%[[VAL_36]] : memref<4xf64>)
+// CHECK: linalg.fill ins(%[[VAL_13]] : i1) outs(%[[VAL_37]] : memref<4xi1>)
+// CHECK: %[[VAL_40:.*]]:4 = scf.for %[[VAL_41:.*]] = %[[VAL_11]] to %[[VAL_8]] step %[[VAL_12]] iter_args(%[[VAL_42:.*]] = %[[VAL_32]], %[[VAL_43:.*]] = %[[VAL_18]], %[[VAL_44:.*]] = %[[VAL_20]], %[[VAL_45:.*]] = %[[VAL_35]]) -> (memref<?xindex>, memref<?xindex>, memref<?xf64>, !sparse_tensor.storage_specifier
+// CHECK: %[[VAL_46:.*]] = memref.load %[[VAL_0]]{{\[}}%[[VAL_41]]] : memref<?xindex>
+// CHECK: %[[VAL_47:.*]] = arith.addi %[[VAL_41]], %[[VAL_12]] : index
+// CHECK: %[[VAL_48:.*]] = memref.load %[[VAL_0]]{{\[}}%[[VAL_47]]] : memref<?xindex>
+// CHECK: %[[VAL_49:.*]] = scf.for %[[VAL_50:.*]] = %[[VAL_46]] to %[[VAL_48]] step %[[VAL_12]] iter_args(%[[VAL_51:.*]] = %[[VAL_11]]) -> (index) {
+// CHECK: %[[VAL_52:.*]] = memref.load %[[VAL_1]]{{\[}}%[[VAL_50]]] : memref<?xindex>
+// CHECK: %[[VAL_53:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_50]]] : memref<?xf64>
+// CHECK: %[[VAL_54:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_52]]] : memref<?xindex>
+// CHECK: %[[VAL_55:.*]] = arith.addi %[[VAL_52]], %[[VAL_12]] : index
+// CHECK: %[[VAL_56:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_55]]] : memref<?xindex>
+// CHECK: %[[VAL_57:.*]] = scf.for %[[VAL_58:.*]] = %[[VAL_54]] to %[[VAL_56]] step %[[VAL_12]] iter_args(%[[VAL_59:.*]] = %[[VAL_51]]) -> (index) {
+// CHECK: %[[VAL_60:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_58]]] : memref<?xindex>
+// CHECK: %[[VAL_61:.*]] = memref.load %[[VAL_36]]{{\[}}%[[VAL_60]]] : memref<4xf64>
+// CHECK: %[[VAL_62:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_58]]] : memref<?xf64>
+// CHECK: %[[VAL_63:.*]] = arith.mulf %[[VAL_53]], %[[VAL_62]] : f64
+// CHECK: %[[VAL_64:.*]] = arith.addf %[[VAL_61]], %[[VAL_63]] : f64
+// CHECK: %[[VAL_65:.*]] = memref.load %[[VAL_37]]{{\[}}%[[VAL_60]]] : memref<4xi1>
+// CHECK: %[[VAL_66:.*]] = arith.cmpi eq, %[[VAL_65]], %[[VAL_13]] : i1
+// CHECK: %[[VAL_67:.*]] = scf.if %[[VAL_66]] -> (index) {
+// CHECK: memref.store %[[VAL_14]], %[[VAL_37]]{{\[}}%[[VAL_60]]] : memref<4xi1>
+// CHECK: memref.store %[[VAL_60]], %[[VAL_38]]{{\[}}%[[VAL_59]]] : memref<4xindex>
+// CHECK: %[[VAL_68:.*]] = arith.addi %[[VAL_59]], %[[VAL_12]] : index
+// CHECK: scf.yield %[[VAL_68]] : index
// CHECK: } else {
-// CHECK: scf.yield %[[VAL_50]] : index
+// CHECK: scf.yield %[[VAL_59]] : index
// CHECK: }
-// CHECK: memref.store %[[VAL_55]], %[[VAL_26]]{{\[}}%[[VAL_51]]] : memref<4xf64>
-// CHECK: scf.yield %[[VAL_60:.*]] : index
+// CHECK: memref.store %[[VAL_64]], %[[VAL_36]]{{\[}}%[[VAL_60]]] : memref<4xf64>
+// CHECK: scf.yield %[[VAL_69:.*]] : index
// CHECK: } {"Emitted from" = "linalg.generic"}
-// CHECK: sparse_tensor.sort %[[VAL_62:.*]], %[[VAL_29]] : memref<?xindex>
-// CHECK: %[[VAL_63:.*]]:5 = scf.for %[[VAL_64:.*]] = %[[VAL_12]] to %[[VAL_62]] step %[[VAL_13]] iter_args(%[[VAL_65:.*]] = %[[VAL_32]], %[[VAL_66:.*]] = %[[VAL_33]], %[[VAL_67:.*]] = %[[VAL_34]], %[[VAL_68:.*]] = %[[VAL_35]], %[[VAL_69:.*]] = %[[VAL_36]]) -> (memref<2xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf64>) {
-// CHECK: %[[VAL_70:.*]] = memref.load %[[VAL_28]]{{\[}}%[[VAL_64]]] : memref<4xindex>
-// CHECK: %[[VAL_71:.*]] = memref.load %[[VAL_26]]{{\[}}%[[VAL_70]]] : memref<4xf64>
-// CHECK: %[[VAL_72:.*]]:5 = func.call @_insert_D_C_4_4_f64_0_0(%[[VAL_65]], %[[VAL_66]], %[[VAL_67]], %[[VAL_68]], %[[VAL_69]], %[[VAL_31]], %[[VAL_70]], %[[VAL_71]]) : (memref<2xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf64>, index, index, f64) -> (memref<2xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf64>)
-// CHECK: memref.store %[[VAL_11]], %[[VAL_26]]{{\[}}%[[VAL_70]]] : memref<4xf64>
-// CHECK: memref.store %[[VAL_14]], %[[VAL_27]]{{\[}}%[[VAL_70]]] : memref<4xi1>
-// CHECK: scf.yield %[[VAL_72]]#0, %[[VAL_72]]#1, %[[VAL_72]]#2, %[[VAL_72]]#3, %[[VAL_72]]#4 : memref<2xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf64>
+// CHECK: scf.yield %[[VAL_70:.*]] : index
+// CHECK: } {"Emitted from" = "linalg.generic"}
+// CHECK: sparse_tensor.sort %[[VAL_71:.*]], %[[VAL_39]] : memref<?xindex>
+// CHECK: %[[VAL_72:.*]]:4 = scf.for %[[VAL_73:.*]] = %[[VAL_11]] to %[[VAL_71]] step %[[VAL_12]] iter_args(%[[VAL_74:.*]] = %[[VAL_42]], %[[VAL_75:.*]] = %[[VAL_43]], %[[VAL_76:.*]] = %[[VAL_44]], %[[VAL_77:.*]] = %[[VAL_45]]) -> (memref<?xindex>, memref<?xindex>, memref<?xf64>, !sparse_tensor.storage_specifier
+// CHECK: %[[VAL_78:.*]] = memref.load %[[VAL_38]]{{\[}}%[[VAL_73]]] : memref<4xindex>
+// CHECK: %[[VAL_79:.*]] = memref.load %[[VAL_36]]{{\[}}%[[VAL_78]]] : memref<4xf64>
+// CHECK: %[[VAL_80:.*]]:4 = func.call @_insert_D_C_4_4_f64_0_0(%[[VAL_74]], %[[VAL_75]], %[[VAL_76]], %[[VAL_77]], %[[VAL_41]], %[[VAL_78]], %[[VAL_79]]) : (memref<?xindex>, memref<?xindex>, memref<?xf64>, !sparse_tensor.storage_specifier
+// CHECK: memref.store %[[VAL_10]], %[[VAL_36]]{{\[}}%[[VAL_78]]] : memref<4xf64>
+// CHECK: memref.store %[[VAL_13]], %[[VAL_37]]{{\[}}%[[VAL_78]]] : memref<4xi1>
+// CHECK: scf.yield %[[VAL_80]]#0, %[[VAL_80]]#1, %[[VAL_80]]#2, %[[VAL_80]]#3 : memref<?xindex>, memref<?xindex>, memref<?xf64>, !sparse_tensor.storage_specifier
// CHECK: }
-// CHECK: scf.yield %[[VAL_73:.*]]#0, %[[VAL_73]]#1, %[[VAL_73]]#2, %[[VAL_73]]#3, %[[VAL_73]]#4 : memref<2xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf64>
+// CHECK: scf.yield %[[VAL_81:.*]]#0, %[[VAL_81]]#1, %[[VAL_81]]#2, %[[VAL_81]]#3 : memref<?xindex>, memref<?xindex>, memref<?xf64>, !sparse_tensor.storage_specifier
// CHECK: } {"Emitted from" = "linalg.generic"}
-// CHECK: memref.dealloc %[[VAL_26]] : memref<4xf64>
-// CHECK: memref.dealloc %[[VAL_27]] : memref<4xi1>
-// CHECK: memref.dealloc %[[VAL_28]] : memref<4xindex>
-// CHECK: %[[VAL_74:.*]] = memref.load %[[VAL_75:.*]]#1{{\[}}%[[VAL_12]]] : memref<3xindex>
-// CHECK: %[[VAL_76:.*]] = memref.load %[[VAL_75]]#2{{\[}}%[[VAL_12]]] : memref<?xindex>
-// CHECK: %[[VAL_77:.*]] = scf.for %[[VAL_78:.*]] = %[[VAL_13]] to %[[VAL_74]] step %[[VAL_13]] iter_args(%[[VAL_79:.*]] = %[[VAL_76]]) -> (index) {
-// CHECK: %[[VAL_80:.*]] = memref.load %[[VAL_75]]#2{{\[}}%[[VAL_78]]] : memref<?xindex>
-// CHECK: %[[VAL_81:.*]] = arith.cmpi eq, %[[VAL_80]], %[[VAL_12]] : index
-// CHECK: %[[VAL_82:.*]] = arith.select %[[VAL_81]], %[[VAL_79]], %[[VAL_80]] : index
-// CHECK: scf.if %[[VAL_81]] {
-// CHECK: memref.store %[[VAL_79]], %[[VAL_75]]#2{{\[}}%[[VAL_78]]] : memref<?xindex>
+// CHECK: memref.dealloc %[[VAL_36]] : memref<4xf64>
+// CHECK: memref.dealloc %[[VAL_37]] : memref<4xi1>
+// CHECK: memref.dealloc %[[VAL_38]] : memref<4xindex>
+// CHECK: %[[VAL_82:.*]] = sparse_tensor.storage_specifier.get %[[VAL_83:.*]]#3 ptr_mem_sz at 1 : !sparse_tensor.storage_specifier
+// CHECK: %[[VAL_84:.*]] = arith.index_cast %[[VAL_82]] : i64 to index
+// CHECK: %[[VAL_85:.*]] = memref.load %[[VAL_83]]#0{{\[}}%[[VAL_11]]] : memref<?xindex>
+// CHECK: %[[VAL_86:.*]] = scf.for %[[VAL_87:.*]] = %[[VAL_12]] to %[[VAL_84]] step %[[VAL_12]] iter_args(%[[VAL_88:.*]] = %[[VAL_85]]) -> (index) {
+// CHECK: %[[VAL_89:.*]] = memref.load %[[VAL_83]]#0{{\[}}%[[VAL_87]]] : memref<?xindex>
+// CHECK: %[[VAL_90:.*]] = arith.cmpi eq, %[[VAL_89]], %[[VAL_11]] : index
+// CHECK: %[[VAL_91:.*]] = arith.select %[[VAL_90]], %[[VAL_88]], %[[VAL_89]] : index
+// CHECK: scf.if %[[VAL_90]] {
+// CHECK: memref.store %[[VAL_88]], %[[VAL_83]]#0{{\[}}%[[VAL_87]]] : memref<?xindex>
// CHECK: }
-// CHECK: scf.yield %[[VAL_82]] : index
+// CHECK: scf.yield %[[VAL_91]] : index
// CHECK: }
-// CHECK: return %[[VAL_75]]#0, %[[VAL_75]]#1, %[[VAL_75]]#2, %[[VAL_75]]#3, %[[VAL_75]]#4 : memref<2xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf64>
-// CHECK: }
+// CHECK: return %[[VAL_83]]#0, %[[VAL_83]]#1, %[[VAL_83]]#2, %[[VAL_83]]#3 : memref<?xindex>, memref<?xindex>, memref<?xf64>, !sparse_tensor.storage_specifier
func.func @matmul(%A: tensor<4x8xf64, #CSR>,
%B: tensor<8x4xf64, #CSR>) -> tensor<4x4xf64, #CSR> {
%C = bufferization.alloc_tensor() : tensor<4x4xf64, #CSR>
diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_push_back.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_push_back.mlir
index 626c6f8f520c..7f385e5973d0 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_push_back.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_push_back.mlir
@@ -24,16 +24,15 @@ module {
%buffer = memref.alloc(%c1) : memref<?xf32>
memref.store %c0, %bufferSizes[%c0] : memref<?xindex>
- %buffer2 = sparse_tensor.push_back %bufferSizes, %buffer, %d2 {idx=0 : index} : memref<?xindex>, memref<?xf32>, f32
- %buffer3 = sparse_tensor.push_back %bufferSizes, %buffer2, %d1, %c10 {idx=0 : index} : memref<?xindex>, memref<?xf32>, f32, index
+ %buffer2, %s0 = sparse_tensor.push_back %c0, %buffer, %d2 : index, memref<?xf32>, f32
+ %buffer3, %s1 = sparse_tensor.push_back %s0, %buffer2, %d1, %c10 : index, memref<?xf32>, f32, index
// CHECK: 16
%capacity = memref.dim %buffer3, %c0 : memref<?xf32>
vector.print %capacity : index
- // CHECK: ( 11 )
- %size = vector.transfer_read %bufferSizes[%c0], %c0: memref<?xindex>, vector<1xindex>
- vector.print %size : vector<1xindex>
+ // CHECK: 11
+ vector.print %s1 : index
// CHECK ( 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 )
%values = vector.transfer_read %buffer3[%c0], %d0: memref<?xf32>, vector<11xf32>
More information about the Mlir-commits
mailing list