[Mlir-commits] [mlir] 1b27484 - [mlir][sparse] further implement singleton dimension level type

wren romano llvmlistbot at llvm.org
Wed Oct 5 16:15:02 PDT 2022


Author: wren romano
Date: 2022-10-05T16:14:52-07:00
New Revision: 1b27484a49ac12a5dad632e633c1c77b4281545d

URL: https://github.com/llvm/llvm-project/commit/1b27484a49ac12a5dad632e633c1c77b4281545d
DIFF: https://github.com/llvm/llvm-project/commit/1b27484a49ac12a5dad632e633c1c77b4281545d.diff

LOG: [mlir][sparse] further implement singleton dimension level type

Handle more cases of singleton DLT including direct sparse2sparse conversion.  (Followup to D134096)

Depends On D134926

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D134933

Added: 
    

Modified: 
    mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
    mlir/lib/ExecutionEngine/SparseTensor/NNZ.cpp
    mlir/test/Dialect/SparseTensor/conversion_sparse2sparse.mlir
    mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_conversion_sparse2sparse.mlir
    mlir/test/Integration/Dialect/SparseTensor/python/test_stress.py

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h b/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h
index 975f4435c73ef..f8633dedf54bd 100644
--- a/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h
+++ b/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h
@@ -53,6 +53,9 @@ class SparseTensorEnumeratorBase;
   assert(d < getRank() && "Dimension index is out of bounds");
 #define ASSERT_COMPRESSED_DIM(d)                                               \
   assert(isCompressedDim(d) && "Dimension is not compressed");
+#define ASSERT_COMPRESSED_OR_SINGLETON_DIM(d)                                  \
+  assert((isCompressedDim(d) || isSingletonDim(d)) &&                          \
+         "Dimension is neither compressed nor singleton");
 #define ASSERT_DENSE_DIM(d) assert(isDenseDim(d) && "Dimension is not dense");
 
 /// Abstract base class for `SparseTensorStorage<P,I,V>`.  This class
