[Mlir-commits] [mlir] afe78db - [mlir][sparse] Make sparse_tensor::StorageLayout publicly available.

Peiming Liu llvmlistbot at llvm.org
Thu May 18 13:29:51 PDT 2023


Author: Peiming Liu
Date: 2023-05-18T20:29:46Z
New Revision: afe78db7701d15f1d69b3fa50ee05fd42a6297cd

URL: https://github.com/llvm/llvm-project/commit/afe78db7701d15f1d69b3fa50ee05fd42a6297cd
DIFF: https://github.com/llvm/llvm-project/commit/afe78db7701d15f1d69b3fa50ee05fd42a6297cd.diff

LOG: [mlir][sparse] Make sparse_tensor::StorageLayout publicly available.

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D150739

Added: 
    mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorDescriptor.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorDescriptor.h

Modified: 
    mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
    mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/SparseStorageSpecifierToLLVM.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
    utils/bazel/llvm-project-overlay/mlir/BUILD.bazel

Removed: 
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h
new file mode 100644
index 0000000000000..82c618efa984f
--- /dev/null
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h
@@ -0,0 +1,183 @@
+//===- SparseTensorStorageLayout.h ------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This header file defines utilities for the sparse memory layout.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_SPARSETENSOR_IR_SPARSETENSORSTORAGELAYOUT_H_
+#define MLIR_DIALECT_SPARSETENSOR_IR_SPARSETENSORSTORAGELAYOUT_H_
+
+#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
+#include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
+
+namespace mlir {
+namespace sparse_tensor {
+
+///===----------------------------------------------------------------------===//
+/// The sparse tensor storage scheme for a tensor is organized as a single
+/// compound type with the following fields. Note that every memref with ? size
+/// actually behaves as a "vector", i.e. the stored size is the capacity and the
+/// used size resides in the storage_specifier struct.
+///
+/// struct {
+///   ; per-level l:
+///   ;  if dense:
+///        <nothing>
+///   ;  if compresed:
+///        memref<? x pos>  positions-l   ; positions for sparse level l
+///        memref<? x crd>  coordinates-l ; coordinates for sparse level l
+///   ;  if singleton:
+///        memref<? x crd>  coordinates-l ; coordinates for singleton level l
+///
+///   memref<? x eltType> values        ; values
+///
+///   struct sparse_tensor.storage_specifier {
+///     array<rank x int> lvlSizes    ; sizes/cardinalities for each level
+///     array<n x int> memSizes;      ; sizes/lengths for each data memref
+///   }
+/// };
+///
+/// In addition, for a "trailing COO region", defined as a compressed level
+/// followed by one or more singleton levels, the default SOA storage that
+/// is inherent to the TACO format is optimized into an AOS storage where
+/// all coordinates of a stored element appear consecutively.  In such cases,
+/// a special operation (sparse_tensor.coordinates_buffer) must be used to
+/// access the AOS coordinates array. In the code below, the method
+/// `getCOOStart` is used to find the start of the "trailing COO region".
+///
+/// If the sparse tensor is a slice (produced by `tensor.extract_slice`
+/// operation), instead of allocating a new sparse tensor for it, it reuses the
+/// same sets of MemRefs but attaching a additional set of slicing-metadata for
+/// per-dimension slice offset and stride.
+///
+/// Examples.
+///
+/// #CSR storage of 2-dim matrix yields
+///   memref<?xindex>                           ; positions-1
+///   memref<?xindex>                           ; coordinates-1
+///   memref<?xf64>                             ; values
+///   struct<(array<2 x i64>, array<3 x i64>)>) ; lvl0, lvl1, 3xsizes
+///
+/// #COO storage of 2-dim matrix yields
+///   memref<?xindex>,                          ; positions-0, essentially
+///   [0,sz] memref<?xindex>                           ; AOS coordinates storage
+///   memref<?xf64>                             ; values
+///   struct<(array<2 x i64>, array<3 x i64>)>) ; lvl0, lvl1, 3xsizes
+///
+/// Slice on #COO storage of 2-dim matrix yields
+///   ;; Inherited from the original sparse tensors
+///   memref<?xindex>,                          ; positions-0, essentially
+///   [0,sz] memref<?xindex>                           ; AOS coordinates storage
+///   memref<?xf64>                             ; values
+///   struct<(array<2 x i64>, array<3 x i64>,   ; lvl0, lvl1, 3xsizes
+///   ;; Extra slicing-metadata
+///           array<2 x i64>, array<2 x i64>)>) ; dim offset, dim stride.
+///
+///===----------------------------------------------------------------------===//
+
+enum class SparseTensorFieldKind : uint32_t {
+  StorageSpec = 0,
+  PosMemRef = static_cast<uint32_t>(StorageSpecifierKind::PosMemSize),
+  CrdMemRef = static_cast<uint32_t>(StorageSpecifierKind::CrdMemSize),
+  ValMemRef = static_cast<uint32_t>(StorageSpecifierKind::ValMemSize)
+};
+
+inline StorageSpecifierKind toSpecifierKind(SparseTensorFieldKind kind) {
+  assert(kind != SparseTensorFieldKind::StorageSpec);
+  return static_cast<StorageSpecifierKind>(kind);
+}
+
+inline SparseTensorFieldKind toFieldKind(StorageSpecifierKind kind) {
+  assert(kind != StorageSpecifierKind::LvlSize);
+  return static_cast<SparseTensorFieldKind>(kind);
+}
+
+/// The type of field indices.  This alias is to help code be more
+/// self-documenting; unfortunately it is not type-checked, so it only
+/// provides documentation rather than doing anything to prevent mixups.
+using FieldIndex = unsigned;
+
+/// Provides methods to access fields of a sparse tensor with the given
+/// encoding.
+class StorageLayout {
+public:
+  // TODO: Functions/methods marked with [NUMFIELDS] might should use
+  // `FieldIndex` for their return type, via the same reasoning for why
+  // `Dimension`/`Level` are used both for identifiers and ranks.
+  explicit StorageLayout(const SparseTensorType &stt)
+      : StorageLayout(stt.getEncoding()) {}
+  explicit StorageLayout(SparseTensorEncodingAttr enc) : enc(enc) {
+    assert(enc);
+  }
+
+  /// For each field that will be allocated for the given sparse tensor
+  /// encoding, calls the callback with the corresponding field index,
+  /// field kind, level, and level-type (the last two are only for level
+  /// memrefs).  The field index always starts with zero and increments
+  /// by one between each callback invocation.  Ideally, all other methods
+  /// should rely on this function to query a sparse tensor fields instead
+  /// of relying on ad-hoc index computation.
+  void foreachField(
+      llvm::function_ref<bool(
+          FieldIndex /*fieldIdx*/, SparseTensorFieldKind /*fieldKind*/,
+          Level /*lvl (if applicable)*/, DimLevelType /*DLT (if applicable)*/)>)
+      const;
+
+  /// Gets the field index for required field.
+  FieldIndex getMemRefFieldIndex(SparseTensorFieldKind kind,
+                                 std::optional<Level> lvl) const {
+    return getFieldIndexAndStride(kind, lvl).first;
+  }
+
+  /// Gets the total number of fields for the given sparse tensor encoding.
+  unsigned getNumFields() const;
+
+  /// Gets the total number of data fields (coordinate arrays, position
+  /// arrays, and a value array) for the given sparse tensor encoding.
+  unsigned getNumDataFields() const;
+
+  std::pair<FieldIndex, unsigned>
+  getFieldIndexAndStride(SparseTensorFieldKind kind,
+                         std::optional<Level> lvl) const;
+
+private:
+  const SparseTensorEncodingAttr enc;
+};
+
+//
+// Wrapper functions to invoke StorageLayout-related method.
+//
+
+// TODO: See note [NUMFIELDS].
+inline unsigned getNumFieldsFromEncoding(SparseTensorEncodingAttr enc) {
+  return StorageLayout(enc).getNumFields();
+}
+
+// TODO: See note [NUMFIELDS].
+inline unsigned getNumDataFieldsFromEncoding(SparseTensorEncodingAttr enc) {
+  return StorageLayout(enc).getNumDataFields();
+}
+
+inline void foreachFieldInSparseTensor(
+    SparseTensorEncodingAttr enc,
+    llvm::function_ref<bool(FieldIndex, SparseTensorFieldKind, Level,
+                            DimLevelType)>
+        callback) {
+  return StorageLayout(enc).foreachField(callback);
+}
+
+void foreachFieldAndTypeInSparseTensor(
+    SparseTensorType,
+    llvm::function_ref<bool(Type, FieldIndex, SparseTensorFieldKind, Level,
+                            DimLevelType)>);
+
+} // namespace sparse_tensor
+} // namespace mlir
+
+#endif // MLIR_DIALECT_SPARSETENSOR_IR_SPARSETENSORSTORAGELAYOUT_H_

