[Mlir-commits] [mlir] b24788a - [mlir][sparse] implement sparse tensor init operation

Aart Bik llvmlistbot at llvm.org
Fri Oct 15 09:33:30 PDT 2021


Author: Aart Bik
Date: 2021-10-15T09:33:16-07:00
New Revision: b24788abd8df02169ecbf6afa91836819c8a35fe

URL: https://github.com/llvm/llvm-project/commit/b24788abd8df02169ecbf6afa91836819c8a35fe
DIFF: https://github.com/llvm/llvm-project/commit/b24788abd8df02169ecbf6afa91836819c8a35fe.diff

LOG: [mlir][sparse] implement sparse tensor init operation

Next step towards supporting sparse tensors outputs.
Also some minor refactoring of enum constants as well
as replacing tensor arguments with proper buffer arguments
(latter is required for more general sizes arguments for
the sparse_tensor.init operation, as well as more general
spares_tensor.convert operations later)

Reviewed By: wrengr

Differential Revision: https://reviews.llvm.org/D111771

Added: 
    

Modified: 
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
    mlir/lib/ExecutionEngine/SparseUtils.cpp
    mlir/test/Dialect/SparseTensor/conversion.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
index f19db7d072a3..e98b1fa26173 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
@@ -29,6 +29,16 @@ using namespace mlir::sparse_tensor;
 
 namespace {
 
+/// New tensor storage action. Keep these values consistent with
+/// the sparse runtime support library.
+enum Action : uint32_t {
+  kEmpty = 0,
+  kFromFile = 1,
+  kFromCOO = 2,
+  kEmptyCOO = 3,
+  kToCOO = 4
+};
+
 //===----------------------------------------------------------------------===//
 // Helper methods.
 //===----------------------------------------------------------------------===//
@@ -105,18 +115,10 @@ inline static Value constantI32(ConversionPatternRewriter &rewriter,
   return rewriter.create<arith::ConstantIntOp>(loc, i, 32);
 }
 
-/// Returns integers of given width and values as a constant tensor.
-/// We cast the static shape into a dynamic shape to ensure that the
-/// method signature remains uniform across 
diff erent tensor dimensions.
-static Value getTensor(ConversionPatternRewriter &rewriter, unsigned width,
-                       Location loc, ArrayRef<APInt> values) {
-  Type etp = rewriter.getIntegerType(width);
-  unsigned sz = values.size();
-  RankedTensorType tt1 = RankedTensorType::get({sz}, etp);
-  RankedTensorType tt2 = RankedTensorType::get({ShapedType::kDynamicSize}, etp);
-  auto elts = rewriter.create<arith::ConstantOp>(
-      loc, DenseElementsAttr::get(tt1, values));
-  return rewriter.create<tensor::CastOp>(loc, tt2, elts);
+/// Generates a constant of `i8` type.
+inline static Value constantI8(ConversionPatternRewriter &rewriter,
+                               Location loc, int8_t i) {
+  return rewriter.create<arith::ConstantIntOp>(loc, i, 8);
 }
 
 /// Returns a function reference (first hit also inserts into module). Sets
@@ -142,43 +144,70 @@ static FlatSymbolRefAttr getFunc(Operation *op, StringRef name,
   return result;
 }
 
+/// Generates a temporary buffer of the given size and type.
+static Value genAlloca(ConversionPatternRewriter &rewriter, Location loc,
+                       unsigned sz, Type tp) {
+  auto memTp = MemRefType::get({ShapedType::kDynamicSize}, tp);
+  Value a = constantIndex(rewriter, loc, sz);
+  return rewriter.create<memref::AllocaOp>(loc, memTp, ValueRange{a});
+}
+
+/// Fills a temporary buffer of the given type with arguments.
+static Value genBuffer(ConversionPatternRewriter &rewriter, Location loc,
+                       ArrayRef<Value> values) {
+  unsigned sz = values.size();
+  assert(sz >= 1);
+  Value buffer = genAlloca(rewriter, loc, sz, values[0].getType());
+  for (unsigned i = 0; i < sz; i++) {
+    Value idx = constantIndex(rewriter, loc, i);
+    rewriter.create<memref::StoreOp>(loc, values[i], buffer, idx);
+  }
+  return buffer;
+}
+
 /// Generates a call into the "swiss army knife" method of the sparse runtime
 /// support library for materializing sparse tensors into the computation. The
 /// method returns the call value and assigns the permutation to 'perm'.
 static Value genNewCall(ConversionPatternRewriter &rewriter, Operation *op,
                         SparseTensorEncodingAttr &enc, uint32_t action,
-                        Value &perm, Value ptr = Value()) {
+                        Value &perm, ValueRange szs, Value ptr = Value()) {
   Location loc = op->getLoc();
   ShapedType resType = op->getResult(0).getType().cast<ShapedType>();
   SmallVector<Value, 8> params;
   // Sparsity annotations in tensor constant form.
-  SmallVector<APInt, 4> attrs;
-  unsigned sz = enc.getDimLevelType().size();
+  SmallVector<Value, 4> attrs;
+  ArrayRef<SparseTensorEncodingAttr::DimLevelType> dlt = enc.getDimLevelType();
+  unsigned sz = dlt.size();
   for (unsigned i = 0; i < sz; i++)
-    attrs.push_back(
-        APInt(8, getDimLevelTypeEncoding(enc.getDimLevelType()[i])));
-  params.push_back(getTensor(rewriter, 8, loc, attrs));
+    attrs.push_back(constantI8(rewriter, loc, getDimLevelTypeEncoding(dlt[i])));
+  params.push_back(genBuffer(rewriter, loc, attrs));
   // Dimension sizes array of the enveloping *dense* tensor. Useful for either
   // verification of external data, or for construction of internal data.
   auto shape = resType.getShape();
-  SmallVector<APInt, 4> sizes;
-  for (unsigned i = 0; i < sz; i++) {
-    uint64_t s = shape[i] == ShapedType::kDynamicSize ? 0 : shape[i];
-    sizes.push_back(APInt(64, s));
+  SmallVector<Value, 4> sizes;
+  if (szs.size() > 0) {
+    for (Value s : szs)
+      sizes.push_back(
+          rewriter.create<arith::IndexCastOp>(loc, s, rewriter.getI64Type()));
+  } else {
+    for (unsigned i = 0; i < sz; i++) {
+      uint64_t s = shape[i] == ShapedType::kDynamicSize ? 0 : shape[i];
+      sizes.push_back(constantI64(rewriter, loc, s));
+    }
   }
-  params.push_back(getTensor(rewriter, 64, loc, sizes));
+  params.push_back(genBuffer(rewriter, loc, sizes));
   // Dimension order permutation array. This is the "identity" permutation by
   // default, or otherwise the "reverse" permutation of a given ordering, so
   // that indices can be mapped quickly to the right position.
-  SmallVector<APInt, 4> rev(sz);
+  SmallVector<Value, 4> rev(sz);
   if (AffineMap p = enc.getDimOrdering()) {
     for (unsigned i = 0; i < sz; i++)
-      rev[p.getDimPosition(i)] = APInt(64, i);
+      rev[p.getDimPosition(i)] = constantI64(rewriter, loc, i);
   } else {
     for (unsigned i = 0; i < sz; i++)
-      rev[i] = APInt(64, i);
+      rev[i] = constantI64(rewriter, loc, i);
   }
-  perm = getTensor(rewriter, 64, loc, rev);
+  perm = genBuffer(rewriter, loc, rev);
   params.push_back(perm);
   // Secondary and primary types encoding.
   unsigned secPtr = getOverheadTypeEncoding(enc.getPointerBitWidth());
@@ -309,18 +338,6 @@ static Value genIndexAndValueForSparse(ConversionPatternRewriter &rewriter,
   return rewriter.create<tensor::ExtractOp>(loc, values, ivs[0]);
 }
 
-/// Generates code to stack-allocate a `memref<?xindex>` where the `?`
-/// is the given `rank`.  This array is intended to serve as a reusable
-/// buffer for storing the indices of a single tensor element, to avoid
-/// allocation in the body of loops.
-static Value allocaIndices(ConversionPatternRewriter &rewriter, Location loc,
-                           int64_t rank) {
-  auto indexTp = rewriter.getIndexType();
-  auto memTp = MemRefType::get({ShapedType::kDynamicSize}, indexTp);
-  Value arg = constantIndex(rewriter, loc, rank);
-  return rewriter.create<memref::AllocaOp>(loc, memTp, ValueRange{arg});
-}
-
 //===----------------------------------------------------------------------===//
 // Conversion rules.
 //===----------------------------------------------------------------------===//
@@ -378,8 +395,25 @@ class SparseTensorNewConverter : public OpConversionPattern<NewOp> {
     if (!enc)
       return failure();
     Value perm;
+    rewriter.replaceOp(op, genNewCall(rewriter, op, enc, kFromFile, perm, {},
+                                      adaptor.getOperands()[0]));
+    return success();
+  }
+};
+
+/// Sparse conversion rule for the init operator.
+class SparseTensorInitConverter : public OpConversionPattern<InitOp> {
+  using OpConversionPattern::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(InitOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Type resType = op.getType();
+    auto enc = getSparseTensorEncoding(resType);
+    if (!enc)
+      return failure();
+    Value perm;
     rewriter.replaceOp(
-        op, genNewCall(rewriter, op, enc, 0, perm, adaptor.getOperands()[0]));
+        op, genNewCall(rewriter, op, enc, kEmpty, perm, adaptor.getOperands()));
     return success();
   }
 };
@@ -402,8 +436,9 @@ class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> {
       // yield the fastest conversion but avoids the need for a full
       // O(N^2) conversion matrix.
       Value perm;
-      Value coo = genNewCall(rewriter, op, encDst, 3, perm, src);
-      rewriter.replaceOp(op, genNewCall(rewriter, op, encDst, 1, perm, coo));
+      Value coo = genNewCall(rewriter, op, encDst, kToCOO, perm, {}, src);
+      rewriter.replaceOp(
+          op, genNewCall(rewriter, op, encDst, kFromCOO, perm, {}, coo));
       return success();
     }
     if (!encDst || encSrc) {
@@ -439,8 +474,9 @@ class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> {
     Location loc = op->getLoc();
     ShapedType shape = resType.cast<ShapedType>();
     Value perm;
-    Value ptr = genNewCall(rewriter, op, encDst, 2, perm);
-    Value ind = allocaIndices(rewriter, loc, shape.getRank());
+    Value ptr = genNewCall(rewriter, op, encDst, kEmptyCOO, perm, {});
+    Value ind =
+        genAlloca(rewriter, loc, shape.getRank(), rewriter.getIndexType());
     SmallVector<Value> lo;
     SmallVector<Value> hi;
     SmallVector<Value> st;
@@ -478,7 +514,8 @@ class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> {
           genAddEltCall(rewriter, op, eltType, ptr, val, ind, perm);
           return {};
         });
-    rewriter.replaceOp(op, genNewCall(rewriter, op, encDst, 1, perm, ptr));
+    rewriter.replaceOp(
+        op, genNewCall(rewriter, op, encDst, kFromCOO, perm, {}, ptr));
     return success();
   }
 };
@@ -637,9 +674,9 @@ class SparseTensorToTensorConverter : public OpConversionPattern<ToTensorOp> {
 void mlir::populateSparseTensorConversionPatterns(TypeConverter &typeConverter,
                                                   RewritePatternSet &patterns) {
   patterns.add<SparseReturnConverter, SparseTensorToDimSizeConverter,
-               SparseTensorNewConverter, SparseTensorConvertConverter,
-               SparseTensorReleaseConverter, SparseTensorToPointersConverter,
-               SparseTensorToIndicesConverter, SparseTensorToValuesConverter,
-               SparseTensorToTensorConverter>(typeConverter,
-                                              patterns.getContext());
+               SparseTensorNewConverter, SparseTensorInitConverter,
+               SparseTensorConvertConverter, SparseTensorReleaseConverter,
+               SparseTensorToPointersConverter, SparseTensorToIndicesConverter,
+               SparseTensorToValuesConverter, SparseTensorToTensorConverter>(
+      typeConverter, patterns.getContext());
 }

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
index fdf506e8d810..9875a5c58ba7 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
@@ -97,8 +97,8 @@ struct SparseTensorConversionPass
     RewritePatternSet patterns(ctx);
     SparseTensorTypeConverter converter;
     ConversionTarget target(*ctx);
-    target.addIllegalOp<NewOp, ConvertOp, ToPointersOp, ToIndicesOp, ToValuesOp,
-                        ToTensorOp>();
+    target.addIllegalOp<ConvertOp, NewOp, ToIndicesOp, ToPointersOp, ToTensorOp,
+                        ToValuesOp>();
     // All dynamic rules below accept new function, call, return, and dimop
     // operations as legal output of the rewriting provided that all sparse
     // tensor types have been fully rewritten.
@@ -114,11 +114,10 @@ struct SparseTensorConversionPass
     });
     // The following operations and dialects may be introduced by the
     // rewriting rules, and are therefore marked as legal.
