[Mlir-commits] [mlir] 63bdcaf - [mlir][sparse] Moving `delete coo` into codegen instead of runtime library

wren romano llvmlistbot at llvm.org
Fri Apr 1 11:08:59 PDT 2022


Author: wren romano
Date: 2022-04-01T11:08:52-07:00
New Revision: 63bdcaf92a5ed5dc8c9bdc0ea2d1a13b8ddb3c68

URL: https://github.com/llvm/llvm-project/commit/63bdcaf92a5ed5dc8c9bdc0ea2d1a13b8ddb3c68
DIFF: https://github.com/llvm/llvm-project/commit/63bdcaf92a5ed5dc8c9bdc0ea2d1a13b8ddb3c68.diff

LOG: [mlir][sparse] Moving `delete coo` into codegen instead of runtime library

Prior to this change there were a number of places where the allocation and deallocation of SparseTensorCOO objects were not cleanly paired, leading to inconsistencies regarding whether each function released its tensor/coo arguments or not, as well as making it easy to run afoul of memory leaks, use-after-free, or double-free errors.  This change cleans up the codegen vs runtime boundary to resolve those issues.  Now, the only time the runtime library frees an object is either (a) because it's a function explicitly designed to do so, or (b) because the allocated object is entirely local to the function and would be a memory leak if not released.  Thus, now the codegen takes complete responsibility for releasing any objects it caused to be allocated.

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D122435

Added: 
    

Modified: 
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
    mlir/lib/ExecutionEngine/SparseTensorUtils.cpp
    mlir/test/Dialect/SparseTensor/conversion.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
index 11329f6abc7be..d78061e0d7587 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
@@ -253,6 +253,14 @@ static Value genIndexAndValueForDense(ConversionPatternRewriter &rewriter,
   return val;
 }
 
+/// Generates a call to release/delete a `SparseTensorCOO`.
+static void genDelCOOCall(OpBuilder &builder, Operation *op, Type elemTp,
+                          Value coo) {
+  SmallString<21> name{"delSparseTensorCOO", primaryTypeFunctionSuffix(elemTp)};
+  TypeRange noTp;
+  createFuncCall(builder, op, name, noTp, coo, EmitCInterface::Off);
+}
+
 /// Generates a call that adds one element to a coordinate scheme.
 /// In particular, this generates code like the following:
 ///   val = a[i1,..,ik];
@@ -501,7 +509,9 @@ class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> {
       params[4] = constantIndexTypeEncoding(rewriter, loc, encDst);
       params[6] = constantAction(rewriter, loc, Action::kFromCOO);
       params[7] = coo;
-      rewriter.replaceOp(op, genNewCall(rewriter, op, params));
+      Value dst = genNewCall(rewriter, op, params);
+      genDelCOOCall(rewriter, op, stp.getElementType(), coo);
+      rewriter.replaceOp(op, dst);
       return success();
     }
     if (!encDst && encSrc) {
@@ -545,6 +555,7 @@ class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> {
       insertScalarIntoDenseTensor(rewriter, loc, elemPtr, dst, rank, ind);
       rewriter.create<scf::YieldOp>(loc);
       rewriter.setInsertionPointAfter(whileOp);
+      genDelCOOCall(rewriter, op, elemTp, iter);
       rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(op, resType, dst);
       return success();
     }
@@ -584,7 +595,7 @@ class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> {
     SmallVector<Value, 8> params;
     sizesFromSrc(rewriter, sizes, loc, src);
     newParams(rewriter, params, op, stp, encDst, Action::kEmptyCOO, sizes);
-    Value ptr = genNewCall(rewriter, op, params);
+    Value coo = genNewCall(rewriter, op, params);
     Value ind = genAlloca(rewriter, loc, rank, rewriter.getIndexType());
     Value perm = params[2];
     SmallVector<Value> lo;
@@ -620,13 +631,15 @@ class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> {
                                             ivs, rank);
           else
             val = genIndexAndValueForDense(rewriter, loc, src, ind, ivs);
-          genAddEltCall(rewriter, op, eltType, ptr, val, ind, perm);
+          genAddEltCall(rewriter, op, eltType, coo, val, ind, perm);
           return {};
         });
     // Final call to construct sparse tensor storage.
     params[6] = constantAction(rewriter, loc, Action::kFromCOO);
-    params[7] = ptr;
-    rewriter.replaceOp(op, genNewCall(rewriter, op, params));
+    params[7] = coo;
+    Value dst = genNewCall(rewriter, op, params);
+    genDelCOOCall(rewriter, op, eltType, coo);
+    rewriter.replaceOp(op, dst);
     return success();
   }
 };
