[Mlir-commits] [mlir] [mlir][sparse] support SoA COO in codegen path. (PR #82439)

Peiming Liu llvmlistbot at llvm.org
Tue Feb 20 15:12:12 PST 2024


https://github.com/PeimingLiu updated https://github.com/llvm/llvm-project/pull/82439

>From 4ec0d2fa1bf48609d7df0dab94385a0323870bf2 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Tue, 20 Feb 2024 23:08:44 +0000
Subject: [PATCH] [mlir][sparse] support SoA COO in codegen path.

---
 .../mlir/Dialect/SparseTensor/IR/Enums.h      |  8 ++++-
 .../SparseTensor/IR/SparseTensorType.h        |  5 ++-
 .../SparseTensor/IR/SparseTensorDialect.cpp   | 14 ++++----
 .../Transforms/SparseTensorCodegen.cpp        |  6 ++--
 .../Transforms/SparseTensorRewriting.cpp      |  2 +-
 .../Transforms/Utils/CodegenUtils.cpp         |  2 +-
 .../Utils/SparseTensorDescriptor.cpp          |  2 +-
 .../Transforms/Utils/SparseTensorDescriptor.h |  2 +-
 .../SparseTensor/CPU/sparse_coo_test.mlir     | 32 +++++++++++--------
 9 files changed, 43 insertions(+), 30 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
index 41a14575ed1054..a00c9c31256c96 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
@@ -283,7 +283,13 @@ struct LevelType {
   }
   bool operator!=(const LevelType lhs) const { return !(*this == lhs); }
 
-  LevelType stripProperties() const { return LevelType(lvlBits & ~0xffff); }
+  LevelType stripStorageIrrelevantProperties() const {
+    // Properties other than `SoA` do not change the storage scheme of the
+    // sparse tensor.
+    constexpr uint64_t mask =
+        0xffff & ~static_cast<uint64_t>(LevelPropNonDefault::SoA);
+    return LevelType(lvlBits & ~mask);
+  }
 
   /// Get N of NOutOfM level type.
   constexpr uint64_t getN() const {
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
index 24a5640d820e43..1a090ddb782fdb 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
@@ -24,6 +24,7 @@ struct COOSegment {
   std::pair<Level, Level> lvlRange; // [low, high)
   bool isSoA;
 
+  bool isAoS() const { return !isSoA; }
   bool isSegmentStart(Level l) const { return l == lvlRange.first; }
   bool inSegment(Level l) const {
     return l >= lvlRange.first && l < lvlRange.second;
@@ -337,7 +338,9 @@ class SparseTensorType {
   /// Returns the starting level of this sparse tensor type for a
   /// trailing COO region that spans **at least** two levels. If
   /// no such COO region is found, then returns the level-rank.
-  Level getCOOStart() const;
+  ///
+  /// DEPRECATED: use getCOOSegment instead;
+  Level getAoSCOOStart() const;
 
   /// Returns [un]ordered COO type for this sparse tensor type.
   RankedTensorType getCOOType(bool ordered) const;
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 53e78d2c28b1d7..af7b85d458774d 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -182,7 +182,7 @@ StorageLayout::getFieldIndexAndStride(SparseTensorFieldKind kind,
   unsigned stride = 1;
   if (kind == SparseTensorFieldKind::CrdMemRef) {
     assert(lvl.has_value());
-    const Level cooStart = SparseTensorType(enc).getCOOStart();
+    const Level cooStart = SparseTensorType(enc).getAoSCOOStart();
     const Level lvlRank = enc.getLvlRank();
     if (lvl.value() >= cooStart && lvl.value() < lvlRank) {
       lvl = cooStart;
@@ -811,10 +811,10 @@ bool mlir::sparse_tensor::SparseTensorType::isCOOType(Level startLvl,
   return !isUnique || isUniqueLvl(lvlRank - 1);
 }
 
-Level mlir::sparse_tensor::SparseTensorType::getCOOStart() const {
+Level mlir::sparse_tensor::SparseTensorType::getAoSCOOStart() const {
   SmallVector<COOSegment> coo = getCOOSegments();
-  if (!coo.empty()) {
-    assert(coo.size() == 1);
+  assert(coo.size() == 1 || coo.empty());
+  if (!coo.empty() && coo.front().isAoS()) {
     return coo.front().lvlRange.first;
   }
   return lvlRank;
@@ -1051,7 +1051,7 @@ static SparseTensorEncodingAttr
 getNormalizedEncodingForSpecifier(SparseTensorEncodingAttr enc) {
   SmallVector<LevelType> lts;
   for (auto lt : enc.getLvlTypes())
-    lts.push_back(lt.stripProperties());
+    lts.push_back(lt.stripStorageIrrelevantProperties());
 
   return SparseTensorEncodingAttr::get(
       enc.getContext(), lts,
@@ -1137,7 +1137,7 @@ static LogicalResult verifyPackUnPack(Operation *op, bool requiresStaticShape,
     return op->emitError("the sparse-tensor must have an encoding attribute");
 
   // Verifies the trailing COO.
-  Level cooStartLvl = stt.getCOOStart();
+  Level cooStartLvl = stt.getAoSCOOStart();
   if (cooStartLvl < stt.getLvlRank()) {
     // We only supports trailing COO for now, must be the last input.
     auto cooTp = llvm::cast<ShapedType>(lvlTps.back());
@@ -1452,7 +1452,7 @@ LogicalResult ToCoordinatesOp::verify() {
 
 LogicalResult ToCoordinatesBufferOp::verify() {
   auto stt = getSparseTensorType(getTensor());
-  if (stt.getCOOStart() >= stt.getLvlRank())
+  if (stt.getAoSCOOStart() >= stt.getLvlRank())
     return emitError("expected sparse tensor with a COO region");
   return success();
 }
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index d4459c6ea1e521..0ccb11f3a6b858 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -194,7 +194,7 @@ static void createAllocFields(OpBuilder &builder, Location loc,
       valHeuristic =
           builder.create<arith::MulIOp>(loc, valHeuristic, lvlSizesValues[lvl]);
   } else if (sizeHint) {
-    if (stt.getCOOStart() == 0) {
+    if (stt.getAoSCOOStart() == 0) {
       posHeuristic = constantIndex(builder, loc, 2);
       crdHeuristic = builder.create<arith::MulIOp>(
           loc, constantIndex(builder, loc, lvlRank), sizeHint); // AOS
@@ -1316,7 +1316,7 @@ struct SparseAssembleOpConverter : public OpConversionPattern<AssembleOp> {
     Value posBack = c0; // index to the last value in the position array
     Value memSize = c1; // memory size for current array
 
-    Level trailCOOStart = stt.getCOOStart();
+    Level trailCOOStart = stt.getAoSCOOStart();
     Level trailCOORank = stt.getLvlRank() - trailCOOStart;
     // Sets up SparseTensorSpecifier.
     for (Level lvl = 0, lvlRank = stt.getLvlRank(); lvl < lvlRank; lvl++) {
@@ -1453,7 +1453,7 @@ struct SparseNewConverter : public OpConversionPattern<NewOp> {
     const auto dstTp = getSparseTensorType(op.getResult());
     // Creating COO with NewOp is handled by direct IR codegen. All other cases
     // are handled by rewriting.
-    if (!dstTp.hasEncoding() || dstTp.getCOOStart() != 0)
+    if (!dstTp.hasEncoding() || dstTp.getAoSCOOStart() != 0)
       return failure();
 
     // Implement as follows:
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index 7326a6a3811284..2ccb2361b5efe1 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -1180,7 +1180,7 @@ struct NewRewriter : public OpRewritePattern<NewOp> {
                                 PatternRewriter &rewriter) const override {
     Location loc = op.getLoc();
     auto stt = getSparseTensorType(op.getResult());
-    if (!stt.hasEncoding() || stt.getCOOStart() == 0)
+    if (!stt.hasEncoding() || stt.getAoSCOOStart() == 0)
       return failure();
 
     // Implement the NewOp as follows:
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.cpp
index 75a43891491879..b888dfadb9c714 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.cpp
@@ -568,7 +568,7 @@ Value sparse_tensor::genToCoordinates(OpBuilder &builder, Location loc,
   const auto srcTp = getSparseTensorType(tensor);
   const Type crdTp = srcTp.getCrdType();
   const Type memTp =
-      get1DMemRefType(crdTp, /*withLayout=*/lvl >= srcTp.getCOOStart());
+      get1DMemRefType(crdTp, /*withLayout=*/lvl >= srcTp.getAoSCOOStart());
   return builder.create<ToCoordinatesOp>(loc, memTp, tensor,
                                          builder.getIndexAttr(lvl));
 }
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.cpp
index 3ab4157475cd4c..6ac26ad550f9f3 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.cpp
@@ -103,7 +103,7 @@ void SparseTensorSpecifier::setSpecifierField(OpBuilder &builder, Location loc,
 
 Value sparse_tensor::SparseTensorDescriptor::getCrdMemRefOrView(
     OpBuilder &builder, Location loc, Level lvl) const {
-  const Level cooStart = rType.getCOOStart();
+  const Level cooStart = rType.getAoSCOOStart();
   if (lvl < cooStart)
     return getMemRefField(SparseTensorFieldKind::CrdMemRef, lvl);
 
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.h
index 3a61ec7a2236f3..c2f631605bf4b2 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.h
@@ -137,7 +137,7 @@ class SparseTensorDescriptorImpl {
   }
 
   Value getAOSMemRef() const {
-    const Level cooStart = rType.getCOOStart();
+    const Level cooStart = rType.getAoSCOOStart();
     assert(cooStart < rType.getLvlRank());
     return getMemRefField(SparseTensorFieldKind::CrdMemRef, cooStart);
   }
diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_coo_test.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_coo_test.mlir
index aaf15ecc681fc2..16252c1005ebbb 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_coo_test.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_coo_test.mlir
@@ -34,6 +34,10 @@
   map = (d0, d1) -> (d0 : compressed(nonunique), d1 : singleton)
 }>
 
+#SortedCOOSoA = #sparse_tensor.encoding<{
+  map = (d0, d1) -> (d0 : compressed(nonunique), d1 : singleton(soa))
+}>
+
 #CSR = #sparse_tensor.encoding<{
   map = (d0, d1) -> (d0 : dense, d1 : compressed)
 }>
@@ -50,7 +54,7 @@
 
 module {
   func.func @add_coo_csr(%arga: tensor<8x8xf32, #CSR>,
-                         %argb: tensor<8x8xf32, #SortedCOO>)
+                         %argb: tensor<8x8xf32, #SortedCOOSoA>)
 		         -> tensor<8x8xf32> {
     %empty = tensor.empty() : tensor<8x8xf32>
     %zero = arith.constant 0.000000e+00 : f32
@@ -59,7 +63,7 @@ module {
         outs(%empty : tensor<8x8xf32>) -> tensor<8x8xf32>
     %0 = linalg.generic #trait
       ins(%arga, %argb: tensor<8x8xf32, #CSR>,
-                        tensor<8x8xf32, #SortedCOO>)
+                        tensor<8x8xf32, #SortedCOOSoA>)
       outs(%init: tensor<8x8xf32>) {
         ^bb(%a: f32, %b: f32, %x: f32):
           %0 = arith.addf %a, %b : f32
@@ -69,7 +73,7 @@ module {
   }
 
   func.func @add_coo_coo(%arga: tensor<8x8xf32, #SortedCOO>,
-                         %argb: tensor<8x8xf32, #SortedCOO>)
+                         %argb: tensor<8x8xf32, #SortedCOOSoA>)
 		         -> tensor<8x8xf32> {
     %empty = tensor.empty() : tensor<8x8xf32>
     %zero = arith.constant 0.000000e+00 : f32
@@ -78,7 +82,7 @@ module {
         outs(%empty : tensor<8x8xf32>) -> tensor<8x8xf32>
     %0 = linalg.generic #trait
       ins(%arga, %argb: tensor<8x8xf32, #SortedCOO>,
-                        tensor<8x8xf32, #SortedCOO>)
+                        tensor<8x8xf32, #SortedCOOSoA>)
       outs(%init: tensor<8x8xf32>) {
         ^bb(%a: f32, %b: f32, %x: f32):
           %0 = arith.addf %a, %b : f32
@@ -88,12 +92,12 @@ module {
   }
 
   func.func @add_coo_coo_out_coo(%arga: tensor<8x8xf32, #SortedCOO>,
-                                 %argb: tensor<8x8xf32, #SortedCOO>)
+                                 %argb: tensor<8x8xf32, #SortedCOOSoA>)
 		                 -> tensor<8x8xf32, #SortedCOO> {
     %init = tensor.empty() : tensor<8x8xf32, #SortedCOO>
     %0 = linalg.generic #trait
       ins(%arga, %argb: tensor<8x8xf32, #SortedCOO>,
-                        tensor<8x8xf32, #SortedCOO>)
+                        tensor<8x8xf32, #SortedCOOSoA>)
       outs(%init: tensor<8x8xf32, #SortedCOO>) {
         ^bb(%a: f32, %b: f32, %x: f32):
           %0 = arith.addf %a, %b : f32
@@ -104,7 +108,7 @@ module {
 
 
   func.func @add_coo_dense(%arga: tensor<8x8xf32>,
-                           %argb: tensor<8x8xf32, #SortedCOO>)
+                           %argb: tensor<8x8xf32, #SortedCOOSoA>)
   	    	         -> tensor<8x8xf32> {
     %empty = tensor.empty() : tensor<8x8xf32>
     %zero = arith.constant 0.000000e+00 : f32
@@ -113,7 +117,7 @@ module {
         outs(%empty : tensor<8x8xf32>) -> tensor<8x8xf32>
     %0 = linalg.generic #trait
       ins(%arga, %argb: tensor<8x8xf32>,
-                        tensor<8x8xf32, #SortedCOO>)
+                        tensor<8x8xf32, #SortedCOOSoA>)
       outs(%init: tensor<8x8xf32>) {
         ^bb(%a: f32, %b: f32, %x: f32):
           %0 = arith.addf %a, %b : f32
@@ -154,19 +158,19 @@ module {
     %COO_A = sparse_tensor.convert %A
       : tensor<8x8xf32> to tensor<8x8xf32, #SortedCOO>
     %COO_B = sparse_tensor.convert %B
-      : tensor<8x8xf32> to tensor<8x8xf32, #SortedCOO>
+      : tensor<8x8xf32> to tensor<8x8xf32, #SortedCOOSoA>
 
     %C1 = call @add_coo_dense(%A, %COO_B) : (tensor<8x8xf32>,
-                                             tensor<8x8xf32, #SortedCOO>)
+                                             tensor<8x8xf32, #SortedCOOSoA>)
                                           -> tensor<8x8xf32>
     %C2 = call @add_coo_csr(%CSR_A, %COO_B) : (tensor<8x8xf32, #CSR>,
-                                               tensor<8x8xf32, #SortedCOO>)
+                                               tensor<8x8xf32, #SortedCOOSoA>)
                                             -> tensor<8x8xf32>
     %C3 = call @add_coo_coo(%COO_A, %COO_B) : (tensor<8x8xf32, #SortedCOO>,
-                                               tensor<8x8xf32, #SortedCOO>)
+                                               tensor<8x8xf32, #SortedCOOSoA>)
                                             -> tensor<8x8xf32>
     %COO_RET = call @add_coo_coo_out_coo(%COO_A, %COO_B) : (tensor<8x8xf32, #SortedCOO>,
-                                                            tensor<8x8xf32, #SortedCOO>)
+                                                            tensor<8x8xf32, #SortedCOOSoA>)
                                                          -> tensor<8x8xf32, #SortedCOO>
     %C4 = sparse_tensor.convert %COO_RET : tensor<8x8xf32, #SortedCOO> to tensor<8x8xf32>
     //
@@ -204,7 +208,7 @@ module {
     bufferization.dealloc_tensor %C4 : tensor<8x8xf32>
     bufferization.dealloc_tensor %CSR_A : tensor<8x8xf32, #CSR>
     bufferization.dealloc_tensor %COO_A : tensor<8x8xf32, #SortedCOO>
-    bufferization.dealloc_tensor %COO_B : tensor<8x8xf32, #SortedCOO>
+    bufferization.dealloc_tensor %COO_B : tensor<8x8xf32, #SortedCOOSoA>
     bufferization.dealloc_tensor %COO_RET : tensor<8x8xf32, #SortedCOO>
 
 



More information about the Mlir-commits mailing list