[Mlir-commits] [mlir] b958954 - [mlir][sparse] introduce a new compressed(hi) dimension level type

Peiming Liu llvmlistbot at llvm.org
Tue Apr 18 16:26:19 PDT 2023


Author: Peiming Liu
Date: 2023-04-18T23:26:11Z
New Revision: b9589545c45b8b42a7434075151a4b1d4b798a70

URL: https://github.com/llvm/llvm-project/commit/b9589545c45b8b42a7434075151a4b1d4b798a70
DIFF: https://github.com/llvm/llvm-project/commit/b9589545c45b8b42a7434075151a4b1d4b798a70.diff

LOG: [mlir][sparse] introduce a new compressed(hi) dimension level type

`compressed(hi)` is similar to `compressed`, but instead of reusing the previous position high as the current position low, it uses a pair of positions for each sparse index.

The patch only introduces the definition (syntax) but does not provide codegen implementation.

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D148664

Added: 
    

Modified: 
    mlir/include/mlir-c/Dialect/SparseTensor.h
    mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
    mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
    mlir/lib/Bindings/Python/DialectSparseTensor.cpp
    mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
    mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir-c/Dialect/SparseTensor.h b/mlir/include/mlir-c/Dialect/SparseTensor.h
index 7d560dd80a90a..8a6763b6ca89e 100644
--- a/mlir/include/mlir-c/Dialect/SparseTensor.h
+++ b/mlir/include/mlir-c/Dialect/SparseTensor.h
@@ -26,15 +26,19 @@ MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(SparseTensor, sparse_tensor);
 /// If updating, keep them in sync and update the static_assert in the impl
 /// file.
 enum MlirSparseTensorDimLevelType {
-  MLIR_SPARSE_TENSOR_DIM_LEVEL_DENSE = 4,             // 0b001_00
-  MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED = 8,        // 0b010_00
-  MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_NU = 9,     // 0b010_01
-  MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_NO = 10,    // 0b010_10
-  MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_NU_NO = 11, // 0b010_11
-  MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON = 16,        // 0b100_00
-  MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON_NU = 17,     // 0b100_01
-  MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON_NO = 18,     // 0b100_10
-  MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON_NU_NO = 19,  // 0b100_11
+  MLIR_SPARSE_TENSOR_DIM_LEVEL_DENSE = 4,                     // 0b0001_00
+  MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED = 8,                // 0b0010_00
+  MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_NU = 9,             // 0b0010_01
+  MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_NO = 10,            // 0b0010_10
+  MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_NU_NO = 11,         // 0b0010_11
+  MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON = 16,                // 0b0100_00
+  MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON_NU = 17,             // 0b0100_01
+  MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON_NO = 18,             // 0b0100_10
+  MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON_NU_NO = 19,          // 0b0100_11
+  MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_WITH_HI = 32,       // 0b1000_00
+  MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_WITH_HI_NU = 33,    // 0b1000_01
+  MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_WITH_HI_NO = 34,    // 0b1000_10
+  MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_WITH_HI_NU_NO = 35, // 0b1000_11
 };
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
index 89065b15bf230..acb543adb81b8 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
@@ -172,24 +172,29 @@ enum class Action : uint32_t {
 /// It should not be used externally, since it does not indicate an
 /// actual/representable format.
 enum class DimLevelType : uint8_t {
-  Undef = 0,           // 0b000_00
-  Dense = 4,           // 0b001_00
-  Compressed = 8,      // 0b010_00
-  CompressedNu = 9,    // 0b010_01
-  CompressedNo = 10,   // 0b010_10
-  CompressedNuNo = 11, // 0b010_11
-  Singleton = 16,      // 0b100_00
-  SingletonNu = 17,    // 0b100_01
-  SingletonNo = 18,    // 0b100_10
-  SingletonNuNo = 19,  // 0b100_11
+  Undef = 0,                 // 0b0000_00
+  Dense = 4,                 // 0b0001_00
+  Compressed = 8,            // 0b0010_00
+  CompressedNu = 9,          // 0b0010_01
+  CompressedNo = 10,         // 0b0010_10
+  CompressedNuNo = 11,       // 0b0010_11
+  Singleton = 16,            // 0b0100_00
+  SingletonNu = 17,          // 0b0100_01
+  SingletonNo = 18,          // 0b0100_10
+  SingletonNuNo = 19,        // 0b0100_11
+  CompressedWithHi = 32,     // 0b1000_00
+  CompressedWithHiNu = 33,   // 0b1000_01
+  CompressedWithHiNo = 34,   // 0b1000_10
+  CompressedWithHiNuNo = 35, // 0b1000_11
 };
 
 /// This enum defines all the storage formats supported by the sparse compiler,
 /// without the level properties.
 enum class LevelFormat : uint8_t {
-  Dense = 4,      // 0b001_00
-  Compressed = 8, // 0b010_00
-  Singleton = 16, // 0b100_00
+  Dense = 4,             // 0b0001_00
+  Compressed = 8,        // 0b0010_00
+  Singleton = 16,        // 0b0100_00
+  CompressedWithHi = 32, // 0b1000_00
 };
 
 /// Returns string representation of the given dimension level type.
@@ -216,6 +221,14 @@ inline std::string toMLIRString(DimLevelType dlt) {
     return "singleton-no";
   case DimLevelType::SingletonNuNo:
     return "singleton-nu-no";
+  case DimLevelType::CompressedWithHi:
+    return "compressed-hi";
+  case DimLevelType::CompressedWithHiNu:
+    return "compressed-hi-nu";
+  case DimLevelType::CompressedWithHiNo:
+    return "compressed-hi-no";
+  case DimLevelType::CompressedWithHiNuNo:
+    return "compressed-hi-nu-no";
   }
   return "";
 }
