[Mlir-commits] [mlir] [mlir][sparse] support type conversion from SoA COO to memrefs. (PR #82398)

Peiming Liu llvmlistbot at llvm.org
Tue Feb 20 10:39:13 PST 2024


https://github.com/PeimingLiu updated https://github.com/llvm/llvm-project/pull/82398

>From 58b0843b0532899a11b1686257a4a36c3f975f5a Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Tue, 20 Feb 2024 18:22:49 +0000
Subject: [PATCH] [mlir][sparse] support type conversion from AoS COO to
 memrefs.

---
 .../mlir/Dialect/SparseTensor/IR/Enums.h      |  4 +-
 .../SparseTensor/IR/SparseTensorType.h        | 15 +++++
 .../SparseTensor/IR/SparseTensorDialect.cpp   | 64 +++++++++++++++++--
 mlir/test/Dialect/SparseTensor/codegen.mlir   | 26 ++++++++
 4 files changed, 100 insertions(+), 9 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
index fa34bbe3c9d910..41a14575ed1054 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
@@ -303,9 +303,9 @@ struct LevelType {
   }
 
   /// Check if the `LevelType` is in the `LevelFormat`.
-  template <LevelFormat fmt>
+  template <LevelFormat... fmt>
   constexpr bool isa() const {
-    return getLvlFmt() == fmt;
+    return (... || (getLvlFmt() == fmt)) || false;
   }
 
   /// Check if the `LevelType` has the properties
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
index 4e2b85d35c1ac1..c5d5748c53e4cc 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
@@ -18,6 +18,18 @@
 namespace mlir {
 namespace sparse_tensor {
 
+/// A simple structure that encodes a range of levels in the sparse tensors that
+/// forms a COO segment.
+struct COOSegment {
+  std::pair<Level /*low*/, Level /*high*/> lvlRange;
+  bool isSoA;
+
+  bool isSegmentStart(Level l) const { return l == lvlRange.first; }
+  bool inSegment(Level l) const {
+    return l >= lvlRange.first && l < lvlRange.second;
+  }
+};
+
 //===----------------------------------------------------------------------===//
 /// A wrapper around `RankedTensorType`, which has three goals:
 ///
@@ -330,6 +342,9 @@ class SparseTensorType {
   /// Returns [un]ordered COO type for this sparse tensor type.
   RankedTensorType getCOOType(bool ordered) const;
 
+  /// Returns a list of COO segments in the sparse tensor types.
+  SmallVector<COOSegment> getCOOSegments() const;
+
 private:
   // These two must be const, to ensure coherence of the memoized fields.
   const RankedTensorType rtp;
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index db359b4b7a5d09..222f42c00193a6 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -74,11 +74,13 @@ void StorageLayout::foreachField(
         callback) const {
   const auto lvlTypes = enc.getLvlTypes();
   const Level lvlRank = enc.getLvlRank();
-  const Level cooStart = SparseTensorType(enc).getCOOStart();
-  const Level end = cooStart == lvlRank ? cooStart : cooStart + 1;
+  SmallVector<COOSegment> cooSegs = SparseTensorType(enc).getCOOSegments();
   FieldIndex fieldIdx = kDataFieldStartingIdx;
+
+  Level l = 0;
+  ArrayRef cooSegsRef = cooSegs;
   // Per-level storage.
-  for (Level l = 0; l < end; l++) {
+  while (l < lvlRank) {
     const auto lt = lvlTypes[l];
     if (isWithPosLT(lt)) {
       if (!(callback(fieldIdx++, SparseTensorFieldKind::PosMemRef, l, lt)))
@@ -88,6 +90,21 @@ void StorageLayout::foreachField(
       if (!(callback(fieldIdx++, SparseTensorFieldKind::CrdMemRef, l, lt)))
         return;
     }
+    if (!cooSegsRef.empty() && cooSegsRef.front().isSegmentStart(l)) {
+      if (!cooSegsRef.front().isSoA) {
+        // AoS COO, all singletons are fused into one memrefs. Skips the entire
+        // COO segement.
+        l = cooSegsRef.front().lvlRange.second;
+      } else {
+        // SoA COO, each singleton level has one memref.
+        l++;
+      }
+      // Expire handled COO segment.
+      cooSegsRef = cooSegsRef.drop_front();
+    } else {
+      // Non COO levels.
+      l++;
+    }
   }
   // The values array.
   if (!(callback(fieldIdx++, SparseTensorFieldKind::ValMemRef, kInvalidLevel,
@@ -796,13 +813,46 @@ bool mlir::sparse_tensor::SparseTensorType::isCOOType(Level startLvl,
 }
 
 Level mlir::sparse_tensor::SparseTensorType::getCOOStart() const {
-  if (hasEncoding() && lvlRank > 1)
-    for (Level l = 0; l < lvlRank - 1; l++)
-      if (isCOOType(l, /*isUnique=*/false))
-        return l;
+  SmallVector<COOSegment> coo = getCOOSegments();
+  if (!coo.empty()) {
+    assert(coo.size() == 1);
+    return coo.front().lvlRange.first;
+  }
   return lvlRank;
 }
 
+SmallVector<COOSegment>
+mlir::sparse_tensor::SparseTensorType::getCOOSegments() const {
+  SmallVector<COOSegment> ret;
+  if (!hasEncoding() || lvlRank <= 1)
+    return ret;
+
+  ArrayRef<LevelType> lts = getLvlTypes();
+  Level l = 0;
+  while (l < lvlRank) {
+    auto lt = lts[l];
+    if (lt.isa<LevelFormat::Compressed, LevelFormat::LooseCompressed>()) {
+      auto cur = lts.begin() + l;
+      auto end = std::find_if(cur + 1, lts.end(), [](LevelType lt) {
+        return !lt.isa<LevelFormat::Singleton>();
+      });
+      unsigned cooLen = std::distance(cur, end);
+      if (cooLen > 1) {
+        // To support mixed SoA/AoS COO, we should break the segment when the
+        // storage scheme changes, for now we faithfully assume that all
+        // consecutive singleton levels have the same storage format as verified
+        // STEA.
+        ret.push_back(COOSegment{std::make_pair(l, l + cooLen),
+                                 lts[l + 1].isa<LevelPropNonDefault::SoA>()});
+      }
+      l += cooLen;
+    } else {
+      l++;
+    }
+  }
+  return ret;
+}
+
 RankedTensorType
 mlir::sparse_tensor::SparseTensorType::getCOOType(bool ordered) const {
   SmallVector<LevelType> lvlTypes;
diff --git a/mlir/test/Dialect/SparseTensor/codegen.mlir b/mlir/test/Dialect/SparseTensor/codegen.mlir
index a3b26972d66ff5..c1a976c84fecca 100644
--- a/mlir/test/Dialect/SparseTensor/codegen.mlir
+++ b/mlir/test/Dialect/SparseTensor/codegen.mlir
@@ -48,6 +48,10 @@
   map = (d0, d1) -> (d0 : compressed(nonunique), d1 : singleton)
 }>
 
+#SoACOO = #sparse_tensor.encoding<{
+  map = (d0, d1) -> (d0 : compressed(nonunique), d1 : singleton(soa))
+}>
+
 #CooPNo = #sparse_tensor.encoding<{
   map = (d0, d1) -> (d1 : compressed(nonunique), d0 : singleton(nonordered))
 }>
@@ -67,6 +71,28 @@ func.func @sparse_nop(%arg0: tensor<?xf64, #SparseVector>) -> tensor<?xf64, #Spa
   return %arg0 : tensor<?xf64, #SparseVector>
 }
 
+// CHECK-LABEL: func @sparse_nop_aos_coo(
+//  CHECK-SAME: %[[POS:.*0]]: memref<?xindex>,
+//  CHECK-SAME: %[[AoS_CRD:.*1]]: memref<?xindex>,
+//  CHECK-SAME: %[[VAL:.*]]: memref<?xf64>,
+//  CHECK-SAME: %[[A3:.*]]: !sparse_tensor.storage_specifier
+//       CHECK: return %[[POS]], %[[AoS_CRD]], %[[VAL]], %[[A3]]
+func.func @sparse_nop_aos_coo(%arg0: tensor<?x?xf64, #Coo>) -> tensor<?x?xf64, #Coo> {
+  return %arg0 : tensor<?x?xf64, #Coo>
+}
+
+// CHECK-LABEL: func @sparse_nop_soa_coo(
+//  CHECK-SAME: %[[POS:.*0]]: memref<?xindex>,
+//  CHECK-SAME: %[[SoA_CRD_0:.*1]]: memref<?xindex>,
+//  CHECK-SAME: %[[SoA_CRD_1:.*2]]: memref<?xindex>,
+//  CHECK-SAME: %[[VAL:.*]]: memref<?xf64>,
+//  CHECK-SAME: %[[A3:.*]]: !sparse_tensor.storage_specifier
+//       CHECK: return %[[POS]], %[[SoA_CRD_0]], %[[SoA_CRD_1]], %[[VAL]], %[[A3]]
+func.func @sparse_nop_soa_coo(%arg0: tensor<?x?xf64, #SoACOO>) -> tensor<?x?xf64, #SoACOO> {
+  return %arg0 : tensor<?x?xf64, #SoACOO>
+}
+
+
 // CHECK-LABEL: func @sparse_nop_multi_ret(
 //  CHECK-SAME: %[[A0:.*0]]: memref<?xi32>,
 //  CHECK-SAME: %[[A1:.*1]]: memref<?xi64>,



More information about the Mlir-commits mailing list