[Mlir-commits] [mlir] [mlir][sparse] simplify reader construction of new sparse tensor (PR #69036)
Aart Bik
llvmlistbot at llvm.org
Fri Oct 13 17:46:36 PDT 2023
https://github.com/aartbik updated https://github.com/llvm/llvm-project/pull/69036
>From da85ab48e4c835a51d508ca6722e283e5a2c921b Mon Sep 17 00:00:00 2001
From: Aart Bik <ajcbik at google.com>
Date: Fri, 13 Oct 2023 17:15:31 -0700
Subject: [PATCH 1/2] [mlir][sparse] simplify reader construction of new sparse
tensor
Making the materialize-from-reader method part of the swiss army knife
suite agains removes a lot of redundant boiler plate code and unifies
the parameter setup into a single centralized utility. Furthermore,
we now have minimized the number of entry points into the library
that need a non-permutation map setup, simplifying what comes next
---
.../mlir/Dialect/SparseTensor/IR/Enums.h | 1 +
.../ExecutionEngine/SparseTensorRuntime.h | 25 ----
.../Transforms/SparseTensorConversion.cpp | 33 ++---
.../ExecutionEngine/SparseTensorRuntime.cpp | 137 +-----------------
.../test/Dialect/SparseTensor/conversion.mlir | 30 ++--
5 files changed, 31 insertions(+), 195 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
index 1434c649acd29b4..0caf83a63b531f2 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
@@ -146,6 +146,7 @@ enum class Action : uint32_t {
kEmptyForward = 1,
kFromCOO = 2,
kSparseToSparse = 3,
+ kFromReader = 4,
kToCOO = 5,
kPack = 7,
kSortCOOInPlace = 8,
diff --git a/mlir/include/mlir/ExecutionEngine/SparseTensorRuntime.h b/mlir/include/mlir/ExecutionEngine/SparseTensorRuntime.h
index e8dd50d6730c784..a470afc2f0c8cd1 100644
--- a/mlir/include/mlir/ExecutionEngine/SparseTensorRuntime.h
+++ b/mlir/include/mlir/ExecutionEngine/SparseTensorRuntime.h
@@ -115,16 +115,6 @@ MLIR_CRUNNERUTILS_EXPORT void *_mlir_ciface_createCheckedSparseTensorReader(
char *filename, StridedMemRefType<index_type, 1> *dimShapeRef,
PrimaryType valTp);
-/// Constructs a new sparse-tensor storage object with the given encoding,
-/// initializes it by reading all the elements from the file, and then
-/// closes the file.
-MLIR_CRUNNERUTILS_EXPORT void *_mlir_ciface_newSparseTensorFromReader(
- void *p, StridedMemRefType<index_type, 1> *lvlSizesRef,
- StridedMemRefType<DimLevelType, 1> *lvlTypesRef,
- StridedMemRefType<index_type, 1> *dim2lvlRef,
- StridedMemRefType<index_type, 1> *lvl2dimRef, OverheadType posTp,
- OverheadType crdTp, PrimaryType valTp);
-
/// SparseTensorReader method to obtain direct access to the
/// dimension-sizes array.
MLIR_CRUNNERUTILS_EXPORT void _mlir_ciface_getSparseTensorReaderDimSizes(
@@ -197,24 +187,9 @@ MLIR_SPARSETENSOR_FOREVERY_V(DECL_DELCOO)
/// defined with the naming convention ${TENSOR0}, ${TENSOR1}, etc.
MLIR_CRUNNERUTILS_EXPORT char *getTensorFilename(index_type id);
-/// Helper function to read the header of a file and return the
-/// shape/sizes, without parsing the elements of the file.
-MLIR_CRUNNERUTILS_EXPORT void readSparseTensorShape(char *filename,
- std::vector<uint64_t> *out);
-
-/// Returns the rank of the sparse tensor being read.
-MLIR_CRUNNERUTILS_EXPORT index_type getSparseTensorReaderRank(void *p);
-
-/// Returns the is_symmetric bit for the sparse tensor being read.
-MLIR_CRUNNERUTILS_EXPORT bool getSparseTensorReaderIsSymmetric(void *p);
-
/// Returns the number of stored elements for the sparse tensor being read.
MLIR_CRUNNERUTILS_EXPORT index_type getSparseTensorReaderNSE(void *p);
-/// Returns the size of a dimension for the sparse tensor being read.
-MLIR_CRUNNERUTILS_EXPORT index_type getSparseTensorReaderDimSize(void *p,
- index_type d);
-
/// Releases the SparseTensorReader and closes the associated file.
MLIR_CRUNNERUTILS_EXPORT void delSparseTensorReader(void *p);
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
index a76f81410aa87a0..638475a80343d91 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
@@ -199,12 +199,15 @@ class NewCallParams final {
/// type-level information such as the encoding and sizes), generating
/// MLIR buffers as needed, and returning `this` for method chaining.
NewCallParams &genBuffers(SparseTensorType stt,
- ArrayRef<Value> dimSizesValues) {
+ ArrayRef<Value> dimSizesValues,
+ Value dimSizesBuffer = Value()) {
assert(dimSizesValues.size() == static_cast<size_t>(stt.getDimRank()));
// Sparsity annotations.
params[kParamLvlTypes] = genLvlTypesBuffer(builder, loc, stt);
// Construct dimSizes, lvlSizes, dim2lvl, and lvl2dim buffers.
- params[kParamDimSizes] = allocaBuffer(builder, loc, dimSizesValues);
+ params[kParamDimSizes] = dimSizesBuffer
+ ? dimSizesBuffer
+ : allocaBuffer(builder, loc, dimSizesValues);
params[kParamLvlSizes] =
genMapBuffers(builder, loc, stt, dimSizesValues, params[kParamDimSizes],
params[kParamDim2Lvl], params[kParamLvl2Dim]);
@@ -342,33 +345,15 @@ class SparseTensorNewConverter : public OpConversionPattern<NewOp> {
const auto stt = getSparseTensorType(op);
if (!stt.hasEncoding())
return failure();
- // Construct the reader opening method calls.
+ // Construct the `reader` opening method calls.
SmallVector<Value> dimShapesValues;
Value dimSizesBuffer;
Value reader = genReader(rewriter, loc, stt, adaptor.getOperands()[0],
dimShapesValues, dimSizesBuffer);
- // Now construct the lvlSizes, dim2lvl, and lvl2dim buffers.
- Value dim2lvlBuffer;
- Value lvl2dimBuffer;
- Value lvlSizesBuffer =
- genMapBuffers(rewriter, loc, stt, dimShapesValues, dimSizesBuffer,
- dim2lvlBuffer, lvl2dimBuffer);
// Use the `reader` to parse the file.
- Type opaqueTp = getOpaquePointerType(rewriter);
- Type eltTp = stt.getElementType();
- Value valTp = constantPrimaryTypeEncoding(rewriter, loc, eltTp);
- SmallVector<Value, 8> params{
- reader,
- lvlSizesBuffer,
- genLvlTypesBuffer(rewriter, loc, stt),
- dim2lvlBuffer,
- lvl2dimBuffer,
- constantPosTypeEncoding(rewriter, loc, stt.getEncoding()),
- constantCrdTypeEncoding(rewriter, loc, stt.getEncoding()),
- valTp};
- Value tensor = createFuncCall(rewriter, loc, "newSparseTensorFromReader",
- opaqueTp, params, EmitCInterface::On)
- .getResult(0);
+ Value tensor = NewCallParams(rewriter, loc)
+ .genBuffers(stt, dimShapesValues, dimSizesBuffer)
+ .genNewCall(Action::kFromReader, reader);
// Free the memory for `reader`.
createFuncCall(rewriter, loc, "delSparseTensorReader", {}, {reader},
EmitCInterface::Off);
diff --git a/mlir/lib/ExecutionEngine/SparseTensorRuntime.cpp b/mlir/lib/ExecutionEngine/SparseTensorRuntime.cpp
index ae33a869497a01c..fbd98f6cf183793 100644
--- a/mlir/lib/ExecutionEngine/SparseTensorRuntime.cpp
+++ b/mlir/lib/ExecutionEngine/SparseTensorRuntime.cpp
@@ -138,6 +138,12 @@ extern "C" {
dimRank, dimSizes, lvlRank, lvlSizes, lvlTypes, dim2lvl, lvl2dim, \
dimRank, tensor); \
} \
+ case Action::kFromReader: { \
+ assert(ptr && "Received nullptr for SparseTensorReader object"); \
+ SparseTensorReader &reader = *static_cast<SparseTensorReader *>(ptr); \
+ return static_cast<void *>(reader.readSparseTensor<P, C, V>( \
+ lvlRank, lvlSizes, lvlTypes, dim2lvl, lvl2dim)); \
+ } \
case Action::kToCOO: { \
assert(ptr && "Received nullptr for SparseTensorStorage object"); \
auto &tensor = *static_cast<SparseTensorStorage<P, C, V> *>(ptr); \
@@ -442,113 +448,6 @@ void _mlir_ciface_getSparseTensorReaderDimSizes(
MLIR_SPARSETENSOR_FOREVERY_V_O(IMPL_GETNEXT)
#undef IMPL_GETNEXT
-void *_mlir_ciface_newSparseTensorFromReader(
- void *p, StridedMemRefType<index_type, 1> *lvlSizesRef,
- StridedMemRefType<DimLevelType, 1> *lvlTypesRef,
- StridedMemRefType<index_type, 1> *dim2lvlRef,
- StridedMemRefType<index_type, 1> *lvl2dimRef, OverheadType posTp,
- OverheadType crdTp, PrimaryType valTp) {
- assert(p);
- SparseTensorReader &reader = *static_cast<SparseTensorReader *>(p);
- ASSERT_NO_STRIDE(lvlSizesRef);
- ASSERT_NO_STRIDE(lvlTypesRef);
- ASSERT_NO_STRIDE(dim2lvlRef);
- ASSERT_NO_STRIDE(lvl2dimRef);
- const uint64_t dimRank = reader.getRank();
- const uint64_t lvlRank = MEMREF_GET_USIZE(lvlSizesRef);
- ASSERT_USIZE_EQ(lvlTypesRef, lvlRank);
- ASSERT_USIZE_EQ(dim2lvlRef, dimRank);
- ASSERT_USIZE_EQ(lvl2dimRef, lvlRank);
- (void)dimRank;
- const index_type *lvlSizes = MEMREF_GET_PAYLOAD(lvlSizesRef);
- const DimLevelType *lvlTypes = MEMREF_GET_PAYLOAD(lvlTypesRef);
- const index_type *dim2lvl = MEMREF_GET_PAYLOAD(dim2lvlRef);
- const index_type *lvl2dim = MEMREF_GET_PAYLOAD(lvl2dimRef);
-#define CASE(p, c, v, P, C, V) \
- if (posTp == OverheadType::p && crdTp == OverheadType::c && \
- valTp == PrimaryType::v) \
- return static_cast<void *>(reader.readSparseTensor<P, C, V>( \
- lvlRank, lvlSizes, lvlTypes, dim2lvl, lvl2dim));
-#define CASE_SECSAME(p, v, P, V) CASE(p, p, v, P, P, V)
- // Rewrite kIndex to kU64, to avoid introducing a bunch of new cases.
- // This is safe because of the static_assert above.
- if (posTp == OverheadType::kIndex)
- posTp = OverheadType::kU64;
- if (crdTp == OverheadType::kIndex)
- crdTp = OverheadType::kU64;
- // Double matrices with all combinations of overhead storage.
- CASE(kU64, kU64, kF64, uint64_t, uint64_t, double);
- CASE(kU64, kU32, kF64, uint64_t, uint32_t, double);
- CASE(kU64, kU16, kF64, uint64_t, uint16_t, double);
- CASE(kU64, kU8, kF64, uint64_t, uint8_t, double);
- CASE(kU32, kU64, kF64, uint32_t, uint64_t, double);
- CASE(kU32, kU32, kF64, uint32_t, uint32_t, double);
- CASE(kU32, kU16, kF64, uint32_t, uint16_t, double);
- CASE(kU32, kU8, kF64, uint32_t, uint8_t, double);
- CASE(kU16, kU64, kF64, uint16_t, uint64_t, double);
- CASE(kU16, kU32, kF64, uint16_t, uint32_t, double);
- CASE(kU16, kU16, kF64, uint16_t, uint16_t, double);
- CASE(kU16, kU8, kF64, uint16_t, uint8_t, double);
- CASE(kU8, kU64, kF64, uint8_t, uint64_t, double);
- CASE(kU8, kU32, kF64, uint8_t, uint32_t, double);
- CASE(kU8, kU16, kF64, uint8_t, uint16_t, double);
- CASE(kU8, kU8, kF64, uint8_t, uint8_t, double);
- // Float matrices with all combinations of overhead storage.
- CASE(kU64, kU64, kF32, uint64_t, uint64_t, float);
- CASE(kU64, kU32, kF32, uint64_t, uint32_t, float);
- CASE(kU64, kU16, kF32, uint64_t, uint16_t, float);
- CASE(kU64, kU8, kF32, uint64_t, uint8_t, float);
- CASE(kU32, kU64, kF32, uint32_t, uint64_t, float);
- CASE(kU32, kU32, kF32, uint32_t, uint32_t, float);
- CASE(kU32, kU16, kF32, uint32_t, uint16_t, float);
- CASE(kU32, kU8, kF32, uint32_t, uint8_t, float);
- CASE(kU16, kU64, kF32, uint16_t, uint64_t, float);
- CASE(kU16, kU32, kF32, uint16_t, uint32_t, float);
- CASE(kU16, kU16, kF32, uint16_t, uint16_t, float);
- CASE(kU16, kU8, kF32, uint16_t, uint8_t, float);
- CASE(kU8, kU64, kF32, uint8_t, uint64_t, float);
- CASE(kU8, kU32, kF32, uint8_t, uint32_t, float);
- CASE(kU8, kU16, kF32, uint8_t, uint16_t, float);
- CASE(kU8, kU8, kF32, uint8_t, uint8_t, float);
- // Two-byte floats with both overheads of the same type.
- CASE_SECSAME(kU64, kF16, uint64_t, f16);
- CASE_SECSAME(kU64, kBF16, uint64_t, bf16);
- CASE_SECSAME(kU32, kF16, uint32_t, f16);
- CASE_SECSAME(kU32, kBF16, uint32_t, bf16);
- CASE_SECSAME(kU16, kF16, uint16_t, f16);
- CASE_SECSAME(kU16, kBF16, uint16_t, bf16);
- CASE_SECSAME(kU8, kF16, uint8_t, f16);
- CASE_SECSAME(kU8, kBF16, uint8_t, bf16);
- // Integral matrices with both overheads of the same type.
- CASE_SECSAME(kU64, kI64, uint64_t, int64_t);
- CASE_SECSAME(kU64, kI32, uint64_t, int32_t);
- CASE_SECSAME(kU64, kI16, uint64_t, int16_t);
- CASE_SECSAME(kU64, kI8, uint64_t, int8_t);
- CASE_SECSAME(kU32, kI64, uint32_t, int64_t);
- CASE_SECSAME(kU32, kI32, uint32_t, int32_t);
- CASE_SECSAME(kU32, kI16, uint32_t, int16_t);
- CASE_SECSAME(kU32, kI8, uint32_t, int8_t);
- CASE_SECSAME(kU16, kI64, uint16_t, int64_t);
- CASE_SECSAME(kU16, kI32, uint16_t, int32_t);
- CASE_SECSAME(kU16, kI16, uint16_t, int16_t);
- CASE_SECSAME(kU16, kI8, uint16_t, int8_t);
- CASE_SECSAME(kU8, kI64, uint8_t, int64_t);
- CASE_SECSAME(kU8, kI32, uint8_t, int32_t);
- CASE_SECSAME(kU8, kI16, uint8_t, int16_t);
- CASE_SECSAME(kU8, kI8, uint8_t, int8_t);
- // Complex matrices with wide overhead.
- CASE_SECSAME(kU64, kC64, uint64_t, complex64);
- CASE_SECSAME(kU64, kC32, uint64_t, complex32);
-
- // Unsupported case (add above if needed).
- MLIR_SPARSETENSOR_FATAL(
- "unsupported combination of types: <P=%d, C=%d, V=%d>\n",
- static_cast<int>(posTp), static_cast<int>(crdTp),
- static_cast<int>(valTp));
-#undef CASE_SECSAME
-#undef CASE
-}
-
void _mlir_ciface_outSparseTensorWriterMetaData(
void *p, index_type dimRank, index_type nse,
StridedMemRefType<index_type, 1> *dimSizesRef) {
@@ -635,34 +534,10 @@ char *getTensorFilename(index_type id) {
return env;
}
-void readSparseTensorShape(char *filename, std::vector<uint64_t> *out) {
- assert(out && "Received nullptr for out-parameter");
- SparseTensorReader reader(filename);
- reader.openFile();
- reader.readHeader();
- reader.closeFile();
- const uint64_t dimRank = reader.getRank();
- const uint64_t *dimSizes = reader.getDimSizes();
- out->reserve(dimRank);
- out->assign(dimSizes, dimSizes + dimRank);
-}
-
-index_type getSparseTensorReaderRank(void *p) {
- return static_cast<SparseTensorReader *>(p)->getRank();
-}
-
-bool getSparseTensorReaderIsSymmetric(void *p) {
- return static_cast<SparseTensorReader *>(p)->isSymmetric();
-}
-
index_type getSparseTensorReaderNSE(void *p) {
return static_cast<SparseTensorReader *>(p)->getNSE();
}
-index_type getSparseTensorReaderDimSize(void *p, index_type d) {
- return static_cast<SparseTensorReader *>(p)->getDimSize(d);
-}
-
void delSparseTensorReader(void *p) {
delete static_cast<SparseTensorReader *>(p);
}
diff --git a/mlir/test/Dialect/SparseTensor/conversion.mlir b/mlir/test/Dialect/SparseTensor/conversion.mlir
index 96300a98a6a4bc5..2ff4887dae7b8c9 100644
--- a/mlir/test/Dialect/SparseTensor/conversion.mlir
+++ b/mlir/test/Dialect/SparseTensor/conversion.mlir
@@ -78,11 +78,11 @@ func.func @sparse_dim3d_const(%arg0: tensor<10x20x30xf64, #SparseTensor>) -> ind
// CHECK-DAG: %[[DimShape0:.*]] = memref.alloca() : memref<1xindex>
// CHECK-DAG: %[[DimShape:.*]] = memref.cast %[[DimShape0]] : memref<1xindex> to memref<?xindex>
// CHECK: %[[Reader:.*]] = call @createCheckedSparseTensorReader(%[[A]], %[[DimShape]], %{{.*}})
-// CHECK-DAG: %[[Iota0:.*]] = memref.alloca() : memref<1xindex>
-// CHECK-DAG: %[[Iota:.*]] = memref.cast %[[Iota0]] : memref<1xindex> to memref<?xindex>
// CHECK-DAG: %[[LvlTypes0:.*]] = memref.alloca() : memref<1xi8>
// CHECK-DAG: %[[LvlTypes:.*]] = memref.cast %[[LvlTypes0]] : memref<1xi8> to memref<?xi8>
-// CHECK: %[[T:.*]] = call @newSparseTensorFromReader(%[[Reader]], %[[DimShape]], %[[LvlTypes]], %[[Iota]], %[[Iota]], %{{.*}}, %{{.*}}, %{{.*}})
+// CHECK-DAG: %[[Iota0:.*]] = memref.alloca() : memref<1xindex>
+// CHECK-DAG: %[[Iota:.*]] = memref.cast %[[Iota0]] : memref<1xindex> to memref<?xindex>
+// CHECK: %[[T:.*]] = call @newSparseTensor(%[[DimShape]], %[[DimShape]], %[[LvlTypes]], %[[Iota]], %[[Iota]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[Reader]])
// CHECK: call @delSparseTensorReader(%[[Reader]])
// CHECK: return %[[T]] : !llvm.ptr<i8>
func.func @sparse_new1d(%arg0: !llvm.ptr<i8>) -> tensor<128xf64, #SparseVector> {
@@ -96,11 +96,11 @@ func.func @sparse_new1d(%arg0: !llvm.ptr<i8>) -> tensor<128xf64, #SparseVector>
// CHECK-DAG: %[[DimShape:.*]] = memref.cast %[[DimShape0]] : memref<2xindex> to memref<?xindex>
// CHECK: %[[Reader:.*]] = call @createCheckedSparseTensorReader(%[[A]], %[[DimShape]], %{{.*}})
// CHECK: %[[DimSizes:.*]] = call @getSparseTensorReaderDimSizes(%[[Reader]])
-// CHECK-DAG: %[[Iota0:.*]] = memref.alloca() : memref<2xindex>
-// CHECK-DAG: %[[Iota:.*]] = memref.cast %[[Iota0]] : memref<2xindex> to memref<?xindex>
// CHECK-DAG: %[[LvlTypes0:.*]] = memref.alloca() : memref<2xi8>
// CHECK-DAG: %[[LvlTypes:.*]] = memref.cast %[[LvlTypes0]] : memref<2xi8> to memref<?xi8>
-// CHECK: %[[T:.*]] = call @newSparseTensorFromReader(%[[Reader]], %[[DimSizes]], %[[LvlTypes]], %[[Iota]], %[[Iota]], %{{.*}}, %{{.*}}, %{{.*}})
+// CHECK-DAG: %[[Iota0:.*]] = memref.alloca() : memref<2xindex>
+// CHECK-DAG: %[[Iota:.*]] = memref.cast %[[Iota0]] : memref<2xindex> to memref<?xindex>
+// CHECK: %[[T:.*]] = call @newSparseTensor(%[[DimSizes]], %[[DimSizes]], %[[LvlTypes]], %[[Iota]], %[[Iota]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[Reader]])
// CHECK: call @delSparseTensorReader(%[[Reader]])
// CHECK: return %[[T]] : !llvm.ptr<i8>
func.func @sparse_new2d(%arg0: !llvm.ptr<i8>) -> tensor<?x?xf32, #CSR> {
@@ -114,15 +114,15 @@ func.func @sparse_new2d(%arg0: !llvm.ptr<i8>) -> tensor<?x?xf32, #CSR> {
// CHECK-DAG: %[[DimShape:.*]] = memref.cast %[[DimShape0]] : memref<3xindex> to memref<?xindex>
// CHECK: %[[Reader:.*]] = call @createCheckedSparseTensorReader(%[[A]], %[[DimShape]], %{{.*}})
// CHECK: %[[DimSizes:.*]] = call @getSparseTensorReaderDimSizes(%[[Reader]])
-// CHECK: %[[Dim2Lvl0:.*]] = memref.alloca() : memref<3xindex>
-// CHECK: %[[Dim2Lvl:.*]] = memref.cast %[[Dim2Lvl0]] : memref<3xindex> to memref<?xindex>
-// CHECK: %[[Lvl2Dim0:.*]] = memref.alloca() : memref<3xindex>
-// CHECK: %[[Lvl2Dim:.*]] = memref.cast %[[Lvl2Dim0]] : memref<3xindex> to memref<?xindex>
-// CHECK: %[[LvlSizes0:.*]] = memref.alloca() : memref<3xindex>
-// CHECK: %[[LvlSizes:.*]] = memref.cast %[[LvlSizes0]] : memref<3xindex> to memref<?xindex>
-// CHECK: %[[LvlTypes0:.*]] = memref.alloca() : memref<3xi8>
-// CHECK: %[[LvlTypes:.*]] = memref.cast %[[LvlTypes0]] : memref<3xi8> to memref<?xi8>
-// CHECK: %[[T:.*]] = call @newSparseTensorFromReader(%[[Reader]], %[[LvlSizes]], %[[LvlTypes]], %[[Dim2Lvl]], %[[Lvl2Dim]], %{{.*}}, %{{.*}}, %{{.*}})
+// CHECK-DAG: %[[LvlTypes0:.*]] = memref.alloca() : memref<3xi8>
+// CHECK-DAG: %[[LvlTypes:.*]] = memref.cast %[[LvlTypes0]] : memref<3xi8> to memref<?xi8>
+// CHECK-DAG: %[[Dim2Lvl0:.*]] = memref.alloca() : memref<3xindex>
+// CHECK-DAG: %[[Dim2Lvl:.*]] = memref.cast %[[Dim2Lvl0]] : memref<3xindex> to memref<?xindex>
+// CHECK-DAG: %[[Lvl2Dim0:.*]] = memref.alloca() : memref<3xindex>
+// CHECK-DAG: %[[Lvl2Dim:.*]] = memref.cast %[[Lvl2Dim0]] : memref<3xindex> to memref<?xindex>
+// CHECK-DAG: %[[LvlSizes0:.*]] = memref.alloca() : memref<3xindex>
+// CHECK-DAG: %[[LvlSizes:.*]] = memref.cast %[[LvlSizes0]] : memref<3xindex> to memref<?xindex>
+// CHECK: %[[T:.*]] = call @newSparseTensor(%[[DimSizes]], %[[LvlSizes]], %[[LvlTypes]], %[[Dim2Lvl]], %[[Lvl2Dim]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[Reader]])
// CHECK: call @delSparseTensorReader(%[[Reader]])
// CHECK: return %[[T]] : !llvm.ptr<i8>
func.func @sparse_new3d(%arg0: !llvm.ptr<i8>) -> tensor<?x?x?xf32, #SparseTensor> {
>From 54d0fd06222284f97887fddea987af82727a1919 Mon Sep 17 00:00:00 2001
From: Aart Bik <ajcbik at google.com>
Date: Fri, 13 Oct 2023 17:45:56 -0700
Subject: [PATCH 2/2] clang-format
---
.../SparseTensor/Transforms/SparseTensorConversion.cpp | 4 ++--
mlir/lib/ExecutionEngine/SparseTensorRuntime.cpp | 2 +-
2 files changed, 3 insertions(+), 3 deletions(-)
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
index 638475a80343d91..73f5e3eeb7d512e 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
@@ -352,8 +352,8 @@ class SparseTensorNewConverter : public OpConversionPattern<NewOp> {
dimShapesValues, dimSizesBuffer);
// Use the `reader` to parse the file.
Value tensor = NewCallParams(rewriter, loc)
- .genBuffers(stt, dimShapesValues, dimSizesBuffer)
- .genNewCall(Action::kFromReader, reader);
+ .genBuffers(stt, dimShapesValues, dimSizesBuffer)
+ .genNewCall(Action::kFromReader, reader);
// Free the memory for `reader`.
createFuncCall(rewriter, loc, "delSparseTensorReader", {}, {reader},
EmitCInterface::Off);
diff --git a/mlir/lib/ExecutionEngine/SparseTensorRuntime.cpp b/mlir/lib/ExecutionEngine/SparseTensorRuntime.cpp
index fbd98f6cf183793..74ab65c143d63e8 100644
--- a/mlir/lib/ExecutionEngine/SparseTensorRuntime.cpp
+++ b/mlir/lib/ExecutionEngine/SparseTensorRuntime.cpp
@@ -142,7 +142,7 @@ extern "C" {
assert(ptr && "Received nullptr for SparseTensorReader object"); \
SparseTensorReader &reader = *static_cast<SparseTensorReader *>(ptr); \
return static_cast<void *>(reader.readSparseTensor<P, C, V>( \
- lvlRank, lvlSizes, lvlTypes, dim2lvl, lvl2dim)); \
+ lvlRank, lvlSizes, lvlTypes, dim2lvl, lvl2dim)); \
} \
case Action::kToCOO: { \
assert(ptr && "Received nullptr for SparseTensorStorage object"); \
More information about the Mlir-commits
mailing list