[Mlir-commits] [mlir] cad4646 - [mlir][sparse] Improve handling of NEW_SYNTAX
wren romano
llvmlistbot at llvm.org
Fri Aug 4 17:53:41 PDT 2023
Author: wren romano
Date: 2023-08-04T17:53:34-07:00
New Revision: cad4646733092bb5909e2294da7d38da33c71c69
URL: https://github.com/llvm/llvm-project/commit/cad4646733092bb5909e2294da7d38da33c71c69
DIFF: https://github.com/llvm/llvm-project/commit/cad4646733092bb5909e2294da7d38da33c71c69.diff
LOG: [mlir][sparse] Improve handling of NEW_SYNTAX
Improves the conversion from `DimLvlMap` to STEA, in order to correct rank-mismatch issues in the roundtrip tests.
Reviewed By: aartbik
Differential Revision: https://reviews.llvm.org/D157162
Added:
Modified:
mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
index 8f116e6355cf44..4e94fb6134fb46 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
@@ -73,6 +73,11 @@ def SparseTensorDimSliceAttr : SparseTensor_Attr<"SparseTensorDimSlice", []> {
"int64_t" : $stride
);
+ let builders = [
+ // The nop slice (i.e., that includes everything).
+ AttrBuilder<(ins), [{ return $_get($_ctxt, 0, kDynamic, 1); }]>
+ ];
+
let extraClassDeclaration = [{
void print(llvm::raw_ostream &os) const;
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index db31ae0f0433d2..ad3fe427509ec8 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -213,6 +213,7 @@ std::string SparseTensorDimSliceAttr::getStaticString(int64_t v) {
}
void SparseTensorDimSliceAttr::print(llvm::raw_ostream &os) const {
+ assert(getImpl() && "Uninitialized SparseTensorDimSliceAttr");
os << '(';
os << getStaticString(getOffset());
os << ", ";
@@ -528,10 +529,37 @@ Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) {
ir_detail::DimLvlMapParser cParser(parser);
auto res = cParser.parseDimLvlMap();
RETURN_ON_FAIL(res);
- // Proof of concept result.
- // TODO: use DimLvlMap directly as storage representation
- for (Level lvl = 0, lvlRank = res->getLvlRank(); lvl < lvlRank; lvl++)
- lvlTypes.push_back(res->getLvlType(lvl));
+ // TODO: use DimLvlMap directly as storage representation, rather
+ // than converting things over.
+ const auto &dlm = *res;
+
+ ERROR_IF(!lvlTypes.empty(), "Cannot mix `lvlTypes` with `NEW_SYNTAX`")
+ const Level lvlRank = dlm.getLvlRank();
+ for (Level lvl = 0; lvl < lvlRank; lvl++)
+ lvlTypes.push_back(dlm.getLvlType(lvl));
+
+ ERROR_IF(!dimSlices.empty(), "Cannot mix `dimSlices` with `NEW_SYNTAX`")
+ const Dimension dimRank = dlm.getDimRank();
+ for (Dimension dim = 0; dim < dimRank; dim++)
+ dimSlices.push_back(dlm.getDimSlice(dim));
+ // NOTE: the old syntax requires an all-or-nothing approach to
+ // `dimSlices`; therefore, if any slice actually exists then we need
+ // to convert null-DSA into default/nop DSA.
+ const auto isDefined = [](SparseTensorDimSliceAttr slice) {
+ return static_cast<bool>(slice.getImpl());
+ };
+ if (llvm::any_of(dimSlices, isDefined)) {
+ const auto defaultSlice =
+ SparseTensorDimSliceAttr::get(parser.getContext());
+ for (Dimension dim = 0; dim < dimRank; dim++)
+ if (!isDefined(dimSlices[dim]))
+ dimSlices[dim] = defaultSlice;
+ } else {
+ dimSlices.clear();
+ }
+
+ ERROR_IF(dimToLvl, "Cannot mix `dimToLvl` with `NEW_SYNTAX`")
+ dimToLvl = dlm.getDimToLvlMap(parser.getContext());
}
// Only the last item can omit the comma
diff --git a/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir b/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir
index 99b1ea1afd2d61..cf3a98f6e1f9c6 100644
--- a/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir
+++ b/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir
@@ -156,9 +156,9 @@ func.func private @sparse_2_out_of_4(tensor<?x?xf64, #NV_24>)
(d0, d1) -> (d0 : dense, d1 : compressed)
}>
-// CHECK-LABEL: func private @foo(
+// CHECK-LABEL: func private @CSR_implicit(
// CHECK-SAME: tensor<?x?xf64, #sparse_tensor.encoding<{ lvlTypes = [ "dense", "compressed" ] }>>
-func.func private @foo(%arg0: tensor<?x?xf64, #CSR_implicit>) {
+func.func private @CSR_implicit(%arg0: tensor<?x?xf64, #CSR_implicit>) {
return
}
@@ -169,9 +169,9 @@ func.func private @foo(%arg0: tensor<?x?xf64, #CSR_implicit>) {
{l0, l1} (d0 = l0, d1 = l1) -> (l0 = d0 : dense, l1 = d1 : compressed)
}>
-// CHECK-LABEL: func private @foo(
+// CHECK-LABEL: func private @CSR_explicit(
// CHECK-SAME: tensor<?x?xf64, #sparse_tensor.encoding<{ lvlTypes = [ "dense", "compressed" ] }>>
-func.func private @foo(%arg0: tensor<?x?xf64, #CSR_explicit>) {
+func.func private @CSR_explicit(%arg0: tensor<?x?xf64, #CSR_explicit>) {
return
}
@@ -187,11 +187,9 @@ func.func private @foo(%arg0: tensor<?x?xf64, #CSR_explicit>) {
)
}>
-// FIXME: should not have to use 4 dims ;-)
-//
-// CHECK-LABEL: func private @foo(
-// CHECK-SAME: tensor<?x?x?x?xf64, #sparse_tensor.encoding<{ lvlTypes = [ "compressed", "compressed", "dense", "dense" ] }>>
-func.func private @foo(%arg0: tensor<?x?x?x?xf64, #BCSR_implicit>) {
+// CHECK-LABEL: func private @BCSR_implicit(
+// CHECK-SAME: tensor<?x?xf64, #sparse_tensor.encoding<{ lvlTypes = [ "compressed", "compressed", "dense", "dense" ], dimToLvl = affine_map<(d0, d1) -> (d0 floordiv 2, d1 floordiv 3, d0 mod 2, d1 mod 3)> }>>
+func.func private @BCSR_implicit(%arg0: tensor<?x?xf64, #BCSR_implicit>) {
return
}
@@ -210,11 +208,9 @@ func.func private @foo(%arg0: tensor<?x?x?x?xf64, #BCSR_implicit>) {
)
}>
-// FIXME: should not have to use 4 dims ;-)
-//
-// CHECK-LABEL: func private @foo(
-// CHECK-SAME: tensor<?x?x?x?xf64, #sparse_tensor.encoding<{ lvlTypes = [ "compressed", "compressed", "dense", "dense" ] }>>
-func.func private @foo(%arg0: tensor<?x?x?x?xf64, #BCSR_explicit>) {
+// CHECK-LABEL: func private @BCSR_explicit(
+// CHECK-SAME: tensor<?x?xf64, #sparse_tensor.encoding<{ lvlTypes = [ "compressed", "compressed", "dense", "dense" ], dimToLvl = affine_map<(d0, d1) -> (d0 floordiv 2, d1 floordiv 3, d0 mod 2, d1 mod 3)> }>>
+func.func private @BCSR_explicit(%arg0: tensor<?x?xf64, #BCSR_explicit>) {
return
}
@@ -229,9 +225,8 @@ func.func private @foo(%arg0: tensor<?x?x?x?xf64, #BCSR_explicit>) {
)
}>
-//
-// CHECK-LABEL: func private @foo_2_out_of_4(
-// CHECK-SAME: tensor<?x?x?xf64, #sparse_tensor.encoding<{ lvlTypes = [ "dense", "dense", "compressed24" ] }>>
-func.func private @foo_2_out_of_4(%arg0: tensor<?x?x?xf64, #NV_24>) {
+// CHECK-LABEL: func private @NV_24(
+// CHECK-SAME: tensor<?x?xf64, #sparse_tensor.encoding<{ lvlTypes = [ "dense", "dense", "compressed24" ], dimToLvl = affine_map<(d0, d1) -> (d0, d1 floordiv 4, d1 mod 4)> }>>
+func.func private @NV_24(%arg0: tensor<?x?xf64, #NV_24>) {
return
}
More information about the Mlir-commits
mailing list