[Mlir-commits] [mlir] [mlir][sparse] Add more tests and verification for n:m (PR #81186)
Yinying Li
llvmlistbot at llvm.org
Thu Feb 8 12:22:26 PST 2024
https://github.com/yinying-lisa-li updated https://github.com/llvm/llvm-project/pull/81186
>From 61bc16760fb362b42051000346f05b7a1475ee6f Mon Sep 17 00:00:00 2001
From: Yinying Li <yinyingli at google.com>
Date: Wed, 7 Feb 2024 20:10:32 +0000
Subject: [PATCH 1/4] pybinding for n:m
---
mlir/include/mlir-c/Dialect/SparseTensor.h | 11 ++++
.../Bindings/Python/DialectSparseTensor.cpp | 39 ++++++++++++-
mlir/lib/CAPI/Dialect/SparseTensor.cpp | 18 ++++++
mlir/test/CAPI/sparse_tensor.c | 4 +-
.../python/dialects/sparse_tensor/dialect.py | 57 +++++++++++++++++++
5 files changed, 126 insertions(+), 3 deletions(-)
diff --git a/mlir/include/mlir-c/Dialect/SparseTensor.h b/mlir/include/mlir-c/Dialect/SparseTensor.h
index 2c71b0008ad16a..d4cac0e326cffd 100644
--- a/mlir/include/mlir-c/Dialect/SparseTensor.h
+++ b/mlir/include/mlir-c/Dialect/SparseTensor.h
@@ -84,6 +84,17 @@ mlirSparseTensorEncodingAttrGetPosWidth(MlirAttribute attr);
MLIR_CAPI_EXPORTED int
mlirSparseTensorEncodingAttrGetCrdWidth(MlirAttribute attr);
+MLIR_CAPI_EXPORTED unsigned
+mlirSparseTensorEncodingAttrGetStructuredN(MlirSparseTensorLevelType lvlType);
+
+MLIR_CAPI_EXPORTED unsigned
+mlirSparseTensorEncodingAttrGetStructuredM(MlirSparseTensorLevelType lvlType);
+
+MLIR_CAPI_EXPORTED MlirSparseTensorLevelType
+mlirSparseTensorEncodingAttrBuildLvlType(MlirBaseSparseTensorLevelType lvlType,
+ bool ordered, bool unique, unsigned n,
+ unsigned m);
+
#ifdef __cplusplus
}
#endif
diff --git a/mlir/lib/Bindings/Python/DialectSparseTensor.cpp b/mlir/lib/Bindings/Python/DialectSparseTensor.cpp
index 607534c6156439..ac2b6720c699b9 100644
--- a/mlir/lib/Bindings/Python/DialectSparseTensor.cpp
+++ b/mlir/lib/Bindings/Python/DialectSparseTensor.cpp
@@ -60,6 +60,16 @@ static void populateDialectSparseTensorSubmodule(const py::module &m) {
py::arg("lvl_to_dim"), py::arg("pos_width"), py::arg("crd_width"),
py::arg("context") = py::none(),
"Gets a sparse_tensor.encoding from parameters.")
+ .def_classmethod(
+ "build_level_type",
+ [](py::object cls, MlirBaseSparseTensorLevelType lvlType,
+ bool ordered, bool unique, unsigned n, unsigned m) {
+ return mlirSparseTensorEncodingAttrBuildLvlType(lvlType, ordered,
+ unique, n, m);
+ },
+ py::arg("cls"), py::arg("lvl_type"), py::arg("ordered") = true,
+ py::arg("unique") = true, py::arg("n") = 0, py::arg("m") = 0,
+ "Builds a sparse_tensor.encoding.level_type from parameters.")
.def_property_readonly(
"lvl_types",
[](MlirAttribute self) {
@@ -89,7 +99,34 @@ static void populateDialectSparseTensorSubmodule(const py::module &m) {
.def_property_readonly("pos_width",
mlirSparseTensorEncodingAttrGetPosWidth)
.def_property_readonly("crd_width",
- mlirSparseTensorEncodingAttrGetCrdWidth);
+ mlirSparseTensorEncodingAttrGetCrdWidth)
+ .def_property_readonly(
+ "structured_n",
+ [](MlirAttribute self) -> unsigned {
+ const int lvlRank = mlirSparseTensorEncodingGetLvlRank(self);
+ return mlirSparseTensorEncodingAttrGetStructuredN(
+ mlirSparseTensorEncodingAttrGetLvlType(self, lvlRank - 1));
+ })
+ .def_property_readonly(
+ "structured_m",
+ [](MlirAttribute self) -> unsigned {
+ const int lvlRank = mlirSparseTensorEncodingGetLvlRank(self);
+ return mlirSparseTensorEncodingAttrGetStructuredM(
+ mlirSparseTensorEncodingAttrGetLvlType(self, lvlRank - 1));
+ })
+ .def_property_readonly("lvl_types_enum", [](MlirAttribute self) {
+ const int lvlRank = mlirSparseTensorEncodingGetLvlRank(self);
+ std::vector<MlirBaseSparseTensorLevelType> ret;
+ ret.reserve(lvlRank);
+ for (int l = 0; l < lvlRank; l++) {
+ // Convert level type to 32 bits to ignore n and m for n_out_of_m
+ // format.
+ ret.push_back(
+ static_cast<MlirBaseSparseTensorLevelType>(static_cast<uint32_t>(
+ mlirSparseTensorEncodingAttrGetLvlType(self, l))));
+ }
+ return ret;
+ });
}
PYBIND11_MODULE(_mlirDialectsSparseTensor, m) {
diff --git a/mlir/lib/CAPI/Dialect/SparseTensor.cpp b/mlir/lib/CAPI/Dialect/SparseTensor.cpp
index a34b9a29b0e90a..2a7d47d9ece90f 100644
--- a/mlir/lib/CAPI/Dialect/SparseTensor.cpp
+++ b/mlir/lib/CAPI/Dialect/SparseTensor.cpp
@@ -94,3 +94,21 @@ int mlirSparseTensorEncodingAttrGetPosWidth(MlirAttribute attr) {
int mlirSparseTensorEncodingAttrGetCrdWidth(MlirAttribute attr) {
return cast<SparseTensorEncodingAttr>(unwrap(attr)).getCrdWidth();
}
+
+MlirSparseTensorLevelType
+mlirSparseTensorEncodingAttrBuildLvlType(MlirBaseSparseTensorLevelType lvlType,
+ bool ordered, bool unique, unsigned n,
+ unsigned m) {
+ return static_cast<MlirSparseTensorLevelType>(*buildLevelType(
+ *getLevelFormat(static_cast<LevelType>(lvlType)), ordered, unique, n, m));
+}
+
+unsigned
+mlirSparseTensorEncodingAttrGetStructuredN(MlirSparseTensorLevelType lvlType) {
+ return getN(static_cast<LevelType>(lvlType));
+}
+
+unsigned
+mlirSparseTensorEncodingAttrGetStructuredM(MlirSparseTensorLevelType lvlType) {
+ return getM(static_cast<LevelType>(lvlType));
+}
diff --git a/mlir/test/CAPI/sparse_tensor.c b/mlir/test/CAPI/sparse_tensor.c
index a8b9f9048d5912..5c21f0b1d8927d 100644
--- a/mlir/test/CAPI/sparse_tensor.c
+++ b/mlir/test/CAPI/sparse_tensor.c
@@ -26,7 +26,7 @@ static int testRoundtripEncoding(MlirContext ctx) {
// clang-format off
const char *originalAsm =
"#sparse_tensor.encoding<{ "
- "map = [s0](d0, d1) -> (s0 : dense, d0 : compressed, d1 : compressed), "
+ "map = [s0](d0, d1) -> (s0 : dense, d0 : compressed, d1 : structured[2, 4]), "
"posWidth = 32, crdWidth = 64 }>";
// clang-format on
MlirAttribute originalAttr =
@@ -40,7 +40,7 @@ static int testRoundtripEncoding(MlirContext ctx) {
mlirAffineMapDump(dimToLvl);
// CHECK: level_type: 65536
// CHECK: level_type: 131072
- // CHECK: level_type: 131072
+ // CHECK: level_type: 4406637494272
MlirAffineMap lvlToDim =
mlirSparseTensorEncodingAttrGetLvlToDim(originalAttr);
int lvlRank = mlirSparseTensorEncodingGetLvlRank(originalAttr);
diff --git a/mlir/test/python/dialects/sparse_tensor/dialect.py b/mlir/test/python/dialects/sparse_tensor/dialect.py
index 412c5797067b7a..ba7205d5162aa1 100644
--- a/mlir/test/python/dialects/sparse_tensor/dialect.py
+++ b/mlir/test/python/dialects/sparse_tensor/dialect.py
@@ -52,6 +52,63 @@ def testEncodingAttr1D():
print(f"created_pos_width: {created.pos_width}")
+# CHECK-LABEL: TEST: testEncodingAttr1DStructure
+ at run
+def testEncodingAttr1DStructure():
+ with Context() as ctx:
+ parsed = Attribute.parse(
+ "#sparse_tensor.encoding<{"
+ " map = (d0) -> (d0 : structured[2, 4]),"
+ " posWidth = 16,"
+ " crdWidth = 32"
+ "}>"
+ )
+ # CHECK: #sparse_tensor.encoding<{ map = (d0) -> (d0 : structured[2, 4]), posWidth = 16, crdWidth = 32 }>
+ print(parsed)
+
+ casted = st.EncodingAttr(parsed)
+ # CHECK: equal: True
+ print(f"equal: {casted == parsed}")
+
+ # CHECK: lvl_types: [4406637494272]
+ print(f"lvl_types: {casted.lvl_types}")
+ # CHECK: lvl_types: [<LevelType.n_out_of_m: 1048576>]
+ print(f"lvl_types_enum: {casted.lvl_types_enum}")
+ # CHECK: structured_n: 2
+ print(f"structured_n: {casted.structured_n}")
+ # CHECK: structured_m: 4
+ print(f"structured_m: {casted.structured_m}")
+ # CHECK: dim_to_lvl: (d0) -> (d0)
+ print(f"dim_to_lvl: {casted.dim_to_lvl}")
+ # CHECK: lvl_to_dim: (d0) -> (d0)
+ print(f"lvl_to_dim: {casted.lvl_to_dim}")
+ # CHECK: pos_width: 16
+ print(f"pos_width: {casted.pos_width}")
+ # CHECK: crd_width: 32
+ print(f"crd_width: {casted.crd_width}")
+
+ created = st.EncodingAttr.get(casted.lvl_types, None, None, 0, 0)
+ # CHECK: #sparse_tensor.encoding<{ map = (d0) -> (d0 : structured[2, 4]) }>
+ print(created)
+ # CHECK: created_equal: False
+ print(f"created_equal: {created == casted}")
+
+ built_type = st.EncodingAttr.build_level_type(
+ st.LevelType.n_out_of_m, True, True, 2, 4
+ )
+ built = st.EncodingAttr.get([built_type], None, None, 0, 0)
+ # CHECK: #sparse_tensor.encoding<{ map = (d0) -> (d0 : structured[2, 4]) }>
+ print(built)
+ # CHECK: built_equal: True
+ print(f"built_equal: {built == created}")
+
+ # Verify that the factory creates an instance of the proper type.
+ # CHECK: is_proper_instance: True
+ print(f"is_proper_instance: {isinstance(created, st.EncodingAttr)}")
+ # CHECK: created_pos_width: 0
+ print(f"created_pos_width: {created.pos_width}")
+
+
# CHECK-LABEL: TEST: testEncodingAttr2D
@run
def testEncodingAttr2D():
>From 00b2bf5febd20f5b891335cedb95073527f7ad9d Mon Sep 17 00:00:00 2001
From: Yinying Li <yinyingli at google.com>
Date: Wed, 7 Feb 2024 22:42:43 +0000
Subject: [PATCH 2/4] verification
---
mlir/include/mlir-c/Dialect/SparseTensor.h | 5 +-
.../Bindings/Python/DialectSparseTensor.cpp | 8 +--
mlir/lib/CAPI/Dialect/SparseTensor.cpp | 6 +-
.../SparseTensor/IR/Detail/LvlTypeParser.cpp | 28 +++++++--
.../SparseTensor/IR/SparseTensorDialect.cpp | 10 ++++
mlir/test/CAPI/sparse_tensor.c | 4 +-
.../SparseTensor/invalid_encoding.mlir | 60 +++++++++++++++++++
.../python/dialects/sparse_tensor/dialect.py | 57 +++++++++++++-----
8 files changed, 146 insertions(+), 32 deletions(-)
diff --git a/mlir/include/mlir-c/Dialect/SparseTensor.h b/mlir/include/mlir-c/Dialect/SparseTensor.h
index d4cac0e326cffd..d549f5dddc1318 100644
--- a/mlir/include/mlir-c/Dialect/SparseTensor.h
+++ b/mlir/include/mlir-c/Dialect/SparseTensor.h
@@ -91,9 +91,8 @@ MLIR_CAPI_EXPORTED unsigned
mlirSparseTensorEncodingAttrGetStructuredM(MlirSparseTensorLevelType lvlType);
MLIR_CAPI_EXPORTED MlirSparseTensorLevelType
-mlirSparseTensorEncodingAttrBuildLvlType(MlirBaseSparseTensorLevelType lvlType,
- bool ordered, bool unique, unsigned n,
- unsigned m);
+mlirSparseTensorEncodingAttrBuildLvlType(
+ enum MlirBaseSparseTensorLevelType lvlType, unsigned n, unsigned m);
#ifdef __cplusplus
}
diff --git a/mlir/lib/Bindings/Python/DialectSparseTensor.cpp b/mlir/lib/Bindings/Python/DialectSparseTensor.cpp
index ac2b6720c699b9..fe6e5e4e5b2b26 100644
--- a/mlir/lib/Bindings/Python/DialectSparseTensor.cpp
+++ b/mlir/lib/Bindings/Python/DialectSparseTensor.cpp
@@ -62,13 +62,13 @@ static void populateDialectSparseTensorSubmodule(const py::module &m) {
"Gets a sparse_tensor.encoding from parameters.")
.def_classmethod(
"build_level_type",
- [](py::object cls, MlirBaseSparseTensorLevelType lvlType,
- bool ordered, bool unique, unsigned n, unsigned m) {
+ [](py::object cls, MlirBaseSparseTensorLevelType lvlType, unsigned n,
+ unsigned m) {
return mlirSparseTensorEncodingAttrBuildLvlType(lvlType, ordered,
unique, n, m);
},
- py::arg("cls"), py::arg("lvl_type"), py::arg("ordered") = true,
- py::arg("unique") = true, py::arg("n") = 0, py::arg("m") = 0,
+ py::arg("cls"), py::arg("lvl_type"), py::arg("n") = 0,
+ py::arg("m") = 0,
"Builds a sparse_tensor.encoding.level_type from parameters.")
.def_property_readonly(
"lvl_types",
diff --git a/mlir/lib/CAPI/Dialect/SparseTensor.cpp b/mlir/lib/CAPI/Dialect/SparseTensor.cpp
index 2a7d47d9ece90f..4e1bd45863fdac 100644
--- a/mlir/lib/CAPI/Dialect/SparseTensor.cpp
+++ b/mlir/lib/CAPI/Dialect/SparseTensor.cpp
@@ -97,10 +97,10 @@ int mlirSparseTensorEncodingAttrGetCrdWidth(MlirAttribute attr) {
MlirSparseTensorLevelType
mlirSparseTensorEncodingAttrBuildLvlType(MlirBaseSparseTensorLevelType lvlType,
- bool ordered, bool unique, unsigned n,
- unsigned m) {
+ unsigned n, unsigned m) {
+ LevelType lt = static_cast<LevelType>(lvlType);
return static_cast<MlirSparseTensorLevelType>(*buildLevelType(
- *getLevelFormat(static_cast<LevelType>(lvlType)), ordered, unique, n, m));
+ *getLevelFormat(lt), isOrderedLT(lt), isUniqueLT(lt), n, m));
}
unsigned
diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp b/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp
index 752d6e6481dfee..49189a82fb0704 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp
@@ -45,6 +45,22 @@ FailureOr<uint64_t> LvlTypeParser::parseLvlType(AsmParser &parser) const {
FAILURE_IF_FAILED(res)
}
+ if (base.compare("structured") == 0) {
+ ParseResult res = parser.parseCommaSeparatedList(
+ mlir::OpAsmParser::Delimiter::OptionalSquare,
+ [&]() -> ParseResult { return parseStructure(parser, &structure); },
+ " in structure n out of m");
+ FAILURE_IF_FAILED(res)
+ if (structure.size() != 2) {
+ parser.emitError(loc, "expected exactly 2 structure sizes");
+ return failure();
+ }
+ if (structure[0] > structure[1]) {
+ parser.emitError(loc, "expected n <= m in n_out_of_m");
+ return failure();
+ }
+ }
+
ParseResult res = parser.parseCommaSeparatedList(
mlir::OpAsmParser::Delimiter::OptionalParen,
[&]() -> ParseResult { return parseProperty(parser, &properties); },
@@ -57,10 +73,6 @@ FailureOr<uint64_t> LvlTypeParser::parseLvlType(AsmParser &parser) const {
} else if (base.compare("compressed") == 0) {
properties |= static_cast<uint64_t>(LevelFormat::Compressed);
} else if (base.compare("structured") == 0) {
- if (structure.size() != 2) {
- parser.emitError(loc, "expected exactly 2 structure sizes");
- return failure();
- }
properties |= static_cast<uint64_t>(LevelFormat::NOutOfM);
properties |= nToBits(structure[0]) | mToBits(structure[1]);
} else if (base.compare("loose_compressed") == 0) {
@@ -102,13 +114,17 @@ LvlTypeParser::parseStructure(AsmParser &parser,
OptionalParseResult intValParseResult = parser.parseOptionalInteger(intVal);
if (intValParseResult.has_value()) {
if (failed(*intValParseResult)) {
- parser.emitError(loc, "failed to parse block size");
+ parser.emitError(loc, "failed to parse structure size");
+ return failure();
+ }
+ if (intVal < 0) {
+ parser.emitError(loc, "expected structure size to be >= 0");
return failure();
}
structure->push_back(intVal);
return success();
}
- parser.emitError(loc, "expected valid integer for block size");
+ parser.emitError(loc, "expected valid integer for structure size");
return failure();
}
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 67b1d7974fa259..c7a063f020794e 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -657,6 +657,16 @@ LogicalResult SparseTensorEncodingAttr::verify(
return emitError() << "expected all singleton lvlTypes "
"following a singleton level";
}
+ // TODO: audit formats that actually are supported by backend.
+ if (auto it = std::find_if(lvlTypes.begin(), lvlTypes.end(), isNOutOfMLT);
+ it != std::end(lvlTypes)) {
+ if (it != lvlTypes.end() - 1)
+ return emitError() << "expected n_out_of_m to be the last level type";
+ if (!std::all_of(lvlTypes.begin(), it,
+ [](LevelType i) { return isDenseLT(i); }))
+ return emitError() << "expected all dense lvlTypes "
+ "before a n_out_of_m level";
+ }
// Before we can check that the level-rank is consistent/coherent
// across all fields, we need to define it. The source-of-truth for
// the `getLvlRank` method is the length of the level-types array,
diff --git a/mlir/test/CAPI/sparse_tensor.c b/mlir/test/CAPI/sparse_tensor.c
index 5c21f0b1d8927d..580a619d1f7d99 100644
--- a/mlir/test/CAPI/sparse_tensor.c
+++ b/mlir/test/CAPI/sparse_tensor.c
@@ -26,7 +26,7 @@ static int testRoundtripEncoding(MlirContext ctx) {
// clang-format off
const char *originalAsm =
"#sparse_tensor.encoding<{ "
- "map = [s0](d0, d1) -> (s0 : dense, d0 : compressed, d1 : structured[2, 4]), "
+ "map = [s0](d0, d1) -> (s0 : dense, d0 : dense, d1 : structured[2, 4]), "
"posWidth = 32, crdWidth = 64 }>";
// clang-format on
MlirAttribute originalAttr =
@@ -39,7 +39,7 @@ static int testRoundtripEncoding(MlirContext ctx) {
// CHECK: (d0, d1)[s0] -> (s0, d0, d1)
mlirAffineMapDump(dimToLvl);
// CHECK: level_type: 65536
- // CHECK: level_type: 131072
+ // CHECK: level_type: 65536
// CHECK: level_type: 4406637494272
MlirAffineMap lvlToDim =
mlirSparseTensorEncodingAttrGetLvlToDim(originalAttr);
diff --git a/mlir/test/Dialect/SparseTensor/invalid_encoding.mlir b/mlir/test/Dialect/SparseTensor/invalid_encoding.mlir
index 2d189cc94c15e2..49d6e7a6d866a9 100644
--- a/mlir/test/Dialect/SparseTensor/invalid_encoding.mlir
+++ b/mlir/test/Dialect/SparseTensor/invalid_encoding.mlir
@@ -315,3 +315,63 @@ func.func private @BSR(%arg0: tensor<?x?xf64, #BSR>) {
func.func private @BSR_explicit(%arg0: tensor<?x?xf64, #BSR_explicit>) {
return
}
+
+// -----
+
+// expected-error at +6 {{expected structure size to be >= 0}}
+#NOutOfM = #sparse_tensor.encoding<{
+ map = ( i, j, k ) ->
+ ( i : dense,
+ k floordiv 4 : dense,
+ j : dense,
+ k mod 4 : structured[-2, 4]
+ )
+}>
+func.func private @NOutOfM(%arg0: tensor<?x?x?xf64, #NOutOfM>) {
+ return
+}
+
+// -----
+
+// expected-error at +6 {{expected n <= m in n_out_of_m}}
+#NOutOfM = #sparse_tensor.encoding<{
+ map = ( i, j, k ) ->
+ ( i : dense,
+ k floordiv 4 : dense,
+ j : dense,
+ k mod 4 : structured[5, 4]
+ )
+}>
+func.func private @NOutOfM(%arg0: tensor<?x?x?xf64, #NOutOfM>) {
+ return
+}
+
+// -----
+
+// expected-error at +1 {{expected all dense lvlTypes before a n_out_of_m level}}
+#NOutOfM = #sparse_tensor.encoding<{
+ map = ( i, j, k ) ->
+ ( i : dense,
+ k floordiv 4 : compressed,
+ j : dense,
+ k mod 4 : structured[2, 4]
+ )
+}>
+func.func private @NOutOfM(%arg0: tensor<?x?x?xf64, #NOutOfM>) {
+ return
+}
+
+// -----
+
+// expected-error at +1 {{expected n_out_of_m to be the last level type}}
+#NOutOfM = #sparse_tensor.encoding<{
+ map = ( i, j, k ) ->
+ ( i : dense,
+ k floordiv 4 : structured[2, 4],
+ j : dense,
+ k mod 4 : compressed
+ )
+}>
+func.func private @NOutOfM(%arg0: tensor<?x?x?xf64, #NOutOfM>) {
+ return
+}
diff --git a/mlir/test/python/dialects/sparse_tensor/dialect.py b/mlir/test/python/dialects/sparse_tensor/dialect.py
index ba7205d5162aa1..b066ba00dddae5 100644
--- a/mlir/test/python/dialects/sparse_tensor/dialect.py
+++ b/mlir/test/python/dialects/sparse_tensor/dialect.py
@@ -52,52 +52,81 @@ def testEncodingAttr1D():
print(f"created_pos_width: {created.pos_width}")
-# CHECK-LABEL: TEST: testEncodingAttr1DStructure
+# CHECK-LABEL: TEST: testEncodingAttrStructure
@run
-def testEncodingAttr1DStructure():
+def testEncodingAttrStructure():
with Context() as ctx:
parsed = Attribute.parse(
"#sparse_tensor.encoding<{"
- " map = (d0) -> (d0 : structured[2, 4]),"
+ " map = (d0, d1) -> (d0 : dense, d1 floordiv 4 : dense,"
+ " d1 mod 4 : structured[2, 4]),"
" posWidth = 16,"
" crdWidth = 32"
"}>"
)
- # CHECK: #sparse_tensor.encoding<{ map = (d0) -> (d0 : structured[2, 4]), posWidth = 16, crdWidth = 32 }>
+ # CHECK: #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 floordiv 4 : dense, d1 mod 4 : structured[2, 4]), posWidth = 16, crdWidth = 32 }>
print(parsed)
casted = st.EncodingAttr(parsed)
# CHECK: equal: True
print(f"equal: {casted == parsed}")
- # CHECK: lvl_types: [4406637494272]
+ # CHECK: lvl_types: [65536, 65536, 4406637494272]
print(f"lvl_types: {casted.lvl_types}")
- # CHECK: lvl_types: [<LevelType.n_out_of_m: 1048576>]
+ # CHECK: lvl_types_enum: [<LevelType.dense: 65536>, <LevelType.dense: 65536>, <LevelType.n_out_of_m: 1048576>]
print(f"lvl_types_enum: {casted.lvl_types_enum}")
# CHECK: structured_n: 2
print(f"structured_n: {casted.structured_n}")
# CHECK: structured_m: 4
print(f"structured_m: {casted.structured_m}")
- # CHECK: dim_to_lvl: (d0) -> (d0)
+ # CHECK: dim_to_lvl: (d0, d1) -> (d0, d1 floordiv 4, d1 mod 4)
print(f"dim_to_lvl: {casted.dim_to_lvl}")
- # CHECK: lvl_to_dim: (d0) -> (d0)
+ # CHECK: lvl_to_dim: (d0, d1, d2) -> (d0, d1 * 4 + d2)
print(f"lvl_to_dim: {casted.lvl_to_dim}")
# CHECK: pos_width: 16
print(f"pos_width: {casted.pos_width}")
# CHECK: crd_width: 32
print(f"crd_width: {casted.crd_width}")
- created = st.EncodingAttr.get(casted.lvl_types, None, None, 0, 0)
- # CHECK: #sparse_tensor.encoding<{ map = (d0) -> (d0 : structured[2, 4]) }>
+ created = st.EncodingAttr.get(
+ casted.lvl_types, casted.dim_to_lvl, casted.lvl_to_dim, 0, 0
+ )
+ # CHECK: #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 floordiv 4 : dense, d1 mod 4 : structured[2, 4]) }>
print(created)
# CHECK: created_equal: False
print(f"created_equal: {created == casted}")
- built_type = st.EncodingAttr.build_level_type(
- st.LevelType.n_out_of_m, True, True, 2, 4
+ built_2_4 = st.EncodingAttr.build_level_type(
+ st.LevelType.n_out_of_m, 2, 4
+ )
+ dim_to_lvl = AffineMap.get(
+ 2,
+ 0,
+ [
+ AffineExpr.get_dim(0),
+ AffineExpr.get_floor_div(AffineExpr.get_dim(1), 4),
+ AffineExpr.get_mod(AffineExpr.get_dim(1), 4),
+ ],
+ )
+ lvl_to_dim = AffineMap.get(
+ 3,
+ 0,
+ [
+ AffineExpr.get_dim(0),
+ AffineExpr.get_add(
+ AffineExpr.get_mul(AffineExpr.get_dim(1), 4),
+ AffineExpr.get_dim(2),
+ ),
+ ],
+ )
+ built = st.EncodingAttr.get(
+ [st.LevelType.dense, st.LevelType.dense, built_2_4],
+ dim_to_lvl,
+ lvl_to_dim,
+ 0,
+ 0,
)
- built = st.EncodingAttr.get([built_type], None, None, 0, 0)
- # CHECK: #sparse_tensor.encoding<{ map = (d0) -> (d0 : structured[2, 4]) }>
+ # CHECK: #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 floordiv 4 : dense, d1 mod 4 : structured[2, 4]) }>
print(built)
# CHECK: built_equal: True
print(f"built_equal: {built == created}")
>From bab45ddc09e96982dc27f2edd0262812e3398c17 Mon Sep 17 00:00:00 2001
From: Yinying Li <yinyingli at google.com>
Date: Thu, 8 Feb 2024 20:11:39 +0000
Subject: [PATCH 3/4] remove redundant code
---
mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp | 8 --------
1 file changed, 8 deletions(-)
diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp b/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp
index 49189a82fb0704..a585928c3fa3ee 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp
@@ -37,14 +37,6 @@ FailureOr<uint64_t> LvlTypeParser::parseLvlType(AsmParser &parser) const {
uint64_t properties = 0;
SmallVector<unsigned> structure;
- if (base.compare("structured") == 0) {
- ParseResult res = parser.parseCommaSeparatedList(
- mlir::OpAsmParser::Delimiter::OptionalSquare,
- [&]() -> ParseResult { return parseStructure(parser, &structure); },
- " in block n out of m");
- FAILURE_IF_FAILED(res)
- }
-
if (base.compare("structured") == 0) {
ParseResult res = parser.parseCommaSeparatedList(
mlir::OpAsmParser::Delimiter::OptionalSquare,
>From c944dd970b03165f9b8a0c55e8f52303d80898e7 Mon Sep 17 00:00:00 2001
From: Yinying Li <yinyingli at google.com>
Date: Thu, 8 Feb 2024 20:22:08 +0000
Subject: [PATCH 4/4] format
---
mlir/test/CAPI/sparse_tensor.c | 6 +++---
mlir/test/python/dialects/sparse_tensor/dialect.py | 4 +---
2 files changed, 4 insertions(+), 6 deletions(-)
diff --git a/mlir/test/CAPI/sparse_tensor.c b/mlir/test/CAPI/sparse_tensor.c
index 580a619d1f7d99..a8b9f9048d5912 100644
--- a/mlir/test/CAPI/sparse_tensor.c
+++ b/mlir/test/CAPI/sparse_tensor.c
@@ -26,7 +26,7 @@ static int testRoundtripEncoding(MlirContext ctx) {
// clang-format off
const char *originalAsm =
"#sparse_tensor.encoding<{ "
- "map = [s0](d0, d1) -> (s0 : dense, d0 : dense, d1 : structured[2, 4]), "
+ "map = [s0](d0, d1) -> (s0 : dense, d0 : compressed, d1 : compressed), "
"posWidth = 32, crdWidth = 64 }>";
// clang-format on
MlirAttribute originalAttr =
@@ -39,8 +39,8 @@ static int testRoundtripEncoding(MlirContext ctx) {
// CHECK: (d0, d1)[s0] -> (s0, d0, d1)
mlirAffineMapDump(dimToLvl);
// CHECK: level_type: 65536
- // CHECK: level_type: 65536
- // CHECK: level_type: 4406637494272
+ // CHECK: level_type: 131072
+ // CHECK: level_type: 131072
MlirAffineMap lvlToDim =
mlirSparseTensorEncodingAttrGetLvlToDim(originalAttr);
int lvlRank = mlirSparseTensorEncodingGetLvlRank(originalAttr);
diff --git a/mlir/test/python/dialects/sparse_tensor/dialect.py b/mlir/test/python/dialects/sparse_tensor/dialect.py
index b066ba00dddae5..1fa7030ca1be91 100644
--- a/mlir/test/python/dialects/sparse_tensor/dialect.py
+++ b/mlir/test/python/dialects/sparse_tensor/dialect.py
@@ -96,9 +96,7 @@ def testEncodingAttrStructure():
# CHECK: created_equal: False
print(f"created_equal: {created == casted}")
- built_2_4 = st.EncodingAttr.build_level_type(
- st.LevelType.n_out_of_m, 2, 4
- )
+ built_2_4 = st.EncodingAttr.build_level_type(st.LevelType.n_out_of_m, 2, 4)
dim_to_lvl = AffineMap.get(
2,
0,
More information about the Mlir-commits
mailing list