[Mlir-commits] [mlir] [mlir][sparse] Introduce batch level format. (PR #83082)
Peiming Liu
llvmlistbot at llvm.org
Mon Feb 26 15:34:53 PST 2024
https://github.com/PeimingLiu created https://github.com/llvm/llvm-project/pull/83082
None
>From f1865f7157e1275a87b11e30ec7d13a5088b53bc Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Mon, 26 Feb 2024 23:30:23 +0000
Subject: [PATCH] [mlir][sparse] Introduce batch level format.
---
mlir/include/mlir-c/Dialect/SparseTensor.h | 9 +++---
.../mlir/Dialect/SparseTensor/IR/Enums.h | 28 +++++++++++++++----
.../SparseTensor/IR/SparseTensorAttrDefs.td | 3 +-
.../SparseTensor/IR/Detail/LvlTypeParser.cpp | 2 ++
.../SparseTensor/IR/SparseTensorDialect.cpp | 4 +++
.../Transforms/Utils/SparseTensorLevel.cpp | 2 ++
mlir/test/CAPI/sparse_tensor.c | 4 +--
.../SparseTensor/invalid_encoding.mlir | 6 ++++
.../SparseTensor/roundtrip_encoding.mlir | 11 ++++++++
.../SparseTensor/sparse_fill_zero.mlir | 2 +-
.../python/dialects/sparse_tensor/dialect.py | 8 +++---
11 files changed, 62 insertions(+), 17 deletions(-)
diff --git a/mlir/include/mlir-c/Dialect/SparseTensor.h b/mlir/include/mlir-c/Dialect/SparseTensor.h
index 898d2f12779e39..52ca7ba8a1618f 100644
--- a/mlir/include/mlir-c/Dialect/SparseTensor.h
+++ b/mlir/include/mlir-c/Dialect/SparseTensor.h
@@ -29,10 +29,11 @@ typedef uint64_t MlirSparseTensorLevelType;
enum MlirSparseTensorLevelFormat {
MLIR_SPARSE_TENSOR_LEVEL_DENSE = 0x000000010000,
- MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED = 0x000000020000,
- MLIR_SPARSE_TENSOR_LEVEL_SINGLETON = 0x000000040000,
- MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED = 0x000000080000,
- MLIR_SPARSE_TENSOR_LEVEL_N_OUT_OF_M = 0x000000100000,
+ MLIR_SPARSE_TENSOR_LEVEL_BATCH = 0x000000020000,
+ MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED = 0x000000040000,
+ MLIR_SPARSE_TENSOR_LEVEL_SINGLETON = 0x000000080000,
+ MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED = 0x000000100000,
+ MLIR_SPARSE_TENSOR_LEVEL_N_OUT_OF_M = 0x000000200000,
};
enum MlirSparseTensorLevelPropertyNondefault {
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
index 1c81d80ea7ec4e..c8404f10686307 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
@@ -154,12 +154,27 @@ enum class Action : uint32_t {
enum class LevelFormat : uint64_t {
Undef = 0x00000000,
Dense = 0x00010000,
- Compressed = 0x00020000,
- Singleton = 0x00040000,
- LooseCompressed = 0x00080000,
- NOutOfM = 0x00100000,
+ Batch = 0x00020000,
+ Compressed = 0x00040000,
+ Singleton = 0x00080000,
+ LooseCompressed = 0x00100000,
+ NOutOfM = 0x00200000,
};
+constexpr bool encPowOfTwo(LevelFormat fmt) {
+ auto enc = static_cast<std::underlying_type_t<LevelFormat>>(fmt);
+ // http://www.graphics.stanford.edu/~seander/bithacks.html#DetermineIfPowerOf2
+ return (enc & (enc - 1)) == 0;
+}
+
+// All LevelFormat must have only one bit set (power of two).
+static_assert(encPowOfTwo(LevelFormat::Dense) &&
+ encPowOfTwo(LevelFormat::Batch) &&
+ encPowOfTwo(LevelFormat::Compressed) &&
+ encPowOfTwo(LevelFormat::Singleton) &&
+ encPowOfTwo(LevelFormat::LooseCompressed) &&
+ encPowOfTwo(LevelFormat::NOutOfM));
+
template <LevelFormat... targets>
constexpr bool isAnyOfFmt(LevelFormat fmt) {
return (... || (targets == fmt));
@@ -172,6 +187,8 @@ constexpr const char *toFormatString(LevelFormat lvlFmt) {
return "undef";
case LevelFormat::Dense:
return "dense";
+ case LevelFormat::Batch:
+ return "batch";
case LevelFormat::Compressed:
return "compressed";
case LevelFormat::Singleton:
@@ -228,7 +245,7 @@ struct LevelType {
// If undefined/dense/NOutOfM, then must be unique and ordered.
// Otherwise, the format must be one of the known ones.
return (isAnyOfFmt<LevelFormat::Undef, LevelFormat::Dense,
- LevelFormat::NOutOfM>(fmt))
+ LevelFormat::Batch, LevelFormat::NOutOfM>(fmt))
? (propertyBits == 0)
: (isAnyOfFmt<LevelFormat::Compressed, LevelFormat::Singleton,
LevelFormat::LooseCompressed>(fmt));
@@ -375,6 +392,7 @@ inline std::optional<LevelType> buildLevelType(LevelFormat lf, bool ordered,
}
inline bool isUndefLT(LevelType lt) { return lt.isa<LevelFormat::Undef>(); }
inline bool isDenseLT(LevelType lt) { return lt.isa<LevelFormat::Dense>(); }
+inline bool isBatchLT(LevelType lt) { return lt.isa<LevelFormat::Batch>(); }
inline bool isCompressedLT(LevelType lt) {
return lt.isa<LevelFormat::Compressed>();
}
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
index f0b832571e68ec..ca98665256be5a 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
@@ -141,7 +141,8 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
The supported level-formats are the following:
- - **dense** : all entries along this level are stored
+ - **dense** : all entries along this level are stored and linearized.
+ - **batch** : all entries along this level are stored but not linearized.
- **compressed** : only nonzeros along this level are stored
- **loose_compressed** : as compressed, but allows for free space between regions
- **singleton** : a variant of the compressed format, where coordinates have no siblings
diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp b/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp
index 455e90baf0a715..92e5efaa810497 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp
@@ -62,6 +62,8 @@ FailureOr<uint64_t> LvlTypeParser::parseLvlType(AsmParser &parser) const {
// Set the base bit for properties.
if (base.compare("dense") == 0) {
properties |= static_cast<uint64_t>(LevelFormat::Dense);
+ } else if (base.compare("batch") == 0) {
+ properties |= static_cast<uint64_t>(LevelFormat::Batch);
} else if (base.compare("compressed") == 0) {
properties |= static_cast<uint64_t>(LevelFormat::Compressed);
} else if (base.compare("structured") == 0) {
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index af7b85d458774d..fd0ed26fbde072 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -690,6 +690,10 @@ LogicalResult SparseTensorEncodingAttr::verify(
}
}
+ auto lastBatch = std::find_if(lvlTypes.rbegin(), lvlTypes.rend(), isBatchLT);
+ if (!std::all_of(lastBatch, lvlTypes.rend(), isBatchLT))
+ return emitError() << "Batch lvlType can only be leading levels.";
+
// SoA property can only be applied on singleton level.
auto soaLvls = llvm::make_filter_range(lvlTypes, [](LevelType lt) {
return lt.isa<LevelPropNonDefault::SoA>();
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp
index 61a3703b73bf07..011d814cd90094 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp
@@ -1278,6 +1278,8 @@ sparse_tensor::makeSparseTensorLevel(OpBuilder &b, Location l, Value t,
switch (lt.getLvlFmt()) {
case LevelFormat::Dense:
return std::make_unique<DenseLevel>(tid, lvl, sz, stt.hasEncoding());
+ case LevelFormat::Batch:
+ llvm_unreachable("not implemented");
case LevelFormat::Compressed: {
Value pos = genToPositions(b, l, t, lvl);
Value crd = genToCoordinates(b, l, t, lvl);
diff --git a/mlir/test/CAPI/sparse_tensor.c b/mlir/test/CAPI/sparse_tensor.c
index a8b9f9048d5912..f241e0e5c2fb56 100644
--- a/mlir/test/CAPI/sparse_tensor.c
+++ b/mlir/test/CAPI/sparse_tensor.c
@@ -39,8 +39,8 @@ static int testRoundtripEncoding(MlirContext ctx) {
// CHECK: (d0, d1)[s0] -> (s0, d0, d1)
mlirAffineMapDump(dimToLvl);
// CHECK: level_type: 65536
- // CHECK: level_type: 131072
- // CHECK: level_type: 131072
+ // CHECK: level_type: 262144
+ // CHECK: level_type: 262144
MlirAffineMap lvlToDim =
mlirSparseTensorEncodingAttrGetLvlToDim(originalAttr);
int lvlRank = mlirSparseTensorEncodingGetLvlRank(originalAttr);
diff --git a/mlir/test/Dialect/SparseTensor/invalid_encoding.mlir b/mlir/test/Dialect/SparseTensor/invalid_encoding.mlir
index 9ed3cee2591475..8096c010ac935a 100644
--- a/mlir/test/Dialect/SparseTensor/invalid_encoding.mlir
+++ b/mlir/test/Dialect/SparseTensor/invalid_encoding.mlir
@@ -54,6 +54,12 @@ func.func private @tensor_dimlevel_size_mismatch(%arg0: tensor<8xi32, #a>) -> ()
// -----
+// expected-error at +1 {{Batch lvlType can only be leading levels}}
+#a = #sparse_tensor.encoding<{map = (d0, d1, d2) -> (d0 : batch, d1 : compressed, d2: batch)}>
+func.func private @non_leading_batch(%arg0: tensor<?x?x?i32, #a>) -> ()
+
+// -----
+
// expected-error at +1 {{use of undeclared identifier}}
#a = #sparse_tensor.encoding<{map = (d0) -> (d0 : dense, d1 : compressed)}>
func.func private @tensor_sizes_mismatch(%arg0: tensor<8xi32, #a>) -> ()
diff --git a/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir b/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir
index 9d5118ceecc587..66e61afd897dd1 100644
--- a/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir
+++ b/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir
@@ -22,6 +22,17 @@ func.func private @sparse_csr(tensor<?x?xf32, #CSR>)
// -----
+#BCSR = #sparse_tensor.encoding<{
+ map = (d0, d1, d2) -> (d0 : batch, d1: dense, d2 : compressed),
+}>
+
+// CHECK: #[[$BCSR:.*]] = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : batch, d1 : dense, d2 : compressed) }>
+// CHECK-LABEL: func private @sparse_bcsr(
+// CHECK-SAME: tensor<?x?x?xf32, #[[$BCSR]]>)
+func.func private @sparse_bcsr(tensor<?x?x?xf32, #BCSR>)
+
+// -----
+
#CSR_explicit = #sparse_tensor.encoding<{
map = {l0, l1} (d0 = l0, d1 = l1) -> (l0 = d0 : dense, l1 = d1 : compressed)
}>
diff --git a/mlir/test/Dialect/SparseTensor/sparse_fill_zero.mlir b/mlir/test/Dialect/SparseTensor/sparse_fill_zero.mlir
index d04fbe8ed5c220..6e8a26762d90fa 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_fill_zero.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_fill_zero.mlir
@@ -14,7 +14,7 @@
// CHECK-DAG: %[[VAL_8:.*]] = arith.constant true
// CHECK-DAG: %[[VAL_9:.*]] = arith.constant 100 : index
// CHECK-DAG: %[[VAL_10:.*]] = arith.constant 300 : index
-// CHECK-DAG: %[[VAL_11:.*]] = arith.constant 131072 : i64
+// CHECK-DAG: %[[VAL_11:.*]] = arith.constant 262144 : i64
// CHECK: %[[VAL_12:.*]] = memref.alloca() : memref<2xi64>
// CHECK: %[[VAL_13:.*]] = memref.cast %[[VAL_12]] : memref<2xi64> to memref<?xi64>
// CHECK: memref.store %[[VAL_11]], %[[VAL_12]]{{\[}}%[[VAL_5]]] : memref<2xi64>
diff --git a/mlir/test/python/dialects/sparse_tensor/dialect.py b/mlir/test/python/dialects/sparse_tensor/dialect.py
index 2c0603216ef2c2..5666d090c3d5ee 100644
--- a/mlir/test/python/dialects/sparse_tensor/dialect.py
+++ b/mlir/test/python/dialects/sparse_tensor/dialect.py
@@ -28,7 +28,7 @@ def testEncodingAttr1D():
# CHECK: equal: True
print(f"equal: {casted == parsed}")
- # CHECK: lvl_types: [131072]
+ # CHECK: lvl_types: [262144]
print(f"lvl_types: {casted.lvl_types}")
# CHECK: dim_to_lvl: (d0) -> (d0)
print(f"dim_to_lvl: {casted.dim_to_lvl}")
@@ -71,9 +71,9 @@ def testEncodingAttrStructure():
# CHECK: equal: True
print(f"equal: {casted == parsed}")
- # CHECK: lvl_types: [65536, 65536, 4406637494272]
+ # CHECK: lvl_types: [65536, 65536, 4406638542848]
print(f"lvl_types: {casted.lvl_types}")
- # CHECK: lvl_formats_enum: [<LevelFormat.dense: 65536>, <LevelFormat.dense: 65536>, <LevelFormat.n_out_of_m: 1048576>]
+ # CHECK: lvl_formats_enum: [<LevelFormat.dense: 65536>, <LevelFormat.dense: 65536>, <LevelFormat.n_out_of_m: 2097152>]
print(f"lvl_formats_enum: {casted.lvl_formats_enum}")
# CHECK: structured_n: 2
print(f"structured_n: {casted.structured_n}")
@@ -157,7 +157,7 @@ def testEncodingAttr2D():
# CHECK: equal: True
print(f"equal: {casted == parsed}")
- # CHECK: lvl_types: [65536, 131072]
+ # CHECK: lvl_types: [65536, 262144]
print(f"lvl_types: {casted.lvl_types}")
# CHECK: dim_to_lvl: (d0, d1) -> (d1, d0)
print(f"dim_to_lvl: {casted.dim_to_lvl}")
More information about the Mlir-commits
mailing list