[Mlir-commits] [mlir] [mlir][sparse] cleanup of enums header (PR #71090)

Aart Bik llvmlistbot at llvm.org
Thu Nov 2 12:04:51 PDT 2023


https://github.com/aartbik updated https://github.com/llvm/llvm-project/pull/71090

>From f067914fdb8c746b4ec94aa573e020c0a285d324 Mon Sep 17 00:00:00 2001
From: Aart Bik <ajcbik at google.com>
Date: Thu, 2 Nov 2023 11:29:20 -0700
Subject: [PATCH 1/3] [mlir][sparse] cleanup of enums header

Some DLT related methods leaked into sparse_tensor.h, and this
moves it back to the right header. Also, the asserts were incomplete
and some DLT methods duplicated.
---
 .../mlir/Dialect/SparseTensor/IR/Enums.h      | 129 +++++++++++-------
 .../Dialect/SparseTensor/IR/SparseTensor.h    |  10 --
 .../SparseTensor/IR/SparseTensorAttrDefs.td   |   2 +-
 .../SparseTensor/IR/SparseTensorDialect.cpp   |   1 +
 4 files changed, 81 insertions(+), 61 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
index 1e9aa2bdf45dbdb..a867b99c3bfa5ba 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
@@ -153,24 +153,18 @@ enum class Action : uint32_t {
 };
 
 /// This enum defines all the sparse representations supportable by
-/// the SparseTensor dialect.  We use a lightweight encoding to encode
-/// both the "format" per se (dense, compressed, singleton) as well as
-/// the "properties" (ordered, unique).  The encoding is chosen for
-/// performance of the runtime library, and thus may change in future
-/// versions; consequently, client code should use the predicate functions
-/// defined below, rather than relying on knowledge about the particular
-/// binary encoding.
+/// the SparseTensor dialect. We use a lightweight encoding to encode
+/// both the "format" per se (dense, compressed, singleton, loose_compressed,
+/// two-out-of-four) as well as the "properties" (ordered, unique). The
+/// encoding is chosen for performance of the runtime library, and thus may
+/// change in future versions; consequently, client code should use the
+/// predicate functions defined below, rather than relying on knowledge
+/// about the particular binary encoding.
 ///
 /// The `Undef` "format" is a special value used internally for cases
 /// where we need to store an undefined or indeterminate `DimLevelType`.
 /// It should not be used externally, since it does not indicate an
 /// actual/representable format.
-///
-// TODO: We should generalize TwoOutOfFour to N out of M and use property to
-// encode the value of N and M.
-// TODO: Update DimLevelType to use lower 8 bits for storage formats and the
-// higher 4 bits to store level properties. Consider LooseCompressed and
-// TwoOutOfFour as properties instead of formats.
 enum class DimLevelType : uint8_t {
   Undef = 0,                // 0b00000_00
   Dense = 4,                // 0b00001_00
@@ -257,20 +251,12 @@ constexpr bool isUndefDLT(DimLevelType dlt) {
   return dlt == DimLevelType::Undef;
 }
 
-/// Check if the `DimLevelType` is dense.
+/// Check if the `DimLevelType` is dense (regardless of properties).
 constexpr bool isDenseDLT(DimLevelType dlt) {
-  return dlt == DimLevelType::Dense;
-}
-
-/// Check if the `DimLevelType` is 2:4
-constexpr bool isTwoOutOfFourDLT(DimLevelType dlt) {
-  return dlt == DimLevelType::TwoOutOfFour;
+  return (static_cast<uint8_t>(dlt) & ~3) ==
+         static_cast<uint8_t>(DimLevelType::Dense);
 }
 
-// We use the idiom `(dlt & ~3) == format` in order to only return true
-// for valid DLTs.  Whereas the `dlt & format` idiom is a bit faster but
-// can return false-positives on invalid DLTs.
-
 /// Check if the `DimLevelType` is compressed (regardless of properties).
 constexpr bool isCompressedDLT(DimLevelType dlt) {
   return (static_cast<uint8_t>(dlt) & ~3) ==
@@ -295,6 +281,17 @@ constexpr bool is2OutOf4DLT(DimLevelType dlt) {
          static_cast<uint8_t>(DimLevelType::TwoOutOfFour);
 }
 
+/// Check if the `DimLevelType` needs positions array.
+constexpr bool isDLTWithPos(DimLevelType dlt) {
+  return isLooseCompressedDLT(dlt) || isCompressedDLT(dlt);
+}
+
+/// Check if the `DimLevelType` needs coordinates array.
+constexpr bool isDLTWithCrd(DimLevelType dlt) {
+  return isSingletonDLT(dlt) || isLooseCompressedDLT(dlt) ||
+         isCompressedDLT(dlt);
+}
+
 /// Check if the `DimLevelType` is ordered (regardless of storage format).
 constexpr bool isOrderedDLT(DimLevelType dlt) {
   return !(static_cast<uint8_t>(dlt) & 2);
@@ -336,35 +333,52 @@ static_assert(
      *getLevelFormat(DimLevelType::Singleton) == LevelFormat::Singleton &&
      *getLevelFormat(DimLevelType::SingletonNu) == LevelFormat::Singleton &&
      *getLevelFormat(DimLevelType::SingletonNo) == LevelFormat::Singleton &&
-     *getLevelFormat(DimLevelType::SingletonNuNo) == LevelFormat::Singleton),
+     *getLevelFormat(DimLevelType::SingletonNuNo) == LevelFormat::Singleton &&
+     *getLevelFormat(DimLevelType::LooseCompressed) ==
+         LevelFormat::LooseCompressed &&
+     *getLevelFormat(DimLevelType::LooseCompressedNu) ==
+         LevelFormat::LooseCompressed &&
+     *getLevelFormat(DimLevelType::LooseCompressedNo) ==
+         LevelFormat::LooseCompressed &&
+     *getLevelFormat(DimLevelType::LooseCompressedNuNo) ==
+         LevelFormat::LooseCompressed &&
+     *getLevelFormat(DimLevelType::TwoOutOfFour) == LevelFormat::TwoOutOfFour),
     "getLevelFormat conversion is broken");
 
 static_assert(
     (buildLevelType(LevelFormat::Dense, false, true) == std::nullopt &&
      buildLevelType(LevelFormat::Dense, true, false) == std::nullopt &&
      buildLevelType(LevelFormat::Dense, false, false) == std::nullopt &&
-     *buildLevelType(LevelFormat::Dense, true, true) == DimLevelType::Dense &&
-     buildLevelType(LevelFormat::TwoOutOfFour, false, true) == std::nullopt &&
-     buildLevelType(LevelFormat::TwoOutOfFour, true, false) == std::nullopt &&
-     buildLevelType(LevelFormat::TwoOutOfFour, false, false) == std::nullopt &&
-     *buildLevelType(LevelFormat::TwoOutOfFour, true, true) ==
-         DimLevelType::TwoOutOfFour &&
-     *buildLevelType(LevelFormat::Compressed, true, true) ==
+     buildLevelType(LevelFormat::Dense, true, true) == DimLevelType::Dense &&
+     buildLevelType(LevelFormat::Compressed, true, true) ==
          DimLevelType::Compressed &&
-     *buildLevelType(LevelFormat::Compressed, true, false) ==
+     buildLevelType(LevelFormat::Compressed, true, false) ==
          DimLevelType::CompressedNu &&
-     *buildLevelType(LevelFormat::Compressed, false, true) ==
+     buildLevelType(LevelFormat::Compressed, false, true) ==
          DimLevelType::CompressedNo &&
-     *buildLevelType(LevelFormat::Compressed, false, false) ==
+     buildLevelType(LevelFormat::Compressed, false, false) ==
          DimLevelType::CompressedNuNo &&
-     *buildLevelType(LevelFormat::Singleton, true, true) ==
+     buildLevelType(LevelFormat::Singleton, true, true) ==
          DimLevelType::Singleton &&
-     *buildLevelType(LevelFormat::Singleton, true, false) ==
+     buildLevelType(LevelFormat::Singleton, true, false) ==
          DimLevelType::SingletonNu &&
-     *buildLevelType(LevelFormat::Singleton, false, true) ==
+     buildLevelType(LevelFormat::Singleton, false, true) ==
          DimLevelType::SingletonNo &&
-     *buildLevelType(LevelFormat::Singleton, false, false) ==
-         DimLevelType::SingletonNuNo),
+     buildLevelType(LevelFormat::Singleton, false, false) ==
+         DimLevelType::SingletonNuNo &&
+     buildLevelType(LevelFormat::LooseCompressed, true, true) ==
+         DimLevelType::LooseCompressed &&
+     buildLevelType(LevelFormat::LooseCompressed, true, false) ==
+         DimLevelType::LooseCompressedNu &&
+     buildLevelType(LevelFormat::LooseCompressed, false, true) ==
+         DimLevelType::LooseCompressedNo &&
+     buildLevelType(LevelFormat::LooseCompressed, false, false) ==
+         DimLevelType::LooseCompressedNuNo &&
+     buildLevelType(LevelFormat::TwoOutOfFour, false, true) == std::nullopt &&
+     buildLevelType(LevelFormat::TwoOutOfFour, true, false) == std::nullopt &&
+     buildLevelType(LevelFormat::TwoOutOfFour, false, false) == std::nullopt &&
+     buildLevelType(LevelFormat::TwoOutOfFour, true, true) ==
+         DimLevelType::TwoOutOfFour),
     "buildLevelType conversion is broken");
 
 // Ensure the above predicates work as intended.
@@ -393,18 +407,28 @@ static_assert((!isCompressedDLT(DimLevelType::Dense) &&
                !isCompressedDLT(DimLevelType::Singleton) &&
                !isCompressedDLT(DimLevelType::SingletonNu) &&
                !isCompressedDLT(DimLevelType::SingletonNo) &&
-               !isCompressedDLT(DimLevelType::SingletonNuNo)),
+               !isCompressedDLT(DimLevelType::SingletonNuNo) &&
+               !isCompressedDLT(DimLevelType::LooseCompressed) &&
+               !isCompressedDLT(DimLevelType::LooseCompressedNu) &&
+               !isCompressedDLT(DimLevelType::LooseCompressedNo) &&
+               !isCompressedDLT(DimLevelType::LooseCompressedNuNo) &&
+               !isCompressedDLT(DimLevelType::TwoOutOfFour)),
               "isCompressedDLT definition is broken");
 
 static_assert((!isLooseCompressedDLT(DimLevelType::Dense) &&
+               !isLooseCompressedDLT(DimLevelType::Compressed) &&
+               !isLooseCompressedDLT(DimLevelType::CompressedNu) &&
+               !isLooseCompressedDLT(DimLevelType::CompressedNo) &&
+               !isLooseCompressedDLT(DimLevelType::CompressedNuNo) &&
+               !isLooseCompressedDLT(DimLevelType::Singleton) &&
+               !isLooseCompressedDLT(DimLevelType::SingletonNu) &&
+               !isLooseCompressedDLT(DimLevelType::SingletonNo) &&
+               !isLooseCompressedDLT(DimLevelType::SingletonNuNo) &&
                isLooseCompressedDLT(DimLevelType::LooseCompressed) &&
                isLooseCompressedDLT(DimLevelType::LooseCompressedNu) &&
                isLooseCompressedDLT(DimLevelType::LooseCompressedNo) &&
                isLooseCompressedDLT(DimLevelType::LooseCompressedNuNo) &&
-               !isLooseCompressedDLT(DimLevelType::Singleton) &&
-               !isLooseCompressedDLT(DimLevelType::SingletonNu) &&
-               !isLooseCompressedDLT(DimLevelType::SingletonNo) &&
-               !isLooseCompressedDLT(DimLevelType::SingletonNuNo)),
+               !isLooseCompressedDLT(DimLevelType::TwoOutOfFour)),
               "isLooseCompressedDLT definition is broken");
 
 static_assert((!isSingletonDLT(DimLevelType::Dense) &&
@@ -415,11 +439,15 @@ static_assert((!isSingletonDLT(DimLevelType::Dense) &&
                isSingletonDLT(DimLevelType::Singleton) &&
                isSingletonDLT(DimLevelType::SingletonNu) &&
                isSingletonDLT(DimLevelType::SingletonNo) &&
-               isSingletonDLT(DimLevelType::SingletonNuNo)),
+               isSingletonDLT(DimLevelType::SingletonNuNo) &&
+               !isSingletonDLT(DimLevelType::LooseCompressed) &&
+               !isSingletonDLT(DimLevelType::LooseCompressedNu) &&
+               !isSingletonDLT(DimLevelType::LooseCompressedNo) &&
+               !isSingletonDLT(DimLevelType::LooseCompressedNuNo) &&
+               !isSingletonDLT(DimLevelType::TwoOutOfFour)),
               "isSingletonDLT definition is broken");
 
 static_assert((isOrderedDLT(DimLevelType::Dense) &&
-               isOrderedDLT(DimLevelType::TwoOutOfFour) &&
                isOrderedDLT(DimLevelType::Compressed) &&
                isOrderedDLT(DimLevelType::CompressedNu) &&
                !isOrderedDLT(DimLevelType::CompressedNo) &&
@@ -431,11 +459,11 @@ static_assert((isOrderedDLT(DimLevelType::Dense) &&
                isOrderedDLT(DimLevelType::LooseCompressed) &&
                isOrderedDLT(DimLevelType::LooseCompressedNu) &&
                !isOrderedDLT(DimLevelType::LooseCompressedNo) &&
-               !isOrderedDLT(DimLevelType::LooseCompressedNuNo)),
+               !isOrderedDLT(DimLevelType::LooseCompressedNuNo) &&
+               isOrderedDLT(DimLevelType::TwoOutOfFour)),
               "isOrderedDLT definition is broken");
 
 static_assert((isUniqueDLT(DimLevelType::Dense) &&
-               isUniqueDLT(DimLevelType::TwoOutOfFour) &&
                isUniqueDLT(DimLevelType::Compressed) &&
                !isUniqueDLT(DimLevelType::CompressedNu) &&
                isUniqueDLT(DimLevelType::CompressedNo) &&
@@ -447,7 +475,8 @@ static_assert((isUniqueDLT(DimLevelType::Dense) &&
                isUniqueDLT(DimLevelType::LooseCompressed) &&
                !isUniqueDLT(DimLevelType::LooseCompressedNu) &&
                isUniqueDLT(DimLevelType::LooseCompressedNo) &&
-               !isUniqueDLT(DimLevelType::LooseCompressedNuNo)),
+               !isUniqueDLT(DimLevelType::LooseCompressedNuNo) &&
+               isUniqueDLT(DimLevelType::TwoOutOfFour)),
               "isUniqueDLT definition is broken");
 
 /// Bit manipulations for affine encoding.
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
index 94e7d12b9ee915f..241d90a87165928 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
@@ -89,16 +89,6 @@ inline MemRefType getMemRefType(T &&t) {
 /// Returns null-attribute for any type without an encoding.
 SparseTensorEncodingAttr getSparseTensorEncoding(Type type);
 
-/// Convenience method to query whether a given DLT needs both position and
-/// coordinates array or only coordinates array.
-constexpr inline bool isDLTWithPos(DimLevelType dlt) {
-  return isLooseCompressedDLT(dlt) || isCompressedDLT(dlt);
-}
-constexpr inline bool isDLTWithCrd(DimLevelType dlt) {
-  return isSingletonDLT(dlt) || isLooseCompressedDLT(dlt) ||
-         isCompressedDLT(dlt);
-}
-
 /// Returns true iff the given sparse tensor encoding attribute has a trailing
 /// COO region starting at the given level.
 bool isCOOType(SparseTensorEncodingAttr enc, Level startLvl, bool isUnique);
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
index 3c73b19319e588c..e7c6435e997ca00 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
@@ -371,7 +371,7 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
     ::mlir::sparse_tensor::DimLevelType getLvlType(::mlir::sparse_tensor::Level l) const;
 
     bool isDenseLvl(::mlir::sparse_tensor::Level l) const { return isDenseDLT(getLvlType(l)); }
-    bool isTwoOutOfFourLvl(::mlir::sparse_tensor::Level l) const { return isTwoOutOfFourDLT(getLvlType(l)); }
+    bool isTwoOutOfFourLvl(::mlir::sparse_tensor::Level l) const { return is2OutOf4DLT(getLvlType(l)); }
     bool isCompressedLvl(::mlir::sparse_tensor::Level l) const { return isCompressedDLT(getLvlType(l)); }
     bool isLooseCompressedLvl(::mlir::sparse_tensor::Level l) const { return isLooseCompressedDLT(getLvlType(l)); }
     bool isSingletonLvl(::mlir::sparse_tensor::Level l) const { return isSingletonDLT(getLvlType(l)); }
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 6080317d07a64e0..97ef753aacf35b1 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -10,6 +10,7 @@
 
 #include "Detail/DimLvlMapParser.h"
 
+#include "mlir/Dialect/SparseTensor/IR/Enums.h"
 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
 #include "mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h"
 #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"

>From 8c3006b96b81818c8ba49f8d68245972f68076be Mon Sep 17 00:00:00 2001
From: Aart Bik <ajcbik at google.com>
Date: Thu, 2 Nov 2023 11:45:02 -0700
Subject: [PATCH 2/3] 2:4 has crd

---
 mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
index a867b99c3bfa5ba..4f18bcd815230a3 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
@@ -289,7 +289,7 @@ constexpr bool isDLTWithPos(DimLevelType dlt) {
 /// Check if the `DimLevelType` needs coordinates array.
 constexpr bool isDLTWithCrd(DimLevelType dlt) {
   return isSingletonDLT(dlt) || isLooseCompressedDLT(dlt) ||
-         isCompressedDLT(dlt);
+         isCompressedDLT(dlt) || is2OutOf4DLT(dlt);
 }
 
 /// Check if the `DimLevelType` is ordered (regardless of storage format).

>From 50eb4ed88f60f5689f81ca2ddfc776c1ac3d926b Mon Sep 17 00:00:00 2001
From: Aart Bik <ajcbik at google.com>
Date: Thu, 2 Nov 2023 12:04:05 -0700
Subject: [PATCH 3/3] typo

---
 .../mlir/Dialect/SparseTensor/IR/Enums.h      | 28 +++++++++----------
 1 file changed, 14 insertions(+), 14 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
index 4f18bcd815230a3..272f9b5b1756c23 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
@@ -349,35 +349,35 @@ static_assert(
     (buildLevelType(LevelFormat::Dense, false, true) == std::nullopt &&
      buildLevelType(LevelFormat::Dense, true, false) == std::nullopt &&
      buildLevelType(LevelFormat::Dense, false, false) == std::nullopt &&
-     buildLevelType(LevelFormat::Dense, true, true) == DimLevelType::Dense &&
-     buildLevelType(LevelFormat::Compressed, true, true) ==
+     *buildLevelType(LevelFormat::Dense, true, true) == DimLevelType::Dense &&
+     *buildLevelType(LevelFormat::Compressed, true, true) ==
          DimLevelType::Compressed &&
-     buildLevelType(LevelFormat::Compressed, true, false) ==
+     *buildLevelType(LevelFormat::Compressed, true, false) ==
          DimLevelType::CompressedNu &&
-     buildLevelType(LevelFormat::Compressed, false, true) ==
+     *buildLevelType(LevelFormat::Compressed, false, true) ==
          DimLevelType::CompressedNo &&
-     buildLevelType(LevelFormat::Compressed, false, false) ==
+     *buildLevelType(LevelFormat::Compressed, false, false) ==
          DimLevelType::CompressedNuNo &&
-     buildLevelType(LevelFormat::Singleton, true, true) ==
+     *buildLevelType(LevelFormat::Singleton, true, true) ==
          DimLevelType::Singleton &&
-     buildLevelType(LevelFormat::Singleton, true, false) ==
+     *buildLevelType(LevelFormat::Singleton, true, false) ==
          DimLevelType::SingletonNu &&
-     buildLevelType(LevelFormat::Singleton, false, true) ==
+     *buildLevelType(LevelFormat::Singleton, false, true) ==
          DimLevelType::SingletonNo &&
-     buildLevelType(LevelFormat::Singleton, false, false) ==
+     *buildLevelType(LevelFormat::Singleton, false, false) ==
          DimLevelType::SingletonNuNo &&
-     buildLevelType(LevelFormat::LooseCompressed, true, true) ==
+     *buildLevelType(LevelFormat::LooseCompressed, true, true) ==
          DimLevelType::LooseCompressed &&
-     buildLevelType(LevelFormat::LooseCompressed, true, false) ==
+     *buildLevelType(LevelFormat::LooseCompressed, true, false) ==
          DimLevelType::LooseCompressedNu &&
-     buildLevelType(LevelFormat::LooseCompressed, false, true) ==
+     *buildLevelType(LevelFormat::LooseCompressed, false, true) ==
          DimLevelType::LooseCompressedNo &&
-     buildLevelType(LevelFormat::LooseCompressed, false, false) ==
+     *buildLevelType(LevelFormat::LooseCompressed, false, false) ==
          DimLevelType::LooseCompressedNuNo &&
      buildLevelType(LevelFormat::TwoOutOfFour, false, true) == std::nullopt &&
      buildLevelType(LevelFormat::TwoOutOfFour, true, false) == std::nullopt &&
      buildLevelType(LevelFormat::TwoOutOfFour, false, false) == std::nullopt &&
-     buildLevelType(LevelFormat::TwoOutOfFour, true, true) ==
+     *buildLevelType(LevelFormat::TwoOutOfFour, true, true) ==
          DimLevelType::TwoOutOfFour),
     "buildLevelType conversion is broken");
 



More information about the Mlir-commits mailing list