@@ -226,8 +239,9 @@ constexpr bool isValidDLT(DimLevelType dlt) {
   const uint8_t propertyBits = static_cast<uint8_t>(dlt) & 3;
   // If undefined or dense, then must be unique and ordered.
   // Otherwise, the format must be one of the known ones.
-  return (formatBits <= 1) ? (propertyBits == 0)
-                           : (formatBits == 2 || formatBits == 4);
+  return (formatBits <= 1)
+             ? (propertyBits == 0)
+             : (formatBits == 2 || formatBits == 4 || formatBits == 8);
 }
 
 /// Check if the `DimLevelType` is the special undefined value.
@@ -250,6 +264,12 @@ constexpr bool isCompressedDLT(DimLevelType dlt) {
          static_cast<uint8_t>(DimLevelType::Compressed);
 }
 
+/// Check if the `DimLevelType` is compressed (regardless of properties).
+constexpr bool isCompressedWithHiDLT(DimLevelType dlt) {
+  return (static_cast<uint8_t>(dlt) & ~3) ==
+         static_cast<uint8_t>(DimLevelType::CompressedWithHi);
+}
+
 /// Check if the `DimLevelType` is singleton (regardless of properties).
 constexpr bool isSingletonDLT(DimLevelType dlt) {
   return (static_cast<uint8_t>(dlt) & ~3) ==
@@ -333,7 +353,11 @@ static_assert((isValidDLT(DimLevelType::Undef) &&
                isValidDLT(DimLevelType::Singleton) &&
                isValidDLT(DimLevelType::SingletonNu) &&
                isValidDLT(DimLevelType::SingletonNo) &&
-               isValidDLT(DimLevelType::SingletonNuNo)),
+               isValidDLT(DimLevelType::SingletonNuNo) &&
+               isValidDLT(DimLevelType::CompressedWithHi) &&
+               isValidDLT(DimLevelType::CompressedWithHiNu) &&
+               isValidDLT(DimLevelType::CompressedWithHiNo) &&
+               isValidDLT(DimLevelType::CompressedWithHiNuNo)),
               "isValidDLT definition is broken");
 
 static_assert((!isCompressedDLT(DimLevelType::Dense) &&
@@ -347,6 +371,17 @@ static_assert((!isCompressedDLT(DimLevelType::Dense) &&
                !isCompressedDLT(DimLevelType::SingletonNuNo)),
               "isCompressedDLT definition is broken");
 
+static_assert((!isCompressedWithHiDLT(DimLevelType::Dense) &&
+               isCompressedWithHiDLT(DimLevelType::CompressedWithHi) &&
+               isCompressedWithHiDLT(DimLevelType::CompressedWithHiNu) &&
+               isCompressedWithHiDLT(DimLevelType::CompressedWithHiNo) &&
+               isCompressedWithHiDLT(DimLevelType::CompressedWithHiNuNo) &&
+               !isCompressedWithHiDLT(DimLevelType::Singleton) &&
+               !isCompressedWithHiDLT(DimLevelType::SingletonNu) &&
+               !isCompressedWithHiDLT(DimLevelType::SingletonNo) &&
+               !isCompressedWithHiDLT(DimLevelType::SingletonNuNo)),
+              "isCompressedWithHiDLT definition is broken");
+
 static_assert((!isSingletonDLT(DimLevelType::Dense) &&
                !isSingletonDLT(DimLevelType::Compressed) &&
                !isSingletonDLT(DimLevelType::CompressedNu) &&
@@ -366,7 +401,11 @@ static_assert((isOrderedDLT(DimLevelType::Dense) &&
                isOrderedDLT(DimLevelType::Singleton) &&
                isOrderedDLT(DimLevelType::SingletonNu) &&
                !isOrderedDLT(DimLevelType::SingletonNo) &&
-               !isOrderedDLT(DimLevelType::SingletonNuNo)),
+               !isOrderedDLT(DimLevelType::SingletonNuNo) &&
+               isOrderedDLT(DimLevelType::CompressedWithHi) &&
+               isOrderedDLT(DimLevelType::CompressedWithHiNu) &&
+               !isOrderedDLT(DimLevelType::CompressedWithHiNo) &&
+               !isOrderedDLT(DimLevelType::CompressedWithHiNuNo)),
               "isOrderedDLT definition is broken");
 
 static_assert((isUniqueDLT(DimLevelType::Dense) &&
@@ -377,7 +416,11 @@ static_assert((isUniqueDLT(DimLevelType::Dense) &&
                isUniqueDLT(DimLevelType::Singleton) &&
                !isUniqueDLT(DimLevelType::SingletonNu) &&
                isUniqueDLT(DimLevelType::SingletonNo) &&
-               !isUniqueDLT(DimLevelType::SingletonNuNo)),
+               !isUniqueDLT(DimLevelType::SingletonNuNo) &&
+               isUniqueDLT(DimLevelType::CompressedWithHi) &&
+               !isUniqueDLT(DimLevelType::CompressedWithHiNu) &&
+               isUniqueDLT(DimLevelType::CompressedWithHiNo) &&
+               !isUniqueDLT(DimLevelType::CompressedWithHiNuNo)),
               "isUniqueDLT definition is broken");
 
 } // namespace sparse_tensor

diff  --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
index a5b96a86596eb..5ba7ff8744795 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
@@ -337,6 +337,7 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
 
     bool isDenseLvl(::mlir::sparse_tensor::Level l) const { return isDenseDLT(getLvlType(l)); }
     bool isCompressedLvl(::mlir::sparse_tensor::Level l) const { return isCompressedDLT(getLvlType(l)); }
+    bool isCompressedWithHiLvl(::mlir::sparse_tensor::Level l) const { return isCompressedWithHiDLT(getLvlType(l)); }
     bool isSingletonLvl(::mlir::sparse_tensor::Level l) const { return isSingletonDLT(getLvlType(l)); }
     bool isOrderedLvl(::mlir::sparse_tensor::Level l) const { return isOrderedDLT(getLvlType(l)); }
     bool isUniqueLvl(::mlir::sparse_tensor::Level l) const { return isUniqueDLT(getLvlType(l)); }

diff  --git a/mlir/lib/Bindings/Python/DialectSparseTensor.cpp b/mlir/lib/Bindings/Python/DialectSparseTensor.cpp
index e84937df89ba1..0e07f256344f9 100644
--- a/mlir/lib/Bindings/Python/DialectSparseTensor.cpp
+++ b/mlir/lib/Bindings/Python/DialectSparseTensor.cpp
@@ -26,7 +26,14 @@ static void populateDialectSparseTensorSubmodule(const py::module &m) {
       .value("singleton", MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON)
       .value("singleton-nu", MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON_NU)
       .value("singleton-no", MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON_NO)
-      .value("singleton-nu-no", MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON_NU_NO);
+      .value("singleton-nu-no", MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON_NU_NO)
+      .value("compressed-hi", MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_WITH_HI)
+      .value("compressed-hi-nu",
+             MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_WITH_HI_NU)
+      .value("compressed-hi-no",
+             MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_WITH_HI_NO)
+      .value("compressed-hi-nu-no",
+             MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_WITH_HI_NU_NO);
 
   mlir_attribute_subclass(m, "EncodingAttr",
                           mlirAttributeIsASparseTensorEncodingAttr)

diff  --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 11e6ac81b7e14..1a93b14f780dd 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -198,12 +198,19 @@ SparseTensorEncodingAttr::getStaticLvlSliceStride(Level lvl) const {
   return getStaticDimSliceStride(toOrigDim(*this, lvl));
 }
 
-const static DimLevelType validDLTs[] = {
-    DimLevelType::Dense,          DimLevelType::Compressed,
-    DimLevelType::CompressedNu,   DimLevelType::CompressedNo,
-    DimLevelType::CompressedNuNo, DimLevelType::Singleton,
-    DimLevelType::SingletonNu,    DimLevelType::SingletonNo,
-    DimLevelType::SingletonNuNo};
+const static DimLevelType validDLTs[] = {DimLevelType::Dense,
+                                         DimLevelType::Compressed,
+                                         DimLevelType::CompressedNu,
+                                         DimLevelType::CompressedNo,
+                                         DimLevelType::CompressedNuNo,
+                                         DimLevelType::Singleton,
+                                         DimLevelType::SingletonNu,
+                                         DimLevelType::SingletonNo,
+                                         DimLevelType::SingletonNuNo,
+                                         DimLevelType::CompressedWithHi,
+                                         DimLevelType::CompressedWithHiNu,
+                                         DimLevelType::CompressedWithHiNo,
+                                         DimLevelType::CompressedWithHiNuNo};
 
 static std::optional<DimLevelType> parseDLT(StringRef str) {
   for (DimLevelType dlt : validDLTs)

diff  --git a/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir b/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir
index d2d6fab60f124..087ce42f18779 100644
--- a/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir
+++ b/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir
@@ -55,6 +55,16 @@ func.func private @sparse_coo(tensor<?x?xf32, #COO>)
 
 // -----
 
+#BCOO = #sparse_tensor.encoding<{
+  dimLevelType = [ "dense", "compressed-hi-nu", "singleton" ]
+}>
+
+// CHECK-LABEL: func private @sparse_bcoo(
+// CHECK-SAME: tensor<?x?x?xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed-hi-nu", "singleton" ] }>>)
+func.func private @sparse_bcoo(tensor<?x?x?xf32, #BCOO>)
+
+// -----
+
 #SortedCOO = #sparse_tensor.encoding<{
   dimLevelType = [ "compressed-nu", "singleton" ]
 }>


        


More information about the Mlir-commits mailing list