[Mlir-commits] [mlir] [mlir][sparse] provide an AoS "view" into sparse runtime support lib (PR #87116)

Aart Bik llvmlistbot at llvm.org
Fri Mar 29 15:07:59 PDT 2024


https://github.com/aartbik created https://github.com/llvm/llvm-project/pull/87116

Note that even though the sparse runtime support lib always uses SoA storage for COO storage (and provides correct codegen by means of views into this storage), in some rare cases we need the true physical SoA storage as a coordinate buffer. This PR provides that functionality by means of a (costly) coordinate buffer call.

Since this is currently only used for testing/debugging by means of the sparse_tensor.print method, this solution is acceptable. If we ever want a performing version of this, we should truly support AoS storage of COO in addition to the SoA used right now.

>From bc77d46c2181feec70550bf8c6cf7289792360ab Mon Sep 17 00:00:00 2001
From: Aart Bik <ajcbik at google.com>
Date: Fri, 29 Mar 2024 14:56:39 -0700
Subject: [PATCH] [mlir][sparse] provide an AoS "view" into sparse runtime
 support lib

Note that even though the sparse runtime support lib always uses
SoA storage for COO storage (and provides correct codegen by means
of views into this storage), in some rare cases we need the true
physical SoA storage as a coordinate buffer. This PR provides
that functionality by means of a (costly) coordinate buffer call.

Since this is currently only used for testing/debugging by means
of the sparse_tensor.print method, this solution is acceptable.
If we ever want a performing version of this, we should truly
support AoS storage of COO in addition to the SoA used right now.
---
 .../ExecutionEngine/SparseTensor/Storage.h    | 35 ++++++++++++
 .../ExecutionEngine/SparseTensorRuntime.h     |  8 +++
 .../Transforms/SparseTensorConversion.cpp     | 56 ++++++++++++++----
 .../Transforms/SparseTensorRewriting.cpp      |  4 +-
 .../ExecutionEngine/SparseTensor/Storage.cpp  |  8 +++
 .../ExecutionEngine/SparseTensorRuntime.cpp   |  7 +++
 .../SparseTensor/CPU/sparse_print.mlir        | 57 +++++++++++++++----
 7 files changed, 152 insertions(+), 23 deletions(-)

diff --git a/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h b/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h
index 773957a8b51162..80e3fec22694fb 100644
--- a/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h
+++ b/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h
@@ -143,6 +143,12 @@ class SparseTensorStorageBase {
   MLIR_SPARSETENSOR_FOREVERY_FIXED_O(DECL_GETCOORDINATES)
 #undef DECL_GETCOORDINATES
 
+  /// Gets coordinates-overhead storage buffer for the given level.
+#define DECL_GETCOORDINATESBUFFER(INAME, C)                                    \
+  virtual void getCoordinatesBuffer(std::vector<C> **, uint64_t);
+  MLIR_SPARSETENSOR_FOREVERY_FIXED_O(DECL_GETCOORDINATESBUFFER)
+#undef DECL_GETCOORDINATESBUFFER
+
   /// Gets primary storage.
 #define DECL_GETVALUES(VNAME, V) virtual void getValues(std::vector<V> **);
   MLIR_SPARSETENSOR_FOREVERY_V(DECL_GETVALUES)
@@ -251,6 +257,31 @@ class SparseTensorStorage final : public SparseTensorStorageBase {
     assert(lvl < getLvlRank());
     *out = &coordinates[lvl];
   }
+  void getCoordinatesBuffer(std::vector<C> **out, uint64_t lvl) final {
+    assert(out && "Received nullptr for out parameter");
+    assert(lvl < getLvlRank());
+    // Note that the sparse tensor support library always stores COO in SoA
+    // format, even when AoS is requested. This is never an issue, since all
+    // actual code/library generation requests "views" into the coordinate
+    // storage for the individual levels, which is trivially provided for
+    // both AoS and SoA (as well as all the other storage formats). The only
+    // exception is when the buffer version of coordinate storage is requested
+    // (currently only for printing). In that case, we do the following
+    // potentially expensive transformation to provide that view. If this
+    // operation becomes more common beyond debugging, we should consider
+    // implementing proper AoS in the support library as well.
+    uint64_t lvlRank = getLvlRank();
+    uint64_t nnz = values.size();
+    crdBuffer.clear();
+    crdBuffer.reserve(nnz * (lvlRank - lvl));
+    for (uint64_t i = 0; i < nnz; i++) {
+      for (uint64_t l = lvl; l < lvlRank; l++) {
+        assert(i < coordinates[l].size());
+        crdBuffer.push_back(coordinates[l][i]);
+      }
+    }
+    *out = &crdBuffer;
+  }
   void getValues(std::vector<V> **out) final {
     assert(out && "Received nullptr for out parameter");
     *out = &values;
@@ -529,10 +560,14 @@ class SparseTensorStorage final : public SparseTensorStorageBase {
     return -1u;
   }
 
+  // Sparse tensor storage components.
   std::vector<std::vector<P>> positions;
   std::vector<std::vector<C>> coordinates;
   std::vector<V> values;
+
+  // Auxiliary data structures.
   std::vector<uint64_t> lvlCursor;
+  std::vector<C> crdBuffer; // just for AoS view
 };
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/ExecutionEngine/SparseTensorRuntime.h b/mlir/include/mlir/ExecutionEngine/SparseTensorRuntime.h
index d916186c835c2e..396f76fd8f921a 100644
--- a/mlir/include/mlir/ExecutionEngine/SparseTensorRuntime.h
+++ b/mlir/include/mlir/ExecutionEngine/SparseTensorRuntime.h
@@ -77,6 +77,14 @@ MLIR_SPARSETENSOR_FOREVERY_O(DECL_SPARSEPOSITIONS)
 MLIR_SPARSETENSOR_FOREVERY_O(DECL_SPARSECOORDINATES)
 #undef DECL_SPARSECOORDINATES
 
+/// Tensor-storage method to obtain direct access to the coordinates array
+/// buffer for the given level (provides an AoS view into the library).
+#define DECL_SPARSECOORDINATES(CNAME, C)                                       \
+  MLIR_CRUNNERUTILS_EXPORT void _mlir_ciface_sparseCoordinatesBuffer##CNAME(   \
+      StridedMemRefType<C, 1> *out, void *tensor, index_type lvl);
+MLIR_SPARSETENSOR_FOREVERY_O(DECL_SPARSECOORDINATES)
+#undef DECL_SPARSECOORDINATES
+
 /// Tensor-storage method to insert elements in lexicographical
 /// level-coordinate order.
 #define DECL_LEXINSERT(VNAME, V)                                               \
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
index 92c98b34af6027..c52fa3751e6b4a 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
@@ -275,7 +275,7 @@ static Value genPositionsCall(OpBuilder &builder, Location loc,
       .getResult(0);
 }
 
-/// Generates a call to obtain the coordindates array.
+/// Generates a call to obtain the coordinates array.
 static Value genCoordinatesCall(OpBuilder &builder, Location loc,
                                 SparseTensorType stt, Value ptr, Level l) {
   Type crdTp = stt.getCrdType();
@@ -287,6 +287,20 @@ static Value genCoordinatesCall(OpBuilder &builder, Location loc,
       .getResult(0);
 }
 
+/// Generates a call to obtain the coordinates array (AoS view).
+static Value genCoordinatesBufferCall(OpBuilder &builder, Location loc,
+                                      SparseTensorType stt, Value ptr,
+                                      Level l) {
+  Type crdTp = stt.getCrdType();
+  auto resTp = MemRefType::get({ShapedType::kDynamic}, crdTp);
+  Value lvl = constantIndex(builder, loc, l);
+  SmallString<25> name{"sparseCoordinatesBuffer",
+                       overheadTypeFunctionSuffix(crdTp)};
+  return createFuncCall(builder, loc, name, resTp, {ptr, lvl},
+                        EmitCInterface::On)
+      .getResult(0);
+}
+
 //===----------------------------------------------------------------------===//
 // Conversion rules.
 //===----------------------------------------------------------------------===//
@@ -518,13 +532,35 @@ class SparseTensorToCoordinatesConverter
   LogicalResult
   matchAndRewrite(ToCoordinatesOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
+    const Location loc = op.getLoc();
+    auto stt = getSparseTensorType(op.getTensor());
+    auto crds = genCoordinatesCall(rewriter, loc, stt, adaptor.getTensor(),
+                                   op.getLevel());
+    // Cast the MemRef type to the type expected by the users, though these
+    // two types should be compatible at runtime.
+    if (op.getType() != crds.getType())
+      crds = rewriter.create<memref::CastOp>(loc, op.getType(), crds);
+    rewriter.replaceOp(op, crds);
+    return success();
+  }
+};
+
+/// Sparse conversion rule for coordinate accesses (AoS style).
+class SparseToCoordinatesBufferConverter
+    : public OpConversionPattern<ToCoordinatesBufferOp> {
+public:
+  using OpConversionPattern::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(ToCoordinatesBufferOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    const Location loc = op.getLoc();
     auto stt = getSparseTensorType(op.getTensor());
-    auto crds = genCoordinatesCall(rewriter, op.getLoc(), stt,
-                                   adaptor.getTensor(), op.getLevel());
+    auto crds = genCoordinatesBufferCall(
+        rewriter, loc, stt, adaptor.getTensor(), stt.getAoSCOOStart());
     // Cast the MemRef type to the type expected by the users, though these
     // two types should be compatible at runtime.
     if (op.getType() != crds.getType())
-      crds = rewriter.create<memref::CastOp>(op.getLoc(), op.getType(), crds);
+      crds = rewriter.create<memref::CastOp>(loc, op.getType(), crds);
     rewriter.replaceOp(op, crds);
     return success();
   }
@@ -878,10 +914,10 @@ void mlir::populateSparseTensorConversionPatterns(TypeConverter &typeConverter,
            SparseTensorAllocConverter, SparseTensorEmptyConverter,
            SparseTensorDeallocConverter, SparseTensorReorderCOOConverter,
            SparseTensorToPositionsConverter, SparseTensorToCoordinatesConverter,
-           SparseTensorToValuesConverter, SparseNumberOfEntriesConverter,
-           SparseTensorLoadConverter, SparseTensorInsertConverter,
-           SparseTensorExpandConverter, SparseTensorCompressConverter,
-           SparseTensorAssembleConverter, SparseTensorDisassembleConverter,
-           SparseHasRuntimeLibraryConverter>(typeConverter,
-                                             patterns.getContext());
+           SparseToCoordinatesBufferConverter, SparseTensorToValuesConverter,
+           SparseNumberOfEntriesConverter, SparseTensorLoadConverter,
+           SparseTensorInsertConverter, SparseTensorExpandConverter,
+           SparseTensorCompressConverter, SparseTensorAssembleConverter,
+           SparseTensorDisassembleConverter, SparseHasRuntimeLibraryConverter>(
+          typeConverter, patterns.getContext());
 }
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index 17f70d0796ccfc..b117c1694e45b8 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -648,7 +648,9 @@ struct PrintRewriter : public OpRewritePattern<PrintOp> {
             loc, lvl, vector::PrintPunctuation::NoPunctuation);
         rewriter.create<vector::PrintOp>(loc, rewriter.getStringAttr("] : "));
         Value crd = nullptr;
-        // TODO: eliminates ToCoordinateBufferOp!
+        // For COO AoS storage, we want to print a single, linear view of
+        // the full coordinate storage at this level. For any other storage,
+        // we show the coordinate storage for every indivual level.
         if (stt.getAoSCOOStart() == l)
           crd = rewriter.create<ToCoordinatesBufferOp>(loc, tensor);
         else
diff --git a/mlir/lib/ExecutionEngine/SparseTensor/Storage.cpp b/mlir/lib/ExecutionEngine/SparseTensor/Storage.cpp
index aaa42a7e3a31bf..acb2d1bb5bed62 100644
--- a/mlir/lib/ExecutionEngine/SparseTensor/Storage.cpp
+++ b/mlir/lib/ExecutionEngine/SparseTensor/Storage.cpp
@@ -68,6 +68,14 @@ MLIR_SPARSETENSOR_FOREVERY_FIXED_O(IMPL_GETPOSITIONS)
 MLIR_SPARSETENSOR_FOREVERY_FIXED_O(IMPL_GETCOORDINATES)
 #undef IMPL_GETCOORDINATES
 
+#define IMPL_GETCOORDINATESBUFFER(CNAME, C)                                    \
+  void SparseTensorStorageBase::getCoordinatesBuffer(std::vector<C> **,        \
+                                                     uint64_t) {               \
+    FATAL_PIV("getCoordinatesBuffer" #CNAME);                                  \
+  }
+MLIR_SPARSETENSOR_FOREVERY_FIXED_O(IMPL_GETCOORDINATESBUFFER)
+#undef IMPL_GETCOORDINATESBUFFER
+
 #define IMPL_GETVALUES(VNAME, V)                                               \
   void SparseTensorStorageBase::getValues(std::vector<V> **) {                 \
     FATAL_PIV("getValues" #VNAME);                                             \
diff --git a/mlir/lib/ExecutionEngine/SparseTensorRuntime.cpp b/mlir/lib/ExecutionEngine/SparseTensorRuntime.cpp
index 8835056099d234..f160b0f40fb0a3 100644
--- a/mlir/lib/ExecutionEngine/SparseTensorRuntime.cpp
+++ b/mlir/lib/ExecutionEngine/SparseTensorRuntime.cpp
@@ -311,6 +311,7 @@ MLIR_SPARSETENSOR_FOREVERY_V(IMPL_SPARSEVALUES)
     assert(v);                                                                 \
     aliasIntoMemref(v->size(), v->data(), *ref);                               \
   }
+
 #define IMPL_SPARSEPOSITIONS(PNAME, P)                                         \
   IMPL_GETOVERHEAD(sparsePositions##PNAME, P, getPositions)
 MLIR_SPARSETENSOR_FOREVERY_O(IMPL_SPARSEPOSITIONS)
@@ -320,6 +321,12 @@ MLIR_SPARSETENSOR_FOREVERY_O(IMPL_SPARSEPOSITIONS)
   IMPL_GETOVERHEAD(sparseCoordinates##CNAME, C, getCoordinates)
 MLIR_SPARSETENSOR_FOREVERY_O(IMPL_SPARSECOORDINATES)
 #undef IMPL_SPARSECOORDINATES
+
+#define IMPL_SPARSECOORDINATESBUFFER(CNAME, C)                                 \
+  IMPL_GETOVERHEAD(sparseCoordinatesBuffer##CNAME, C, getCoordinatesBuffer)
+MLIR_SPARSETENSOR_FOREVERY_O(IMPL_SPARSECOORDINATESBUFFER)
+#undef IMPL_SPARSECOORDINATESBUFFER
+
 #undef IMPL_GETOVERHEAD
 
 #define IMPL_LEXINSERT(VNAME, V)                                               \
diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_print.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_print.mlir
index 98d76ba350cbd9..7758ca77dce9ea 100755
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_print.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_print.mlir
@@ -120,6 +120,14 @@
   )
 }>
 
+#COOAoS = #sparse_tensor.encoding<{
+  map = (d0, d1) -> (d0 : compressed(nonunique), d1 : singleton)
+}>
+
+#COOSoA = #sparse_tensor.encoding<{
+  map = (d0, d1) -> (d0 : compressed(nonunique), d1 : singleton(soa))
+}>
+
 module {
 
   //
@@ -161,6 +169,8 @@ module {
     %h = sparse_tensor.convert %x : tensor<4x8xi32> to tensor<4x8xi32, #BSCC>
     %i = sparse_tensor.convert %x : tensor<4x8xi32> to tensor<4x8xi32, #BSR0>
     %j = sparse_tensor.convert %x : tensor<4x8xi32> to tensor<4x8xi32, #BSC0>
+    %AoS = sparse_tensor.convert %x : tensor<4x8xi32> to tensor<4x8xi32, #COOAoS>
+    %SoA = sparse_tensor.convert %x : tensor<4x8xi32> to tensor<4x8xi32, #COOSoA>
 
     // CHECK-NEXT: ---- Sparse Tensor ----
     // CHECK-NEXT: nse = 5
@@ -274,19 +284,42 @@ module {
     // CHECK-NEXT: ----
     sparse_tensor.print %j : tensor<4x8xi32, #BSC0>
 
+    // CHECK-NEXT: ---- Sparse Tensor ----
+    // CHECK-NEXT: nse = 5
+    // CHECK-NEXT: dim = ( 4, 8 )
+    // CHECK-NEXT: lvl = ( 4, 8 )
+    // CHECK-NEXT: pos[0] : ( 0, 5,
+    // CHECK-NEXT: crd[0] : ( 0, 0, 0, 2, 3, 2, 3, 3, 3, 5,
+    // CHECK-NEXT: values : ( 1, 2, 3, 4, 5,
+    // CHECK-NEXT: ----
+    sparse_tensor.print %AoS : tensor<4x8xi32, #COOAoS>
+
+    // CHECK-NEXT: ---- Sparse Tensor ----
+    // CHECK-NEXT: nse = 5
+    // CHECK-NEXT: dim = ( 4, 8 )
+    // CHECK-NEXT: lvl = ( 4, 8 )
+    // CHECK-NEXT: pos[0] : ( 0, 5,
+    // CHECK-NEXT: crd[0] : ( 0, 0, 3, 3, 3,
+    // CHECK-NEXT: crd[1] : ( 0, 2, 2, 3, 5,
+    // CHECK-NEXT: values : ( 1, 2, 3, 4, 5,
+    // CHECK-NEXT: ----
+    sparse_tensor.print %SoA : tensor<4x8xi32, #COOSoA>
+
     // Release the resources.
-    bufferization.dealloc_tensor %XO : tensor<4x8xi32, #AllDense>
-    bufferization.dealloc_tensor %XT : tensor<4x8xi32, #AllDenseT>
-    bufferization.dealloc_tensor %a : tensor<4x8xi32, #CSR>
-    bufferization.dealloc_tensor %b : tensor<4x8xi32, #DCSR>
-    bufferization.dealloc_tensor %c : tensor<4x8xi32, #CSC>
-    bufferization.dealloc_tensor %d : tensor<4x8xi32, #DCSC>
-    bufferization.dealloc_tensor %e : tensor<4x8xi32, #BSR>
-    bufferization.dealloc_tensor %f : tensor<4x8xi32, #BSRC>
-    bufferization.dealloc_tensor %g : tensor<4x8xi32, #BSC>
-    bufferization.dealloc_tensor %h : tensor<4x8xi32, #BSCC>
-    bufferization.dealloc_tensor %i : tensor<4x8xi32, #BSR0>
-    bufferization.dealloc_tensor %j : tensor<4x8xi32, #BSC0>
+    bufferization.dealloc_tensor %XO  : tensor<4x8xi32, #AllDense>
+    bufferization.dealloc_tensor %XT  : tensor<4x8xi32, #AllDenseT>
+    bufferization.dealloc_tensor %a   : tensor<4x8xi32, #CSR>
+    bufferization.dealloc_tensor %b   : tensor<4x8xi32, #DCSR>
+    bufferization.dealloc_tensor %c   : tensor<4x8xi32, #CSC>
+    bufferization.dealloc_tensor %d   : tensor<4x8xi32, #DCSC>
+    bufferization.dealloc_tensor %e   : tensor<4x8xi32, #BSR>
+    bufferization.dealloc_tensor %f   : tensor<4x8xi32, #BSRC>
+    bufferization.dealloc_tensor %g   : tensor<4x8xi32, #BSC>
+    bufferization.dealloc_tensor %h   : tensor<4x8xi32, #BSCC>
+    bufferization.dealloc_tensor %i   : tensor<4x8xi32, #BSR0>
+    bufferization.dealloc_tensor %j   : tensor<4x8xi32, #BSC0>
+    bufferization.dealloc_tensor %AoS : tensor<4x8xi32, #COOAoS>
+    bufferization.dealloc_tensor %SoA : tensor<4x8xi32, #COOSoA>
 
     return
   }



More information about the Mlir-commits mailing list