[Mlir-commits] [mlir] [mlir][sparse] simplify reader construction of new sparse tensor (PR #69036)

Aart Bik llvmlistbot at llvm.org
Fri Oct 13 17:46:36 PDT 2023


https://github.com/aartbik updated https://github.com/llvm/llvm-project/pull/69036

>From da85ab48e4c835a51d508ca6722e283e5a2c921b Mon Sep 17 00:00:00 2001
From: Aart Bik <ajcbik at google.com>
Date: Fri, 13 Oct 2023 17:15:31 -0700
Subject: [PATCH 1/2] [mlir][sparse] simplify reader construction of new sparse
 tensor

Making the materialize-from-reader method part of the swiss army knife
suite agains removes a lot of redundant boiler plate code and unifies
the parameter setup into a single centralized utility. Furthermore,
we now have minimized the number of entry points into the library
that need a non-permutation map setup, simplifying what comes next
---
 .../mlir/Dialect/SparseTensor/IR/Enums.h      |   1 +
 .../ExecutionEngine/SparseTensorRuntime.h     |  25 ----
 .../Transforms/SparseTensorConversion.cpp     |  33 ++---
 .../ExecutionEngine/SparseTensorRuntime.cpp   | 137 +-----------------
 .../test/Dialect/SparseTensor/conversion.mlir |  30 ++--
 5 files changed, 31 insertions(+), 195 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
index 1434c649acd29b4..0caf83a63b531f2 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
@@ -146,6 +146,7 @@ enum class Action : uint32_t {
   kEmptyForward = 1,
   kFromCOO = 2,
   kSparseToSparse = 3,
+  kFromReader = 4,
   kToCOO = 5,
   kPack = 7,
   kSortCOOInPlace = 8,
diff --git a/mlir/include/mlir/ExecutionEngine/SparseTensorRuntime.h b/mlir/include/mlir/ExecutionEngine/SparseTensorRuntime.h
index e8dd50d6730c784..a470afc2f0c8cd1 100644
--- a/mlir/include/mlir/ExecutionEngine/SparseTensorRuntime.h
+++ b/mlir/include/mlir/ExecutionEngine/SparseTensorRuntime.h
@@ -115,16 +115,6 @@ MLIR_CRUNNERUTILS_EXPORT void *_mlir_ciface_createCheckedSparseTensorReader(
     char *filename, StridedMemRefType<index_type, 1> *dimShapeRef,
     PrimaryType valTp);
 
-/// Constructs a new sparse-tensor storage object with the given encoding,
-/// initializes it by reading all the elements from the file, and then
-/// closes the file.
-MLIR_CRUNNERUTILS_EXPORT void *_mlir_ciface_newSparseTensorFromReader(
-    void *p, StridedMemRefType<index_type, 1> *lvlSizesRef,
-    StridedMemRefType<DimLevelType, 1> *lvlTypesRef,
-    StridedMemRefType<index_type, 1> *dim2lvlRef,
-    StridedMemRefType<index_type, 1> *lvl2dimRef, OverheadType posTp,
-    OverheadType crdTp, PrimaryType valTp);
-
 /// SparseTensorReader method to obtain direct access to the
 /// dimension-sizes array.
 MLIR_CRUNNERUTILS_EXPORT void _mlir_ciface_getSparseTensorReaderDimSizes(
@@ -197,24 +187,9 @@ MLIR_SPARSETENSOR_FOREVERY_V(DECL_DELCOO)
 /// defined with the naming convention ${TENSOR0}, ${TENSOR1}, etc.
 MLIR_CRUNNERUTILS_EXPORT char *getTensorFilename(index_type id);
 
-/// Helper function to read the header of a file and return the
-/// shape/sizes, without parsing the elements of the file.
-MLIR_CRUNNERUTILS_EXPORT void readSparseTensorShape(char *filename,
-                                                    std::vector<uint64_t> *out);
-
-/// Returns the rank of the sparse tensor being read.
-MLIR_CRUNNERUTILS_EXPORT index_type getSparseTensorReaderRank(void *p);
-
-/// Returns the is_symmetric bit for the sparse tensor being read.
-MLIR_CRUNNERUTILS_EXPORT bool getSparseTensorReaderIsSymmetric(void *p);
-
 /// Returns the number of stored elements for the sparse tensor being read.
 MLIR_CRUNNERUTILS_EXPORT index_type getSparseTensorReaderNSE(void *p);
 
-/// Returns the size of a dimension for the sparse tensor being read.
-MLIR_CRUNNERUTILS_EXPORT index_type getSparseTensorReaderDimSize(void *p,
-                                                                 index_type d);
-
 /// Releases the SparseTensorReader and closes the associated file.
 MLIR_CRUNNERUTILS_EXPORT void delSparseTensorReader(void *p);
 
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
index a76f81410aa87a0..638475a80343d91 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
@@ -199,12 +199,15 @@ class NewCallParams final {
   /// type-level information such as the encoding and sizes), generating
   /// MLIR buffers as needed, and returning `this` for method chaining.
   NewCallParams &genBuffers(SparseTensorType stt,
-                            ArrayRef<Value> dimSizesValues) {
+                            ArrayRef<Value> dimSizesValues,
+                            Value dimSizesBuffer = Value()) {
     assert(dimSizesValues.size() == static_cast<size_t>(stt.getDimRank()));
     // Sparsity annotations.
     params[kParamLvlTypes] = genLvlTypesBuffer(builder, loc, stt);
     // Construct dimSizes, lvlSizes, dim2lvl, and lvl2dim buffers.
-    params[kParamDimSizes] = allocaBuffer(builder, loc, dimSizesValues);
+    params[kParamDimSizes] = dimSizesBuffer
+                                 ? dimSizesBuffer
+                                 : allocaBuffer(builder, loc, dimSizesValues);
     params[kParamLvlSizes] =
         genMapBuffers(builder, loc, stt, dimSizesValues, params[kParamDimSizes],
                       params[kParamDim2Lvl], params[kParamLvl2Dim]);
@@ -342,33 +345,15 @@ class SparseTensorNewConverter : public OpConversionPattern<NewOp> {
     const auto stt = getSparseTensorType(op);
     if (!stt.hasEncoding())
       return failure();
-    // Construct the reader opening method calls.
+    // Construct the `reader` opening method calls.
     SmallVector<Value> dimShapesValues;
     Value dimSizesBuffer;
     Value reader = genReader(rewriter, loc, stt, adaptor.getOperands()[0],
                              dimShapesValues, dimSizesBuffer);
-    // Now construct the lvlSizes, dim2lvl, and lvl2dim buffers.
-    Value dim2lvlBuffer;
-    Value lvl2dimBuffer;
-    Value lvlSizesBuffer =
-        genMapBuffers(rewriter, loc, stt, dimShapesValues, dimSizesBuffer,
-                      dim2lvlBuffer, lvl2dimBuffer);
     // Use the `reader` to parse the file.
-    Type opaqueTp = getOpaquePointerType(rewriter);
-    Type eltTp = stt.getElementType();
-    Value valTp = constantPrimaryTypeEncoding(rewriter, loc, eltTp);
-    SmallVector<Value, 8> params{
-        reader,
-        lvlSizesBuffer,
-        genLvlTypesBuffer(rewriter, loc, stt),
-        dim2lvlBuffer,
-        lvl2dimBuffer,
-        constantPosTypeEncoding(rewriter, loc, stt.getEncoding()),
-        constantCrdTypeEncoding(rewriter, loc, stt.getEncoding()),
-        valTp};
-    Value tensor = createFuncCall(rewriter, loc, "newSparseTensorFromReader",
-                                  opaqueTp, params, EmitCInterface::On)
-                       .getResult(0);
+    Value tensor = NewCallParams(rewriter, loc)
+                 .genBuffers(stt, dimShapesValues, dimSizesBuffer)
+                 .genNewCall(Action::kFromReader, reader);
     // Free the memory for `reader`.
     createFuncCall(rewriter, loc, "delSparseTensorReader", {}, {reader},
                    EmitCInterface::Off);
diff --git a/mlir/lib/ExecutionEngine/SparseTensorRuntime.cpp b/mlir/lib/ExecutionEngine/SparseTensorRuntime.cpp
index ae33a869497a01c..fbd98f6cf183793 100644
--- a/mlir/lib/ExecutionEngine/SparseTensorRuntime.cpp
+++ b/mlir/lib/ExecutionEngine/SparseTensorRuntime.cpp
@@ -138,6 +138,12 @@ extern "C" {
           dimRank, dimSizes, lvlRank, lvlSizes, lvlTypes, dim2lvl, lvl2dim,    \
           dimRank, tensor);                                                    \
     }                                                                          \
+    case Action::kFromReader: {                                                \
+      assert(ptr && "Received nullptr for SparseTensorReader object");         \
+      SparseTensorReader &reader = *static_cast<SparseTensorReader *>(ptr);    \
+      return static_cast<void *>(reader.readSparseTensor<P, C, V>(             \
+        lvlRank, lvlSizes, lvlTypes, dim2lvl, lvl2dim));                       \
+    }                                                                          \
     case Action::kToCOO: {                                                     \
       assert(ptr && "Received nullptr for SparseTensorStorage object");        \
       auto &tensor = *static_cast<SparseTensorStorage<P, C, V> *>(ptr);        \
@@ -442,113 +448,6 @@ void _mlir_ciface_getSparseTensorReaderDimSizes(
 MLIR_SPARSETENSOR_FOREVERY_V_O(IMPL_GETNEXT)
 #undef IMPL_GETNEXT
 
-void *_mlir_ciface_newSparseTensorFromReader(
-    void *p, StridedMemRefType<index_type, 1> *lvlSizesRef,
-    StridedMemRefType<DimLevelType, 1> *lvlTypesRef,
-    StridedMemRefType<index_type, 1> *dim2lvlRef,
-    StridedMemRefType<index_type, 1> *lvl2dimRef, OverheadType posTp,
-    OverheadType crdTp, PrimaryType valTp) {
-  assert(p);
-  SparseTensorReader &reader = *static_cast<SparseTensorReader *>(p);
-  ASSERT_NO_STRIDE(lvlSizesRef);
-  ASSERT_NO_STRIDE(lvlTypesRef);
-  ASSERT_NO_STRIDE(dim2lvlRef);
-  ASSERT_NO_STRIDE(lvl2dimRef);
-  const uint64_t dimRank = reader.getRank();
-  const uint64_t lvlRank = MEMREF_GET_USIZE(lvlSizesRef);
-  ASSERT_USIZE_EQ(lvlTypesRef, lvlRank);
-  ASSERT_USIZE_EQ(dim2lvlRef, dimRank);
-  ASSERT_USIZE_EQ(lvl2dimRef, lvlRank);
-  (void)dimRank;
-  const index_type *lvlSizes = MEMREF_GET_PAYLOAD(lvlSizesRef);
-  const DimLevelType *lvlTypes = MEMREF_GET_PAYLOAD(lvlTypesRef);
-  const index_type *dim2lvl = MEMREF_GET_PAYLOAD(dim2lvlRef);
-  const index_type *lvl2dim = MEMREF_GET_PAYLOAD(lvl2dimRef);
-#define CASE(p, c, v, P, C, V)                                                 \
-  if (posTp == OverheadType::p && crdTp == OverheadType::c &&                  \
-      valTp == PrimaryType::v)                                                 \
-    return static_cast<void *>(reader.readSparseTensor<P, C, V>(               \
-        lvlRank, lvlSizes, lvlTypes, dim2lvl, lvl2dim));
-#define CASE_SECSAME(p, v, P, V) CASE(p, p, v, P, P, V)
-  // Rewrite kIndex to kU64, to avoid introducing a bunch of new cases.
-  // This is safe because of the static_assert above.
-  if (posTp == OverheadType::kIndex)
-    posTp = OverheadType::kU64;
-  if (crdTp == OverheadType::kIndex)
-    crdTp = OverheadType::kU64;
-  // Double matrices with all combinations of overhead storage.
-  CASE(kU64, kU64, kF64, uint64_t, uint64_t, double);
-  CASE(kU64, kU32, kF64, uint64_t, uint32_t, double);
-  CASE(kU64, kU16, kF64, uint64_t, uint16_t, double);
-  CASE(kU64, kU8, kF64, uint64_t, uint8_t, double);
-  CASE(kU32, kU64, kF64, uint32_t, uint64_t, double);
-  CASE(kU32, kU32, kF64, uint32_t, uint32_t, double);
-  CASE(kU32, kU16, kF64, uint32_t, uint16_t, double);
-  CASE(kU32, kU8, kF64, uint32_t, uint8_t, double);
-  CASE(kU16, kU64, kF64, uint16_t, uint64_t, double);
-  CASE(kU16, kU32, kF64, uint16_t, uint32_t, double);
-  CASE(kU16, kU16, kF64, uint16_t, uint16_t, double);
-  CASE(kU16, kU8, kF64, uint16_t, uint8_t, double);
-  CASE(kU8, kU64, kF64, uint8_t, uint64_t, double);
-  CASE(kU8, kU32, kF64, uint8_t, uint32_t, double);
-  CASE(kU8, kU16, kF64, uint8_t, uint16_t, double);
-  CASE(kU8, kU8, kF64, uint8_t, uint8_t, double);
-  // Float matrices with all combinations of overhead storage.
-  CASE(kU64, kU64, kF32, uint64_t, uint64_t, float);
-  CASE(kU64, kU32, kF32, uint64_t, uint32_t, float);
-  CASE(kU64, kU16, kF32, uint64_t, uint16_t, float);
-  CASE(kU64, kU8, kF32, uint64_t, uint8_t, float);
-  CASE(kU32, kU64, kF32, uint32_t, uint64_t, float);
-  CASE(kU32, kU32, kF32, uint32_t, uint32_t, float);
-  CASE(kU32, kU16, kF32, uint32_t, uint16_t, float);
-  CASE(kU32, kU8, kF32, uint32_t, uint8_t, float);
-  CASE(kU16, kU64, kF32, uint16_t, uint64_t, float);
-  CASE(kU16, kU32, kF32, uint16_t, uint32_t, float);
-  CASE(kU16, kU16, kF32, uint16_t, uint16_t, float);
-  CASE(kU16, kU8, kF32, uint16_t, uint8_t, float);
-  CASE(kU8, kU64, kF32, uint8_t, uint64_t, float);
-  CASE(kU8, kU32, kF32, uint8_t, uint32_t, float);
-  CASE(kU8, kU16, kF32, uint8_t, uint16_t, float);
-  CASE(kU8, kU8, kF32, uint8_t, uint8_t, float);
-  // Two-byte floats with both overheads of the same type.
-  CASE_SECSAME(kU64, kF16, uint64_t, f16);
-  CASE_SECSAME(kU64, kBF16, uint64_t, bf16);
-  CASE_SECSAME(kU32, kF16, uint32_t, f16);
-  CASE_SECSAME(kU32, kBF16, uint32_t, bf16);
-  CASE_SECSAME(kU16, kF16, uint16_t, f16);
-  CASE_SECSAME(kU16, kBF16, uint16_t, bf16);
-  CASE_SECSAME(kU8, kF16, uint8_t, f16);
-  CASE_SECSAME(kU8, kBF16, uint8_t, bf16);
-  // Integral matrices with both overheads of the same type.
-  CASE_SECSAME(kU64, kI64, uint64_t, int64_t);
-  CASE_SECSAME(kU64, kI32, uint64_t, int32_t);
-  CASE_SECSAME(kU64, kI16, uint64_t, int16_t);
-  CASE_SECSAME(kU64, kI8, uint64_t, int8_t);
-  CASE_SECSAME(kU32, kI64, uint32_t, int64_t);
-  CASE_SECSAME(kU32, kI32, uint32_t, int32_t);
-  CASE_SECSAME(kU32, kI16, uint32_t, int16_t);
-  CASE_SECSAME(kU32, kI8, uint32_t, int8_t);
-  CASE_SECSAME(kU16, kI64, uint16_t, int64_t);
-  CASE_SECSAME(kU16, kI32, uint16_t, int32_t);
-  CASE_SECSAME(kU16, kI16, uint16_t, int16_t);
-  CASE_SECSAME(kU16, kI8, uint16_t, int8_t);
-  CASE_SECSAME(kU8, kI64, uint8_t, int64_t);
-  CASE_SECSAME(kU8, kI32, uint8_t, int32_t);
-  CASE_SECSAME(kU8, kI16, uint8_t, int16_t);
-  CASE_SECSAME(kU8, kI8, uint8_t, int8_t);
-  // Complex matrices with wide overhead.
-  CASE_SECSAME(kU64, kC64, uint64_t, complex64);
-  CASE_SECSAME(kU64, kC32, uint64_t, complex32);
-
-  // Unsupported case (add above if needed).
-  MLIR_SPARSETENSOR_FATAL(
-      "unsupported combination of types: <P=%d, C=%d, V=%d>\n",
-      static_cast<int>(posTp), static_cast<int>(crdTp),
-      static_cast<int>(valTp));
-#undef CASE_SECSAME
-#undef CASE
-}
-
 void _mlir_ciface_outSparseTensorWriterMetaData(
     void *p, index_type dimRank, index_type nse,
     StridedMemRefType<index_type, 1> *dimSizesRef) {
@@ -635,34 +534,10 @@ char *getTensorFilename(index_type id) {
   return env;
 }
 
-void readSparseTensorShape(char *filename, std::vector<uint64_t> *out) {
-  assert(out && "Received nullptr for out-parameter");
-  SparseTensorReader reader(filename);
-  reader.openFile();
-  reader.readHeader();
-  reader.closeFile();
-  const uint64_t dimRank = reader.getRank();
-  const uint64_t *dimSizes = reader.getDimSizes();
-  out->reserve(dimRank);
-  out->assign(dimSizes, dimSizes + dimRank);
-}
-
-index_type getSparseTensorReaderRank(void *p) {
-  return static_cast<SparseTensorReader *>(p)->getRank();
-}
-
-bool getSparseTensorReaderIsSymmetric(void *p) {
-  return static_cast<SparseTensorReader *>(p)->isSymmetric();
-}
-
 index_type getSparseTensorReaderNSE(void *p) {
   return static_cast<SparseTensorReader *>(p)->getNSE();
 }
 
-index_type getSparseTensorReaderDimSize(void *p, index_type d) {
-  return static_cast<SparseTensorReader *>(p)->getDimSize(d);
-}
-
 void delSparseTensorReader(void *p) {
   delete static_cast<SparseTensorReader *>(p);
 }
diff --git a/mlir/test/Dialect/SparseTensor/conversion.mlir b/mlir/test/Dialect/SparseTensor/conversion.mlir
index 96300a98a6a4bc5..2ff4887dae7b8c9 100644
--- a/mlir/test/Dialect/SparseTensor/conversion.mlir
+++ b/mlir/test/Dialect/SparseTensor/conversion.mlir
@@ -78,11 +78,11 @@ func.func @sparse_dim3d_const(%arg0: tensor<10x20x30xf64, #SparseTensor>) -> ind
 //   CHECK-DAG: %[[DimShape0:.*]] = memref.alloca() : memref<1xindex>
 //   CHECK-DAG: %[[DimShape:.*]] = memref.cast %[[DimShape0]] : memref<1xindex> to memref<?xindex>
 //       CHECK: %[[Reader:.*]] = call @createCheckedSparseTensorReader(%[[A]], %[[DimShape]], %{{.*}})
-//   CHECK-DAG: %[[Iota0:.*]] = memref.alloca() : memref<1xindex>
-//   CHECK-DAG: %[[Iota:.*]] = memref.cast %[[Iota0]] : memref<1xindex> to memref<?xindex>
 //   CHECK-DAG: %[[LvlTypes0:.*]] = memref.alloca() : memref<1xi8>
 //   CHECK-DAG: %[[LvlTypes:.*]] = memref.cast %[[LvlTypes0]] : memref<1xi8> to memref<?xi8>
-//       CHECK: %[[T:.*]] = call @newSparseTensorFromReader(%[[Reader]], %[[DimShape]], %[[LvlTypes]], %[[Iota]], %[[Iota]], %{{.*}}, %{{.*}}, %{{.*}})
+//   CHECK-DAG: %[[Iota0:.*]] = memref.alloca() : memref<1xindex>
+//   CHECK-DAG: %[[Iota:.*]] = memref.cast %[[Iota0]] : memref<1xindex> to memref<?xindex>
+//       CHECK: %[[T:.*]] = call @newSparseTensor(%[[DimShape]], %[[DimShape]], %[[LvlTypes]], %[[Iota]], %[[Iota]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[Reader]])
 //       CHECK: call @delSparseTensorReader(%[[Reader]])
 //       CHECK: return %[[T]] : !llvm.ptr<i8>
 func.func @sparse_new1d(%arg0: !llvm.ptr<i8>) -> tensor<128xf64, #SparseVector> {
@@ -96,11 +96,11 @@ func.func @sparse_new1d(%arg0: !llvm.ptr<i8>) -> tensor<128xf64, #SparseVector>
 //   CHECK-DAG: %[[DimShape:.*]] = memref.cast %[[DimShape0]] : memref<2xindex> to memref<?xindex>
 //       CHECK: %[[Reader:.*]] = call @createCheckedSparseTensorReader(%[[A]], %[[DimShape]], %{{.*}})
 //       CHECK: %[[DimSizes:.*]] = call @getSparseTensorReaderDimSizes(%[[Reader]])
-//   CHECK-DAG: %[[Iota0:.*]] = memref.alloca() : memref<2xindex>
-//   CHECK-DAG: %[[Iota:.*]] = memref.cast %[[Iota0]] : memref<2xindex> to memref<?xindex>
 //   CHECK-DAG: %[[LvlTypes0:.*]] = memref.alloca() : memref<2xi8>
 //   CHECK-DAG: %[[LvlTypes:.*]] = memref.cast %[[LvlTypes0]] : memref<2xi8> to memref<?xi8>
-//       CHECK: %[[T:.*]] = call @newSparseTensorFromReader(%[[Reader]], %[[DimSizes]], %[[LvlTypes]], %[[Iota]], %[[Iota]], %{{.*}}, %{{.*}}, %{{.*}})
+//   CHECK-DAG: %[[Iota0:.*]] = memref.alloca() : memref<2xindex>
+//   CHECK-DAG: %[[Iota:.*]] = memref.cast %[[Iota0]] : memref<2xindex> to memref<?xindex>
+//       CHECK: %[[T:.*]] = call @newSparseTensor(%[[DimSizes]], %[[DimSizes]], %[[LvlTypes]], %[[Iota]], %[[Iota]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[Reader]])
 //       CHECK: call @delSparseTensorReader(%[[Reader]])
 //       CHECK: return %[[T]] : !llvm.ptr<i8>
 func.func @sparse_new2d(%arg0: !llvm.ptr<i8>) -> tensor<?x?xf32, #CSR> {
@@ -114,15 +114,15 @@ func.func @sparse_new2d(%arg0: !llvm.ptr<i8>) -> tensor<?x?xf32, #CSR> {
 //   CHECK-DAG: %[[DimShape:.*]] = memref.cast %[[DimShape0]] : memref<3xindex> to memref<?xindex>
 //       CHECK: %[[Reader:.*]] = call @createCheckedSparseTensorReader(%[[A]], %[[DimShape]], %{{.*}})
 //       CHECK: %[[DimSizes:.*]] = call @getSparseTensorReaderDimSizes(%[[Reader]])
-//       CHECK: %[[Dim2Lvl0:.*]] = memref.alloca() : memref<3xindex>
-//       CHECK: %[[Dim2Lvl:.*]] = memref.cast %[[Dim2Lvl0]] : memref<3xindex> to memref<?xindex>
-//       CHECK: %[[Lvl2Dim0:.*]] = memref.alloca() : memref<3xindex>
-//       CHECK: %[[Lvl2Dim:.*]] = memref.cast %[[Lvl2Dim0]] : memref<3xindex> to memref<?xindex>
-//       CHECK: %[[LvlSizes0:.*]] = memref.alloca() : memref<3xindex>
-//       CHECK: %[[LvlSizes:.*]] = memref.cast %[[LvlSizes0]] : memref<3xindex> to memref<?xindex>
-//       CHECK: %[[LvlTypes0:.*]] = memref.alloca() : memref<3xi8>
-//       CHECK: %[[LvlTypes:.*]] = memref.cast %[[LvlTypes0]] : memref<3xi8> to memref<?xi8>
-//       CHECK: %[[T:.*]] = call @newSparseTensorFromReader(%[[Reader]], %[[LvlSizes]], %[[LvlTypes]], %[[Dim2Lvl]], %[[Lvl2Dim]], %{{.*}}, %{{.*}}, %{{.*}})
+//   CHECK-DAG: %[[LvlTypes0:.*]] = memref.alloca() : memref<3xi8>
+//   CHECK-DAG: %[[LvlTypes:.*]] = memref.cast %[[LvlTypes0]] : memref<3xi8> to memref<?xi8>
+//   CHECK-DAG: %[[Dim2Lvl0:.*]] = memref.alloca() : memref<3xindex>
+//   CHECK-DAG: %[[Dim2Lvl:.*]] = memref.cast %[[Dim2Lvl0]] : memref<3xindex> to memref<?xindex>
+//   CHECK-DAG: %[[Lvl2Dim0:.*]] = memref.alloca() : memref<3xindex>
+//   CHECK-DAG: %[[Lvl2Dim:.*]] = memref.cast %[[Lvl2Dim0]] : memref<3xindex> to memref<?xindex>
+//   CHECK-DAG: %[[LvlSizes0:.*]] = memref.alloca() : memref<3xindex>
+//   CHECK-DAG: %[[LvlSizes:.*]] = memref.cast %[[LvlSizes0]] : memref<3xindex> to memref<?xindex>
+//       CHECK: %[[T:.*]] = call @newSparseTensor(%[[DimSizes]], %[[LvlSizes]], %[[LvlTypes]], %[[Dim2Lvl]], %[[Lvl2Dim]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[Reader]])
 //       CHECK: call @delSparseTensorReader(%[[Reader]])
 //       CHECK: return %[[T]] : !llvm.ptr<i8>
 func.func @sparse_new3d(%arg0: !llvm.ptr<i8>) -> tensor<?x?x?xf32, #SparseTensor> {

>From 54d0fd06222284f97887fddea987af82727a1919 Mon Sep 17 00:00:00 2001
From: Aart Bik <ajcbik at google.com>
Date: Fri, 13 Oct 2023 17:45:56 -0700
Subject: [PATCH 2/2] clang-format

---
 .../SparseTensor/Transforms/SparseTensorConversion.cpp        | 4 ++--
 mlir/lib/ExecutionEngine/SparseTensorRuntime.cpp              | 2 +-
 2 files changed, 3 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
index 638475a80343d91..73f5e3eeb7d512e 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
@@ -352,8 +352,8 @@ class SparseTensorNewConverter : public OpConversionPattern<NewOp> {
                              dimShapesValues, dimSizesBuffer);
     // Use the `reader` to parse the file.
     Value tensor = NewCallParams(rewriter, loc)
-                 .genBuffers(stt, dimShapesValues, dimSizesBuffer)
-                 .genNewCall(Action::kFromReader, reader);
+                       .genBuffers(stt, dimShapesValues, dimSizesBuffer)
+                       .genNewCall(Action::kFromReader, reader);
     // Free the memory for `reader`.
     createFuncCall(rewriter, loc, "delSparseTensorReader", {}, {reader},
                    EmitCInterface::Off);
diff --git a/mlir/lib/ExecutionEngine/SparseTensorRuntime.cpp b/mlir/lib/ExecutionEngine/SparseTensorRuntime.cpp
index fbd98f6cf183793..74ab65c143d63e8 100644
--- a/mlir/lib/ExecutionEngine/SparseTensorRuntime.cpp
+++ b/mlir/lib/ExecutionEngine/SparseTensorRuntime.cpp
@@ -142,7 +142,7 @@ extern "C" {
       assert(ptr && "Received nullptr for SparseTensorReader object");         \
       SparseTensorReader &reader = *static_cast<SparseTensorReader *>(ptr);    \
       return static_cast<void *>(reader.readSparseTensor<P, C, V>(             \
-        lvlRank, lvlSizes, lvlTypes, dim2lvl, lvl2dim));                       \
+          lvlRank, lvlSizes, lvlTypes, dim2lvl, lvl2dim));                     \
     }                                                                          \
     case Action::kToCOO: {                                                     \
       assert(ptr && "Received nullptr for SparseTensorStorage object");        \



More information about the Mlir-commits mailing list