[Mlir-commits] [mlir] [mlir][sparse] introduce MapRef, unify conversion/codegen for reader (PR #68360)
Aart Bik
llvmlistbot at llvm.org
Thu Oct 5 15:18:58 PDT 2023
https://github.com/aartbik updated https://github.com/llvm/llvm-project/pull/68360
>From 6094912685a0cfa5c13e023e8ec97238a84fca2f Mon Sep 17 00:00:00 2001
From: Aart Bik <ajcbik at google.com>
Date: Thu, 5 Oct 2023 13:22:28 -0700
Subject: [PATCH 1/3] [mlir][sparse] introduce MapRef, unify conversion/codegen
for reader
This revision introduces a MapRef, which will support a future
generalization beyond permutations (e.g. block sparsity). This
revision also unifies the conversion/codegen paths for the
sparse_tensor.new operation from file (eg. the readers). Note
that more unification is planned as well as general affine
dim2lvl and lvl2dim (all marked with TODOs).
---
.../mlir/ExecutionEngine/SparseTensor/File.h | 156 ++++++----------
.../ExecutionEngine/SparseTensor/MapRef.h | 96 ++++++++++
.../ExecutionEngine/SparseTensor/Storage.h | 108 +----------
.../ExecutionEngine/SparseTensorRuntime.h | 8 -
.../SparseTensor/Transforms/CodegenUtils.cpp | 89 +++++++++
.../SparseTensor/Transforms/CodegenUtils.h | 18 ++
.../Transforms/SparseTensorCodegen.cpp | 73 ++------
.../Transforms/SparseTensorConversion.cpp | 111 ++---------
.../SparseTensor/CMakeLists.txt | 1 +
.../ExecutionEngine/SparseTensor/MapRef.cpp | 52 ++++++
.../ExecutionEngine/SparseTensorRuntime.cpp | 60 +++---
mlir/test/Dialect/SparseTensor/codegen.mlir | 172 +++++++++---------
.../test/Dialect/SparseTensor/conversion.mlir | 18 +-
13 files changed, 475 insertions(+), 487 deletions(-)
create mode 100644 mlir/include/mlir/ExecutionEngine/SparseTensor/MapRef.h
create mode 100644 mlir/lib/ExecutionEngine/SparseTensor/MapRef.cpp
diff --git a/mlir/include/mlir/ExecutionEngine/SparseTensor/File.h b/mlir/include/mlir/ExecutionEngine/SparseTensor/File.h
index 78c1a0544e3a521..9157bfa7e773239 100644
--- a/mlir/include/mlir/ExecutionEngine/SparseTensor/File.h
+++ b/mlir/include/mlir/ExecutionEngine/SparseTensor/File.h
@@ -20,6 +20,7 @@
#ifndef MLIR_EXECUTIONENGINE_SPARSETENSOR_FILE_H
#define MLIR_EXECUTIONENGINE_SPARSETENSOR_FILE_H
+#include "mlir/ExecutionEngine/SparseTensor/MapRef.h"
#include "mlir/ExecutionEngine/SparseTensor/Storage.h"
#include <fstream>
@@ -75,6 +76,10 @@ inline V readValue(char **linePtr, bool isPattern) {
} // namespace detail
+//===----------------------------------------------------------------------===//
+//
+// Reader class.
+//
//===----------------------------------------------------------------------===//
/// This class abstracts over the information stored in file headers,
@@ -132,6 +137,7 @@ class SparseTensorReader final {
/// Reads and parses the file's header.
void readHeader();
+ /// Returns the stored value kind.
ValueKind getValueKind() const { return valueKind_; }
/// Checks if a header has been successfully read.
@@ -185,58 +191,37 @@ class SparseTensorReader final {
/// valid after parsing the header.
void assertMatchesShape(uint64_t rank, const uint64_t *shape) const;
- /// Reads a sparse tensor element from the next line in the input file and
- /// returns the value of the element. Stores the coordinates of the element
- /// to the `dimCoords` array.
- template <typename V>
- V readElement(uint64_t dimRank, uint64_t *dimCoords) {
- assert(dimRank == getRank() && "rank mismatch");
- char *linePtr = readCoords(dimCoords);
- return detail::readValue<V>(&linePtr, isPattern());
- }
-
- /// Allocates a new COO object for `lvlSizes`, initializes it by reading
- /// all the elements from the file and applying `dim2lvl` to their
- /// dim-coordinates, and then closes the file. Templated on V only.
- template <typename V>
- SparseTensorCOO<V> *readCOO(uint64_t lvlRank, const uint64_t *lvlSizes,
- const uint64_t *dim2lvl);
-
/// Allocates a new sparse-tensor storage object with the given encoding,
/// initializes it by reading all the elements from the file, and then
/// closes the file. Templated on P, I, and V.
template <typename P, typename I, typename V>
SparseTensorStorage<P, I, V> *
readSparseTensor(uint64_t lvlRank, const uint64_t *lvlSizes,
- const DimLevelType *lvlTypes, const uint64_t *lvl2dim,
- const uint64_t *dim2lvl) {
- auto *lvlCOO = readCOO<V>(lvlRank, lvlSizes, dim2lvl);
+ const DimLevelType *lvlTypes, const uint64_t *dim2lvl,
+ const uint64_t *lvl2dim) {
+ const uint64_t dimRank = getRank();
+ MapRef map(dimRank, lvlRank, dim2lvl, lvl2dim);
+ auto *coo = readCOO<V>(map, lvlSizes);
auto *tensor = SparseTensorStorage<P, I, V>::newFromCOO(
- getRank(), getDimSizes(), lvlRank, lvlTypes, lvl2dim, *lvlCOO);
- delete lvlCOO;
+ dimRank, getDimSizes(), lvlRank, lvlTypes, lvl2dim, *coo);
+ delete coo;
return tensor;
}
/// Reads the COO tensor from the file, stores the coordinates and values to
/// the given buffers, returns a boolean value to indicate whether the COO
/// elements are sorted.
- /// Precondition: the buffers should have enough space to hold the elements.
template <typename C, typename V>
bool readToBuffers(uint64_t lvlRank, const uint64_t *dim2lvl,
- C *lvlCoordinates, V *values);
+ const uint64_t *lvl2dim, C *lvlCoordinates, V *values);
private:
- /// Attempts to read a line from the file. Is private because there's
- /// no reason for client code to call it.
+ /// Attempts to read a line from the file.
void readLine();
/// Reads the next line of the input file and parses the coordinates
/// into the `dimCoords` argument. Returns the position in the `line`
- /// buffer where the element's value should be parsed from. This method
- /// has been factored out from `readElement` to minimize code bloat
- /// for the generated library.
- ///
- /// Precondition: `dimCoords` is valid for `getRank()`.
+ /// buffer where the element's value should be parsed from.
template <typename C>
char *readCoords(C *dimCoords) {
readLine();
@@ -251,24 +236,20 @@ class SparseTensorReader final {
return linePtr;
}
- /// The internal implementation of `readCOO`. We template over
- /// `IsPattern` in order to perform LICM without needing to duplicate the
- /// source code.
- //
- // TODO: We currently take the `dim2lvl` argument as a `PermutationRef`
- // since that's what `readCOO` creates. Once we update `readCOO` to
- // functionalize the mapping, then this helper will just take that
- // same function.
+ /// Reads all the elements from the file while applying the given map.
+ template <typename V>
+ SparseTensorCOO<V> *readCOO(const MapRef &map, const uint64_t *lvlSizes);
+
+ /// The implementation of `readCOO` that is templated `IsPattern` in order
+ /// to perform LICM without needing to duplicate the source code.
template <typename V, bool IsPattern>
- void readCOOLoop(uint64_t lvlRank, detail::PermutationRef dim2lvl,
- SparseTensorCOO<V> *lvlCOO);
+ void readCOOLoop(const MapRef &map, SparseTensorCOO<V> *coo);
- /// The internal implementation of `readToBuffers`. We template over
- /// `IsPattern` in order to perform LICM without needing to duplicate the
- /// source code.
+ /// The internal implementation of `readToBuffers`. We template over
+ /// `IsPattern` in order to perform LICM without needing to duplicate
+ /// the source code.
template <typename C, typename V, bool IsPattern>
- bool readToBuffersLoop(uint64_t lvlRank, detail::PermutationRef dim2lvl,
- C *lvlCoordinates, V *values);
+ bool readToBuffersLoop(const MapRef &map, C *lvlCoordinates, V *values);
/// Reads the MME header of a general sparse matrix of type real.
void readMMEHeader();
@@ -288,96 +269,76 @@ class SparseTensorReader final {
char line[kColWidth];
};
+//===----------------------------------------------------------------------===//
+//
+// Reader class methods.
+//
//===----------------------------------------------------------------------===//
template <typename V>
-SparseTensorCOO<V> *SparseTensorReader::readCOO(uint64_t lvlRank,
- const uint64_t *lvlSizes,
- const uint64_t *dim2lvl) {
+SparseTensorCOO<V> *SparseTensorReader::readCOO(const MapRef &map,
+ const uint64_t *lvlSizes) {
assert(isValid() && "Attempt to readCOO() before readHeader()");
- const uint64_t dimRank = getRank();
- assert(lvlRank == dimRank && "Rank mismatch");
- detail::PermutationRef d2l(dimRank, dim2lvl);
// Prepare a COO object with the number of stored elems as initial capacity.
- auto *lvlCOO = new SparseTensorCOO<V>(lvlRank, lvlSizes, getNSE());
- // Do some manual LICM, to avoid assertions in the for-loop.
- const bool IsPattern = isPattern();
- if (IsPattern)
- readCOOLoop<V, true>(lvlRank, d2l, lvlCOO);
+ auto *coo = new SparseTensorCOO<V>(map.getLvlRank(), lvlSizes, getNSE());
+ // Enter the reading loop.
+ if (isPattern())
+ readCOOLoop<V, true>(map, coo);
else
- readCOOLoop<V, false>(lvlRank, d2l, lvlCOO);
+ readCOOLoop<V, false>(map, coo);
// Close the file and return the COO.
closeFile();
- return lvlCOO;
+ return coo;
}
template <typename V, bool IsPattern>
-void SparseTensorReader::readCOOLoop(uint64_t lvlRank,
- detail::PermutationRef dim2lvl,
- SparseTensorCOO<V> *lvlCOO) {
- const uint64_t dimRank = getRank();
+void SparseTensorReader::readCOOLoop(const MapRef &map,
+ SparseTensorCOO<V> *coo) {
+ const uint64_t dimRank = map.getDimRank();
+ const uint64_t lvlRank = map.getLvlRank();
+ assert(dimRank == getRank());
std::vector<uint64_t> dimCoords(dimRank);
std::vector<uint64_t> lvlCoords(lvlRank);
- for (uint64_t nse = getNSE(), k = 0; k < nse; ++k) {
- // We inline `readElement` here in order to avoid redundant
- // assertions, since they're guaranteed by the call to `isValid()`
- // and the construction of `dimCoords` above.
+ for (uint64_t k = 0, nse = getNSE(); k < nse; k++) {
char *linePtr = readCoords(dimCoords.data());
const V value = detail::readValue<V, IsPattern>(&linePtr);
- dim2lvl.pushforward(dimRank, dimCoords.data(), lvlCoords.data());
- // TODO: <https://github.com/llvm/llvm-project/issues/54179>
- lvlCOO->add(lvlCoords, value);
+ map.pushforward(dimCoords.data(), lvlCoords.data());
+ coo->add(lvlCoords, value);
}
}
template <typename C, typename V>
bool SparseTensorReader::readToBuffers(uint64_t lvlRank,
const uint64_t *dim2lvl,
+ const uint64_t *lvl2dim,
C *lvlCoordinates, V *values) {
assert(isValid() && "Attempt to readCOO() before readHeader()");
- // Construct a `PermutationRef` for the `pushforward` below.
- // TODO: This specific implementation does not generalize to arbitrary
- // mappings, but once we functionalize the `dim2lvl` argument we can
- // simply use that function instead.
- const uint64_t dimRank = getRank();
- assert(lvlRank == dimRank && "Rank mismatch");
- detail::PermutationRef d2l(dimRank, dim2lvl);
- // Do some manual LICM, to avoid assertions in the for-loop.
+ MapRef map(getRank(), lvlRank, dim2lvl, lvl2dim);
bool isSorted =
- isPattern()
- ? readToBuffersLoop<C, V, true>(lvlRank, d2l, lvlCoordinates, values)
- : readToBuffersLoop<C, V, false>(lvlRank, d2l, lvlCoordinates,
- values);
-
- // Close the file and return isSorted.
+ isPattern() ? readToBuffersLoop<C, V, true>(map, lvlCoordinates, values)
+ : readToBuffersLoop<C, V, false>(map, lvlCoordinates, values);
closeFile();
return isSorted;
}
template <typename C, typename V, bool IsPattern>
-bool SparseTensorReader::readToBuffersLoop(uint64_t lvlRank,
- detail::PermutationRef dim2lvl,
- C *lvlCoordinates, V *values) {
- const uint64_t dimRank = getRank();
+bool SparseTensorReader::readToBuffersLoop(const MapRef &map, C *lvlCoordinates,
+ V *values) {
+ const uint64_t dimRank = map.getDimRank();
+ const uint64_t lvlRank = map.getLvlRank();
const uint64_t nse = getNSE();
+ assert(dimRank == getRank());
std::vector<C> dimCoords(dimRank);
- // Read the first element with isSorted=false as a way to avoid accessing its
- // previous element.
bool isSorted = false;
char *linePtr;
- // We inline `readElement` here in order to avoid redundant assertions,
- // since they're guaranteed by the call to `isValid()` and the construction
- // of `dimCoords` above.
const auto readNextElement = [&]() {
linePtr = readCoords<C>(dimCoords.data());
- dim2lvl.pushforward(dimRank, dimCoords.data(), lvlCoordinates);
+ map.pushforward(dimCoords.data(), lvlCoordinates);
*values = detail::readValue<V, IsPattern>(&linePtr);
if (isSorted) {
- // Note that isSorted was set to false while reading the first element,
+ // Note that isSorted is set to false when reading the first element,
// to guarantee the safeness of using prevLvlCoords.
C *prevLvlCoords = lvlCoordinates - lvlRank;
- // TODO: define a new CoordsLT which is like ElementLT but doesn't have
- // the V parameter, and use it here.
for (uint64_t l = 0; l < lvlRank; ++l) {
if (prevLvlCoords[l] != lvlCoordinates[l]) {
if (prevLvlCoords[l] > lvlCoordinates[l])
@@ -393,7 +354,6 @@ bool SparseTensorReader::readToBuffersLoop(uint64_t lvlRank,
isSorted = true;
for (uint64_t n = 1; n < nse; ++n)
readNextElement();
-
return isSorted;
}
diff --git a/mlir/include/mlir/ExecutionEngine/SparseTensor/MapRef.h b/mlir/include/mlir/ExecutionEngine/SparseTensor/MapRef.h
new file mode 100644
index 000000000000000..1c155568802e579
--- /dev/null
+++ b/mlir/include/mlir/ExecutionEngine/SparseTensor/MapRef.h
@@ -0,0 +1,96 @@
+//===- MapRef.h - A dim2lvl/lvl2dim map encoding ----------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// A dim2lvl/lvl2dim map encoding class, with utility methods.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_EXECUTIONENGINE_SPARSETENSOR_MAPREF_H
+#define MLIR_EXECUTIONENGINE_SPARSETENSOR_MAPREF_H
+
+#include <cinttypes>
+
+#include <cassert>
+
+namespace mlir {
+namespace sparse_tensor {
+
+/// A class for capturing the sparse tensor type map with a compact encoding.
+///
+/// Currently, the following situations are supported:
+/// (1) map is an identity
+/// (2) map is a permutation
+/// (3) map has affine ops (restricted set)
+///
+/// The pushforward/backward operations are fast for (1) and (2) but
+/// incur some obvious overhead for situation (3).
+///
+class MapRef final {
+public:
+ MapRef(uint64_t d, uint64_t l, const uint64_t *d2l, const uint64_t *l2d);
+
+ //
+ // Push forward maps from dimensions to levels.
+ //
+
+ template <typename T> inline void pushforward(const T *in, T *out) const {
+ switch (kind) {
+ case MapKind::kIdentity:
+ for (uint64_t i = 0; i < dimRank; ++i)
+ out[i] = in[i]; // TODO: optimize with in == out ?
+ break;
+ case MapKind::kPermutation:
+ for (uint64_t i = 0; i < dimRank; ++i)
+ out[dim2lvl[i]] = in[i];
+ break;
+ case MapKind::kAffine:
+ assert(0 && "coming soon");
+ break;
+ }
+ }
+
+ //
+ // Push backward maps from levels to dimensions.
+ //
+
+ template <typename T> inline void pushbackward(const T *in, T *out) const {
+ switch (kind) {
+ case MapKind::kIdentity:
+ for (uint64_t i = 0; i < lvlRank; ++i)
+ out[i] = in[i];
+ break;
+ case MapKind::kPermutation:
+ for (uint64_t i = 0; i < lvlRank; ++i)
+ out[lvl2dim[i]] = in[i];
+ break;
+ case MapKind::kAffine:
+ assert(0 && "coming soon");
+ break;
+ }
+ }
+
+ uint64_t getDimRank() const { return dimRank; }
+ uint64_t getLvlRank() const { return lvlRank; }
+
+private:
+ enum class MapKind { kIdentity, kPermutation, kAffine };
+
+ bool isIdentity() const;
+ bool isPermutation() const;
+
+ MapKind kind;
+ const uint64_t dimRank;
+ const uint64_t lvlRank;
+ const uint64_t *const dim2lvl; // non-owning pointer
+ const uint64_t *const lvl2dim; // non-owning pointer
+};
+
+} // namespace sparse_tensor
+} // namespace mlir
+
+#endif // MLIR_EXECUTIONENGINE_SPARSETENSOR_MAPREF_H
diff --git a/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h b/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h
index 28c28c28109c3c7..37ad3c1b042313c 100644
--- a/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h
+++ b/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h
@@ -49,103 +49,6 @@ class SparseTensorEnumeratorBase;
template <typename P, typename C, typename V>
class SparseTensorEnumerator;
-namespace detail {
-
-/// Checks whether the `perm` array is a permutation of `[0 .. size)`.
-inline bool isPermutation(uint64_t size, const uint64_t *perm) {
- assert(perm && "Got nullptr for permutation");
- std::vector<bool> seen(size, false);
- for (uint64_t i = 0; i < size; ++i) {
- const uint64_t j = perm[i];
- if (j >= size || seen[j])
- return false;
- seen[j] = true;
- }
- for (uint64_t i = 0; i < size; ++i)
- if (!seen[i])
- return false;
- return true;
-}
-
-/// Wrapper around `isPermutation` to ensure consistent error messages.
-inline void assertIsPermutation(uint64_t size, const uint64_t *perm) {
-#ifndef NDEBUG
- if (!isPermutation(size, perm))
- MLIR_SPARSETENSOR_FATAL("Not a permutation of [0..%" PRIu64 ")\n", size);
-#endif
-}
-
-/// A class for capturing the knowledge that `isPermutation` is true.
-class PermutationRef final {
-public:
- /// Asserts `isPermutation` and returns the witness to that being true.
- explicit PermutationRef(uint64_t size, const uint64_t *perm)
- : permSize(size), perm(perm) {
- assertIsPermutation(size, perm);
- }
-
- uint64_t size() const { return permSize; }
-
- const uint64_t *data() const { return perm; }
-
- const uint64_t &operator[](uint64_t i) const {
- assert(i < permSize && "index is out of bounds");
- return perm[i];
- }
-
- /// Constructs a pushforward array of values. This method is the inverse
- /// of `permute` in the sense that for all `p` and `xs` we have:
- /// * `p.permute(p.pushforward(xs)) == xs`
- /// * `p.pushforward(p.permute(xs)) == xs`
- template <typename T>
- inline std::vector<T> pushforward(const std::vector<T> &values) const {
- return pushforward(values.size(), values.data());
- }
-
- template <typename T>
- inline std::vector<T> pushforward(uint64_t size, const T *values) const {
- std::vector<T> out(permSize);
- pushforward(size, values, out.data());
- return out;
- }
-
- template <typename T>
- inline void pushforward(uint64_t size, const T *values, T *out) const {
- assert(size == permSize && "size mismatch");
- for (uint64_t i = 0; i < permSize; ++i)
- out[perm[i]] = values[i];
- }
-
- /// Constructs a permuted array of values. This method is the inverse
- /// of `pushforward` in the sense that for all `p` and `xs` we have:
- /// * `p.permute(p.pushforward(xs)) == xs`
- /// * `p.pushforward(p.permute(xs)) == xs`
- template <typename T>
- inline std::vector<T> permute(const std::vector<T> &values) const {
- return permute(values.size(), values.data());
- }
-
- template <typename T>
- inline std::vector<T> permute(uint64_t size, const T *values) const {
- std::vector<T> out(permSize);
- permute(size, values, out.data());
- return out;
- }
-
- template <typename T>
- inline void permute(uint64_t size, const T *values, T *out) const {
- assert(size == permSize && "size mismatch");
- for (uint64_t i = 0; i < permSize; ++i)
- out[i] = values[perm[i]];
- }
-
-private:
- const uint64_t permSize;
- const uint64_t *const perm; // non-owning pointer.
-};
-
-} // namespace detail
-
/// Abstract base class for `SparseTensorStorage<P,C,V>`. This class
/// takes responsibility for all the `<P,C,V>`-independent aspects
/// of the tensor (e.g., shape, sparsity, permutation). In addition,
@@ -263,7 +166,7 @@ class SparseTensorStorageBase {
bool isUniqueLvl(uint64_t l) const { return isUniqueDLT(getLvlType(l)); }
/// Allocates a new enumerator. Callers must make sure to delete
- /// the enumerator when they're done with it. The first argument
+ /// the enumerator when they're done with it. The first argument
/// is the out-parameter for storing the newly allocated enumerator;
/// all other arguments are passed along to the `SparseTensorEnumerator`
/// ctor and must satisfy the preconditions/assertions thereof.
@@ -326,6 +229,7 @@ class SparseTensorStorageBase {
const std::vector<uint64_t> lvl2dim;
};
+
/// A memory-resident sparse tensor using a storage scheme based on
/// per-level sparse/dense annotations. This data structure provides
/// a bufferized form of a sparse tensor type. In contrast to generating
@@ -401,7 +305,7 @@ class SparseTensorStorage final : public SparseTensorStorageBase {
const DimLevelType *lvlTypes, const uint64_t *lvl2dim,
const intptr_t *lvlBufs);
- /// Allocates a new empty sparse tensor. The preconditions/assertions
+ /// Allocates a new empty sparse tensor. The preconditions/assertions
/// are as per the `SparseTensorStorageBase` ctor; which is to say,
/// the `dimSizes` and `lvlSizes` must both be "sizes" not "shapes",
/// since there's nowhere to reconstruct dynamic sizes from.
@@ -577,6 +481,7 @@ class SparseTensorStorage final : public SparseTensorStorageBase {
SparseTensorCOO<V> *toCOO(uint64_t trgRank, const uint64_t *trgSizes,
uint64_t srcRank, const uint64_t *src2trg) const {
// We inline `newEnumerator` to avoid virtual dispatch and allocation.
+ // TODO: use MapRef here too for the translation
SparseTensorEnumerator<P, C, V> enumerator(*this, trgRank, trgSizes,
srcRank, src2trg);
auto *coo = new SparseTensorCOO<V>(trgRank, trgSizes, values.size());
@@ -733,7 +638,7 @@ class SparseTensorStorage final : public SparseTensorStorageBase {
}
}
- /// Continues a single insertion path, outer to inner. The first
+ /// Continues a single insertion path, outer to inner. The first
/// argument is the level-coordinates for the value being inserted.
void insPath(const uint64_t *lvlCoords, uint64_t diffLvl, uint64_t full,
V val) {
@@ -875,7 +780,8 @@ class SparseTensorEnumeratorBase {
//===----------------------------------------------------------------------===//
template <typename P, typename C, typename V>
-class SparseTensorEnumerator final : public SparseTensorEnumeratorBase<V> {
+class SparseTensorEnumerator final
+ : public SparseTensorEnumeratorBase<V> {
using Base = SparseTensorEnumeratorBase<V>;
using StorageImpl = SparseTensorStorage<P, C, V>;
diff --git a/mlir/include/mlir/ExecutionEngine/SparseTensorRuntime.h b/mlir/include/mlir/ExecutionEngine/SparseTensorRuntime.h
index 8f320f04f23fc84..861b7eff65115b6 100644
--- a/mlir/include/mlir/ExecutionEngine/SparseTensorRuntime.h
+++ b/mlir/include/mlir/ExecutionEngine/SparseTensorRuntime.h
@@ -143,14 +143,6 @@ MLIR_CRUNNERUTILS_EXPORT void *_mlir_ciface_newSparseTensorFromReader(
MLIR_CRUNNERUTILS_EXPORT void _mlir_ciface_getSparseTensorReaderDimSizes(
StridedMemRefType<index_type, 1> *out, void *p);
-/// Returns the next element for the sparse tensor being read.
-#define DECL_GETNEXT(VNAME, V) \
- MLIR_CRUNNERUTILS_EXPORT void _mlir_ciface_getSparseTensorReaderNext##VNAME( \
- void *p, StridedMemRefType<index_type, 1> *dimCoordsRef, \
- StridedMemRefType<V, 0> *vref);
-MLIR_SPARSETENSOR_FOREVERY_V(DECL_GETNEXT)
-#undef DECL_GETNEXT
-
/// Reads the sparse tensor, stores the coordinates and values to the given
/// memrefs. Returns a boolean to indicate whether the COO elements are sorted.
#define DECL_GETNEXT(VNAME, V, CNAME, C) \
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
index ce77d7a519877d6..ffb1a550957edb8 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
@@ -729,3 +729,92 @@ Value sparse_tensor::createOrFoldSliceStrideOp(OpBuilder &builder, Location loc,
return constantIndex(builder, loc, *stride);
return builder.create<ToSliceStrideOp>(loc, tensor, APInt(64, dim));
}
+
+void sparse_tensor::fillDimShape(OpBuilder &builder, Location loc,
+ SparseTensorType stt,
+ SmallVectorImpl<Value> &out) {
+ out.clear();
+ out.reserve(stt.getDimRank());
+ for (const DynSize sh : stt.getDimShape()) {
+ const auto s = ShapedType::isDynamic(sh) ? 0 : sh;
+ out.push_back(constantIndex(builder, loc, s));
+ }
+}
+
+Value sparse_tensor::genReader(OpBuilder &builder, Location loc,
+ SparseTensorType stt, Value tensor,
+ /*out*/ SmallVectorImpl<Value> &dimShapesValues,
+ /*out*/ Value &dimSizesBuffer) {
+ // Construct the dimShapes buffer. The buffer contains the static size
+ // per dimension, or otherwise a zero for a dynamic size.
+ fillDimShape(builder, loc, stt, dimShapesValues);
+ Value dimShapesBuffer = allocaBuffer(builder, loc, dimShapesValues);
+ // Create the `CheckedSparseTensorReader`. This reader performs a
+ // consistency check on the static sizes, but accepts any size
+ // of each dimension with a dynamic size.
+ Type opaqueTp = getOpaquePointerType(builder);
+ Type eltTp = stt.getElementType();
+ Value valTp = constantPrimaryTypeEncoding(builder, loc, eltTp);
+ Value reader =
+ createFuncCall(builder, loc, "createCheckedSparseTensorReader", opaqueTp,
+ {tensor, dimShapesBuffer, valTp}, EmitCInterface::On)
+ .getResult(0);
+ // For static shapes, the shape buffer can be used right away. For dynamic
+ // shapes, use the information from the reader to construct a buffer that
+ // supplies the actual size for each dynamic dimension.
+ dimSizesBuffer = dimShapesBuffer;
+ if (stt.hasDynamicDimShape()) {
+ Type indexTp = builder.getIndexType();
+ auto memTp = MemRefType::get({ShapedType::kDynamic}, indexTp);
+ dimSizesBuffer =
+ createFuncCall(builder, loc, "getSparseTensorReaderDimSizes", memTp,
+ reader, EmitCInterface::On)
+ .getResult(0);
+ }
+ return reader;
+}
+
+Value sparse_tensor::genReaderBuffers(OpBuilder &builder, Location loc,
+ SparseTensorType stt,
+ SmallVectorImpl<Value> &dimShapesValues,
+ Value dimSizesBuffer,
+ /*out*/ Value &dim2lvlBuffer,
+ /*out*/ Value &lvl2dimBuffer) {
+ const Dimension dimRank = stt.getDimRank();
+ const Level lvlRank = stt.getLvlRank();
+ // For an identify mapping, the dim2lvl and lvl2dim mappings are
+ // identical as are dimSizes and lvlSizes, so buffers are reused
+ // as much as possible.
+ if (stt.isIdentity()) {
+ assert(dimRank == lvlRank);
+ SmallVector<Value> iotaValues;
+ iotaValues.reserve(lvlRank);
+ for (Level l = 0; l < lvlRank; l++)
+ iotaValues.push_back(constantIndex(builder, loc, l));
+ dim2lvlBuffer = lvl2dimBuffer = allocaBuffer(builder, loc, iotaValues);
+ return dimSizesBuffer;
+ }
+ // Otherwise, some code needs to be generated to set up the buffers.
+ // TODO: use the lvl2dim once available and deal with non-permutations!
+ const auto dimToLvl = stt.getDimToLvl();
+ assert(dimToLvl.isPermutation());
+ SmallVector<Value> dim2lvlValues(dimRank);
+ SmallVector<Value> lvl2dimValues(lvlRank);
+ SmallVector<Value> lvlSizesValues(lvlRank);
+ for (Level l = 0; l < lvlRank; l++) {
+ // The `d`th source variable occurs in the `l`th result position.
+ Dimension d = dimToLvl.getDimPosition(l);
+ Value lvl = constantIndex(builder, loc, l);
+ Value dim = constantIndex(builder, loc, d);
+ dim2lvlValues[d] = lvl;
+ lvl2dimValues[l] = dim;
+ if (stt.isDynamicDim(d))
+ lvlSizesValues[l] =
+ builder.create<memref::LoadOp>(loc, dimSizesBuffer, dim);
+ else
+ lvlSizesValues[l] = dimShapesValues[d];
+ }
+ dim2lvlBuffer = allocaBuffer(builder, loc, dim2lvlValues);
+ lvl2dimBuffer = allocaBuffer(builder, loc, lvl2dimValues);
+ return allocaBuffer(builder, loc, lvlSizesValues);
+}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
index 8145446751b9938..08ea019d8224a73 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
@@ -19,6 +19,7 @@
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/SparseTensor/IR/Enums.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
+#include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
#include "mlir/IR/Builders.h"
@@ -341,6 +342,23 @@ Value createOrFoldSliceOffsetOp(OpBuilder &builder, Location loc, Value tensor,
Value createOrFoldSliceStrideOp(OpBuilder &builder, Location loc, Value tensor,
Dimension dim);
+/// Populates the array with the dimension-shape of the given
+/// `SparseTensorType`, where dynamic sizes are represented by zero.
+void fillDimShape(OpBuilder &builder, Location loc, SparseTensorType stt,
+ SmallVectorImpl<Value> &out);
+
+/// Generates code that opens a reader and sets the dimension sizes.
+Value genReader(OpBuilder &builder, Location loc, SparseTensorType stt,
+ Value tensor,
+ /*out*/ SmallVectorImpl<Value> &dimShapeValues,
+ /*out*/ Value &dimSizesBuffer);
+
+/// Generates code to set up the buffer parameters for a reader.
+Value genReaderBuffers(OpBuilder &builder, Location loc, SparseTensorType stt,
+ SmallVectorImpl<Value> &dimShapeValues,
+ Value dimSizesBuffer, /*out*/ Value &dim2lvlBuffer,
+ /*out*/ Value &lvl2dimBuffer);
+
//===----------------------------------------------------------------------===//
// Inlined constant generators.
//
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index 7c362c086623b42..2c03f0a6020e6a8 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -1428,7 +1428,7 @@ struct SparseDisassembleOpConverter
}
};
-struct SparseNewOpConverter : public OpConversionPattern<NewOp> {
+struct SparseNewConverter : public OpConversionPattern<NewOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(NewOp op, OpAdaptor adaptor,
@@ -1440,7 +1440,7 @@ struct SparseNewOpConverter : public OpConversionPattern<NewOp> {
if (!dstTp.hasEncoding() || getCOOStart(dstTp.getEncoding()) != 0)
return failure();
- // Implement the NewOp(filename) as follows:
+ // Implement as follows:
// %reader = @createCheckedSparseTensorReader(%filename)
// %nse = @getSparseTensorNSE(%reader)
// %coo = bufferization.alloc_tensor an ordered COO with
@@ -1451,74 +1451,39 @@ struct SparseNewOpConverter : public OpConversionPattern<NewOp> {
// if (! %isSorted) sparse_tensor.sort_coo(%nse, %coordinates, %values)
// update storage specifier
// @delSparseTensorReader(%reader)
+ SmallVector<Value> dimShapesValues;
+ Value dimSizesBuffer;
+ Value reader = genReader(rewriter, loc, dstTp, adaptor.getOperands()[0],
+ dimShapesValues, dimSizesBuffer);
- // Allocate `SparseTensorReader` and perform all initial setup that
- // does not depend on lvlSizes (nor dimToLvl, lvlToDim, etc).
- const Type opaqueTp = getOpaquePointerType(rewriter);
- const Value fileName = op.getSource();
- SmallVector<Value> dimShapeValues;
- for (const DynSize sh : dstTp.getDimShape()) {
- const auto s = ShapedType::isDynamic(sh) ? 0 : sh;
- dimShapeValues.push_back(constantIndex(rewriter, loc, s));
- }
- Value dimShapeBuffer = allocaBuffer(rewriter, loc, dimShapeValues);
- Value valTp =
- constantPrimaryTypeEncoding(rewriter, loc, dstTp.getElementType());
- Value reader =
- createFuncCall(rewriter, loc, "createCheckedSparseTensorReader",
- opaqueTp, {fileName, dimShapeBuffer, valTp},
- EmitCInterface::On)
- .getResult(0);
+ // Get the number of stored entries.
const Type indexTp = rewriter.getIndexType();
- const Dimension dimRank = dstTp.getDimRank();
- const Level lvlRank = dstTp.getLvlRank();
+ Value nse = createFuncCall(rewriter, loc, "getSparseTensorReaderNSE",
+ {indexTp}, {reader}, EmitCInterface::Off)
+ .getResult(0);
- // If the result tensor has dynamic dimensions, get the dynamic sizes from
- // the sparse tensor reader.
+ // Construct allocation for each field.
SmallVector<Value> dynSizes;
if (dstTp.hasDynamicDimShape()) {
- auto memTp = MemRefType::get({ShapedType::kDynamic}, indexTp);
- Value dimSizesBuffer =
- createFuncCall(rewriter, loc, "getSparseTensorReaderDimSizes", memTp,
- reader, EmitCInterface::On)
- .getResult(0);
for (const auto &d : llvm::enumerate(dstTp.getDimShape()))
if (ShapedType::isDynamic(d.value()))
dynSizes.push_back(rewriter.create<memref::LoadOp>(
loc, dimSizesBuffer, constantIndex(rewriter, loc, d.index())));
}
-
- // Get the number of stored entries.
- Value nse = createFuncCall(rewriter, loc, "getSparseTensorReaderNSE",
- {indexTp}, {reader}, EmitCInterface::Off)
- .getResult(0);
- // Construct allocation for each field.
SmallVector<Value> fields;
createAllocFields(rewriter, loc, dstTp, dynSizes, /*enableInit=*/false,
fields, nse);
MutSparseTensorDescriptor desc(dstTp, fields);
- // Construct the `dimToLvl` buffer for handing off to the runtime library.
- SmallVector<Value> dimToLvlValues(dimRank);
- if (!dstTp.isIdentity()) {
- const auto dimToLvl = dstTp.getDimToLvl();
- assert(dimToLvl.isPermutation() && "Got non-permutation");
- for (Level l = 0; l < lvlRank; l++) {
- const Dimension d = dimToLvl.getDimPosition(l);
- dimToLvlValues[d] = constantIndex(rewriter, loc, l);
- }
- } else {
- // The `SparseTensorType` ctor already ensures `dimRank == lvlRank`
- // when `isIdentity`; so no need to re-assert it here.
- for (Dimension d = 0; d < dimRank; d++)
- dimToLvlValues[d] = constantIndex(rewriter, loc, d);
- }
- Value dimToLvl = allocaBuffer(rewriter, loc, dimToLvlValues);
+ // Now construct the dim2lvl and lvl2dim buffers.
+ Value dim2lvlBuffer;
+ Value lvl2dimBuffer;
+ genReaderBuffers(rewriter, loc, dstTp, dimShapesValues, dimSizesBuffer,
+ dim2lvlBuffer, lvl2dimBuffer);
// Read the COO tensor data.
Value xs = desc.getAOSMemRef();
Value ys = desc.getValMemRef();
-
const Type boolTp = rewriter.getIntegerType(1);
const Type elemTp = dstTp.getElementType();
const Type crdTp = dstTp.getCrdType();
@@ -1527,11 +1492,13 @@ struct SparseNewOpConverter : public OpConversionPattern<NewOp> {
primaryTypeFunctionSuffix(elemTp)};
Value isSorted =
createFuncCall(rewriter, loc, readToBuffersFuncName, {boolTp},
- {reader, dimToLvl, xs, ys}, EmitCInterface::On)
+ {reader, dim2lvlBuffer, lvl2dimBuffer, xs, ys},
+ EmitCInterface::On)
.getResult(0);
// If the destination tensor is a sorted COO, we need to sort the COO tensor
// data if the input elements aren't sorted yet.
+ const Level lvlRank = dstTp.getLvlRank();
if (dstTp.isOrderedLvl(lvlRank - 1)) {
Value kFalse = constantI1(rewriter, loc, false);
Value notSorted = rewriter.create<arith::CmpIOp>(
@@ -1593,7 +1560,7 @@ void mlir::populateSparseTensorCodegenPatterns(
StorageSpecifierKind::DimStride>,
SparseToPositionsConverter, SparseToCoordinatesConverter,
SparseToCoordinatesBufferConverter, SparseToValuesConverter,
- SparseConvertConverter, SparseNewOpConverter,
+ SparseConvertConverter, SparseNewConverter,
SparseNumberOfEntriesConverter>(typeConverter,
patterns.getContext());
patterns.add<SparseTensorDeallocConverter>(
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
index a3361c2cd48c6dd..eb0c5160e8d6193 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
@@ -46,8 +46,7 @@ static std::optional<Type> convertSparseTensorTypes(Type type) {
return std::nullopt;
}
-/// Replaces the `op` with a `CallOp` to the function reference returned
-/// by `getFunc()`.
+/// Replaces the `op` with a `CallOp` to the `getFunc()` function reference.
static func::CallOp replaceOpWithFuncCall(RewriterBase &rewriter, Operation *op,
StringRef name, TypeRange resultType,
ValueRange operands,
@@ -141,27 +140,6 @@ static SmallVector<Value> getDimSizes(OpBuilder &builder, Location loc,
return out;
}
-/// Populates the array with the dimension-shape of the given
-/// `SparseTensorType`, where dynamic sizes are represented by zero.
-static void fillDimShape(OpBuilder &builder, Location loc, SparseTensorType stt,
- SmallVectorImpl<Value> &out) {
- out.clear();
- out.reserve(stt.getDimRank());
- for (const DynSize sh : stt.getDimShape()) {
- const auto s = ShapedType::isDynamic(sh) ? 0 : sh;
- out.push_back(constantIndex(builder, loc, s));
- }
-}
-
-/// Returns an array with the dimension-shape of the given `SparseTensorType`,
-/// where dynamic sizes are represented by zero.
-static SmallVector<Value> getDimShape(OpBuilder &builder, Location loc,
- SparseTensorType stt) {
- SmallVector<Value> out;
- fillDimShape(builder, loc, stt, out);
- return out;
-}
-
/// Generates an uninitialized buffer of the given size and type,
/// but returns it as type `memref<? x $tp>` (rather than as type
/// `memref<$sz x $tp>`). Unlike temporary buffers on the stack,
@@ -503,84 +481,27 @@ class SparseTensorNewConverter : public OpConversionPattern<NewOp> {
const auto stt = getSparseTensorType(op);
if (!stt.hasEncoding())
return failure();
- const Dimension dimRank = stt.getDimRank();
- const Level lvlRank = stt.getLvlRank();
- // Construct the dimShape.
- SmallVector<Value> dimShapeValues = getDimShape(rewriter, loc, stt);
- Value dimShapeBuffer = allocaBuffer(rewriter, loc, dimShapeValues);
- // Allocate `SparseTensorReader` and perform all initial setup that
- // does not depend on lvlSizes (nor dimToLvl, lvlToDim, etc).
- Type opaqueTp = getOpaquePointerType(rewriter);
- Value valTp =
- constantPrimaryTypeEncoding(rewriter, loc, stt.getElementType());
- Value reader =
- createFuncCall(rewriter, loc, "createCheckedSparseTensorReader",
- opaqueTp,
- {adaptor.getOperands()[0], dimShapeBuffer, valTp},
- EmitCInterface::On)
- .getResult(0);
- // Construct the lvlSizes. If the dimShape is static, then it's
- // identical to dimSizes: so we can compute lvlSizes entirely at
- // compile-time. If dimShape is dynamic, then we'll need to generate
- // code for computing lvlSizes from the `reader`'s actual dimSizes.
- //
- // TODO: For now we're still assuming `dimToLvl` is a permutation.
- // But since we're computing lvlSizes here (rather than in the runtime),
- // we can easily generalize that simply by adjusting this code.
- //
- // FIXME: reduce redundancy vs `NewCallParams::genBuffers`.
+ // Construct the reader opening method calls.
+ SmallVector<Value> dimShapesValues;
Value dimSizesBuffer;
- if (stt.hasDynamicDimShape()) {
- Type indexTp = rewriter.getIndexType();
- auto memTp = MemRefType::get({ShapedType::kDynamic}, indexTp);
- dimSizesBuffer =
- createFuncCall(rewriter, loc, "getSparseTensorReaderDimSizes", memTp,
- reader, EmitCInterface::On)
- .getResult(0);
- }
- Value lvlSizesBuffer;
- Value lvlToDimBuffer;
- Value dimToLvlBuffer;
- if (!stt.isIdentity()) {
- const auto dimToLvl = stt.getDimToLvl();
- assert(dimToLvl.isPermutation() && "Got non-permutation");
- // We preinitialize `dimToLvlValues` since we need random-access writing.
- // And we preinitialize the others for stylistic consistency.
- SmallVector<Value> lvlSizeValues(lvlRank);
- SmallVector<Value> lvlToDimValues(lvlRank);
- SmallVector<Value> dimToLvlValues(dimRank);
- for (Level l = 0; l < lvlRank; l++) {
- // The `d`th source variable occurs in the `l`th result position.
- Dimension d = dimToLvl.getDimPosition(l);
- Value lvl = constantIndex(rewriter, loc, l);
- Value dim = constantIndex(rewriter, loc, d);
- dimToLvlValues[d] = lvl;
- lvlToDimValues[l] = dim;
- lvlSizeValues[l] =
- stt.isDynamicDim(d)
- ? rewriter.create<memref::LoadOp>(loc, dimSizesBuffer, dim)
- : dimShapeValues[d];
- }
- lvlSizesBuffer = allocaBuffer(rewriter, loc, lvlSizeValues);
- lvlToDimBuffer = allocaBuffer(rewriter, loc, lvlToDimValues);
- dimToLvlBuffer = allocaBuffer(rewriter, loc, dimToLvlValues);
- } else {
- // The `SparseTensorType` ctor already ensures `dimRank == lvlRank`
- // when `isIdentity`; so no need to re-assert it here.
- SmallVector<Value> iotaValues;
- iotaValues.reserve(lvlRank);
- for (Level l = 0; l < lvlRank; l++)
- iotaValues.push_back(constantIndex(rewriter, loc, l));
- lvlSizesBuffer = dimSizesBuffer ? dimSizesBuffer : dimShapeBuffer;
- dimToLvlBuffer = lvlToDimBuffer = allocaBuffer(rewriter, loc, iotaValues);
- }
+ Value reader = genReader(rewriter, loc, stt, adaptor.getOperands()[0],
+ dimShapesValues, dimSizesBuffer);
+ // Now construct the lvlSizes, dim2lvl, and lvl2dim buffers.
+ Value dim2lvlBuffer;
+ Value lvl2dimBuffer;
+ Value lvlSizesBuffer =
+ genReaderBuffers(rewriter, loc, stt, dimShapesValues, dimSizesBuffer,
+ dim2lvlBuffer, lvl2dimBuffer);
// Use the `reader` to parse the file.
+ Type opaqueTp = getOpaquePointerType(rewriter);
+ Type eltTp = stt.getElementType();
+ Value valTp = constantPrimaryTypeEncoding(rewriter, loc, eltTp);
SmallVector<Value, 8> params{
reader,
lvlSizesBuffer,
genLvlTypesBuffer(rewriter, loc, stt),
- lvlToDimBuffer,
- dimToLvlBuffer,
+ dim2lvlBuffer,
+ lvl2dimBuffer,
constantPosTypeEncoding(rewriter, loc, stt.getEncoding()),
constantCrdTypeEncoding(rewriter, loc, stt.getEncoding()),
valTp};
diff --git a/mlir/lib/ExecutionEngine/SparseTensor/CMakeLists.txt b/mlir/lib/ExecutionEngine/SparseTensor/CMakeLists.txt
index 085d83634a702a8..c48af17b2d94bb7 100644
--- a/mlir/lib/ExecutionEngine/SparseTensor/CMakeLists.txt
+++ b/mlir/lib/ExecutionEngine/SparseTensor/CMakeLists.txt
@@ -7,6 +7,7 @@
# that is reserved/intended for shared libraries only.
add_mlir_library(MLIRSparseTensorRuntime
File.cpp
+ MapRef.cpp
NNZ.cpp
Storage.cpp
diff --git a/mlir/lib/ExecutionEngine/SparseTensor/MapRef.cpp b/mlir/lib/ExecutionEngine/SparseTensor/MapRef.cpp
new file mode 100644
index 000000000000000..ed458afeae746bc
--- /dev/null
+++ b/mlir/lib/ExecutionEngine/SparseTensor/MapRef.cpp
@@ -0,0 +1,52 @@
+//===- MapRef.cpp - A dim2lvl/lvl2dim map reference wrapper ---------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include <vector>
+
+#include "mlir/ExecutionEngine/SparseTensor/MapRef.h"
+
+mlir::sparse_tensor::MapRef::MapRef(uint64_t d, uint64_t l, const uint64_t *d2l,
+ const uint64_t *l2d)
+ : dimRank(d), lvlRank(l), dim2lvl(d2l), lvl2dim(l2d) {
+ assert(d2l && l2d);
+ // Determine the kind of mapping (and asserts on simple inference).
+ if (isIdentity()) {
+ kind = MapKind::kIdentity;
+ for (uint64_t i = 0; i < dimRank; i++)
+ assert(lvl2dim[i] == i);
+ } else if (isPermutation()) {
+ kind = MapKind::kPermutation;
+ for (uint64_t i = 0; i < dimRank; i++)
+ assert(lvl2dim[dim2lvl[i]] == i);
+ } else {
+ kind = MapKind::kAffine;
+ }
+}
+
+bool mlir::sparse_tensor::MapRef::isIdentity() const {
+ if (dimRank != lvlRank)
+ return false;
+ for (uint64_t i = 0; i < dimRank; i++) {
+ if (dim2lvl[i] != i)
+ return false;
+ }
+ return true;
+}
+
+bool mlir::sparse_tensor::MapRef::isPermutation() const {
+ if (dimRank != lvlRank)
+ return false;
+ std::vector<bool> seen(dimRank, false);
+ for (uint64_t i = 0; i < dimRank; i++) {
+ const uint64_t j = dim2lvl[i];
+ if (j >= dimRank || seen[j])
+ return false;
+ seen[j] = true;
+ }
+ return true;
+}
diff --git a/mlir/lib/ExecutionEngine/SparseTensorRuntime.cpp b/mlir/lib/ExecutionEngine/SparseTensorRuntime.cpp
index 82cb6d3aeefa35f..5b910716c0f9e59 100644
--- a/mlir/lib/ExecutionEngine/SparseTensorRuntime.cpp
+++ b/mlir/lib/ExecutionEngine/SparseTensorRuntime.cpp
@@ -226,11 +226,7 @@ extern "C" {
static_assert(std::is_same<index_type, uint64_t>::value,
"Expected index_type == uint64_t");
-// TODO: this swiss-army-knife should be split up into separate functions
-// for each action, since the various actions don't agree on (1) whether
-// the first two arguments are "sizes" vs "shapes", (2) whether the "lvl"
-// arguments are actually storage-levels vs target tensor-dimensions,
-// (3) whether all the arguments are actually used/required.
+// The Swiss-army-knife for sparse tensor creation.
void *_mlir_ciface_newSparseTensor( // NOLINT
StridedMemRefType<index_type, 1> *dimSizesRef,
StridedMemRefType<index_type, 1> *lvlSizesRef,
@@ -241,18 +237,18 @@ void *_mlir_ciface_newSparseTensor( // NOLINT
ASSERT_NO_STRIDE(dimSizesRef);
ASSERT_NO_STRIDE(lvlSizesRef);
ASSERT_NO_STRIDE(lvlTypesRef);
- ASSERT_NO_STRIDE(lvl2dimRef);
ASSERT_NO_STRIDE(dim2lvlRef);
+ ASSERT_NO_STRIDE(lvl2dimRef);
const uint64_t dimRank = MEMREF_GET_USIZE(dimSizesRef);
const uint64_t lvlRank = MEMREF_GET_USIZE(lvlSizesRef);
- ASSERT_USIZE_EQ(dim2lvlRef, dimRank);
ASSERT_USIZE_EQ(lvlTypesRef, lvlRank);
+ ASSERT_USIZE_EQ(dim2lvlRef, dimRank);
ASSERT_USIZE_EQ(lvl2dimRef, lvlRank);
const index_type *dimSizes = MEMREF_GET_PAYLOAD(dimSizesRef);
const index_type *lvlSizes = MEMREF_GET_PAYLOAD(lvlSizesRef);
const DimLevelType *lvlTypes = MEMREF_GET_PAYLOAD(lvlTypesRef);
- const index_type *lvl2dim = MEMREF_GET_PAYLOAD(lvl2dimRef);
const index_type *dim2lvl = MEMREF_GET_PAYLOAD(dim2lvlRef);
+ const index_type *lvl2dim = MEMREF_GET_PAYLOAD(lvl2dimRef);
// Rewrite kIndex to kU64, to avoid introducing a bunch of new cases.
// This is safe because of the static_assert above.
@@ -403,10 +399,7 @@ MLIR_SPARSETENSOR_FOREVERY_O(IMPL_SPARSECOORDINATES)
#undef IMPL_SPARSECOORDINATES
#undef IMPL_GETOVERHEAD
-// TODO: while this API design will work for arbitrary dim2lvl mappings,
-// we should probably move the `dimCoords`-to-`lvlCoords` computation into
-// codegen (since that could enable optimizations to remove the intermediate
-// memref).
+// TODO: use MapRef here for translation of coordinates
#define IMPL_ADDELT(VNAME, V) \
void *_mlir_ciface_addElt##VNAME( \
void *lvlCOO, StridedMemRefType<V, 0> *vref, \
@@ -506,44 +499,33 @@ void _mlir_ciface_getSparseTensorReaderDimSizes(
aliasIntoMemref(reader.getRank(), dimSizes, *out);
}
-#define IMPL_GETNEXT(VNAME, V) \
- void _mlir_ciface_getSparseTensorReaderNext##VNAME( \
- void *p, StridedMemRefType<index_type, 1> *dimCoordsRef, \
- StridedMemRefType<V, 0> *vref) { \
- assert(p &&vref); \
- auto &reader = *static_cast<SparseTensorReader *>(p); \
- ASSERT_NO_STRIDE(dimCoordsRef); \
- const uint64_t dimRank = MEMREF_GET_USIZE(dimCoordsRef); \
- index_type *dimCoords = MEMREF_GET_PAYLOAD(dimCoordsRef); \
- V *value = MEMREF_GET_PAYLOAD(vref); \
- *value = reader.readElement<V>(dimRank, dimCoords); \
- }
-MLIR_SPARSETENSOR_FOREVERY_V(IMPL_GETNEXT)
-#undef IMPL_GETNEXT
-
#define IMPL_GETNEXT(VNAME, V, CNAME, C) \
bool _mlir_ciface_getSparseTensorReaderReadToBuffers##CNAME##VNAME( \
void *p, StridedMemRefType<index_type, 1> *dim2lvlRef, \
+ StridedMemRefType<index_type, 1> *lvl2dimRef, \
StridedMemRefType<C, 1> *cref, StridedMemRefType<V, 1> *vref) { \
assert(p); \
auto &reader = *static_cast<SparseTensorReader *>(p); \
+ ASSERT_NO_STRIDE(dim2lvlRef); \
+ ASSERT_NO_STRIDE(lvl2dimRef); \
ASSERT_NO_STRIDE(cref); \
ASSERT_NO_STRIDE(vref); \
- ASSERT_NO_STRIDE(dim2lvlRef); \
+ const uint64_t dimRank = reader.getRank(); \
+ const uint64_t lvlRank = MEMREF_GET_USIZE(lvl2dimRef); \
const uint64_t cSize = MEMREF_GET_USIZE(cref); \
const uint64_t vSize = MEMREF_GET_USIZE(vref); \
- const uint64_t lvlRank = reader.getRank(); \
- assert(vSize *lvlRank <= cSize); \
+ ASSERT_USIZE_EQ(dim2lvlRef, dimRank); \
+ assert(cSize >= lvlRank * vSize); \
assert(vSize >= reader.getNSE() && "Not enough space in buffers"); \
- ASSERT_USIZE_EQ(dim2lvlRef, lvlRank); \
+ (void)dimRank; \
(void)cSize; \
(void)vSize; \
- (void)lvlRank; \
+ index_type *dim2lvl = MEMREF_GET_PAYLOAD(dim2lvlRef); \
+ index_type *lvl2dim = MEMREF_GET_PAYLOAD(lvl2dimRef); \
C *lvlCoordinates = MEMREF_GET_PAYLOAD(cref); \
V *values = MEMREF_GET_PAYLOAD(vref); \
- index_type *dim2lvl = MEMREF_GET_PAYLOAD(dim2lvlRef); \
- return reader.readToBuffers<C, V>(lvlRank, dim2lvl, lvlCoordinates, \
- values); \
+ return reader.readToBuffers<C, V>(lvlRank, dim2lvl, lvl2dim, \
+ lvlCoordinates, values); \
}
MLIR_SPARSETENSOR_FOREVERY_V_O(IMPL_GETNEXT)
#undef IMPL_GETNEXT
@@ -551,8 +533,8 @@ MLIR_SPARSETENSOR_FOREVERY_V_O(IMPL_GETNEXT)
void *_mlir_ciface_newSparseTensorFromReader(
void *p, StridedMemRefType<index_type, 1> *lvlSizesRef,
StridedMemRefType<DimLevelType, 1> *lvlTypesRef,
- StridedMemRefType<index_type, 1> *lvl2dimRef,
- StridedMemRefType<index_type, 1> *dim2lvlRef, OverheadType posTp,
+ StridedMemRefType<index_type, 1> *dim2lvlRef,
+ StridedMemRefType<index_type, 1> *lvl2dimRef, OverheadType posTp,
OverheadType crdTp, PrimaryType valTp) {
assert(p);
SparseTensorReader &reader = *static_cast<SparseTensorReader *>(p);
@@ -568,13 +550,13 @@ void *_mlir_ciface_newSparseTensorFromReader(
(void)dimRank;
const index_type *lvlSizes = MEMREF_GET_PAYLOAD(lvlSizesRef);
const DimLevelType *lvlTypes = MEMREF_GET_PAYLOAD(lvlTypesRef);
- const index_type *lvl2dim = MEMREF_GET_PAYLOAD(lvl2dimRef);
const index_type *dim2lvl = MEMREF_GET_PAYLOAD(dim2lvlRef);
+ const index_type *lvl2dim = MEMREF_GET_PAYLOAD(lvl2dimRef);
#define CASE(p, c, v, P, C, V) \
if (posTp == OverheadType::p && crdTp == OverheadType::c && \
valTp == PrimaryType::v) \
return static_cast<void *>(reader.readSparseTensor<P, C, V>( \
- lvlRank, lvlSizes, lvlTypes, lvl2dim, dim2lvl));
+ lvlRank, lvlSizes, lvlTypes, dim2lvl, lvl2dim));
#define CASE_SECSAME(p, v, P, V) CASE(p, p, v, P, P, V)
// Rewrite kIndex to kU64, to avoid introducing a bunch of new cases.
// This is safe because of the static_assert above.
diff --git a/mlir/test/Dialect/SparseTensor/codegen.mlir b/mlir/test/Dialect/SparseTensor/codegen.mlir
index 4ac768c21aff8fc..ff523e70bfc914a 100644
--- a/mlir/test/Dialect/SparseTensor/codegen.mlir
+++ b/mlir/test/Dialect/SparseTensor/codegen.mlir
@@ -423,7 +423,7 @@ func.func @sparse_expansion3(%arg0: index, %arg1: index) -> memref<?xindex> {
// CHECK-DAG: %[[A9:.*]] = arith.constant 0.000000e+00 : f64
// CHECK-DAG: %[[A10:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[A11:.*]] = arith.constant 0 : index
-// CHECK: sparse_tensor.sort hybrid_quick_sort %[[A7]], %[[A6]]
+// CHECK: sparse_tensor.sort_coo hybrid_quick_sort %[[A7]], %[[A6]]
// CHECK: %[[A12:.*]]:4 = scf.for %[[A13:.*]] = %[[A11]] to %[[A7]] step %[[A10]] iter_args(%[[A14:.*]] = %[[A0]], %[[A15:.*]] = %[[A1]], %[[A16:.*]] = %[[A2]], %[[A17:.*]] = %[[A3]])
// CHECK: %[[A18:.*]] = memref.load %[[A6]]{{\[}}%[[A13]]] : memref<?xindex>
// CHECK: %[[A19:.*]] = memref.load %[[A4]]{{\[}}%[[A18]]] : memref<?xf64>
@@ -471,7 +471,7 @@ func.func @sparse_compression_1d(%tensor: tensor<100xf64, #SV>,
// CHECK: %[[A11:.*]] = arith.constant 0.000000e+00 : f64
// CHECK: %[[A12:.*]] = arith.constant 1 : index
// CHECK: %[[A13:.*]] = arith.constant 0 : index
-// CHECK: sparse_tensor.sort hybrid_quick_sort %[[A7]], %[[A6]]
+// CHECK: sparse_tensor.sort_coo hybrid_quick_sort %[[A7]], %[[A6]]
// CHECK: %[[A14:.*]]:4 = scf.for %[[A15:.*]] = %[[A13]] to %[[A7]] step %[[A12]] iter_args(%[[A16:.*]] = %[[A0]], %[[A17:.*]] = %[[A1]], %[[A18:.*]] = %[[A2]], %[[A19:.*]] = %[[A3]]) -> (memref<?xi32>, memref<?xi64>, memref<?xf64>, !sparse_tensor.storage_specifier
// CHECK: %[[A20:.*]] = memref.load %[[A6]]{{\[}}%[[A15]]] : memref<?xindex>
// CHECK: %[[A21:.*]] = memref.load %[[A4]]{{\[}}%[[A20]]] : memref<?xf64>
@@ -507,7 +507,7 @@ func.func @sparse_compression(%tensor: tensor<8x8xf64, #CSR>,
return %1 : tensor<8x8xf64, #CSR>
}
-// CHECK-LABEL: func.func private @_insert_dense_compressed_nonordered_8_8_f64_0_0(
+// CHECK-LABEL: func.func private @_insert_dense_compressed_no_8_8_f64_0_0(
// CHECK-SAME: %[[A1:.*0]]: memref<?xindex>,
// CHECK-SAME: %[[A2:.*1]]: memref<?xindex>,
// CHECK-SAME: %[[A3:.*2]]: memref<?xf64>,
@@ -533,7 +533,7 @@ func.func @sparse_compression(%tensor: tensor<8x8xf64, #CSR>,
// CHECK: %[[A13:.*]]:4 = scf.for %[[A14:.*]] = %[[A11]] to %[[A7]] step %[[A12]] iter_args(%[[A15:.*]] = %[[A0]], %[[A16:.*]] = %[[A1]], %[[A17:.*]] = %[[A2]], %[[A18:.*]] = %[[A3]]) -> (memref<?xindex>, memref<?xindex>, memref<?xf64>, !sparse_tensor.storage_specifier
// CHECK: %[[A19:.*]] = memref.load %[[A6]]{{\[}}%[[A14]]] : memref<?xindex>
// CHECK: %[[A20:.*]] = memref.load %[[A4]]{{\[}}%[[A19]]] : memref<?xf64>
-// CHECK: %[[A21:.*]]:4 = func.call @_insert_dense_compressed_nonordered_8_8_f64_0_0(%[[A15]], %[[A16]], %[[A17]], %[[A18]], %[[A8]], %[[A19]], %[[A20]]) : (memref<?xindex>, memref<?xindex>, memref<?xf64>, !sparse_tensor.storage_specifier
+// CHECK: %[[A21:.*]]:4 = func.call @_insert_dense_compressed_no_8_8_f64_0_0(%[[A15]], %[[A16]], %[[A17]], %[[A18]], %[[A8]], %[[A19]], %[[A20]]) : (memref<?xindex>, memref<?xindex>, memref<?xf64>, !sparse_tensor.storage_specifier
// CHECK: memref.store %[[A10]], %[[A4]]{{\[}}%[[A19]]] : memref<?xf64>
// CHECK: memref.store %[[A9]], %[[A5]]{{\[}}%[[A19]]] : memref<?xi1>
// CHECK: scf.yield %[[A21]]#0, %[[A21]]#1, %[[A21]]#2, %[[A21]]#3 : memref<?xindex>, memref<?xindex>, memref<?xf64>, !sparse_tensor.storage_specifier
@@ -611,7 +611,7 @@ func.func @sparse_insert_typed(%arg0: tensor<128xf64, #SparseVector>, %arg1: ind
return %1 : tensor<128xf64, #SparseVector>
}
-// CHECK-LABEL: func.func private @_insert_compressed_nonunique_singleton_5_6_f64_0_0(
+// CHECK-LABEL: func.func private @_insert_compressed_nu_singleton_5_6_f64_0_0(
// CHECK-SAME: %[[A1:.*0]]: memref<?xindex>,
// CHECK-SAME: %[[A2:.*1]]: memref<?xindex>,
// CHECK-SAME: %[[A3:.*2]]: memref<?xf64>,
@@ -627,7 +627,7 @@ func.func @sparse_insert_typed(%arg0: tensor<128xf64, #SparseVector>, %arg1: ind
// CHECK-SAME: %[[A3:.*3]]: !sparse_tensor.storage_specifier
// CHECK-SAME: %[[A4:.*4]]: index,
// CHECK-SAME: %[[A5:.*5]]: f64)
-// CHECK: %[[R:.*]]:4 = call @_insert_compressed_nonunique_singleton_5_6_f64_0_0(%[[A0]], %[[A1]], %[[A2]], %[[A3]], %[[A4]], %[[A4]], %[[A5]])
+// CHECK: %[[R:.*]]:4 = call @_insert_compressed_nu_singleton_5_6_f64_0_0(%[[A0]], %[[A1]], %[[A2]], %[[A3]], %[[A4]], %[[A4]], %[[A5]])
// CHECK: return %[[R]]#0, %[[R]]#1, %[[R]]#2, %[[R]]#3
func.func @sparse_insert_coo(%arg0: tensor<5x6xf64, #Coo>, %arg1: index, %arg2: f64) -> tensor<5x6xf64, #Coo> {
%0 = sparse_tensor.insert %arg2 into %arg0[%arg1, %arg1] : tensor<5x6xf64, #Coo>
@@ -665,90 +665,94 @@ func.func @sparse_convert_element_type(%arg0: tensor<32xf32, #SparseVector>) ->
// CHECK-LABEL: func.func @sparse_new_coo(
// CHECK-SAME: %[[A0:.*]]: !llvm.ptr<i8>) -> (memref<?xindex>, memref<?xindex>, memref<?xf32>, !sparse_tensor.storage_specifier<#sparse_tensor.encoding<{{{.*}}}>>) {
-// CHECK-DAG: %[[A1:.*]] = arith.constant false
-// CHECK-DAG: %[[A2:.*]] = arith.constant 1 : index
-// CHECK-DAG: %[[A3:.*]] = arith.constant 0 : index
-// CHECK-DAG: %[[A4:.*]] = arith.constant 2 : index
-// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : i32
-// CHECK: %[[D0:.*]] = memref.alloca() : memref<2xindex>
-// CHECK: %[[D1:.*]] = memref.cast %[[D0]] : memref<2xindex> to memref<?xindex>
-// CHECK: memref.store %[[A3]], %[[D0]][%[[A3]]] : memref<2xindex
-// CHECK: memref.store %[[A3]], %[[D0]][%[[A2]]] : memref<2xindex>
-// CHECK: %[[A5:.*]] = call @createCheckedSparseTensorReader(%[[A0]], %[[D1]], %[[C2]])
-// CHECK: %[[D2:.*]] = call @getSparseTensorReaderDimSizes(%0) : (!llvm.ptr<i8>) -> memref<?xindex>
-// CHECK: %[[A8:.*]] = memref.load %[[D2]]{{\[}}%[[A3]]] : memref<?xindex>
-// CHECK: %[[A9:.*]] = memref.load %[[D2]]{{\[}}%[[A2]]] : memref<?xindex>
-// CHECK: %[[A10:.*]] = call @getSparseTensorReaderNSE(%[[A5]])
-// CHECK: %[[A11:.*]] = arith.muli %[[A10]], %[[A4]] : index
-// CHECK: %[[A12:.*]] = memref.alloc() : memref<2xindex>
-// CHECK: %[[A13:.*]] = memref.cast %[[A12]] : memref<2xindex> to memref<?xindex>
-// CHECK: %[[A14:.*]] = memref.alloc(%[[A11]]) : memref<?xindex>
-// CHECK: %[[A15:.*]] = memref.alloc(%[[A10]]) : memref<?xf32>
-// CHECK: %[[A16:.*]] = sparse_tensor.storage_specifier.init : !sparse_tensor.storage_specifier<#sparse_tensor.encoding<{{{.*}}}>>
-// CHECK: %[[A18:.*]] = sparse_tensor.storage_specifier.set %[[A16]] lvl_sz at 0 with %[[A8]]
-// CHECK: %[[A19:.*]] = sparse_tensor.storage_specifier.get %[[A18]] pos_mem_sz at 0
-// CHECK: %[[A21:.*]], %[[A22:.*]] = sparse_tensor.push_back %[[A19]], %[[A13]], %[[A3]]
-// CHECK: %[[A24:.*]] = sparse_tensor.storage_specifier.set %[[A18]] pos_mem_sz at 0 with %[[A22]]
-// CHECK: %[[A26:.*]] = sparse_tensor.storage_specifier.set %[[A24]] lvl_sz at 1 with %[[A9]]
-// CHECK: %[[A27:.*]], %[[A28:.*]] = sparse_tensor.push_back %[[A22]], %[[A21]], %[[A3]], %[[A2]]
-// CHECK: %[[A30:.*]] = sparse_tensor.storage_specifier.set %[[A26]] pos_mem_sz at 0 with %[[A28]]
-// CHECK: %[[A31:.*]] = memref.alloca() : memref<2xindex>
-// CHECK: %[[A32:.*]] = memref.cast %[[A31]] : memref<2xindex> to memref<?xindex>
-// CHECK: memref.store %[[A3]], %[[A31]]{{\[}}%[[A3]]] : memref<2xindex>
-// CHECK: memref.store %[[A2]], %[[A31]]{{\[}}%[[A2]]] : memref<2xindex>
-// CHECK: %[[A33:.*]] = call @getSparseTensorReaderReadToBuffers0F32(%[[A5]], %[[A32]], %[[A14]], %[[A15]])
-// CHECK: %[[A34:.*]] = arith.cmpi eq, %[[A33]], %[[A1]] : i1
-// CHECK: scf.if %[[A34]] {
-// CHECK: sparse_tensor.sort hybrid_quick_sort %[[A10]], %[[A14]] jointly %[[A15]] {ny = 0 : index, perm_map = #{{.*}}} : memref<?xindex> jointly memref<?xf32>
-// CHECK: }
-// CHECK: memref.store %[[A10]], %[[A27]]{{\[}}%[[A2]]] : memref<?xindex>
-// CHECK: %[[A36:.*]] = sparse_tensor.storage_specifier.set %[[A30]] crd_mem_sz at 0 with %[[A11]]
-// CHECK: %[[A38:.*]] = sparse_tensor.storage_specifier.set %[[A36]] val_mem_sz with %[[A10]]
-// CHECK: call @delSparseTensorReader(%[[A5]]) : (!llvm.ptr<i8>) -> ()
-// CHECK: return %[[A27]], %[[A14]], %[[A15]], %[[A38]]
+// CHECK-DAG: %[[VAL_1:.*]] = arith.constant false
+// CHECK-DAG: %[[VAL_2:.*]] = arith.constant 2 : i32
+// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 2 : index
+// CHECK: %[[VAL_6:.*]] = memref.alloca() : memref<2xindex>
+// CHECK: %[[VAL_7:.*]] = memref.cast %[[VAL_6]] : memref<2xindex> to memref<?xindex>
+// CHECK: memref.store %[[VAL_4]], %[[VAL_6]]{{\[}}%[[VAL_4]]] : memref<2xindex>
+// CHECK: memref.store %[[VAL_4]], %[[VAL_6]]{{\[}}%[[VAL_3]]] : memref<2xindex>
+// CHECK: %[[VAL_8:.*]] = call @createCheckedSparseTensorReader(%[[A0]], %[[VAL_7]], %[[VAL_2]]) : (!llvm.ptr<i8>, memref<?xindex>, i32) -> !llvm.ptr<i8>
+// CHECK: %[[VAL_9:.*]] = call @getSparseTensorReaderDimSizes(%[[VAL_8]]) : (!llvm.ptr<i8>) -> memref<?xindex>
+// CHECK: %[[VAL_10:.*]] = call @getSparseTensorReaderNSE(%[[VAL_8]]) : (!llvm.ptr<i8>) -> index
+// CHECK: %[[VAL_11:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_4]]] : memref<?xindex>
+// CHECK: %[[VAL_12:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_3]]] : memref<?xindex>
+// CHECK: %[[VAL_13:.*]] = arith.muli %[[VAL_10]], %[[VAL_5]] : index
+// CHECK: %[[VAL_14:.*]] = memref.alloc() : memref<2xindex>
+// CHECK: %[[VAL_15:.*]] = memref.cast %[[VAL_14]] : memref<2xindex> to memref<?xindex>
+// CHECK: %[[VAL_16:.*]] = memref.alloc(%[[VAL_13]]) : memref<?xindex>
+// CHECK: %[[VAL_17:.*]] = memref.alloc(%[[VAL_10]]) : memref<?xf32>
+// CHECK: %[[VAL_18:.*]] = sparse_tensor.storage_specifier.init
+// CHECK: %[[VAL_19:.*]] = sparse_tensor.storage_specifier.set %[[VAL_18]] lvl_sz at 0 with %[[VAL_11]]
+// CHECK: %[[VAL_20:.*]] = sparse_tensor.storage_specifier.get %[[VAL_19]] pos_mem_sz at 0
+// CHECK: %[[VAL_21:.*]], %[[VAL_22:.*]] = sparse_tensor.push_back %[[VAL_20]], %[[VAL_15]], %[[VAL_4]]
+// CHECK: %[[VAL_23:.*]] = sparse_tensor.storage_specifier.set %[[VAL_19]] pos_mem_sz at 0 with %[[VAL_22]]
+// CHECK: %[[VAL_24:.*]] = sparse_tensor.storage_specifier.set %[[VAL_23]] lvl_sz at 1 with %[[VAL_12]]
+// CHECK: %[[VAL_25:.*]], %[[VAL_26:.*]] = sparse_tensor.push_back %[[VAL_22]], %[[VAL_21]], %[[VAL_4]], %[[VAL_3]]
+// CHECK: %[[VAL_27:.*]] = sparse_tensor.storage_specifier.set %[[VAL_24]] pos_mem_sz at 0 with %[[VAL_26]]
+// CHECK: %[[VAL_28:.*]] = memref.alloca() : memref<2xindex>
+// CHECK: %[[VAL_29:.*]] = memref.cast %[[VAL_28]] : memref<2xindex> to memref<?xindex>
+// CHECK: memref.store %[[VAL_4]], %[[VAL_28]]{{\[}}%[[VAL_4]]] : memref<2xindex>
+// CHECK: memref.store %[[VAL_3]], %[[VAL_28]]{{\[}}%[[VAL_3]]] : memref<2xindex>
+// CHECK: %[[VAL_30:.*]] = call @getSparseTensorReaderReadToBuffers0F32(%[[VAL_8]], %[[VAL_29]], %[[VAL_29]], %[[VAL_16]], %[[VAL_17]]) : (!llvm.ptr<i8>, memref<?xindex>, memref<?xindex>, memref<?xindex>, memref<?xf32>) -> i1
+// CHECK: %[[VAL_31:.*]] = arith.cmpi eq, %[[VAL_30]], %[[VAL_1]] : i1
+// CHECK: scf.if %[[VAL_31]] {
+// CHECK: sparse_tensor.sort_coo hybrid_quick_sort %[[VAL_10]], %[[VAL_16]] jointly %[[VAL_17]]
+// CHECK: }
+// CHECK: memref.store %[[VAL_10]], %[[VAL_25]]{{\[}}%[[VAL_3]]] : memref<?xindex>
+// CHECK: %[[VAL_32:.*]] = sparse_tensor.storage_specifier.set %[[VAL_27]] crd_mem_sz at 0 with %[[VAL_13]]
+// CHECK: %[[VAL_33:.*]] = sparse_tensor.storage_specifier.set %[[VAL_32]] val_mem_sz with %[[VAL_10]]
+// CHECK: call @delSparseTensorReader(%[[VAL_8]]) : (!llvm.ptr<i8>) -> ()
+// CHECK: return %[[VAL_25]], %[[VAL_16]], %[[VAL_17]], %[[VAL_33]]
func.func @sparse_new_coo(%arg0: !llvm.ptr<i8>) -> tensor<?x?xf32, #Coo> {
%0 = sparse_tensor.new %arg0 : !llvm.ptr<i8> to tensor<?x?xf32, #Coo>
return %0 : tensor<?x?xf32, #Coo>
}
// CHECK-LABEL: func.func @sparse_new_coo_permute_no(
-// CHECK-SAME: %[[A0:.*]]: !llvm.ptr<i8>) -> (memref<?xindex>, memref<?xindex>, memref<?xf32>, !sparse_tensor.storage_specifier<#sparse_tensor.encoding<{{{.*}}}>>) {
-// CHECK-DAG: %[[A1:.*]] = arith.constant 1 : index
-// CHECK-DAG: %[[A2:.*]] = arith.constant 0 : index
-// CHECK-DAG: %[[A3:.*]] = arith.constant 2 : index
-// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : i32
-// CHECK: %[[D0:.*]] = memref.alloca() : memref<2xindex>
-// CHECK: %[[D1:.*]] = memref.cast %[[D0]] : memref<2xindex> to memref<?xindex>
-// CHECK: memref.store %[[A2]], %[[D0]][%[[A2]]] : memref<2xindex
-// CHECK: memref.store %[[A2]], %[[D0]][%[[A1]]] : memref<2xindex>
-// CHECK: %[[A4:.*]] = call @createCheckedSparseTensorReader(%[[A0]], %[[D1]], %[[C2]])
-// CHECK: %[[D2:.*]] = call @getSparseTensorReaderDimSizes(%0) : (!llvm.ptr<i8>) -> memref<?xindex>
-// CHECK: %[[A7:.*]] = memref.load %[[D2]]{{\[}}%[[A2]]] : memref<?xindex>
-// CHECK: %[[A8:.*]] = memref.load %[[D2]]{{\[}}%[[A1]]] : memref<?xindex>
-// CHECK: %[[A9:.*]] = call @getSparseTensorReaderNSE(%[[A4]])
-// CHECK: %[[A10:.*]] = arith.muli %[[A9]], %[[A3]] : index
-// CHECK: %[[A11:.*]] = memref.alloc() : memref<2xindex>
-// CHECK: %[[A12:.*]] = memref.cast %[[A11]] : memref<2xindex> to memref<?xindex>
-// CHECK: %[[A13:.*]] = memref.alloc(%[[A10]]) : memref<?xindex>
-// CHECK: %[[A14:.*]] = memref.alloc(%[[A9]]) : memref<?xf32>
-// CHECK: %[[A15:.*]] = sparse_tensor.storage_specifier.init : !sparse_tensor.storage_specifier<#sparse_tensor.encoding<{{{.*}}}>>
-// CHECK: %[[A17:.*]] = sparse_tensor.storage_specifier.set %[[A15]] lvl_sz at 0 with %[[A8]]
-// CHECK: %[[A18:.*]] = sparse_tensor.storage_specifier.get %[[A17]] pos_mem_sz at 0
-// CHECK: %[[A20:.*]], %[[A21:.*]] = sparse_tensor.push_back %[[A18]], %[[A12]], %[[A2]]
-// CHECK: %[[A23:.*]] = sparse_tensor.storage_specifier.set %[[A17]] pos_mem_sz at 0 with %[[A21]]
-// CHECK: %[[A25:.*]] = sparse_tensor.storage_specifier.set %[[A23]] lvl_sz at 1 with %[[A7]]
-// CHECK: %[[A26:.*]], %[[A27:.*]] = sparse_tensor.push_back %[[A21]], %[[A20]], %[[A2]], %[[A1]]
-// CHECK: %[[A29:.*]] = sparse_tensor.storage_specifier.set %[[A25]] pos_mem_sz at 0 with %[[A27]]
-// CHECK: %[[A30:.*]] = memref.alloca() : memref<2xindex>
-// CHECK: %[[A31:.*]] = memref.cast %[[A30]] : memref<2xindex> to memref<?xindex>
-// CHECK: memref.store %[[A1]], %[[A30]]{{\[}}%[[A2]]] : memref<2xindex>
-// CHECK: memref.store %[[A2]], %[[A30]]{{\[}}%[[A1]]] : memref<2xindex>
-// CHECK: %[[A32:.*]] = call @getSparseTensorReaderReadToBuffers0F32(%[[A4]], %[[A31]], %[[A13]], %[[A14]])
-// CHECK: memref.store %[[A9]], %[[A26]]{{\[}}%[[A1]]] : memref<?xindex>
-// CHECK: %[[A34:.*]] = sparse_tensor.storage_specifier.set %[[A29]] crd_mem_sz at 0 with %[[A10]]
-// CHECK: %[[A36:.*]] = sparse_tensor.storage_specifier.set %[[A34]] val_mem_sz with %[[A9]]
-// CHECK: call @delSparseTensorReader(%[[A4]]) : (!llvm.ptr<i8>) -> ()
-// CHECK: return %[[A26]], %[[A13]], %[[A14]], %[[A36]]
+// CHECK-SAME: %[[A0:.*]]: !llvm.ptr<i8>) -> (memref<?xindex>, memref<?xindex>, memref<?xf32>, !sparse_tensor.storage_specifier<#sparse_tensor.encoding<{{{.*}}}>>) {
+// CHECK-DAG: %[[VAL_1:.*]] = arith.constant 2 : i32
+// CHECK-DAG: %[[VAL_2:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 2 : index
+// CHECK: %[[VAL_5:.*]] = memref.alloca() : memref<2xindex>
+// CHECK: %[[VAL_6:.*]] = memref.cast %[[VAL_5]] : memref<2xindex> to memref<?xindex>
+// CHECK: memref.store %[[VAL_3]], %[[VAL_5]]{{\[}}%[[VAL_3]]] : memref<2xindex>
+// CHECK: memref.store %[[VAL_3]], %[[VAL_5]]{{\[}}%[[VAL_2]]] : memref<2xindex>
+// CHECK: %[[VAL_7:.*]] = call @createCheckedSparseTensorReader(%[[A0]], %[[VAL_6]], %[[VAL_1]]) : (!llvm.ptr<i8>, memref<?xindex>, i32) -> !llvm.ptr<i8>
+// CHECK: %[[VAL_8:.*]] = call @getSparseTensorReaderDimSizes(%[[VAL_7]]) : (!llvm.ptr<i8>) -> memref<?xindex>
+// CHECK: %[[VAL_9:.*]] = call @getSparseTensorReaderNSE(%[[VAL_7]]) : (!llvm.ptr<i8>) -> index
+// CHECK: %[[VAL_10:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_3]]] : memref<?xindex>
+// CHECK: %[[VAL_11:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_2]]] : memref<?xindex>
+// CHECK: %[[VAL_12:.*]] = arith.muli %[[VAL_9]], %[[VAL_4]] : index
+// CHECK: %[[VAL_13:.*]] = memref.alloc() : memref<2xindex>
+// CHECK: %[[VAL_14:.*]] = memref.cast %[[VAL_13]] : memref<2xindex> to memref<?xindex>
+// CHECK: %[[VAL_15:.*]] = memref.alloc(%[[VAL_12]]) : memref<?xindex>
+// CHECK: %[[VAL_16:.*]] = memref.alloc(%[[VAL_9]]) : memref<?xf32>
+// CHECK: %[[VAL_17:.*]] = sparse_tensor.storage_specifier.init
+// CHECK: %[[VAL_18:.*]] = sparse_tensor.storage_specifier.set %[[VAL_17]] lvl_sz at 0 with %[[VAL_11]]
+// CHECK: %[[VAL_19:.*]] = sparse_tensor.storage_specifier.get %[[VAL_18]] pos_mem_sz at 0
+// CHECK: %[[VAL_20:.*]], %[[VAL_21:.*]] = sparse_tensor.push_back %[[VAL_19]], %[[VAL_14]], %[[VAL_3]]
+// CHECK: %[[VAL_22:.*]] = sparse_tensor.storage_specifier.set %[[VAL_18]] pos_mem_sz at 0 with %[[VAL_21]]
+// CHECK: %[[VAL_23:.*]] = sparse_tensor.storage_specifier.set %[[VAL_22]] lvl_sz at 1 with %[[VAL_10]]
+// CHECK: %[[VAL_24:.*]], %[[VAL_25:.*]] = sparse_tensor.push_back %[[VAL_21]], %[[VAL_20]], %[[VAL_3]], %[[VAL_2]]
+// CHECK: %[[VAL_26:.*]] = sparse_tensor.storage_specifier.set %[[VAL_23]] pos_mem_sz at 0 with %[[VAL_25]]
+// CHECK: %[[VAL_27:.*]] = memref.alloca() : memref<2xindex>
+// CHECK: %[[VAL_28:.*]] = memref.cast %[[VAL_27]] : memref<2xindex> to memref<?xindex>
+// CHECK: memref.store %[[VAL_2]], %[[VAL_27]]{{\[}}%[[VAL_3]]] : memref<2xindex>
+// CHECK: memref.store %[[VAL_3]], %[[VAL_27]]{{\[}}%[[VAL_2]]] : memref<2xindex>
+// CHECK: %[[VAL_29:.*]] = memref.alloca() : memref<2xindex>
+// CHECK: %[[VAL_30:.*]] = memref.cast %[[VAL_29]] : memref<2xindex> to memref<?xindex>
+// CHECK: memref.store %[[VAL_2]], %[[VAL_29]]{{\[}}%[[VAL_3]]] : memref<2xindex>
+// CHECK: memref.store %[[VAL_3]], %[[VAL_29]]{{\[}}%[[VAL_2]]] : memref<2xindex>
+// CHECK: %[[VAL_31:.*]] = call @getSparseTensorReaderReadToBuffers0F32(%[[VAL_7]], %[[VAL_28]], %[[VAL_30]], %[[VAL_15]], %[[VAL_16]]) : (!llvm.ptr<i8>, memref<?xindex>, memref<?xindex>, memref<?xindex>, memref<?xf32>) -> i1
+// CHECK: memref.store %[[VAL_9]], %[[VAL_24]]{{\[}}%[[VAL_2]]] : memref<?xindex>
+// CHECK: %[[VAL_32:.*]] = sparse_tensor.storage_specifier.set %[[VAL_26]] crd_mem_sz at 0 with %[[VAL_12]]
+// CHECK: %[[VAL_33:.*]] = sparse_tensor.storage_specifier.set %[[VAL_32]] val_mem_sz with %[[VAL_9]]
+// CHECK: call @delSparseTensorReader(%[[VAL_7]]) : (!llvm.ptr<i8>) -> ()
+// CHECK: return %[[VAL_24]], %[[VAL_15]], %[[VAL_16]], %[[VAL_33]]
func.func @sparse_new_coo_permute_no(%arg0: !llvm.ptr<i8>) -> tensor<?x?xf32, #CooPNo> {
%0 = sparse_tensor.new %arg0 : !llvm.ptr<i8> to tensor<?x?xf32, #CooPNo>
return %0 : tensor<?x?xf32, #CooPNo>
diff --git a/mlir/test/Dialect/SparseTensor/conversion.mlir b/mlir/test/Dialect/SparseTensor/conversion.mlir
index 9d337b929fa423a..138736e26c1dfdd 100644
--- a/mlir/test/Dialect/SparseTensor/conversion.mlir
+++ b/mlir/test/Dialect/SparseTensor/conversion.mlir
@@ -114,15 +114,15 @@ func.func @sparse_new2d(%arg0: !llvm.ptr<i8>) -> tensor<?x?xf32, #CSR> {
// CHECK-DAG: %[[DimShape:.*]] = memref.cast %[[DimShape0]] : memref<3xindex> to memref<?xindex>
// CHECK: %[[Reader:.*]] = call @createCheckedSparseTensorReader(%[[A]], %[[DimShape]], %{{.*}})
// CHECK: %[[DimSizes:.*]] = call @getSparseTensorReaderDimSizes(%[[Reader]])
-// CHECK-DAG: %[[LvlSizes0:.*]] = memref.alloca() : memref<3xindex>
-// CHECK-DAG: %[[LvlSizes:.*]] = memref.cast %[[LvlSizes0]] : memref<3xindex> to memref<?xindex>
-// CHECK-DAG: %[[Lvl2Dim0:.*]] = memref.alloca() : memref<3xindex>
-// CHECK-DAG: %[[Lvl2Dim:.*]] = memref.cast %[[Lvl2Dim0]] : memref<3xindex> to memref<?xindex>
-// CHECK-DAG: %[[Dim2Lvl0:.*]] = memref.alloca() : memref<3xindex>
-// CHECK-DAG: %[[Dim2Lvl:.*]] = memref.cast %[[Dim2Lvl0]] : memref<3xindex> to memref<?xindex>
-// CHECK-DAG: %[[LvlTypes0:.*]] = memref.alloca() : memref<3xi8>
-// CHECK-DAG: %[[LvlTypes:.*]] = memref.cast %[[LvlTypes0]] : memref<3xi8> to memref<?xi8>
-// CHECK: %[[T:.*]] = call @newSparseTensorFromReader(%[[Reader]], %[[LvlSizes]], %[[LvlTypes]], %[[Lvl2Dim]], %[[Dim2Lvl]], %{{.*}}, %{{.*}}, %{{.*}})
+// CHECK: %[[Dim2Lvl0:.*]] = memref.alloca() : memref<3xindex>
+// CHECK: %[[Dim2Lvl:.*]] = memref.cast %[[Dim2Lvl0]] : memref<3xindex> to memref<?xindex>
+// CHECK: %[[Lvl2Dim0:.*]] = memref.alloca() : memref<3xindex>
+// CHECK: %[[Lvl2Dim:.*]] = memref.cast %[[Lvl2Dim0]] : memref<3xindex> to memref<?xindex>
+// CHECK: %[[LvlSizes0:.*]] = memref.alloca() : memref<3xindex>
+// CHECK: %[[LvlSizes:.*]] = memref.cast %[[LvlSizes0]] : memref<3xindex> to memref<?xindex>
+// CHECK: %[[LvlTypes0:.*]] = memref.alloca() : memref<3xi8>
+// CHECK: %[[LvlTypes:.*]] = memref.cast %[[LvlTypes0]] : memref<3xi8> to memref<?xi8>
+// CHECK: %[[T:.*]] = call @newSparseTensorFromReader(%[[Reader]], %[[LvlSizes]], %[[LvlTypes]], %[[Dim2Lvl]], %[[Lvl2Dim]], %{{.*}}, %{{.*}}, %{{.*}})
// CHECK: call @delSparseTensorReader(%[[Reader]])
// CHECK: return %[[T]] : !llvm.ptr<i8>
func.func @sparse_new3d(%arg0: !llvm.ptr<i8>) -> tensor<?x?x?xf32, #SparseTensor> {
>From d54b03e367ed34ebea5a0b06c6c6f2e4a04b93b7 Mon Sep 17 00:00:00 2001
From: Aart Bik <ajcbik at google.com>
Date: Thu, 5 Oct 2023 15:14:22 -0700
Subject: [PATCH 2/3] fix merge conflict
---
mlir/test/Dialect/SparseTensor/codegen.mlir | 12 ++++++------
1 file changed, 6 insertions(+), 6 deletions(-)
diff --git a/mlir/test/Dialect/SparseTensor/codegen.mlir b/mlir/test/Dialect/SparseTensor/codegen.mlir
index ff523e70bfc914a..adefceba7379f99 100644
--- a/mlir/test/Dialect/SparseTensor/codegen.mlir
+++ b/mlir/test/Dialect/SparseTensor/codegen.mlir
@@ -423,7 +423,7 @@ func.func @sparse_expansion3(%arg0: index, %arg1: index) -> memref<?xindex> {
// CHECK-DAG: %[[A9:.*]] = arith.constant 0.000000e+00 : f64
// CHECK-DAG: %[[A10:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[A11:.*]] = arith.constant 0 : index
-// CHECK: sparse_tensor.sort_coo hybrid_quick_sort %[[A7]], %[[A6]]
+// CHECK: sparse_tensor.sort hybrid_quick_sort %[[A7]], %[[A6]]
// CHECK: %[[A12:.*]]:4 = scf.for %[[A13:.*]] = %[[A11]] to %[[A7]] step %[[A10]] iter_args(%[[A14:.*]] = %[[A0]], %[[A15:.*]] = %[[A1]], %[[A16:.*]] = %[[A2]], %[[A17:.*]] = %[[A3]])
// CHECK: %[[A18:.*]] = memref.load %[[A6]]{{\[}}%[[A13]]] : memref<?xindex>
// CHECK: %[[A19:.*]] = memref.load %[[A4]]{{\[}}%[[A18]]] : memref<?xf64>
@@ -471,7 +471,7 @@ func.func @sparse_compression_1d(%tensor: tensor<100xf64, #SV>,
// CHECK: %[[A11:.*]] = arith.constant 0.000000e+00 : f64
// CHECK: %[[A12:.*]] = arith.constant 1 : index
// CHECK: %[[A13:.*]] = arith.constant 0 : index
-// CHECK: sparse_tensor.sort_coo hybrid_quick_sort %[[A7]], %[[A6]]
+// CHECK: sparse_tensor.sort hybrid_quick_sort %[[A7]], %[[A6]]
// CHECK: %[[A14:.*]]:4 = scf.for %[[A15:.*]] = %[[A13]] to %[[A7]] step %[[A12]] iter_args(%[[A16:.*]] = %[[A0]], %[[A17:.*]] = %[[A1]], %[[A18:.*]] = %[[A2]], %[[A19:.*]] = %[[A3]]) -> (memref<?xi32>, memref<?xi64>, memref<?xf64>, !sparse_tensor.storage_specifier
// CHECK: %[[A20:.*]] = memref.load %[[A6]]{{\[}}%[[A15]]] : memref<?xindex>
// CHECK: %[[A21:.*]] = memref.load %[[A4]]{{\[}}%[[A20]]] : memref<?xf64>
@@ -507,7 +507,7 @@ func.func @sparse_compression(%tensor: tensor<8x8xf64, #CSR>,
return %1 : tensor<8x8xf64, #CSR>
}
-// CHECK-LABEL: func.func private @_insert_dense_compressed_no_8_8_f64_0_0(
+// CHECK-LABEL: func.func private @_insert_dense_compressed_nonordered_8_8_f64_0_0(
// CHECK-SAME: %[[A1:.*0]]: memref<?xindex>,
// CHECK-SAME: %[[A2:.*1]]: memref<?xindex>,
// CHECK-SAME: %[[A3:.*2]]: memref<?xf64>,
@@ -533,7 +533,7 @@ func.func @sparse_compression(%tensor: tensor<8x8xf64, #CSR>,
// CHECK: %[[A13:.*]]:4 = scf.for %[[A14:.*]] = %[[A11]] to %[[A7]] step %[[A12]] iter_args(%[[A15:.*]] = %[[A0]], %[[A16:.*]] = %[[A1]], %[[A17:.*]] = %[[A2]], %[[A18:.*]] = %[[A3]]) -> (memref<?xindex>, memref<?xindex>, memref<?xf64>, !sparse_tensor.storage_specifier
// CHECK: %[[A19:.*]] = memref.load %[[A6]]{{\[}}%[[A14]]] : memref<?xindex>
// CHECK: %[[A20:.*]] = memref.load %[[A4]]{{\[}}%[[A19]]] : memref<?xf64>
-// CHECK: %[[A21:.*]]:4 = func.call @_insert_dense_compressed_no_8_8_f64_0_0(%[[A15]], %[[A16]], %[[A17]], %[[A18]], %[[A8]], %[[A19]], %[[A20]]) : (memref<?xindex>, memref<?xindex>, memref<?xf64>, !sparse_tensor.storage_specifier
+// CHECK: %[[A21:.*]]:4 = func.call @_insert_dense_compressed_nonordered_8_8_f64_0_0(%[[A15]], %[[A16]], %[[A17]], %[[A18]], %[[A8]], %[[A19]], %[[A20]]) : (memref<?xindex>, memref<?xindex>, memref<?xf64>, !sparse_tensor.storage_specifier
// CHECK: memref.store %[[A10]], %[[A4]]{{\[}}%[[A19]]] : memref<?xf64>
// CHECK: memref.store %[[A9]], %[[A5]]{{\[}}%[[A19]]] : memref<?xi1>
// CHECK: scf.yield %[[A21]]#0, %[[A21]]#1, %[[A21]]#2, %[[A21]]#3 : memref<?xindex>, memref<?xindex>, memref<?xf64>, !sparse_tensor.storage_specifier
@@ -611,7 +611,7 @@ func.func @sparse_insert_typed(%arg0: tensor<128xf64, #SparseVector>, %arg1: ind
return %1 : tensor<128xf64, #SparseVector>
}
-// CHECK-LABEL: func.func private @_insert_compressed_nu_singleton_5_6_f64_0_0(
+// CHECK-LABEL: func.func private @_insert_compressed_nonunique_singleton_5_6_f64_0_0(
// CHECK-SAME: %[[A1:.*0]]: memref<?xindex>,
// CHECK-SAME: %[[A2:.*1]]: memref<?xindex>,
// CHECK-SAME: %[[A3:.*2]]: memref<?xf64>,
@@ -627,7 +627,7 @@ func.func @sparse_insert_typed(%arg0: tensor<128xf64, #SparseVector>, %arg1: ind
// CHECK-SAME: %[[A3:.*3]]: !sparse_tensor.storage_specifier
// CHECK-SAME: %[[A4:.*4]]: index,
// CHECK-SAME: %[[A5:.*5]]: f64)
-// CHECK: %[[R:.*]]:4 = call @_insert_compressed_nu_singleton_5_6_f64_0_0(%[[A0]], %[[A1]], %[[A2]], %[[A3]], %[[A4]], %[[A4]], %[[A5]])
+// CHECK: %[[R:.*]]:4 = call @_insert_compressed_nonunique_singleton_5_6_f64_0_0(%[[A0]], %[[A1]], %[[A2]], %[[A3]], %[[A4]], %[[A4]], %[[A5]])
// CHECK: return %[[R]]#0, %[[R]]#1, %[[R]]#2, %[[R]]#3
func.func @sparse_insert_coo(%arg0: tensor<5x6xf64, #Coo>, %arg1: index, %arg2: f64) -> tensor<5x6xf64, #Coo> {
%0 = sparse_tensor.insert %arg2 into %arg0[%arg1, %arg1] : tensor<5x6xf64, #Coo>
>From 5ecff8cfae4fb7790d41ac3e07a6b2dbb3a47403 Mon Sep 17 00:00:00 2001
From: Aart Bik <ajcbik at google.com>
Date: Thu, 5 Oct 2023 15:17:46 -0700
Subject: [PATCH 3/3] clang-format
---
mlir/include/mlir/ExecutionEngine/SparseTensor/MapRef.h | 6 ++++--
1 file changed, 4 insertions(+), 2 deletions(-)
diff --git a/mlir/include/mlir/ExecutionEngine/SparseTensor/MapRef.h b/mlir/include/mlir/ExecutionEngine/SparseTensor/MapRef.h
index 1c155568802e579..a1bd6798f150b43 100644
--- a/mlir/include/mlir/ExecutionEngine/SparseTensor/MapRef.h
+++ b/mlir/include/mlir/ExecutionEngine/SparseTensor/MapRef.h
@@ -38,7 +38,8 @@ class MapRef final {
// Push forward maps from dimensions to levels.
//
- template <typename T> inline void pushforward(const T *in, T *out) const {
+ template <typename T>
+ inline void pushforward(const T *in, T *out) const {
switch (kind) {
case MapKind::kIdentity:
for (uint64_t i = 0; i < dimRank; ++i)
@@ -58,7 +59,8 @@ class MapRef final {
// Push backward maps from levels to dimensions.
//
- template <typename T> inline void pushbackward(const T *in, T *out) const {
+ template <typename T>
+ inline void pushbackward(const T *in, T *out) const {
switch (kind) {
case MapKind::kIdentity:
for (uint64_t i = 0; i < lvlRank; ++i)
More information about the Mlir-commits
mailing list