[Mlir-commits] [mlir] f66e576 - [mlir][sparse] first version of "truly" dynamic sparse tensors as outputs of kernels

Aart Bik llvmlistbot at llvm.org
Mon Nov 15 15:33:40 PST 2021


Author: Aart Bik
Date: 2021-11-15T15:33:32-08:00
New Revision: f66e5769d41b436176f87a08279feec5163c32f3

URL: https://github.com/llvm/llvm-project/commit/f66e5769d41b436176f87a08279feec5163c32f3
DIFF: https://github.com/llvm/llvm-project/commit/f66e5769d41b436176f87a08279feec5163c32f3.diff

LOG: [mlir][sparse] first version of "truly" dynamic sparse tensors as outputs of kernels

This revision contains all "sparsification" ops and rewriting necessary to support sparse output tensors when the kernel has no reduction (viz. insertions occur in lexicographic order and are "injective"). This will be later generalized to allow reductions too. Also, this first revision only supports sparse 1-d tensors (viz. vectors) as output in the runtime support library. This will be generalized to n-d tensors shortly. But this way, the revision is kept to a manageable size.

Reviewed By: bixia

Differential Revision: https://reviews.llvm.org/D113705

Added: 
    mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_vector_ops.mlir

Modified: 
    mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
    mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
    mlir/lib/ExecutionEngine/SparseTensorUtils.cpp
    mlir/test/Dialect/SparseTensor/conversion.mlir
    mlir/test/Dialect/SparseTensor/dense.mlir
    mlir/test/Dialect/SparseTensor/fold.mlir
    mlir/test/Dialect/SparseTensor/invalid.mlir
    mlir/test/Dialect/SparseTensor/roundtrip.mlir
    mlir/test/Dialect/SparseTensor/sparse_out.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index c4682a16a0eff..4cbd3682de7a9 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -25,10 +25,10 @@ class SparseTensor_Op<string mnemonic, list<OpTrait> traits = []>
 }
 
 //===----------------------------------------------------------------------===//
-// Operations.
+// Sparse Tensor Operations.
 //===----------------------------------------------------------------------===//
 
-def SparseTensor_NewOp : SparseTensor_Op<"new", []>,
+def SparseTensor_NewOp : SparseTensor_Op<"new", [NoSideEffect]>,
     Arguments<(ins AnyType:$source)>,
     Results<(outs TensorOf<[AnyType]>:$result)> {
   string summary = "Materializes a new sparse tensor from given source";
@@ -51,15 +51,15 @@ def SparseTensor_NewOp : SparseTensor_Op<"new", []>,
   let assemblyFormat = "$source attr-dict `:` type($source) `to` type($result)";
 }
 