@@ -145,6 +148,7 @@ class SparseTensorStorageBase {
   virtual void getIndices(std::vector<I> **, uint64_t);
   MLIR_SPARSETENSOR_FOREVERY_FIXED_O(DECL_GETINDICES)
 #undef DECL_GETINDICES
+  virtual uint64_t getIndex(uint64_t d, uint64_t pos) const = 0;
 
   /// Gets primary storage.
 #define DECL_GETVALUES(VNAME, V) virtual void getValues(std::vector<V> **);
@@ -254,6 +258,12 @@ class SparseTensorStorage final : public SparseTensorStorageBase {
   }
   void getValues(std::vector<V> **out) final { *out = &values; }
 
+  uint64_t getIndex(uint64_t d, uint64_t pos) const final {
+    ASSERT_COMPRESSED_OR_SINGLETON_DIM(d);
+    assert(pos < indices[d].size() && "Index position is out of bounds");
+    return indices[d][pos]; // Converts the stored `I` into `uint64_t`.
+  }
+
   /// Partially specialize lexicographical insertions based on template types.
   void lexInsert(const uint64_t *cursor, V val) final {
     // First, wrap up pending insertion path.
@@ -376,7 +386,7 @@ class SparseTensorStorage final : public SparseTensorStorageBase {
   /// does not check that `i` is semantically valid (i.e., in bounds
   /// for `dimSizes[d]` and not elsewhere occurring in the same segment).
   void writeIndex(uint64_t d, uint64_t pos, uint64_t i) {
-    ASSERT_COMPRESSED_DIM(d);
+    ASSERT_COMPRESSED_OR_SINGLETON_DIM(d);
     // Subscript assignment to `std::vector` requires that the `pos`-th
     // entry has been initialized; thus we must be sure to check `size()`
     // here, instead of `capacity()` as would be ideal.
@@ -397,8 +407,11 @@ class SparseTensorStorage final : public SparseTensorStorageBase {
   uint64_t assembledSize(uint64_t parentSz, uint64_t d) const {
     if (isCompressedDim(d))
       return pointers[d][parentSz];
-    // else if dense:
-    return parentSz * getDimSizes()[d];
+    if (isSingletonDim(d))
+      return parentSz; // New size is same as the parent.
+    if (isDenseDim(d))
+      return parentSz * getDimSizes()[d];
+    MLIR_SPARSETENSOR_FATAL("unsupported dimension level type");
   }
 
   /// Initializes sparse tensor storage scheme from a memory-resident sparse
@@ -446,7 +459,7 @@ class SparseTensorStorage final : public SparseTensorStorageBase {
     if (isCompressedDim(d)) {
       appendPointer(d, indices[d].size(), count);
     } else if (isSingletonDim(d)) {
-      return;
+      return; // Nothing to finalize.
     } else { // Dense dimension.
       ASSERT_DENSE_DIM(d);
       const uint64_t sz = getDimSizes()[d];
@@ -475,8 +488,8 @@ class SparseTensorStorage final : public SparseTensorStorageBase {
 
   /// Continues a single insertion path, outer to inner.
   void insPath(const uint64_t *cursor, uint64_t 
diff , uint64_t top, V val) {
-    ASSERT_VALID_DIM(
diff );
     const uint64_t rank = getRank();
+    assert(
diff  <= rank && "Dimension-
diff  is out of bounds");
     for (uint64_t d = 
diff ; d < rank; ++d) {
       const uint64_t i = cursor[d];
       appendIndex(d, top, i);
@@ -509,6 +522,7 @@ class SparseTensorStorage final : public SparseTensorStorageBase {
   std::vector<uint64_t> idx; // index cursor for lexicographic insertion.
 };
 
+#undef ASSERT_COMPRESSED_OR_SINGLETON_DIM
 #undef ASSERT_COMPRESSED_DIM
 #undef ASSERT_VALID_DIM
 
@@ -637,7 +651,8 @@ class SparseTensorEnumerator final : public SparseTensorEnumeratorBase<V> {
         forallElements(yield, pos, d + 1);
       }
     } else if (src.isSingletonDim(d)) {
-      MLIR_SPARSETENSOR_FATAL("unsupported dimension level type");
+      this->cursor[this->reord[d]] = src.getIndex(d, parentPos);
+      forallElements(yield, parentPos, d + 1);
     } else { // Dense dimension.
       assert(src.isDenseDim(d)); // TODO: reuse the ASSERT_DENSE_DIM message
       const uint64_t sz = src.getDimSizes()[d];
@@ -740,7 +755,6 @@ SparseTensorStorage<P, I, V> *SparseTensorStorage<P, I, V>::newSparseTensor(
 #endif
     return new SparseTensorStorage<P, I, V>(coosz, perm, sparsity, coo);
   }
-  // else
   std::vector<uint64_t> permsz(rank);
   for (uint64_t r = 0; r < rank; ++r) {
     assert(shape[r] > 0 && "Dimension size zero has trivial storage");
@@ -848,8 +862,10 @@ SparseTensorStorage<P, I, V>::SparseTensorStorage(
       // That is, in the yieldPos loop we need random-access assignment
       // to `indices[r]`; however, `std::vector`'s subscript-assignment
       // only allows assigning to already-initialized positions.
-      if (isCompressedDim(r))
+      if (isCompressedDim(r) || isSingletonDim(r))
         indices[r].resize(parentSz, 0);
+      else
+        ASSERT_DENSE_DIM(r); // Future-proofing.
     }
     values.resize(parentSz, 0); // Both allocate and zero-initialize.
   }
@@ -872,6 +888,7 @@ SparseTensorStorage<P, I, V>::SparseTensorStorage(
         writeIndex(r, currentPos, ind[r]);
         parentPos = currentPos;
       } else if (isSingletonDim(r)) {
+        writeIndex(r, parentPos, ind[r]);
         // the new parentPos equals the old parentPos.
       } else { // Dense dimension.
         ASSERT_DENSE_DIM(r);
@@ -898,14 +915,19 @@ SparseTensorStorage<P, I, V>::SparseTensorStorage(
         pointers[r][parentPos] = pointers[r][parentPos - 1];
       }
       pointers[r][0] = 0;
+    } else {
+      // Both dense and singleton are no-ops for the finalizeYieldPos loop.
+      // This assertion is for future-proofing.
+      assert((isDenseDim(r) || isSingletonDim(r)) &&
+             "Dimension is neither dense nor singleton");
     }
     parentSz = assembledSize(parentSz, r);
   }
 }
 
+#undef ASSERT_DENSE_DIM
+
 } // namespace sparse_tensor
 } // namespace mlir
 
-#undef ASSERT_DENSE_DIM
-
 #endif // MLIR_EXECUTIONENGINE_SPARSETENSOR_STORAGE_H

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
index 9b725ec2ee8c3..f6917235a58f7 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
@@ -449,17 +449,18 @@ static bool canUseDirectConversion(
     ArrayRef<SparseTensorEncodingAttr::DimLevelType> dimTypes) {
   bool alreadyCompressed = false;
   for (uint64_t rank = dimTypes.size(), r = 0; r < rank; r++) {
-    switch (dimTypes[r]) {
-    case SparseTensorEncodingAttr::DimLevelType::Compressed:
+    const DimLevelType dlt = dimLevelTypeEncoding(dimTypes[r]);
+    if (isCompressedDLT(dlt)) {
       if (alreadyCompressed)
         return false; // Multiple compressed dimensions not yet supported.
       alreadyCompressed = true;
-      break;
-    case SparseTensorEncodingAttr::DimLevelType::Dense:
+    } else if (isDenseDLT(dlt)) {
       if (alreadyCompressed)
         return false; // Dense after Compressed not yet supported.
-      break;
-    default: // TODO: investigate
+    } else if (isSingletonDLT(dlt)) {
+      // Direct conversion doesn't have any particular problems with
+      // singleton after compressed.
+    } else { // TODO: investigate
       return false;
     }
   }

diff  --git a/mlir/lib/ExecutionEngine/SparseTensor/NNZ.cpp b/mlir/lib/ExecutionEngine/SparseTensor/NNZ.cpp
index da2784b7077fb..eb110def4402c 100644
--- a/mlir/lib/ExecutionEngine/SparseTensor/NNZ.cpp
+++ b/mlir/lib/ExecutionEngine/SparseTensor/NNZ.cpp
@@ -31,28 +31,29 @@ SparseTensorNNZ::SparseTensorNNZ(const std::vector<uint64_t> &dimSizes,
                                  const std::vector<DimLevelType> &sparsity)
     : dimSizes(dimSizes), dimTypes(sparsity), nnz(getRank()) {
   assert(dimSizes.size() == dimTypes.size() && "Rank mismatch");
-  bool uncompressed = true;
-  (void)uncompressed;
+  bool alreadyCompressed = false;
+  (void)alreadyCompressed;
   uint64_t sz = 1; // the product of all `dimSizes` strictly less than `r`.
   for (uint64_t rank = getRank(), r = 0; r < rank; r++) {
-    switch (dimTypes[r]) {
-    case DimLevelType::kCompressed:
-      assert(uncompressed &&
-             "Multiple compressed layers not currently supported");
-      uncompressed = false;
+    const DimLevelType dlt = sparsity[r];
+    if (isCompressedDLT(dlt)) {
+      if (alreadyCompressed)
+        MLIR_SPARSETENSOR_FATAL(
+            "Multiple compressed layers not currently supported");
+      alreadyCompressed = true;
       nnz[r].resize(sz, 0); // Both allocate and zero-initialize.
-      break;
-    case DimLevelType::kDense:
-      assert(uncompressed && "Dense after compressed not currently supported");
-      break;
-    case DimLevelType::kSingleton:
+    } else if (isDenseDLT(dlt)) {
+      if (alreadyCompressed)
+        MLIR_SPARSETENSOR_FATAL(
+            "Dense after compressed not currently supported");
+    } else if (isSingletonDLT(dlt)) {
       // Singleton after Compressed causes no problems for allocating
       // `nnz` nor for the yieldPos loop.  This remains true even
       // when adding support for multiple compressed dimensions or
       // for dense-after-compressed.
-      break;
-    default:
-      MLIR_SPARSETENSOR_FATAL("unsupported dimension level type");
+    } else {
+      MLIR_SPARSETENSOR_FATAL("unsupported dimension level type: %d\n",
+                              static_cast<uint8_t>(dlt));
     }
     sz = detail::checkedMul(sz, dimSizes[r]);
   }
@@ -65,7 +66,7 @@ SparseTensorNNZ::SparseTensorNNZ(const std::vector<uint64_t> &dimSizes,
 void SparseTensorNNZ::forallIndices(uint64_t stopDim,
                                     SparseTensorNNZ::NNZConsumer yield) const {
   assert(stopDim < getRank() && "Dimension out of bounds");
-  assert(dimTypes[stopDim] == DimLevelType::kCompressed &&
+  assert(isCompressedDLT(dimTypes[stopDim]) &&
          "Cannot look up non-compressed dimensions");
   forallIndices(yield, stopDim, 0, 0);
 }
@@ -78,7 +79,7 @@ void SparseTensorNNZ::forallIndices(uint64_t stopDim,
 void SparseTensorNNZ::add(const std::vector<uint64_t> &ind) {
   uint64_t parentPos = 0;
   for (uint64_t rank = getRank(), r = 0; r < rank; ++r) {
-    if (dimTypes[r] == DimLevelType::kCompressed)
+    if (isCompressedDLT(dimTypes[r]))
       nnz[r][parentPos]++;
     parentPos = parentPos * dimSizes[r] + ind[r];
   }

diff  --git a/mlir/test/Dialect/SparseTensor/conversion_sparse2sparse.mlir b/mlir/test/Dialect/SparseTensor/conversion_sparse2sparse.mlir
index 488f19cf495f5..fd38d69eedb3b 100644
--- a/mlir/test/Dialect/SparseTensor/conversion_sparse2sparse.mlir
+++ b/mlir/test/Dialect/SparseTensor/conversion_sparse2sparse.mlir
@@ -47,3 +47,45 @@ func.func @sparse_convert(%arg0: tensor<?xf32, #SparseVector64>) -> tensor<?xf32
   %0 = sparse_tensor.convert %arg0 : tensor<?xf32, #SparseVector64> to tensor<?xf32, #SparseVector32>
   return %0 : tensor<?xf32, #SparseVector32>
 }
+
+#SparseSingleton64 = #sparse_tensor.encoding<{
+  dimLevelType = ["singleton"],
+  pointerBitWidth = 64,
+  indexBitWidth = 64
+}>
+
+#SparseSingleton32 = #sparse_tensor.encoding<{
+  dimLevelType = ["singleton"],
+  pointerBitWidth = 32,
+  indexBitWidth = 32
+}>
+
+// CHECK-COO-LABEL: func @sparse_convert_singleton(
+//  CHECK-COO-SAME: %[[A:.*]]: !llvm.ptr<i8>)
+//  CHECK-COO-DAG:  %[[ToCOO:.*]] = arith.constant 5 : i32
+//  CHECK-COO-DAG:  %[[FromCOO:.*]] = arith.constant 2 : i32
+//   CHECK-COO-DAG: %[[P:.*]] = memref.alloca() : memref<1xi8>
+//   CHECK-COO-DAG: %[[Q:.*]] = memref.alloca() : memref<1xindex>
+//   CHECK-COO-DAG: %[[R:.*]] = memref.alloca() : memref<1xindex>
+//   CHECK-COO-DAG: %[[X:.*]] = memref.cast %[[P]] : memref<1xi8> to memref<?xi8>
+//   CHECK-COO-DAG: %[[Y:.*]] = memref.cast %[[Q]] : memref<1xindex> to memref<?xindex>
+//   CHECK-COO-DAG: %[[Z:.*]] = memref.cast %[[R]] : memref<1xindex> to memref<?xindex>
+//       CHECK-COO: %[[C:.*]] = call @newSparseTensor(%[[X]], %[[Y]], %[[Z]], %{{.*}}, %{{.*}}, %{{.*}}, %[[ToCOO]], %[[A]])
+//       CHECK-COO: %[[T:.*]] = call @newSparseTensor(%[[X]], %[[Y]], %[[Z]], %{{.*}}, %{{.*}}, %{{.*}}, %[[FromCOO]], %[[C]])
+//       CHECK-COO: call @delSparseTensorCOOF32(%[[C]])
+//       CHECK-COO: return %[[T]] : !llvm.ptr<i8>
+// CHECK-AUTO-LABEL: func @sparse_convert_singleton(
+//  CHECK-AUTO-SAME: %[[A:.*]]: !llvm.ptr<i8>)
+//   CHECK-AUTO-DAG: %[[SparseToSparse:.*]] = arith.constant 3 : i32
+//   CHECK-AUTO-DAG: %[[P:.*]] = memref.alloca() : memref<1xi8>
+//   CHECK-AUTO-DAG: %[[Q:.*]] = memref.alloca() : memref<1xindex>
+//   CHECK-AUTO-DAG: %[[R:.*]] = memref.alloca() : memref<1xindex>
+//   CHECK-AUTO-DAG: %[[X:.*]] = memref.cast %[[P]] : memref<1xi8> to memref<?xi8>
+//   CHECK-AUTO-DAG: %[[Y:.*]] = memref.cast %[[Q]] : memref<1xindex> to memref<?xindex>
+//   CHECK-AUTO-DAG: %[[Z:.*]] = memref.cast %[[R]] : memref<1xindex> to memref<?xindex>
+//       CHECK-AUTO: %[[T:.*]] = call @newSparseTensor(%[[X]], %[[Y]], %[[Z]], %{{.*}}, %{{.*}}, %{{.*}}, %[[SparseToSparse]], %[[A]])
+//       CHECK-AUTO: return %[[T]] : !llvm.ptr<i8>
+func.func @sparse_convert_singleton(%arg0: tensor<?xf32, #SparseSingleton64>) -> tensor<?xf32, #SparseSingleton32> {
+  %0 = sparse_tensor.convert %arg0 : tensor<?xf32, #SparseSingleton64> to tensor<?xf32, #SparseSingleton32>
+  return %0 : tensor<?xf32, #SparseSingleton32>
+}

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_conversion_sparse2sparse.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_conversion_sparse2sparse.mlir
index bab083e75ed43..11bd380ada0db 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_conversion_sparse2sparse.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_conversion_sparse2sparse.mlir
@@ -19,6 +19,20 @@
   dimOrdering = affine_map<(i,j,k) -> (i,k,j)>
 }>
 
+#SingletonTensor1 = #sparse_tensor.encoding<{
+  dimLevelType = [ "dense", "compressed", "singleton" ]
+}>
+
+// This also checks the compressed->dense conversion (when there are zeros).
+#SingletonTensor2 = #sparse_tensor.encoding<{
+  dimLevelType = [ "dense", "dense", "singleton" ]
+}>
+
+// This also checks the singleton->compressed conversion.
+#SingletonTensor3 = #sparse_tensor.encoding<{
+  dimLevelType = [ "dense", "dense", "compressed" ]
+}>
+
 module {
   //
   // Utilities for output and releasing memory.
@@ -36,9 +50,9 @@ module {
   }
 
   //
-  // Main driver.
+  // The first test suite (for non-singleton DimLevelTypes).
   //
-  func.func @entry() {
+  func.func @testNonSingleton() {
     //
     // Initialize a 3-dim dense tensor.
     //
@@ -97,4 +111,81 @@ module {
 
     return
   }
+
+  //
+  // The second test suite (for singleton DimLevelTypes).
+  //
+  func.func @testSingleton() {
+    //
+    // Initialize a 3-dim dense tensor with the 3rd dim being singleton.
+    //
+    %src = arith.constant dense<[
+       [  [  1.0,  0.0,  0.0,  0.0 ],
+          [  0.0,  6.0,  0.0,  0.0 ],
+          [  0.0,  0.0, 11.0,  0.0 ] ],
+       [  [  0.0, 14.0,  0.0,  0.0 ],
+          [  0.0,  0.0,  0.0, 20.0 ],
+          [ 21.0,  0.0,  0.0,  0.0 ] ]
+    ]> : tensor<2x3x4xf64>
+
+    //
+    // Convert dense tensor directly to various sparse tensors.
+    //
+    %s1 = sparse_tensor.convert %src : tensor<2x3x4xf64> to tensor<2x3x4xf64, #SingletonTensor1>
+    %s2 = sparse_tensor.convert %src : tensor<2x3x4xf64> to tensor<2x3x4xf64, #SingletonTensor2>
+    %s3 = sparse_tensor.convert %src : tensor<2x3x4xf64> to tensor<2x3x4xf64, #SingletonTensor3>
+
+    //
+    // Convert sparse tensor directly to another sparse format.
+    //
+    %t12 = sparse_tensor.convert %s1 : tensor<2x3x4xf64, #SingletonTensor1> to tensor<2x3x4xf64, #SingletonTensor2>
+    %t13 = sparse_tensor.convert %s1 : tensor<2x3x4xf64, #SingletonTensor1> to tensor<2x3x4xf64, #SingletonTensor3>
+    %t21 = sparse_tensor.convert %s2 : tensor<2x3x4xf64, #SingletonTensor2> to tensor<2x3x4xf64, #SingletonTensor1>
+    %t23 = sparse_tensor.convert %s2 : tensor<2x3x4xf64, #SingletonTensor2> to tensor<2x3x4xf64, #SingletonTensor3>
+    %t31 = sparse_tensor.convert %s3 : tensor<2x3x4xf64, #SingletonTensor3> to tensor<2x3x4xf64, #SingletonTensor1>
+    %t32 = sparse_tensor.convert %s3 : tensor<2x3x4xf64, #SingletonTensor3> to tensor<2x3x4xf64, #SingletonTensor2>
+
+    //
+    // Convert sparse tensor back to dense.
+    //
+    %d12 = sparse_tensor.convert %t12 : tensor<2x3x4xf64, #SingletonTensor2> to tensor<2x3x4xf64>
+    %d13 = sparse_tensor.convert %t13 : tensor<2x3x4xf64, #SingletonTensor3> to tensor<2x3x4xf64>
+    %d21 = sparse_tensor.convert %t21 : tensor<2x3x4xf64, #SingletonTensor1> to tensor<2x3x4xf64>
+    %d23 = sparse_tensor.convert %t23 : tensor<2x3x4xf64, #SingletonTensor3> to tensor<2x3x4xf64>
+    %d31 = sparse_tensor.convert %t31 : tensor<2x3x4xf64, #SingletonTensor1> to tensor<2x3x4xf64>
+    %d32 = sparse_tensor.convert %t32 : tensor<2x3x4xf64, #SingletonTensor2> to tensor<2x3x4xf64>
+
+    //
+    // Check round-trip equality.  And release dense tensors.
+    //
+    // CHECK-COUNT-7: ( ( ( 1, 0, 0, 0 ), ( 0, 6, 0, 0 ), ( 0, 0, 11, 0 ) ), ( ( 0, 14, 0, 0 ), ( 0, 0, 0, 20 ), ( 21, 0, 0, 0 ) ) )
+    call @dump(%src) : (tensor<2x3x4xf64>) -> ()
+    call @dumpAndRelease_234(%d12) : (tensor<2x3x4xf64>) -> ()
+    call @dumpAndRelease_234(%d13) : (tensor<2x3x4xf64>) -> ()
+    call @dumpAndRelease_234(%d21) : (tensor<2x3x4xf64>) -> ()
+    call @dumpAndRelease_234(%d23) : (tensor<2x3x4xf64>) -> ()
+    call @dumpAndRelease_234(%d31) : (tensor<2x3x4xf64>) -> ()
+    call @dumpAndRelease_234(%d32) : (tensor<2x3x4xf64>) -> ()
+
+    //
+    // Release sparse tensors.
+    //
+    bufferization.dealloc_tensor %t12 : tensor<2x3x4xf64, #SingletonTensor2>
+    bufferization.dealloc_tensor %t13 : tensor<2x3x4xf64, #SingletonTensor3>
+    bufferization.dealloc_tensor %t21 : tensor<2x3x4xf64, #SingletonTensor1>
+    bufferization.dealloc_tensor %t23 : tensor<2x3x4xf64, #SingletonTensor3>
+    bufferization.dealloc_tensor %t31 : tensor<2x3x4xf64, #SingletonTensor1>
+    bufferization.dealloc_tensor %t32 : tensor<2x3x4xf64, #SingletonTensor2>
+
+    return
+  }
+
+  //
+  // Main driver.
+  //
+  func.func @entry() {
+    call @testNonSingleton() : () -> ()
+    call @testSingleton() : () -> ()
+    return
+  }
 }

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/python/test_stress.py b/mlir/test/Integration/Dialect/SparseTensor/python/test_stress.py
index ae4de073bc81e..d05cb400e235e 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/python/test_stress.py
+++ b/mlir/test/Integration/Dialect/SparseTensor/python/test_stress.py
@@ -189,6 +189,7 @@ def main():
     # TODO: While direct s2s is far too slow for per-commit testing,
     # we should have some framework ensure that we run this test with
     # `s2s=0` on a regular basis, to ensure that it does continue to work.
+    # TODO: be sure to test s2s=0 together with singletons.
     s2s = 1
     sparsification_options = (
         f'parallelization-strategy=none '
@@ -200,10 +201,12 @@ def main():
         options=sparsification_options, opt_level=0, shared_libs=[support_lib])
     f64 = ir.F64Type.get()
     # Be careful about increasing this because
-    #     len(types) = 1 + 2^rank * rank! * len(bitwidths)^2
+    #     len(types) = 1 + len(level_choices)^rank * rank! * len(bitwidths)^2
     shape = range(2, 6)
     rank = len(shape)
     # All combinations.
+    # TODO: add singleton here too; which requires updating how `np_arg0`
+    # is initialized below.
     levels = list(itertools.product(*itertools.repeat(
       [st.DimLevelType.dense, st.DimLevelType.compressed], rank)))
     # All permutations.


        


More information about the Mlir-commits mailing list