-    target.addLegalOp<arith::ConstantOp, ConstantOp, arith::IndexCastOp,
-                      tensor::CastOp, tensor::ExtractOp, arith::CmpFOp,
-                      arith::CmpIOp>();
-    target.addLegalDialect<scf::SCFDialect, LLVM::LLVMDialect,
-                           memref::MemRefDialect>();
+    target.addLegalOp<arith::CmpFOp, arith::CmpIOp, arith::ConstantOp,
+                      arith::IndexCastOp, tensor::CastOp, tensor::ExtractOp>();
+    target.addLegalDialect<LLVM::LLVMDialect, memref::MemRefDialect,
+                           scf::SCFDialect>();
     // Populate with rules and apply rewriting rules.
     populateFuncOpTypeConversionPattern(patterns, converter);
     populateCallOpTypeConversionPattern(patterns, converter);

diff  --git a/mlir/lib/ExecutionEngine/SparseUtils.cpp b/mlir/lib/ExecutionEngine/SparseUtils.cpp
index b30804e69433..37c9b1314525 100644
--- a/mlir/lib/ExecutionEngine/SparseUtils.cpp
+++ b/mlir/lib/ExecutionEngine/SparseUtils.cpp
@@ -37,9 +37,9 @@
 // (a) A coordinate scheme for temporarily storing and lexicographically
 //     sorting a sparse tensor by index (SparseTensorCOO).
 //
