[Mlir-commits] [mlir] 10033a1 - Revert "[mlir][sparse] Refactoring: abstract sparse tensor memory scheme into a SparseTensorDescriptor class."
Stella Stamenova
llvmlistbot at llvm.org
Mon Dec 5 17:21:12 PST 2022
Author: Stella Stamenova
Date: 2022-12-05T17:20:01-08:00
New Revision: 10033a179f0c73f28f051ac70b058a0c61882e3a
URL: https://github.com/llvm/llvm-project/commit/10033a179f0c73f28f051ac70b058a0c61882e3a
DIFF: https://github.com/llvm/llvm-project/commit/10033a179f0c73f28f051ac70b058a0c61882e3a.diff
LOG: Revert "[mlir][sparse] Refactoring: abstract sparse tensor memory scheme into a SparseTensorDescriptor class."
This reverts commit 8a7e69d145ff72e7e4fc10ce6b81c3aa4794201c.
This broke the windows mlir buildbot: https://lab.llvm.org/buildbot/#/builders/13/builds/29257
Added:
Modified:
mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
index e9b63a6f6da1b..52f9fef7041cc 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
@@ -75,21 +75,6 @@ inline bool isSingletonDim(RankedTensorType type, uint64_t d) {
return isSingletonDLT(getDimLevelType(type, d));
}
-/// Convenience function to test for dense dimension (0 <= d < rank).
-inline bool isDenseDim(SparseTensorEncodingAttr enc, uint64_t d) {
- return isDenseDLT(getDimLevelType(enc, d));
-}
-
-/// Convenience function to test for compressed dimension (0 <= d < rank).
-inline bool isCompressedDim(SparseTensorEncodingAttr enc, uint64_t d) {
- return isCompressedDLT(getDimLevelType(enc, d));
-}
-
-/// Convenience function to test for singleton dimension (0 <= d < rank).
-inline bool isSingletonDim(SparseTensorEncodingAttr enc, uint64_t d) {
- return isSingletonDLT(getDimLevelType(enc, d));
-}
-
//
// Dimension level properties.
//
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
index cd2fb58e9d7d4..2ac3f3bb07298 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
@@ -90,115 +90,6 @@ static Value genIndexAndValueForDense(OpBuilder &builder, Location loc,
return val;
}
-void sparse_tensor::foreachFieldInSparseTensor(
- const SparseTensorEncodingAttr enc,
- llvm::function_ref<bool(unsigned, SparseTensorFieldKind, unsigned,
- DimLevelType)>
- callback) {
- assert(enc);
-
-#define RETURN_ON_FALSE(idx, kind, dim, dlt) \
- if (!(callback(idx, kind, dim, dlt))) \
- return;
-
- RETURN_ON_FALSE(dimSizesIdx, SparseTensorFieldKind::DimSizes, -1u,
- DimLevelType::Undef);
- RETURN_ON_FALSE(memSizesIdx, SparseTensorFieldKind::MemSizes, -1u,
- DimLevelType::Undef);
-
- static_assert(dataFieldIdx == memSizesIdx + 1);
- unsigned fieldIdx = dataFieldIdx;
- // Per-dimension storage.
- for (unsigned r = 0, rank = enc.getDimLevelType().size(); r < rank; r++) {
- // Dimension level types apply in order to the reordered dimension.
- // As a result, the compound type can be constructed directly in the given
- // order.
- auto dlt = getDimLevelType(enc, r);
- if (isCompressedDLT(dlt)) {
- RETURN_ON_FALSE(fieldIdx++, SparseTensorFieldKind::PtrMemRef, r, dlt);
- RETURN_ON_FALSE(fieldIdx++, SparseTensorFieldKind::IdxMemRef, r, dlt);
- } else if (isSingletonDLT(dlt)) {
- RETURN_ON_FALSE(fieldIdx++, SparseTensorFieldKind::IdxMemRef, r, dlt);
- } else {
- assert(isDenseDLT(dlt)); // no fields
- }
- }
- // The values array.
- RETURN_ON_FALSE(fieldIdx++, SparseTensorFieldKind::ValMemRef, -1u,
- DimLevelType::Undef);
-
-#undef RETURN_ON_FALSE
-}
-
-void sparse_tensor::foreachFieldAndTypeInSparseTensor(
- RankedTensorType rType,
- llvm::function_ref<bool(Type, unsigned, SparseTensorFieldKind, unsigned,
- DimLevelType)>
- callback) {
- auto enc = getSparseTensorEncoding(rType);
- assert(enc);
- // Construct the basic types.
- Type indexType = IndexType::get(enc.getContext());
- Type idxType = enc.getIndexType();
- Type ptrType = enc.getPointerType();
- Type eltType = rType.getElementType();
- unsigned rank = rType.getShape().size();
- // memref<rank x index> dimSizes
- Type dimSizeType = MemRefType::get({rank}, indexType);
- // memref<n x index> memSizes
- Type memSizeType =
- MemRefType::get({getNumDataFieldsFromEncoding(enc)}, indexType);
- // memref<? x ptr> pointers
- Type ptrMemType = MemRefType::get({ShapedType::kDynamic}, ptrType);
- // memref<? x idx> indices
- Type idxMemType = MemRefType::get({ShapedType::kDynamic}, idxType);
- // memref<? x eltType> values
- Type valMemType = MemRefType::get({ShapedType::kDynamic}, eltType);
-
- foreachFieldInSparseTensor(
- enc,
- [dimSizeType, memSizeType, ptrMemType, idxMemType, valMemType,
- callback](unsigned fieldIdx, SparseTensorFieldKind fieldKind,
- unsigned dim, DimLevelType dlt) -> bool {
- switch (fieldKind) {
- case SparseTensorFieldKind::DimSizes:
- return callback(dimSizeType, fieldIdx, fieldKind, dim, dlt);
- case SparseTensorFieldKind::MemSizes:
- return callback(memSizeType, fieldIdx, fieldKind, dim, dlt);
- case SparseTensorFieldKind::PtrMemRef:
- return callback(ptrMemType, fieldIdx, fieldKind, dim, dlt);
- case SparseTensorFieldKind::IdxMemRef:
- return callback(idxMemType, fieldIdx, fieldKind, dim, dlt);
- case SparseTensorFieldKind::ValMemRef:
- return callback(valMemType, fieldIdx, fieldKind, dim, dlt);
- };
- });
-}
-
-unsigned sparse_tensor::getNumFieldsFromEncoding(SparseTensorEncodingAttr enc) {
- unsigned numFields = 0;
- foreachFieldInSparseTensor(enc,
- [&numFields](unsigned, SparseTensorFieldKind,
- unsigned, DimLevelType) -> bool {
- numFields++;
- return true;
- });
- return numFields;
-}
-
-unsigned
-sparse_tensor::getNumDataFieldsFromEncoding(SparseTensorEncodingAttr enc) {
- unsigned numFields = 0; // one value memref
- foreachFieldInSparseTensor(enc,
- [&numFields](unsigned fidx, SparseTensorFieldKind,
- unsigned, DimLevelType) -> bool {
- if (fidx >= dataFieldIdx)
- numFields++;
- return true;
- });
- assert(numFields == getNumFieldsFromEncoding(enc) - dataFieldIdx);
- return numFields;
-}
//===----------------------------------------------------------------------===//
// Sparse tensor loop emitter class implementations
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
index 4c861358893f1..bafe752b03d55 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
@@ -311,222 +311,8 @@ inline bool isZeroRankedTensorOrScalar(Type type) {
}
//===----------------------------------------------------------------------===//
-// SparseTensorDescriptor and helpers, manage the sparse tensor memory layout
-// scheme.
-//
-// Sparse tensor storage scheme for rank-dimensional tensor is organized
-// as a single compound type with the following fields. Note that every
-// memref with ? size actually behaves as a "vector", i.e. the stored
-// size is the capacity and the used size resides in the memSizes array.
-//
-// struct {
-// memref<rank x index> dimSizes ; size in each dimension
-// memref<n x index> memSizes ; sizes of ptrs/inds/values
-// ; per-dimension d:
-// ; if dense:
-// <nothing>
-// ; if compresed:
-// memref<? x ptr> pointers-d ; pointers for sparse dim d
-// memref<? x idx> indices-d ; indices for sparse dim d
-// ; if singleton:
-// memref<? x idx> indices-d ; indices for singleton dim d
-// memref<? x eltType> values ; values
-// };
-//
-//===----------------------------------------------------------------------===//
-enum class SparseTensorFieldKind {
- DimSizes,
- MemSizes,
- PtrMemRef,
- IdxMemRef,
- ValMemRef
-};
-
-constexpr uint64_t dimSizesIdx = 0;
-constexpr uint64_t memSizesIdx = dimSizesIdx + 1;
-constexpr uint64_t dataFieldIdx = memSizesIdx + 1;
-
-/// For each field that will be allocated for the given sparse tensor encoding,
-/// calls the callback with the corresponding field index, field kind, dimension
-/// (for sparse tensor level memrefs) and dimlevelType.
-/// The field index always starts with zero and increments by one between two
-/// callback invocations.
-/// Ideally, all other methods should rely on this function to query a sparse
-/// tensor fields instead of relying on ad-hoc index computation.
-void foreachFieldInSparseTensor(
- SparseTensorEncodingAttr,
- llvm::function_ref<bool(unsigned /*fieldIdx*/,
- SparseTensorFieldKind /*fieldKind*/,
- unsigned /*dim (if applicable)*/,
- DimLevelType /*DLT (if applicable)*/)>);
-
-/// Same as above, except that it also builds the Type for the corresponding
-/// field.
-void foreachFieldAndTypeInSparseTensor(
- RankedTensorType,
- llvm::function_ref<bool(Type /*fieldType*/, unsigned /*fieldIdx*/,
- SparseTensorFieldKind /*fieldKind*/,
- unsigned /*dim (if applicable)*/,
- DimLevelType /*DLT (if applicable)*/)>);
-
-/// Gets the total number of fields for the given sparse tensor encoding.
-unsigned getNumFieldsFromEncoding(SparseTensorEncodingAttr enc);
-
-/// Gets the total number of data fields (index arrays, pointer arrays, and a
-/// value array) for the given sparse tensor encoding.
-unsigned getNumDataFieldsFromEncoding(SparseTensorEncodingAttr enc);
-
-/// Get the index of the field in memSizes (only valid for data fields).
-inline unsigned getFieldMemSizesIndex(unsigned fid) {
- assert(fid >= dataFieldIdx);
- return fid - dataFieldIdx;
-}
-
-/// A helper class around an array of values that corresponding to a sparse
-/// tensor, provides a set of meaningful APIs to query and update a particular
-/// field in a consistent way.
-/// Users should not make assumption on how a sparse tensor is laid out but
-/// instead relies on this class to access the right value for the right field.
-template <bool mut>
-class SparseTensorDescriptorImpl {
-private:
- template <bool>
- struct ArrayStorage;
-
- template <>
- struct ArrayStorage<false> {
- using ValueArray = ValueRange;
- };
-
- template <>
- struct ArrayStorage<true> {
- using ValueArray = SmallVectorImpl<Value> &;
- };
-
- // Uses ValueRange for immuatable descriptors; uses SmallVectorImpl<Value> &
- // for mutable descriptors.
- // Using SmallVector for mutable descriptor allows users to reuse it as a tmp
- // buffers to append value for some special cases, though users should be
- // responsible to restore the buffer to legal states after their use. It is
- // probably not a clean way, but it is the most efficient way to avoid copying
- // the fields into another SmallVector. If a more clear way is wanted, we
- // should change it to MutableArrayRef instead.
- using Storage = typename ArrayStorage<mut>::ValueArray;
-
-public:
- SparseTensorDescriptorImpl(Type tp, Storage fields)
- : rType(tp.cast<RankedTensorType>()), fields(fields) {
- assert(getSparseTensorEncoding(tp) &&
- getNumFieldsFromEncoding(getSparseTensorEncoding(tp)) ==
- fields.size());
- // We should make sure the class is trivially copyable (and should be small
- // enough) such that we can pass it by value.
- static_assert(
- std::is_trivially_copyable_v<SparseTensorDescriptorImpl<mut>>);
- }
-
- // Implicit (and cheap) type conversion from MutSparseTensorDescriptor to
- // SparseTensorDescriptor.
- template <typename T = SparseTensorDescriptorImpl<true>>
- /*implicit*/ SparseTensorDescriptorImpl(std::enable_if_t<!mut, T> &mDesc)
- : rType(mDesc.getTensorType()), fields(mDesc.getFields()) {}
-
- ///
- /// Getters: get the field index for required field.
- ///
-
- unsigned getPtrMemRefIndex(unsigned ptrDim) const {
- return getFieldIndex(ptrDim, SparseTensorFieldKind::PtrMemRef);
- }
-
- unsigned getIdxMemRefIndex(unsigned idxDim) const {
- return getFieldIndex(idxDim, SparseTensorFieldKind::IdxMemRef);
- }
-
- unsigned getValMemRefIndex() const { return fields.size() - 1; }
-
- unsigned getPtrMemSizesIndex(unsigned dim) const {
- return getPtrMemRefIndex(dim) - dataFieldIdx;
- }
-
- unsigned getIdxMemSizesIndex(unsigned dim) const {
- return getIdxMemRefIndex(dim) - dataFieldIdx;
- }
-
- unsigned getValMemSizesIndex() const {
- return getValMemRefIndex() - dataFieldIdx;
- }
-
- unsigned getNumFields() const { return fields.size(); }
-
- ///
- /// Getters: get the value for required field.
- ///
-
- Value getDimSizesMemRef() const { return fields[dimSizesIdx]; }
- Value getMemSizesMemRef() const { return fields[memSizesIdx]; }
-
- Value getPtrMemRef(unsigned ptrDim) const {
- return fields[getPtrMemRefIndex(ptrDim)];
- }
-
- Value getIdxMemRef(unsigned idxDim) const {
- return fields[getIdxMemRefIndex(idxDim)];
- }
-
- Value getValMemRef() const { return fields[getValMemRefIndex()]; }
-
- Value getField(unsigned fid) const {
- assert(fid < fields.size());
- return fields[fid];
- }
-
- ///
- /// Setters: update the value for required field (only enabled for
- /// MutSparseTensorDescriptor).
- ///
-
- template <typename T = Value>
- void setField(unsigned fid, std::enable_if_t<mut, T> v) {
- assert(fid < fields.size());
- fields[fid] = v;
- }
-
- RankedTensorType getTensorType() const { return rType; }
- Storage getFields() const { return fields; }
-
- Type getElementType(unsigned fidx) const {
- return fields[fidx].getType().template cast<MemRefType>().getElementType();
- }
-
-private:
- unsigned getFieldIndex(unsigned dim, SparseTensorFieldKind kind) const {
- unsigned fieldIdx = -1u;
- foreachFieldInSparseTensor(
- getSparseTensorEncoding(rType),
- [dim, kind, &fieldIdx](unsigned fIdx, SparseTensorFieldKind fKind,
- unsigned fDim, DimLevelType dlt) -> bool {
- if (fDim == dim && kind == fKind) {
- fieldIdx = fIdx;
- // Returns false to break the iteration.
- return false;
- }
- return true;
- });
- assert(fieldIdx != -1u);
- return fieldIdx;
- }
-
- RankedTensorType rType;
- Storage fields;
-};
-
-using SparseTensorDescriptor = SparseTensorDescriptorImpl<false>;
-using MutSparseTensorDescriptor = SparseTensorDescriptorImpl<true>;
-
-//===----------------------------------------------------------------------===//
-// SparseTensorLoopEmiter class, manages sparse tensors and helps to
-// generate loop structure to (co)-iterate sparse tensors.
+// SparseTensorLoopEmiter class, manages sparse tensors and helps to generate
+// loop structure to (co)-iterate sparse tensors.
//
// An example usage:
// To generate the following loops over T1<?x?> and T2<?x?>
@@ -559,15 +345,15 @@ class SparseTensorLoopEmitter {
using OutputUpdater = function_ref<Value(OpBuilder &builder, Location loc,
Value memref, Value tensor)>;
- /// Constructor: take an array of tensors inputs, on which the generated
- /// loops will iterate on. The index of the tensor in the array is also the
+ /// Constructor: take an array of tensors inputs, on which the generated loops
+ /// will iterate on. The index of the tensor in the array is also the
/// tensor id (tid) used in related functions.
/// If isSparseOut is set, loop emitter assume that the sparse output tensor
/// is empty, and will always generate loops on it based on the dim sizes.
/// An optional array could be provided (by sparsification) to indicate the
/// loop id sequence that will be generated. It is used to establish the
- /// mapping between affineDimExpr to the corresponding loop index in the
- /// loop stack that are maintained by the loop emitter.
+ /// mapping between affineDimExpr to the corresponding loop index in the loop
+ /// stack that are maintained by the loop emitter.
explicit SparseTensorLoopEmitter(ValueRange tensors,
StringAttr loopTag = nullptr,
bool hasOutput = false,
@@ -582,8 +368,8 @@ class SparseTensorLoopEmitter {
/// Generates a list of operations to compute the affine expression.
Value genAffine(OpBuilder &builder, AffineExpr a, Location loc);
- /// Enters a new loop sequence, the loops within the same sequence starts
- /// from the break points of previous loop instead of starting over from 0.
+ /// Enters a new loop sequence, the loops within the same sequence starts from
+ /// the break points of previous loop instead of starting over from 0.
/// e.g.,
/// {
/// // loop sequence start.
@@ -738,10 +524,10 @@ class SparseTensorLoopEmitter {
/// scf.reduce.return %val
/// }
/// }
- /// NOTE: only one instruction will be moved into reduce block,
- /// transformation will fail if multiple instructions are used to compute
- /// the reduction value. Return %ret to user, while %val is provided by
- /// users (`reduc`).
+ /// NOTE: only one instruction will be moved into reduce block, transformation
+ /// will fail if multiple instructions are used to compute the reduction
+ /// value.
+ /// Return %ret to user, while %val is provided by users (`reduc`).
void exitForLoop(RewriterBase &rewriter, Location loc,
MutableArrayRef<Value> reduc);
@@ -749,9 +535,9 @@ class SparseTensorLoopEmitter {
void exitCoIterationLoop(OpBuilder &builder, Location loc,
MutableArrayRef<Value> reduc);
- /// A optional string attribute that should be attached to the loop
- /// generated by loop emitter, it might help following passes to identify
- /// loops that operates on sparse tensors more easily.
+ /// A optional string attribute that should be attached to the loop generated
+ /// by loop emitter, it might help following passes to identify loops that
+ /// operates on sparse tensors more easily.
StringAttr loopTag;
/// Whether the loop emitter needs to treat the last tensor as the output
/// tensor.
@@ -770,8 +556,7 @@ class SparseTensorLoopEmitter {
std::vector<std::vector<Value>> idxBuffer; // to_indices
std::vector<Value> valBuffer; // to_value
- // Loop Stack, stores the information of all the nested loops that are
- // alive.
+ // Loop Stack, stores the information of all the nested loops that are alive.
std::vector<LoopLevelInfo> loopStack;
// Loop Sequence Stack, stores the unversial index for the current loop
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index e059bd36dc02c..113347cdc323a 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -36,6 +36,10 @@ using FuncGeneratorType =
static constexpr const char kInsertFuncNamePrefix[] = "_insert_";
+static constexpr uint64_t dimSizesIdx = 0;
+static constexpr uint64_t memSizesIdx = 1;
+static constexpr uint64_t fieldsIdx = 2;
+
//===----------------------------------------------------------------------===//
// Helper methods.
//===----------------------------------------------------------------------===//
@@ -45,18 +49,6 @@ static UnrealizedConversionCastOp getTuple(Value tensor) {
return llvm::cast<UnrealizedConversionCastOp>(tensor.getDefiningOp());
}
-static SparseTensorDescriptor getDescriptorFromTensorTuple(Value tensor) {
- auto tuple = getTuple(tensor);
- return SparseTensorDescriptor(tuple.getResultTypes()[0], tuple.getInputs());
-}
-
-static MutSparseTensorDescriptor
-getMutDescriptorFromTensorTuple(Value tensor, SmallVectorImpl<Value> &fields) {
- auto tuple = getTuple(tensor);
- fields.assign(tuple.getInputs().begin(), tuple.getInputs().end());
- return MutSparseTensorDescriptor(tuple.getResultTypes()[0], fields);
-}
-
/// Packs the given values as a "tuple" value.
static Value genTuple(OpBuilder &builder, Location loc, Type tp,
ValueRange values) {
@@ -64,14 +56,6 @@ static Value genTuple(OpBuilder &builder, Location loc, Type tp,
.getResult(0);
}
-static Value genTuple(OpBuilder &builder, Location loc,
- SparseTensorDescriptor desc) {
- return builder
- .create<UnrealizedConversionCastOp>(loc, desc.getTensorType(),
- desc.getFields())
- .getResult(0);
-}
-
/// Flatten a list of operands that may contain sparse tensors.
static void flattenOperands(ValueRange operands,
SmallVectorImpl<Value> &flattened) {
@@ -117,7 +101,7 @@ static void genStore(OpBuilder &builder, Location loc, Value val, Value mem,
/// Creates a straightforward counting for-loop.
static scf::ForOp createFor(OpBuilder &builder, Location loc, Value upper,
- MutableArrayRef<Value> fields,
+ SmallVectorImpl<Value> &fields,
Value lower = Value()) {
Type indexType = builder.getIndexType();
if (!lower)
@@ -134,46 +118,81 @@ static scf::ForOp createFor(OpBuilder &builder, Location loc, Value upper,
/// original dimension 'dim'. Returns std::nullopt if no sparse encoding is
/// attached to the given tensor type.
static Optional<Value> sizeFromTensorAtDim(OpBuilder &builder, Location loc,
- SparseTensorDescriptor desc,
- unsigned dim) {
- RankedTensorType rtp = desc.getTensorType();
+ RankedTensorType tensorTp,
+ Value adaptedValue, unsigned dim) {
+ auto enc = getSparseTensorEncoding(tensorTp);
+ if (!enc)
+ return std::nullopt;
+
// Access into static dimension can query original type directly.
// Note that this is typically already done by DimOp's folding.
- auto shape = rtp.getShape();
+ auto shape = tensorTp.getShape();
if (!ShapedType::isDynamic(shape[dim]))
return constantIndex(builder, loc, shape[dim]);
// Any other query can consult the dimSizes array at field DimSizesIdx,
// accounting for the reordering applied to the sparse storage.
- Value idx = constantIndex(builder, loc, toStoredDim(rtp, dim));
- return builder.create<memref::LoadOp>(loc, desc.getDimSizesMemRef(), idx)
+ auto tuple = getTuple(adaptedValue);
+ Value idx = constantIndex(builder, loc, toStoredDim(tensorTp, dim));
+ return builder
+ .create<memref::LoadOp>(loc, tuple.getInputs()[dimSizesIdx], idx)
.getResult();
}
// Gets the dimension size at the given stored dimension 'd', either as a
// constant for a static size, or otherwise dynamically through memSizes.
-Value sizeAtStoredDim(OpBuilder &builder, Location loc,
- SparseTensorDescriptor desc, unsigned d) {
- RankedTensorType rtp = desc.getTensorType();
+Value sizeAtStoredDim(OpBuilder &builder, Location loc, RankedTensorType rtp,
+ SmallVectorImpl<Value> &fields, unsigned d) {
unsigned dim = toOrigDim(rtp, d);
auto shape = rtp.getShape();
if (!ShapedType::isDynamic(shape[dim]))
return constantIndex(builder, loc, shape[dim]);
-
- return genLoad(builder, loc, desc.getDimSizesMemRef(),
+ return genLoad(builder, loc, fields[dimSizesIdx],
constantIndex(builder, loc, d));
}
+/// Translates field index to memSizes index.
+static unsigned getMemSizesIndex(unsigned field) {
+ assert(fieldsIdx <= field);
+ return field - fieldsIdx;
+}
+
+/// Creates a pushback op for given field and updates the fields array
+/// accordingly. This operation also updates the memSizes contents.
static void createPushback(OpBuilder &builder, Location loc,
- MutSparseTensorDescriptor desc, unsigned fidx,
+ SmallVectorImpl<Value> &fields, unsigned field,
Value value, Value repeat = Value()) {
- Type etp = desc.getElementType(fidx);
- Value field = desc.getField(fidx);
- Value newField = builder.create<PushBackOp>(
- loc, field.getType(), desc.getMemSizesMemRef(), field,
- toType(builder, loc, value, etp), APInt(64, getFieldMemSizesIndex(fidx)),
+ assert(fieldsIdx <= field && field < fields.size());
+ Type etp = fields[field].getType().cast<ShapedType>().getElementType();
+ fields[field] = builder.create<PushBackOp>(
+ loc, fields[field].getType(), fields[memSizesIdx], fields[field],
+ toType(builder, loc, value, etp), APInt(64, getMemSizesIndex(field)),
repeat);
- desc.setField(fidx, newField);
+}
+
+/// Returns field index of sparse tensor type for pointers/indices, when set.
+static unsigned getFieldIndex(Type type, unsigned ptrDim, unsigned idxDim) {
+ assert(getSparseTensorEncoding(type));
+ RankedTensorType rType = type.cast<RankedTensorType>();
+ unsigned field = fieldsIdx; // start past header
+ for (unsigned r = 0, rank = rType.getShape().size(); r < rank; r++) {
+ if (isCompressedDim(rType, r)) {
+ if (r == ptrDim)
+ return field;
+ field++;
+ if (r == idxDim)
+ return field;
+ field++;
+ } else if (isSingletonDim(rType, r)) {
+ if (r == idxDim)
+ return field;
+ field++;
+ } else {
+ assert(isDenseDim(rType, r)); // no fields
+ }
+ }
+ assert(ptrDim == -1u && idxDim == -1u);
+ return field + 1; // return values field index
}
/// Maps a sparse tensor type to the appropriate compounded buffers.
@@ -182,24 +201,64 @@ convertSparseTensorType(Type type, SmallVectorImpl<Type> &fields) {
auto enc = getSparseTensorEncoding(type);
if (!enc)
return std::nullopt;
-
+ // Construct the basic types.
+ auto *context = type.getContext();
RankedTensorType rType = type.cast<RankedTensorType>();
- foreachFieldAndTypeInSparseTensor(
- rType,
- [&fields](Type fieldType, unsigned fieldIdx,
- SparseTensorFieldKind /*fieldKind*/, unsigned /*dim*/,
- DimLevelType /*dlt*/) -> bool {
- assert(fieldIdx == fields.size());
- fields.push_back(fieldType);
- return true;
- });
+ Type indexType = IndexType::get(context);
+ Type idxType = enc.getIndexType();
+ Type ptrType = enc.getPointerType();
+ Type eltType = rType.getElementType();
+ //
+ // Sparse tensor storage scheme for rank-dimensional tensor is organized
+ // as a single compound type with the following fields. Note that every
+ // memref with ? size actually behaves as a "vector", i.e. the stored
+ // size is the capacity and the used size resides in the memSizes array.
+ //
+ // struct {
+ // memref<rank x index> dimSizes ; size in each dimension
+ // memref<n x index> memSizes ; sizes of ptrs/inds/values
+ // ; per-dimension d:
+ // ; if dense:
+ // <nothing>
+ // ; if compresed:
+ // memref<? x ptr> pointers-d ; pointers for sparse dim d
+ // memref<? x idx> indices-d ; indices for sparse dim d
+ // ; if singleton:
+ // memref<? x idx> indices-d ; indices for singleton dim d
+ // memref<? x eltType> values ; values
+ // };
+ //
+ unsigned rank = rType.getShape().size();
+ unsigned lastField = getFieldIndex(type, -1u, -1u);
+ // The dimSizes array and memSizes array.
+ fields.push_back(MemRefType::get({rank}, indexType));
+ fields.push_back(MemRefType::get({getMemSizesIndex(lastField)}, indexType));
+ // Per-dimension storage.
+ for (unsigned r = 0; r < rank; r++) {
+ // Dimension level types apply in order to the reordered dimension.
+ // As a result, the compound type can be constructed directly in the given
+ // order. Clients of this type know what field is what from the sparse
+ // tensor type.
+ if (isCompressedDim(rType, r)) {
+ fields.push_back(MemRefType::get({ShapedType::kDynamic}, ptrType));
+ fields.push_back(MemRefType::get({ShapedType::kDynamic}, idxType));
+ } else if (isSingletonDim(rType, r)) {
+ fields.push_back(MemRefType::get({ShapedType::kDynamic}, idxType));
+ } else {
+ assert(isDenseDim(rType, r)); // no fields
+ }
+ }
+ // The values array.
+ fields.push_back(MemRefType::get({ShapedType::kDynamic}, eltType));
+ assert(fields.size() == lastField);
return success();
}
/// Generates code that allocates a sparse storage scheme for given rank.
static void allocSchemeForRank(OpBuilder &builder, Location loc,
- MutSparseTensorDescriptor desc, unsigned r0) {
- RankedTensorType rtp = desc.getTensorType();
+ RankedTensorType rtp,
+ SmallVectorImpl<Value> &fields, unsigned field,
+ unsigned r0) {
unsigned rank = rtp.getShape().size();
Value linear = constantIndex(builder, loc, 1);
for (unsigned r = r0; r < rank; r++) {
@@ -209,8 +268,7 @@ static void allocSchemeForRank(OpBuilder &builder, Location loc,
// the desired "linear + 1" length property at all times.
Type ptrType = getSparseTensorEncoding(rtp).getPointerType();
Value ptrZero = constantZero(builder, loc, ptrType);
- createPushback(builder, loc, desc, desc.getPtrMemRefIndex(r), ptrZero,
- linear);
+ createPushback(builder, loc, fields, field, ptrZero, linear);
return;
}
if (isSingletonDim(rtp, r)) {
@@ -220,23 +278,23 @@ static void allocSchemeForRank(OpBuilder &builder, Location loc,
// at this level. We will eventually reach a compressed level or
// otherwise the values array for the from-here "all-dense" case.
assert(isDenseDim(rtp, r));
- Value size = sizeAtStoredDim(builder, loc, desc, r);
+ Value size = sizeAtStoredDim(builder, loc, rtp, fields, r);
linear = builder.create<arith::MulIOp>(loc, linear, size);
}
// Reached values array so prepare for an insertion.
Value valZero = constantZero(builder, loc, rtp.getElementType());
- createPushback(builder, loc, desc, desc.getValMemRefIndex(), valZero, linear);
+ createPushback(builder, loc, fields, field, valZero, linear);
+ assert(fields.size() == ++field);
}
/// Creates allocation operation.
-static Value createAllocation(OpBuilder &builder, Location loc,
- MemRefType memRefType, Value sz,
- bool enableInit) {
- Value buffer = builder.create<memref::AllocOp>(loc, memRefType, sz);
- Type elemType = memRefType.getElementType();
+static Value createAllocation(OpBuilder &builder, Location loc, Type type,
+ Value sz, bool enableInit) {
+ auto memType = MemRefType::get({ShapedType::kDynamic}, type);
+ Value buffer = builder.create<memref::AllocOp>(loc, memType, sz);
if (enableInit) {
- Value fillValue = builder.create<arith::ConstantOp>(
- loc, elemType, builder.getZeroAttr(elemType));
+ Value fillValue =
+ builder.create<arith::ConstantOp>(loc, type, builder.getZeroAttr(type));
builder.create<linalg::FillOp>(loc, fillValue, buffer);
}
return buffer;
@@ -252,68 +310,69 @@ static Value createAllocation(OpBuilder &builder, Location loc,
static void createAllocFields(OpBuilder &builder, Location loc, Type type,
ValueRange dynSizes, bool enableInit,
SmallVectorImpl<Value> &fields) {
+ auto enc = getSparseTensorEncoding(type);
+ assert(enc);
RankedTensorType rtp = type.cast<RankedTensorType>();
+ Type indexType = builder.getIndexType();
+ Type idxType = enc.getIndexType();
+ Type ptrType = enc.getPointerType();
+ Type eltType = rtp.getElementType();
+ auto shape = rtp.getShape();
+ unsigned rank = shape.size();
Value heuristic = constantIndex(builder, loc, 16);
-
- foreachFieldAndTypeInSparseTensor(
- rtp,
- [&builder, &fields, loc, heuristic,
- enableInit](Type fType, unsigned fIdx, SparseTensorFieldKind fKind,
- unsigned /*dim*/, DimLevelType /*dlt*/) -> bool {
- assert(fields.size() == fIdx);
- auto memRefTp = fType.cast<MemRefType>();
- Value field;
- switch (fKind) {
- case SparseTensorFieldKind::DimSizes:
- case SparseTensorFieldKind::MemSizes:
- field = builder.create<memref::AllocOp>(loc, memRefTp);
- break;
- case SparseTensorFieldKind::PtrMemRef:
- case SparseTensorFieldKind::IdxMemRef:
- case SparseTensorFieldKind::ValMemRef:
- field =
- createAllocation(builder, loc, memRefTp, heuristic, enableInit);
- break;
- }
- assert(field);
- fields.push_back(field);
- // Returns true to continue the iteration.
- return true;
- });
-
- MutSparseTensorDescriptor desc(rtp, fields);
-
// Build original sizes.
SmallVector<Value> sizes;
- auto shape = rtp.getShape();
- unsigned rank = shape.size();
for (unsigned r = 0, o = 0; r < rank; r++) {
if (ShapedType::isDynamic(shape[r]))
sizes.push_back(dynSizes[o++]);
else
sizes.push_back(constantIndex(builder, loc, shape[r]));
}
+ // The dimSizes array and memSizes array.
+ unsigned lastField = getFieldIndex(type, -1u, -1u);
+ Value dimSizes =
+ builder.create<memref::AllocOp>(loc, MemRefType::get({rank}, indexType));
+ Value memSizes = builder.create<memref::AllocOp>(
+ loc, MemRefType::get({getMemSizesIndex(lastField)}, indexType));
+ fields.push_back(dimSizes);
+ fields.push_back(memSizes);
+ // Per-dimension storage.
+ for (unsigned r = 0; r < rank; r++) {
+ if (isCompressedDim(rtp, r)) {
+ fields.push_back(
+ createAllocation(builder, loc, ptrType, heuristic, enableInit));
+ fields.push_back(
+ createAllocation(builder, loc, idxType, heuristic, enableInit));
+ } else if (isSingletonDim(rtp, r)) {
+ fields.push_back(
+ createAllocation(builder, loc, idxType, heuristic, enableInit));
+ } else {
+ assert(isDenseDim(rtp, r)); // no fields
+ }
+ }
+ // The values array.
+ fields.push_back(
+ createAllocation(builder, loc, eltType, heuristic, enableInit));
+ assert(fields.size() == lastField);
// Initialize the storage scheme to an empty tensor. Initialized memSizes
// to all zeros, sets the dimSizes to known values and gives all pointer
// fields an initial zero entry, so that it is easier to maintain the
// "linear + 1" length property.
builder.create<linalg::FillOp>(
- loc, constantZero(builder, loc, builder.getIndexType()),
- desc.getMemSizesMemRef()); // zero memSizes
-
- Value ptrZero =
- constantZero(builder, loc, getSparseTensorEncoding(rtp).getPointerType());
- for (unsigned r = 0; r < rank; r++) {
+ loc, ValueRange{constantZero(builder, loc, indexType)},
+ ValueRange{memSizes}); // zero memSizes
+ Value ptrZero = constantZero(builder, loc, ptrType);
+ for (unsigned r = 0, field = fieldsIdx; r < rank; r++) {
unsigned ro = toOrigDim(rtp, r);
- // Fills dim sizes array.
- genStore(builder, loc, sizes[ro], desc.getDimSizesMemRef(),
- constantIndex(builder, loc, r));
-
- // Pushes a leading zero to pointers memref.
- if (isCompressedDim(rtp, r))
- createPushback(builder, loc, desc, desc.getPtrMemRefIndex(r), ptrZero);
+ genStore(builder, loc, sizes[ro], dimSizes, constantIndex(builder, loc, r));
+ if (isCompressedDim(rtp, r)) {
+ createPushback(builder, loc, fields, field, ptrZero);
+ field += 2;
+ } else if (isSingletonDim(rtp, r)) {
+ field += 1;
+ }
}
- allocSchemeForRank(builder, loc, desc, /*rank=*/0);
+ allocSchemeForRank(builder, loc, rtp, fields, fieldsIdx, /*rank=*/0);
}
/// Helper method that generates block specific to compressed case:
@@ -337,22 +396,19 @@ static void createAllocFields(OpBuilder &builder, Location loc, Type type,
/// }
/// pos[d] = next
static Value genCompressed(OpBuilder &builder, Location loc,
- MutSparseTensorDescriptor desc,
+ RankedTensorType rtp, SmallVectorImpl<Value> &fields,
SmallVectorImpl<Value> &indices, Value value,
- Value pos, unsigned d) {
- RankedTensorType rtp = desc.getTensorType();
+ Value pos, unsigned field, unsigned d) {
unsigned rank = rtp.getShape().size();
SmallVector<Type> types;
Type indexType = builder.getIndexType();
Type boolType = builder.getIntegerType(1);
- unsigned idxIndex = desc.getIdxMemRefIndex(d);
- unsigned ptrIndex = desc.getPtrMemRefIndex(d);
Value one = constantIndex(builder, loc, 1);
Value pp1 = builder.create<arith::AddIOp>(loc, pos, one);
- Value plo = genLoad(builder, loc, desc.getField(ptrIndex), pos);
- Value phi = genLoad(builder, loc, desc.getField(ptrIndex), pp1);
- Value psz = constantIndex(builder, loc, getFieldMemSizesIndex(idxIndex));
- Value msz = genLoad(builder, loc, desc.getMemSizesMemRef(), psz);
+ Value plo = genLoad(builder, loc, fields[field], pos);
+ Value phi = genLoad(builder, loc, fields[field], pp1);
+ Value psz = constantIndex(builder, loc, getMemSizesIndex(field + 1));
+ Value msz = genLoad(builder, loc, fields[memSizesIdx], psz);
Value phim1 = builder.create<arith::SubIOp>(
loc, toType(builder, loc, phi, indexType), one);
// Conditional expression.
@@ -362,55 +418,49 @@ static Value genCompressed(OpBuilder &builder, Location loc,
scf::IfOp ifOp1 = builder.create<scf::IfOp>(loc, types, lt, /*else*/ true);
types.pop_back();
builder.setInsertionPointToStart(&ifOp1.getThenRegion().front());
- Value crd = genLoad(builder, loc, desc.getField(idxIndex), phim1);
+ Value crd = genLoad(builder, loc, fields[field + 1], phim1);
Value eq = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
toType(builder, loc, crd, indexType),
indices[d]);
builder.create<scf::YieldOp>(loc, eq);
builder.setInsertionPointToStart(&ifOp1.getElseRegion().front());
if (d > 0)
- genStore(builder, loc, msz, desc.getField(ptrIndex), pos);
+ genStore(builder, loc, msz, fields[field], pos);
builder.create<scf::YieldOp>(loc, constantI1(builder, loc, false));
builder.setInsertionPointAfter(ifOp1);
Value p = ifOp1.getResult(0);
- // If present construct. Note that for a non-unique dimension level, we
- // simply set the condition to false and rely on CSE/DCE to clean up the IR.
+ // If present construct. Note that for a non-unique dimension level, we simply
+ // set the condition to false and rely on CSE/DCE to clean up the IR.
//
// TODO: generate less temporary IR?
//
- for (unsigned i = 0, e = desc.getNumFields(); i < e; i++)
- types.push_back(desc.getField(i).getType());
+ for (unsigned i = 0, e = fields.size(); i < e; i++)
+ types.push_back(fields[i].getType());
types.push_back(indexType);
if (!isUniqueDim(rtp, d))
p = constantI1(builder, loc, false);
scf::IfOp ifOp2 = builder.create<scf::IfOp>(loc, types, p, /*else*/ true);
// If present (fields unaffected, update next to phim1).
builder.setInsertionPointToStart(&ifOp2.getThenRegion().front());
-
- // FIXME: This does not looks like a clean way, but probably the most
- // efficient way.
- desc.getFields().push_back(phim1);
- builder.create<scf::YieldOp>(loc, desc.getFields());
- desc.getFields().pop_back();
-
+ fields.push_back(phim1);
+ builder.create<scf::YieldOp>(loc, fields);
+ fields.pop_back();
// If !present (changes fields, update next).
builder.setInsertionPointToStart(&ifOp2.getElseRegion().front());
Value mszp1 = builder.create<arith::AddIOp>(loc, msz, one);
- genStore(builder, loc, mszp1, desc.getField(ptrIndex), pp1);
- createPushback(builder, loc, desc, idxIndex, indices[d]);
+ genStore(builder, loc, mszp1, fields[field], pp1);
+ createPushback(builder, loc, fields, field + 1, indices[d]);
// Prepare the next dimension "as needed".
if ((d + 1) < rank)
- allocSchemeForRank(builder, loc, desc, d + 1);
-
- desc.getFields().push_back(msz);
- builder.create<scf::YieldOp>(loc, desc.getFields());
- desc.getFields().pop_back();
-
+ allocSchemeForRank(builder, loc, rtp, fields, field + 2, d + 1);
+ fields.push_back(msz);
+ builder.create<scf::YieldOp>(loc, fields);
+ fields.pop_back();
// Update fields and return next pos.
builder.setInsertionPointAfter(ifOp2);
unsigned o = 0;
- for (unsigned i = 0, e = desc.getNumFields(); i < e; i++)
- desc.setField(i, ifOp2.getResult(o++));
+ for (unsigned i = 0, e = fields.size(); i < e; i++)
+ fields[i] = ifOp2.getResult(o++);
return ifOp2.getResult(o);
}
@@ -438,10 +488,11 @@ static void genInsertBody(OpBuilder &builder, ModuleOp module,
// Construct fields and indices arrays from parameters.
ValueRange tmp = args.drop_back(rank + 1);
SmallVector<Value> fields(tmp.begin(), tmp.end());
- MutSparseTensorDescriptor desc(rtp, fields);
tmp = args.take_back(rank + 1).drop_back();
SmallVector<Value> indices(tmp.begin(), tmp.end());
Value value = args.back();
+
+ unsigned field = fieldsIdx; // Start past header.
Value pos = constantZero(builder, loc, builder.getIndexType());
// Generate code for every dimension.
for (unsigned d = 0; d < rank; d++) {
@@ -453,35 +504,39 @@ static void genInsertBody(OpBuilder &builder, ModuleOp module,
// }
// pos[d] = indices.size() - 1
// <insert @ pos[d] at next dimension d + 1>
- pos = genCompressed(builder, loc, desc, indices, value, pos, d);
+ pos = genCompressed(builder, loc, rtp, fields, indices, value, pos, field,
+ d);
+ field += 2;
} else if (isSingletonDim(rtp, d)) {
// Create:
// indices[d].push_back(i[d])
// pos[d] = pos[d-1]
// <insert @ pos[d] at next dimension d + 1>
- createPushback(builder, loc, desc, desc.getIdxMemRefIndex(d), indices[d]);
+ createPushback(builder, loc, fields, field, indices[d]);
+ field += 1;
} else {
assert(isDenseDim(rtp, d));
// Construct the new position as:
// pos[d] = size * pos[d-1] + i[d]
// <insert @ pos[d] at next dimension d + 1>
- Value size = sizeAtStoredDim(builder, loc, desc, d);
+ Value size = sizeAtStoredDim(builder, loc, rtp, fields, d);
Value mult = builder.create<arith::MulIOp>(loc, size, pos);
pos = builder.create<arith::AddIOp>(loc, mult, indices[d]);
}
}
// Reached the actual value append/insert.
if (!isDenseDim(rtp, rank - 1))
- createPushback(builder, loc, desc, desc.getValMemRefIndex(), value);
+ createPushback(builder, loc, fields, field++, value);
else
- genStore(builder, loc, value, desc.getValMemRef(), pos);
+ genStore(builder, loc, value, fields[field++], pos);
+ assert(fields.size() == field);
builder.create<func::ReturnOp>(loc, fields);
}
/// Generates a call to a function to perform an insertion operation. If the
/// function doesn't exist yet, call `createFunc` to generate the function.
-static void genInsertionCallHelper(OpBuilder &builder,
- MutSparseTensorDescriptor desc,
+static void genInsertionCallHelper(OpBuilder &builder, RankedTensorType rtp,
+ SmallVectorImpl<Value> &fields,
SmallVectorImpl<Value> &indices, Value value,
func::FuncOp insertPoint,
StringRef namePrefix,
@@ -489,7 +544,6 @@ static void genInsertionCallHelper(OpBuilder &builder,
// The mangled name of the function has this format:
// <namePrefix>_[C|S|D]_<shape>_<ordering>_<eltType>
// _<indexBitWidth>_<pointerBitWidth>
- RankedTensorType rtp = desc.getTensorType();
SmallString<32> nameBuffer;
llvm::raw_svector_ostream nameOstream(nameBuffer);
nameOstream << namePrefix;
@@ -523,7 +577,7 @@ static void genInsertionCallHelper(OpBuilder &builder,
auto func = module.lookupSymbol<func::FuncOp>(result.getAttr());
// Construct parameters for fields and indices.
- SmallVector<Value> operands(desc.getFields().begin(), desc.getFields().end());
+ SmallVector<Value> operands(fields.begin(), fields.end());
operands.append(indices.begin(), indices.end());
operands.push_back(value);
Location loc = insertPoint.getLoc();
@@ -536,7 +590,7 @@ static void genInsertionCallHelper(OpBuilder &builder,
func = builder.create<func::FuncOp>(
loc, nameOstream.str(),
FunctionType::get(context, ValueRange(operands).getTypes(),
- ValueRange(desc.getFields()).getTypes()));
+ ValueRange(fields).getTypes()));
func.setPrivate();
createFunc(builder, module, func, rtp);
}
@@ -544,44 +598,42 @@ static void genInsertionCallHelper(OpBuilder &builder,
// Generate a call to perform the insertion and update `fields` with values
// returned from the call.
func::CallOp call = builder.create<func::CallOp>(loc, func, operands);
- for (size_t i = 0, e = desc.getNumFields(); i < e; i++) {
- desc.getFields()[i] = call.getResult(i);
+ for (size_t i = 0; i < fields.size(); i++) {
+ fields[i] = call.getResult(i);
}
}
/// Generations insertion finalization code.
-static void genEndInsert(OpBuilder &builder, Location loc,
- MutSparseTensorDescriptor desc) {
- RankedTensorType rtp = desc.getTensorType();
+static void genEndInsert(OpBuilder &builder, Location loc, RankedTensorType rtp,
+ SmallVectorImpl<Value> &fields) {
unsigned rank = rtp.getShape().size();
+ unsigned field = fieldsIdx; // start past header
for (unsigned d = 0; d < rank; d++) {
if (isCompressedDim(rtp, d)) {
// Compressed dimensions need a pointer cleanup for all entries
// that were not visited during the insertion pass.
//
- // TODO: avoid cleanup and keep compressed scheme consistent at all
- // times?
+ // TODO: avoid cleanup and keep compressed scheme consistent at all times?
//
if (d > 0) {
Type ptrType = getSparseTensorEncoding(rtp).getPointerType();
- Value ptrMemRef = desc.getPtrMemRef(d);
- Value mz = constantIndex(builder, loc, desc.getPtrMemSizesIndex(d));
- Value hi = genLoad(builder, loc, desc.getMemSizesMemRef(), mz);
+ Value mz = constantIndex(builder, loc, getMemSizesIndex(field));
+ Value hi = genLoad(builder, loc, fields[memSizesIdx], mz);
Value zero = constantIndex(builder, loc, 0);
Value one = constantIndex(builder, loc, 1);
// Vector of only one, but needed by createFor's prototype.
- SmallVector<Value, 1> inits{genLoad(builder, loc, ptrMemRef, zero)};
+ SmallVector<Value, 1> inits{genLoad(builder, loc, fields[field], zero)};
scf::ForOp loop = createFor(builder, loc, hi, inits, one);
Value i = loop.getInductionVar();
Value oldv = loop.getRegionIterArg(0);
- Value newv = genLoad(builder, loc, ptrMemRef, i);
+ Value newv = genLoad(builder, loc, fields[field], i);
Value ptrZero = constantZero(builder, loc, ptrType);
Value cond = builder.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::eq, newv, ptrZero);
scf::IfOp ifOp = builder.create<scf::IfOp>(loc, TypeRange(ptrType),
cond, /*else*/ true);
builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
- genStore(builder, loc, oldv, ptrMemRef, i);
+ genStore(builder, loc, oldv, fields[field], i);
builder.create<scf::YieldOp>(loc, oldv);
builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
builder.create<scf::YieldOp>(loc, newv);
@@ -589,10 +641,14 @@ static void genEndInsert(OpBuilder &builder, Location loc,
builder.create<scf::YieldOp>(loc, ifOp.getResult(0));
builder.setInsertionPointAfter(loop);
}
+ field += 2;
+ } else if (isSingletonDim(rtp, d)) {
+ field++;
} else {
- assert(isDenseDim(rtp, d) || isSingletonDim(rtp, d));
+ assert(isDenseDim(rtp, d));
}
}
+ assert(fields.size() == ++field);
}
//===----------------------------------------------------------------------===//
@@ -683,12 +739,12 @@ class SparseDimOpConverter : public OpConversionPattern<tensor::DimOp> {
matchAndRewrite(tensor::DimOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Optional<int64_t> index = op.getConstantIndex();
- if (!index || !getSparseTensorEncoding(adaptor.getSource().getType()))
+ if (!index)
return failure();
-
- auto desc = getDescriptorFromTensorTuple(adaptor.getSource());
- auto sz = sizeFromTensorAtDim(rewriter, op.getLoc(), desc, *index);
-
+ auto sz =
+ sizeFromTensorAtDim(rewriter, op.getLoc(),
+ op.getSource().getType().cast<RankedTensorType>(),
+ adaptor.getSource(), *index);
if (!sz)
return failure();
@@ -778,14 +834,16 @@ class SparseTensorLoadConverter : public OpConversionPattern<LoadOp> {
LogicalResult
matchAndRewrite(LoadOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- // Prepare descriptor.
- SmallVector<Value> fields;
- auto desc = getMutDescriptorFromTensorTuple(adaptor.getTensor(), fields);
+ RankedTensorType srcType =
+ op.getTensor().getType().cast<RankedTensorType>();
+ auto tuple = getTuple(adaptor.getTensor());
+ // Prepare fields.
+ SmallVector<Value> fields(tuple.getInputs());
// Generate optional insertion finalization code.
if (op.getHasInserts())
- genEndInsert(rewriter, op.getLoc(), desc);
+ genEndInsert(rewriter, op.getLoc(), srcType, fields);
// Replace operation with resulting memrefs.
- rewriter.replaceOp(op, genTuple(rewriter, op.getLoc(), desc));
+ rewriter.replaceOp(op, genTuple(rewriter, op.getLoc(), srcType, fields));
return success();
}
};
@@ -797,10 +855,7 @@ class SparseExpandConverter : public OpConversionPattern<ExpandOp> {
LogicalResult
matchAndRewrite(ExpandOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- if (!getSparseTensorEncoding(op.getTensor().getType()))
- return failure();
Location loc = op->getLoc();
- auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
RankedTensorType srcType =
op.getTensor().getType().cast<RankedTensorType>();
Type eltType = srcType.getElementType();
@@ -812,7 +867,8 @@ class SparseExpandConverter : public OpConversionPattern<ExpandOp> {
// dimension size, translated back to original dimension). Note that we
// recursively rewrite the new DimOp on the **original** tensor.
unsigned innerDim = toOrigDim(srcType, srcType.getRank() - 1);
- auto sz = sizeFromTensorAtDim(rewriter, loc, desc, innerDim);
+ auto sz = sizeFromTensorAtDim(rewriter, loc, srcType, adaptor.getTensor(),
+ innerDim);
assert(sz); // This for sure is a sparse tensor
// Generate a memref for `sz` elements of type `t`.
auto genAlloc = [&](Type t) {
@@ -852,15 +908,16 @@ class SparseCompressConverter : public OpConversionPattern<CompressOp> {
matchAndRewrite(CompressOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
- SmallVector<Value> fields;
- auto desc = getMutDescriptorFromTensorTuple(adaptor.getTensor(), fields);
+ RankedTensorType dstType =
+ op.getTensor().getType().cast<RankedTensorType>();
+ Type eltType = dstType.getElementType();
+ auto tuple = getTuple(adaptor.getTensor());
Value values = adaptor.getValues();
Value filled = adaptor.getFilled();
Value added = adaptor.getAdded();
Value count = adaptor.getCount();
- RankedTensorType dstType = desc.getTensorType();
- Type eltType = dstType.getElementType();
- // Prepare indices.
+ // Prepare fields and indices.
+ SmallVector<Value> fields(tuple.getInputs());
SmallVector<Value> indices(adaptor.getIndices());
// If the innermost dimension is ordered, we need to sort the indices
// in the "added" array prior to applying the compression.
@@ -882,19 +939,19 @@ class SparseCompressConverter : public OpConversionPattern<CompressOp> {
// filled[index] = false;
// yield new_memrefs
// }
- scf::ForOp loop = createFor(rewriter, loc, count, desc.getFields());
+ scf::ForOp loop = createFor(rewriter, loc, count, fields);
Value i = loop.getInductionVar();
Value index = genLoad(rewriter, loc, added, i);
Value value = genLoad(rewriter, loc, values, index);
indices.push_back(index);
// TODO: faster for subsequent insertions?
auto insertPoint = op->template getParentOfType<func::FuncOp>();
- genInsertionCallHelper(rewriter, desc, indices, value, insertPoint,
- kInsertFuncNamePrefix, genInsertBody);
+ genInsertionCallHelper(rewriter, dstType, fields, indices, value,
+ insertPoint, kInsertFuncNamePrefix, genInsertBody);
genStore(rewriter, loc, constantZero(rewriter, loc, eltType), values,
index);
genStore(rewriter, loc, constantI1(rewriter, loc, false), filled, index);
- rewriter.create<scf::YieldOp>(loc, desc.getFields());
+ rewriter.create<scf::YieldOp>(loc, fields);
rewriter.setInsertionPointAfter(loop);
Value result = genTuple(rewriter, loc, dstType, loop->getResults());
// Deallocate the buffers on exit of the full loop nest.
@@ -916,18 +973,20 @@ class SparseInsertConverter : public OpConversionPattern<InsertOp> {
LogicalResult
matchAndRewrite(InsertOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- SmallVector<Value> fields;
- auto desc = getMutDescriptorFromTensorTuple(adaptor.getTensor(), fields);
- // Prepare and indices.
+ RankedTensorType dstType =
+ op.getTensor().getType().cast<RankedTensorType>();
+ auto tuple = getTuple(adaptor.getTensor());
+ // Prepare fields and indices.
+ SmallVector<Value> fields(tuple.getInputs());
SmallVector<Value> indices(adaptor.getIndices());
// Generate insertion.
Value value = adaptor.getValue();
auto insertPoint = op->template getParentOfType<func::FuncOp>();
- genInsertionCallHelper(rewriter, desc, indices, value, insertPoint,
- kInsertFuncNamePrefix, genInsertBody);
+ genInsertionCallHelper(rewriter, dstType, fields, indices, value,
+ insertPoint, kInsertFuncNamePrefix, genInsertBody);
// Replace operation with resulting memrefs.
- rewriter.replaceOp(op, genTuple(rewriter, op.getLoc(), desc));
+ rewriter.replaceOp(op, genTuple(rewriter, op.getLoc(), dstType, fields));
return success();
}
};
@@ -944,9 +1003,11 @@ class SparseGetterOpConverter : public OpConversionPattern<SourceOp> {
// Replace the requested pointer access with corresponding field.
// The cast_op is inserted by type converter to intermix 1:N type
// conversion.
- auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
- Value field = Base::getFieldForOp(desc, op);
- rewriter.replaceOp(op, field);
+ auto tuple = getTuple(adaptor.getTensor());
+ unsigned idx = Base::getIndexForOp(tuple, op);
+ auto fields = tuple.getInputs();
+ assert(idx < fields.size());
+ rewriter.replaceOp(op, fields[idx]);
return success();
}
};
@@ -957,10 +1018,10 @@ class SparseToPointersConverter
public:
using SparseGetterOpConverter::SparseGetterOpConverter;
// Callback for SparseGetterOpConverter.
- static Value getFieldForOp(const SparseTensorDescriptor &desc,
- ToPointersOp op) {
+ static unsigned getIndexForOp(UnrealizedConversionCastOp /*tuple*/,
+ ToPointersOp op) {
uint64_t dim = op.getDimension().getZExtValue();
- return desc.getPtrMemRef(dim);
+ return getFieldIndex(op.getTensor().getType(), /*ptrDim=*/dim, -1u);
}
};
@@ -970,10 +1031,10 @@ class SparseToIndicesConverter
public:
using SparseGetterOpConverter::SparseGetterOpConverter;
// Callback for SparseGetterOpConverter.
- static Value getFieldForOp(const SparseTensorDescriptor &desc,
- ToIndicesOp op) {
+ static unsigned getIndexForOp(UnrealizedConversionCastOp /*tuple*/,
+ ToIndicesOp op) {
uint64_t dim = op.getDimension().getZExtValue();
- return desc.getIdxMemRef(dim);
+ return getFieldIndex(op.getTensor().getType(), -1u, /*idxDim=*/dim);
}
};
@@ -983,9 +1044,10 @@ class SparseToValuesConverter
public:
using SparseGetterOpConverter::SparseGetterOpConverter;
// Callback for SparseGetterOpConverter.
- static Value getFieldForOp(const SparseTensorDescriptor &desc,
- ToValuesOp /*op*/) {
- return desc.getValMemRef();
+ static unsigned getIndexForOp(UnrealizedConversionCastOp tuple,
+ ToValuesOp /*op*/) {
+ // The last field holds the value buffer.
+ return tuple.getInputs().size() - 1;
}
};
@@ -1017,11 +1079,12 @@ class SparseNumberOfEntriesConverter
matchAndRewrite(NumberOfEntriesOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Query memSizes for the actually stored values size.
- auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
+ auto tuple = getTuple(adaptor.getTensor());
+ auto fields = tuple.getInputs();
+ unsigned lastField = fields.size() - 1;
Value field =
- constantIndex(rewriter, op.getLoc(), desc.getValMemSizesIndex());
- rewriter.replaceOpWithNewOp<memref::LoadOp>(op, desc.getMemSizesMemRef(),
- field);
+ constantIndex(rewriter, op.getLoc(), getMemSizesIndex(lastField));
+ rewriter.replaceOpWithNewOp<memref::LoadOp>(op, fields[memSizesIdx], field);
return success();
}
};
More information about the Mlir-commits
mailing list