[Mlir-commits] [mlir] [mlir][sparse] allow multiple COO segments in sparse encodings. (PR #91786)

Peiming Liu llvmlistbot at llvm.org
Fri May 10 11:16:38 PDT 2024


https://github.com/PeimingLiu created https://github.com/llvm/llvm-project/pull/91786

**NOTE**: we still have implementation holes when handling multiple COO segments in the encoding. But the format should be considered to be legal.

>From 245bfe15fb18a2c62657e771818232b85bc031e5 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Fri, 10 May 2024 18:13:44 +0000
Subject: [PATCH] [mlir][sparse] allow multiple COO segments in sparse
 encodings.

---
 .../SparseTensor/IR/SparseTensorDialect.cpp       | 15 ++++++++++-----
 .../Dialect/SparseTensor/roundtrip_encoding.mlir  | 11 +++++++++++
 2 files changed, 21 insertions(+), 5 deletions(-)

diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 4cc6ee971d4a3..4adb1c19096a2 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -788,24 +788,29 @@ LogicalResult SparseTensorEncodingAttr::verify(
     return emitError() << "unexpected position bitwidth: " << posWidth;
   if (!acceptBitWidth(crdWidth))
     return emitError() << "unexpected coordinate bitwidth: " << crdWidth;
-  if (auto it = std::find_if(lvlTypes.begin(), lvlTypes.end(), isSingletonLT);
-      it != std::end(lvlTypes)) {
+
+  // Verify every COO segment.
+  auto *it = std::find_if(lvlTypes.begin(), lvlTypes.end(), isSingletonLT);
+  while (it != lvlTypes.end()) {
     if (it == lvlTypes.begin() ||
-        (!isCompressedLT(*(it - 1)) && !isLooseCompressedLT(*(it - 1))))
+        !(it - 1)->isa<LevelFormat::Compressed, LevelFormat::LooseCompressed>())
       return emitError() << "expected compressed or loose_compressed level "
                             "before singleton level";
-    if (!std::all_of(it, lvlTypes.end(),
+
+    auto *curCOOEnd = std::find_if_not(it, lvlTypes.end(), isSingletonLT);
+    if (!std::all_of(it, curCOOEnd,
                      [](LevelType i) { return isSingletonLT(i); }))
       return emitError() << "expected all singleton lvlTypes "
                             "following a singleton level";
     // We can potentially support mixed SoA/AoS singleton levels.
-    if (!std::all_of(it, lvlTypes.end(), [it](LevelType i) {
+    if (!std::all_of(it, curCOOEnd, [it](LevelType i) {
           return it->isa<LevelPropNonDefault::SoA>() ==
                  i.isa<LevelPropNonDefault::SoA>();
         })) {
       return emitError() << "expected all singleton lvlTypes stored in the "
                             "same memory layout (SoA vs AoS).";
     }
+    it = std::find_if(curCOOEnd, lvlTypes.end(), isSingletonLT);
   }
 
   auto lastBatch = std::find_if(lvlTypes.rbegin(), lvlTypes.rend(), isBatchLT);
diff --git a/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir b/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir
index 7fb1c76c1a1ff..44710cad246c6 100644
--- a/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir
+++ b/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir
@@ -156,6 +156,17 @@ func.func private @sparse_coo(tensor<?x?xf32, #COO>)
 
 // -----
 
+#COO_DENSE = #sparse_tensor.encoding<{
+  map = (d0, d1, d2) -> (d0 : compressed(nonunique), d1 : singleton, d2: dense)
+}>
+
+// CHECK-DAG: #[[$COO:.*]] = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : compressed(nonunique), d1 : singleton, d2 : dense) }>
+// CHECK-LABEL: func private @sparse_coo_trailing_dense(
+// CHECK-SAME: tensor<?x?x1xf32, #[[$COO]]>)
+func.func private @sparse_coo_trailing_dense(tensor<?x?x1xf32, #COO_DENSE>)
+
+// -----
+
 #BCOO = #sparse_tensor.encoding<{
   map = (d0, d1, d2) -> (d0 : dense, d1 : loose_compressed(nonunique), d2 : singleton)
 }>



More information about the Mlir-commits mailing list