[Mlir-commits] [mlir] [mlir][sparse] Introduce batch level format. (PR #83082)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Feb 26 15:35:23 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-sparse

Author: Peiming Liu (PeimingLiu)

<details>
<summary>Changes</summary>



---
Full diff: https://github.com/llvm/llvm-project/pull/83082.diff


11 Files Affected:

- (modified) mlir/include/mlir-c/Dialect/SparseTensor.h (+5-4) 
- (modified) mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h (+23-5) 
- (modified) mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td (+2-1) 
- (modified) mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp (+2) 
- (modified) mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp (+4) 
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp (+2) 
- (modified) mlir/test/CAPI/sparse_tensor.c (+2-2) 
- (modified) mlir/test/Dialect/SparseTensor/invalid_encoding.mlir (+6) 
- (modified) mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir (+11) 
- (modified) mlir/test/Dialect/SparseTensor/sparse_fill_zero.mlir (+1-1) 
- (modified) mlir/test/python/dialects/sparse_tensor/dialect.py (+4-4) 


``````````diff
diff --git a/mlir/include/mlir-c/Dialect/SparseTensor.h b/mlir/include/mlir-c/Dialect/SparseTensor.h
index 898d2f12779e39..52ca7ba8a1618f 100644
--- a/mlir/include/mlir-c/Dialect/SparseTensor.h
+++ b/mlir/include/mlir-c/Dialect/SparseTensor.h
@@ -29,10 +29,11 @@ typedef uint64_t MlirSparseTensorLevelType;
 
 enum MlirSparseTensorLevelFormat {
   MLIR_SPARSE_TENSOR_LEVEL_DENSE = 0x000000010000,
-  MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED = 0x000000020000,
-  MLIR_SPARSE_TENSOR_LEVEL_SINGLETON = 0x000000040000,
-  MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED = 0x000000080000,
-  MLIR_SPARSE_TENSOR_LEVEL_N_OUT_OF_M = 0x000000100000,
+  MLIR_SPARSE_TENSOR_LEVEL_BATCH = 0x000000020000,
+  MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED = 0x000000040000,
+  MLIR_SPARSE_TENSOR_LEVEL_SINGLETON = 0x000000080000,
+  MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED = 0x000000100000,
+  MLIR_SPARSE_TENSOR_LEVEL_N_OUT_OF_M = 0x000000200000,
 };
 
 enum MlirSparseTensorLevelPropertyNondefault {
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
index 1c81d80ea7ec4e..c8404f10686307 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
@@ -154,12 +154,27 @@ enum class Action : uint32_t {
 enum class LevelFormat : uint64_t {
   Undef = 0x00000000,
   Dense = 0x00010000,
-  Compressed = 0x00020000,
-  Singleton = 0x00040000,
-  LooseCompressed = 0x00080000,
-  NOutOfM = 0x00100000,
+  Batch = 0x00020000,
+  Compressed = 0x00040000,
+  Singleton = 0x00080000,
+  LooseCompressed = 0x00100000,
+  NOutOfM = 0x00200000,
 };
 
+constexpr bool encPowOfTwo(LevelFormat fmt) {
+  auto enc = static_cast<std::underlying_type_t<LevelFormat>>(fmt);
+  // http://www.graphics.stanford.edu/~seander/bithacks.html#DetermineIfPowerOf2
+  return (enc & (enc - 1)) == 0;
+}
+
+// All LevelFormat must have only one bit set (power of two).
+static_assert(encPowOfTwo(LevelFormat::Dense) &&
+              encPowOfTwo(LevelFormat::Batch) &&
+              encPowOfTwo(LevelFormat::Compressed) &&
+              encPowOfTwo(LevelFormat::Singleton) &&
+              encPowOfTwo(LevelFormat::LooseCompressed) &&
+              encPowOfTwo(LevelFormat::NOutOfM));
+
 template <LevelFormat... targets>
 constexpr bool isAnyOfFmt(LevelFormat fmt) {
   return (... || (targets == fmt));
@@ -172,6 +187,8 @@ constexpr const char *toFormatString(LevelFormat lvlFmt) {
     return "undef";
   case LevelFormat::Dense:
     return "dense";
+  case LevelFormat::Batch:
+    return "batch";
   case LevelFormat::Compressed:
     return "compressed";
   case LevelFormat::Singleton:
@@ -228,7 +245,7 @@ struct LevelType {
     // If undefined/dense/NOutOfM, then must be unique and ordered.
     // Otherwise, the format must be one of the known ones.
     return (isAnyOfFmt<LevelFormat::Undef, LevelFormat::Dense,
-                       LevelFormat::NOutOfM>(fmt))
+                       LevelFormat::Batch, LevelFormat::NOutOfM>(fmt))
                ? (propertyBits == 0)
                : (isAnyOfFmt<LevelFormat::Compressed, LevelFormat::Singleton,
                              LevelFormat::LooseCompressed>(fmt));
@@ -375,6 +392,7 @@ inline std::optional<LevelType> buildLevelType(LevelFormat lf, bool ordered,
 }
 inline bool isUndefLT(LevelType lt) { return lt.isa<LevelFormat::Undef>(); }
 inline bool isDenseLT(LevelType lt) { return lt.isa<LevelFormat::Dense>(); }
+inline bool isBatchLT(LevelType lt) { return lt.isa<LevelFormat::Batch>(); }
 inline bool isCompressedLT(LevelType lt) {
   return lt.isa<LevelFormat::Compressed>();
 }
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
index f0b832571e68ec..ca98665256be5a 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
@@ -141,7 +141,8 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
 
     The supported level-formats are the following:
 
-    - **dense** : all entries along this level are stored
+    - **dense** : all entries along this level are stored and linearized.
+    - **batch** : all entries along this level are stored but not linearized.
     - **compressed** : only nonzeros along this level are stored
     - **loose_compressed** : as compressed, but allows for free space between regions
     - **singleton** : a variant of the compressed format, where coordinates have no siblings
diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp b/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp
index 455e90baf0a715..92e5efaa810497 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp
@@ -62,6 +62,8 @@ FailureOr<uint64_t> LvlTypeParser::parseLvlType(AsmParser &parser) const {
   // Set the base bit for properties.
   if (base.compare("dense") == 0) {
     properties |= static_cast<uint64_t>(LevelFormat::Dense);
+  } else if (base.compare("batch") == 0) {
+    properties |= static_cast<uint64_t>(LevelFormat::Batch);
   } else if (base.compare("compressed") == 0) {
     properties |= static_cast<uint64_t>(LevelFormat::Compressed);
   } else if (base.compare("structured") == 0) {
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index af7b85d458774d..fd0ed26fbde072 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -690,6 +690,10 @@ LogicalResult SparseTensorEncodingAttr::verify(
     }
   }
 
+  auto lastBatch = std::find_if(lvlTypes.rbegin(), lvlTypes.rend(), isBatchLT);
+  if (!std::all_of(lastBatch, lvlTypes.rend(), isBatchLT))
+    return emitError() << "Batch lvlType can only be leading levels.";
+
   // SoA property can only be applied on singleton level.
   auto soaLvls = llvm::make_filter_range(lvlTypes, [](LevelType lt) {
     return lt.isa<LevelPropNonDefault::SoA>();
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp
index 61a3703b73bf07..011d814cd90094 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp
@@ -1278,6 +1278,8 @@ sparse_tensor::makeSparseTensorLevel(OpBuilder &b, Location l, Value t,
   switch (lt.getLvlFmt()) {
   case LevelFormat::Dense:
     return std::make_unique<DenseLevel>(tid, lvl, sz, stt.hasEncoding());
+  case LevelFormat::Batch:
+    llvm_unreachable("not implemented");
   case LevelFormat::Compressed: {
     Value pos = genToPositions(b, l, t, lvl);
     Value crd = genToCoordinates(b, l, t, lvl);
diff --git a/mlir/test/CAPI/sparse_tensor.c b/mlir/test/CAPI/sparse_tensor.c
index a8b9f9048d5912..f241e0e5c2fb56 100644
--- a/mlir/test/CAPI/sparse_tensor.c
+++ b/mlir/test/CAPI/sparse_tensor.c
@@ -39,8 +39,8 @@ static int testRoundtripEncoding(MlirContext ctx) {
   // CHECK: (d0, d1)[s0] -> (s0, d0, d1)
   mlirAffineMapDump(dimToLvl);
   // CHECK: level_type: 65536
-  // CHECK: level_type: 131072
-  // CHECK: level_type: 131072
+  // CHECK: level_type: 262144
+  // CHECK: level_type: 262144
   MlirAffineMap lvlToDim =
       mlirSparseTensorEncodingAttrGetLvlToDim(originalAttr);
   int lvlRank = mlirSparseTensorEncodingGetLvlRank(originalAttr);
diff --git a/mlir/test/Dialect/SparseTensor/invalid_encoding.mlir b/mlir/test/Dialect/SparseTensor/invalid_encoding.mlir
index 9ed3cee2591475..8096c010ac935a 100644
--- a/mlir/test/Dialect/SparseTensor/invalid_encoding.mlir
+++ b/mlir/test/Dialect/SparseTensor/invalid_encoding.mlir
@@ -54,6 +54,12 @@ func.func private @tensor_dimlevel_size_mismatch(%arg0: tensor<8xi32, #a>) -> ()
 
 // -----
 
+// expected-error at +1 {{Batch lvlType can only be leading levels}}
+#a = #sparse_tensor.encoding<{map = (d0, d1, d2) -> (d0 : batch, d1 : compressed, d2: batch)}>
+func.func private @non_leading_batch(%arg0: tensor<?x?x?i32, #a>) -> ()
+
+// -----
+
 // expected-error at +1 {{use of undeclared identifier}}
 #a = #sparse_tensor.encoding<{map = (d0) -> (d0 : dense, d1 : compressed)}>
 func.func private @tensor_sizes_mismatch(%arg0: tensor<8xi32, #a>) -> ()
diff --git a/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir b/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir
index 9d5118ceecc587..66e61afd897dd1 100644
--- a/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir
+++ b/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir
@@ -22,6 +22,17 @@ func.func private @sparse_csr(tensor<?x?xf32, #CSR>)
 
 // -----
 
+#BCSR = #sparse_tensor.encoding<{
+  map = (d0, d1, d2) -> (d0 : batch, d1: dense, d2 : compressed),
+}>
+
+// CHECK: #[[$BCSR:.*]] = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : batch, d1 : dense, d2 : compressed) }>
+// CHECK-LABEL: func private @sparse_bcsr(
+// CHECK-SAME: tensor<?x?x?xf32, #[[$BCSR]]>)
+func.func private @sparse_bcsr(tensor<?x?x?xf32, #BCSR>)
+
+// -----
+
 #CSR_explicit = #sparse_tensor.encoding<{
   map = {l0, l1} (d0 = l0, d1 = l1) -> (l0 = d0 : dense, l1 = d1 : compressed)
 }>
diff --git a/mlir/test/Dialect/SparseTensor/sparse_fill_zero.mlir b/mlir/test/Dialect/SparseTensor/sparse_fill_zero.mlir
index d04fbe8ed5c220..6e8a26762d90fa 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_fill_zero.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_fill_zero.mlir
@@ -14,7 +14,7 @@
 // CHECK-DAG:       %[[VAL_8:.*]] = arith.constant true
 // CHECK-DAG:       %[[VAL_9:.*]] = arith.constant 100 : index
 // CHECK-DAG:       %[[VAL_10:.*]] = arith.constant 300 : index
-// CHECK-DAG:       %[[VAL_11:.*]] = arith.constant 131072 : i64
+// CHECK-DAG:       %[[VAL_11:.*]] = arith.constant 262144 : i64
 // CHECK:           %[[VAL_12:.*]] = memref.alloca() : memref<2xi64>
 // CHECK:           %[[VAL_13:.*]] = memref.cast %[[VAL_12]] : memref<2xi64> to memref<?xi64>
 // CHECK:           memref.store %[[VAL_11]], %[[VAL_12]]{{\[}}%[[VAL_5]]] : memref<2xi64>
diff --git a/mlir/test/python/dialects/sparse_tensor/dialect.py b/mlir/test/python/dialects/sparse_tensor/dialect.py
index 2c0603216ef2c2..5666d090c3d5ee 100644
--- a/mlir/test/python/dialects/sparse_tensor/dialect.py
+++ b/mlir/test/python/dialects/sparse_tensor/dialect.py
@@ -28,7 +28,7 @@ def testEncodingAttr1D():
         # CHECK: equal: True
         print(f"equal: {casted == parsed}")
 
-        # CHECK: lvl_types: [131072]
+        # CHECK: lvl_types: [262144]
         print(f"lvl_types: {casted.lvl_types}")
         # CHECK: dim_to_lvl: (d0) -> (d0)
         print(f"dim_to_lvl: {casted.dim_to_lvl}")
@@ -71,9 +71,9 @@ def testEncodingAttrStructure():
         # CHECK: equal: True
         print(f"equal: {casted == parsed}")
 
-        # CHECK: lvl_types: [65536, 65536, 4406637494272]
+        # CHECK: lvl_types: [65536, 65536, 4406638542848]
         print(f"lvl_types: {casted.lvl_types}")
-        # CHECK: lvl_formats_enum: [<LevelFormat.dense: 65536>, <LevelFormat.dense: 65536>, <LevelFormat.n_out_of_m: 1048576>]
+        # CHECK: lvl_formats_enum: [<LevelFormat.dense: 65536>, <LevelFormat.dense: 65536>, <LevelFormat.n_out_of_m: 2097152>]
         print(f"lvl_formats_enum: {casted.lvl_formats_enum}")
         # CHECK: structured_n: 2
         print(f"structured_n: {casted.structured_n}")
@@ -157,7 +157,7 @@ def testEncodingAttr2D():
         # CHECK: equal: True
         print(f"equal: {casted == parsed}")
 
-        # CHECK: lvl_types: [65536, 131072]
+        # CHECK: lvl_types: [65536, 262144]
         print(f"lvl_types: {casted.lvl_types}")
         # CHECK: dim_to_lvl: (d0, d1) -> (d1, d0)
         print(f"dim_to_lvl: {casted.dim_to_lvl}")

``````````

</details>


https://github.com/llvm/llvm-project/pull/83082


More information about the Mlir-commits mailing list