[Mlir-commits] [mlir] ee3ee13 - [mlir][sparse] cleanup of enums header (#71090)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Nov 2 13:00:32 PDT 2023


Author: Aart Bik
Date: 2023-11-02T13:00:27-07:00
New Revision: ee3ee1315a757345dfc50eed34b5074c6e87df2d

URL: https://github.com/llvm/llvm-project/commit/ee3ee1315a757345dfc50eed34b5074c6e87df2d
DIFF: https://github.com/llvm/llvm-project/commit/ee3ee1315a757345dfc50eed34b5074c6e87df2d.diff

LOG: [mlir][sparse] cleanup of enums header (#71090)

Some DLT related methods leaked into sparse_tensor.h, and this moves it
back to the right header. Also, the asserts were incomplete and some DLT
methods duplicated.

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
    mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
    mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
    mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
index 1e9aa2bdf45dbdb..9c277a0b23633d8 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
@@ -153,24 +153,18 @@ enum class Action : uint32_t {
 };
 
 /// This enum defines all the sparse representations supportable by
-/// the SparseTensor dialect.  We use a lightweight encoding to encode
-/// both the "format" per se (dense, compressed, singleton) as well as
-/// the "properties" (ordered, unique).  The encoding is chosen for
-/// performance of the runtime library, and thus may change in future
-/// versions; consequently, client code should use the predicate functions
-/// defined below, rather than relying on knowledge about the particular
-/// binary encoding.
+/// the SparseTensor dialect. We use a lightweight encoding to encode
+/// both the "format" per se (dense, compressed, singleton, loose_compressed,
+/// two-out-of-four) as well as the "properties" (ordered, unique). The
+/// encoding is chosen for performance of the runtime library, and thus may
+/// change in future versions; consequently, client code should use the
+/// predicate functions defined below, rather than relying on knowledge
+/// about the particular binary encoding.
 ///
 /// The `Undef` "format" is a special value used internally for cases
 /// where we need to store an undefined or indeterminate `DimLevelType`.
 /// It should not be used externally, since it does not indicate an
 /// actual/representable format.
-///
-// TODO: We should generalize TwoOutOfFour to N out of M and use property to
-// encode the value of N and M.
-// TODO: Update DimLevelType to use lower 8 bits for storage formats and the
-// higher 4 bits to store level properties. Consider LooseCompressed and
-// TwoOutOfFour as properties instead of formats.
 enum class DimLevelType : uint8_t {
   Undef = 0,                // 0b00000_00
   Dense = 4,                // 0b00001_00
@@ -257,44 +251,47 @@ constexpr bool isUndefDLT(DimLevelType dlt) {
   return dlt == DimLevelType::Undef;
 }
 
-/// Check if the `DimLevelType` is dense.
+/// Check if the `DimLevelType` is dense (regardless of properties).
 constexpr bool isDenseDLT(DimLevelType dlt) {
-  return dlt == DimLevelType::Dense;
-}
-
-/// Check if the `DimLevelType` is 2:4
-constexpr bool isTwoOutOfFourDLT(DimLevelType dlt) {
-  return dlt == DimLevelType::TwoOutOfFour;
+  return (static_cast<uint8_t>(dlt) & ~3) ==
+         static_cast<uint8_t>(DimLevelType::Dense);
 }
 
-// We use the idiom `(dlt & ~3) == format` in order to only return true
-// for valid DLTs.  Whereas the `dlt & format` idiom is a bit faster but
-// can return false-positives on invalid DLTs.
-
 /// Check if the `DimLevelType` is compressed (regardless of properties).
 constexpr bool isCompressedDLT(DimLevelType dlt) {
   return (static_cast<uint8_t>(dlt) & ~3) ==
          static_cast<uint8_t>(DimLevelType::Compressed);
 }
 
-/// Check if the `DimLevelType` is loose compressed (regardless of properties).
-constexpr bool isLooseCompressedDLT(DimLevelType dlt) {
-  return (static_cast<uint8_t>(dlt) & ~3) ==
-         static_cast<uint8_t>(DimLevelType::LooseCompressed);
-}
-
 /// Check if the `DimLevelType` is singleton (regardless of properties).
 constexpr bool isSingletonDLT(DimLevelType dlt) {
   return (static_cast<uint8_t>(dlt) & ~3) ==
          static_cast<uint8_t>(DimLevelType::Singleton);
 }
 
+/// Check if the `DimLevelType` is loose compressed (regardless of properties).
+constexpr bool isLooseCompressedDLT(DimLevelType dlt) {
+  return (static_cast<uint8_t>(dlt) & ~3) ==
+         static_cast<uint8_t>(DimLevelType::LooseCompressed);
+}
+
 /// Check if the `DimLevelType` is 2OutOf4 (regardless of properties).
 constexpr bool is2OutOf4DLT(DimLevelType dlt) {
   return (static_cast<uint8_t>(dlt) & ~3) ==
          static_cast<uint8_t>(DimLevelType::TwoOutOfFour);
 }
 
+/// Check if the `DimLevelType` needs positions array.
+constexpr bool isDLTWithPos(DimLevelType dlt) {
+  return isCompressedDLT(dlt) || isLooseCompressedDLT(dlt);
+}
+
+/// Check if the `DimLevelType` needs coordinates array.
+constexpr bool isDLTWithCrd(DimLevelType dlt) {
+  return isCompressedDLT(dlt) || isSingletonDLT(dlt) ||
+         isLooseCompressedDLT(dlt) || is2OutOf4DLT(dlt);
+}
+
 /// Check if the `DimLevelType` is ordered (regardless of storage format).
 constexpr bool isOrderedDLT(DimLevelType dlt) {
   return !(static_cast<uint8_t>(dlt) & 2);
@@ -325,7 +322,10 @@ buildLevelType(LevelFormat lf, bool ordered, bool unique) {
   return isValidDLT(dlt) ? std::optional(dlt) : std::nullopt;
 }
 
-/// Ensure the above conversion works as intended.
+//
+// Ensure the above methods work as indended.
+//
+
 static_assert(
     (getLevelFormat(DimLevelType::Undef) == std::nullopt &&
      *getLevelFormat(DimLevelType::Dense) == LevelFormat::Dense &&
@@ -336,7 +336,16 @@ static_assert(
      *getLevelFormat(DimLevelType::Singleton) == LevelFormat::Singleton &&
      *getLevelFormat(DimLevelType::SingletonNu) == LevelFormat::Singleton &&
      *getLevelFormat(DimLevelType::SingletonNo) == LevelFormat::Singleton &&
-     *getLevelFormat(DimLevelType::SingletonNuNo) == LevelFormat::Singleton),
+     *getLevelFormat(DimLevelType::SingletonNuNo) == LevelFormat::Singleton &&
+     *getLevelFormat(DimLevelType::LooseCompressed) ==
+         LevelFormat::LooseCompressed &&
+     *getLevelFormat(DimLevelType::LooseCompressedNu) ==
+         LevelFormat::LooseCompressed &&
+     *getLevelFormat(DimLevelType::LooseCompressedNo) ==
+         LevelFormat::LooseCompressed &&
+     *getLevelFormat(DimLevelType::LooseCompressedNuNo) ==
+         LevelFormat::LooseCompressed &&
+     *getLevelFormat(DimLevelType::TwoOutOfFour) == LevelFormat::TwoOutOfFour),
     "getLevelFormat conversion is broken");
 
 static_assert(
@@ -344,11 +353,6 @@ static_assert(
      buildLevelType(LevelFormat::Dense, true, false) == std::nullopt &&
      buildLevelType(LevelFormat::Dense, false, false) == std::nullopt &&
      *buildLevelType(LevelFormat::Dense, true, true) == DimLevelType::Dense &&
-     buildLevelType(LevelFormat::TwoOutOfFour, false, true) == std::nullopt &&
-     buildLevelType(LevelFormat::TwoOutOfFour, true, false) == std::nullopt &&
-     buildLevelType(LevelFormat::TwoOutOfFour, false, false) == std::nullopt &&
-     *buildLevelType(LevelFormat::TwoOutOfFour, true, true) ==
-         DimLevelType::TwoOutOfFour &&
      *buildLevelType(LevelFormat::Compressed, true, true) ==
          DimLevelType::Compressed &&
      *buildLevelType(LevelFormat::Compressed, true, false) ==
@@ -364,10 +368,22 @@ static_assert(
      *buildLevelType(LevelFormat::Singleton, false, true) ==
          DimLevelType::SingletonNo &&
      *buildLevelType(LevelFormat::Singleton, false, false) ==
-         DimLevelType::SingletonNuNo),
+         DimLevelType::SingletonNuNo &&
+     *buildLevelType(LevelFormat::LooseCompressed, true, true) ==
+         DimLevelType::LooseCompressed &&
+     *buildLevelType(LevelFormat::LooseCompressed, true, false) ==
+         DimLevelType::LooseCompressedNu &&
+     *buildLevelType(LevelFormat::LooseCompressed, false, true) ==
+         DimLevelType::LooseCompressedNo &&
+     *buildLevelType(LevelFormat::LooseCompressed, false, false) ==
+         DimLevelType::LooseCompressedNuNo &&
+     buildLevelType(LevelFormat::TwoOutOfFour, false, true) == std::nullopt &&
+     buildLevelType(LevelFormat::TwoOutOfFour, true, false) == std::nullopt &&
+     buildLevelType(LevelFormat::TwoOutOfFour, false, false) == std::nullopt &&
+     *buildLevelType(LevelFormat::TwoOutOfFour, true, true) ==
+         DimLevelType::TwoOutOfFour),
     "buildLevelType conversion is broken");
 
-// Ensure the above predicates work as intended.
 static_assert((isValidDLT(DimLevelType::Undef) &&
                isValidDLT(DimLevelType::Dense) &&
                isValidDLT(DimLevelType::Compressed) &&
@@ -385,6 +401,22 @@ static_assert((isValidDLT(DimLevelType::Undef) &&
                isValidDLT(DimLevelType::TwoOutOfFour)),
               "isValidDLT definition is broken");
 
+static_assert((isDenseDLT(DimLevelType::Dense) &&
+               !isDenseDLT(DimLevelType::Compressed) &&
+               !isDenseDLT(DimLevelType::CompressedNu) &&
+               !isDenseDLT(DimLevelType::CompressedNo) &&
+               !isDenseDLT(DimLevelType::CompressedNuNo) &&
+               !isDenseDLT(DimLevelType::Singleton) &&
+               !isDenseDLT(DimLevelType::SingletonNu) &&
+               !isDenseDLT(DimLevelType::SingletonNo) &&
+               !isDenseDLT(DimLevelType::SingletonNuNo) &&
+               !isDenseDLT(DimLevelType::LooseCompressed) &&
+               !isDenseDLT(DimLevelType::LooseCompressedNu) &&
+               !isDenseDLT(DimLevelType::LooseCompressedNo) &&
+               !isDenseDLT(DimLevelType::LooseCompressedNuNo) &&
+               !isDenseDLT(DimLevelType::TwoOutOfFour)),
+              "isDenseDLT definition is broken");
+
 static_assert((!isCompressedDLT(DimLevelType::Dense) &&
                isCompressedDLT(DimLevelType::Compressed) &&
                isCompressedDLT(DimLevelType::CompressedNu) &&
@@ -393,20 +425,14 @@ static_assert((!isCompressedDLT(DimLevelType::Dense) &&
                !isCompressedDLT(DimLevelType::Singleton) &&
                !isCompressedDLT(DimLevelType::SingletonNu) &&
                !isCompressedDLT(DimLevelType::SingletonNo) &&
-               !isCompressedDLT(DimLevelType::SingletonNuNo)),
+               !isCompressedDLT(DimLevelType::SingletonNuNo) &&
+               !isCompressedDLT(DimLevelType::LooseCompressed) &&
+               !isCompressedDLT(DimLevelType::LooseCompressedNu) &&
+               !isCompressedDLT(DimLevelType::LooseCompressedNo) &&
+               !isCompressedDLT(DimLevelType::LooseCompressedNuNo) &&
+               !isCompressedDLT(DimLevelType::TwoOutOfFour)),
               "isCompressedDLT definition is broken");
 
-static_assert((!isLooseCompressedDLT(DimLevelType::Dense) &&
-               isLooseCompressedDLT(DimLevelType::LooseCompressed) &&
-               isLooseCompressedDLT(DimLevelType::LooseCompressedNu) &&
-               isLooseCompressedDLT(DimLevelType::LooseCompressedNo) &&
-               isLooseCompressedDLT(DimLevelType::LooseCompressedNuNo) &&
-               !isLooseCompressedDLT(DimLevelType::Singleton) &&
-               !isLooseCompressedDLT(DimLevelType::SingletonNu) &&
-               !isLooseCompressedDLT(DimLevelType::SingletonNo) &&
-               !isLooseCompressedDLT(DimLevelType::SingletonNuNo)),
-              "isLooseCompressedDLT definition is broken");
-
 static_assert((!isSingletonDLT(DimLevelType::Dense) &&
                !isSingletonDLT(DimLevelType::Compressed) &&
                !isSingletonDLT(DimLevelType::CompressedNu) &&
@@ -415,11 +441,47 @@ static_assert((!isSingletonDLT(DimLevelType::Dense) &&
                isSingletonDLT(DimLevelType::Singleton) &&
                isSingletonDLT(DimLevelType::SingletonNu) &&
                isSingletonDLT(DimLevelType::SingletonNo) &&
-               isSingletonDLT(DimLevelType::SingletonNuNo)),
+               isSingletonDLT(DimLevelType::SingletonNuNo) &&
+               !isSingletonDLT(DimLevelType::LooseCompressed) &&
+               !isSingletonDLT(DimLevelType::LooseCompressedNu) &&
+               !isSingletonDLT(DimLevelType::LooseCompressedNo) &&
+               !isSingletonDLT(DimLevelType::LooseCompressedNuNo) &&
+               !isSingletonDLT(DimLevelType::TwoOutOfFour)),
               "isSingletonDLT definition is broken");
 
+static_assert((!isLooseCompressedDLT(DimLevelType::Dense) &&
+               !isLooseCompressedDLT(DimLevelType::Compressed) &&
+               !isLooseCompressedDLT(DimLevelType::CompressedNu) &&
+               !isLooseCompressedDLT(DimLevelType::CompressedNo) &&
+               !isLooseCompressedDLT(DimLevelType::CompressedNuNo) &&
+               !isLooseCompressedDLT(DimLevelType::Singleton) &&
+               !isLooseCompressedDLT(DimLevelType::SingletonNu) &&
+               !isLooseCompressedDLT(DimLevelType::SingletonNo) &&
+               !isLooseCompressedDLT(DimLevelType::SingletonNuNo) &&
+               isLooseCompressedDLT(DimLevelType::LooseCompressed) &&
+               isLooseCompressedDLT(DimLevelType::LooseCompressedNu) &&
+               isLooseCompressedDLT(DimLevelType::LooseCompressedNo) &&
+               isLooseCompressedDLT(DimLevelType::LooseCompressedNuNo) &&
+               !isLooseCompressedDLT(DimLevelType::TwoOutOfFour)),
+              "isLooseCompressedDLT definition is broken");
+
+static_assert((!is2OutOf4DLT(DimLevelType::Dense) &&
+               !is2OutOf4DLT(DimLevelType::Compressed) &&
+               !is2OutOf4DLT(DimLevelType::CompressedNu) &&
+               !is2OutOf4DLT(DimLevelType::CompressedNo) &&
+               !is2OutOf4DLT(DimLevelType::CompressedNuNo) &&
+               !is2OutOf4DLT(DimLevelType::Singleton) &&
+               !is2OutOf4DLT(DimLevelType::SingletonNu) &&
+               !is2OutOf4DLT(DimLevelType::SingletonNo) &&
+               !is2OutOf4DLT(DimLevelType::SingletonNuNo) &&
+               !is2OutOf4DLT(DimLevelType::LooseCompressed) &&
+               !is2OutOf4DLT(DimLevelType::LooseCompressedNu) &&
+               !is2OutOf4DLT(DimLevelType::LooseCompressedNo) &&
+               !is2OutOf4DLT(DimLevelType::LooseCompressedNuNo) &&
+               is2OutOf4DLT(DimLevelType::TwoOutOfFour)),
+              "is2OutOf4DLT definition is broken");
+
 static_assert((isOrderedDLT(DimLevelType::Dense) &&
-               isOrderedDLT(DimLevelType::TwoOutOfFour) &&
                isOrderedDLT(DimLevelType::Compressed) &&
                isOrderedDLT(DimLevelType::CompressedNu) &&
                !isOrderedDLT(DimLevelType::CompressedNo) &&
@@ -431,11 +493,11 @@ static_assert((isOrderedDLT(DimLevelType::Dense) &&
                isOrderedDLT(DimLevelType::LooseCompressed) &&
                isOrderedDLT(DimLevelType::LooseCompressedNu) &&
                !isOrderedDLT(DimLevelType::LooseCompressedNo) &&
-               !isOrderedDLT(DimLevelType::LooseCompressedNuNo)),
+               !isOrderedDLT(DimLevelType::LooseCompressedNuNo) &&
+               isOrderedDLT(DimLevelType::TwoOutOfFour)),
               "isOrderedDLT definition is broken");
 
 static_assert((isUniqueDLT(DimLevelType::Dense) &&
-               isUniqueDLT(DimLevelType::TwoOutOfFour) &&
                isUniqueDLT(DimLevelType::Compressed) &&
                !isUniqueDLT(DimLevelType::CompressedNu) &&
                isUniqueDLT(DimLevelType::CompressedNo) &&
@@ -447,7 +509,8 @@ static_assert((isUniqueDLT(DimLevelType::Dense) &&
                isUniqueDLT(DimLevelType::LooseCompressed) &&
                !isUniqueDLT(DimLevelType::LooseCompressedNu) &&
                isUniqueDLT(DimLevelType::LooseCompressedNo) &&
-               !isUniqueDLT(DimLevelType::LooseCompressedNuNo)),
+               !isUniqueDLT(DimLevelType::LooseCompressedNuNo) &&
+               isUniqueDLT(DimLevelType::TwoOutOfFour)),
               "isUniqueDLT definition is broken");
 
 /// Bit manipulations for affine encoding.

diff  --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
index 94e7d12b9ee915f..241d90a87165928 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
@@ -89,16 +89,6 @@ inline MemRefType getMemRefType(T &&t) {
 /// Returns null-attribute for any type without an encoding.
 SparseTensorEncodingAttr getSparseTensorEncoding(Type type);
 
-/// Convenience method to query whether a given DLT needs both position and
-/// coordinates array or only coordinates array.
-constexpr inline bool isDLTWithPos(DimLevelType dlt) {
-  return isLooseCompressedDLT(dlt) || isCompressedDLT(dlt);
-}
-constexpr inline bool isDLTWithCrd(DimLevelType dlt) {
-  return isSingletonDLT(dlt) || isLooseCompressedDLT(dlt) ||
-         isCompressedDLT(dlt);
-}
-
 /// Returns true iff the given sparse tensor encoding attribute has a trailing
 /// COO region starting at the given level.
 bool isCOOType(SparseTensorEncodingAttr enc, Level startLvl, bool isUnique);

diff  --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
index 3c73b19319e588c..a254f52aa86e7db 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
@@ -371,10 +371,10 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
     ::mlir::sparse_tensor::DimLevelType getLvlType(::mlir::sparse_tensor::Level l) const;
 
     bool isDenseLvl(::mlir::sparse_tensor::Level l) const { return isDenseDLT(getLvlType(l)); }
-    bool isTwoOutOfFourLvl(::mlir::sparse_tensor::Level l) const { return isTwoOutOfFourDLT(getLvlType(l)); }
     bool isCompressedLvl(::mlir::sparse_tensor::Level l) const { return isCompressedDLT(getLvlType(l)); }
-    bool isLooseCompressedLvl(::mlir::sparse_tensor::Level l) const { return isLooseCompressedDLT(getLvlType(l)); }
     bool isSingletonLvl(::mlir::sparse_tensor::Level l) const { return isSingletonDLT(getLvlType(l)); }
+    bool isLooseCompressedLvl(::mlir::sparse_tensor::Level l) const { return isLooseCompressedDLT(getLvlType(l)); }
+    bool isTwoOutOfFourLvl(::mlir::sparse_tensor::Level l) const { return is2OutOf4DLT(getLvlType(l)); }
     bool isOrderedLvl(::mlir::sparse_tensor::Level l) const { return isOrderedDLT(getLvlType(l)); }
     bool isUniqueLvl(::mlir::sparse_tensor::Level l) const { return isUniqueDLT(getLvlType(l)); }
 

diff  --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 99214fadf4ba3db..c727b8d05c26d7d 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -10,6 +10,7 @@
 
 #include "Detail/DimLvlMapParser.h"
 
+#include "mlir/Dialect/SparseTensor/IR/Enums.h"
 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
 #include "mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h"
 #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"


        


More information about the Mlir-commits mailing list