[Mlir-commits] [mlir] 1dfb9a6 - [mlir][sparse] LICM for SparseTensorReader::readCOO
wren romano
llvmlistbot at llvm.org
Fri Dec 2 11:13:43 PST 2022
Author: wren romano
Date: 2022-12-02T11:13:36-08:00
New Revision: 1dfb9a64f1f0fd24eefb19a2ed0b1ae3039db413
URL: https://github.com/llvm/llvm-project/commit/1dfb9a64f1f0fd24eefb19a2ed0b1ae3039db413
DIFF: https://github.com/llvm/llvm-project/commit/1dfb9a64f1f0fd24eefb19a2ed0b1ae3039db413.diff
LOG: [mlir][sparse] LICM for SparseTensorReader::readCOO
This commit performs two related changes. First we adjust `readCOOValue` to take the `IsPattern` bool as a template parameter rather than a function argument. Second we factor `readCOOLoop` out from `readCOO`, and template it on `IsPattern` and `IsSymmetric`. Together this moves all the assertions and header-dependent conditionals out of the main for-loop of `readCOO`. The only remaining conditional is in the `IsSymmetric=true` variant: checking whether the element is on the diagonal or not (which cannot be lifted out of the loop).
Depends On D138363
Reviewed By: aartbik
Differential Revision: https://reviews.llvm.org/D138365
Added:
Modified:
mlir/include/mlir/ExecutionEngine/SparseTensor/File.h
Removed:
################################################################################
diff --git a/mlir/include/mlir/ExecutionEngine/SparseTensor/File.h b/mlir/include/mlir/ExecutionEngine/SparseTensor/File.h
index 3734f96889ff..fa36b7e393ee 100644
--- a/mlir/include/mlir/ExecutionEngine/SparseTensor/File.h
+++ b/mlir/include/mlir/ExecutionEngine/SparseTensor/File.h
@@ -42,32 +42,46 @@ struct is_complex final : public std::false_type {};
template <typename T>
struct is_complex<std::complex<T>> final : public std::true_type {};
-/// Reads an element of a non-complex type for the current indices in
-/// coordinate scheme.
-template <typename V>
-inline std::enable_if_t<!is_complex<V>::value, V>
-readCOOValue(char **linePtr, bool is_pattern) {
+/// Returns an element-value of non-complex type. If `IsPattern` is true,
+/// then returns an arbitrary value. If `IsPattern` is false, then
+/// reads the value from the current line buffer beginning at `linePtr`.
+template <typename V, bool IsPattern>
+inline std::enable_if_t<!is_complex<V>::value, V> readCOOValue(char **linePtr) {
// The external formats always store these numerical values with the type
// double, but we cast these values to the sparse tensor object type.
// For a pattern tensor, we arbitrarily pick the value 1 for all entries.
- return is_pattern ? 1.0 : strtod(*linePtr, linePtr);
+ if constexpr (IsPattern)
+ return 1.0;
+ return strtod(*linePtr, linePtr);
}
-/// Reads an element of a complex type for the current indices in
-/// coordinate scheme.
-template <typename V>
-inline std::enable_if_t<is_complex<V>::value, V> readCOOValue(char **linePtr,
- bool is_pattern) {
+/// Returns an element-value of complex type. If `IsPattern` is true,
+/// then returns an arbitrary value. If `IsPattern` is false, then reads
+/// the value from the current line buffer beginning at `linePtr`.
+template <typename V, bool IsPattern>
+inline std::enable_if_t<is_complex<V>::value, V> readCOOValue(char **linePtr) {
// Read two values to make a complex. The external formats always store
// numerical values with the type double, but we cast these values to the
// sparse tensor object type. For a pattern tensor, we arbitrarily pick the
// value 1 for all entries.
- double re = is_pattern ? 1.0 : strtod(*linePtr, linePtr);
- double im = is_pattern ? 1.0 : strtod(*linePtr, linePtr);
+ if constexpr (IsPattern)
+ return V(1.0, 1.0);
+ double re = strtod(*linePtr, linePtr);
+ double im = strtod(*linePtr, linePtr);
// Avoiding brace-notation since that forbids narrowing to `float`.
return V(re, im);
}
+/// Returns an element-value. If `is_pattern` is true, then returns an
+/// arbitrary value. If `is_pattern` is false, then reads the value from
+/// the current line buffer beginning at `linePtr`.
+template <typename V>
+inline V readCOOValue(char **linePtr, bool is_pattern) {
+ if (is_pattern)
+ return readCOOValue<V, true>(linePtr);
+ return readCOOValue<V, false>(linePtr);
+}
+
} // namespace detail
//===----------------------------------------------------------------------===//
@@ -249,6 +263,18 @@ class SparseTensorReader final {
/// Precondition: `indices` is valid for `getRank()`.
char *readCOOIndices(uint64_t *indices);
+ /// The internal implementation of `readCOO`. We template over
+ /// `IsPattern` and `IsSymmetric` in order to perform LICM without
+ /// needing to duplicate the source code.
+ //
+ // TODO: We currently take the `dim2lvl` argument as a `PermutationRef`
+ // since that's what `readCOO` creates. Once we update `readCOO` to
+ // functionalize the mapping, then this helper will just take that
+ // same function.
+ template <typename V, bool IsPattern, bool IsSymmetric>
+ void readCOOLoop(uint64_t lvlRank, detail::PermutationRef dim2lvl,
+ SparseTensorCOO<V> *lvlCOO);
+
/// Reads the MME header of a general sparse matrix of type real.
void readMMEHeader();
@@ -282,36 +308,50 @@ SparseTensorCOO<V> *SparseTensorReader::readCOO(uint64_t lvlRank,
assert(lvlRank == dimRank && "Rank mismatch");
detail::PermutationRef d2l(dimRank, dim2lvl);
// Prepare a COO object with the number of nonzeros as initial capacity.
- const uint64_t nnz = getNNZ();
- auto *lvlCOO = new SparseTensorCOO<V>(lvlRank, lvlSizes, nnz);
- // Read all nonzero elements.
+ auto *lvlCOO = new SparseTensorCOO<V>(lvlRank, lvlSizes, getNNZ());
+ // Do some manual LICM, to avoid assertions in the for-loop.
+ const bool IsPattern = isPattern();
+ const bool IsSymmetric = (isSymmetric() && getRank() == 2);
+ if (IsPattern && IsSymmetric)
+ readCOOLoop<V, true, true>(lvlRank, d2l, lvlCOO);
+ else if (IsPattern)
+ readCOOLoop<V, true, false>(lvlRank, d2l, lvlCOO);
+ else if (IsSymmetric)
+ readCOOLoop<V, false, true>(lvlRank, d2l, lvlCOO);
+ else
+ readCOOLoop<V, false, false>(lvlRank, d2l, lvlCOO);
+ // Close the file and return the COO.
+ closeFile();
+ return lvlCOO;
+}
+
+template <typename V, bool IsPattern, bool IsSymmetric>
+void SparseTensorReader::readCOOLoop(uint64_t lvlRank,
+ detail::PermutationRef dim2lvl,
+ SparseTensorCOO<V> *lvlCOO) {
+ const uint64_t dimRank = getRank();
std::vector<uint64_t> dimInd(dimRank);
std::vector<uint64_t> lvlInd(lvlRank);
- // Do some manual LICM, to avoid assertions in the for-loop.
- const bool addSymmetric = (isSymmetric() && dimRank == 2);
- const bool isPattern_ = isPattern();
- for (uint64_t k = 0; k < nnz; ++k) {
+ for (uint64_t nnz = getNNZ(), k = 0; k < nnz; ++k) {
// We inline `readCOOElement` here in order to avoid redundant
// assertions, since they're guaranteed by the call to `isValid()`
// and the construction of `dimInd` above.
char *linePtr = readCOOIndices(dimInd.data());
- const V value = detail::readCOOValue<V>(&linePtr, isPattern_);
- d2l.pushforward(dimRank, dimInd.data(), lvlInd.data());
+ const V value = detail::readCOOValue<V, IsPattern>(&linePtr);
+ dim2lvl.pushforward(dimRank, dimInd.data(), lvlInd.data());
// TODO: <https://github.com/llvm/llvm-project/issues/54179>
lvlCOO->add(lvlInd, value);
// We currently chose to deal with symmetric matrices by fully
// constructing them. In the future, we may want to make symmetry
// implicit for storage reasons.
- if (addSymmetric && dimInd[0] != dimInd[1]) {
- // Must recompute `lvlInd`, since arbitrary mappings don't preserve swap.
- std::swap(dimInd[0], dimInd[1]);
- d2l.pushforward(dimRank, dimInd.data(), lvlInd.data());
- lvlCOO->add(lvlInd, value);
- }
+ if constexpr (IsSymmetric)
+ if (dimInd[0] != dimInd[1]) {
+ // Must recompute `lvlInd`, since arbitrary maps don't preserve swap.
+ std::swap(dimInd[0], dimInd[1]);
+ dim2lvl.pushforward(dimRank, dimInd.data(), lvlInd.data());
+ lvlCOO->add(lvlInd, value);
+ }
}
- // Close the file and return the COO.
- closeFile();
- return lvlCOO;
}
/// Writes the sparse tensor to `filename` in extended FROSTT format.
More information about the Mlir-commits
mailing list