@@ -822,8 +835,9 @@ class SparseTensorOutConverter : public OpConversionPattern<OutOp> {
     Type eltType = srcType.getElementType();
     SmallString<18> name{"outSparseTensor", primaryTypeFunctionSuffix(eltType)};
     TypeRange noTp;
-    replaceOpWithFuncCall(rewriter, op, name, noTp, params,
-                          EmitCInterface::Off);
+    createFuncCall(rewriter, op, name, noTp, params, EmitCInterface::Off);
+    genDelCOOCall(rewriter, op, eltType, coo);
+    rewriter.eraseOp(op);
     return success();
   }
 };

diff  --git a/mlir/lib/ExecutionEngine/SparseTensorUtils.cpp b/mlir/lib/ExecutionEngine/SparseTensorUtils.cpp
index b86f10316cb8f..d1a08f10ca905 100644
--- a/mlir/lib/ExecutionEngine/SparseTensorUtils.cpp
+++ b/mlir/lib/ExecutionEngine/SparseTensorUtils.cpp
@@ -400,7 +400,6 @@ class SparseTensorStorage : public SparseTensorStorageBase {
         assert(shape[r] == 0 || shape[r] == tensor->getSizes()[perm[r]]);
       n = new SparseTensorStorage<P, I, V>(tensor->getSizes(), perm, sparsity,
                                            tensor);
-      delete tensor;
     } else {
       std::vector<uint64_t> permsz(rank);
       for (uint64_t r = 0; r < rank; r++) {
@@ -748,7 +747,6 @@ static void outSparseTensor(void *tensor, void *dest, bool sort) {
   file.flush();
   file.close();
   assert(file.good());
-  delete coo;
 }
 
 /// Initializes sparse tensor from an external COO-flavored format.
@@ -780,17 +778,19 @@ toMLIRSparseTensor(uint64_t rank, uint64_t nse, uint64_t *shape, V *values,
 #endif
 
   // Convert external format to internal COO.
-  auto *tensor = SparseTensorCOO<V>::newSparseTensorCOO(rank, shape, perm, nse);
+  auto *coo = SparseTensorCOO<V>::newSparseTensorCOO(rank, shape, perm, nse);
   std::vector<uint64_t> idx(rank);
   for (uint64_t i = 0, base = 0; i < nse; i++) {
     for (uint64_t r = 0; r < rank; r++)
       idx[perm[r]] = indices[base + r];
-    tensor->add(idx, values[i]);
+    coo->add(idx, values[i]);
     base += rank;
   }
   // Return sparse tensor storage format as opaque pointer.
-  return SparseTensorStorage<uint64_t, uint64_t, V>::newSparseTensor(
-      rank, shape, perm, sparsity, tensor);
+  auto *tensor = SparseTensorStorage<uint64_t, uint64_t, V>::newSparseTensor(
+      rank, shape, perm, sparsity, coo);
+  delete coo;
+  return tensor;
 }
 
 /// Converts a sparse tensor to an external COO-flavored format.
@@ -847,28 +847,31 @@ extern "C" {
 
 #define CASE(p, i, v, P, I, V)                                                 \
   if (ptrTp == (p) && indTp == (i) && valTp == (v)) {                          \
-    SparseTensorCOO<V> *tensor = nullptr;                                      \
+    SparseTensorCOO<V> *coo = nullptr;                                         \
     if (action <= Action::kFromCOO) {                                          \
       if (action == Action::kFromFile) {                                       \
         char *filename = static_cast<char *>(ptr);                             \
-        tensor = openSparseTensorCOO<V>(filename, rank, shape, perm);          \
+        coo = openSparseTensorCOO<V>(filename, rank, shape, perm);             \
       } else if (action == Action::kFromCOO) {                                 \
-        tensor = static_cast<SparseTensorCOO<V> *>(ptr);                       \
+        coo = static_cast<SparseTensorCOO<V> *>(ptr);                          \
       } else {                                                                 \
         assert(action == Action::kEmpty);                                      \
       }                                                                        \
-      return SparseTensorStorage<P, I, V>::newSparseTensor(rank, shape, perm,  \
-                                                           sparsity, tensor);  \
+      auto *tensor = SparseTensorStorage<P, I, V>::newSparseTensor(            \
+          rank, shape, perm, sparsity, coo);                                   \
+      if (action == Action::kFromFile)                                         \
+        delete coo;                                                            \
+      return tensor;                                                           \
     }                                                                          \
     if (action == Action::kEmptyCOO)                                           \
       return SparseTensorCOO<V>::newSparseTensorCOO(rank, shape, perm);        \
-    tensor = static_cast<SparseTensorStorage<P, I, V> *>(ptr)->toCOO(perm);    \
+    coo = static_cast<SparseTensorStorage<P, I, V> *>(ptr)->toCOO(perm);       \
     if (action == Action::kToIterator) {                                       \
-      tensor->startIterator();                                                 \
+      coo->startIterator();                                                    \
     } else {                                                                   \
       assert(action == Action::kToCOO);                                        \
     }                                                                          \
-    return tensor;                                                             \
+    return coo;                                                                \
   }
 
 #define CASE_SECSAME(p, v, P, V) CASE(p, p, v, P, P, V)
@@ -924,10 +927,8 @@ extern "C" {
     const uint64_t isize = iref->sizes[0];                                     \
     auto iter = static_cast<SparseTensorCOO<V> *>(tensor);                     \
     const Element<V> *elem = iter->getNext();                                  \
-    if (elem == nullptr) {                                                     \
-      delete iter;                                                             \
+    if (elem == nullptr)                                                       \
       return false;                                                            \
-    }                                                                          \
     for (uint64_t r = 0; r < isize; r++)                                       \
       indx[r] = elem->indices[r];                                              \
     *value = elem->value;                                                      \
@@ -1208,6 +1209,19 @@ void delSparseTensor(void *tensor) {
   delete static_cast<SparseTensorStorageBase *>(tensor);
 }
 
+/// Releases sparse tensor coordinate scheme.
+#define IMPL_DELCOO(VNAME, V)                                                  \
+  void delSparseTensorCOO##VNAME(void *coo) {                                  \
+    delete static_cast<SparseTensorCOO<V> *>(coo);                             \
+  }
+IMPL_DELCOO(F64, double)
+IMPL_DELCOO(F32, float)
+IMPL_DELCOO(I64, int64_t)
+IMPL_DELCOO(I32, int32_t)
+IMPL_DELCOO(I16, int16_t)
+IMPL_DELCOO(I8, int8_t)
+#undef IMPL_DELCOO
+
 /// Initializes sparse tensor from a COO-flavored format expressed using C-style
 /// data structures. The expected parameters are:
 ///

diff  --git a/mlir/test/Dialect/SparseTensor/conversion.mlir b/mlir/test/Dialect/SparseTensor/conversion.mlir
index db96c8ce07193..a3078e4913fca 100644
--- a/mlir/test/Dialect/SparseTensor/conversion.mlir
+++ b/mlir/test/Dialect/SparseTensor/conversion.mlir
@@ -65,13 +65,14 @@ func @sparse_dim3d_const(%arg0: tensor<10x20x30xf64, #SparseTensor>) -> index {
 
 // CHECK-LABEL: func @sparse_new1d(
 //  CHECK-SAME: %[[A:.*]]: !llvm.ptr<i8>) -> !llvm.ptr<i8>
+//   CHECK-DAG: %[[FromFile:.*]] = arith.constant 1 : i32
 //   CHECK-DAG: %[[P:.*]] = memref.alloca() : memref<1xi8>
 //   CHECK-DAG: %[[Q:.*]] = memref.alloca() : memref<1xindex>
 //   CHECK-DAG: %[[R:.*]] = memref.alloca() : memref<1xindex>
 //   CHECK-DAG: %[[X:.*]] = memref.cast %[[P]] : memref<1xi8> to memref<?xi8>
 //   CHECK-DAG: %[[Y:.*]] = memref.cast %[[Q]] : memref<1xindex> to memref<?xindex>
 //   CHECK-DAG: %[[Z:.*]] = memref.cast %[[R]] : memref<1xindex> to memref<?xindex>
-//       CHECK: %[[T:.*]] = call @newSparseTensor(%[[X]], %[[Y]], %[[Z]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[A]])
+//       CHECK: %[[T:.*]] = call @newSparseTensor(%[[X]], %[[Y]], %[[Z]], %{{.*}}, %{{.*}}, %{{.*}}, %[[FromFile]], %[[A]])
 //       CHECK: return %[[T]] : !llvm.ptr<i8>
 func @sparse_new1d(%arg0: !llvm.ptr<i8>) -> tensor<128xf64, #SparseVector> {
   %0 = sparse_tensor.new %arg0 : !llvm.ptr<i8> to tensor<128xf64, #SparseVector>
@@ -80,13 +81,14 @@ func @sparse_new1d(%arg0: !llvm.ptr<i8>) -> tensor<128xf64, #SparseVector> {
 
 // CHECK-LABEL: func @sparse_new2d(
 //  CHECK-SAME: %[[A:.*]]: !llvm.ptr<i8>) -> !llvm.ptr<i8>
+//   CHECK-DAG: %[[FromFile:.*]] = arith.constant 1 : i32
 //   CHECK-DAG: %[[P:.*]] = memref.alloca() : memref<2xi8>
 //   CHECK-DAG: %[[Q:.*]] = memref.alloca() : memref<2xindex>
 //   CHECK-DAG: %[[R:.*]] = memref.alloca() : memref<2xindex>
 //   CHECK-DAG: %[[X:.*]] = memref.cast %[[P]] : memref<2xi8> to memref<?xi8>
 //   CHECK-DAG: %[[Y:.*]] = memref.cast %[[Q]] : memref<2xindex> to memref<?xindex>
 //   CHECK-DAG: %[[Z:.*]] = memref.cast %[[R]] : memref<2xindex> to memref<?xindex>
-//       CHECK: %[[T:.*]] = call @newSparseTensor(%[[X]], %[[Y]], %[[Z]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[A]])
+//       CHECK: %[[T:.*]] = call @newSparseTensor(%[[X]], %[[Y]], %[[Z]], %{{.*}}, %{{.*}}, %{{.*}}, %[[FromFile]], %[[A]])
 //       CHECK: return %[[T]] : !llvm.ptr<i8>
 func @sparse_new2d(%arg0: !llvm.ptr<i8>) -> tensor<?x?xf32, #SparseMatrix> {
   %0 = sparse_tensor.new %arg0 : !llvm.ptr<i8> to tensor<?x?xf32, #SparseMatrix>
@@ -95,13 +97,14 @@ func @sparse_new2d(%arg0: !llvm.ptr<i8>) -> tensor<?x?xf32, #SparseMatrix> {
 
 // CHECK-LABEL: func @sparse_new3d(
 //  CHECK-SAME: %[[A:.*]]: !llvm.ptr<i8>) -> !llvm.ptr<i8>
+//   CHECK-DAG: %[[FromFile:.*]] = arith.constant 1 : i32
 //   CHECK-DAG: %[[P:.*]] = memref.alloca() : memref<3xi8>
 //   CHECK-DAG: %[[Q:.*]] = memref.alloca() : memref<3xindex>
 //   CHECK-DAG: %[[R:.*]] = memref.alloca() : memref<3xindex>
 //   CHECK-DAG: %[[X:.*]] = memref.cast %[[P]] : memref<3xi8> to memref<?xi8>
 //   CHECK-DAG: %[[Y:.*]] = memref.cast %[[Q]] : memref<3xindex> to memref<?xindex>
 //   CHECK-DAG: %[[Z:.*]] = memref.cast %[[R]] : memref<3xindex> to memref<?xindex>
-//       CHECK: %[[T:.*]] = call @newSparseTensor(%[[X]], %[[Y]], %[[Z]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[A]])
+//       CHECK: %[[T:.*]] = call @newSparseTensor(%[[X]], %[[Y]], %[[Z]], %{{.*}}, %{{.*}}, %{{.*}}, %[[FromFile]], %[[A]])
 //       CHECK: return %[[T]] : !llvm.ptr<i8>
 func @sparse_new3d(%arg0: !llvm.ptr<i8>) -> tensor<?x?x?xf32, #SparseTensor> {
   %0 = sparse_tensor.new %arg0 : !llvm.ptr<i8> to tensor<?x?x?xf32, #SparseTensor>
@@ -111,6 +114,7 @@ func @sparse_new3d(%arg0: !llvm.ptr<i8>) -> tensor<?x?x?xf32, #SparseTensor> {
 // CHECK-LABEL: func @sparse_init(
 //  CHECK-SAME: %[[I:.*]]: index,
 //  CHECK-SAME: %[[J:.*]]: index) -> !llvm.ptr<i8>
+//   CHECK-DAG: %[[Empty:.*]] = arith.constant 0 : i32
 //   CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
 //   CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
 //   CHECK-DAG: %[[P:.*]] = memref.alloca() : memref<2xi8>
@@ -122,7 +126,7 @@ func @sparse_new3d(%arg0: !llvm.ptr<i8>) -> tensor<?x?x?xf32, #SparseTensor> {
 //   CHECK-DAG: memref.store %[[I]], %[[Q]][%[[C0]]] : memref<2xindex>
 //   CHECK-DAG: memref.store %[[J]], %[[Q]][%[[C1]]] : memref<2xindex>
 //       CHECK: %[[NP:.*]] = llvm.mlir.null : !llvm.ptr<i8>
-//       CHECK: %[[T:.*]] = call @newSparseTensor(%[[X]], %[[Y]], %[[Z]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[NP]])
+//       CHECK: %[[T:.*]] = call @newSparseTensor(%[[X]], %[[Y]], %[[Z]], %{{.*}}, %{{.*}}, %{{.*}}, %[[Empty]], %[[NP]])
 //       CHECK: return %[[T]] : !llvm.ptr<i8>
 func @sparse_init(%arg0: index, %arg1: index) -> tensor<?x?xf64, #SparseMatrix> {
   %0 = sparse_tensor.init [%arg0, %arg1] : tensor<?x?xf64, #SparseMatrix>
@@ -164,6 +168,8 @@ func @sparse_nop_cast(%arg0: tensor<64xf32, #SparseVector>) -> tensor<?xf32, #Sp
 
 // CHECK-LABEL: func @sparse_convert_1d(
 //  CHECK-SAME: %[[A:.*]]: tensor<?xi32>) -> !llvm.ptr<i8>
+//   CHECK-DAG: %[[EmptyCOO:.*]] = arith.constant 4 : i32
+//   CHECK-DAG: %[[FromCOO:.*]] = arith.constant 2 : i32
 //   CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
 //   CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
 //   CHECK-DAG: %[[U:.*]] = tensor.dim %[[A]], %[[C0]] : tensor<?xi32>
@@ -174,7 +180,7 @@ func @sparse_nop_cast(%arg0: tensor<64xf32, #SparseVector>) -> tensor<?xf32, #Sp
 //   CHECK-DAG: %[[Y:.*]] = memref.cast %[[Q]] : memref<1xindex> to memref<?xindex>
 //   CHECK-DAG: %[[Z:.*]] = memref.cast %[[R]] : memref<1xindex> to memref<?xindex>
 //       CHECK: %[[NP:.*]] = llvm.mlir.null : !llvm.ptr<i8>
-//       CHECK: %[[C:.*]] = call @newSparseTensor(%[[X]], %[[Y]], %[[Z]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[NP]])
+//       CHECK: %[[C:.*]] = call @newSparseTensor(%[[X]], %[[Y]], %[[Z]], %{{.*}}, %{{.*}}, %{{.*}}, %[[EmptyCOO]], %[[NP]])
 //       CHECK: %[[M:.*]] = memref.alloca() : memref<1xindex>
 //       CHECK: %[[T:.*]] = memref.cast %[[M]] : memref<1xindex> to memref<?xindex>
 //       CHECK: scf.for %[[I:.*]] = %[[C0]] to %[[U]] step %[[C1]] {
@@ -182,7 +188,8 @@ func @sparse_nop_cast(%arg0: tensor<64xf32, #SparseVector>) -> tensor<?xf32, #Sp
 //       CHECK:   memref.store %[[I]], %[[M]][%[[C0]]] : memref<1xindex>
 //       CHECK:   call @addEltI32(%[[C]], %[[E]], %[[T]], %[[Z]])
 //       CHECK: }
-//       CHECK: %[[T:.*]] = call @newSparseTensor(%[[X]], %[[Y]], %[[Z]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[C]])
+//       CHECK: %[[T:.*]] = call @newSparseTensor(%[[X]], %[[Y]], %[[Z]], %{{.*}}, %{{.*}}, %{{.*}}, %[[FromCOO]], %[[C]])
+//       CHECK: call @delSparseTensorCOOI32(%[[C]])
 //       CHECK: return %[[T]] : !llvm.ptr<i8>
 func @sparse_convert_1d(%arg0: tensor<?xi32>) -> tensor<?xi32, #SparseVector> {
   %0 = sparse_tensor.convert %arg0 : tensor<?xi32> to tensor<?xi32, #SparseVector>
@@ -191,14 +198,17 @@ func @sparse_convert_1d(%arg0: tensor<?xi32>) -> tensor<?xi32, #SparseVector> {
 
 // CHECK-LABEL: func @sparse_convert_1d_ss(
 //  CHECK-SAME: %[[A:.*]]: !llvm.ptr<i8>)
+//  CHECK-DAG:  %[[ToCOO:.*]] = arith.constant 5 : i32
+//  CHECK-DAG:  %[[FromCOO:.*]] = arith.constant 2 : i32
 //   CHECK-DAG: %[[P:.*]] = memref.alloca() : memref<1xi8>
 //   CHECK-DAG: %[[Q:.*]] = memref.alloca() : memref<1xindex>
 //   CHECK-DAG: %[[R:.*]] = memref.alloca() : memref<1xindex>
 //   CHECK-DAG: %[[X:.*]] = memref.cast %[[P]] : memref<1xi8> to memref<?xi8>
 //   CHECK-DAG: %[[Y:.*]] = memref.cast %[[Q]] : memref<1xindex> to memref<?xindex>
 //   CHECK-DAG: %[[Z:.*]] = memref.cast %[[R]] : memref<1xindex> to memref<?xindex>
-//       CHECK: %[[C:.*]] = call @newSparseTensor(%[[X]], %[[Y]], %[[Z]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[A]])
-//       CHECK: %[[T:.*]] = call @newSparseTensor(%[[X]], %[[Y]], %[[Z]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[C]])
+//       CHECK: %[[C:.*]] = call @newSparseTensor(%[[X]], %[[Y]], %[[Z]], %{{.*}}, %{{.*}}, %{{.*}}, %[[ToCOO]], %[[A]])
+//       CHECK: %[[T:.*]] = call @newSparseTensor(%[[X]], %[[Y]], %[[Z]], %{{.*}}, %{{.*}}, %{{.*}}, %[[FromCOO]], %[[C]])
+//       CHECK: call @delSparseTensorCOOF32(%[[C]])
 //       CHECK: return %[[T]] : !llvm.ptr<i8>
 func @sparse_convert_1d_ss(%arg0: tensor<?xf32, #SparseVector64>) -> tensor<?xf32, #SparseVector32> {
   %0 = sparse_tensor.convert %arg0 : tensor<?xf32, #SparseVector64> to tensor<?xf32, #SparseVector32>
@@ -207,6 +217,8 @@ func @sparse_convert_1d_ss(%arg0: tensor<?xf32, #SparseVector64>) -> tensor<?xf3
 
 // CHECK-LABEL: func @sparse_convert_2d(
 //  CHECK-SAME: %[[A:.*]]: tensor<2x4xf64>) -> !llvm.ptr<i8>
+//   CHECK-DAG: %[[EmptyCOO:.*]] = arith.constant 4 : i32
+//   CHECK-DAG: %[[FromCOO:.*]] = arith.constant 2 : i32
 //   CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
 //   CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
 //   CHECK-DAG: %[[P:.*]] = memref.alloca() : memref<2xi8>
@@ -216,7 +228,7 @@ func @sparse_convert_1d_ss(%arg0: tensor<?xf32, #SparseVector64>) -> tensor<?xf3
 //   CHECK-DAG: %[[Y:.*]] = memref.cast %[[Q]] : memref<2xindex> to memref<?xindex>
 //   CHECK-DAG: %[[Z:.*]] = memref.cast %[[R]] : memref<2xindex> to memref<?xindex>
 //       CHECK: %[[NP:.*]] = llvm.mlir.null : !llvm.ptr<i8>
-//       CHECK: %[[C:.*]] = call @newSparseTensor(%[[X]], %[[Y]], %[[Z]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[NP]])
+//       CHECK: %[[C:.*]] = call @newSparseTensor(%[[X]], %[[Y]], %[[Z]], %{{.*}}, %{{.*}}, %{{.*}}, %[[EmptyCOO]], %[[NP]])
 //       CHECK: %[[M:.*]] = memref.alloca() : memref<2xindex>
 //       CHECK: %[[T:.*]] = memref.cast %[[M]] : memref<2xindex> to memref<?xindex>
 //       CHECK: scf.for %[[I:.*]] = %[[C0]] to %{{.*}} step %[[C1]] {
@@ -227,7 +239,8 @@ func @sparse_convert_1d_ss(%arg0: tensor<?xf32, #SparseVector64>) -> tensor<?xf3
 //       CHECK:     call @addEltF64(%[[C]], %[[E]], %[[T]], %[[Z]])
 //       CHECK:   }
 //       CHECK: }
-//       CHECK: %[[T:.*]] = call @newSparseTensor(%[[X]], %[[Y]], %[[Z]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[C]])
+//       CHECK: %[[T:.*]] = call @newSparseTensor(%[[X]], %[[Y]], %[[Z]], %{{.*}}, %{{.*}}, %{{.*}}, %[[FromCOO]], %[[C]])
+//       CHECK: call @delSparseTensorCOOF64(%[[C]])
 //       CHECK: return %[[T]] : !llvm.ptr<i8>
 func @sparse_convert_2d(%arg0: tensor<2x4xf64>) -> tensor<2x4xf64, #SparseMatrix> {
   %0 = sparse_tensor.convert %arg0 : tensor<2x4xf64> to tensor<2x4xf64, #SparseMatrix>
@@ -235,6 +248,8 @@ func @sparse_convert_2d(%arg0: tensor<2x4xf64>) -> tensor<2x4xf64, #SparseMatrix
 }
 
 // CHECK-LABEL: func @sparse_constant() -> !llvm.ptr<i8> {
+//   CHECK-DAG: %[[EmptyCOO:.*]] = arith.constant 4 : i32
+//   CHECK-DAG: %[[FromCOO:.*]] = arith.constant 2 : i32
 //   CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
 //   CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
 //   CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
@@ -245,7 +260,7 @@ func @sparse_convert_2d(%arg0: tensor<2x4xf64>) -> tensor<2x4xf64, #SparseMatrix
 //   CHECK-DAG: %[[Y:.*]] = memref.cast %[[Q]] : memref<2xindex> to memref<?xindex>
 //   CHECK-DAG: %[[Z:.*]] = memref.cast %[[R]] : memref<2xindex> to memref<?xindex>
 //       CHECK: %[[NP:.*]] = llvm.mlir.null : !llvm.ptr<i8>
-//       CHECK: %[[C:.*]] = call @newSparseTensor(%[[X]], %[[Y]], %[[Z]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[NP]])
+//       CHECK: %[[C:.*]] = call @newSparseTensor(%[[X]], %[[Y]], %[[Z]], %{{.*}}, %{{.*}}, %{{.*}}, %[[EmptyCOO]], %[[NP]])
 //       CHECK: %[[M:.*]] = memref.alloca() : memref<2xindex>
 //       CHECK: %[[N:.*]] = memref.cast %[[M]] : memref<2xindex> to memref<?xindex>
 //       CHECK: scf.for %[[I:.*]] = %[[C0]] to %[[C2]] step %[[C1]] {
@@ -254,7 +269,8 @@ func @sparse_convert_2d(%arg0: tensor<2x4xf64>) -> tensor<2x4xf64, #SparseMatrix
 //       CHECK:   %[[V:.*]] = tensor.extract %{{.*}}[%[[I]]] : tensor<2xf32>
 //       CHECK:   call @addEltF32(%{{.*}}, %[[V]], %[[N]], %{{.*}})
 //       CHECK: }
-//       CHECK: %[[T:.*]] = call @newSparseTensor(%[[X]], %[[Y]], %[[Z]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[C]])
+//       CHECK: %[[T:.*]] = call @newSparseTensor(%[[X]], %[[Y]], %[[Z]], %{{.*}}, %{{.*}}, %{{.*}}, %[[FromCOO]], %[[C]])
+//       CHECK: call @delSparseTensorCOOF32(%[[C]])
 //       CHECK: return %[[T]] : !llvm.ptr<i8>
 func @sparse_constant() -> tensor<8x7xf32, #SparseMatrix>{
   // Initialize a tensor.
@@ -266,6 +282,8 @@ func @sparse_constant() -> tensor<8x7xf32, #SparseMatrix>{
 
 // CHECK-LABEL: func @sparse_convert_3d(
 //  CHECK-SAME: %[[A:.*]]: tensor<?x?x?xf64>) -> !llvm.ptr<i8>
+//   CHECK-DAG: %[[EmptyCOO:.*]] = arith.constant 4 : i32
+//   CHECK-DAG: %[[FromCOO:.*]] = arith.constant 2 : i32
 //   CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
 //   CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
 //   CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
@@ -279,7 +297,7 @@ func @sparse_constant() -> tensor<8x7xf32, #SparseMatrix>{
 //   CHECK-DAG: %[[Y:.*]] = memref.cast %[[Q]] : memref<3xindex> to memref<?xindex>
 //   CHECK-DAG: %[[Z:.*]] = memref.cast %[[R]] : memref<3xindex> to memref<?xindex>
 //       CHECK: %[[NP:.*]] = llvm.mlir.null : !llvm.ptr<i8>
-//       CHECK: %[[C:.*]] = call @newSparseTensor(%[[X]], %[[Y]], %[[Z]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[NP]])
+//       CHECK: %[[C:.*]] = call @newSparseTensor(%[[X]], %[[Y]], %[[Z]], %{{.*}}, %{{.*}}, %{{.*}}, %[[EmptyCOO]], %[[NP]])
 //       CHECK: %[[M:.*]] = memref.alloca() : memref<3xindex>
 //       CHECK: %[[N:.*]] = memref.cast %[[M]] : memref<3xindex> to memref<?xindex>
 //       CHECK: scf.for %[[I:.*]] = %[[C0]] to %[[U1]] step %[[C1]] {
@@ -293,7 +311,8 @@ func @sparse_constant() -> tensor<8x7xf32, #SparseMatrix>{
 //       CHECK:     }
 //       CHECK:   }
 //       CHECK: }
-//       CHECK: %[[T:.*]] = call @newSparseTensor(%[[X]], %[[Y]], %[[Z]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[C]])
+//       CHECK: %[[T:.*]] = call @newSparseTensor(%[[X]], %[[Y]], %[[Z]], %{{.*}}, %{{.*}}, %{{.*}}, %[[FromCOO]], %[[C]])
+//       CHECK: call @delSparseTensorCOOF64(%[[C]])
 //       CHECK: return %[[T]] : !llvm.ptr<i8>
 func @sparse_convert_3d(%arg0: tensor<?x?x?xf64>) -> tensor<?x?x?xf64, #SparseTensor> {
   %0 = sparse_tensor.convert %arg0 : tensor<?x?x?xf64> to tensor<?x?x?xf64, #SparseTensor>
@@ -472,9 +491,11 @@ func @sparse_compression(%arg0: tensor<8x8xf64, #SparseMatrix>,
 // CHECK-LABEL: func @sparse_out1(
 //  CHECK-SAME: %[[A:.*]]: !llvm.ptr<i8>,
 //  CHECK-SAME: %[[B:.*]]: !llvm.ptr<i8>)
-//  CHECK-DAG:  %[[C:.*]] = arith.constant false
-//       CHECK: %[[T:.*]] = call @newSparseTensor(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[A]])
-//       CHECK: call @outSparseTensorF64(%[[T]], %[[B]], %[[C]]) : (!llvm.ptr<i8>, !llvm.ptr<i8>, i1) -> ()
+//  CHECK-DAG:  %[[ToCOO:.*]] = arith.constant 5 : i32
+//  CHECK-DAG:  %[[Sort:.*]] = arith.constant false
+//       CHECK: %[[COO:.*]] = call @newSparseTensor(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[ToCOO]], %[[A]])
+//       CHECK: call @outSparseTensorF64(%[[COO]], %[[B]], %[[Sort]]) : (!llvm.ptr<i8>, !llvm.ptr<i8>, i1) -> ()
+//       CHECK: call @delSparseTensorCOOF64(%[[COO]])
 //       CHECK: return
 func @sparse_out1(%arg0: tensor<?x?xf64, #SparseMatrix>, %arg1: !llvm.ptr<i8>) {
   sparse_tensor.out %arg0, %arg1 : tensor<?x?xf64, #SparseMatrix>, !llvm.ptr<i8>
@@ -484,9 +505,11 @@ func @sparse_out1(%arg0: tensor<?x?xf64, #SparseMatrix>, %arg1: !llvm.ptr<i8>) {
 // CHECK-LABEL: func @sparse_out2(
 //  CHECK-SAME: %[[A:.*]]: !llvm.ptr<i8>,
 //  CHECK-SAME: %[[B:.*]]: !llvm.ptr<i8>)
-//  CHECK-DAG:  %[[C:.*]] = arith.constant true
-//       CHECK: %[[T:.*]] = call @newSparseTensor(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[A]])
-//       CHECK: call @outSparseTensorF32(%[[T]], %[[B]], %[[C]]) : (!llvm.ptr<i8>, !llvm.ptr<i8>, i1) -> ()
+//  CHECK-DAG:  %[[ToCOO:.*]] = arith.constant 5 : i32
+//  CHECK-DAG:  %[[Sort:.*]] = arith.constant true
+//       CHECK: %[[COO:.*]] = call @newSparseTensor(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[ToCOO]], %[[A]])
+//       CHECK: call @outSparseTensorF32(%[[COO]], %[[B]], %[[Sort]]) : (!llvm.ptr<i8>, !llvm.ptr<i8>, i1) -> ()
+//       CHECK: call @delSparseTensorCOOF32(%[[COO]])
 //       CHECK: return
 func @sparse_out2(%arg0: tensor<?x?x?xf32, #SparseTensor>, %arg1: !llvm.ptr<i8>) {
   sparse_tensor.out %arg0, %arg1 : tensor<?x?x?xf32, #SparseTensor>, !llvm.ptr<i8>


        


More information about the Mlir-commits mailing list