[Mlir-commits] [mlir] 2af2e4d - [mlir][sparse] Breaking up openSparseTensor to better support non-permutations

wren romano llvmlistbot at llvm.org
Fri Dec 2 11:11:05 PST 2022


Author: wren romano
Date: 2022-12-02T11:10:57-08:00
New Revision: 2af2e4dbb790daafd3cbbf6189a7a27145cf4c12

URL: https://github.com/llvm/llvm-project/commit/2af2e4dbb790daafd3cbbf6189a7a27145cf4c12
DIFF: https://github.com/llvm/llvm-project/commit/2af2e4dbb790daafd3cbbf6189a7a27145cf4c12.diff

LOG: [mlir][sparse] Breaking up openSparseTensor to better support non-permutations

This commit updates how the `SparseTensorConversion` pass handles `NewOp`.  It breaks up the underlying `openSparseTensor` function into two parts (`SparseTensorReader::create` and `SparseTensorReader::readSparseTensor`) so that the pass can inject code for constructing `lvlSizes` between those two parts.  Migrating the construction of `lvlSizes` out of the runtime and into the pass is a necessary first step toward fully supporting non-permutations.  (The alternative would be for the pass to generate a `FuncOp` for performing the construction and then passing that to the runtime; which doesn't seem to have any benefits over the design of this commit.)  And since the pass now generates the code to call these two functions, this change also removes the `Action::kFromFile` value from the enum used by `_mlir_ciface_newSparseTensor`.

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D138363

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
    mlir/include/mlir/ExecutionEngine/SparseTensor/File.h
    mlir/include/mlir/ExecutionEngine/SparseTensorRuntime.h
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
    mlir/lib/ExecutionEngine/SparseTensor/File.cpp
    mlir/lib/ExecutionEngine/SparseTensorRuntime.cpp
    mlir/test/Dialect/SparseTensor/conversion.mlir
    mlir/test/Dialect/SparseTensor/rewriting_for_codegen.mlir
    mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_file_io.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
index 2f475b4dd913..62118130a3fa 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
@@ -122,7 +122,8 @@ constexpr bool isComplexPrimaryType(PrimaryType valTy) {
 /// The actions performed by @newSparseTensor.
 enum class Action : uint32_t {
   kEmpty = 0,
-  kFromFile = 1,
+  // newSparseTensor no longer handles `kFromFile=1`, so we leave this
+  // number reserved to help catch any code that still needs updating.
   kFromCOO = 2,
   kSparseToSparse = 3,
   kEmptyCOO = 4,

diff  --git a/mlir/include/mlir/ExecutionEngine/SparseTensor/File.h b/mlir/include/mlir/ExecutionEngine/SparseTensor/File.h
index 7e8428f15ec4..3734f96889ff 100644
--- a/mlir/include/mlir/ExecutionEngine/SparseTensor/File.h
+++ b/mlir/include/mlir/ExecutionEngine/SparseTensor/File.h
@@ -103,6 +103,23 @@ class SparseTensorReader final {
   SparseTensorReader(const SparseTensorReader &) = delete;
   SparseTensorReader &operator=(const SparseTensorReader &) = delete;
 
+  /// Factory method to allocate a new reader, open the file, read the
+  /// header, and validate that the actual contents of the file match
+  /// the expected `dimShape` and `valTp`.
+  static SparseTensorReader *create(const char *filename, uint64_t dimRank,
+                                    const uint64_t *dimShape,
+                                    PrimaryType valTp) {
+    SparseTensorReader *reader = new SparseTensorReader(filename);
+    reader->openFile();
+    reader->readHeader();
+    if (!reader->canReadAs(valTp))
+      MLIR_SPARSETENSOR_FATAL(
+          "Tensor element type %d not compatible with values in file %s\n",
+          static_cast<int>(valTp), filename);
+    reader->assertMatchesShape(dimRank, dimShape);
+    return reader;
+  }
+
   // This dtor tries to avoid leaking the `file`.  (Though it's better
   // to call `closeFile` explicitly when possible, since there are
   // circumstances where dtors are not called reliably.)
@@ -173,10 +190,51 @@ class SparseTensorReader final {
   /// to the `indices` array.
   template <typename V>
   V readCOOElement(uint64_t rank, uint64_t *indices) {
-    char *linePtr = readCOOIndices(rank, indices);
+    assert(rank == getRank() && "rank mismatch");
+    char *linePtr = readCOOIndices(indices);
     return detail::readCOOValue<V>(&linePtr, isPattern());
   }
 
+  /// Allocates a new COO object for `lvlSizes`, initializes it by reading
+  /// all the elements from the file and applying `dim2lvl` to their indices,
+  /// and then closes the file.
+  ///
+  /// Preconditions:
+  /// * `lvlSizes` must be valid for `lvlRank`.
+  /// * `dim2lvl` must be valid for `getRank()`.
+  /// * `dim2lvl` maps indices valid for `getDimSizes()` to indices
+  ///   valid for `lvlSizes`.
+  /// * the file's actual value type can be read as `V`.
+  ///
+  /// Asserts:
+  /// * `isValid()`
+  /// * `dim2lvl` is a permutation, and therefore also `lvlRank == getRank()`.
+  ///   (This requirement will be lifted once we functionalize `dim2lvl`.)
+  //
+  // NOTE: This method is factored out of `readSparseTensor` primarily to
+  // reduce code bloat (since the bulk of the code doesn't care about the
+  // `<P,I>` type template parameters).  But we leave it public since it's
+  // perfectly reasonable for clients to use.
+  template <typename V>
+  SparseTensorCOO<V> *readCOO(uint64_t lvlRank, const uint64_t *lvlSizes,
+                              const uint64_t *dim2lvl);
+
+  /// Allocates a new sparse-tensor storage object with the given encoding,
+  /// initializes it by reading all the elements from the file, and then
+  /// closes the file.  Preconditions/assertions are as per `readCOO`
+  /// and `SparseTensorStorage::newFromCOO`.
+  template <typename P, typename I, typename V>
+  SparseTensorStorage<P, I, V> *
+  readSparseTensor(uint64_t lvlRank, const uint64_t *lvlSizes,
+                   const DimLevelType *lvlTypes, const uint64_t *lvl2dim,
+                   const uint64_t *dim2lvl) {
+    auto *lvlCOO = readCOO<V>(lvlRank, lvlSizes, dim2lvl);
+    auto *tensor = SparseTensorStorage<P, I, V>::newFromCOO(
+        getRank(), getDimSizes(), lvlRank, lvlTypes, lvl2dim, *lvlCOO);
+    delete lvlCOO;
+    return tensor;
+  }
+
 private:
   /// Attempts to read a line from the file.  Is private because there's
   /// no reason for client code to call it.
@@ -187,7 +245,9 @@ class SparseTensorReader final {
   /// buffer where the element's value should be parsed from.  This method
   /// has been factored out from `readCOOElement` to minimize code bloat
   /// for the generated library.
-  char *readCOOIndices(uint64_t rank, uint64_t *indices);
+  ///
+  /// Precondition: `indices` is valid for `getRank()`.
+  char *readCOOIndices(uint64_t *indices);
 
   /// Reads the MME header of a general sparse matrix of type real.
   void readMMEHeader();
@@ -209,72 +269,49 @@ class SparseTensorReader final {
 
 //===----------------------------------------------------------------------===//
 
-/// Reads a sparse tensor with the given filename into a memory-resident
-/// sparse tensor.
-///
-/// Preconditions:
-/// * `dimShape` and `dim2lvl` must be valid for `dimRank`.
-/// * `lvlTypes` and `lvl2dim` must be valid for `lvlRank`.
-/// * `dim2lvl` is the inverse of `lvl2dim`.
-///
-/// Asserts:
-/// * the file's actual value type can be read as `valTp`.
-/// * the file's actual dimension-sizes match the expected `dimShape`.
-/// * `dim2lvl` is a permutation, and therefore also `dimRank == lvlRank`.
-//
-// TODO: As currently written, this function uses `dim2lvl` in two
-// places: first, to construct the level-sizes from the file's actual
-// dimension-sizes; and second, to map the file's dimension-indices into
-// level-indices.  The latter can easily generalize to arbitrary mappings,
-// however the former cannot.  Thus, once we functionalize the mappings,
-// this function will need both the sizes-to-sizes and indices-to-indices
-// variants of the `dim2lvl` mapping.  For the `lvl2dim` direction we only
-// need the indices-to-indices variant, for handing off to `newFromCOO`.
-template <typename P, typename I, typename V>
-inline SparseTensorStorage<P, I, V> *
-openSparseTensor(uint64_t dimRank, const uint64_t *dimShape, uint64_t lvlRank,
-                 const DimLevelType *lvlTypes, const uint64_t *lvl2dim,
-                 const uint64_t *dim2lvl, const char *filename,
-                 PrimaryType valTp) {
-  // Read the file's header and check the file's actual element type and
-  // dimension-sizes against the expected element type and dimension-shape.
-  SparseTensorReader stfile(filename);
-  stfile.openFile();
-  stfile.readHeader();
-  if (!stfile.canReadAs(valTp))
-    MLIR_SPARSETENSOR_FATAL(
-        "Tensor element type %d not compatible with values in file %s\n",
-        static_cast<int>(valTp), filename);
-  stfile.assertMatchesShape(dimRank, dimShape);
-  const uint64_t *dimSizes = stfile.getDimSizes();
-  // Construct the level-sizes from the file's dimension-sizes
-  // TODO: This doesn't generalize to arbitrary mappings. (See above.)
-  assert(dimRank == lvlRank && "Rank mismatch");
+template <typename V>
+SparseTensorCOO<V> *SparseTensorReader::readCOO(uint64_t lvlRank,
+                                                const uint64_t *lvlSizes,
+                                                const uint64_t *dim2lvl) {
+  assert(isValid() && "Attempt to readCOO() before readHeader()");
+  // Construct a `PermutationRef` for the `pushforward` below.
+  // TODO: This specific implementation does not generalize to arbitrary
+  // mappings, but once we functionalize the `dim2lvl` argument we can
+  // simply use that function instead.
+  const uint64_t dimRank = getRank();
+  assert(lvlRank == dimRank && "Rank mismatch");
   detail::PermutationRef d2l(dimRank, dim2lvl);
-  std::vector<uint64_t> lvlSizes = d2l.pushforward(dimRank, dimSizes);
   // Prepare a COO object with the number of nonzeros as initial capacity.
-  uint64_t nnz = stfile.getNNZ();
-  auto *lvlCOO = new SparseTensorCOO<V>(lvlSizes, nnz);
+  const uint64_t nnz = getNNZ();
+  auto *lvlCOO = new SparseTensorCOO<V>(lvlRank, lvlSizes, nnz);
   // Read all nonzero elements.
   std::vector<uint64_t> dimInd(dimRank);
   std::vector<uint64_t> lvlInd(lvlRank);
+  // Do some manual LICM, to avoid assertions in the for-loop.
+  const bool addSymmetric = (isSymmetric() && dimRank == 2);
+  const bool isPattern_ = isPattern();
   for (uint64_t k = 0; k < nnz; ++k) {
-    const V value = stfile.readCOOElement<V>(dimRank, dimInd.data());
+    // We inline `readCOOElement` here in order to avoid redundant
+    // assertions, since they're guaranteed by the call to `isValid()`
+    // and the construction of `dimInd` above.
+    char *linePtr = readCOOIndices(dimInd.data());
+    const V value = detail::readCOOValue<V>(&linePtr, isPattern_);
     d2l.pushforward(dimRank, dimInd.data(), lvlInd.data());
     // TODO: <https://github.com/llvm/llvm-project/issues/54179>
     lvlCOO->add(lvlInd, value);
     // We currently chose to deal with symmetric matrices by fully
     // constructing them.  In the future, we may want to make symmetry
     // implicit for storage reasons.
-    if (stfile.isSymmetric() && lvlInd[0] != lvlInd[1])
-      lvlCOO->add({lvlInd[1], lvlInd[0]}, value);
+    if (addSymmetric && dimInd[0] != dimInd[1]) {
+      // Must recompute `lvlInd`, since arbitrary mappings don't preserve swap.
+      std::swap(dimInd[0], dimInd[1]);
+      d2l.pushforward(dimRank, dimInd.data(), lvlInd.data());
+      lvlCOO->add(lvlInd, value);
+    }
   }
-  // Close the file, convert the COO to SparseTensorStorage, and return.
-  stfile.closeFile();
-  auto *tensor = SparseTensorStorage<P, I, V>::newFromCOO(
-      dimRank, dimSizes, lvlRank, lvlTypes, lvl2dim, *lvlCOO);
-  delete lvlCOO;
-  return tensor;
+  // Close the file and return the COO.
+  closeFile();
+  return lvlCOO;
 }
 
 /// Writes the sparse tensor to `filename` in extended FROSTT format.

diff  --git a/mlir/include/mlir/ExecutionEngine/SparseTensorRuntime.h b/mlir/include/mlir/ExecutionEngine/SparseTensorRuntime.h
index c16efccfb433..558799528c4c 100644
--- a/mlir/include/mlir/ExecutionEngine/SparseTensorRuntime.h
+++ b/mlir/include/mlir/ExecutionEngine/SparseTensorRuntime.h
@@ -220,8 +220,28 @@ MLIR_SPARSETENSOR_FOREVERY_V(DECL_CONVERTFROMMLIRSPARSETENSOR)
 /// Creates a SparseTensorReader for reading a sparse tensor from a file with
 /// the given file name. This opens the file and read the header meta data based
 /// of the sparse tensor format derived from the suffix of the file name.
+//
+// FIXME: update `SparseTensorCodegenPass` to use
+// `_mlir_ciface_createCheckedSparseTensorReader` instead.
 MLIR_CRUNNERUTILS_EXPORT void *createSparseTensorReader(char *filename);
 
+/// Constructs a new SparseTensorReader object, opens the file, reads the
+/// header, and validates that the actual contents of the file match
+/// the expected `dimShapeRef` and `valTp`.
+MLIR_CRUNNERUTILS_EXPORT void *_mlir_ciface_createCheckedSparseTensorReader(
+    char *filename, StridedMemRefType<index_type, 1> *dimShapeRef,
+    PrimaryType valTp);
+
+/// Constructs a new sparse-tensor storage object with the given encoding,
+/// initializes it by reading all the elements from the file, and then
+/// closes the file.
+MLIR_CRUNNERUTILS_EXPORT void *_mlir_ciface_newSparseTensorFromReader(
+    void *p, StridedMemRefType<index_type, 1> *lvlSizesRef,
+    StridedMemRefType<DimLevelType, 1> *lvlTypesRef,
+    StridedMemRefType<index_type, 1> *lvl2dimRef,
+    StridedMemRefType<index_type, 1> *dim2lvlRef, OverheadType ptrTp,
+    OverheadType indTp, PrimaryType valTp);
+
 /// Returns the rank of the sparse tensor being read.
 MLIR_CRUNNERUTILS_EXPORT index_type getSparseTensorReaderRank(void *p);
 
@@ -235,10 +255,19 @@ MLIR_CRUNNERUTILS_EXPORT index_type getSparseTensorReaderNNZ(void *p);
 MLIR_CRUNNERUTILS_EXPORT index_type getSparseTensorReaderDimSize(void *p,
                                                                  index_type d);
 
-/// Returns all dimension sizes for the sparse tensor being read.
-MLIR_CRUNNERUTILS_EXPORT void _mlir_ciface_getSparseTensorReaderDimSizes(
+/// SparseTensorReader method to copy the dimension-sizes into the
+/// provided memref.
+//
+// FIXME: update `SparseTensorCodegenPass` to use
+// `_mlir_ciface_getSparseTensorReaderDimSizes` instead.
+MLIR_CRUNNERUTILS_EXPORT void _mlir_ciface_copySparseTensorReaderDimSizes(
     void *p, StridedMemRefType<index_type, 1> *dref);
 
+/// SparseTensorReader method to obtain direct access to the
+/// dimension-sizes array.
+MLIR_CRUNNERUTILS_EXPORT void _mlir_ciface_getSparseTensorReaderDimSizes(
+    StridedMemRefType<index_type, 1> *out, void *p);
+
 /// Releases the SparseTensorReader. This also closes the file associated with
 /// the reader.
 MLIR_CRUNNERUTILS_EXPORT void delSparseTensorReader(void *p);

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
index 7522e26b5853..40570fc42cd3 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
@@ -95,7 +95,9 @@ static Value sizeFromPtrAtDim(OpBuilder &builder, Location loc,
 static void sizesFromPtr(OpBuilder &builder, SmallVectorImpl<Value> &sizes,
                          Location loc, SparseTensorEncodingAttr &enc,
                          ShapedType stp, Value src) {
-  for (unsigned i = 0, rank = stp.getRank(); i < rank; i++)
+  unsigned rank = stp.getRank();
+  sizes.reserve(rank);
+  for (unsigned i = 0; i < rank; i++)
     sizes.push_back(sizeFromPtrAtDim(builder, loc, enc, stp, src, i));
 }
 
@@ -103,7 +105,9 @@ static void sizesFromPtr(OpBuilder &builder, SmallVectorImpl<Value> &sizes,
 static void sizesFromType(OpBuilder &builder, SmallVectorImpl<Value> &sizes,
                           Location loc, ShapedType stp) {
   auto shape = stp.getShape();
-  for (unsigned i = 0, rank = stp.getRank(); i < rank; i++) {
+  unsigned rank = stp.getRank();
+  sizes.reserve(rank);
+  for (unsigned i = 0; i < rank; i++) {
     uint64_t s = shape[i] == ShapedType::kDynamic ? 0 : shape[i];
     sizes.push_back(constantIndex(builder, loc, s));
   }
@@ -167,6 +171,17 @@ static Value genBuffer(OpBuilder &builder, Location loc, ValueRange values) {
   return buffer;
 }
 
+/// Generates a temporary buffer for the level-types of the given encoding.
+static Value genLvlTypesBuffer(OpBuilder &builder, Location loc,
+                               SparseTensorEncodingAttr enc) {
+  SmallVector<Value> lvlTypes;
+  auto dlts = enc.getDimLevelType();
+  lvlTypes.reserve(dlts.size());
+  for (auto dlt : dlts)
+    lvlTypes.push_back(constantDimLevelTypeEncoding(builder, loc, dlt));
+  return genBuffer(builder, loc, lvlTypes);
+}
+
 /// This class abstracts over the API of `_mlir_ciface_newSparseTensor`:
 /// the "swiss army knife" method of the sparse runtime support library
 /// for materializing sparse tensors into the computation.  This abstraction
@@ -262,11 +277,7 @@ NewCallParams &NewCallParams::genBuffers(SparseTensorEncodingAttr enc,
   const unsigned lvlRank = enc.getDimLevelType().size();
   const unsigned dimRank = stp.getRank();
   // Sparsity annotations.
-  SmallVector<Value> lvlTypes;
-  for (auto dlt : enc.getDimLevelType())
-    lvlTypes.push_back(constantDimLevelTypeEncoding(builder, loc, dlt));
-  assert(lvlTypes.size() == lvlRank && "Level-rank mismatch");
-  params[kParamLvlTypes] = genBuffer(builder, loc, lvlTypes);
+  params[kParamLvlTypes] = genLvlTypesBuffer(builder, loc, enc);
   // Dimension-sizes array of the enveloping tensor.  Useful for either
   // verification of external data, or for construction of internal data.
   assert(dimSizes.size() == dimRank && "Dimension-rank mismatch");
@@ -715,19 +726,98 @@ class SparseTensorNewConverter : public OpConversionPattern<NewOp> {
   matchAndRewrite(NewOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     Location loc = op.getLoc();
-    Type resType = op.getType();
-    auto enc = getSparseTensorEncoding(resType);
+    auto stp = op.getType().cast<ShapedType>();
+    auto enc = getSparseTensorEncoding(stp);
     if (!enc)
       return failure();
-    // Generate the call to construct tensor from ptr. The sizes are
-    // inferred from the result type of the new operator.
-    SmallVector<Value> sizes;
-    ShapedType stp = resType.cast<ShapedType>();
-    sizesFromType(rewriter, sizes, loc, stp);
-    Value ptr = adaptor.getOperands()[0];
-    rewriter.replaceOp(op, NewCallParams(rewriter, loc)
-                               .genBuffers(enc, sizes, stp)
-                               .genNewCall(Action::kFromFile, ptr));
+    const unsigned dimRank = stp.getRank();
+    const unsigned lvlRank = enc.getDimLevelType().size();
+    // Construct the dimShape.
+    const auto dimShape = stp.getShape();
+    SmallVector<Value> dimShapeValues;
+    sizesFromType(rewriter, dimShapeValues, loc, stp);
+    Value dimShapeBuffer = genBuffer(rewriter, loc, dimShapeValues);
+    // Allocate `SparseTensorReader` and perform all initial setup that
+    // does not depend on lvlSizes (nor dim2lvl, lvl2dim, etc).
+    Type opaqueTp = getOpaquePointerType(rewriter);
+    Value valTp =
+        constantPrimaryTypeEncoding(rewriter, loc, stp.getElementType());
+    Value reader =
+        createFuncCall(rewriter, loc, "createCheckedSparseTensorReader",
+                       opaqueTp,
+                       {adaptor.getOperands()[0], dimShapeBuffer, valTp},
+                       EmitCInterface::On)
+            .getResult(0);
+    // Construct the lvlSizes.  If the dimShape is static, then it's
+    // identical to dimSizes: so we can compute lvlSizes entirely at
+    // compile-time.  If dimShape is dynamic, then we'll need to generate
+    // code for computing lvlSizes from the `reader`'s actual dimSizes.
+    //
+    // TODO: For now we're still assuming `dim2lvl` is a permutation.
+    // But since we're computing lvlSizes here (rather than in the runtime),
+    // we can easily generalize that simply by adjusting this code.
+    //
+    // FIXME: reduce redundancy vs `NewCallParams::genBuffers`.
+    Value dimSizesBuffer;
+    if (!stp.hasStaticShape()) {
+      Type indexTp = rewriter.getIndexType();
+      auto memTp = MemRefType::get({ShapedType::kDynamic}, indexTp);
+      dimSizesBuffer =
+          createFuncCall(rewriter, loc, "getSparseTensorReaderDimSizes", memTp,
+                         reader, EmitCInterface::On)
+              .getResult(0);
+    }
+    Value lvlSizesBuffer;
+    Value lvl2dimBuffer;
+    Value dim2lvlBuffer;
+    if (auto dimOrder = enc.getDimOrdering()) {
+      assert(dimOrder.isPermutation() && "Got non-permutation");
+      // We preinitialize `dim2lvlValues` since we need random-access writing.
+      // And we preinitialize the others for stylistic consistency.
+      SmallVector<Value> lvlSizeValues(lvlRank);
+      SmallVector<Value> lvl2dimValues(lvlRank);
+      SmallVector<Value> dim2lvlValues(dimRank);
+      for (unsigned l = 0; l < lvlRank; l++) {
+        // The `d`th source variable occurs in the `l`th result position.
+        uint64_t d = dimOrder.getDimPosition(l);
+        Value lvl = constantIndex(rewriter, loc, l);
+        Value dim = constantIndex(rewriter, loc, d);
+        dim2lvlValues[d] = lvl;
+        lvl2dimValues[l] = dim;
+        lvlSizeValues[l] =
+            (dimShape[d] == ShapedType::kDynamic)
+                ? rewriter.create<memref::LoadOp>(loc, dimSizesBuffer, dim)
+                : dimShapeValues[d];
+      }
+      lvlSizesBuffer = genBuffer(rewriter, loc, lvlSizeValues);
+      lvl2dimBuffer = genBuffer(rewriter, loc, lvl2dimValues);
+      dim2lvlBuffer = genBuffer(rewriter, loc, dim2lvlValues);
+    } else {
+      assert(dimRank == lvlRank && "Rank mismatch");
+      SmallVector<Value> iotaValues;
+      iotaValues.reserve(lvlRank);
+      for (unsigned i = 0; i < lvlRank; i++)
+        iotaValues.push_back(constantIndex(rewriter, loc, i));
+      lvlSizesBuffer = dimSizesBuffer ? dimSizesBuffer : dimShapeBuffer;
+      dim2lvlBuffer = lvl2dimBuffer = genBuffer(rewriter, loc, iotaValues);
+    }
+    // Use the `reader` to parse the file.
+    SmallVector<Value, 8> params{
+        reader,
+        lvlSizesBuffer,
+        genLvlTypesBuffer(rewriter, loc, enc),
+        lvl2dimBuffer,
+        dim2lvlBuffer,
+        constantPointerTypeEncoding(rewriter, loc, enc),
+        constantIndexTypeEncoding(rewriter, loc, enc),
+        valTp};
+    Value tensor = createFuncCall(rewriter, loc, "newSparseTensorFromReader",
+                                  opaqueTp, params, EmitCInterface::On)
+                       .getResult(0);
+    // Free the memory for `reader`.
+    createFuncCall(rewriter, loc, "delSparseTensorReader", {}, {reader},
+                   EmitCInterface::Off);
+    rewriter.replaceOp(op, tensor);
     return success();
   }
 };

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index b929ed25a291..b65bcaccf40d 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -913,9 +913,8 @@ struct NewRewriter : public OpRewritePattern<NewOp> {
     Location loc = op.getLoc();
     auto dstTp = op.getResult().getType().template cast<RankedTensorType>();
     SparseTensorEncodingAttr encDst = getSparseTensorEncoding(dstTp);
-    if (!encDst) {
+    if (!encDst)
       return failure();
-    }
 
     // Create a sparse tensor reader.
     Value fileName = op.getSource();
@@ -933,7 +932,7 @@ struct NewRewriter : public OpRewritePattern<NewOp> {
     // the sparse tensor reader.
     SmallVector<Value> dynSizesArray;
     if (!dstTp.hasStaticShape()) {
-      createFuncCall(rewriter, loc, "getSparseTensorReaderDimSizes", {},
+      createFuncCall(rewriter, loc, "copySparseTensorReaderDimSizes", {},
                      {reader, dimSizes}, EmitCInterface::On)
           .getResult(0);
       ArrayRef<int64_t> dstShape = dstTp.getShape();
@@ -977,7 +976,7 @@ struct NewRewriter : public OpRewritePattern<NewOp> {
                                                    ArrayRef<Value>(cooBuffer));
     rewriter.setInsertionPointToStart(forOp.getBody());
 
-    SmallString<18> getNextFuncName{"getSparseTensorReaderNext",
+    SmallString<29> getNextFuncName{"getSparseTensorReaderNext",
                                     primaryTypeFunctionSuffix(eltTp)};
     Value indices = dimSizes; // Reuse the indices memref to store indices.
     createFuncCall(rewriter, loc, getNextFuncName, {}, {reader, indices, value},
@@ -1060,7 +1059,7 @@ struct OutRewriter : public OpRewritePattern<OutOp> {
 
     Value indices = dimSizes; // Reuse the dimSizes buffer for indices.
     Type eltTp = srcTp.getElementType();
-    SmallString<18> outNextFuncName{"outSparseTensorWriterNext",
+    SmallString<29> outNextFuncName{"outSparseTensorWriterNext",
                                     primaryTypeFunctionSuffix(eltTp)};
     Value value = genAllocaScalar(rewriter, loc, eltTp);
     ModuleOp module = op->getParentOfType<ModuleOp>();

diff  --git a/mlir/lib/ExecutionEngine/SparseTensor/File.cpp b/mlir/lib/ExecutionEngine/SparseTensor/File.cpp
index e750bf77b89c..19fbaf23687b 100644
--- a/mlir/lib/ExecutionEngine/SparseTensor/File.cpp
+++ b/mlir/lib/ExecutionEngine/SparseTensor/File.cpp
@@ -53,12 +53,11 @@ void SparseTensorReader::readLine() {
     MLIR_SPARSETENSOR_FATAL("Cannot read next line of %s\n", filename);
 }
 
-char *SparseTensorReader::readCOOIndices(uint64_t rank, uint64_t *indices) {
-  assert(rank == getRank() && "Rank mismatch");
+char *SparseTensorReader::readCOOIndices(uint64_t *indices) {
   readLine();
   // Local variable for tracking the parser's position in the `line` buffer.
   char *linePtr = line;
-  for (uint64_t r = 0; r < rank; ++r) {
+  for (uint64_t rank = getRank(), r = 0; r < rank; ++r) {
     // Parse the 1-based index.
     uint64_t idx = strtoul(linePtr, &linePtr, 10);
     // Store the 0-based index.

diff  --git a/mlir/lib/ExecutionEngine/SparseTensorRuntime.cpp b/mlir/lib/ExecutionEngine/SparseTensorRuntime.cpp
index e5ea01420293..277f531ecea6 100644
--- a/mlir/lib/ExecutionEngine/SparseTensorRuntime.cpp
+++ b/mlir/lib/ExecutionEngine/SparseTensorRuntime.cpp
@@ -243,15 +243,29 @@ fromMLIRSparseTensor(const SparseTensorStorage<uint64_t, uint64_t, V> *tensor,
 
 #define MEMREF_GET_PAYLOAD(MEMREF) ((MEMREF)->data + (MEMREF)->offset)
 
-// We make this a function rather than a macro mainly for type safety
-// reasons.  This function does not modify the vector, but it cannot
-// be marked `const` because it is stored into the non-`const` memref.
-template <typename T>
-static void vectorToMemref(std::vector<T> &v, StridedMemRefType<T, 1> &ref) {
-  ref.basePtr = ref.data = v.data();
+/// Initializes the memref with the provided size and data pointer.  This
+/// is designed for functions which want to "return" a memref that aliases
+/// into memory owned by some other object (e.g., `SparseTensorStorage`),
+/// without doing any actual copying.  (The "return" is in scarequotes
+/// because the `_mlir_ciface_` calling convention migrates any returned
+/// memrefs into an out-parameter passed before all the other function
+/// parameters.)
+///
+/// We make this a function rather than a macro mainly for type safety
+/// reasons.  This function does not modify the data pointer, but it
+/// cannot be marked `const` because it is stored into the (necessarily)
+/// non-`const` memref.  This function is templated over the `DataSizeT`
+/// to work around signedness warnings due to many data types having
+/// varying signedness across 
diff erent platforms.  The templating allows
+/// this function to ensure that it does the right thing and never
+/// introduces errors due to implicit conversions.
+template <typename DataSizeT, typename T>
+static inline void aliasIntoMemref(DataSizeT size, T *data,
+                                   StridedMemRefType<T, 1> &ref) {
+  ref.basePtr = ref.data = data;
   ref.offset = 0;
-  using SizeT = typename std::remove_reference_t<decltype(ref.sizes[0])>;
-  ref.sizes[0] = detail::checkOverflowCast<SizeT>(v.size());
+  using MemrefSizeT = typename std::remove_reference_t<decltype(ref.sizes[0])>;
+  ref.sizes[0] = detail::checkOverflowCast<MemrefSizeT>(size);
   ref.strides[0] = 1;
 }
 
@@ -272,11 +286,6 @@ extern "C" {
     case Action::kEmpty:                                                       \
       return SparseTensorStorage<P, I, V>::newEmpty(                           \
           dimRank, dimSizes, lvlRank, lvlSizes, lvlTypes, lvl2dim);            \
-    case Action::kFromFile: {                                                  \
-      char *filename = static_cast<char *>(ptr);                               \
-      return openSparseTensor<P, I, V>(dimRank, dimSizes, lvlRank, lvlTypes,   \
-                                       lvl2dim, dim2lvl, filename, v);         \
-    }                                                                          \
     case Action::kFromCOO: {                                                   \
       assert(ptr && "Received nullptr for SparseTensorCOO object");            \
       auto &coo = *static_cast<SparseTensorCOO<V> *>(ptr);                     \
@@ -468,7 +477,7 @@ void *_mlir_ciface_newSparseTensor( // NOLINT
     std::vector<V> *v;                                                         \
     static_cast<SparseTensorStorageBase *>(tensor)->getValues(&v);             \
     assert(v);                                                                 \
-    vectorToMemref(*v, *ref);                                                  \
+    aliasIntoMemref(v->size(), v->data(), *ref);                               \
   }
 MLIR_SPARSETENSOR_FOREVERY_V(IMPL_SPARSEVALUES)
 #undef IMPL_SPARSEVALUES
@@ -480,7 +489,7 @@ MLIR_SPARSETENSOR_FOREVERY_V(IMPL_SPARSEVALUES)
     std::vector<TYPE> *v;                                                      \
     static_cast<SparseTensorStorageBase *>(tensor)->LIB(&v, d);                \
     assert(v);                                                                 \
-    vectorToMemref(*v, *ref);                                                  \
+    aliasIntoMemref(v->size(), v->data(), *ref);                               \
   }
 #define IMPL_SPARSEPOINTERS(PNAME, P)                                          \
   IMPL_GETOVERHEAD(sparsePointers##PNAME, P, getPointers)
@@ -574,16 +583,37 @@ MLIR_SPARSETENSOR_FOREVERY_V(IMPL_LEXINSERT)
 MLIR_SPARSETENSOR_FOREVERY_V(IMPL_EXPINSERT)
 #undef IMPL_EXPINSERT
 
-void _mlir_ciface_getSparseTensorReaderDimSizes(
+void *_mlir_ciface_createCheckedSparseTensorReader(
+    char *filename, StridedMemRefType<index_type, 1> *dimShapeRef,
+    PrimaryType valTp) {
+  ASSERT_NO_STRIDE(dimShapeRef);
+  const uint64_t dimRank = MEMREF_GET_USIZE(dimShapeRef);
+  const index_type *dimShape = MEMREF_GET_PAYLOAD(dimShapeRef);
+  auto *reader = SparseTensorReader::create(filename, dimRank, dimShape, valTp);
+  return static_cast<void *>(reader);
+}
+
+// FIXME: update `SparseTensorCodegenPass` to use
+// `_mlir_ciface_getSparseTensorReaderDimSizes` instead.
+void _mlir_ciface_copySparseTensorReaderDimSizes(
     void *p, StridedMemRefType<index_type, 1> *dref) {
   assert(p);
+  SparseTensorReader &reader = *static_cast<SparseTensorReader *>(p);
   ASSERT_NO_STRIDE(dref);
+  const uint64_t dimRank = MEMREF_GET_USIZE(dref);
+  ASSERT_USIZE_EQ(dref, reader.getRank());
   index_type *dimSizes = MEMREF_GET_PAYLOAD(dref);
-  SparseTensorReader &file = *static_cast<SparseTensorReader *>(p);
-  const index_type *sizes = file.getDimSizes();
-  index_type rank = file.getRank();
-  for (index_type r = 0; r < rank; ++r)
-    dimSizes[r] = sizes[r];
+  const index_type *fileSizes = reader.getDimSizes();
+  for (uint64_t d = 0; d < dimRank; ++d)
+    dimSizes[d] = fileSizes[d];
+}
+
+void _mlir_ciface_getSparseTensorReaderDimSizes(
+    StridedMemRefType<index_type, 1> *out, void *p) {
+  assert(out && p);
+  SparseTensorReader &reader = *static_cast<SparseTensorReader *>(p);
+  auto *dimSizes = const_cast<uint64_t *>(reader.getDimSizes());
+  aliasIntoMemref(reader.getRank(), dimSizes, *out);
 }
 
 #define IMPL_GETNEXT(VNAME, V)                                                 \
@@ -591,16 +621,126 @@ void _mlir_ciface_getSparseTensorReaderDimSizes(
       void *p, StridedMemRefType<index_type, 1> *iref,                         \
       StridedMemRefType<V, 0> *vref) {                                         \
     assert(p &&vref);                                                          \
+    auto &reader = *static_cast<SparseTensorReader *>(p);                      \
     ASSERT_NO_STRIDE(iref);                                                    \
+    const uint64_t rank = MEMREF_GET_USIZE(iref);                              \
     index_type *indices = MEMREF_GET_PAYLOAD(iref);                            \
-    SparseTensorReader *stfile = static_cast<SparseTensorReader *>(p);         \
-    index_type rank = stfile->getRank();                                       \
     V *value = MEMREF_GET_PAYLOAD(vref);                                       \
-    *value = stfile->readCOOElement<V>(rank, indices);                         \
+    *value = reader.readCOOElement<V>(rank, indices);                          \
   }
 MLIR_SPARSETENSOR_FOREVERY_V(IMPL_GETNEXT)
 #undef IMPL_GETNEXT
 
+void *_mlir_ciface_newSparseTensorFromReader(
+    void *p, StridedMemRefType<index_type, 1> *lvlSizesRef,
+    StridedMemRefType<DimLevelType, 1> *lvlTypesRef,
+    StridedMemRefType<index_type, 1> *lvl2dimRef,
+    StridedMemRefType<index_type, 1> *dim2lvlRef, OverheadType ptrTp,
+    OverheadType indTp, PrimaryType valTp) {
+  assert(p);
+  SparseTensorReader &reader = *static_cast<SparseTensorReader *>(p);
+  ASSERT_NO_STRIDE(lvlSizesRef);
+  ASSERT_NO_STRIDE(lvlTypesRef);
+  ASSERT_NO_STRIDE(lvl2dimRef);
+  ASSERT_NO_STRIDE(dim2lvlRef);
+  const uint64_t dimRank = reader.getRank();
+  const uint64_t lvlRank = MEMREF_GET_USIZE(lvlSizesRef);
+  ASSERT_USIZE_EQ(lvlTypesRef, lvlRank);
+  ASSERT_USIZE_EQ(lvl2dimRef, lvlRank);
+  ASSERT_USIZE_EQ(dim2lvlRef, dimRank);
+  const index_type *lvlSizes = MEMREF_GET_PAYLOAD(lvlSizesRef);
+  const DimLevelType *lvlTypes = MEMREF_GET_PAYLOAD(lvlTypesRef);
+  const index_type *lvl2dim = MEMREF_GET_PAYLOAD(lvl2dimRef);
+  const index_type *dim2lvl = MEMREF_GET_PAYLOAD(dim2lvlRef);
+  //
+  // FIXME(wrengr): Really need to define a separate x-macro for handling
+  // all this. (Or ideally some better, entirely-
diff erent approach)
+#define CASE(p, i, v, P, I, V)                                                 \
+  if (ptrTp == OverheadType::p && indTp == OverheadType::i &&                  \
+      valTp == PrimaryType::v)                                                 \
+    return static_cast<void *>(reader.readSparseTensor<P, I, V>(               \
+        lvlRank, lvlSizes, lvlTypes, lvl2dim, dim2lvl));
+#define CASE_SECSAME(p, v, P, V) CASE(p, p, v, P, P, V)
+  // Rewrite kIndex to kU64, to avoid introducing a bunch of new cases.
+  // This is safe because of the static_assert above.
+  if (ptrTp == OverheadType::kIndex)
+    ptrTp = OverheadType::kU64;
+  if (indTp == OverheadType::kIndex)
+    indTp = OverheadType::kU64;
+  // Double matrices with all combinations of overhead storage.
+  CASE(kU64, kU64, kF64, uint64_t, uint64_t, double);
+  CASE(kU64, kU32, kF64, uint64_t, uint32_t, double);
+  CASE(kU64, kU16, kF64, uint64_t, uint16_t, double);
+  CASE(kU64, kU8, kF64, uint64_t, uint8_t, double);
+  CASE(kU32, kU64, kF64, uint32_t, uint64_t, double);
+  CASE(kU32, kU32, kF64, uint32_t, uint32_t, double);
+  CASE(kU32, kU16, kF64, uint32_t, uint16_t, double);
+  CASE(kU32, kU8, kF64, uint32_t, uint8_t, double);
+  CASE(kU16, kU64, kF64, uint16_t, uint64_t, double);
+  CASE(kU16, kU32, kF64, uint16_t, uint32_t, double);
+  CASE(kU16, kU16, kF64, uint16_t, uint16_t, double);
+  CASE(kU16, kU8, kF64, uint16_t, uint8_t, double);
+  CASE(kU8, kU64, kF64, uint8_t, uint64_t, double);
+  CASE(kU8, kU32, kF64, uint8_t, uint32_t, double);
+  CASE(kU8, kU16, kF64, uint8_t, uint16_t, double);
+  CASE(kU8, kU8, kF64, uint8_t, uint8_t, double);
+  // Float matrices with all combinations of overhead storage.
+  CASE(kU64, kU64, kF32, uint64_t, uint64_t, float);
+  CASE(kU64, kU32, kF32, uint64_t, uint32_t, float);
+  CASE(kU64, kU16, kF32, uint64_t, uint16_t, float);
+  CASE(kU64, kU8, kF32, uint64_t, uint8_t, float);
+  CASE(kU32, kU64, kF32, uint32_t, uint64_t, float);
+  CASE(kU32, kU32, kF32, uint32_t, uint32_t, float);
+  CASE(kU32, kU16, kF32, uint32_t, uint16_t, float);
+  CASE(kU32, kU8, kF32, uint32_t, uint8_t, float);
+  CASE(kU16, kU64, kF32, uint16_t, uint64_t, float);
+  CASE(kU16, kU32, kF32, uint16_t, uint32_t, float);
+  CASE(kU16, kU16, kF32, uint16_t, uint16_t, float);
+  CASE(kU16, kU8, kF32, uint16_t, uint8_t, float);
+  CASE(kU8, kU64, kF32, uint8_t, uint64_t, float);
+  CASE(kU8, kU32, kF32, uint8_t, uint32_t, float);
+  CASE(kU8, kU16, kF32, uint8_t, uint16_t, float);
+  CASE(kU8, kU8, kF32, uint8_t, uint8_t, float);
+  // Two-byte floats with both overheads of the same type.
+  CASE_SECSAME(kU64, kF16, uint64_t, f16);
+  CASE_SECSAME(kU64, kBF16, uint64_t, bf16);
+  CASE_SECSAME(kU32, kF16, uint32_t, f16);
+  CASE_SECSAME(kU32, kBF16, uint32_t, bf16);
+  CASE_SECSAME(kU16, kF16, uint16_t, f16);
+  CASE_SECSAME(kU16, kBF16, uint16_t, bf16);
+  CASE_SECSAME(kU8, kF16, uint8_t, f16);
+  CASE_SECSAME(kU8, kBF16, uint8_t, bf16);
+  // Integral matrices with both overheads of the same type.
+  CASE_SECSAME(kU64, kI64, uint64_t, int64_t);
+  CASE_SECSAME(kU64, kI32, uint64_t, int32_t);
+  CASE_SECSAME(kU64, kI16, uint64_t, int16_t);
+  CASE_SECSAME(kU64, kI8, uint64_t, int8_t);
+  CASE_SECSAME(kU32, kI64, uint32_t, int64_t);
+  CASE_SECSAME(kU32, kI32, uint32_t, int32_t);
+  CASE_SECSAME(kU32, kI16, uint32_t, int16_t);
+  CASE_SECSAME(kU32, kI8, uint32_t, int8_t);
+  CASE_SECSAME(kU16, kI64, uint16_t, int64_t);
+  CASE_SECSAME(kU16, kI32, uint16_t, int32_t);
+  CASE_SECSAME(kU16, kI16, uint16_t, int16_t);
+  CASE_SECSAME(kU16, kI8, uint16_t, int8_t);
+  CASE_SECSAME(kU8, kI64, uint8_t, int64_t);
+  CASE_SECSAME(kU8, kI32, uint8_t, int32_t);
+  CASE_SECSAME(kU8, kI16, uint8_t, int16_t);
+  CASE_SECSAME(kU8, kI8, uint8_t, int8_t);
+  // Complex matrices with wide overhead.
+  CASE_SECSAME(kU64, kC64, uint64_t, complex64);
+  CASE_SECSAME(kU64, kC32, uint64_t, complex32);
+
+  // Unsupported case (add above if needed).
+  // TODO: better pretty-printing of enum values!
+  MLIR_SPARSETENSOR_FATAL(
+      "unsupported combination of types: <P=%d, I=%d, V=%d>\n",
+      static_cast<int>(ptrTp), static_cast<int>(indTp),
+      static_cast<int>(valTp));
+#undef CASE_SECSAME
+#undef CASE
+}
+
 void _mlir_ciface_outSparseTensorWriterMetaData(
     void *p, index_type rank, index_type nnz,
     StridedMemRefType<index_type, 1> *dref) {
@@ -686,14 +826,14 @@ char *getTensorFilename(index_type id) {
 
 void readSparseTensorShape(char *filename, std::vector<uint64_t> *out) {
   assert(out && "Received nullptr for out-parameter");
-  SparseTensorReader stfile(filename);
-  stfile.openFile();
-  stfile.readHeader();
-  stfile.closeFile();
-  const uint64_t rank = stfile.getRank();
-  const uint64_t *dimSizes = stfile.getDimSizes();
-  out->reserve(rank);
-  out->assign(dimSizes, dimSizes + rank);
+  SparseTensorReader reader(filename);
+  reader.openFile();
+  reader.readHeader();
+  reader.closeFile();
+  const uint64_t dimRank = reader.getRank();
+  const uint64_t *dimSizes = reader.getDimSizes();
+  out->reserve(dimRank);
+  out->assign(dimSizes, dimSizes + dimRank);
 }
 
 // We can't use `static_cast` here because `DimLevelType` is an enum-class.
@@ -718,11 +858,13 @@ MLIR_SPARSETENSOR_FOREVERY_V(IMPL_CONVERTTOMLIRSPARSETENSOR)
 MLIR_SPARSETENSOR_FOREVERY_V(IMPL_CONVERTFROMMLIRSPARSETENSOR)
 #undef IMPL_CONVERTFROMMLIRSPARSETENSOR
 
+// FIXME: update `SparseTensorCodegenPass` to use
+// `_mlir_ciface_createCheckedSparseTensorReader` instead.
 void *createSparseTensorReader(char *filename) {
-  SparseTensorReader *stfile = new SparseTensorReader(filename);
-  stfile->openFile();
-  stfile->readHeader();
-  return static_cast<void *>(stfile);
+  SparseTensorReader *reader = new SparseTensorReader(filename);
+  reader->openFile();
+  reader->readHeader();
+  return static_cast<void *>(reader);
 }
 
 index_type getSparseTensorReaderRank(void *p) {

diff  --git a/mlir/test/Dialect/SparseTensor/conversion.mlir b/mlir/test/Dialect/SparseTensor/conversion.mlir
index 0dea82678612..226406635316 100644
--- a/mlir/test/Dialect/SparseTensor/conversion.mlir
+++ b/mlir/test/Dialect/SparseTensor/conversion.mlir
@@ -77,16 +77,15 @@ func.func @sparse_dim3d_const(%arg0: tensor<10x20x30xf64, #SparseTensor>) -> ind
 
 // CHECK-LABEL: func @sparse_new1d(
 //  CHECK-SAME: %[[A:.*]]: !llvm.ptr<i8>) -> !llvm.ptr<i8>
-//   CHECK-DAG: %[[FromFile:.*]] = arith.constant 1 : i32
-//   CHECK-DAG: %[[DimSizes0:.*]] = memref.alloca() : memref<1xindex>
-//   CHECK-DAG: %[[LvlSizes0:.*]] = memref.alloca() : memref<1xindex>
-//   CHECK-DAG: %[[LvlTypes0:.*]] = memref.alloca() : memref<1xi8>
+//   CHECK-DAG: %[[DimShape0:.*]] = memref.alloca() : memref<1xindex>
+//   CHECK-DAG: %[[DimShape:.*]] = memref.cast %[[DimShape0]] : memref<1xindex> to memref<?xindex>
+//       CHECK: %[[Reader:.*]] = call @createCheckedSparseTensorReader(%[[A]], %[[DimShape]], %{{.*}})
 //   CHECK-DAG: %[[Iota0:.*]] = memref.alloca() : memref<1xindex>
-//   CHECK-DAG: %[[DimSizes:.*]] = memref.cast %[[DimSizes0]] : memref<1xindex> to memref<?xindex>
-//   CHECK-DAG: %[[LvlSizes:.*]] = memref.cast %[[LvlSizes0]] : memref<1xindex> to memref<?xindex>
-//   CHECK-DAG: %[[LvlTypes:.*]] = memref.cast %[[LvlTypes0]] : memref<1xi8> to memref<?xi8>
 //   CHECK-DAG: %[[Iota:.*]] = memref.cast %[[Iota0]] : memref<1xindex> to memref<?xindex>
-//       CHECK: %[[T:.*]] = call @newSparseTensor(%[[DimSizes]], %[[LvlSizes]], %[[LvlTypes]], %[[Iota]], %[[Iota]], %{{.*}}, %{{.*}}, %{{.*}}, %[[FromFile]], %[[A]])
+//   CHECK-DAG: %[[LvlTypes0:.*]] = memref.alloca() : memref<1xi8>
+//   CHECK-DAG: %[[LvlTypes:.*]] = memref.cast %[[LvlTypes0]] : memref<1xi8> to memref<?xi8>
+//       CHECK: %[[T:.*]] = call @newSparseTensorFromReader(%[[Reader]], %[[DimShape]], %[[LvlTypes]], %[[Iota]], %[[Iota]], %{{.*}}, %{{.*}}, %{{.*}})
+//       CHECK: call @delSparseTensorReader(%[[Reader]])
 //       CHECK: return %[[T]] : !llvm.ptr<i8>
 func.func @sparse_new1d(%arg0: !llvm.ptr<i8>) -> tensor<128xf64, #SparseVector> {
   %0 = sparse_tensor.new %arg0 : !llvm.ptr<i8> to tensor<128xf64, #SparseVector>
@@ -95,16 +94,16 @@ func.func @sparse_new1d(%arg0: !llvm.ptr<i8>) -> tensor<128xf64, #SparseVector>
 
 // CHECK-LABEL: func @sparse_new2d(
 //  CHECK-SAME: %[[A:.*]]: !llvm.ptr<i8>) -> !llvm.ptr<i8>
-//   CHECK-DAG: %[[FromFile:.*]] = arith.constant 1 : i32
-//   CHECK-DAG: %[[DimSizes0:.*]] = memref.alloca() : memref<2xindex>
-//   CHECK-DAG: %[[LvlSizes0:.*]] = memref.alloca() : memref<2xindex>
-//   CHECK-DAG: %[[LvlTypes0:.*]] = memref.alloca() : memref<2xi8>
+//   CHECK-DAG: %[[DimShape0:.*]] = memref.alloca() : memref<2xindex>
+//   CHECK-DAG: %[[DimShape:.*]] = memref.cast %[[DimShape0]] : memref<2xindex> to memref<?xindex>
+//       CHECK: %[[Reader:.*]] = call @createCheckedSparseTensorReader(%[[A]], %[[DimShape]], %{{.*}})
+//       CHECK: %[[DimSizes:.*]] = call @getSparseTensorReaderDimSizes(%[[Reader]])
 //   CHECK-DAG: %[[Iota0:.*]] = memref.alloca() : memref<2xindex>
-//   CHECK-DAG: %[[DimSizes:.*]] = memref.cast %[[DimSizes0]] : memref<2xindex> to memref<?xindex>
-//   CHECK-DAG: %[[LvlSizes:.*]] = memref.cast %[[LvlSizes0]] : memref<2xindex> to memref<?xindex>
-//   CHECK-DAG: %[[LvlTypes:.*]] = memref.cast %[[LvlTypes0]] : memref<2xi8> to memref<?xi8>
 //   CHECK-DAG: %[[Iota:.*]] = memref.cast %[[Iota0]] : memref<2xindex> to memref<?xindex>
-//       CHECK: %[[T:.*]] = call @newSparseTensor(%[[DimSizes]], %[[LvlSizes]], %[[LvlTypes]], %[[Iota]], %[[Iota]], %{{.*}}, %{{.*}}, %{{.*}}, %[[FromFile]], %[[A]])
+//   CHECK-DAG: %[[LvlTypes0:.*]] = memref.alloca() : memref<2xi8>
+//   CHECK-DAG: %[[LvlTypes:.*]] = memref.cast %[[LvlTypes0]] : memref<2xi8> to memref<?xi8>
+//       CHECK: %[[T:.*]] = call @newSparseTensorFromReader(%[[Reader]], %[[DimSizes]], %[[LvlTypes]], %[[Iota]], %[[Iota]], %{{.*}}, %{{.*}}, %{{.*}})
+//       CHECK: call @delSparseTensorReader(%[[Reader]])
 //       CHECK: return %[[T]] : !llvm.ptr<i8>
 func.func @sparse_new2d(%arg0: !llvm.ptr<i8>) -> tensor<?x?xf32, #CSR> {
   %0 = sparse_tensor.new %arg0 : !llvm.ptr<i8> to tensor<?x?xf32, #CSR>
@@ -113,18 +112,20 @@ func.func @sparse_new2d(%arg0: !llvm.ptr<i8>) -> tensor<?x?xf32, #CSR> {
 
 // CHECK-LABEL: func @sparse_new3d(
 //  CHECK-SAME: %[[A:.*]]: !llvm.ptr<i8>) -> !llvm.ptr<i8>
-//   CHECK-DAG: %[[FromFile:.*]] = arith.constant 1 : i32
-//   CHECK-DAG: %[[DimSizes0:.*]] = memref.alloca() : memref<3xindex>
+//   CHECK-DAG: %[[DimShape0:.*]] = memref.alloca() : memref<3xindex>
+//   CHECK-DAG: %[[DimShape:.*]] = memref.cast %[[DimShape0]] : memref<3xindex> to memref<?xindex>
+//       CHECK: %[[Reader:.*]] = call @createCheckedSparseTensorReader(%[[A]], %[[DimShape]], %{{.*}})
+//       CHECK: %[[DimSizes:.*]] = call @getSparseTensorReaderDimSizes(%[[Reader]])
 //   CHECK-DAG: %[[LvlSizes0:.*]] = memref.alloca() : memref<3xindex>
-//   CHECK-DAG: %[[LvlTypes0:.*]] = memref.alloca() : memref<3xi8>
-//   CHECK-DAG: %[[Lvl2Dim0:.*]] = memref.alloca() : memref<3xindex>
-//   CHECK-DAG: %[[Dim2Lvl0:.*]] = memref.alloca() : memref<3xindex>
-//   CHECK-DAG: %[[DimSizes:.*]] = memref.cast %[[DimSizes0]] : memref<3xindex> to memref<?xindex>
 //   CHECK-DAG: %[[LvlSizes:.*]] = memref.cast %[[LvlSizes0]] : memref<3xindex> to memref<?xindex>
-//   CHECK-DAG: %[[LvlTypes:.*]] = memref.cast %[[LvlTypes0]] : memref<3xi8> to memref<?xi8>
+//   CHECK-DAG: %[[Lvl2Dim0:.*]] = memref.alloca() : memref<3xindex>
 //   CHECK-DAG: %[[Lvl2Dim:.*]] = memref.cast %[[Lvl2Dim0]] : memref<3xindex> to memref<?xindex>
+//   CHECK-DAG: %[[Dim2Lvl0:.*]] = memref.alloca() : memref<3xindex>
 //   CHECK-DAG: %[[Dim2Lvl:.*]] = memref.cast %[[Dim2Lvl0]] : memref<3xindex> to memref<?xindex>
-//       CHECK: %[[T:.*]] = call @newSparseTensor(%[[DimSizes]], %[[LvlSizes]], %[[LvlTypes]], %[[Lvl2Dim]], %[[Dim2Lvl]], %{{.*}}, %{{.*}}, %{{.*}}, %[[FromFile]], %[[A]])
+//   CHECK-DAG: %[[LvlTypes0:.*]] = memref.alloca() : memref<3xi8>
+//   CHECK-DAG: %[[LvlTypes:.*]] = memref.cast %[[LvlTypes0]] : memref<3xi8> to memref<?xi8>
+//       CHECK: %[[T:.*]] = call @newSparseTensorFromReader(%[[Reader]], %[[LvlSizes]], %[[LvlTypes]], %[[Lvl2Dim]], %[[Dim2Lvl]], %{{.*}}, %{{.*}}, %{{.*}})
+//       CHECK: call @delSparseTensorReader(%[[Reader]])
 //       CHECK: return %[[T]] : !llvm.ptr<i8>
 func.func @sparse_new3d(%arg0: !llvm.ptr<i8>) -> tensor<?x?x?xf32, #SparseTensor> {
   %0 = sparse_tensor.new %arg0 : !llvm.ptr<i8> to tensor<?x?x?xf32, #SparseTensor>

diff  --git a/mlir/test/Dialect/SparseTensor/rewriting_for_codegen.mlir b/mlir/test/Dialect/SparseTensor/rewriting_for_codegen.mlir
index 8c69066fb176..ceb9c085dca5 100644
--- a/mlir/test/Dialect/SparseTensor/rewriting_for_codegen.mlir
+++ b/mlir/test/Dialect/SparseTensor/rewriting_for_codegen.mlir
@@ -12,7 +12,7 @@
 // CHECK-DAG:     %[[C1:.*]] = arith.constant 1 : index
 // CHECK:         %[[R:.*]] = call @createSparseTensorReader(%[[A]])
 // CHECK:         %[[DS:.*]] = memref.alloca(%[[C2]]) : memref<?xindex>
-// CHECK:         call @getSparseTensorReaderDimSizes(%[[R]], %[[DS]])
+// CHECK:         call @copySparseTensorReaderDimSizes(%[[R]], %[[DS]])
 // CHECK:         %[[D0:.*]] = memref.load %[[DS]]{{\[}}%[[C0]]]
 // CHECK:         %[[D1:.*]] = memref.load %[[DS]]{{\[}}%[[C1]]]
 // CHECK:         %[[T:.*]] = bufferization.alloc_tensor(%[[D0]], %[[D1]])
@@ -51,7 +51,7 @@ func.func @sparse_new_symmetry(%arg0: !llvm.ptr<i8>) -> tensor<?x?xf32, #CSR> {
 // CHECK-DAG:     %[[C1:.*]] = arith.constant 1 : index
 // CHECK:         %[[R:.*]] = call @createSparseTensorReader(%[[A]])
 // CHECK:         %[[DS:.*]] = memref.alloca(%[[C2]]) : memref<?xindex>
-// CHECK:         call @getSparseTensorReaderDimSizes(%[[R]], %[[DS]])
+// CHECK:         call @copySparseTensorReaderDimSizes(%[[R]], %[[DS]])
 // CHECK:         %[[D0:.*]] = memref.load %[[DS]]{{\[}}%[[C0]]]
 // CHECK:         %[[D1:.*]] = memref.load %[[DS]]{{\[}}%[[C1]]]
 // CHECK:         %[[T:.*]] = bufferization.alloc_tensor(%[[D0]], %[[D1]])

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_file_io.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_file_io.mlir
index 06cca8a02e6e..595c7718d86a 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_file_io.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_file_io.mlir
@@ -26,7 +26,7 @@ module {
   func.func private @getSparseTensorReaderRank(!TensorReader) -> (index)
   func.func private @getSparseTensorReaderNNZ(!TensorReader) -> (index)
   func.func private @getSparseTensorReaderIsSymmetric(!TensorReader) -> (i1)
-  func.func private @getSparseTensorReaderDimSizes(!TensorReader,
+  func.func private @copySparseTensorReaderDimSizes(!TensorReader,
     memref<?xindex>) -> () attributes { llvm.emit_c_interface }
   func.func private @getSparseTensorReaderNextF32(!TensorReader,
     memref<?xindex>, memref<f32>) -> () attributes { llvm.emit_c_interface }
@@ -98,7 +98,7 @@ module {
       : (!TensorReader) -> i1
     vector.print %symmetric : i1
     %dimSizes = memref.alloc(%rank) : memref<?xindex>
-    func.call @getSparseTensorReaderDimSizes(%tensor, %dimSizes)
+    func.call @copySparseTensorReaderDimSizes(%tensor, %dimSizes)
       : (!TensorReader, memref<?xindex>) -> ()
     call @dumpi(%dimSizes) : (memref<?xindex>) -> ()
     %x0s, %x1s, %vs = call @readTensorFile(%tensor)
@@ -132,7 +132,7 @@ module {
     %rank = call @getSparseTensorReaderRank(%tensor0) : (!TensorReader) -> index
     %nnz = call @getSparseTensorReaderNNZ(%tensor0) : (!TensorReader) -> index
     %dimSizes = memref.alloc(%rank) : memref<?xindex>
-    func.call @getSparseTensorReaderDimSizes(%tensor0,%dimSizes)
+    func.call @copySparseTensorReaderDimSizes(%tensor0, %dimSizes)
       : (!TensorReader, memref<?xindex>) -> ()
     call @outSparseTensorWriterMetaData(%tensor1, %rank, %nnz, %dimSizes)
       : (!TensorWriter, index, index, memref<?xindex>) -> ()


        


More information about the Mlir-commits mailing list