-// (b) A "one-size-fits-all" sparse tensor storage scheme defined by per-rank
-//     sparse/dense annnotations together with a dimension ordering to be
-//     used by MLIR compiler-generated code (SparseTensorStorage).
+// (b) A "one-size-fits-all" sparse tensor storage scheme defined by
+//     per-dimension sparse/dense annnotations together with a dimension
+//     ordering used by MLIR compiler-generated code (SparseTensorStorage).
 //
 // The following external formats are supported:
 //
@@ -93,8 +93,9 @@ struct SparseTensorCOO {
   }
   /// Adds element as indices and value.
   void add(const std::vector<uint64_t> &ind, V val) {
-    assert(getRank() == ind.size());
-    for (uint64_t r = 0, rank = getRank(); r < rank; r++)
+    uint64_t rank = getRank();
+    assert(rank == ind.size());
+    for (uint64_t r = 0; r < rank; r++)
       assert(ind[r] < sizes[r]); // within bounds
     elements.emplace_back(ind, val);
   }
@@ -111,12 +112,12 @@ struct SparseTensorCOO {
   /// the given ordering and expects subsequent add() calls to honor
   /// that same ordering for the given indices. The result is a
   /// fully permuted coordinate scheme.
-  static SparseTensorCOO<V> *newSparseTensorCOO(uint64_t size,
+  static SparseTensorCOO<V> *newSparseTensorCOO(uint64_t rank,
                                                 const uint64_t *sizes,
                                                 const uint64_t *perm,
                                                 uint64_t capacity = 0) {
-    std::vector<uint64_t> permsz(size);
-    for (uint64_t r = 0; r < size; r++)
+    std::vector<uint64_t> permsz(rank);
+    for (uint64_t r = 0; r < rank; r++)
       permsz[perm[r]] = sizes[r];
     return new SparseTensorCOO<V>(permsz, capacity);
   }
@@ -124,15 +125,16 @@ struct SparseTensorCOO {
 private:
   /// Returns true if indices of e1 < indices of e2.
   static bool lexOrder(const Element<V> &e1, const Element<V> &e2) {
-    assert(e1.indices.size() == e2.indices.size());
-    for (uint64_t r = 0, rank = e1.indices.size(); r < rank; r++) {
+    uint64_t rank = e1.indices.size();
+    assert(rank == e2.indices.size());
+    for (uint64_t r = 0; r < rank; r++) {
       if (e1.indices[r] == e2.indices[r])
         continue;
       return e1.indices[r] < e2.indices[r];
     }
     return false;
   }
-  std::vector<uint64_t> sizes; // per-rank dimension sizes
+  std::vector<uint64_t> sizes; // per-dimension sizes
   std::vector<Element<V>> elements;
 };
 
@@ -171,45 +173,47 @@ class SparseTensorStorageBase {
   }
 };
 
-/// A memory-resident sparse tensor using a storage scheme based on per-rank
-/// annotations on dense/sparse. This data structure provides a bufferized
-/// form of a sparse tensor type. In contrast to generating setup methods for
-/// each 
diff erently annotated sparse tensor, this method provides a convenient
-/// "one-size-fits-all" solution that simply takes an input tensor and
-/// annotations to implement all required setup in a general manner.
+/// A memory-resident sparse tensor using a storage scheme based on
+/// per-dimension sparse/dense annotations. This data structure provides a
+/// bufferized form of a sparse tensor type. In contrast to generating setup
+/// methods for each 
diff erently annotated sparse tensor, this method provides
+/// a convenient "one-size-fits-all" solution that simply takes an input tensor
+/// and annotations to implement all required setup in a general manner.
 template <typename P, typename I, typename V>
 class SparseTensorStorage : public SparseTensorStorageBase {
 public:
-  /// Constructs a sparse tensor storage scheme from the given sparse
-  /// tensor in coordinate scheme following the given per-rank dimension
-  /// dense/sparse annotations.
-  SparseTensorStorage(SparseTensorCOO<V> *tensor, const uint8_t *sparsity,
-                      const uint64_t *perm)
-      : sizes(tensor->getSizes()), rev(getRank()), pointers(getRank()),
-        indices(getRank()) {
+  /// Constructs a sparse tensor storage scheme with the given dimensions,
+  /// permutation, and per-dimension dense/sparse annotations, using
+  /// the coordinate scheme tensor for the initial contents if provided.
+  SparseTensorStorage(const std::vector<uint64_t> &szs, const uint64_t *perm,
+                      const uint8_t *sparsity, SparseTensorCOO<V> *tensor)
+      : sizes(szs), rev(getRank()), pointers(getRank()), indices(getRank()) {
+    uint64_t rank = getRank();
     // Store "reverse" permutation.
-    for (uint64_t d = 0, rank = getRank(); d < rank; d++)
-      rev[perm[d]] = d;
-    // Provide hints on capacity.
+    for (uint64_t r = 0; r < rank; r++)
+      rev[perm[r]] = r;
+    // Provide hints on capacity of pointers and indices.
     // TODO: needs fine-tuning based on sparsity
-    uint64_t nnz = tensor->getElements().size();
-    values.reserve(nnz);
-    for (uint64_t d = 0, s = 1, rank = getRank(); d < rank; d++) {
-      s *= sizes[d];
-      if (sparsity[d] == kCompressed) {
-        pointers[d].reserve(s + 1);
-        indices[d].reserve(s);
+    for (uint64_t r = 0, s = 1; r < rank; r++) {
+      s *= sizes[r];
+      if (sparsity[r] == kCompressed) {
+        pointers[r].reserve(s + 1);
+        indices[r].reserve(s);
         s = 1;
       } else {
-        assert(sparsity[d] == kDense && "singleton not yet supported");
+        assert(sparsity[r] == kDense && "singleton not yet supported");
       }
     }
     // Prepare sparse pointer structures for all dimensions.
-    for (uint64_t d = 0, rank = getRank(); d < rank; d++)
-      if (sparsity[d] == kCompressed)
-        pointers[d].push_back(0);
-    // Then setup the tensor.
-    fromCOO(tensor, sparsity, 0, nnz, 0);
+    for (uint64_t r = 0; r < rank; r++)
+      if (sparsity[r] == kCompressed)
+        pointers[r].push_back(0);
+    // Then assign contents from coordinate scheme tensor if provided.
+    if (tensor) {
+      uint64_t nnz = tensor->getElements().size();
+      values.reserve(nnz);
+      fromCOO(tensor, sparsity, 0, nnz, 0);
+    }
   }
 
   virtual ~SparseTensorStorage() {}
@@ -239,40 +243,54 @@ class SparseTensorStorage : public SparseTensorStorageBase {
   SparseTensorCOO<V> *toCOO(const uint64_t *perm) {
     // Restore original order of the dimension sizes and allocate coordinate
     // scheme with desired new ordering specified in perm.
-    uint64_t size = getRank();
-    std::vector<uint64_t> orgsz(size);
-    for (uint64_t r = 0; r < size; r++)
+    uint64_t rank = getRank();
+    std::vector<uint64_t> orgsz(rank);
+    for (uint64_t r = 0; r < rank; r++)
       orgsz[rev[r]] = sizes[r];
     SparseTensorCOO<V> *tensor = SparseTensorCOO<V>::newSparseTensorCOO(
-        size, orgsz.data(), perm, values.size());
+        rank, orgsz.data(), perm, values.size());
     // Populate coordinate scheme restored from old ordering and changed with
     // new ordering. Rather than applying both reorderings during the recursion,
     // we compute the combine permutation in advance.
-    std::vector<uint64_t> reord(size);
-    for (uint64_t r = 0; r < size; r++)
+    std::vector<uint64_t> reord(rank);
+    for (uint64_t r = 0; r < rank; r++)
       reord[r] = perm[rev[r]];
-    std::vector<uint64_t> idx(size);
+    std::vector<uint64_t> idx(rank);
     toCOO(tensor, reord, idx, 0, 0);
     assert(tensor->getElements().size() == values.size());
     return tensor;
   }
 
-  /// Factory method. Expects a coordinate scheme that respects the same
-  /// permutation as is desired for the new sparse storage scheme.
-  static SparseTensorStorage<P, I, V> *newSparseTensor(SparseTensorCOO<V> *t,
-                                                       const uint8_t *sparsity,
-                                                       const uint64_t *perm) {
-    t->sort(); // sort lexicographically
-    SparseTensorStorage<P, I, V> *n =
-        new SparseTensorStorage<P, I, V>(t, sparsity, perm);
-    delete t;
+  /// Factory method. Constructs a sparse tensor storage scheme with the given
+  /// dimensions, permutation, and per-dimension dense/sparse annotations,
+  /// using the coordinate scheme tensor for the initial contents if provided.
+  /// In the latter case, the coordinate scheme must respect the same
+  /// permutation as is desired for the new sparse tensor storage.
+  static SparseTensorStorage<P, I, V> *
+  newSparseTensor(uint64_t rank, const uint64_t *sizes, const uint64_t *perm,
+                  const uint8_t *sparsity, SparseTensorCOO<V> *tensor) {
+    SparseTensorStorage<P, I, V> *n = nullptr;
+    if (tensor) {
+      assert(tensor->getRank() == rank);
+      for (uint64_t r = 0; r < rank; r++)
+        assert(tensor->getSizes()[perm[r]] == sizes[r] || sizes[r] == 0);
+      tensor->sort(); // sort lexicographically
+      n = new SparseTensorStorage<P, I, V>(tensor->getSizes(), perm, sparsity,
+                                           tensor);
+      delete tensor;
+    } else {
+      std::vector<uint64_t> permsz(rank);
+      for (uint64_t r = 0; r < rank; r++)
+        permsz[perm[r]] = sizes[r];
+      n = new SparseTensorStorage<P, I, V>(permsz, perm, sparsity, tensor);
+    }
     return n;
   }
 
 private:
   /// Initializes sparse tensor storage scheme from a memory-resident sparse
-  /// tensor in coordinate scheme. This method prepares the pointers and indices
-  /// arrays under the given per-rank dimension dense/sparse annotations.
+  /// tensor in coordinate scheme. This method prepares the pointers and
+  /// indices arrays under the given per-dimension dense/sparse annotations.
   void fromCOO(SparseTensorCOO<V> *tensor, const uint8_t *sparsity, uint64_t lo,
                uint64_t hi, uint64_t d) {
     const std::vector<Element<V>> &elements = tensor->getElements();
@@ -342,7 +360,7 @@ class SparseTensorStorage : public SparseTensorStorageBase {
   }
 
 private:
-  std::vector<uint64_t> sizes; // per-rank dimension sizes
+  std::vector<uint64_t> sizes; // per-dimension sizes
   std::vector<uint64_t> rev;   // "reverse" permutation
   std::vector<std::vector<P>> pointers;
   std::vector<std::vector<I>> indices;
@@ -429,7 +447,7 @@ static void readExtFROSTTHeader(FILE *file, char *name, uint64_t *idata) {
 /// Reads a sparse tensor with the given filename into a memory-resident
 /// sparse tensor in coordinate scheme.
 template <typename V>
-static SparseTensorCOO<V> *openSparseTensorCOO(char *filename, uint64_t size,
+static SparseTensorCOO<V> *openSparseTensorCOO(char *filename, uint64_t rank,
                                                const uint64_t *sizes,
                                                const uint64_t *perm) {
   // Open the file.
@@ -448,20 +466,20 @@ static SparseTensorCOO<V> *openSparseTensorCOO(char *filename, uint64_t size,
     fprintf(stderr, "Unknown format %s\n", filename);
     exit(1);
   }
-  // Prepare sparse tensor object with per-rank dimension sizes
+  // Prepare sparse tensor object with per-dimension sizes
   // and the number of nonzeros as initial capacity.
-  assert(size == idata[0] && "rank mismatch");
+  assert(rank == idata[0] && "rank mismatch");
   uint64_t nnz = idata[1];
-  for (uint64_t r = 0; r < size; r++)
+  for (uint64_t r = 0; r < rank; r++)
     assert((sizes[r] == 0 || sizes[r] == idata[2 + r]) &&
            "dimension size mismatch");
   SparseTensorCOO<V> *tensor =
-      SparseTensorCOO<V>::newSparseTensorCOO(size, idata + 2, perm, nnz);
+      SparseTensorCOO<V>::newSparseTensorCOO(rank, idata + 2, perm, nnz);
   //  Read all nonzero elements.
-  std::vector<uint64_t> indices(size);
+  std::vector<uint64_t> indices(rank);
   for (uint64_t k = 0; k < nnz; k++) {
     uint64_t idx = -1;
-    for (uint64_t r = 0; r < size; r++) {
+    for (uint64_t r = 0; r < rank; r++) {
       if (fscanf(file, "%" PRIu64, &idx) != 1) {
         fprintf(stderr, "Cannot find next index in %s\n", filename);
         exit(1);
@@ -518,24 +536,30 @@ enum PrimaryTypeEnum : uint64_t {
   kI8 = 6
 };
 
-enum Action : uint32_t { kFromFile = 0, kFromCOO = 1, kNewCOO = 2, kToCOO = 3 };
+enum Action : uint32_t {
+  kEmpty = 0,
+  kFromFile = 1,
+  kFromCOO = 2,
+  kEmptyCOO = 3,
+  kToCOO = 4
+};
 
 #define CASE(p, i, v, P, I, V)                                                 \
   if (ptrTp == (p) && indTp == (i) && valTp == (v)) {                          \
     SparseTensorCOO<V> *tensor = nullptr;                                      \
     if (action == kFromFile)                                                   \
       tensor =                                                                 \
-          openSparseTensorCOO<V>(static_cast<char *>(ptr), size, sizes, perm); \
+          openSparseTensorCOO<V>(static_cast<char *>(ptr), rank, sizes, perm); \
     else if (action == kFromCOO)                                               \
       tensor = static_cast<SparseTensorCOO<V> *>(ptr);                         \
-    else if (action == kNewCOO)                                                \
-      return SparseTensorCOO<V>::newSparseTensorCOO(size, sizes, perm);        \
+    else if (action == kEmptyCOO)                                              \
+      return SparseTensorCOO<V>::newSparseTensorCOO(rank, sizes, perm);        \
     else if (action == kToCOO)                                                 \
       return static_cast<SparseTensorStorage<P, I, V> *>(ptr)->toCOO(perm);    \
     else                                                                       \
-      assert(0);                                                               \
-    return SparseTensorStorage<P, I, V>::newSparseTensor(tensor, sparsity,     \
-                                                         perm);                \
+      assert(action == kEmpty);                                                \
+    return SparseTensorStorage<P, I, V>::newSparseTensor(rank, sizes, perm,    \
+                                                         sparsity, tensor);    \
   }
 
 #define IMPL1(NAME, TYPE, LIB)                                                 \
@@ -586,9 +610,10 @@ enum Action : uint32_t { kFromFile = 0, kFromCOO = 1, kNewCOO = 2, kToCOO = 3 };
 /// method for materializing sparse tensors into the computation.
 ///
 /// action:
-/// kFromFile = ptr contains filename to read into storage
-/// kFromCOO = ptr contains coordinate scheme to assign to new storage
-/// kNewCOO = returns empty coordinate scheme to fill and use with kFromCOO
+/// kEmpty = returns empty storage to fill later
+/// kFromFile = returns storage, where ptr contains filename to read
+/// kFromCOO = returns storage, where ptr contains coordinate scheme to assign
+/// kEmptyCOO = returns empty coordinate scheme to fill and use with kFromCOO
 /// kToCOO = returns coordinate scheme from storage in ptr to use with kFromCOO
 void *
 _mlir_ciface_newSparseTensor(StridedMemRefType<uint8_t, 1> *aref, // NOLINT
@@ -603,7 +628,7 @@ _mlir_ciface_newSparseTensor(StridedMemRefType<uint8_t, 1> *aref, // NOLINT
   const uint8_t *sparsity = aref->data + aref->offset;
   const uint64_t *sizes = sref->data + sref->offset;
   const uint64_t *perm = pref->data + pref->offset;
-  uint64_t size = aref->sizes[0];
+  uint64_t rank = aref->sizes[0];
 
   // Double matrices with all combinations of overhead storage.
   CASE(kU64, kU64, kF64, uint64_t, uint64_t, double);
@@ -743,14 +768,14 @@ void *convertToMLIRSparseTensor(uint64_t rank, uint64_t nse, uint64_t *shape,
       rank, shape, perm.data(), nse);
   std::vector<uint64_t> idx(rank);
   for (uint64_t i = 0, base = 0; i < nse; i++) {
-    for (uint64_t j = 0; j < rank; j++)
-      idx[j] = indices[base + j];
+    for (uint64_t r = 0; r < rank; r++)
+      idx[r] = indices[base + r];
     tensor->add(idx, values[i]);
     base += rank;
   }
   // Return sparse tensor storage format as opaque pointer.
   return SparseTensorStorage<uint64_t, uint64_t, double>::newSparseTensor(
-      tensor, sparse.data(), perm.data());
+      rank, shape, perm.data(), sparse.data(), tensor);
 }
 
 } // extern "C"

diff  --git a/mlir/test/Dialect/SparseTensor/conversion.mlir b/mlir/test/Dialect/SparseTensor/conversion.mlir
index 771aedb6dc1f..d6e43079d8c0 100644
--- a/mlir/test/Dialect/SparseTensor/conversion.mlir
+++ b/mlir/test/Dialect/SparseTensor/conversion.mlir
@@ -69,12 +69,12 @@ func @sparse_dim3d_const(%arg0: tensor<10x20x30xf64, #SparseTensor>) -> index {
 
 // CHECK-LABEL: func @sparse_new1d(
 //  CHECK-SAME: %[[A:.*]]: !llvm.ptr<i8>) -> !llvm.ptr<i8>
-//   CHECK-DAG: %[[U:.*]] = arith.constant dense<1> : tensor<1xi8>
-//   CHECK-DAG: %[[V:.*]] = arith.constant dense<128> : tensor<1xi64>
-//   CHECK-DAG: %[[W:.*]] = arith.constant dense<0> : tensor<1xi64>
-//   CHECK-DAG: %[[X:.*]] = tensor.cast %[[U]] : tensor<1xi8> to tensor<?xi8>
-//   CHECK-DAG: %[[Y:.*]] = tensor.cast %[[V]] : tensor<1xi64> to tensor<?xi64>
-//   CHECK-DAG: %[[Z:.*]] = tensor.cast %[[W]] : tensor<1xi64> to tensor<?xi64>
+//   CHECK-DAG: %[[P:.*]] = memref.alloca() : memref<1xi8>
+//   CHECK-DAG: %[[Q:.*]] = memref.alloca() : memref<1xi64>
+//   CHECK-DAG: %[[R:.*]] = memref.alloca() : memref<1xi64>
+//   CHECK-DAG: %[[X:.*]] = memref.cast %[[P]] : memref<1xi8> to memref<?xi8>
+//   CHECK-DAG: %[[Y:.*]] = memref.cast %[[Q]] : memref<1xi64> to memref<?xi64>
+//   CHECK-DAG: %[[Z:.*]] = memref.cast %[[R]] : memref<1xi64> to memref<?xi64>
 //       CHECK: %[[T:.*]] = call @newSparseTensor(%[[X]], %[[Y]], %[[Z]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[A]])
 //       CHECK: return %[[T]] : !llvm.ptr<i8>
 func @sparse_new1d(%arg0: !llvm.ptr<i8>) -> tensor<128xf64, #SparseVector> {
@@ -84,12 +84,12 @@ func @sparse_new1d(%arg0: !llvm.ptr<i8>) -> tensor<128xf64, #SparseVector> {
 
 // CHECK-LABEL: func @sparse_new2d(
 //  CHECK-SAME: %[[A:.*]]: !llvm.ptr<i8>) -> !llvm.ptr<i8>
-//   CHECK-DAG: %[[U:.*]] = arith.constant dense<[0, 1]> : tensor<2xi8>
-//   CHECK-DAG: %[[V:.*]] = arith.constant dense<0> : tensor<2xi64>
-//   CHECK-DAG: %[[W:.*]] = arith.constant dense<[0, 1]> : tensor<2xi64>
-//   CHECK-DAG: %[[X:.*]] = tensor.cast %[[U]] : tensor<2xi8> to tensor<?xi8>
-//   CHECK-DAG: %[[Y:.*]] = tensor.cast %[[V]] : tensor<2xi64> to tensor<?xi64>
-//   CHECK-DAG: %[[Z:.*]] = tensor.cast %[[W]] : tensor<2xi64> to tensor<?xi64>
+//   CHECK-DAG: %[[P:.*]] = memref.alloca() : memref<2xi8>
+//   CHECK-DAG: %[[Q:.*]] = memref.alloca() : memref<2xi64>
+//   CHECK-DAG: %[[R:.*]] = memref.alloca() : memref<2xi64>
+//   CHECK-DAG: %[[X:.*]] = memref.cast %[[P]] : memref<2xi8> to memref<?xi8>
+//   CHECK-DAG: %[[Y:.*]] = memref.cast %[[Q]] : memref<2xi64> to memref<?xi64>
+//   CHECK-DAG: %[[Z:.*]] = memref.cast %[[R]] : memref<2xi64> to memref<?xi64>
 //       CHECK: %[[T:.*]] = call @newSparseTensor(%[[X]], %[[Y]], %[[Z]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[A]])
 //       CHECK: return %[[T]] : !llvm.ptr<i8>
 func @sparse_new2d(%arg0: !llvm.ptr<i8>) -> tensor<?x?xf32, #SparseMatrix> {
@@ -99,12 +99,12 @@ func @sparse_new2d(%arg0: !llvm.ptr<i8>) -> tensor<?x?xf32, #SparseMatrix> {
 
 // CHECK-LABEL: func @sparse_new3d(
 //  CHECK-SAME: %[[A:.*]]: !llvm.ptr<i8>) -> !llvm.ptr<i8>
-//   CHECK-DAG: %[[U:.*]] = arith.constant dense<[0, 1, 1]> : tensor<3xi8>
-//   CHECK-DAG: %[[V:.*]] = arith.constant dense<0> : tensor<3xi64>
-//   CHECK-DAG: %[[W:.*]] = arith.constant dense<[1, 2, 0]> : tensor<3xi64>
-//   CHECK-DAG: %[[X:.*]] = tensor.cast %[[U]] : tensor<3xi8> to tensor<?xi8>
-//   CHECK-DAG: %[[Y:.*]] = tensor.cast %[[V]] : tensor<3xi64> to tensor<?xi64>
-//   CHECK-DAG: %[[Z:.*]] = tensor.cast %[[W]] : tensor<3xi64> to tensor<?xi64>
+//   CHECK-DAG: %[[P:.*]] = memref.alloca() : memref<3xi8>
+//   CHECK-DAG: %[[Q:.*]] = memref.alloca() : memref<3xi64>
+//   CHECK-DAG: %[[R:.*]] = memref.alloca() : memref<3xi64>
+//   CHECK-DAG: %[[X:.*]] = memref.cast %[[P]] : memref<3xi8> to memref<?xi8>
+//   CHECK-DAG: %[[Y:.*]] = memref.cast %[[Q]] : memref<3xi64> to memref<?xi64>
+//   CHECK-DAG: %[[Z:.*]] = memref.cast %[[R]] : memref<3xi64> to memref<?xi64>
 //       CHECK: %[[T:.*]] = call @newSparseTensor(%[[X]], %[[Y]], %[[Z]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[A]])
 //       CHECK: return %[[T]] : !llvm.ptr<i8>
 func @sparse_new3d(%arg0: !llvm.ptr<i8>) -> tensor<?x?x?xf32, #SparseTensor> {
@@ -112,6 +112,29 @@ func @sparse_new3d(%arg0: !llvm.ptr<i8>) -> tensor<?x?x?xf32, #SparseTensor> {
   return %0 : tensor<?x?x?xf32, #SparseTensor>
 }
 
+// CHECK-LABEL: func @sparse_init(
+//  CHECK-SAME: %[[I:.*]]: index,
+//  CHECK-SAME: %[[J:.*]]: index) -> !llvm.ptr<i8>
+//   CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+//   CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+//   CHECK-DAG: %[[P:.*]] = memref.alloca() : memref<2xi8>
+//   CHECK-DAG: %[[Q:.*]] = memref.alloca() : memref<2xi64>
+//   CHECK-DAG: %[[R:.*]] = memref.alloca() : memref<2xi64>
+//   CHECK-DAG: %[[X:.*]] = memref.cast %[[P]] : memref<2xi8> to memref<?xi8>
+//   CHECK-DAG: %[[Y:.*]] = memref.cast %[[Q]] : memref<2xi64> to memref<?xi64>
+//   CHECK-DAG: %[[Z:.*]] = memref.cast %[[R]] : memref<2xi64> to memref<?xi64>
+//   CHECK-DAG: %[[II:.*]] = arith.index_cast %[[I]] : index to i64
+//   CHECK-DAG: %[[JJ:.*]] = arith.index_cast %[[J]] : index to i64
+//   CHECK-DAG: memref.store %[[II]], %[[Q]][%[[C0]]] : memref<2xi64>
+//   CHECK-DAG: memref.store %[[JJ]], %[[Q]][%[[C1]]] : memref<2xi64>
+//       CHECK: %[[A:.*]] = llvm.mlir.null : !llvm.ptr<i8>
+//       CHECK: %[[T:.*]] = call @newSparseTensor(%[[X]], %[[Y]], %[[Z]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[A]])
+//       CHECK: return %[[T]] : !llvm.ptr<i8>
+func @sparse_init(%arg0: index, %arg1: index) -> tensor<?x?xf64, #SparseMatrix> {
+  %0 = sparse_tensor.init [%arg0, %arg1] : tensor<?x?xf64, #SparseMatrix>
+  return %0 : tensor<?x?xf64, #SparseMatrix>
+}
+
 // CHECK-LABEL: func @sparse_release(
 //  CHECK-SAME: %[[A:.*]]: !llvm.ptr<i8>)
 //       CHECK: call @delSparseTensor(%[[A]]) : (!llvm.ptr<i8>) -> ()
@@ -133,20 +156,22 @@ func @sparse_nop_convert(%arg0: tensor<64xf32, #SparseVector>) -> tensor<64xf32,
 //  CHECK-SAME: %[[A:.*]]: tensor<?xi32>) -> !llvm.ptr<i8>
 //   CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
 //   CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
-//   CHECK-DAG: %[[D0:.*]] = arith.constant dense<0> : tensor<1xi64>
-//   CHECK-DAG: %[[D1:.*]] = arith.constant dense<1> : tensor<1xi8>
-//   CHECK-DAG: %[[X:.*]] = tensor.cast %[[D1]] : tensor<1xi8> to tensor<?xi8>
-//   CHECK-DAG: %[[Y:.*]] = tensor.cast %[[D0]] : tensor<1xi64> to tensor<?xi64>
-//       CHECK: %[[C:.*]] = call @newSparseTensor(%[[X]], %[[Y]], %[[Y]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.}})
+//   CHECK-DAG: %[[P:.*]] = memref.alloca() : memref<1xi8>
+//   CHECK-DAG: %[[Q:.*]] = memref.alloca() : memref<1xi64>
+//   CHECK-DAG: %[[R:.*]] = memref.alloca() : memref<1xi64>
+//   CHECK-DAG: %[[X:.*]] = memref.cast %[[P]] : memref<1xi8> to memref<?xi8>
+//   CHECK-DAG: %[[Y:.*]] = memref.cast %[[Q]] : memref<1xi64> to memref<?xi64>
+//   CHECK-DAG: %[[Z:.*]] = memref.cast %[[R]] : memref<1xi64> to memref<?xi64>
+//       CHECK: %[[C:.*]] = call @newSparseTensor(%[[X]], %[[Y]], %[[Z]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.}})
 //       CHECK: %[[M:.*]] = memref.alloca() : memref<1xindex>
 //       CHECK: %[[T:.*]] = memref.cast %[[M]] : memref<1xindex> to memref<?xindex>
 //       CHECK: %[[U:.*]] = tensor.dim %[[A]], %[[C0]] : tensor<?xi32>
 //       CHECK: scf.for %[[I:.*]] = %[[C0]] to %[[U]] step %[[C1]] {
 //       CHECK:   %[[E:.*]] = tensor.extract %[[A]][%[[I]]] : tensor<?xi32>
 //       CHECK:   memref.store %[[I]], %[[M]][%[[C0]]] : memref<1xindex>
-//       CHECK:   call @addEltI32(%[[C]], %[[E]], %[[T]], %[[Y]])
+//       CHECK:   call @addEltI32(%[[C]], %[[E]], %[[T]], %[[Z]])
 //       CHECK: }
-//       CHECK: %[[T:.*]] = call @newSparseTensor(%[[X]], %[[Y]], %[[Y]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[C]])
+//       CHECK: %[[T:.*]] = call @newSparseTensor(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[C]])
 //       CHECK: return %[[T]] : !llvm.ptr<i8>
 func @sparse_convert_1d(%arg0: tensor<?xi32>) -> tensor<?xi32, #SparseVector> {
   %0 = sparse_tensor.convert %arg0 : tensor<?xi32> to tensor<?xi32, #SparseVector>
@@ -167,12 +192,12 @@ func @sparse_convert_1d_ss(%arg0: tensor<?xf32, #SparseVector64>) -> tensor<?xf3
 //  CHECK-SAME: %[[A:.*]]: tensor<2x4xf64>) -> !llvm.ptr<i8>
 //   CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
 //   CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
-//   CHECK-DAG: %[[U:.*]] = arith.constant dense<[0, 1]> : tensor<2xi8>
-//   CHECK-DAG: %[[V:.*]] = arith.constant dense<[2, 4]> : tensor<2xi64>
-//   CHECK-DAG: %[[W:.*]] = arith.constant dense<[0, 1]> : tensor<2xi64>
-//   CHECK-DAG: %[[X:.*]] = tensor.cast %[[U]] : tensor<2xi8> to tensor<?xi8>
-//   CHECK-DAG: %[[Y:.*]] = tensor.cast %[[V]] : tensor<2xi64> to tensor<?xi64>
-//   CHECK-DAG: %[[Z:.*]] = tensor.cast %[[W]] : tensor<2xi64> to tensor<?xi64>
+//   CHECK-DAG: %[[P:.*]] = memref.alloca() : memref<2xi8>
+//   CHECK-DAG: %[[Q:.*]] = memref.alloca() : memref<2xi64>
+//   CHECK-DAG: %[[R:.*]] = memref.alloca() : memref<2xi64>
+//   CHECK-DAG: %[[X:.*]] = memref.cast %[[P]] : memref<2xi8> to memref<?xi8>
+//   CHECK-DAG: %[[Y:.*]] = memref.cast %[[Q]] : memref<2xi64> to memref<?xi64>
+//   CHECK-DAG: %[[Z:.*]] = memref.cast %[[R]] : memref<2xi64> to memref<?xi64>
 //       CHECK: %[[C:.*]] = call @newSparseTensor(%[[X]], %[[Y]], %[[Z]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.}})
 //       CHECK: %[[M:.*]] = memref.alloca() : memref<2xindex>
 //       CHECK: %[[T:.*]] = memref.cast %[[M]] : memref<2xindex> to memref<?xindex>
@@ -184,50 +209,40 @@ func @sparse_convert_1d_ss(%arg0: tensor<?xf32, #SparseVector64>) -> tensor<?xf3
 //       CHECK:     call @addEltF64(%[[C]], %[[E]], %[[T]], %[[Z]])
 //       CHECK:   }
 //       CHECK: }
-//       CHECK: %[[T:.*]] = call @newSparseTensor(%[[X]], %[[Y]], %[[Z]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[C]])
+//       CHECK: %[[T:.*]] = call @newSparseTensor(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[C]])
 //       CHECK: return %[[T]] : !llvm.ptr<i8>
 func @sparse_convert_2d(%arg0: tensor<2x4xf64>) -> tensor<2x4xf64, #SparseMatrix> {
   %0 = sparse_tensor.convert %arg0 : tensor<2x4xf64> to tensor<2x4xf64, #SparseMatrix>
   return %0 : tensor<2x4xf64, #SparseMatrix>
 }
 
-#CSR = #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ] }>
-
-// CHECK-LABEL:   func @entry() -> !llvm.ptr<i8> {
-// CHECK:           %[[C1:.*]] = arith.constant 1 : i32
-// CHECK:           %[[Offset:.*]] = arith.constant dense<[0, 1]> : tensor<2xi64>
-// CHECK:           %[[Dims:.*]] = arith.constant dense<[8, 7]> : tensor<2xi64>
-// CHECK:           %[[Base:.*]] = arith.constant dense<[0, 1]> : tensor<2xi8>
-// CHECK:           %[[I2:.*]] = arith.constant 2 : index
-// CHECK:           %[[SparseV:.*]] = arith.constant dense<[1.000000e+00, 5.000000e+00]> : tensor<2xf32>
-// CHECK:           %[[SparseI:.*]] = arith.constant dense<{{\[\[}}0, 0], [1, 6]]> : tensor<2x2xi64>
-// CHECK:           %[[I1:.*]] = arith.constant 1 : index
-// CHECK:           %[[I0:.*]] = arith.constant 0 : index
-// CHECK:           %[[C2:.*]] = arith.constant 2 : i32
-// CHECK:           %[[BaseD:.*]] = tensor.cast %[[Base]] : tensor<2xi8> to tensor<?xi8>
-// CHECK:           %[[DimsD:.*]] = tensor.cast %[[Dims]] : tensor<2xi64> to tensor<?xi64>
-// CHECK:           %[[OffsetD:.*]] = tensor.cast %[[Offset]] : tensor<2xi64> to tensor<?xi64>
-// CHECK:           %[[TCOO:.*]] = call @newSparseTensor(%[[BaseD]], %[[DimsD]], %[[OffsetD]], %{{.*}}, %{{.*}}, %{{.*}}, %[[C2]], %{{.}})
-// CHECK:           %[[Index:.*]] = memref.alloca() : memref<2xindex>
-// CHECK:           %[[IndexD:.*]] = memref.cast %[[Index]] : memref<2xindex> to memref<?xindex>
-// CHECK:           scf.for %[[IV:.*]] = %[[I0]] to %[[I2]] step %[[I1]] {
-// CHECK:             %[[VAL0:.*]] = tensor.extract %[[SparseI]]{{\[}}%[[IV]], %[[I0]]] : tensor<2x2xi64>
-// CHECK:             %[[VAL1:.*]] = arith.index_cast %[[VAL0]] : i64 to index
-// CHECK:             memref.store %[[VAL1]], %[[Index]]{{\[}}%[[I0]]] : memref<2xindex>
-// CHECK:             %[[VAL2:.*]] = tensor.extract %[[SparseI]]{{\[}}%[[IV]], %[[I1]]] : tensor<2x2xi64>
-// CHECK:             %[[VAL3:.*]] = arith.index_cast %[[VAL2]] : i64 to index
-// CHECK:             memref.store %[[VAL3]], %[[Index]]{{\[}}%[[I1]]] : memref<2xindex>
-// CHECK:             %[[VAL4:.*]] = tensor.extract %[[SparseV]]{{\[}}%[[IV]]] : tensor<2xf32>
-// CHECK:             call @addEltF32(%[[TCOO]], %[[VAL4]], %[[IndexD]], %[[OffsetD]])
-// CHECK:           }
-// CHECK:           %[[T:.*]] = call @newSparseTensor(%[[BaseD]], %[[DimsD]], %[[OffsetD]], %{{.*}}, %{{.*}}, %[[C1]], %{{.*}})
-// CHECK:           return %[[T]] : !llvm.ptr<i8>
-func @entry() -> tensor<8x7xf32, #CSR>{
+// CHECK-LABEL: func @sparse_constant() -> !llvm.ptr<i8> {
+//   CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+//   CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+//   CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
+//   CHECK-DAG: %[[P:.*]] = memref.alloca() : memref<2xi8>
+//   CHECK-DAG: %[[Q:.*]] = memref.alloca() : memref<2xi64>
+//   CHECK-DAG: %[[R:.*]] = memref.alloca() : memref<2xi64>
+//   CHECK-DAG: %[[X:.*]] = memref.cast %[[P]] : memref<2xi8> to memref<?xi8>
+//   CHECK-DAG: %[[Y:.*]] = memref.cast %[[Q]] : memref<2xi64> to memref<?xi64>
+//   CHECK-DAG: %[[Z:.*]] = memref.cast %[[R]] : memref<2xi64> to memref<?xi64>
+//       CHECK: %[[C:.*]] = call @newSparseTensor(%[[X]], %[[Y]], %[[Z]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.}})
+//       CHECK: %[[M:.*]] = memref.alloca() : memref<2xindex>
+//       CHECK: %[[N:.*]] = memref.cast %[[M]] : memref<2xindex> to memref<?xindex>
+//       CHECK: scf.for %[[I:.*]] = %[[C0]] to %[[C2]] step %[[C1]] {
+//       CHECK:   memref.store %{{.*}}, %[[M]][%[[C0]]] : memref<2xindex>
+//       CHECK:   memref.store %{{.*}}, %[[M]][%[[C1]]] : memref<2xindex>
+//       CHECK:   %[[V:.*]] = tensor.extract %{{.*}}[%[[I]]] : tensor<2xf32>
+//       CHECK:   call @addEltF32(%{{.*}}, %[[V]], %[[N]], %{{.*}})
+//       CHECK: }
+//       CHECK: %[[T:.*]] = call @newSparseTensor(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[C]])
+//       CHECK: return %[[T]] : !llvm.ptr<i8>
+func @sparse_constant() -> tensor<8x7xf32, #SparseMatrix>{
   // Initialize a tensor.
   %0 = arith.constant sparse<[[0, 0], [1, 6]], [1.0, 5.0]> : tensor<8x7xf32>
   // Convert the tensor to a sparse tensor.
-  %1 = sparse_tensor.convert %0 : tensor<8x7xf32> to tensor<8x7xf32, #CSR>
-  return %1 : tensor<8x7xf32, #CSR>
+  %1 = sparse_tensor.convert %0 : tensor<8x7xf32> to tensor<8x7xf32, #SparseMatrix>
+  return %1 : tensor<8x7xf32, #SparseMatrix>
 }
 
 // CHECK-LABEL: func @sparse_convert_3d(
@@ -235,15 +250,15 @@ func @entry() -> tensor<8x7xf32, #CSR>{
 //   CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
 //   CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
 //   CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
-//   CHECK-DAG: %[[U:.*]] = arith.constant dense<[0, 1, 1]> : tensor<3xi8>
-//   CHECK-DAG: %[[V:.*]] = arith.constant dense<0> : tensor<3xi64>
-//   CHECK-DAG: %[[W:.*]] = arith.constant dense<[1, 2, 0]> : tensor<3xi64>
-//   CHECK-DAG: %[[X:.*]] = tensor.cast %[[U]] : tensor<3xi8> to tensor<?xi8>
-//   CHECK-DAG: %[[Y:.*]] = tensor.cast %[[V]] : tensor<3xi64> to tensor<?xi64>
-//   CHECK-DAG: %[[Z:.*]] = tensor.cast %[[W]] : tensor<3xi64> to tensor<?xi64>
+//   CHECK-DAG: %[[P:.*]] = memref.alloca() : memref<3xi8>
+//   CHECK-DAG: %[[Q:.*]] = memref.alloca() : memref<3xi64>
+//   CHECK-DAG: %[[R:.*]] = memref.alloca() : memref<3xi64>
+//   CHECK-DAG: %[[X:.*]] = memref.cast %[[P]] : memref<3xi8> to memref<?xi8>
+//   CHECK-DAG: %[[Y:.*]] = memref.cast %[[Q]] : memref<3xi64> to memref<?xi64>
+//   CHECK-DAG: %[[Z:.*]] = memref.cast %[[R]] : memref<3xi64> to memref<?xi64>
 //       CHECK: %[[C:.*]] = call @newSparseTensor(%[[X]], %[[Y]], %[[Z]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.}})
 //       CHECK: %[[M:.*]] = memref.alloca() : memref<3xindex>
-//       CHECK: %[[T:.*]] = memref.cast %[[M]] : memref<3xindex> to memref<?xindex>
+//       CHECK: %[[N:.*]] = memref.cast %[[M]] : memref<3xindex> to memref<?xindex>
 //       CHECK: %[[U1:.*]] = tensor.dim %[[A]], %[[C0]] : tensor<?x?x?xf64>
 //       CHECK: %[[U2:.*]] = tensor.dim %[[A]], %[[C1]] : tensor<?x?x?xf64>
 //       CHECK: %[[U3:.*]] = tensor.dim %[[A]], %[[C2]] : tensor<?x?x?xf64>
@@ -254,11 +269,11 @@ func @entry() -> tensor<8x7xf32, #CSR>{
 //       CHECK:       memref.store %[[I]], %[[M]][%[[C0]]] : memref<3xindex>
 //       CHECK:       memref.store %[[J]], %[[M]][%[[C1]]] : memref<3xindex>
 //       CHECK:       memref.store %[[K]], %[[M]][%[[C2]]] : memref<3xindex>
-//       CHECK:       call @addEltF64(%[[C]], %[[E]], %[[T]], %[[Z]])
+//       CHECK:       call @addEltF64(%[[C]], %[[E]], %[[N]], %[[Z]])
 //       CHECK:     }
 //       CHECK:   }
 //       CHECK: }
-//       CHECK: %[[T:.*]] = call @newSparseTensor(%[[X]], %[[Y]], %[[Z]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[C]])
+//       CHECK: %[[T:.*]] = call @newSparseTensor(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[C]])
 //       CHECK: return %[[T]] : !llvm.ptr<i8>
 func @sparse_convert_3d(%arg0: tensor<?x?x?xf64>) -> tensor<?x?x?xf64, #SparseTensor> {
   %0 = sparse_tensor.convert %arg0 : tensor<?x?x?xf64> to tensor<?x?x?xf64, #SparseTensor>


        


More information about the Mlir-commits mailing list