[Mlir-commits] [mlir] afe78db - [mlir][sparse] Make sparse_tensor::StorageLayout publicly available.
Peiming Liu
llvmlistbot at llvm.org
Thu May 18 13:29:51 PDT 2023
Author: Peiming Liu
Date: 2023-05-18T20:29:46Z
New Revision: afe78db7701d15f1d69b3fa50ee05fd42a6297cd
URL: https://github.com/llvm/llvm-project/commit/afe78db7701d15f1d69b3fa50ee05fd42a6297cd
DIFF: https://github.com/llvm/llvm-project/commit/afe78db7701d15f1d69b3fa50ee05fd42a6297cd.diff
LOG: [mlir][sparse] Make sparse_tensor::StorageLayout publicly available.
Reviewed By: aartbik
Differential Revision: https://reviews.llvm.org/D150739
Added:
mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorDescriptor.cpp
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorDescriptor.h
Modified:
mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
mlir/lib/Dialect/SparseTensor/Transforms/SparseStorageSpecifierToLLVM.cpp
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
Removed:
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.cpp
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h
################################################################################
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h
new file mode 100644
index 0000000000000..82c618efa984f
--- /dev/null
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h
@@ -0,0 +1,183 @@
+//===- SparseTensorStorageLayout.h ------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This header file defines utilities for the sparse memory layout.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_SPARSETENSOR_IR_SPARSETENSORSTORAGELAYOUT_H_
+#define MLIR_DIALECT_SPARSETENSOR_IR_SPARSETENSORSTORAGELAYOUT_H_
+
+#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
+#include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
+
+namespace mlir {
+namespace sparse_tensor {
+
+///===----------------------------------------------------------------------===//
+/// The sparse tensor storage scheme for a tensor is organized as a single
+/// compound type with the following fields. Note that every memref with ? size
+/// actually behaves as a "vector", i.e. the stored size is the capacity and the
+/// used size resides in the storage_specifier struct.
+///
+/// struct {
+/// ; per-level l:
+/// ; if dense:
+/// <nothing>
+/// ; if compresed:
+/// memref<? x pos> positions-l ; positions for sparse level l
+/// memref<? x crd> coordinates-l ; coordinates for sparse level l
+/// ; if singleton:
+/// memref<? x crd> coordinates-l ; coordinates for singleton level l
+///
+/// memref<? x eltType> values ; values
+///
+/// struct sparse_tensor.storage_specifier {
+/// array<rank x int> lvlSizes ; sizes/cardinalities for each level
+/// array<n x int> memSizes; ; sizes/lengths for each data memref
+/// }
+/// };
+///
+/// In addition, for a "trailing COO region", defined as a compressed level
+/// followed by one or more singleton levels, the default SOA storage that
+/// is inherent to the TACO format is optimized into an AOS storage where
+/// all coordinates of a stored element appear consecutively. In such cases,
+/// a special operation (sparse_tensor.coordinates_buffer) must be used to
+/// access the AOS coordinates array. In the code below, the method
+/// `getCOOStart` is used to find the start of the "trailing COO region".
+///
+/// If the sparse tensor is a slice (produced by `tensor.extract_slice`
+/// operation), instead of allocating a new sparse tensor for it, it reuses the
+/// same sets of MemRefs but attaching a additional set of slicing-metadata for
+/// per-dimension slice offset and stride.
+///
+/// Examples.
+///
+/// #CSR storage of 2-dim matrix yields
+/// memref<?xindex> ; positions-1
+/// memref<?xindex> ; coordinates-1
+/// memref<?xf64> ; values
+/// struct<(array<2 x i64>, array<3 x i64>)>) ; lvl0, lvl1, 3xsizes
+///
+/// #COO storage of 2-dim matrix yields
+/// memref<?xindex>, ; positions-0, essentially
+/// [0,sz] memref<?xindex> ; AOS coordinates storage
+/// memref<?xf64> ; values
+/// struct<(array<2 x i64>, array<3 x i64>)>) ; lvl0, lvl1, 3xsizes
+///
+/// Slice on #COO storage of 2-dim matrix yields
+/// ;; Inherited from the original sparse tensors
+/// memref<?xindex>, ; positions-0, essentially
+/// [0,sz] memref<?xindex> ; AOS coordinates storage
+/// memref<?xf64> ; values
+/// struct<(array<2 x i64>, array<3 x i64>, ; lvl0, lvl1, 3xsizes
+/// ;; Extra slicing-metadata
+/// array<2 x i64>, array<2 x i64>)>) ; dim offset, dim stride.
+///
+///===----------------------------------------------------------------------===//
+
+enum class SparseTensorFieldKind : uint32_t {
+ StorageSpec = 0,
+ PosMemRef = static_cast<uint32_t>(StorageSpecifierKind::PosMemSize),
+ CrdMemRef = static_cast<uint32_t>(StorageSpecifierKind::CrdMemSize),
+ ValMemRef = static_cast<uint32_t>(StorageSpecifierKind::ValMemSize)
+};
+
+inline StorageSpecifierKind toSpecifierKind(SparseTensorFieldKind kind) {
+ assert(kind != SparseTensorFieldKind::StorageSpec);
+ return static_cast<StorageSpecifierKind>(kind);
+}
+
+inline SparseTensorFieldKind toFieldKind(StorageSpecifierKind kind) {
+ assert(kind != StorageSpecifierKind::LvlSize);
+ return static_cast<SparseTensorFieldKind>(kind);
+}
+
+/// The type of field indices. This alias is to help code be more
+/// self-documenting; unfortunately it is not type-checked, so it only
+/// provides documentation rather than doing anything to prevent mixups.
+using FieldIndex = unsigned;
+
+/// Provides methods to access fields of a sparse tensor with the given
+/// encoding.
+class StorageLayout {
+public:
+ // TODO: Functions/methods marked with [NUMFIELDS] might should use
+ // `FieldIndex` for their return type, via the same reasoning for why
+ // `Dimension`/`Level` are used both for identifiers and ranks.
+ explicit StorageLayout(const SparseTensorType &stt)
+ : StorageLayout(stt.getEncoding()) {}
+ explicit StorageLayout(SparseTensorEncodingAttr enc) : enc(enc) {
+ assert(enc);
+ }
+
+ /// For each field that will be allocated for the given sparse tensor
+ /// encoding, calls the callback with the corresponding field index,
+ /// field kind, level, and level-type (the last two are only for level
+ /// memrefs). The field index always starts with zero and increments
+ /// by one between each callback invocation. Ideally, all other methods
+ /// should rely on this function to query a sparse tensor fields instead
+ /// of relying on ad-hoc index computation.
+ void foreachField(
+ llvm::function_ref<bool(
+ FieldIndex /*fieldIdx*/, SparseTensorFieldKind /*fieldKind*/,
+ Level /*lvl (if applicable)*/, DimLevelType /*DLT (if applicable)*/)>)
+ const;
+
+ /// Gets the field index for required field.
+ FieldIndex getMemRefFieldIndex(SparseTensorFieldKind kind,
+ std::optional<Level> lvl) const {
+ return getFieldIndexAndStride(kind, lvl).first;
+ }
+
+ /// Gets the total number of fields for the given sparse tensor encoding.
+ unsigned getNumFields() const;
+
+ /// Gets the total number of data fields (coordinate arrays, position
+ /// arrays, and a value array) for the given sparse tensor encoding.
+ unsigned getNumDataFields() const;
+
+ std::pair<FieldIndex, unsigned>
+ getFieldIndexAndStride(SparseTensorFieldKind kind,
+ std::optional<Level> lvl) const;
+
+private:
+ const SparseTensorEncodingAttr enc;
+};
+
+//
+// Wrapper functions to invoke StorageLayout-related method.
+//
+
+// TODO: See note [NUMFIELDS].
+inline unsigned getNumFieldsFromEncoding(SparseTensorEncodingAttr enc) {
+ return StorageLayout(enc).getNumFields();
+}
+
+// TODO: See note [NUMFIELDS].
+inline unsigned getNumDataFieldsFromEncoding(SparseTensorEncodingAttr enc) {
+ return StorageLayout(enc).getNumDataFields();
+}
+
+inline void foreachFieldInSparseTensor(
+ SparseTensorEncodingAttr enc,
+ llvm::function_ref<bool(FieldIndex, SparseTensorFieldKind, Level,
+ DimLevelType)>
+ callback) {
+ return StorageLayout(enc).foreachField(callback);
+}
+
+void foreachFieldAndTypeInSparseTensor(
+ SparseTensorType,
+ llvm::function_ref<bool(Type, FieldIndex, SparseTensorFieldKind, Level,
+ DimLevelType)>);
+
+} // namespace sparse_tensor
+} // namespace mlir
+
+#endif // MLIR_DIALECT_SPARSETENSOR_IR_SPARSETENSORSTORAGELAYOUT_H_
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 22d6304dcb415..f8e7e2df3a95d 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -9,6 +9,7 @@
#include <utility>
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
+#include "mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
@@ -41,6 +42,139 @@ static inline Dimension getDimRank(T t) {
return getRankedTensorType(t).getRank();
}
+//===----------------------------------------------------------------------===//
+// StorageLayout
+//===----------------------------------------------------------------------===//
+
+static constexpr Level kInvalidLevel = -1u;
+static constexpr Level kInvalidFieldIndex = -1u;
+static constexpr FieldIndex kDataFieldStartingIdx = 0;
+
+void StorageLayout::foreachField(
+ llvm::function_ref<bool(FieldIndex, SparseTensorFieldKind, Level,
+ DimLevelType)>
+ callback) const {
+#define RETURN_ON_FALSE(fidx, kind, lvl, dlt) \
+ if (!(callback(fidx, kind, lvl, dlt))) \
+ return;
+
+ const auto lvlTypes = enc.getLvlTypes();
+ const Level lvlRank = enc.getLvlRank();
+ const Level cooStart = getCOOStart(enc);
+ const Level end = cooStart == lvlRank ? cooStart : cooStart + 1;
+ FieldIndex fieldIdx = kDataFieldStartingIdx;
+ // Per-level storage.
+ for (Level l = 0; l < end; l++) {
+ const auto dlt = lvlTypes[l];
+ if (isCompressedDLT(dlt) || isCompressedWithHiDLT(dlt)) {
+ RETURN_ON_FALSE(fieldIdx++, SparseTensorFieldKind::PosMemRef, l, dlt);
+ RETURN_ON_FALSE(fieldIdx++, SparseTensorFieldKind::CrdMemRef, l, dlt);
+ } else if (isSingletonDLT(dlt)) {
+ RETURN_ON_FALSE(fieldIdx++, SparseTensorFieldKind::CrdMemRef, l, dlt);
+ } else {
+ assert(isDenseDLT(dlt)); // no fields
+ }
+ }
+ // The values array.
+ RETURN_ON_FALSE(fieldIdx++, SparseTensorFieldKind::ValMemRef, kInvalidLevel,
+ DimLevelType::Undef);
+
+ // Put metadata at the end.
+ RETURN_ON_FALSE(fieldIdx++, SparseTensorFieldKind::StorageSpec, kInvalidLevel,
+ DimLevelType::Undef);
+
+#undef RETURN_ON_FALSE
+}
+
+void sparse_tensor::foreachFieldAndTypeInSparseTensor(
+ SparseTensorType stt,
+ llvm::function_ref<bool(Type, FieldIndex, SparseTensorFieldKind, Level,
+ DimLevelType)>
+ callback) {
+ assert(stt.hasEncoding());
+ // Construct the basic types.
+ const Type crdType = stt.getCrdType();
+ const Type posType = stt.getPosType();
+ const Type eltType = stt.getElementType();
+
+ const Type specType = StorageSpecifierType::get(stt.getEncoding());
+ // memref<? x pos> positions
+ const Type posMemType = MemRefType::get({ShapedType::kDynamic}, posType);
+ // memref<? x crd> coordinates
+ const Type crdMemType = MemRefType::get({ShapedType::kDynamic}, crdType);
+ // memref<? x eltType> values
+ const Type valMemType = MemRefType::get({ShapedType::kDynamic}, eltType);
+
+ StorageLayout(stt).foreachField(
+ [specType, posMemType, crdMemType, valMemType,
+ callback](FieldIndex fieldIdx, SparseTensorFieldKind fieldKind,
+ Level lvl, DimLevelType dlt) -> bool {
+ switch (fieldKind) {
+ case SparseTensorFieldKind::StorageSpec:
+ return callback(specType, fieldIdx, fieldKind, lvl, dlt);
+ case SparseTensorFieldKind::PosMemRef:
+ return callback(posMemType, fieldIdx, fieldKind, lvl, dlt);
+ case SparseTensorFieldKind::CrdMemRef:
+ return callback(crdMemType, fieldIdx, fieldKind, lvl, dlt);
+ case SparseTensorFieldKind::ValMemRef:
+ return callback(valMemType, fieldIdx, fieldKind, lvl, dlt);
+ };
+ llvm_unreachable("unrecognized field kind");
+ });
+}
+
+unsigned StorageLayout::getNumFields() const {
+ unsigned numFields = 0;
+ foreachField([&numFields](FieldIndex, SparseTensorFieldKind, Level,
+ DimLevelType) -> bool {
+ numFields++;
+ return true;
+ });
+ return numFields;
+}
+
+unsigned StorageLayout::getNumDataFields() const {
+ unsigned numFields = 0; // one value memref
+ foreachField([&numFields](FieldIndex fidx, SparseTensorFieldKind, Level,
+ DimLevelType) -> bool {
+ if (fidx >= kDataFieldStartingIdx)
+ numFields++;
+ return true;
+ });
+ numFields -= 1; // the last field is StorageSpecifier
+ assert(numFields == getNumFields() - kDataFieldStartingIdx - 1);
+ return numFields;
+}
+
+std::pair<FieldIndex, unsigned>
+StorageLayout::getFieldIndexAndStride(SparseTensorFieldKind kind,
+ std::optional<Level> lvl) const {
+ FieldIndex fieldIdx = kInvalidFieldIndex;
+ unsigned stride = 1;
+ if (kind == SparseTensorFieldKind::CrdMemRef) {
+ assert(lvl.has_value());
+ const Level cooStart = getCOOStart(enc);
+ const Level lvlRank = enc.getLvlRank();
+ if (lvl.value() >= cooStart && lvl.value() < lvlRank) {
+ lvl = cooStart;
+ stride = lvlRank - cooStart;
+ }
+ }
+ foreachField([lvl, kind, &fieldIdx](FieldIndex fIdx,
+ SparseTensorFieldKind fKind, Level fLvl,
+ DimLevelType dlt) -> bool {
+ if ((lvl && fLvl == lvl.value() && kind == fKind) ||
+ (kind == fKind && fKind == SparseTensorFieldKind::ValMemRef)) {
+ fieldIdx = fIdx;
+ // Returns false to break the iteration.
+ return false;
+ }
+ return true;
+ });
+ assert(fieldIdx != kInvalidFieldIndex);
+ return std::pair<FieldIndex, unsigned>(fieldIdx, stride);
+}
+
//===----------------------------------------------------------------------===//
// TensorDialect Attribute Methods.
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
index 6133c8b5174b4..5ef9d906f0e8b 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
@@ -10,7 +10,7 @@ add_mlir_dialect_library(MLIRSparseTensorTransforms
SparseTensorConversion.cpp
SparseTensorPasses.cpp
SparseTensorRewriting.cpp
- SparseTensorStorageLayout.cpp
+ SparseTensorDescriptor.cpp
SparseVectorization.cpp
Sparsification.cpp
SparsificationAndBufferizationPass.cpp
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
index 131add5e7fc98..67d074e85de66 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
@@ -7,7 +7,7 @@
//===----------------------------------------------------------------------===//
#include "CodegenUtils.h"
-#include "SparseTensorStorageLayout.h"
+#include "SparseTensorDescriptor.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseStorageSpecifierToLLVM.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseStorageSpecifierToLLVM.cpp
index f34ed9779cfd3..a6f4dd3c2f718 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseStorageSpecifierToLLVM.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseStorageSpecifierToLLVM.cpp
@@ -7,9 +7,11 @@
//===----------------------------------------------------------------------===//
#include "CodegenUtils.h"
-#include "SparseTensorStorageLayout.h"
+#include "mlir/Conversion/LLVMCommon/StructBuilder.h"
+#include "mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h"
#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
+
#include <optional>
using namespace mlir;
@@ -262,7 +264,8 @@ class SpecifierGetterSetterOpConverter : public OpConversionPattern<SourceOp> {
std::optional<unsigned> lvl;
if (op.getLevel())
lvl = (*op.getLevel());
- unsigned idx = layout.getMemRefFieldIndex(op.getSpecifierKind(), lvl);
+ unsigned idx =
+ layout.getMemRefFieldIndex(toFieldKind(op.getSpecifierKind()), lvl);
Value v = Base::onMemSize(rewriter, op, spec, idx);
rewriter.replaceOp(op, v);
return success();
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index 0005c4c6a969b..6d747c00910fc 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -16,7 +16,7 @@
//===----------------------------------------------------------------------===//
#include "CodegenUtils.h"
-#include "SparseTensorStorageLayout.h"
+#include "SparseTensorDescriptor.h"
#include "llvm/Support/FormatVariadic.h"
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorDescriptor.cpp
similarity index 51%
rename from mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.cpp
rename to mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorDescriptor.cpp
index a47d26e1b9595..5c363b0c781d5 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorDescriptor.cpp
@@ -6,7 +6,7 @@
//
//===----------------------------------------------------------------------===//
-#include "SparseTensorStorageLayout.h"
+#include "SparseTensorDescriptor.h"
#include "CodegenUtils.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
@@ -116,117 +116,3 @@ Value sparse_tensor::SparseTensorDescriptor::getCrdMemRefOrView(
/*size=*/ValueRange{size},
/*step=*/ValueRange{stride});
}
-
-//===----------------------------------------------------------------------===//
-// Public methods.
-//===----------------------------------------------------------------------===//
-
-constexpr FieldIndex kDataFieldStartingIdx = 0;
-
-void sparse_tensor::foreachFieldInSparseTensor(
- const SparseTensorEncodingAttr enc,
- llvm::function_ref<bool(FieldIndex, SparseTensorFieldKind, Level,
- DimLevelType)>
- callback) {
- assert(enc);
-
-#define RETURN_ON_FALSE(fidx, kind, dim, dlt) \
- if (!(callback(fidx, kind, dim, dlt))) \
- return;
-
- const auto lvlTypes = enc.getLvlTypes();
- const Level lvlRank = enc.getLvlRank();
- const Level cooStart = getCOOStart(enc);
- const Level end = cooStart == lvlRank ? cooStart : cooStart + 1;
- FieldIndex fieldIdx = kDataFieldStartingIdx;
- // Per-dimension storage.
- for (Level l = 0; l < end; l++) {
- // Dimension level types apply in order to the reordered dimension.
- // As a result, the compound type can be constructed directly in the given
- // order.
- const auto dlt = lvlTypes[l];
- if (isCompressedDLT(dlt) || isCompressedWithHiDLT(dlt)) {
- RETURN_ON_FALSE(fieldIdx++, SparseTensorFieldKind::PosMemRef, l, dlt);
- RETURN_ON_FALSE(fieldIdx++, SparseTensorFieldKind::CrdMemRef, l, dlt);
- } else if (isSingletonDLT(dlt)) {
- RETURN_ON_FALSE(fieldIdx++, SparseTensorFieldKind::CrdMemRef, l, dlt);
- } else {
- assert(isDenseDLT(dlt)); // no fields
- }
- }
- // The values array.
- RETURN_ON_FALSE(fieldIdx++, SparseTensorFieldKind::ValMemRef, -1u,
- DimLevelType::Undef);
-
- // Put metadata at the end.
- RETURN_ON_FALSE(fieldIdx++, SparseTensorFieldKind::StorageSpec, -1u,
- DimLevelType::Undef);
-
-#undef RETURN_ON_FALSE
-}
-
-void sparse_tensor::foreachFieldAndTypeInSparseTensor(
- SparseTensorType stt,
- llvm::function_ref<bool(Type, FieldIndex, SparseTensorFieldKind, Level,
- DimLevelType)>
- callback) {
- assert(stt.hasEncoding());
- // Construct the basic types.
- const Type crdType = stt.getCrdType();
- const Type posType = stt.getPosType();
- const Type eltType = stt.getElementType();
-
- const Type metaDataType = StorageSpecifierType::get(stt.getEncoding());
- // memref<? x pos> positions
- const Type posMemType = MemRefType::get({ShapedType::kDynamic}, posType);
- // memref<? x crd> coordinates
- const Type crdMemType = MemRefType::get({ShapedType::kDynamic}, crdType);
- // memref<? x eltType> values
- const Type valMemType = MemRefType::get({ShapedType::kDynamic}, eltType);
-
- foreachFieldInSparseTensor(
- stt.getEncoding(),
- [metaDataType, posMemType, crdMemType, valMemType,
- callback](FieldIndex fieldIdx, SparseTensorFieldKind fieldKind,
- Level lvl, DimLevelType dlt) -> bool {
- switch (fieldKind) {
- case SparseTensorFieldKind::StorageSpec:
- return callback(metaDataType, fieldIdx, fieldKind, lvl, dlt);
- case SparseTensorFieldKind::PosMemRef:
- return callback(posMemType, fieldIdx, fieldKind, lvl, dlt);
- case SparseTensorFieldKind::CrdMemRef:
- return callback(crdMemType, fieldIdx, fieldKind, lvl, dlt);
- case SparseTensorFieldKind::ValMemRef:
- return callback(valMemType, fieldIdx, fieldKind, lvl, dlt);
- };
- llvm_unreachable("unrecognized field kind");
- });
-}
-
-unsigned sparse_tensor::getNumFieldsFromEncoding(SparseTensorEncodingAttr enc) {
- unsigned numFields = 0;
- foreachFieldInSparseTensor(enc,
- [&numFields](FieldIndex, SparseTensorFieldKind,
- Level, DimLevelType) -> bool {
- numFields++;
- return true;
- });
- return numFields;
-}
-
-unsigned
-sparse_tensor::getNumDataFieldsFromEncoding(SparseTensorEncodingAttr enc) {
- unsigned numFields = 0; // one value memref
- foreachFieldInSparseTensor(enc,
- [&numFields](FieldIndex fidx,
- SparseTensorFieldKind, Level,
- DimLevelType) -> bool {
- if (fidx >= kDataFieldStartingIdx)
- numFields++;
- return true;
- });
- numFields -= 1; // the last field is MetaData field
- assert(numFields ==
- getNumFieldsFromEncoding(enc) - kDataFieldStartingIdx - 1);
- return numFields;
-}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorDescriptor.h
similarity index 53%
rename from mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h
rename to mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorDescriptor.h
index a51fcc598ea9f..a9ed5751ab67e 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorDescriptor.h
@@ -1,4 +1,4 @@
-//===- SparseTensorStorageLayout.h ------------------------------*- C++ -*-===//
+//===- SparseTensorDescriptor.h ------------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -10,11 +10,11 @@
//
//===----------------------------------------------------------------------===//
-#ifndef MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_SPARSETENSORBUILDER_H_
-#define MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_SPARSETENSORBUILDER_H_
+#ifndef MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_SPARSETENSORDESCRIPTOR_H_
+#define MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_SPARSETENSORDESCRIPTOR_H_
-#include "mlir/Conversion/LLVMCommon/StructBuilder.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
+#include "mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
#include "mlir/Transforms/DialectConversion.h"
@@ -27,199 +27,8 @@ namespace sparse_tensor {
// layout scheme during "direct code generation" (i.e. when sparsification
// generates the buffers as part of actual IR, in constrast with the library
// approach where data structures are hidden behind opaque pointers).
-//
-// The sparse tensor storage scheme for a rank-dimensional tensor is organized
-// as a single compound type with the following fields. Note that every memref
-// with ? size actually behaves as a "vector", i.e. the stored size is the
-// capacity and the used size resides in the storage_specifier struct.
-//
-// struct {
-// ; per-level l:
-// ; if dense:
-// <nothing>
-// ; if compresed:
-// memref<? x pos> positions-l ; positions for sparse level l
-// memref<? x crd> coordinates-l ; coordinates for sparse level l
-// ; if singleton:
-// memref<? x crd> coordinates-l ; coordinates for singleton level l
-//
-// memref<? x eltType> values ; values
-//
-// struct sparse_tensor.storage_specifier {
-// array<rank x int> lvlSizes ; sizes/cardinalities for each level
-// array<n x int> memSizes; ; sizes/lengths for each data memref
-// }
-// };
-//
-// In addition, for a "trailing COO region", defined as a compressed level
-// followed by one or more singleton levels, the default SOA storage that
-// is inherent to the TACO format is optimized into an AOS storage where
-// all coordinates of a stored element appear consecutively. In such cases,
-// a special operation (sparse_tensor.coordinates_buffer) must be used to
-// access the AOS coordinates array. In the code below, the method `getCOOStart`
-// is used to find the start of the "trailing COO region".
-//
-// If the sparse tensor is a slice (produced by `tensor.extract_slice`
-// operation), instead of allocating a new sparse tensor for it, it reuses the
-// same sets of MemRefs but attaching a additional set of slicing-metadata for
-// per-dimension slice offset and stride.
-//
-// Examples.
-//
-// #CSR storage of 2-dim matrix yields
-// memref<?xindex> ; positions-1
-// memref<?xindex> ; coordinates-1
-// memref<?xf64> ; values
-// struct<(array<2 x i64>, array<3 x i64>)>) ; lvl0, lvl1, 3xsizes
-//
-// #COO storage of 2-dim matrix yields
-// memref<?xindex>, ; positions-0, essentially [0,sz]
-// memref<?xindex> ; AOS coordinates storage
-// memref<?xf64> ; values
-// struct<(array<2 x i64>, array<3 x i64>)>) ; lvl0, lvl1, 3xsizes
-//
-// Slice on #COO storage of 2-dim matrix yields
-// ;; Inherited from the original sparse tensors
-// memref<?xindex>, ; positions-0, essentially [0,sz]
-// memref<?xindex> ; AOS coordinates storage
-// memref<?xf64> ; values
-// struct<(array<2 x i64>, array<3 x i64>, ; lvl0, lvl1, 3xsizes
-// ;; Extra slicing-metadata
-// array<2 x i64>, array<2 x i64>)>) ; dim offset, dim stride.
-//
//===----------------------------------------------------------------------===//
-enum class SparseTensorFieldKind : uint32_t {
- StorageSpec = 0,
- PosMemRef = 1,
- CrdMemRef = 2,
- ValMemRef = 3
-};
-
-static_assert(static_cast<uint32_t>(SparseTensorFieldKind::PosMemRef) ==
- static_cast<uint32_t>(StorageSpecifierKind::PosMemSize));
-static_assert(static_cast<uint32_t>(SparseTensorFieldKind::CrdMemRef) ==
- static_cast<uint32_t>(StorageSpecifierKind::CrdMemSize));
-static_assert(static_cast<uint32_t>(SparseTensorFieldKind::ValMemRef) ==
- static_cast<uint32_t>(StorageSpecifierKind::ValMemSize));
-
-/// The type of field indices. This alias is to help code be more
-/// self-documenting; unfortunately it is not type-checked, so it only
-/// provides documentation rather than doing anything to prevent mixups.
-using FieldIndex = unsigned;
-
-// TODO: Functions/methods marked with [NUMFIELDS] might should use
-// `FieldIndex` for their return type, via the same reasoning for why
-// `Dimension`/`Level` are used both for identifiers and ranks.
-
-/// For each field that will be allocated for the given sparse tensor
-/// encoding, calls the callback with the corresponding field index,
-/// field kind, level, and level-type (the last two are only for level
-/// memrefs). The field index always starts with zero and increments
-/// by one between each callback invocation. Ideally, all other methods
-/// should rely on this function to query a sparse tensor fields instead
-/// of relying on ad-hoc index computation.
-void foreachFieldInSparseTensor(
- SparseTensorEncodingAttr,
- llvm::function_ref<bool(
- FieldIndex /*fieldIdx*/, SparseTensorFieldKind /*fieldKind*/,
- Level /*lvl (if applicable)*/, DimLevelType /*DLT (if applicable)*/)>);
-
-/// Same as above, except that it also builds the Type for the corresponding
-/// field.
-void foreachFieldAndTypeInSparseTensor(
- SparseTensorType,
- llvm::function_ref<bool(Type /*fieldType*/, FieldIndex /*fieldIdx*/,
- SparseTensorFieldKind /*fieldKind*/,
- Level /*lvl (if applicable)*/,
- DimLevelType /*DLT (if applicable)*/)>);
-
-/// Gets the total number of fields for the given sparse tensor encoding.
-// TODO: See note [NUMFIELDS].
-unsigned getNumFieldsFromEncoding(SparseTensorEncodingAttr enc);
-
-/// Gets the total number of data fields (coordinate arrays, position
-/// arrays, and a value array) for the given sparse tensor encoding.
-// TODO: See note [NUMFIELDS].
-unsigned getNumDataFieldsFromEncoding(SparseTensorEncodingAttr enc);
-
-inline StorageSpecifierKind toSpecifierKind(SparseTensorFieldKind kind) {
- assert(kind != SparseTensorFieldKind::StorageSpec);
- return static_cast<StorageSpecifierKind>(kind);
-}
-
-inline SparseTensorFieldKind toFieldKind(StorageSpecifierKind kind) {
- assert(kind != StorageSpecifierKind::LvlSize);
- return static_cast<SparseTensorFieldKind>(kind);
-}
-
-/// Provides methods to access fields of a sparse tensor with the given
-/// encoding.
-class StorageLayout {
-public:
- explicit StorageLayout(SparseTensorEncodingAttr enc) : enc(enc) {}
-
- ///
- /// Getters: get the field index for required field.
- ///
-
- FieldIndex getMemRefFieldIndex(SparseTensorFieldKind kind,
- std::optional<Level> lvl) const {
- return getFieldIndexAndStride(kind, lvl).first;
- }
-
- FieldIndex getMemRefFieldIndex(StorageSpecifierKind kind,
- std::optional<Level> lvl) const {
- return getMemRefFieldIndex(toFieldKind(kind), lvl);
- }
-
- // TODO: See note [NUMFIELDS].
- static unsigned getNumFieldsFromEncoding(SparseTensorEncodingAttr enc) {
- return sparse_tensor::getNumFieldsFromEncoding(enc);
- }
-
- static void foreachFieldInSparseTensor(
- const SparseTensorEncodingAttr enc,
- llvm::function_ref<bool(FieldIndex, SparseTensorFieldKind, Level,
- DimLevelType)>
- callback) {
- return sparse_tensor::foreachFieldInSparseTensor(enc, callback);
- }
-
- std::pair<FieldIndex, unsigned>
- getFieldIndexAndStride(SparseTensorFieldKind kind,
- std::optional<Level> lvl) const {
- FieldIndex fieldIdx = -1u;
- unsigned stride = 1;
- if (kind == SparseTensorFieldKind::CrdMemRef) {
- assert(lvl.has_value());
- const Level cooStart = getCOOStart(enc);
- const Level lvlRank = enc.getLvlRank();
- if (lvl.value() >= cooStart && lvl.value() < lvlRank) {
- lvl = cooStart;
- stride = lvlRank - cooStart;
- }
- }
- foreachFieldInSparseTensor(
- enc,
- [lvl, kind, &fieldIdx](FieldIndex fIdx, SparseTensorFieldKind fKind,
- Level fLvl, DimLevelType dlt) -> bool {
- if ((lvl && fLvl == lvl.value() && kind == fKind) ||
- (kind == fKind && fKind == SparseTensorFieldKind::ValMemRef)) {
- fieldIdx = fIdx;
- // Returns false to break the iteration.
- return false;
- }
- return true;
- });
- assert(fieldIdx != -1u);
- return std::pair<FieldIndex, unsigned>(fieldIdx, stride);
- }
-
-private:
- SparseTensorEncodingAttr enc;
-};
-
class SparseTensorSpecifier {
public:
explicit SparseTensorSpecifier(Value specifier)
@@ -249,10 +58,12 @@ class SparseTensorSpecifier {
template <typename ValueArrayRef>
class SparseTensorDescriptorImpl {
protected:
+ // TODO: Functions/methods marked with [NUMFIELDS] might should use
+ // `FieldIndex` for their return type, via the same reasoning for why
+ // `Dimension`/`Level` are used both for identifiers and ranks.
SparseTensorDescriptorImpl(SparseTensorType stt, ValueArrayRef fields)
- : rType(stt), fields(fields) {
- assert(stt.hasEncoding() &&
- getNumFieldsFromEncoding(stt.getEncoding()) == getNumFields());
+ : rType(stt), fields(fields), layout(stt) {
+ assert(layout.getNumFields() == getNumFields());
// We should make sure the class is trivially copyable (and should be small
// enough) such that we can pass it by value.
static_assert(std::is_trivially_copyable_v<
@@ -263,7 +74,6 @@ class SparseTensorDescriptorImpl {
FieldIndex getMemRefFieldIndex(SparseTensorFieldKind kind,
std::optional<Level> lvl) const {
// Delegates to storage layout.
- StorageLayout layout(rType.getEncoding());
return layout.getMemRefFieldIndex(kind, lvl);
}
@@ -336,7 +146,6 @@ class SparseTensorDescriptorImpl {
}
std::pair<FieldIndex, unsigned> getCrdMemRefIndexAndStride(Level lvl) const {
- StorageLayout layout(rType.getEncoding());
return layout.getFieldIndexAndStride(SparseTensorFieldKind::CrdMemRef, lvl);
}
@@ -352,6 +161,7 @@ class SparseTensorDescriptorImpl {
protected:
SparseTensorType rType;
ValueArrayRef fields;
+ StorageLayout layout;
};
/// Uses ValueRange for immutable descriptors.
@@ -465,4 +275,4 @@ getMutDescriptorFromTensorTuple(Value tensor, SmallVectorImpl<Value> &fields) {
} // namespace sparse_tensor
} // namespace mlir
-#endif // MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_SPARSETENSORBUILDER_H_
+#endif // MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_SPARSETENSODESCRIPTOR_H_
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 1c0643380f2e7..e435b00cfa6dc 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -2403,6 +2403,7 @@ cc_library(
srcs = ["lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp"],
hdrs = [
"include/mlir/Dialect/SparseTensor/IR/SparseTensor.h",
+ "include/mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h",
"include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h",
],
includes = ["include"],
More information about the Mlir-commits
mailing list