[Mlir-commits] [mlir] cad4646 - [mlir][sparse] Improve handling of NEW_SYNTAX

wren romano llvmlistbot at llvm.org
Fri Aug 4 17:53:41 PDT 2023


Author: wren romano
Date: 2023-08-04T17:53:34-07:00
New Revision: cad4646733092bb5909e2294da7d38da33c71c69

URL: https://github.com/llvm/llvm-project/commit/cad4646733092bb5909e2294da7d38da33c71c69
DIFF: https://github.com/llvm/llvm-project/commit/cad4646733092bb5909e2294da7d38da33c71c69.diff

LOG: [mlir][sparse] Improve handling of NEW_SYNTAX

Improves the conversion from `DimLvlMap` to STEA, in order to correct rank-mismatch issues in the roundtrip tests.

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D157162

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
    mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
    mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
index 8f116e6355cf44..4e94fb6134fb46 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
@@ -73,6 +73,11 @@ def SparseTensorDimSliceAttr : SparseTensor_Attr<"SparseTensorDimSlice", []> {
     "int64_t" : $stride
   );
 
+  let builders = [
+    // The nop slice (i.e., that includes everything).
+    AttrBuilder<(ins), [{ return $_get($_ctxt, 0, kDynamic, 1); }]>
+  ];
+
   let extraClassDeclaration = [{
     void print(llvm::raw_ostream &os) const;
 

diff  --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index db31ae0f0433d2..ad3fe427509ec8 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -213,6 +213,7 @@ std::string SparseTensorDimSliceAttr::getStaticString(int64_t v) {
 }
 
 void SparseTensorDimSliceAttr::print(llvm::raw_ostream &os) const {
+  assert(getImpl() && "Uninitialized SparseTensorDimSliceAttr");
   os << '(';
   os << getStaticString(getOffset());
   os << ", ";
@@ -528,10 +529,37 @@ Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) {
       ir_detail::DimLvlMapParser cParser(parser);
       auto res = cParser.parseDimLvlMap();
       RETURN_ON_FAIL(res);
-      // Proof of concept result.
-      // TODO: use DimLvlMap directly as storage representation
-      for (Level lvl = 0, lvlRank = res->getLvlRank(); lvl < lvlRank; lvl++)
-        lvlTypes.push_back(res->getLvlType(lvl));
+      // TODO: use DimLvlMap directly as storage representation, rather
+      // than converting things over.
+      const auto &dlm = *res;
+
+      ERROR_IF(!lvlTypes.empty(), "Cannot mix `lvlTypes` with `NEW_SYNTAX`")
+      const Level lvlRank = dlm.getLvlRank();
+      for (Level lvl = 0; lvl < lvlRank; lvl++)
+        lvlTypes.push_back(dlm.getLvlType(lvl));
+
+      ERROR_IF(!dimSlices.empty(), "Cannot mix `dimSlices` with `NEW_SYNTAX`")
+      const Dimension dimRank = dlm.getDimRank();
+      for (Dimension dim = 0; dim < dimRank; dim++)
+        dimSlices.push_back(dlm.getDimSlice(dim));
+      // NOTE: the old syntax requires an all-or-nothing approach to
+      // `dimSlices`; therefore, if any slice actually exists then we need
+      // to convert null-DSA into default/nop DSA.
+      const auto isDefined = [](SparseTensorDimSliceAttr slice) {
+        return static_cast<bool>(slice.getImpl());
+      };
+      if (llvm::any_of(dimSlices, isDefined)) {
+        const auto defaultSlice =
+            SparseTensorDimSliceAttr::get(parser.getContext());
+        for (Dimension dim = 0; dim < dimRank; dim++)
+          if (!isDefined(dimSlices[dim]))
+            dimSlices[dim] = defaultSlice;
+      } else {
+        dimSlices.clear();
+      }
+
+      ERROR_IF(dimToLvl, "Cannot mix `dimToLvl` with `NEW_SYNTAX`")
+      dimToLvl = dlm.getDimToLvlMap(parser.getContext());
     }
 
     // Only the last item can omit the comma

diff  --git a/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir b/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir
index 99b1ea1afd2d61..cf3a98f6e1f9c6 100644
--- a/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir
+++ b/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir
@@ -156,9 +156,9 @@ func.func private @sparse_2_out_of_4(tensor<?x?xf64, #NV_24>)
   (d0, d1) -> (d0 : dense, d1 : compressed)
 }>
 
-// CHECK-LABEL: func private @foo(
+// CHECK-LABEL: func private @CSR_implicit(
 // CHECK-SAME: tensor<?x?xf64, #sparse_tensor.encoding<{ lvlTypes = [ "dense", "compressed" ] }>>
-func.func private @foo(%arg0: tensor<?x?xf64, #CSR_implicit>) {
+func.func private @CSR_implicit(%arg0: tensor<?x?xf64, #CSR_implicit>) {
   return
 }
 
@@ -169,9 +169,9 @@ func.func private @foo(%arg0: tensor<?x?xf64, #CSR_implicit>) {
   {l0, l1} (d0 = l0, d1 = l1) -> (l0 = d0 : dense, l1 = d1 : compressed)
 }>
 
-// CHECK-LABEL: func private @foo(
+// CHECK-LABEL: func private @CSR_explicit(
 // CHECK-SAME: tensor<?x?xf64, #sparse_tensor.encoding<{ lvlTypes = [ "dense", "compressed" ] }>>
-func.func private @foo(%arg0: tensor<?x?xf64, #CSR_explicit>) {
+func.func private @CSR_explicit(%arg0: tensor<?x?xf64, #CSR_explicit>) {
   return
 }
 
@@ -187,11 +187,9 @@ func.func private @foo(%arg0: tensor<?x?xf64, #CSR_explicit>) {
   )
 }>
 
-// FIXME: should not have to use 4 dims ;-)
-//
-// CHECK-LABEL: func private @foo(
-// CHECK-SAME: tensor<?x?x?x?xf64, #sparse_tensor.encoding<{ lvlTypes = [ "compressed", "compressed", "dense", "dense" ] }>>
-func.func private @foo(%arg0: tensor<?x?x?x?xf64, #BCSR_implicit>) {
+// CHECK-LABEL: func private @BCSR_implicit(
+// CHECK-SAME: tensor<?x?xf64, #sparse_tensor.encoding<{ lvlTypes = [ "compressed", "compressed", "dense", "dense" ], dimToLvl = affine_map<(d0, d1) -> (d0 floordiv 2, d1 floordiv 3, d0 mod 2, d1 mod 3)> }>>
+func.func private @BCSR_implicit(%arg0: tensor<?x?xf64, #BCSR_implicit>) {
   return
 }
 
@@ -210,11 +208,9 @@ func.func private @foo(%arg0: tensor<?x?x?x?xf64, #BCSR_implicit>) {
   )
 }>
 
-// FIXME: should not have to use 4 dims ;-)
-//
-// CHECK-LABEL: func private @foo(
-// CHECK-SAME: tensor<?x?x?x?xf64, #sparse_tensor.encoding<{ lvlTypes = [ "compressed", "compressed", "dense", "dense" ] }>>
-func.func private @foo(%arg0: tensor<?x?x?x?xf64, #BCSR_explicit>) {
+// CHECK-LABEL: func private @BCSR_explicit(
+// CHECK-SAME: tensor<?x?xf64, #sparse_tensor.encoding<{ lvlTypes = [ "compressed", "compressed", "dense", "dense" ], dimToLvl = affine_map<(d0, d1) -> (d0 floordiv 2, d1 floordiv 3, d0 mod 2, d1 mod 3)> }>>
+func.func private @BCSR_explicit(%arg0: tensor<?x?xf64, #BCSR_explicit>) {
   return
 }
 
@@ -229,9 +225,8 @@ func.func private @foo(%arg0: tensor<?x?x?x?xf64, #BCSR_explicit>) {
   )
 }>
 
-//
-// CHECK-LABEL: func private @foo_2_out_of_4(
-// CHECK-SAME: tensor<?x?x?xf64, #sparse_tensor.encoding<{ lvlTypes = [ "dense", "dense", "compressed24" ] }>>
-func.func private @foo_2_out_of_4(%arg0: tensor<?x?x?xf64, #NV_24>) {
+// CHECK-LABEL: func private @NV_24(
+// CHECK-SAME: tensor<?x?xf64, #sparse_tensor.encoding<{ lvlTypes = [ "dense", "dense", "compressed24" ], dimToLvl = affine_map<(d0, d1) -> (d0, d1 floordiv 4, d1 mod 4)> }>>
+func.func private @NV_24(%arg0: tensor<?x?xf64, #NV_24>) {
   return
 }


        


More information about the Mlir-commits mailing list