-def SparseTensor_InitOp : SparseTensor_Op<"init", []>,
+def SparseTensor_InitOp : SparseTensor_Op<"init", [NoSideEffect]>,
     Arguments<(ins Variadic<Index>:$sizes)>,
     Results<(outs AnyTensor:$result)> {
-  string summary = "Materializes an empty sparse tensor";
+  string summary = "Materializes an unitialized sparse tensor";
   string description = [{
-    Materializes an empty sparse tensor with given shape (either static or dynamic).
-    The operation is provided as an anchor that materializes a properly typed sparse
-    tensor into the output clause of a subsequent operation that yields a sparse tensor
-    as the result.
+    Materializes an uninitialized sparse tensor with given shape (either static
+    or dynamic). The operation is provided as an anchor that materializes a
+    properly typed but uninitialized sparse tensor into the output clause of
+    a subsequent operation that yields a sparse tensor as the result.
 
     Example:
 
@@ -114,31 +114,12 @@ def SparseTensor_ConvertOp : SparseTensor_Op<"convert",
   let hasFolder = 1;
 }
 
-def SparseTensor_ReleaseOp : SparseTensor_Op<"release", []>,
-    Arguments<(ins AnyTensor:$tensor)> {
-  string description = [{
-    Releases the underlying sparse storage scheme for a tensor that
-    materialized earlier through a `new` operator, `init` operator, or a
-    non-trivial `convert` operator with an annotated tensor type as destination.
-    This operation should only be called once for any materialized tensor.
-    Also, after this operation, any subsequent `memref` querying operation
-    on the tensor returns undefined results.
-
-    Example:
-
-    ```mlir
-    sparse_tensor.release %tensor : tensor<1024x1024xf64, #CSR>
-    ```
-  }];
-  let assemblyFormat = "$tensor attr-dict `:` type($tensor)";
-}
-
 def SparseTensor_ToPointersOp : SparseTensor_Op<"pointers", [NoSideEffect]>,
     Arguments<(ins AnyTensor:$tensor, Index:$dim)>,
     Results<(outs AnyStridedMemRefOfRank<1>:$result)> {
-  let summary = "Extract pointers array at given dimension from a tensor";
+  let summary = "Extracts pointers array at given dimension from a tensor";
   let description = [{
-    Returns the pointers array of the sparse storage scheme at the
+    Returns the pointers array of the sparse storage format at the
     given dimension for the given sparse tensor. This is similar to the
     `memref.buffer_cast` operation in the sense that it provides a bridge
     between a tensor world view and a bufferized world view. Unlike the
@@ -160,9 +141,9 @@ def SparseTensor_ToPointersOp : SparseTensor_Op<"pointers", [NoSideEffect]>,
 def SparseTensor_ToIndicesOp : SparseTensor_Op<"indices", [NoSideEffect]>,
     Arguments<(ins AnyTensor:$tensor, Index:$dim)>,
     Results<(outs AnyStridedMemRefOfRank<1>:$result)> {
-  let summary = "Extract indices array at given dimension from a tensor";
+  let summary = "Extracts indices array at given dimension from a tensor";
   let description = [{
-    Returns the indices array of the sparse storage scheme at the
+    Returns the indices array of the sparse storage format at the
     given dimension for the given sparse tensor. This is similar to the
     `memref.buffer_cast` operation in the sense that it provides a bridge
     between a tensor world view and a bufferized world view. Unlike the
@@ -184,9 +165,9 @@ def SparseTensor_ToIndicesOp : SparseTensor_Op<"indices", [NoSideEffect]>,
 def SparseTensor_ToValuesOp : SparseTensor_Op<"values", [NoSideEffect]>,
     Arguments<(ins AnyTensor:$tensor)>,
     Results<(outs AnyStridedMemRefOfRank<1>:$result)> {
-  let summary = "Extract numerical values array from a tensor";
+  let summary = "Extracts numerical values array from a tensor";
   let description = [{
-    Returns the values array of the sparse storage scheme for the given
+    Returns the values array of the sparse storage format for the given
     sparse tensor, independent of the actual dimension. This is similar to
     the `memref.buffer_cast` operation in the sense that it provides a bridge
     between a tensor world view and a bufferized world view. Unlike the
@@ -203,33 +184,94 @@ def SparseTensor_ToValuesOp : SparseTensor_Op<"values", [NoSideEffect]>,
   let assemblyFormat = "$tensor attr-dict `:` type($tensor) `to` type($result)";
 }
 
-def SparseTensor_ToTensorOp : SparseTensor_Op<"tensor", [NoSideEffect]>,
-    Arguments<(ins Variadic<AnyStridedMemRefOfRank<1>>:$memrefs)>,
+//===----------------------------------------------------------------------===//
+// Sparse Tensor Management Operations. These operations are "impure" in the
+// sense that they do not properly operate on SSA values. Instead, the behavior
+// is solely defined by side-effects. These operations provide a bridge between
+// the code generator and the support library. The semantics of these operations
+// may be refined over time as our sparse abstractions evolve.
+//===----------------------------------------------------------------------===//
+
+def SparseTensor_LexInsertOp : SparseTensor_Op<"lex_insert", []>,
+    Arguments<(ins AnyTensor:$tensor,
+               StridedMemRefRankOf<[Index], [1]>:$indices,
+	       AnyType:$value)> {
+  string summary = "Inserts a value into given sparse tensor in lexicograph index order";
+  string description = [{
+    Inserts the given value at given indices into the underlying sparse
+    storage format of the given tensor with the given indices. This
+    operation can only be applied when a tensor materializes unintialized
+    with an `init` operation, the insertions occur in strict lexicographic
+    index order, and the final tensor is constructed with a `tensor`
+    operation that has the `hasInserts` attribute set.
+
+    Note that this operation is "impure" in the sense that its behavior
+    is solely defined by side-effects and not SSA values. The semantics
+    may be refined over time as our sparse abstractions evolve.
+
+    ```mlir
+    sparse_tensor.lex_insert %tensor, %indices, %val
+      : tensor<1024x1024xf64, #CSR>, memref<?xindex>, f64
+    ```
+  }];
+  let assemblyFormat = "$tensor `,` $indices `,` $value attr-dict `:`"
+                       " type($tensor) `,` type($indices) `,` type($value)";
+}
+
+def SparseTensor_LoadOp : SparseTensor_Op<"load", [SameOperandsAndResultType]>,
+    Arguments<(ins AnyTensor:$tensor, UnitAttr:$hasInserts)>,
     Results<(outs AnyTensor:$result)> {
-  let summary = "Rematerializes tensor from arrays(s)";
+  let summary =
+    "Rematerializes tensor from underlying sparse storage format";
   let description = [{
-    Rematerializes the sparse tensor from the sparse storage scheme array(s).
-    This is similar to the `memref.load` operation in the sense that it
-    provides a bridge between a bufferized world view and a tensor world
-    view. Unlike the `memref.load` operation, however, this sparse operation
-    is used only temporarily to maintain a correctly typed intermediate
-    representation during progressive bufferization. Eventually the operation
-    is folded away.
-
-    The input arrays are defined unambigously by the sparsity annotations
-    (pointers and indices for overhead storage in every compressed dimension,
-    followed by one final values array).
+    Rematerializes a tensor from the underlying sparse storage format of the
+    given tensor. This is similar to the `memref.load` operation in the sense
+    that it provides a bridge between a bufferized world view and a tensor
+    world view. Unlike the `memref.load` operation, however, this sparse
+    operation is used only temporarily to maintain a correctly typed
+    intermediate representation during progressive bufferization.
+
+    The `hasInserts` attribute denote whether insertions to the underlying
+    sparse storage format may have occurred, in which case the underlying
+    sparse storage format needs to be finalized. Otherwise, the operation
+    simply folds away.
+
+    Note that this operation is "impure" in the sense that its behavior
+    is solely defined by side-effects and not SSA values. The semantics
+    may be refined over time as our sparse abstractions evolve.
 
-    Examples:
+    Example:
 
     ```mlir
-    %1 = sparse_tensor.tensor %0 : memref<?xf64> to tensor<64x64xf64, #Dense>
+    %1 = sparse_tensor.load %0 : tensor<8xf64, #SV>
+    ```
+  }];
+  let assemblyFormat = "$tensor (`hasInserts` $hasInserts^)? attr-dict `:` type($tensor)";
+}
+
+def SparseTensor_ReleaseOp : SparseTensor_Op<"release", []>,
+    Arguments<(ins AnyTensor:$tensor)> {
+  string summary = "Releases underlying sparse storage format of given tensor";
+  string description = [{
+    Releases the underlying sparse storage format for a tensor that
+    materialized earlier through a `new` operator, `init` operator, or a
+    `convert` operator with an annotated tensor type as destination (unless
+    that convert is folded away since the source and destination types were
+    identical). This operation should only be called once for any materialized
+    tensor.  Also, after this operation, any subsequent `memref` querying
+    operation on the tensor returns undefined results.
+
+    Note that this operation is "impure" in the sense that its behavior
+    is solely defined by side-effects and not SSA values. The semantics
+    may be refined over time as our sparse abstractions evolve.
 
-    %3 = sparse_tensor.tensor %0, %1, %2 :
-       memref<?xindex>, memref<?xindex>, memref<?xf32> to tensor<10x10xf32, #CSR>
+    Example:
+
+    ```mlir
+    sparse_tensor.release %tensor : tensor<1024x1024xf64, #CSR>
     ```
   }];
-  let assemblyFormat = "$memrefs attr-dict `:` type($memrefs) `to` type($result)";
+  let assemblyFormat = "$tensor attr-dict `:` type($tensor)";
 }
 
 #endif // SPARSETENSOR_OPS

diff  --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index bd8018155bf8a..6cbb7f5dc7b21 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -263,12 +263,6 @@ OpFoldResult ConvertOp::fold(ArrayRef<Attribute> operands) {
   return {};
 }
 
-static LogicalResult verify(ReleaseOp op) {
-  if (!getSparseTensorEncoding(op.tensor().getType()))
-    return op.emitError("expected a sparse tensor to release");
-  return success();
-}
-
 static LogicalResult verify(ToPointersOp op) {
   if (auto e = getSparseTensorEncoding(op.tensor().getType())) {
     if (failed(isInBounds(op.dim(), op.tensor())))
@@ -301,9 +295,25 @@ static LogicalResult verify(ToValuesOp op) {
   return success();
 }
 
-static LogicalResult verify(ToTensorOp op) {
-  if (!getSparseTensorEncoding(op.result().getType()))
-    return op.emitError("expected a sparse tensor result");
+//===----------------------------------------------------------------------===//
+// TensorDialect Management Operations.
+//===----------------------------------------------------------------------===//
+
+static LogicalResult verify(LexInsertOp op) {
+  if (!getSparseTensorEncoding(op.tensor().getType()))
+    return op.emitError("expected a sparse tensor for insertion");
+  return success();
+}
+
+static LogicalResult verify(LoadOp op) {
+  if (!getSparseTensorEncoding(op.tensor().getType()))
+    return op.emitError("expected a sparse tensor to materialize");
+  return success();
+}
+
+static LogicalResult verify(ReleaseOp op) {
+  if (!getSparseTensorEncoding(op.tensor().getType()))
+    return op.emitError("expected a sparse tensor to release");
   return success();
 }
 

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
index 77e0ff16ba28a..3633ff02d83fa 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
@@ -839,32 +839,53 @@ class SparseTensorToValuesConverter : public OpConversionPattern<ToValuesOp> {
   }
 };
 
-/// Sparse conversion rule for tensor reconstruction.
-class SparseTensorToTensorConverter : public OpConversionPattern<ToTensorOp> {
+/// Sparse conversion rule for tensor rematerialization.
+class SparseTensorLoadConverter : public OpConversionPattern<LoadOp> {
 public:
   using OpConversionPattern::OpConversionPattern;
   LogicalResult
-  // Simply fold the operator into the pointer to the sparse storage scheme.
-  matchAndRewrite(ToTensorOp op, OpAdaptor adaptor,
+  matchAndRewrite(LoadOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    // Check that all arguments of the tensor reconstruction operators are calls
-    // into the support library that query exactly the same opaque pointer.
-    Value ptr;
-    for (Value op : adaptor.getOperands()) {
-      if (auto call = op.getDefiningOp<CallOp>()) {
-        Value arg = call.getOperand(0);
-        if (!arg.getType().isa<LLVM::LLVMPointerType>())
-          return failure();
-        if (!ptr)
-          ptr = arg;
-        else if (arg != ptr)
-          return failure();
-      }
+    if (op.hasInserts()) {
+      // Finalize any pending insertions.
+      StringRef name = "endInsert";
+      TypeRange noTp;
+      auto fn = getFunc(op, name, noTp, adaptor.getOperands());
+      rewriter.create<CallOp>(op.getLoc(), noTp, fn, adaptor.getOperands());
     }
-    // If a single opaque pointer is found, perform the folding.
-    if (!ptr)
-      return failure();
-    rewriter.replaceOp(op, ptr);
+    rewriter.replaceOp(op, adaptor.getOperands());
+    return success();
+  }
+};
+
+/// Sparse conversion rule for inserting in lexicographic index order.
+class SparseTensorLexInsertConverter : public OpConversionPattern<LexInsertOp> {
+public:
+  using OpConversionPattern::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(LexInsertOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Type srcType = op.tensor().getType();
+    Type eltType = srcType.cast<ShapedType>().getElementType();
+    StringRef name;
+    if (eltType.isF64())
+      name = "lexInsertF64";
+    else if (eltType.isF32())
+      name = "lexInsertF32";
+    else if (eltType.isInteger(64))
+      name = "lexInsertI64";
+    else if (eltType.isInteger(32))
+      name = "lexInsertI32";
+    else if (eltType.isInteger(16))
+      name = "lexInsertI16";
+    else if (eltType.isInteger(8))
+      name = "lexInsertI8";
+    else
+      llvm_unreachable("Unknown element type");
+    TypeRange noTp;
+    auto fn =
+        getFunc(op, name, noTp, adaptor.getOperands(), /*emitCInterface=*/true);
+    rewriter.replaceOpWithNewOp<CallOp>(op, noTp, fn, adaptor.getOperands());
     return success();
   }
 };
@@ -884,6 +905,6 @@ void mlir::populateSparseTensorConversionPatterns(TypeConverter &typeConverter,
                SparseTensorInitConverter, SparseTensorConvertConverter,
                SparseTensorReleaseConverter, SparseTensorToPointersConverter,
                SparseTensorToIndicesConverter, SparseTensorToValuesConverter,
-               SparseTensorToTensorConverter>(typeConverter,
-                                              patterns.getContext());
+               SparseTensorLoadConverter, SparseTensorLexInsertConverter>(
+      typeConverter, patterns.getContext());
 }

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index 8dda6c991d5cc..cdcbfc2a54adc 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -44,14 +44,16 @@ enum Reduction { kNoReduc, kSum, kProduct, kAnd, kOr, kXor };
 
 // Code generation.
 struct CodeGen {
-  CodeGen(SparsificationOptions o, unsigned numTensors, unsigned numLoops)
+  CodeGen(SparsificationOptions o, unsigned numTensors, unsigned numLoops,
+          OpOperand *op)
       : options(o), loops(numLoops), sizes(numLoops), buffers(numTensors),
         pointers(numTensors, std::vector<Value>(numLoops)),
         indices(numTensors, std::vector<Value>(numLoops)),
         highs(numTensors, std::vector<Value>(numLoops)),
         pidxs(numTensors, std::vector<Value>(numLoops)),
         idxs(numTensors, std::vector<Value>(numLoops)), redExp(-1u), redVal(),
-        redKind(kNoReduc), curVecLength(1), curVecMask() {}
+        redKind(kNoReduc), sparseOut(op), lexIdx(), curVecLength(1),
+        curVecMask() {}
   /// Sparsification options.
   SparsificationOptions options;
   /// Universal dense indices and upper bounds (by index). The loops array
@@ -76,6 +78,9 @@ struct CodeGen {
   unsigned redExp;
   Value redVal;
   Reduction redKind;
+  // Sparse tensor as output.
+  OpOperand *sparseOut;
+  Value lexIdx;
   // Current vector length and mask.
   unsigned curVecLength;
   Value curVecMask;
@@ -274,7 +279,7 @@ static bool isInPlace(Value val) {
   return false;
 }
 
-/// Returns true if tensor materializes into the computation.
+/// Returns true if tensor materializes uninitialized into the computation.
 static bool isMaterializing(Value val) {
   return val.getDefiningOp<linalg::InitTensorOp>() ||
          val.getDefiningOp<InitOp>();
@@ -283,8 +288,9 @@ static bool isMaterializing(Value val) {
 /// Returns true when the tensor expression is admissable for codegen.
 /// Since all sparse input tensors are admissable, we just need to check
 /// whether the output tensor in the tensor expression codegen is admissable.
+/// Sets `sparseOut` when a "truly dynamic" sparse tensor output occurs.
 static bool isAdmissableTensorExp(Merger &merger, linalg::GenericOp op,
-                                  unsigned exp) {
+                                  unsigned exp, OpOperand **sparseOut) {
   OpOperand *lhs = op.getOutputOperand(0);
   unsigned tensor = lhs->getOperandNumber();
   auto enc = getSparseTensorEncoding(lhs->get().getType());
@@ -307,10 +313,24 @@ static bool isAdmissableTensorExp(Merger &merger, linalg::GenericOp op,
   // but not its nonzero structure, an operation called "simply dynamic" in
   // [Bik96,Ch9], is also admissable without special codegen, provided
   // the tensor's underlying sparse storage scheme can be modified in place.
-  if (merger.isConjunction(tensor, exp))
-    return isInPlace(lhs->get());
-  // Reject for now since this requires changes to the nonzero structure.
-  // TODO: implement "workspaces" [Kjolstad2019]
+  if (merger.isConjunction(tensor, exp) && isInPlace(lhs->get()))
+    return true;
+  // Accept "truly dynamic" if the output tensor materializes uninitialized
+  // into the computation and insertions occur in lexicographic index order.
+  if (isMaterializing(lhs->get())) {
+    // In this first sparse tensor output implementation, this is enforced by
+    // rejecting any reduction loops (since the sparse parallel loops give a
+    // lexicographically sorted and injective view into that tensor).
+    // TODO: generalize to include reductions
+    for (auto attr : op.iterator_types())
+      if (isReductionIterator(attr))
+        return false;
+    // TODO: generalize support lib beyond vectors
+    if (op.iterator_types().size() != 1)
+      return false;
+    *sparseOut = lhs;
+    return true;
+  }
   return false;
 }
 
@@ -517,6 +537,12 @@ static void genBuffers(Merger &merger, CodeGen &codegen,
       else
         codegen.buffers[tensor] =
             genOutputBuffer(codegen, rewriter, op, denseTp, args);
+    } else if (t == codegen.sparseOut) {
+      // True sparse output needs a lexIdx array.
+      Value rank = rewriter.create<arith::ConstantIndexOp>(loc, op.getRank(t));
+      auto dynShape = {ShapedType::kDynamicSize};
+      auto memTp = MemRefType::get(dynShape, rewriter.getIndexType());
+      codegen.lexIdx = rewriter.create<memref::AllocaOp>(loc, memTp, rank);
     } else {
       // Annotated sparse tensors.
       auto dynShape = {ShapedType::kDynamicSize};
@@ -691,22 +717,28 @@ static Value genTensorLoad(Merger &merger, CodeGen &codegen,
 static void genTensorStore(Merger &merger, CodeGen &codegen,
                            PatternRewriter &rewriter, linalg::GenericOp op,
                            Value rhs) {
+  Location loc = op.getLoc();
   // Test if this is a scalarized reduction.
   if (codegen.redVal) {
     if (codegen.curVecLength > 1)
-      rhs = rewriter.create<SelectOp>(op.getLoc(), codegen.curVecMask, rhs,
+      rhs = rewriter.create<SelectOp>(loc, codegen.curVecMask, rhs,
                                       codegen.redVal);
     updateReduc(merger, codegen, rhs);
     return;
   }
+  // Insertion.
+  OpOperand *t = op.getOutputOperand(0);
+  if (t == codegen.sparseOut) {
+    rewriter.create<LexInsertOp>(loc, t->get(), codegen.lexIdx, rhs);
+    return;
+  }
   // Actual store.
   SmallVector<Value, 4> args;
-  OpOperand *t = op.getOutputOperand(0);
   Value ptr = genSubscript(codegen, rewriter, op, t, args);
   if (codegen.curVecLength > 1)
     genVectorStore(codegen, rewriter, rhs, ptr, args);
   else
-    rewriter.create<memref::StoreOp>(op.getLoc(), rhs, ptr, args);
+    rewriter.create<memref::StoreOp>(loc, rhs, ptr, args);
 }
 
 /// Generates a pointer/index load from the sparse storage scheme. Narrower
@@ -978,9 +1010,11 @@ static Operation *genFor(Merger &merger, CodeGen &codegen,
   auto iteratorTypes = op.iterator_types().getValue();
   bool isReduction = isReductionIterator(iteratorTypes[idx]);
   bool isSparse = merger.isDim(fb, Dim::kSparse);
-  bool isVector = isVectorFor(codegen, isInner, isSparse) &&
+  bool isVector = !codegen.sparseOut &&
+                  isVectorFor(codegen, isInner, isSparse) &&
                   denseUnitStrides(merger, op, idx);
   bool isParallel =
+      !codegen.sparseOut &&
       isParallelFor(codegen, isOuter, isReduction, isSparse, isVector);
 
   // Prepare vector length.
@@ -1162,6 +1196,13 @@ static void genLocals(Merger &merger, CodeGen &codegen,
           codegen, rewriter, loc, codegen.sizes[idx], p, codegen.loops[idx]);
     }
   }
+
+  // Move the insertion indices in lexicographic index order.
+  if (codegen.sparseOut) {
+    Value pos = rewriter.create<arith::ConstantIndexOp>(loc, at);
+    rewriter.create<memref::StoreOp>(loc, codegen.loops[idx], codegen.lexIdx,
+                                     pos);
+  }
 }
 
 /// Generates the induction structure for a while-loop.
@@ -1414,36 +1455,20 @@ static void genStmt(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter,
 /// Converts the result computed by the sparse kernel into the required form.
 static void genResult(Merger &merger, CodeGen &codegen,
                       PatternRewriter &rewriter, linalg::GenericOp op) {
-  Location loc = op.getLoc();
   OpOperand *lhs = op.getOutputOperand(0);
   Type resType = lhs->get().getType();
-  unsigned tensor = lhs->getOperandNumber();
-  auto map = op.getTiedIndexingMap(lhs);
-  auto enc = getSparseTensorEncoding(resType);
-  Value result = codegen.buffers.back(); // value array
-  if (enc) {
-    // The sparse annotation unambigiously defines the arrays needed
-    // to "reconstruct" the sparse tensor from the storage scheme
-    // (even though lowering should never need this eventually).
-    SmallVector<Value, 4> args;
-    for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) {
-      AffineExpr a = map.getResult(perm(enc, d));
-      if (a.getKind() != AffineExprKind::DimId)
-        continue; // compound
-      unsigned idx = a.cast<AffineDimExpr>().getPosition();
-      if (merger.isDim(tensor, idx, Dim::kSparse)) {
-        args.push_back(codegen.pointers[tensor][idx]);
-        args.push_back(codegen.indices[tensor][idx]);
-      }
-    }
-    args.push_back(result);
-    result = rewriter.create<ToTensorOp>(loc, resType, args);
+  Value result;
+  if (getSparseTensorEncoding(resType)) {
+    // The sparse tensor rematerializes from the original sparse tensor's
+    // underlying sparse storage format.
+    rewriter.replaceOpWithNewOp<LoadOp>(op, resType, lhs->get(),
+                                        codegen.sparseOut == lhs);
   } else {
-    // To "reconstruct" an non-annotated tensor, sipmly load it
+    // To rematerialize an non-annotated tensor, simply load it
     // from the bufferized value.
-    result = rewriter.create<memref::TensorLoadOp>(loc, resType, result);
+    Value val = codegen.buffers.back(); // value array
+    rewriter.replaceOpWithNewOp<memref::TensorLoadOp>(op, resType, val);
   }
-  rewriter.replaceOp(op, result);
 }
 
 //===----------------------------------------------------------------------===//
@@ -1489,11 +1514,12 @@ struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
     unsigned exp = optExp.getValue();
 
     // Rejects an inadmissable tensor expression.
-    if (!isAdmissableTensorExp(merger, op, exp))
+    OpOperand *sparseOut = nullptr;
+    if (!isAdmissableTensorExp(merger, op, exp, &sparseOut))
       return failure();
 
     // Recursively generates code.
-    CodeGen codegen(options, numTensors, numLoops);
+    CodeGen codegen(options, numTensors, numLoops, sparseOut);
     genBuffers(merger, codegen, rewriter, op);
     genStmt(merger, codegen, rewriter, op, topSort, exp, 0);
     genResult(merger, codegen, rewriter, op);

diff  --git a/mlir/lib/ExecutionEngine/SparseTensorUtils.cpp b/mlir/lib/ExecutionEngine/SparseTensorUtils.cpp
index 52396d4ce6fcd..69e678ce3657b 100644
--- a/mlir/lib/ExecutionEngine/SparseTensorUtils.cpp
+++ b/mlir/lib/ExecutionEngine/SparseTensorUtils.cpp
@@ -163,6 +163,7 @@ struct SparseTensorCOO {
 /// function overloading to implement "partial" method specialization.
 class SparseTensorStorageBase {
 public:
+  // Dimension size query.
   virtual uint64_t getDimSize(uint64_t) = 0;
 
   // Overhead storage.
@@ -183,6 +184,15 @@ class SparseTensorStorageBase {
   virtual void getValues(std::vector<int16_t> **) { fatal("vali16"); }
   virtual void getValues(std::vector<int8_t> **) { fatal("vali8"); }
 
+  // Element-wise insertion in lexicographic index order.
+  virtual void lexInsert(uint64_t *, double) { fatal("insf64"); }
+  virtual void lexInsert(uint64_t *, float) { fatal("insf32"); }
+  virtual void lexInsert(uint64_t *, int64_t) { fatal("insi64"); }
+  virtual void lexInsert(uint64_t *, int32_t) { fatal("insi32"); }
+  virtual void lexInsert(uint64_t *, int16_t) { fatal("ins16"); }
+  virtual void lexInsert(uint64_t *, int8_t) { fatal("insi8"); }
+  virtual void endInsert() = 0;
+
   virtual ~SparseTensorStorageBase() {}
 
 private:
@@ -205,20 +215,25 @@ class SparseTensorStorage : public SparseTensorStorageBase {
   /// permutation, and per-dimension dense/sparse annotations, using
   /// the coordinate scheme tensor for the initial contents if provided.
   SparseTensorStorage(const std::vector<uint64_t> &szs, const uint64_t *perm,
-                      const DimLevelType *sparsity, SparseTensorCOO<V> *tensor)
-      : sizes(szs), rev(getRank()), pointers(getRank()), indices(getRank()) {
+                      const DimLevelType *sparsity,
+                      SparseTensorCOO<V> *tensor = nullptr)
+      : sizes(szs), rev(getRank()), idx(getRank()), pointers(getRank()),
+        indices(getRank()) {
     uint64_t rank = getRank();
     // Store "reverse" permutation.
     for (uint64_t r = 0; r < rank; r++)
       rev[perm[r]] = r;
     // Provide hints on capacity of pointers and indices.
     // TODO: needs fine-tuning based on sparsity
-    for (uint64_t r = 0, s = 1; r < rank; r++) {
-      s *= sizes[r];
+    bool allDense = true;
+    uint64_t sz = 1;
+    for (uint64_t r = 0; r < rank; r++) {
+      sz *= sizes[r];
       if (sparsity[r] == DimLevelType::kCompressed) {
-        pointers[r].reserve(s + 1);
-        indices[r].reserve(s);
-        s = 1;
+        pointers[r].reserve(sz + 1);
+        indices[r].reserve(sz);
+        sz = 1;
+        allDense = false;
       } else {
         assert(sparsity[r] == DimLevelType::kDense &&
                "singleton not yet supported");
@@ -233,6 +248,11 @@ class SparseTensorStorage : public SparseTensorStorageBase {
       uint64_t nnz = tensor->getElements().size();
       values.reserve(nnz);
       fromCOO(tensor, sparsity, 0, nnz, 0);
+    } else {
+      if (allDense)
+        values.resize(sz, 0);
+      for (uint64_t r = 0; r < rank; r++)
+        idx[r] = -1u;
     }
   }
 
@@ -247,7 +267,7 @@ class SparseTensorStorage : public SparseTensorStorageBase {
     return sizes[d];
   }
 
-  // Partially specialize these three methods based on template types.
+  /// Partially specialize these getter methods based on template types.
   void getPointers(std::vector<P> **out, uint64_t d) override {
     assert(d < getRank());
     *out = &pointers[d];
@@ -258,6 +278,18 @@ class SparseTensorStorage : public SparseTensorStorageBase {
   }
   void getValues(std::vector<V> **out) override { *out = &values; }
 
+  /// Partially specialize lexicographic insertions based on template types.
+  // TODO: 1-dim tensors only for now, generalize soon
+  void lexInsert(uint64_t *cursor, V val) override {
+    assert((idx[0] == -1u || idx[0] < cursor[0]) && "not lexicographic");
+    indices[0].push_back(cursor[0]);
+    values.push_back(val);
+    idx[0] = cursor[0];
+  }
+
+  /// Finalizes lexicographic insertions.
+  void endInsert() override { pointers[0].push_back(indices[0].size()); }
+
   /// Returns this sparse tensor storage scheme as a new memory-resident
   /// sparse tensor in coordinate scheme with the given dimension order.
   SparseTensorCOO<V> *toCOO(const uint64_t *perm) {
@@ -275,8 +307,7 @@ class SparseTensorStorage : public SparseTensorStorageBase {
     std::vector<uint64_t> reord(rank);
     for (uint64_t r = 0; r < rank; r++)
       reord[r] = perm[rev[r]];
-    std::vector<uint64_t> idx(rank);
-    toCOO(tensor, reord, idx, 0, 0);
+    toCOO(tensor, reord, 0, 0);
     assert(tensor->getElements().size() == values.size());
     return tensor;
   }
@@ -302,7 +333,7 @@ class SparseTensorStorage : public SparseTensorStorageBase {
       std::vector<uint64_t> permsz(rank);
       for (uint64_t r = 0; r < rank; r++)
         permsz[perm[r]] = sizes[r];
-      n = new SparseTensorStorage<P, I, V>(permsz, perm, sparsity, tensor);
+      n = new SparseTensorStorage<P, I, V>(permsz, perm, sparsity);
     }
     return n;
   }
@@ -315,29 +346,29 @@ class SparseTensorStorage : public SparseTensorStorageBase {
                uint64_t lo, uint64_t hi, uint64_t d) {
     const std::vector<Element<V>> &elements = tensor->getElements();
     // Once dimensions are exhausted, insert the numerical values.
+    assert(d <= getRank());
     if (d == getRank()) {
       assert(lo >= hi || lo < elements.size());
       values.push_back(lo < hi ? elements[lo].value : 0);
       return;
     }
-    assert(d < getRank());
     // Visit all elements in this interval.
     uint64_t full = 0;
     while (lo < hi) {
       assert(lo < elements.size() && hi <= elements.size());
       // Find segment in interval with same index elements in this dimension.
-      uint64_t idx = elements[lo].indices[d];
+      uint64_t i = elements[lo].indices[d];
       uint64_t seg = lo + 1;
-      while (seg < hi && elements[seg].indices[d] == idx)
+      while (seg < hi && elements[seg].indices[d] == i)
         seg++;
       // Handle segment in interval for sparse or dense dimension.
       if (sparsity[d] == DimLevelType::kCompressed) {
-        indices[d].push_back(idx);
+        indices[d].push_back(i);
       } else {
         // For dense storage we must fill in all the zero values between
         // the previous element (when last we ran this for-loop) and the
         // current element.
-        for (; full < idx; full++)
+        for (; full < i; full++)
           fromCOO(tensor, sparsity, 0, 0, d + 1); // pass empty
         full++;
       }
@@ -359,7 +390,7 @@ class SparseTensorStorage : public SparseTensorStorageBase {
   /// Stores the sparse tensor storage scheme into a memory-resident sparse
   /// tensor in coordinate scheme.
   void toCOO(SparseTensorCOO<V> *tensor, std::vector<uint64_t> &reord,
-             std::vector<uint64_t> &idx, uint64_t pos, uint64_t d) {
+             uint64_t pos, uint64_t d) {
     assert(d <= getRank());
     if (d == getRank()) {
       assert(pos < values.size());
@@ -368,13 +399,13 @@ class SparseTensorStorage : public SparseTensorStorageBase {
       // Dense dimension.
       for (uint64_t i = 0, sz = sizes[d], off = pos * sz; i < sz; i++) {
         idx[reord[d]] = i;
-        toCOO(tensor, reord, idx, off + i, d + 1);
+        toCOO(tensor, reord, off + i, d + 1);
       }
     } else {
       // Sparse dimension.
       for (uint64_t ii = pointers[d][pos]; ii < pointers[d][pos + 1]; ii++) {
         idx[reord[d]] = indices[d][ii];
-        toCOO(tensor, reord, idx, ii, d + 1);
+        toCOO(tensor, reord, ii, d + 1);
       }
     }
   }
@@ -382,6 +413,7 @@ class SparseTensorStorage : public SparseTensorStorageBase {
 private:
   std::vector<uint64_t> sizes; // per-dimension sizes
   std::vector<uint64_t> rev;   // "reverse" permutation
+  std::vector<uint64_t> idx;   // index cursor
   std::vector<std::vector<P>> pointers;
   std::vector<std::vector<I>> indices;
   std::vector<V> values;
@@ -498,7 +530,7 @@ static SparseTensorCOO<V> *openSparseTensorCOO(char *filename, uint64_t rank,
   //  Read all nonzero elements.
   std::vector<uint64_t> indices(rank);
   for (uint64_t k = 0; k < nnz; k++) {
-    uint64_t idx = -1;
+    uint64_t idx = -1u;
     for (uint64_t r = 0; r < rank; r++) {
       if (fscanf(file, "%" PRIu64, &idx) != 1) {
         fprintf(stderr, "Cannot find next index in %s\n", filename);
@@ -635,6 +667,15 @@ typedef uint64_t index_t;
     return true;                                                               \
   }
 
+#define IMPL_LEXINSERT(NAME, V)                                                \
+  void _mlir_ciface_##NAME(void *tensor, StridedMemRefType<index_t, 1> *cref,  \
+                           V val) {                                            \
+    assert(cref->strides[0] == 1);                                             \
+    uint64_t *cursor = cref->data + cref->offset;                              \
+    assert(cursor);                                                            \
+    static_cast<SparseTensorStorageBase *>(tensor)->lexInsert(cursor, val);    \
+  }
+
 /// Constructs a new sparse tensor. This is the "swiss army knife"
 /// method for materializing sparse tensors into the computation.
 ///
@@ -786,11 +827,20 @@ IMPL_GETNEXT(getNextI32, int32_t)
 IMPL_GETNEXT(getNextI16, int16_t)
 IMPL_GETNEXT(getNextI8, int8_t)
 
+/// Helper to insert elements in lexicograph index order, one per value type.
+IMPL_LEXINSERT(lexInsertF64, double)
+IMPL_LEXINSERT(lexInsertF32, float)
+IMPL_LEXINSERT(lexInsertI64, int64_t)
+IMPL_LEXINSERT(lexInsertI32, int32_t)
+IMPL_LEXINSERT(lexInsertI16, int16_t)
+IMPL_LEXINSERT(lexInsertI8, int8_t)
+
 #undef CASE
 #undef IMPL_SPARSEVALUES
 #undef IMPL_GETOVERHEAD
 #undef IMPL_ADDELT
 #undef IMPL_GETNEXT
+#undef IMPL_INSERTLEX
 
 //===----------------------------------------------------------------------===//
 //
@@ -815,6 +865,11 @@ index_t sparseDimSize(void *tensor, index_t d) {
   return static_cast<SparseTensorStorageBase *>(tensor)->getDimSize(d);
 }
 
+/// Finalizes lexicographic insertions.
+void endInsert(void *tensor) {
+  return static_cast<SparseTensorStorageBase *>(tensor)->endInsert();
+}
+
 /// Releases sparse tensor storage.
 void delSparseTensor(void *tensor) {
   delete static_cast<SparseTensorStorageBase *>(tensor);

diff  --git a/mlir/test/Dialect/SparseTensor/conversion.mlir b/mlir/test/Dialect/SparseTensor/conversion.mlir
index 2d74f1db4e495..45f806b6f83af 100644
--- a/mlir/test/Dialect/SparseTensor/conversion.mlir
+++ b/mlir/test/Dialect/SparseTensor/conversion.mlir
@@ -1,9 +1,5 @@
 // RUN: mlir-opt %s --sparse-tensor-conversion --canonicalize --cse | FileCheck %s
 
-#DenseVector = #sparse_tensor.encoding<{
-  dimLevelType = ["dense"]
-}>
-
 #SparseVector = #sparse_tensor.encoding<{
   dimLevelType = ["compressed"]
 }>
@@ -415,23 +411,33 @@ func @sparse_valuesi8(%arg0: tensor<128xi8, #SparseVector>) -> memref<?xi8> {
   return %0 : memref<?xi8>
 }
 
-// CHECK-LABEL: func @sparse_reconstruct_1(
+// CHECK-LABEL: func @sparse_reconstruct(
 //  CHECK-SAME: %[[A:.*]]: !llvm.ptr<i8>
 //       CHECK: return %[[A]] : !llvm.ptr<i8>
-func @sparse_reconstruct_1(%arg0: tensor<128xf32, #DenseVector> {linalg.inplaceable = true}) -> tensor<128xf32, #DenseVector> {
-  %0 = sparse_tensor.values %arg0 : tensor<128xf32, #DenseVector> to memref<?xf32>
-  %1 = sparse_tensor.tensor %0 : memref<?xf32> to tensor<128xf32, #DenseVector>
-  return %1 : tensor<128xf32, #DenseVector>
+func @sparse_reconstruct(%arg0: tensor<128xf32, #SparseVector>) -> tensor<128xf32, #SparseVector> {
+  %0 = sparse_tensor.load %arg0 : tensor<128xf32, #SparseVector>
+  return %0 : tensor<128xf32, #SparseVector>
 }
 
-// CHECK-LABEL: func @sparse_reconstruct_n(
+// CHECK-LABEL: func @sparse_reconstruct_ins(
 //  CHECK-SAME: %[[A:.*]]: !llvm.ptr<i8>
+//       CHECK: call @endInsert(%[[A]]) : (!llvm.ptr<i8>) -> ()
 //       CHECK: return %[[A]] : !llvm.ptr<i8>
-func @sparse_reconstruct_n(%arg0: tensor<128xf32, #SparseVector> {linalg.inplaceable = true}) -> tensor<128xf32, #SparseVector> {
-  %c = arith.constant 0 : index
-  %0 = sparse_tensor.pointers %arg0, %c : tensor<128xf32, #SparseVector> to memref<?xindex>
-  %1 = sparse_tensor.indices %arg0, %c : tensor<128xf32, #SparseVector> to memref<?xindex>
-  %2 = sparse_tensor.values %arg0 : tensor<128xf32, #SparseVector> to memref<?xf32>
-  %3 = sparse_tensor.tensor %0, %1, %2 : memref<?xindex>, memref<?xindex>, memref<?xf32> to tensor<128xf32, #SparseVector>
-  return %3 : tensor<128xf32, #SparseVector>
+func @sparse_reconstruct_ins(%arg0: tensor<128xf32, #SparseVector>) -> tensor<128xf32, #SparseVector> {
+  %0 = sparse_tensor.load %arg0 hasInserts : tensor<128xf32, #SparseVector>
+  return %0 : tensor<128xf32, #SparseVector>
+}
+
+// CHECK-LABEL: func @sparse_insert(
+//  CHECK-SAME: %[[A:.*]]: !llvm.ptr<i8>,
+//  CHECK-SAME: %[[B:.*]]: memref<?xindex>,
+//  CHECK-SAME: %[[C:.*]]: f32) {
+//       CHECK: call @lexInsertF32(%[[A]], %[[B]], %[[C]]) : (!llvm.ptr<i8>, memref<?xindex>, f32) -> ()
+//       CHECK: return
+func @sparse_insert(%arg0: tensor<128xf32, #SparseVector>,
+                    %arg1: memref<?xindex>,
+                    %arg2: f32) {
+  sparse_tensor.lex_insert %arg0, %arg1, %arg2 : tensor<128xf32, #SparseVector>, memref<?xindex>, f32
+  return
 }
+

diff  --git a/mlir/test/Dialect/SparseTensor/dense.mlir b/mlir/test/Dialect/SparseTensor/dense.mlir
index 0ddfb8561de9a..f4aa51e592cc9 100644
--- a/mlir/test/Dialect/SparseTensor/dense.mlir
+++ b/mlir/test/Dialect/SparseTensor/dense.mlir
@@ -117,13 +117,13 @@ func @dense2(%arga: tensor<32x16xf32, #DenseMatrix>,
 // The rewriting would fail if argx was not in-placeable.
 //
 // CHECK-LABEL:   func @dense3(
-// CHECK-SAME:                 %[[VAL_0:.*]]: tensor<32x16xf32>,
-// CHECK-SAME:                 %[[VAL_1:.*]]: tensor<32x16xf32, #sparse_tensor.encoding<{{.*}}>> {linalg.inplaceable = true}) -> tensor<32x16xf32, #sparse_tensor.encoding<{{.*}}>> {
-// CHECK:           %[[VAL_2:.*]] = arith.constant 1.000000e+00 : f32
-// CHECK:           %[[VAL_3:.*]] = arith.constant 32 : index
-// CHECK:           %[[VAL_4:.*]] = arith.constant 16 : index
-// CHECK:           %[[VAL_5:.*]] = arith.constant 0 : index
-// CHECK:           %[[VAL_6:.*]] = arith.constant 1 : index
+// CHECK-SAME:      %[[VAL_0:.*]]: tensor<32x16xf32>,
+// CHECK-SAME:      %[[VAL_1:.*]]: tensor<32x16xf32, #sparse_tensor.encoding<{{.*}}>> {linalg.inplaceable = true}) -> tensor<32x16xf32, #sparse_tensor.encoding<{{.*}}>> {
+// CHECK-DAG:       %[[VAL_2:.*]] = arith.constant 1.000000e+00 : f32
+// CHECK-DAG:       %[[VAL_3:.*]] = arith.constant 32 : index
+// CHECK-DAG:       %[[VAL_4:.*]] = arith.constant 16 : index
+// CHECK-DAG:       %[[VAL_5:.*]] = arith.constant 0 : index
+// CHECK-DAG:       %[[VAL_6:.*]] = arith.constant 1 : index
 // CHECK:           %[[VAL_7:.*]] = memref.buffer_cast %[[VAL_0]] : memref<32x16xf32>
 // CHECK:           %[[VAL_8:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<32x16xf32, #sparse_tensor.encoding<{{.*}}>> to memref<?xf32>
 // CHECK:           scf.for %[[VAL_9:.*]] = %[[VAL_5]] to %[[VAL_3]] step %[[VAL_6]] {
@@ -135,7 +135,7 @@ func @dense2(%arga: tensor<32x16xf32, #DenseMatrix>,
 // CHECK:               memref.store %[[VAL_14]], %[[VAL_8]]{{\[}}%[[VAL_12]]] : memref<?xf32>
 // CHECK:             }
 // CHECK:           }
-// CHECK:           %[[VAL_15:.*]] = sparse_tensor.tensor %[[VAL_8]] : memref<?xf32> to tensor<32x16xf32, #sparse_tensor.encoding<{{.*}}>>
+// CHECK:           %[[VAL_15:.*]] = sparse_tensor.load %[[VAL_1]] : tensor<32x16xf32, #sparse_tensor.encoding<{{.*}}>>
 // CHECK:           return %[[VAL_15]] : tensor<32x16xf32, #sparse_tensor.encoding<{{.*}}>>
 // CHECK:         }
 func @dense3(%arga: tensor<32x16xf32>,
@@ -161,13 +161,13 @@ func @dense3(%arga: tensor<32x16xf32>,
 // for by scalarizing the reduction operation for the output tensor.
 //
 // CHECK-LABEL:   func @dense4(
-// CHECK-SAME:                 %[[VAL_0:.*]]: tensor<32x16x8xf32>,
-// CHECK-SAME:                 %[[VAL_1:.*]]: tensor<32x16xf32, #sparse_tensor.encoding<{{.*}}>> {linalg.inplaceable = true}) -> tensor<32x16xf32, #sparse_tensor.encoding<{{.*}}>> {
-// CHECK:           %[[VAL_2:.*]] = arith.constant 8 : index
-// CHECK:           %[[VAL_3:.*]] = arith.constant 32 : index
-// CHECK:           %[[VAL_4:.*]] = arith.constant 16 : index
-// CHECK:           %[[VAL_5:.*]] = arith.constant 0 : index
-// CHECK:           %[[VAL_6:.*]] = arith.constant 1 : index
+// CHECK-SAME:      %[[VAL_0:.*]]: tensor<32x16x8xf32>,
+// CHECK-SAME:      %[[VAL_1:.*]]: tensor<32x16xf32, #sparse_tensor.encoding<{{.*}}>> {linalg.inplaceable = true}) -> tensor<32x16xf32, #sparse_tensor.encoding<{{.*}}>> {
+// CHECK-DAG:       %[[VAL_2:.*]] = arith.constant 8 : index
+// CHECK-DAG:       %[[VAL_3:.*]] = arith.constant 32 : index
+// CHECK-DAG:       %[[VAL_4:.*]] = arith.constant 16 : index
+// CHECK-DAG:       %[[VAL_5:.*]] = arith.constant 0 : index
+// CHECK-DAG:       %[[VAL_6:.*]] = arith.constant 1 : index
 // CHECK:           %[[VAL_7:.*]] = memref.buffer_cast %[[VAL_0]] : memref<32x16x8xf32>
 // CHECK:           %[[VAL_8:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<32x16xf32, #sparse_tensor.encoding<{{.*}}}>> to memref<?xf32>
 // CHECK:           scf.for %[[VAL_9:.*]] = %[[VAL_5]] to %[[VAL_3]] step %[[VAL_6]] {
@@ -183,7 +183,7 @@ func @dense3(%arga: tensor<32x16xf32>,
 // CHECK:               memref.store %[[VAL_19:.*]], %[[VAL_8]]{{\[}}%[[VAL_12]]] : memref<?xf32>
 // CHECK:             }
 // CHECK:           }
-// CHECK:           %[[VAL_20:.*]] = sparse_tensor.tensor %[[VAL_8]] : memref<?xf32> to tensor<32x16xf32, #sparse_tensor.encoding<{{.*}}>>
+// CHECK:           %[[VAL_20:.*]] = sparse_tensor.load %[[VAL_1]] : tensor<32x16xf32, #sparse_tensor.encoding<{{.*}}>>
 // CHECK:           return %[[VAL_20]] : tensor<32x16xf32, #sparse_tensor.encoding<{{.*}}>>
 // CHECK:         }
 func @dense4(%arga: tensor<32x16x8xf32>,

diff  --git a/mlir/test/Dialect/SparseTensor/fold.mlir b/mlir/test/Dialect/SparseTensor/fold.mlir
index 6e3c5ca90ecce..41189eee4271b 100644
--- a/mlir/test/Dialect/SparseTensor/fold.mlir
+++ b/mlir/test/Dialect/SparseTensor/fold.mlir
@@ -1,11 +1,11 @@
 // RUN: mlir-opt %s  --canonicalize --cse | FileCheck %s
 
-#DenseVector  = #sparse_tensor.encoding<{dimLevelType = ["dense"]}>
 #SparseVector = #sparse_tensor.encoding<{dimLevelType = ["compressed"]}>
 
 // CHECK-LABEL: func @sparse_nop_convert(
-//  CHECK-SAME: %[[A:.*]]: tensor<64xf32, #{{.*}}>)
-//       CHECK: return %[[A]] : tensor<64xf32, #{{.*}}>
+//  CHECK-SAME: %[[A:.*]]: tensor<64xf32, #sparse_tensor.encoding<{{{.*}}}>>)
+//   CHECK-NOT: sparse_tensor.convert
+//       CHECK: return %[[A]] : tensor<64xf32, #sparse_tensor.encoding<{{{.*}}}>>
 func @sparse_nop_convert(%arg0: tensor<64xf32, #SparseVector>) -> tensor<64xf32, #SparseVector> {
   %0 = sparse_tensor.convert %arg0 : tensor<64xf32, #SparseVector> to tensor<64xf32, #SparseVector>
   return %0 : tensor<64xf32, #SparseVector>
@@ -33,14 +33,3 @@ func @sparse_dce_getters(%arg0: tensor<64xf32, #SparseVector>) {
   %2 = sparse_tensor.values %arg0 : tensor<64xf32, #SparseVector> to memref<?xf32>
   return
 }
-
-// CHECK-LABEL: func @sparse_dce_reconstruct(
-//  CHECK-SAME: %[[A:.*]]: tensor<64xf32, #sparse_tensor.encoding<{{{.*}}}>>)
-//   CHECK-NOT: sparse_tensor.values
-//   CHECK-NOT: sparse_tensor.tensor
-//       CHECK: return
-func @sparse_dce_reconstruct(%arg0: tensor<64xf32, #DenseVector>) {
-  %0 = sparse_tensor.values %arg0 : tensor<64xf32, #DenseVector> to memref<?xf32>
-  %1 = sparse_tensor.tensor %0 : memref<?xf32> to tensor<64xf32, #DenseVector>
-  return
-}

diff  --git a/mlir/test/Dialect/SparseTensor/invalid.mlir b/mlir/test/Dialect/SparseTensor/invalid.mlir
index 399cf3fba8d50..55c7cc490d741 100644
--- a/mlir/test/Dialect/SparseTensor/invalid.mlir
+++ b/mlir/test/Dialect/SparseTensor/invalid.mlir
@@ -144,14 +144,22 @@ func @mismatch_values_types(%arg0: tensor<?xf64, #SparseVector>) -> memref<?xf32
 
 // -----
 
-func @sparse_to_unannotated_tensor(%arg0: memref<?xf64>) -> tensor<16x32xf64> {
-  // expected-error at +1 {{expected a sparse tensor result}}
-  %0 = sparse_tensor.tensor %arg0 : memref<?xf64> to tensor<16x32xf64>
+func @sparse_unannotated_load(%arg0: tensor<16x32xf64>) -> tensor<16x32xf64> {
+  // expected-error at +1 {{expected a sparse tensor to materialize}}
+  %0 = sparse_tensor.load %arg0 : tensor<16x32xf64>
   return %0 : tensor<16x32xf64>
 }
 
 // -----
 
+func @sparse_unannotated_insert(%arg0: tensor<128xf64>, %arg1: memref<?xindex>, %arg2: f64) {
+  // expected-error at +1 {{expected a sparse tensor for insertion}}
+  sparse_tensor.lex_insert %arg0, %arg1, %arg2 : tensor<128xf64>, memref<?xindex>, f64
+  return
+}
+
+// -----
+
 func @sparse_convert_unranked(%arg0: tensor<*xf32>) -> tensor<10xf32> {
   // expected-error at +1 {{unexpected type in convert}}
   %0 = sparse_tensor.convert %arg0 : tensor<*xf32> to tensor<10xf32>

diff  --git a/mlir/test/Dialect/SparseTensor/roundtrip.mlir b/mlir/test/Dialect/SparseTensor/roundtrip.mlir
index 507eb4e1de520..ad6b90b7918d4 100644
--- a/mlir/test/Dialect/SparseTensor/roundtrip.mlir
+++ b/mlir/test/Dialect/SparseTensor/roundtrip.mlir
@@ -113,11 +113,39 @@ func @sparse_values(%arg0: tensor<128xf64, #SparseVector>) -> memref<?xf64> {
 
 #DenseMatrix = #sparse_tensor.encoding<{dimLevelType = ["dense","dense"]}>
 
-// CHECK-LABEL: func @sparse_to_tensor(
-//  CHECK-SAME: %[[A:.*]]: memref<?xf64>)
-//       CHECK: %[[T:.*]] = sparse_tensor.tensor %[[A]] : memref<?xf64> to tensor<16x32xf64, #{{.*}}>
+// CHECK-LABEL: func @sparse_load(
+//  CHECK-SAME: %[[A:.*]]: tensor<16x32xf64, #{{.*}}>)
+//       CHECK: %[[T:.*]] = sparse_tensor.load %[[A]] : tensor<16x32xf64, #{{.*}}>
 //       CHECK: return %[[T]] : tensor<16x32xf64, #{{.*}}>
-func @sparse_to_tensor(%arg0: memref<?xf64>) -> tensor<16x32xf64, #DenseMatrix> {
-  %0 = sparse_tensor.tensor %arg0 : memref<?xf64> to tensor<16x32xf64, #DenseMatrix>
+func @sparse_load(%arg0: tensor<16x32xf64, #DenseMatrix>) -> tensor<16x32xf64, #DenseMatrix> {
+  %0 = sparse_tensor.load %arg0 : tensor<16x32xf64, #DenseMatrix>
   return %0 : tensor<16x32xf64, #DenseMatrix>
 }
+
+// -----
+
+#DenseMatrix = #sparse_tensor.encoding<{dimLevelType = ["dense","dense"]}>
+
+// CHECK-LABEL: func @sparse_load_ins(
+//  CHECK-SAME: %[[A:.*]]: tensor<16x32xf64, #{{.*}}>)
+//       CHECK: %[[T:.*]] = sparse_tensor.load %[[A]] hasInserts : tensor<16x32xf64, #{{.*}}>
+//       CHECK: return %[[T]] : tensor<16x32xf64, #{{.*}}>
+func @sparse_load_ins(%arg0: tensor<16x32xf64, #DenseMatrix>) -> tensor<16x32xf64, #DenseMatrix> {
+  %0 = sparse_tensor.load %arg0 hasInserts : tensor<16x32xf64, #DenseMatrix>
+  return %0 : tensor<16x32xf64, #DenseMatrix>
+}
+
+// -----
+
+#SparseVector = #sparse_tensor.encoding<{dimLevelType = ["compressed"]}>
+
+// CHECK-LABEL: func @sparse_insert(
+//  CHECK-SAME: %[[A:.*]]: tensor<128xf64, #sparse_tensor.encoding<{{.*}}>>,
+//  CHECK-SAME: %[[B:.*]]: memref<?xindex>,
+//  CHECK-SAME: %[[C:.*]]: f64) {
+//       CHECK: sparse_tensor.lex_insert %[[A]], %[[B]], %[[C]] : tensor<128xf64, #{{.*}}>, memref<?xindex>, f64
+//       CHECK: return
+func @sparse_insert(%arg0: tensor<128xf64, #SparseVector>, %arg1: memref<?xindex>, %arg2: f64) {
+  sparse_tensor.lex_insert %arg0, %arg1, %arg2 : tensor<128xf64, #SparseVector>, memref<?xindex>, f64
+  return
+}

diff  --git a/mlir/test/Dialect/SparseTensor/sparse_out.mlir b/mlir/test/Dialect/SparseTensor/sparse_out.mlir
index 5b2d5f3c632db..e17e3e89bef10 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_out.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_out.mlir
@@ -20,14 +20,12 @@
 }
 
 // CHECK-LABEL:   func @sparse_simply_dynamic1(
-// CHECK-SAME:                                 %[[VAL_0:.*]]: tensor<32x16xf32, #sparse_tensor.encoding<{{.*}}>> {
-// CHECK-DAG:           %[[VAL_1:.*]] = arith.constant 2.000000e+00 : f32
-// CHECK-DAG:           %[[VAL_2:.*]] = arith.constant 0 : index
-// CHECK-DAG:           %[[VAL_3:.*]] = arith.constant 1 : index
+// CHECK-SAME:      %[[VAL_0:.*]]: tensor<32x16xf32, #sparse_tensor.encoding<{{.*}}>> {
+// CHECK-DAG:       %[[VAL_1:.*]] = arith.constant 2.000000e+00 : f32
+// CHECK-DAG:       %[[VAL_2:.*]] = arith.constant 0 : index
+// CHECK-DAG:       %[[VAL_3:.*]] = arith.constant 1 : index
 // CHECK:           %[[VAL_4:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_2]] : tensor<32x16xf32, #sparse_tensor.encoding<{{.*}}>> to memref<?xindex>
-// CHECK:           %[[VAL_5:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_2]] : tensor<32x16xf32, #sparse_tensor.encoding<{{.*}}>> to memref<?xindex>
 // CHECK:           %[[VAL_6:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_3]] : tensor<32x16xf32, #sparse_tensor.encoding<{{.*}}>> to memref<?xindex>
-// CHECK:           %[[VAL_7:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_3]] : tensor<32x16xf32, #sparse_tensor.encoding<{{.*}}>> to memref<?xindex>
 // CHECK:           %[[VAL_8:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x16xf32, #sparse_tensor.encoding<{{.*}}>> to memref<?xf32>
 // CHECK:           %[[VAL_9:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_2]]] : memref<?xindex>
 // CHECK:           %[[VAL_10:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_3]]] : memref<?xindex>
@@ -41,7 +39,7 @@
 // CHECK:               memref.store %[[VAL_17]], %[[VAL_8]]{{\[}}%[[VAL_15]]] : memref<?xf32>
 // CHECK:             }
 // CHECK:           }
-// CHECK:           %[[VAL_18:.*]] = sparse_tensor.tensor %[[VAL_4]], %[[VAL_5]], %[[VAL_6]], %[[VAL_7]], %[[VAL_8]] : memref<?xindex>, memref<?xindex>, memref<?xindex>, memref<?xindex>, memref<?xf32> to tensor<32x16xf32, #sparse_tensor.encoding<{{.*}}>>
+// CHECK:           %[[VAL_18:.*]] = sparse_tensor.load %[[VAL_0]] : tensor<32x16xf32, #sparse_tensor.encoding<{{.*}}>>
 // CHECK:           return %[[VAL_18]] : tensor<32x16xf32, #sparse_tensor.encoding<{{.*}}>>
 // CHECK:         }
 func @sparse_simply_dynamic1(%argx: tensor<32x16xf32, #DCSR> {linalg.inplaceable = true}) -> tensor<32x16xf32, #DCSR> {
@@ -65,10 +63,10 @@ func @sparse_simply_dynamic1(%argx: tensor<32x16xf32, #DCSR> {linalg.inplaceable
 }
 
 // CHECK-LABEL:   func @sparse_simply_dynamic2(
-// CHECK-SAME:                                 %[[VAL_0:.*]]: tensor<32x16xf32, #sparse_tensor.encoding<{{.*}}>>,
-// CHECK-SAME:                                 %[[VAL_1:.*]]: tensor<32x16xf32, #sparse_tensor.encoding<{{.*}}>> {
-// CHECK-DAG:           %[[VAL_2:.*]] = arith.constant 0 : index
-// CHECK-DAG:           %[[VAL_3:.*]] = arith.constant 1 : index
+// CHECK-SAME:      %[[VAL_0:.*]]: tensor<32x16xf32, #sparse_tensor.encoding<{{.*}}>>,
+// CHECK-SAME:      %[[VAL_1:.*]]: tensor<32x16xf32, #sparse_tensor.encoding<{{.*}}>> {
+// CHECK-DAG:       %[[VAL_2:.*]] = arith.constant 0 : index
+// CHECK-DAG:       %[[VAL_3:.*]] = arith.constant 1 : index
 // CHECK:           %[[VAL_4:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_3]] : tensor<32x16xf32, #sparse_tensor.encoding<{{.*}}>> to memref<?xindex>
 // CHECK:           %[[VAL_5:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_3]] : tensor<32x16xf32, #sparse_tensor.encoding<{{.*}}>> to memref<?xindex>
 // CHECK:           %[[VAL_6:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x16xf32, #sparse_tensor.encoding<{{.*}}>> to memref<?xf32>
@@ -117,7 +115,7 @@ func @sparse_simply_dynamic1(%argx: tensor<32x16xf32, #DCSR> {linalg.inplaceable
 // CHECK:               scf.yield %[[VAL_42]], %[[VAL_45]] : index, index
 // CHECK:             }
 // CHECK:           }
-// CHECK:           %[[VAL_46:.*]] = sparse_tensor.tensor %[[VAL_7]], %[[VAL_8]], %[[VAL_9]], %[[VAL_10]], %[[VAL_11]] : memref<?xindex>, memref<?xindex>, memref<?xindex>, memref<?xindex>, memref<?xf32> to tensor<32x16xf32, #sparse_tensor.encoding<{{.*}}>>
+// CHECK:           %[[VAL_46:.*]] = sparse_tensor.load %[[VAL_1]] : tensor<32x16xf32, #sparse_tensor.encoding<{{.*}}>>
 // CHECK:           return %[[VAL_46]] : tensor<32x16xf32, #sparse_tensor.encoding<{{.*}}>>
 // CHECK:         }
 func @sparse_simply_dynamic2(%arga: tensor<32x16xf32, #CSR>,

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_vector_ops.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_vector_ops.mlir
new file mode 100644
index 0000000000000..c6b912b84f47f
--- /dev/null
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_vector_ops.mlir
@@ -0,0 +1,245 @@
+// RUN: mlir-opt %s \
+// RUN:   --sparsification --sparse-tensor-conversion \
+// RUN:   --linalg-bufferize --convert-linalg-to-loops \
+// RUN:   --convert-vector-to-scf --convert-scf-to-std \
+// RUN:   --func-bufferize --tensor-constant-bufferize --tensor-bufferize \
+// RUN:   --std-bufferize --finalizing-bufferize --lower-affine \
+// RUN:   --convert-vector-to-llvm --convert-memref-to-llvm --convert-math-to-llvm \
+// RUN:   --convert-std-to-llvm --reconcile-unrealized-casts | \
+// RUN: mlir-cpu-runner \
+// RUN:  -e entry -entry-point-result=void  \
+// RUN:  -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
+// RUN: FileCheck %s
+
+#SparseVector = #sparse_tensor.encoding<{dimLevelType = ["compressed"]}>
+#DenseVector = #sparse_tensor.encoding<{dimLevelType = ["dense"]}>
+
+//
+// Traits for 1-d tensor (aka vector) operations.
+//
+#trait_scale = {
+  indexing_maps = [
+    affine_map<(i) -> (i)>,  // a (in)
+    affine_map<(i) -> (i)>   // x (out)
+  ],
+  iterator_types = ["parallel"],
+  doc = "x(i) = a(i) * 2.0"
+}
+#trait_scale_inpl = {
+  indexing_maps = [
+    affine_map<(i) -> (i)>   // x (out)
+  ],
+  iterator_types = ["parallel"],
+  doc = "x(i) *= 2.0"
+}
+#trait_op = {
+  indexing_maps = [
+    affine_map<(i) -> (i)>,  // a (in)
+    affine_map<(i) -> (i)>,  // b (in)
+    affine_map<(i) -> (i)>   // x (out)
+  ],
+  iterator_types = ["parallel"],
+  doc = "x(i) = a(i) OP b(i)"
+}
+#trait_dot = {
+  indexing_maps = [
+    affine_map<(i) -> (i)>,  // a (in)
+    affine_map<(i) -> (i)>,  // b (in)
+    affine_map<(i) -> ()>   // x (out)
+  ],
+  iterator_types = ["parallel"],
+  doc = "x(i) += a(i) * b(i)"
+}
+
+module {
+  // Scales a sparse vector into a new sparse vector.
+  func @vector_scale(%arga: tensor<?xf64, #SparseVector>) -> tensor<?xf64, #SparseVector> {
+    %s = arith.constant 2.0 : f64
+    %c = arith.constant 0 : index
+    %d = tensor.dim %arga, %c : tensor<?xf64, #SparseVector>
+    %xv = sparse_tensor.init [%d] : tensor<?xf64, #SparseVector>
+    %0 = linalg.generic #trait_scale
+       ins(%arga: tensor<?xf64, #SparseVector>)
+        outs(%xv: tensor<?xf64, #SparseVector>) {
+        ^bb(%a: f64, %x: f64):
+          %1 = arith.mulf %a, %s : f64
+          linalg.yield %1 : f64
+    } -> tensor<?xf64, #SparseVector>
+    return %0 : tensor<?xf64, #SparseVector>
+  }
+
+  // Scales a sparse vector in place.
+  func @vector_scale_inplace(%argx: tensor<?xf64, #SparseVector>
+                             {linalg.inplaceable = true}) -> tensor<?xf64, #SparseVector> {
+    %s = arith.constant 2.0 : f64
+    %0 = linalg.generic #trait_scale_inpl
+      outs(%argx: tensor<?xf64, #SparseVector>) {
+        ^bb(%x: f64):
+          %1 = arith.mulf %x, %s : f64
+          linalg.yield %1 : f64
+    } -> tensor<?xf64, #SparseVector>
+    return %0 : tensor<?xf64, #SparseVector>
+  }
+
+  // Adds two sparse vectors into a new sparse vector.
+  func @vector_add(%arga: tensor<?xf64, #SparseVector>,
+                   %argb: tensor<?xf64, #SparseVector>) -> tensor<?xf64, #SparseVector> {
+    %c = arith.constant 0 : index
+    %d = tensor.dim %arga, %c : tensor<?xf64, #SparseVector>
+    %xv = sparse_tensor.init [%d] : tensor<?xf64, #SparseVector>
+    %0 = linalg.generic #trait_op
+       ins(%arga, %argb: tensor<?xf64, #SparseVector>, tensor<?xf64, #SparseVector>)
+        outs(%xv: tensor<?xf64, #SparseVector>) {
+        ^bb(%a: f64, %b: f64, %x: f64):
+          %1 = arith.addf %a, %b : f64
+          linalg.yield %1 : f64
+    } -> tensor<?xf64, #SparseVector>
+    return %0 : tensor<?xf64, #SparseVector>
+  }
+
+  // Multiplies two sparse vectors into a new sparse vector.
+  func @vector_mul(%arga: tensor<?xf64, #SparseVector>,
+                   %argb: tensor<?xf64, #SparseVector>) -> tensor<?xf64, #SparseVector> {
+    %c = arith.constant 0 : index
+    %d = tensor.dim %arga, %c : tensor<?xf64, #SparseVector>
+    %xv = sparse_tensor.init [%d] : tensor<?xf64, #SparseVector>
+    %0 = linalg.generic #trait_op
+       ins(%arga, %argb: tensor<?xf64, #SparseVector>, tensor<?xf64, #SparseVector>)
+        outs(%xv: tensor<?xf64, #SparseVector>) {
+        ^bb(%a: f64, %b: f64, %x: f64):
+          %1 = arith.mulf %a, %b : f64
+          linalg.yield %1 : f64
+    } -> tensor<?xf64, #SparseVector>
+    return %0 : tensor<?xf64, #SparseVector>
+  }
+
+  // Multiplies two sparse vectors into a new "annotated" dense vector.
+  func @vector_mul_d(%arga: tensor<?xf64, #SparseVector>,
+                     %argb: tensor<?xf64, #SparseVector>) -> tensor<?xf64, #DenseVector> {
+    %c = arith.constant 0 : index
+    %d = tensor.dim %arga, %c : tensor<?xf64, #SparseVector>
+    %xv = sparse_tensor.init [%d] : tensor<?xf64, #DenseVector>
+    %0 = linalg.generic #trait_op
+       ins(%arga, %argb: tensor<?xf64, #SparseVector>, tensor<?xf64, #SparseVector>)
+        outs(%xv: tensor<?xf64, #DenseVector>) {
+        ^bb(%a: f64, %b: f64, %x: f64):
+          %1 = arith.mulf %a, %b : f64
+          linalg.yield %1 : f64
+    } -> tensor<?xf64, #DenseVector>
+    return %0 : tensor<?xf64, #DenseVector>
+  }
+
+  // Sum reduces dot product of two sparse vectors.
+  func @vector_dotprod(%arga: tensor<?xf64, #SparseVector>,
+                       %argb: tensor<?xf64, #SparseVector>,
+		       %argx: tensor<f64> {linalg.inplaceable = true}) -> tensor<f64> {
+    %0 = linalg.generic #trait_dot
+       ins(%arga, %argb: tensor<?xf64, #SparseVector>, tensor<?xf64, #SparseVector>)
+        outs(%argx: tensor<f64>) {
+        ^bb(%a: f64, %b: f64, %x: f64):
+          %1 = arith.mulf %a, %b : f64
+	  %2 = arith.addf %x, %1 : f64
+          linalg.yield %2 : f64
+    } -> tensor<f64>
+    return %0 : tensor<f64>
+  }
+
+  // Dumps just the values array of the sparse vector.
+  func @dump(%arg0: tensor<?xf64, #SparseVector>) {
+    // Dump the values array to verify only sparse contents are stored.
+    %c0 = arith.constant 0 : index
+    %d0 = arith.constant -1.0 : f64
+    %0 = sparse_tensor.values %arg0 : tensor<?xf64, #SparseVector> to memref<?xf64>
+    %1 = vector.transfer_read %0[%c0], %d0: memref<?xf64>, vector<16xf64>
+    vector.print %1 : vector<16xf64>
+    // Dump the dense vector to verify structure is correct.
+    %dv = sparse_tensor.convert %arg0 : tensor<?xf64, #SparseVector> to tensor<?xf64>
+    %2 = memref.buffer_cast %dv : memref<?xf64>
+    %3 = vector.transfer_read %2[%c0], %d0: memref<?xf64>, vector<32xf64>
+    vector.print %3 : vector<32xf64>
+    memref.dealloc %2 : memref<?xf64>
+    return
+  }
+
+  // Driver method to call and verify vector kernels.
+  func @entry() {
+    %c0 = arith.constant 0 : index
+    %d1 = arith.constant 1.1 : f64
+
+    // Setup sparse vectors.
+    %v1 = arith.constant sparse<
+       [ [0], [3], [11], [17], [20], [21], [28], [29], [31] ],
+         [ 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0 ]
+    > : tensor<32xf64>
+    %v2 = arith.constant sparse<
+       [ [1], [3], [4], [10], [16], [18], [21], [28], [29], [31] ],
+         [11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0 ]
+    > : tensor<32xf64>
+    %sv1 = sparse_tensor.convert %v1 : tensor<32xf64> to tensor<?xf64, #SparseVector>
+    %sv2 = sparse_tensor.convert %v2 : tensor<32xf64> to tensor<?xf64, #SparseVector>
+
+    // Setup memory for a single reduction scalar.
+    %xdata = memref.alloc() : memref<f64>
+    memref.store %d1, %xdata[] : memref<f64>
+    %x = memref.tensor_load %xdata : memref<f64>
+
+    // Call sparse vector kernels.
+    %0 = call @vector_scale(%sv1)
+       : (tensor<?xf64, #SparseVector>) -> tensor<?xf64, #SparseVector>
+    %1 = call @vector_scale_inplace(%sv1)
+       : (tensor<?xf64, #SparseVector>) -> tensor<?xf64, #SparseVector>
+    %2 = call @vector_add(%sv1, %sv2)
+       : (tensor<?xf64, #SparseVector>,
+          tensor<?xf64, #SparseVector>) -> tensor<?xf64, #SparseVector>
+    %3 = call @vector_mul(%sv1, %sv2)
+       : (tensor<?xf64, #SparseVector>,
+          tensor<?xf64, #SparseVector>) -> tensor<?xf64, #SparseVector>
+    %4 = call @vector_mul_d(%sv1, %sv2)
+       : (tensor<?xf64, #SparseVector>,
+          tensor<?xf64, #SparseVector>) -> tensor<?xf64, #DenseVector>
+    %5 = call @vector_dotprod(%sv1, %sv2, %x)
+       : (tensor<?xf64, #SparseVector>,
+          tensor<?xf64, #SparseVector>, tensor<f64>) -> tensor<f64>
+
+    //
+    // Verify the results.
+    //
+    // CHECK:      ( 2, 4, 6, 8, 10, 12, 14, 16, 18, -1, -1, -1, -1, -1, -1, -1 )
+    // CHECK-NEXT: ( 2, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 6, 0, 0, 0, 0, 0, 8, 0, 0, 10, 12, 0, 0, 0, 0, 0, 0, 14, 16, 0, 18 )
+    // CHECK-NEXT: ( 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, -1, -1, -1, -1, -1, -1 )
+    // CHECK-NEXT: ( 0, 11, 0, 12, 13, 0, 0, 0, 0, 0, 14, 0, 0, 0, 0, 0, 15, 0, 16, 0, 0, 17, 0, 0, 0, 0, 0, 0, 18, 19, 0, 20 )
+    // CHECK-NEXT: ( 2, 4, 6, 8, 10, 12, 14, 16, 18, -1, -1, -1, -1, -1, -1, -1 )
+    // CHECK-NEXT: ( 2, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 6, 0, 0, 0, 0, 0, 8, 0, 0, 10, 12, 0, 0, 0, 0, 0, 0, 14, 16, 0, 18 )
+    // CHECK-NEXT: ( 2, 4, 6, 8, 10, 12, 14, 16, 18, -1, -1, -1, -1, -1, -1, -1 )
+    // CHECK-NEXT: ( 2, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 6, 0, 0, 0, 0, 0, 8, 0, 0, 10, 12, 0, 0, 0, 0, 0, 0, 14, 16, 0, 18 )
+    // CHECK-NEXT: ( 2, 11, 16, 13, 14, 6, 15, 8, 16, 10, 29, 32, 35, 38, -1, -1 )
+    // CHECK-NEXT: ( 2, 11, 0, 16, 13, 0, 0, 0, 0, 0, 14, 6, 0, 0, 0, 0, 15, 8, 16, 0, 10, 29, 0, 0, 0, 0, 0, 0, 32, 35, 0, 38 )
+    // CHECK-NEXT: ( 48, 204, 252, 304, 360, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 )
+    // CHECK-NEXT: ( 0, 0, 0, 48, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 204, 0, 0, 0, 0, 0, 0, 252, 304, 0, 360 )
+    // CHECK-NEXT: ( 0, 0, 0, 48, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 204, 0, 0, 0, 0, 0, 0, 252, 304, 0, 360 )
+    // CHECK-NEXT: 1169.1
+    //
+    call @dump(%sv1) : (tensor<?xf64, #SparseVector>) -> ()
+    call @dump(%sv2) : (tensor<?xf64, #SparseVector>) -> ()
+    call @dump(%0) : (tensor<?xf64, #SparseVector>) -> ()
+    call @dump(%1) : (tensor<?xf64, #SparseVector>) -> ()
+    call @dump(%2) : (tensor<?xf64, #SparseVector>) -> ()
+    call @dump(%3) : (tensor<?xf64, #SparseVector>) -> ()
+    %m4 = sparse_tensor.values %4 : tensor<?xf64, #DenseVector> to memref<?xf64>
+    %v4 = vector.load %m4[%c0]: memref<?xf64>, vector<32xf64>
+    vector.print %v4 : vector<32xf64>
+    %m5 = memref.buffer_cast %5 : memref<f64>
+    %v5 = memref.load %m5[] : memref<f64>
+    vector.print %v5 : f64
+
+    // Release the resources.
+    sparse_tensor.release %sv1 : tensor<?xf64, #SparseVector>
+    sparse_tensor.release %sv2 : tensor<?xf64, #SparseVector>
+    sparse_tensor.release %0 : tensor<?xf64, #SparseVector>
+    sparse_tensor.release %2 : tensor<?xf64, #SparseVector>
+    sparse_tensor.release %3 : tensor<?xf64, #SparseVector>
+    sparse_tensor.release %4 : tensor<?xf64, #DenseVector>
+    memref.dealloc %xdata : memref<f64>
+    return
+  }
+}


        


More information about the Mlir-commits mailing list