diff  --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 22d6304dcb415..f8e7e2df3a95d 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -9,6 +9,7 @@
 #include <utility>
 
 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
+#include "mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h"
 #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
 
 #include "mlir/Dialect/Arith/IR/Arith.h"
@@ -41,6 +42,139 @@ static inline Dimension getDimRank(T t) {
   return getRankedTensorType(t).getRank();
 }
 
+//===----------------------------------------------------------------------===//
+// StorageLayout
+//===----------------------------------------------------------------------===//
+
+static constexpr Level kInvalidLevel = -1u;
+static constexpr Level kInvalidFieldIndex = -1u;
+static constexpr FieldIndex kDataFieldStartingIdx = 0;
+
+void StorageLayout::foreachField(
+    llvm::function_ref<bool(FieldIndex, SparseTensorFieldKind, Level,
+                            DimLevelType)>
+        callback) const {
+#define RETURN_ON_FALSE(fidx, kind, lvl, dlt)                                  \
+  if (!(callback(fidx, kind, lvl, dlt)))                                       \
+    return;
+
+  const auto lvlTypes = enc.getLvlTypes();
+  const Level lvlRank = enc.getLvlRank();
+  const Level cooStart = getCOOStart(enc);
+  const Level end = cooStart == lvlRank ? cooStart : cooStart + 1;
+  FieldIndex fieldIdx = kDataFieldStartingIdx;
+  // Per-level storage.
+  for (Level l = 0; l < end; l++) {
+    const auto dlt = lvlTypes[l];
+    if (isCompressedDLT(dlt) || isCompressedWithHiDLT(dlt)) {
+      RETURN_ON_FALSE(fieldIdx++, SparseTensorFieldKind::PosMemRef, l, dlt);
+      RETURN_ON_FALSE(fieldIdx++, SparseTensorFieldKind::CrdMemRef, l, dlt);
+    } else if (isSingletonDLT(dlt)) {
+      RETURN_ON_FALSE(fieldIdx++, SparseTensorFieldKind::CrdMemRef, l, dlt);
+    } else {
+      assert(isDenseDLT(dlt)); // no fields
+    }
+  }
+  // The values array.
+  RETURN_ON_FALSE(fieldIdx++, SparseTensorFieldKind::ValMemRef, kInvalidLevel,
+                  DimLevelType::Undef);
+
+  // Put metadata at the end.
+  RETURN_ON_FALSE(fieldIdx++, SparseTensorFieldKind::StorageSpec, kInvalidLevel,
+                  DimLevelType::Undef);
+
+#undef RETURN_ON_FALSE
+}
+
+void sparse_tensor::foreachFieldAndTypeInSparseTensor(
+    SparseTensorType stt,
+    llvm::function_ref<bool(Type, FieldIndex, SparseTensorFieldKind, Level,
+                            DimLevelType)>
+        callback) {
+  assert(stt.hasEncoding());
+  // Construct the basic types.
+  const Type crdType = stt.getCrdType();
+  const Type posType = stt.getPosType();
+  const Type eltType = stt.getElementType();
+
+  const Type specType = StorageSpecifierType::get(stt.getEncoding());
+  // memref<? x pos>  positions
+  const Type posMemType = MemRefType::get({ShapedType::kDynamic}, posType);
+  // memref<? x crd>  coordinates
+  const Type crdMemType = MemRefType::get({ShapedType::kDynamic}, crdType);
+  // memref<? x eltType> values
+  const Type valMemType = MemRefType::get({ShapedType::kDynamic}, eltType);
+
+  StorageLayout(stt).foreachField(
+      [specType, posMemType, crdMemType, valMemType,
+       callback](FieldIndex fieldIdx, SparseTensorFieldKind fieldKind,
+                 Level lvl, DimLevelType dlt) -> bool {
+        switch (fieldKind) {
+        case SparseTensorFieldKind::StorageSpec:
+          return callback(specType, fieldIdx, fieldKind, lvl, dlt);
+        case SparseTensorFieldKind::PosMemRef:
+          return callback(posMemType, fieldIdx, fieldKind, lvl, dlt);
+        case SparseTensorFieldKind::CrdMemRef:
+          return callback(crdMemType, fieldIdx, fieldKind, lvl, dlt);
+        case SparseTensorFieldKind::ValMemRef:
+          return callback(valMemType, fieldIdx, fieldKind, lvl, dlt);
+        };
+        llvm_unreachable("unrecognized field kind");
+      });
+}
+
+unsigned StorageLayout::getNumFields() const {
+  unsigned numFields = 0;
+  foreachField([&numFields](FieldIndex, SparseTensorFieldKind, Level,
+                            DimLevelType) -> bool {
+    numFields++;
+    return true;
+  });
+  return numFields;
+}
+
+unsigned StorageLayout::getNumDataFields() const {
+  unsigned numFields = 0; // one value memref
+  foreachField([&numFields](FieldIndex fidx, SparseTensorFieldKind, Level,
+                            DimLevelType) -> bool {
+    if (fidx >= kDataFieldStartingIdx)
+      numFields++;
+    return true;
+  });
+  numFields -= 1; // the last field is StorageSpecifier
+  assert(numFields == getNumFields() - kDataFieldStartingIdx - 1);
+  return numFields;
+}
+
+std::pair<FieldIndex, unsigned>
+StorageLayout::getFieldIndexAndStride(SparseTensorFieldKind kind,
+                                      std::optional<Level> lvl) const {
+  FieldIndex fieldIdx = kInvalidFieldIndex;
+  unsigned stride = 1;
+  if (kind == SparseTensorFieldKind::CrdMemRef) {
+    assert(lvl.has_value());
+    const Level cooStart = getCOOStart(enc);
+    const Level lvlRank = enc.getLvlRank();
+    if (lvl.value() >= cooStart && lvl.value() < lvlRank) {
+      lvl = cooStart;
+      stride = lvlRank - cooStart;
+    }
+  }
+  foreachField([lvl, kind, &fieldIdx](FieldIndex fIdx,
+                                      SparseTensorFieldKind fKind, Level fLvl,
+                                      DimLevelType dlt) -> bool {
+    if ((lvl && fLvl == lvl.value() && kind == fKind) ||
+        (kind == fKind && fKind == SparseTensorFieldKind::ValMemRef)) {
+      fieldIdx = fIdx;
+      // Returns false to break the iteration.
+      return false;
+    }
+    return true;
+  });
+  assert(fieldIdx != kInvalidFieldIndex);
+  return std::pair<FieldIndex, unsigned>(fieldIdx, stride);
+}
+
 //===----------------------------------------------------------------------===//
 // TensorDialect Attribute Methods.
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
index 6133c8b5174b4..5ef9d906f0e8b 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
@@ -10,7 +10,7 @@ add_mlir_dialect_library(MLIRSparseTensorTransforms
   SparseTensorConversion.cpp
   SparseTensorPasses.cpp
   SparseTensorRewriting.cpp
-  SparseTensorStorageLayout.cpp
+  SparseTensorDescriptor.cpp
   SparseVectorization.cpp
   Sparsification.cpp
   SparsificationAndBufferizationPass.cpp

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
index 131add5e7fc98..67d074e85de66 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
@@ -7,7 +7,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "CodegenUtils.h"
-#include "SparseTensorStorageLayout.h"
+#include "SparseTensorDescriptor.h"
 
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseStorageSpecifierToLLVM.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseStorageSpecifierToLLVM.cpp
index f34ed9779cfd3..a6f4dd3c2f718 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseStorageSpecifierToLLVM.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseStorageSpecifierToLLVM.cpp
@@ -7,9 +7,11 @@
 //===----------------------------------------------------------------------===//
 
 #include "CodegenUtils.h"
-#include "SparseTensorStorageLayout.h"
 
+#include "mlir/Conversion/LLVMCommon/StructBuilder.h"
+#include "mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h"
 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
+
 #include <optional>
 
 using namespace mlir;
@@ -262,7 +264,8 @@ class SpecifierGetterSetterOpConverter : public OpConversionPattern<SourceOp> {
       std::optional<unsigned> lvl;
       if (op.getLevel())
         lvl = (*op.getLevel());
-      unsigned idx = layout.getMemRefFieldIndex(op.getSpecifierKind(), lvl);
+      unsigned idx =
+          layout.getMemRefFieldIndex(toFieldKind(op.getSpecifierKind()), lvl);
       Value v = Base::onMemSize(rewriter, op, spec, idx);
       rewriter.replaceOp(op, v);
       return success();

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index 0005c4c6a969b..6d747c00910fc 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -16,7 +16,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "CodegenUtils.h"
-#include "SparseTensorStorageLayout.h"
+#include "SparseTensorDescriptor.h"
 
 #include "llvm/Support/FormatVariadic.h"
 

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorDescriptor.cpp
similarity index 51%
rename from mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.cpp
rename to mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorDescriptor.cpp
index a47d26e1b9595..5c363b0c781d5 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorDescriptor.cpp
@@ -6,7 +6,7 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include "SparseTensorStorageLayout.h"
+#include "SparseTensorDescriptor.h"
 #include "CodegenUtils.h"
 
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
@@ -116,117 +116,3 @@ Value sparse_tensor::SparseTensorDescriptor::getCrdMemRefOrView(
       /*size=*/ValueRange{size},
       /*step=*/ValueRange{stride});
 }
-
-//===----------------------------------------------------------------------===//
-// Public methods.
-//===----------------------------------------------------------------------===//
-
-constexpr FieldIndex kDataFieldStartingIdx = 0;
-
-void sparse_tensor::foreachFieldInSparseTensor(
-    const SparseTensorEncodingAttr enc,
-    llvm::function_ref<bool(FieldIndex, SparseTensorFieldKind, Level,
-                            DimLevelType)>
-        callback) {
-  assert(enc);
-
-#define RETURN_ON_FALSE(fidx, kind, dim, dlt)                                  \
-  if (!(callback(fidx, kind, dim, dlt)))                                       \
-    return;
-
-  const auto lvlTypes = enc.getLvlTypes();
-  const Level lvlRank = enc.getLvlRank();
-  const Level cooStart = getCOOStart(enc);
-  const Level end = cooStart == lvlRank ? cooStart : cooStart + 1;
-  FieldIndex fieldIdx = kDataFieldStartingIdx;
-  // Per-dimension storage.
-  for (Level l = 0; l < end; l++) {
-    // Dimension level types apply in order to the reordered dimension.
-    // As a result, the compound type can be constructed directly in the given
-    // order.
-    const auto dlt = lvlTypes[l];
-    if (isCompressedDLT(dlt) || isCompressedWithHiDLT(dlt)) {
-      RETURN_ON_FALSE(fieldIdx++, SparseTensorFieldKind::PosMemRef, l, dlt);
-      RETURN_ON_FALSE(fieldIdx++, SparseTensorFieldKind::CrdMemRef, l, dlt);
-    } else if (isSingletonDLT(dlt)) {
-      RETURN_ON_FALSE(fieldIdx++, SparseTensorFieldKind::CrdMemRef, l, dlt);
-    } else {
-      assert(isDenseDLT(dlt)); // no fields
-    }
-  }
-  // The values array.
-  RETURN_ON_FALSE(fieldIdx++, SparseTensorFieldKind::ValMemRef, -1u,
-                  DimLevelType::Undef);
-
-  // Put metadata at the end.
-  RETURN_ON_FALSE(fieldIdx++, SparseTensorFieldKind::StorageSpec, -1u,
-                  DimLevelType::Undef);
-
-#undef RETURN_ON_FALSE
-}
-
-void sparse_tensor::foreachFieldAndTypeInSparseTensor(
-    SparseTensorType stt,
-    llvm::function_ref<bool(Type, FieldIndex, SparseTensorFieldKind, Level,
-                            DimLevelType)>
-        callback) {
-  assert(stt.hasEncoding());
-  // Construct the basic types.
-  const Type crdType = stt.getCrdType();
-  const Type posType = stt.getPosType();
-  const Type eltType = stt.getElementType();
-
-  const Type metaDataType = StorageSpecifierType::get(stt.getEncoding());
-  // memref<? x pos>  positions
-  const Type posMemType = MemRefType::get({ShapedType::kDynamic}, posType);
-  // memref<? x crd>  coordinates
-  const Type crdMemType = MemRefType::get({ShapedType::kDynamic}, crdType);
-  // memref<? x eltType> values
-  const Type valMemType = MemRefType::get({ShapedType::kDynamic}, eltType);
-
-  foreachFieldInSparseTensor(
-      stt.getEncoding(),
-      [metaDataType, posMemType, crdMemType, valMemType,
-       callback](FieldIndex fieldIdx, SparseTensorFieldKind fieldKind,
-                 Level lvl, DimLevelType dlt) -> bool {
-        switch (fieldKind) {
-        case SparseTensorFieldKind::StorageSpec:
-          return callback(metaDataType, fieldIdx, fieldKind, lvl, dlt);
-        case SparseTensorFieldKind::PosMemRef:
-          return callback(posMemType, fieldIdx, fieldKind, lvl, dlt);
-        case SparseTensorFieldKind::CrdMemRef:
-          return callback(crdMemType, fieldIdx, fieldKind, lvl, dlt);
-        case SparseTensorFieldKind::ValMemRef:
-          return callback(valMemType, fieldIdx, fieldKind, lvl, dlt);
-        };
-        llvm_unreachable("unrecognized field kind");
-      });
-}
-
-unsigned sparse_tensor::getNumFieldsFromEncoding(SparseTensorEncodingAttr enc) {
-  unsigned numFields = 0;
-  foreachFieldInSparseTensor(enc,
-                             [&numFields](FieldIndex, SparseTensorFieldKind,
-                                          Level, DimLevelType) -> bool {
-                               numFields++;
-                               return true;
-                             });
-  return numFields;
-}
-
-unsigned
-sparse_tensor::getNumDataFieldsFromEncoding(SparseTensorEncodingAttr enc) {
-  unsigned numFields = 0; // one value memref
-  foreachFieldInSparseTensor(enc,
-                             [&numFields](FieldIndex fidx,
-                                          SparseTensorFieldKind, Level,
-                                          DimLevelType) -> bool {
-                               if (fidx >= kDataFieldStartingIdx)
-                                 numFields++;
-                               return true;
-                             });
-  numFields -= 1; // the last field is MetaData field
-  assert(numFields ==
-         getNumFieldsFromEncoding(enc) - kDataFieldStartingIdx - 1);
-  return numFields;
-}

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorDescriptor.h
similarity index 53%
rename from mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h
rename to mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorDescriptor.h
index a51fcc598ea9f..a9ed5751ab67e 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorDescriptor.h
@@ -1,4 +1,4 @@
-//===- SparseTensorStorageLayout.h ------------------------------*- C++ -*-===//
+//===- SparseTensorDescriptor.h ------------------------------*- C++ -*-===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -10,11 +10,11 @@
 //
 //===----------------------------------------------------------------------===//
 
-#ifndef MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_SPARSETENSORBUILDER_H_
-#define MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_SPARSETENSORBUILDER_H_
+#ifndef MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_SPARSETENSORDESCRIPTOR_H_
+#define MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_SPARSETENSORDESCRIPTOR_H_
 
-#include "mlir/Conversion/LLVMCommon/StructBuilder.h"
 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
+#include "mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h"
 #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
 #include "mlir/Transforms/DialectConversion.h"
@@ -27,199 +27,8 @@ namespace sparse_tensor {
 // layout scheme during "direct code generation" (i.e. when sparsification
 // generates the buffers as part of actual IR, in constrast with the library
 // approach where data structures are hidden behind opaque pointers).
-//
-// The sparse tensor storage scheme for a rank-dimensional tensor is organized
-// as a single compound type with the following fields. Note that every memref
-// with ? size actually behaves as a "vector", i.e. the stored size is the
-// capacity and the used size resides in the storage_specifier struct.
-//
-// struct {
-//   ; per-level l:
-//   ;  if dense:
-//        <nothing>
-//   ;  if compresed:
-//        memref<? x pos>  positions-l   ; positions for sparse level l
-//        memref<? x crd>  coordinates-l ; coordinates for sparse level l
-//   ;  if singleton:
-//        memref<? x crd>  coordinates-l ; coordinates for singleton level l
-//
-//   memref<? x eltType> values        ; values
-//
-//   struct sparse_tensor.storage_specifier {
-//     array<rank x int> lvlSizes    ; sizes/cardinalities for each level
-//     array<n x int> memSizes;      ; sizes/lengths for each data memref
-//   }
-// };
-//
-// In addition, for a "trailing COO region", defined as a compressed level
-// followed by one or more singleton levels, the default SOA storage that
-// is inherent to the TACO format is optimized into an AOS storage where
-// all coordinates of a stored element appear consecutively.  In such cases,
-// a special operation (sparse_tensor.coordinates_buffer) must be used to
-// access the AOS coordinates array. In the code below, the method `getCOOStart`
-// is used to find the start of the "trailing COO region".
-//
-// If the sparse tensor is a slice (produced by `tensor.extract_slice`
-// operation), instead of allocating a new sparse tensor for it, it reuses the
-// same sets of MemRefs but attaching a additional set of slicing-metadata for
-// per-dimension slice offset and stride.
-//
-// Examples.
-//
-// #CSR storage of 2-dim matrix yields
-//   memref<?xindex>                           ; positions-1
-//   memref<?xindex>                           ; coordinates-1
-//   memref<?xf64>                             ; values
-//   struct<(array<2 x i64>, array<3 x i64>)>) ; lvl0, lvl1, 3xsizes
-//
-// #COO storage of 2-dim matrix yields
-//   memref<?xindex>,                          ; positions-0, essentially [0,sz]
-//   memref<?xindex>                           ; AOS coordinates storage
-//   memref<?xf64>                             ; values
-//   struct<(array<2 x i64>, array<3 x i64>)>) ; lvl0, lvl1, 3xsizes
-//
-// Slice on #COO storage of 2-dim matrix yields
-//   ;; Inherited from the original sparse tensors
-//   memref<?xindex>,                          ; positions-0, essentially [0,sz]
-//   memref<?xindex>                           ; AOS coordinates storage
-//   memref<?xf64>                             ; values
-//   struct<(array<2 x i64>, array<3 x i64>,   ; lvl0, lvl1, 3xsizes
-//   ;; Extra slicing-metadata
-//           array<2 x i64>, array<2 x i64>)>) ; dim offset, dim stride.
-//
 //===----------------------------------------------------------------------===//
 
-enum class SparseTensorFieldKind : uint32_t {
-  StorageSpec = 0,
-  PosMemRef = 1,
-  CrdMemRef = 2,
-  ValMemRef = 3
-};
-
-static_assert(static_cast<uint32_t>(SparseTensorFieldKind::PosMemRef) ==
-              static_cast<uint32_t>(StorageSpecifierKind::PosMemSize));
-static_assert(static_cast<uint32_t>(SparseTensorFieldKind::CrdMemRef) ==
-              static_cast<uint32_t>(StorageSpecifierKind::CrdMemSize));
-static_assert(static_cast<uint32_t>(SparseTensorFieldKind::ValMemRef) ==
-              static_cast<uint32_t>(StorageSpecifierKind::ValMemSize));
-
-/// The type of field indices.  This alias is to help code be more
-/// self-documenting; unfortunately it is not type-checked, so it only
-/// provides documentation rather than doing anything to prevent mixups.
-using FieldIndex = unsigned;
-
-// TODO: Functions/methods marked with [NUMFIELDS] might should use
-// `FieldIndex` for their return type, via the same reasoning for why
-// `Dimension`/`Level` are used both for identifiers and ranks.
-
-/// For each field that will be allocated for the given sparse tensor
-/// encoding, calls the callback with the corresponding field index,
-/// field kind, level, and level-type (the last two are only for level
-/// memrefs).  The field index always starts with zero and increments
-/// by one between each callback invocation.  Ideally, all other methods
-/// should rely on this function to query a sparse tensor fields instead
-/// of relying on ad-hoc index computation.
-void foreachFieldInSparseTensor(
-    SparseTensorEncodingAttr,
-    llvm::function_ref<bool(
-        FieldIndex /*fieldIdx*/, SparseTensorFieldKind /*fieldKind*/,
-        Level /*lvl (if applicable)*/, DimLevelType /*DLT (if applicable)*/)>);
-
-/// Same as above, except that it also builds the Type for the corresponding
-/// field.
-void foreachFieldAndTypeInSparseTensor(
-    SparseTensorType,
-    llvm::function_ref<bool(Type /*fieldType*/, FieldIndex /*fieldIdx*/,
-                            SparseTensorFieldKind /*fieldKind*/,
-                            Level /*lvl (if applicable)*/,
-                            DimLevelType /*DLT (if applicable)*/)>);
-
-/// Gets the total number of fields for the given sparse tensor encoding.
-// TODO: See note [NUMFIELDS].
-unsigned getNumFieldsFromEncoding(SparseTensorEncodingAttr enc);
-
-/// Gets the total number of data fields (coordinate arrays, position
-/// arrays, and a value array) for the given sparse tensor encoding.
-// TODO: See note [NUMFIELDS].
-unsigned getNumDataFieldsFromEncoding(SparseTensorEncodingAttr enc);
-
-inline StorageSpecifierKind toSpecifierKind(SparseTensorFieldKind kind) {
-  assert(kind != SparseTensorFieldKind::StorageSpec);
-  return static_cast<StorageSpecifierKind>(kind);
-}
-
-inline SparseTensorFieldKind toFieldKind(StorageSpecifierKind kind) {
-  assert(kind != StorageSpecifierKind::LvlSize);
-  return static_cast<SparseTensorFieldKind>(kind);
-}
-
-/// Provides methods to access fields of a sparse tensor with the given
-/// encoding.
-class StorageLayout {
-public:
-  explicit StorageLayout(SparseTensorEncodingAttr enc) : enc(enc) {}
-
-  ///
-  /// Getters: get the field index for required field.
-  ///
-
-  FieldIndex getMemRefFieldIndex(SparseTensorFieldKind kind,
-                                 std::optional<Level> lvl) const {
-    return getFieldIndexAndStride(kind, lvl).first;
-  }
-
-  FieldIndex getMemRefFieldIndex(StorageSpecifierKind kind,
-                                 std::optional<Level> lvl) const {
-    return getMemRefFieldIndex(toFieldKind(kind), lvl);
-  }
-
-  // TODO: See note [NUMFIELDS].
-  static unsigned getNumFieldsFromEncoding(SparseTensorEncodingAttr enc) {
-    return sparse_tensor::getNumFieldsFromEncoding(enc);
-  }
-
-  static void foreachFieldInSparseTensor(
-      const SparseTensorEncodingAttr enc,
-      llvm::function_ref<bool(FieldIndex, SparseTensorFieldKind, Level,
-                              DimLevelType)>
-          callback) {
-    return sparse_tensor::foreachFieldInSparseTensor(enc, callback);
-  }
-
-  std::pair<FieldIndex, unsigned>
-  getFieldIndexAndStride(SparseTensorFieldKind kind,
-                         std::optional<Level> lvl) const {
-    FieldIndex fieldIdx = -1u;
-    unsigned stride = 1;
-    if (kind == SparseTensorFieldKind::CrdMemRef) {
-      assert(lvl.has_value());
-      const Level cooStart = getCOOStart(enc);
-      const Level lvlRank = enc.getLvlRank();
-      if (lvl.value() >= cooStart && lvl.value() < lvlRank) {
-        lvl = cooStart;
-        stride = lvlRank - cooStart;
-      }
-    }
-    foreachFieldInSparseTensor(
-        enc,
-        [lvl, kind, &fieldIdx](FieldIndex fIdx, SparseTensorFieldKind fKind,
-                               Level fLvl, DimLevelType dlt) -> bool {
-          if ((lvl && fLvl == lvl.value() && kind == fKind) ||
-              (kind == fKind && fKind == SparseTensorFieldKind::ValMemRef)) {
-            fieldIdx = fIdx;
-            // Returns false to break the iteration.
-            return false;
-          }
-          return true;
-        });
-    assert(fieldIdx != -1u);
-    return std::pair<FieldIndex, unsigned>(fieldIdx, stride);
-  }
-
-private:
-  SparseTensorEncodingAttr enc;
-};
-
 class SparseTensorSpecifier {
 public:
   explicit SparseTensorSpecifier(Value specifier)
@@ -249,10 +58,12 @@ class SparseTensorSpecifier {
 template <typename ValueArrayRef>
 class SparseTensorDescriptorImpl {
 protected:
+  // TODO: Functions/methods marked with [NUMFIELDS] might should use
+  // `FieldIndex` for their return type, via the same reasoning for why
+  // `Dimension`/`Level` are used both for identifiers and ranks.
   SparseTensorDescriptorImpl(SparseTensorType stt, ValueArrayRef fields)
-      : rType(stt), fields(fields) {
-    assert(stt.hasEncoding() &&
-           getNumFieldsFromEncoding(stt.getEncoding()) == getNumFields());
+      : rType(stt), fields(fields), layout(stt) {
+    assert(layout.getNumFields() == getNumFields());
     // We should make sure the class is trivially copyable (and should be small
     // enough) such that we can pass it by value.
     static_assert(std::is_trivially_copyable_v<
@@ -263,7 +74,6 @@ class SparseTensorDescriptorImpl {
   FieldIndex getMemRefFieldIndex(SparseTensorFieldKind kind,
                                  std::optional<Level> lvl) const {
     // Delegates to storage layout.
-    StorageLayout layout(rType.getEncoding());
     return layout.getMemRefFieldIndex(kind, lvl);
   }
 
@@ -336,7 +146,6 @@ class SparseTensorDescriptorImpl {
   }
 
   std::pair<FieldIndex, unsigned> getCrdMemRefIndexAndStride(Level lvl) const {
-    StorageLayout layout(rType.getEncoding());
     return layout.getFieldIndexAndStride(SparseTensorFieldKind::CrdMemRef, lvl);
   }
 
@@ -352,6 +161,7 @@ class SparseTensorDescriptorImpl {
 protected:
   SparseTensorType rType;
   ValueArrayRef fields;
+  StorageLayout layout;
 };
 
 /// Uses ValueRange for immutable descriptors.
@@ -465,4 +275,4 @@ getMutDescriptorFromTensorTuple(Value tensor, SmallVectorImpl<Value> &fields) {
 } // namespace sparse_tensor
 } // namespace mlir
 
-#endif // MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_SPARSETENSORBUILDER_H_
+#endif // MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_SPARSETENSODESCRIPTOR_H_

diff  --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 1c0643380f2e7..e435b00cfa6dc 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -2403,6 +2403,7 @@ cc_library(
     srcs = ["lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp"],
     hdrs = [
         "include/mlir/Dialect/SparseTensor/IR/SparseTensor.h",
+        "include/mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h",
         "include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h",
     ],
     includes = ["include"],


        


More information about the Mlir-commits mailing list