[Mlir-commits] [mlir] edca72f - [mlir][sparse] Refactoring: remove dependence on tuple type when lowering sparse tensors.

Peiming Liu llvmlistbot at llvm.org
Wed Sep 7 10:53:57 PDT 2022


Author: Peiming Liu
Date: 2022-09-07T17:53:48Z
New Revision: edca72f5bcb039840fda28e324af4614d4e46fde

URL: https://github.com/llvm/llvm-project/commit/edca72f5bcb039840fda28e324af4614d4e46fde
DIFF: https://github.com/llvm/llvm-project/commit/edca72f5bcb039840fda28e324af4614d4e46fde.diff

LOG: [mlir][sparse] Refactoring: remove dependence on tuple type when lowering sparse tensors.

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D133390

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
    mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
    mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
    mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
    mlir/test/Dialect/SparseTensor/codegen.mlir
    mlir/test/Dialect/SparseTensor/invalid.mlir
    mlir/test/Dialect/SparseTensor/roundtrip.mlir

Removed: 
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageExpansion.cpp
    mlir/test/Dialect/SparseTensor/sparse_tensor_storage.mlir


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index 5c56f16d71ee8..7035fb16d8e18 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -624,79 +624,4 @@ def SparseTensor_YieldOp : SparseTensor_Op<"yield", [NoSideEffect, Terminator]>,
   let hasVerifier = 1;
 }
 
-//===----------------------------------------------------------------------===//
-// Sparse Tensor Storage Operation. These operations are used internally by
-// sparse tensor codegen to progressively lower sparse tensors.
-//===----------------------------------------------------------------------===//
-
-def SparseTensor_StorageOp : SparseTensor_Op<"storage", []>,
-    Arguments<(ins Variadic<AnyType>:$inputs)>,
-    Results<(outs AnyTuple:$result)> {
-  let summary = "Pack a list of value into one sparse tensor storage value";
-  let description = [{
-     Pack a list of value into one sparse tensor storage value (represented as
-     a tuple) at the given index.
-
-     The result tuple elements' type should match the corresponding type in the
-     input array.
-
-     Example:
-
-     ```mlir
-     %0 = sparse_tensor.storage(%1, %2): memref<?xf64>, memref<?xf64>
-                                to tuple<memref<?xf64>, memref<?xf64>>
-     ```
-   }];
-
-  let assemblyFormat = " attr-dict `(` $inputs `)``:` type($inputs) `to` type($result)";
-  let hasVerifier = 1;
-}
-
-def SparseTensor_StorageGetOp : SparseTensor_Op<"storage_get", []>,
-    Arguments<(ins AnyTuple:$storage,
-                   IndexAttr:$idx)>,
-    Results<(outs AnyType:$result)> {
-  let summary = "Get the data stored in the sparse tensor storage at the given index";
-  let description = [{
-     Get the data stored in the sparse tensor storage (represented as a tuple)
-     at the given index.
-
-     The result type should match the corresponding element type in the tuple.
-
-     Example:
-
-     ```mlir
-     %0 = sparse_tensor.storage_get %arg0[0] : tuple<memref<?xf64>, memref<?xf64>, f64> to memref<?xf64>
-     ```
-   }];
-
-  let assemblyFormat = " $storage attr-dict `[`$idx`]` `:` type($storage) `to` type($result)";
-  let hasVerifier = 1;
-}
-
-def SparseTensor_StorageSetOp : SparseTensor_Op<"storage_set", []>,
-    Arguments<(ins AnyTuple:$storage,
-                   AnyType:$value,
-                   IndexAttr:$idx)>,
-    Results<(outs AnyTuple:$result)> {
-  let summary = "Set the data stored in the sparse tensor storage at given index";
-  let description = [{
-     Set the data stored in the sparse tensor storage (represented as a tuple)
-     at the given index. Return a new SSA value with the corresponding element
-     updated (others remain unchanged).
-
-     The result type should match the original tuple type with only the updated
-     element type changed accordingly.
-
-     Example:
-
-     ```mlir
-     %0 = sparse_tensor.storage_set %arg0, %arg1 at 0 : tuple<memref<?xf64>, memref<?xf64>, f64>, memref<?xf64> to tuple<memref<?xf64>, memref<?xf64>, f64>
-     ```
-   }];
-
-  let assemblyFormat = " $storage attr-dict `[`$idx`]``,` $value `:` type($storage) `,` type($value) `to` type($result)";
-  let hasVerifier = 1;
-}
-
 #endif // SPARSETENSOR_OPS

diff  --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
index 227b70a381192..fd885f646221a 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
@@ -155,22 +155,6 @@ void populateSparseTensorCodegenPatterns(TypeConverter &typeConverter,
 
 std::unique_ptr<Pass> createSparseTensorCodegenPass();
 
-//===----------------------------------------------------------------------===//
-// The SparseTensorStorageExpansion pass.
-//===----------------------------------------------------------------------===//
-
-/// Sparse tensor storage type converter from compound to expanded form.
-class SparseTensorStorageTupleExpander : public TypeConverter {
-public:
-  SparseTensorStorageTupleExpander();
-};
-
-/// Sets up sparse tensor storage expansion rules.
-void populateSparseTensorStorageExpansionPatterns(TypeConverter &typeConverter,
-                                                  RewritePatternSet &patterns);
-
-std::unique_ptr<Pass> createSparseTensorStorageExpansionPass();
-
 //===----------------------------------------------------------------------===//
 // Other rewriting rules and passes.
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
index cd6b77ea50eea..f7f4a39a95f23 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
@@ -175,39 +175,4 @@ def SparseTensorCodegen : Pass<"sparse-tensor-codegen", "ModuleOp"> {
   ];
 }
 
-def SparseTensorStorageExpansion : Pass<"sparse-tensor-storage-expansion", "ModuleOp"> {
-  let summary = "Expand compounded sparse tensor storage into individual SSA values";
-  let description = [{
-    A pass that expands sparse tensor storage (aggregated by tuple) into
-    individual SSA values. It also lowers sparse tensor storage operations,
-    e.g., sparse_tensor.storage_get and sparse_tensor.storage_set.
-
-    Example of the conversion:
-
-    ```mlir
-    Before:
-      func.func @sparse_storage_set(%arg0: tuple<memref<?xf64>,
-                                                 memref<?xf64>,
-                                                 f64>)
-                                        -> tuple<memref<?xf64>,
-                                                 memref<?xf64>,
-                                                 f64> {
-        return %arg0 : tuple<memref<?xf64>, memref<?xf64>, f64>
-      }
-    After:
-      func.func @sparse_storage_set(%arg0: memref<?xf64>,
-                                    %arg1: memref<?xf64>,
-                                    %arg2: f64)
-                                    -> (memref<?xf64>, memref<?xf64>, f64) {
-        return %arg0, %arg1, %arg2 : memref<?xf64>, memref<?xf64>, f64
-      }
-    ```
-  }];
-  let constructor = "mlir::createSparseTensorStorageExpansionPass()";
-  let dependentDialects = [
-    "sparse_tensor::SparseTensorDialect",
-  ];
-}
-
-
 #endif // MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_PASSES

diff  --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index ef32dea8efff2..8691b94351f9f 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -482,65 +482,6 @@ LogicalResult YieldOp::verify() {
       "expected parent op to be sparse_tensor unary, binary, or reduce");
 }
 
-//===----------------------------------------------------------------------===//
-// Sparse Tensor Storage Operation.
-//===----------------------------------------------------------------------===//
-
-LogicalResult StorageOp::verify() {
-  auto retTypes = getResult().getType().getTypes();
-  if (retTypes.size() != getInputs().size())
-    return emitError("The number of inputs is inconsistent with output tuple");
-
-  for (auto pair : llvm::zip(getInputs(), retTypes)) {
-    auto input = std::get<0>(pair);
-    auto retTy = std::get<1>(pair);
-
-    if (input.getType() != retTy)
-      return emitError(llvm::formatv("Type mismatch between input (type={0}) "
-                                     "and output tuple element (type={1})",
-                                     input.getType(), retTy));
-  }
-  return success();
-}
-
-LogicalResult StorageGetOp::verify() {
-  uint64_t extractIdx = getIdx().getZExtValue();
-  auto innerTypeArray = getStorage().getType().getTypes();
-  if (extractIdx >= innerTypeArray.size())
-    return emitError(llvm::formatv(
-        "Out-of-bound access with index={0} on tuple with length={1}",
-        extractIdx, innerTypeArray.size()));
-
-  auto expectedTy = getStorage().getType().getType(extractIdx);
-  auto returnTy = getResult().getType();
-  if (expectedTy != returnTy)
-    return emitError(llvm::formatv(
-        "Type mismatch between the returning type (type={0}) and the "
-        "corresponding element type at index {1} (type={2})",
-        expectedTy, extractIdx, returnTy));
-  return success();
-}
-
-LogicalResult StorageSetOp::verify() {
-  uint64_t setIdx = getIdx().getZExtValue();
-  SmallVector<Type, 8> expectedElemTy(getStorage().getType().getTypes());
-  if (setIdx >= expectedElemTy.size())
-    return emitError(llvm::formatv(
-        "Out-of-bound access with index = {0} on tuple with length={1}", setIdx,
-        expectedElemTy.size()));
-
-  // Updates the element type after storage_set.
-  expectedElemTy[setIdx] = getValue().getType();
-  auto expectedTy = TupleType::get(getContext(), expectedElemTy);
-  auto returnTy = getResult().getType();
-  if (expectedTy != returnTy)
-    return emitError(
-        llvm::formatv("Type mismatch between the returning type "
-                      "(type={0}) and the expected type (type={1})",
-                      returnTy, expectedTy));
-  return success();
-}
-
 //===----------------------------------------------------------------------===//
 // TensorDialect Methods.
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
index 39b633a6c7f6a..640ee67302b1b 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
@@ -7,7 +7,6 @@ add_mlir_dialect_library(MLIRSparseTensorTransforms
   SparseTensorConversion.cpp
   SparseTensorPasses.cpp
   SparseTensorRewriting.cpp
-  SparseTensorStorageExpansion.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/SparseTensor

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index 022c4be443a0a..3c9caf71512b8 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -54,8 +54,30 @@ static unsigned toStored(const SparseTensorEncodingAttr &enc, unsigned i) {
   return i;
 }
 
+/// Flatten a list of operands that may contain sparse tensors.
+static void flattenOperands(ValueRange operands,
+                            SmallVectorImpl<Value> &flattened) {
+  // In case of
+  // sparse_tensor, c, sparse_tensor
+  // ==>
+  // memref ..., c, memref ...
+  for (auto operand : operands) {
+    if (auto cast =
+            dyn_cast<UnrealizedConversionCastOp>(operand.getDefiningOp());
+        cast && getSparseTensorEncoding(cast->getResultTypes()[0]))
+      // An unrealized_conversion_cast will be inserted by type converter to
+      // inter-mix the gap between 1:N conversion between sparse tensors and
+      // fields. In this case, take the operands in the cast and replace the
+      // sparse tensor output with the flattened type array.
+      flattened.append(cast.getOperands().begin(), cast.getOperands().end());
+    else
+      flattened.push_back(operand);
+  }
+}
+
 /// Maps a sparse tensor type to the appropriate compounded buffers.
-static Optional<Type> convertSparseTensorType(Type type) {
+static Optional<LogicalResult>
+convertSparseTensorType(Type type, SmallVectorImpl<Type> &fields) {
   auto enc = getSparseTensorEncoding(type);
   if (!enc)
     return llvm::None;
@@ -86,7 +108,6 @@ static Optional<Type> convertSparseTensorType(Type type) {
   // };
   //
   unsigned rank = rType.getShape().size();
-  SmallVector<Type, 8> fields;
   // The dimSizes array.
   fields.push_back(MemRefType::get({rank}, indexType));
   // Per-dimension storage.
@@ -115,10 +136,7 @@ static Optional<Type> convertSparseTensorType(Type type) {
   }
   // The values array.
   fields.push_back(MemRefType::get({ShapedType::kDynamicSize}, eltType));
-  // Sparse tensor storage (temporarily) lives in a tuple. This allows a
-  // simple 1:1 type conversion during codegen. A subsequent pass uses
-  // a 1:N type conversion to expand the tuple into its fields.
-  return TupleType::get(context, fields);
+  return success();
 }
 
 // Returns field index of sparse tensor type for pointers/indices, when set.
@@ -158,25 +176,6 @@ static unsigned getFieldIndex(Type type, unsigned ptrDim, unsigned idxDim) {
   return -1;
 }
 
-/// Returns field type in tuple at given index.
-static Type getFieldType(Value tuple, unsigned field) {
-  return tuple.getType().cast<TupleType>().getType(field);
-}
-
-/// Creates tuple get operation at given index.
-static Value createTupleGet(OpBuilder &builder, Location loc, Value tuple,
-                            unsigned field) {
-  Type indexType = builder.getIndexType();
-  return builder.create<StorageGetOp>(loc, getFieldType(tuple, field), tuple,
-                                      builder.getIntegerAttr(indexType, field));
-}
-
-/// Creates tuple.
-static Value createTupleMake(OpBuilder &builder, Location loc, Type type,
-                             ValueRange values) {
-  return builder.create<StorageOp>(loc, type, values);
-}
-
 /// Create allocation operation.
 static Value createAllocation(OpBuilder &builder, Location loc, Type type,
                               Value sz) {
@@ -184,14 +183,15 @@ static Value createAllocation(OpBuilder &builder, Location loc, Type type,
   return builder.create<memref::AllocOp>(loc, memType, sz);
 }
 
-/// Creates allocation tuple for sparse tensor type.
+/// Creates allocation for each field in sparse tensor type.
 ///
 /// TODO: for efficiency, we will need heuristis to make educated guesses
 ///       on the required final sizes; also, we will need an improved
 ///       memory allocation scheme with capacity and reallocation
 ///
-static Value createAllocTuple(OpBuilder &builder, Location loc, Type type,
-                              ValueRange dynSizes) {
+static void createAllocFields(OpBuilder &builder, Location loc, Type type,
+                              ValueRange dynSizes,
+                              SmallVectorImpl<Value> &fields) {
   auto enc = getSparseTensorEncoding(type);
   assert(enc);
   // Construct the basic types.
@@ -202,10 +202,8 @@ static Value createAllocTuple(OpBuilder &builder, Location loc, Type type,
   Type idxType = idxWidth ? builder.getIntegerType(idxWidth) : indexType;
   Type ptrType = ptrWidth ? builder.getIntegerType(ptrWidth) : indexType;
   Type eltType = rType.getElementType();
-  // Build the allocation tuple, using heuristics for pre-allocation.
   auto shape = rType.getShape();
   unsigned rank = shape.size();
-  SmallVector<Value, 8> fields;
   bool allDense = true;
   Value one = constantIndex(builder, loc, 1);
   Value linear = one;
@@ -254,9 +252,6 @@ static Value createAllocTuple(OpBuilder &builder, Location loc, Type type,
   // In all other case, we resort to the heuristical initial value.
   Value valuesSz = allDense ? linear : heuristic;
   fields.push_back(createAllocation(builder, loc, eltType, valuesSz));
-  // Construct tuple allocation.
-  Type tupleType = *convertSparseTensorType(type);
-  return createTupleMake(builder, loc, tupleType, fields);
 }
 
 /// Returns integral constant, if defined.
@@ -270,14 +265,80 @@ static Optional<int64_t> getConstantInt(Value val) {
 // Codegen rules.
 //===----------------------------------------------------------------------===//
 
-/// Sparse codegen rule for returns.
+/// Sparse tensor storage conversion rule for returns.
 class SparseReturnConverter : public OpConversionPattern<func::ReturnOp> {
 public:
   using OpConversionPattern::OpConversionPattern;
   LogicalResult
   matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    rewriter.replaceOpWithNewOp<func::ReturnOp>(op, adaptor.getOperands());
+    SmallVector<Value, 8> flattened;
+    flattenOperands(adaptor.getOperands(), flattened);
+    // Create a return with the flattened value extracted from sparse tensors.
+    rewriter.replaceOpWithNewOp<func::ReturnOp>(op, flattened);
+    return success();
+  }
+};
+
+/// Sparse tensor storage conversion rule for calls.
+class SparseCallConverter : public OpConversionPattern<func::CallOp> {
+public:
+  // The default CallOp converter can not handle 1:N type conversion.
+  using OpConversionPattern::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(func::CallOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    // In case of:
+    //  sparse_tensor, f, sparse_tensor = call @foo(...)
+    // ==>
+    //  memref..., f, memref = call @foo(...) replace with
+    //  cast(memref...)->sparse_tensor, f, cast(memref...)->sparse_tensor
+    SmallVector<Type, 8> finalRetTy;
+    if (failed(typeConverter->convertTypes(op.getResultTypes(), finalRetTy)))
+      return failure();
+
+    // (1) Genereates new call with flattened return value.
+    SmallVector<Value, 8> flattened;
+    flattenOperands(adaptor.getOperands(), flattened);
+    auto newCall = rewriter.create<func::CallOp>(loc, op.getCallee(),
+                                                 finalRetTy, flattened);
+    // (2) Create cast operation for sparse tensor returns.
+    SmallVector<Value, 4> castedRet;
+    // Tracks the offset of current return value (of the orignal call)
+    // relative to the new call (after sparse tensor flattening);
+    unsigned retOffset = 0;
+    // Temporal buffer to hold the flattened list of type for
+    // a sparse tensor.
+    SmallVector<Type, 8> sparseFlat;
+    for (auto ret : op.getResults()) {
+      assert(retOffset < newCall.getNumResults());
+      auto retType = ret.getType();
+      if (failed(typeConverter->convertType(retType, sparseFlat)))
+        // This should never happen.
+        llvm_unreachable("Failed to convert type in sparse tensor codegen");
+
+      // Converted types can not be empty when the type conversion succeed.
+      assert(!sparseFlat.empty());
+      if (sparseFlat.size() > 1) {
+        auto flatSize = sparseFlat.size();
+        ValueRange sparseElem(iterator_range<ResultRange::iterator>(
+            newCall.result_begin() + retOffset,
+            newCall.result_begin() + retOffset + flatSize));
+        auto castOp = rewriter.create<UnrealizedConversionCastOp>(
+            loc, TypeRange({retType}), sparseElem);
+        castedRet.push_back(castOp.getResult(0));
+        retOffset += flatSize;
+      } else {
+        // If this is an 1:1 conversion, no need for casting.
+        castedRet.push_back(newCall.getResult(retOffset));
+        retOffset++;
+      }
+      sparseFlat.clear();
+    }
+
+    assert(castedRet.size() == op.getNumResults());
+    rewriter.replaceOp(op, castedRet);
     return success();
   }
 };
@@ -306,10 +367,11 @@ class SparseDimOpConverter : public OpConversionPattern<tensor::DimOp> {
     }
     // Any other query can consult the dimSizes array at field 0 using,
     // accounting for the reordering applied to the sparse storage.
-    Value tuple = adaptor.getSource();
-    Value dimSizes = createTupleGet(rewriter, loc, tuple, 0);
+    auto tuple = llvm::cast<UnrealizedConversionCastOp>(
+        adaptor.getSource().getDefiningOp());
     rewriter.replaceOpWithNewOp<memref::LoadOp>(
-        op, dimSizes, constantIndex(rewriter, loc, toStored(enc, *index)));
+        op, tuple.getInputs().front(),
+        constantIndex(rewriter, loc, toStored(enc, *index)));
     return success();
   }
 };
@@ -345,10 +407,13 @@ class SparseTensorAllocConverter
       return failure();
     if (op.getCopy())
       return rewriter.notifyMatchFailure(op, "tensor copy not implemented");
-    // Construct allocation tuple.
-    Value tuple = createAllocTuple(rewriter, op->getLoc(), resType,
-                                   adaptor.getOperands());
-    rewriter.replaceOp(op, tuple);
+
+    // Construct allocation for each field.
+    Location loc = op.getLoc();
+    SmallVector<Value, 8> fields;
+    createAllocFields(rewriter, loc, resType, adaptor.getOperands(), fields);
+    rewriter.replaceOpWithNewOp<UnrealizedConversionCastOp>(
+        op, TypeRange{resType}, fields);
     return success();
   }
 };
@@ -364,86 +429,101 @@ class SparseTensorDeallocConverter
     auto enc = getSparseTensorEncoding(op.getTensor().getType());
     if (!enc)
       return failure();
-    // Replace the tuple deallocation with field deallocations.
-    Location loc = op->getLoc();
-    Value tuple = adaptor.getTensor();
-    for (unsigned i = 0, sz = tuple.getType().cast<TupleType>().size(); i < sz;
-         i++) {
-      Value mem = createTupleGet(rewriter, loc, tuple, i);
-      rewriter.create<memref::DeallocOp>(loc, mem);
-    }
+
+    // Replace the sparse tensor deallocation with field deallocations.
+    Location loc = op.getLoc();
+    auto tuple = llvm::cast<UnrealizedConversionCastOp>(
+        adaptor.getTensor().getDefiningOp());
+    for (auto input : tuple.getInputs())
+      // Deallocate every buffer used to store the sparse tensor handler.
+      rewriter.create<memref::DeallocOp>(loc, input);
+
     rewriter.eraseOp(op);
     return success();
   }
 };
 
-/// Sparse codegen rule for pointer accesses.
-class SparseToPointersConverter : public OpConversionPattern<ToPointersOp> {
+/// Sparse codegen rule for tensor rematerialization.
+class SparseTensorLoadConverter : public OpConversionPattern<LoadOp> {
 public:
   using OpConversionPattern::OpConversionPattern;
   LogicalResult
-  matchAndRewrite(ToPointersOp op, OpAdaptor adaptor,
+  matchAndRewrite(LoadOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    Optional<int64_t> index = getConstantInt(adaptor.getOperands()[1]);
-    if (!index)
-      return failure();
-    // Replace the requested pointer access with corresponding field.
-    Location loc = op->getLoc();
-    Value tuple = adaptor.getTensor();
-    unsigned i = getFieldIndex(op.getTensor().getType(), /*ptrDim=*/*index, -1);
-    rewriter.replaceOp(op, createTupleGet(rewriter, loc, tuple, i));
+    if (op.getHasInserts()) {
+      // Finalize any pending insertions.
+      // TODO: implement
+    }
+    rewriter.replaceOp(op, adaptor.getOperands());
     return success();
   }
 };
 
-/// Sparse codegen rule for index accesses.
-class SparseToIndicesConverter : public OpConversionPattern<ToIndicesOp> {
+/// Base class for getter-like operations, e.g., to_indices, to_pointers.
+template <typename SourceOp, typename Base>
+class SparseGetterOpConverter : public OpConversionPattern<SourceOp> {
 public:
-  using OpConversionPattern::OpConversionPattern;
+  using OpAdaptor = typename SourceOp::Adaptor;
+  using OpConversionPattern<SourceOp>::OpConversionPattern;
   LogicalResult
-  matchAndRewrite(ToIndicesOp op, OpAdaptor adaptor,
+  matchAndRewrite(SourceOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    Optional<int64_t> index = getConstantInt(adaptor.getOperands()[1]);
-    if (!index)
+    // Replace the requested pointer access with corresponding field.
+    // The cast_op is inserted by type converter to intermix 1:N type
+    // conversion.
+    auto tuple = llvm::cast<UnrealizedConversionCastOp>(
+        adaptor.getTensor().getDefiningOp());
+    auto idx = Base::getIndexForOp(tuple, op);
+    if (!idx)
+      // Failed to get the index.
       return failure();
-    // Replace the requested indices access with corresponding field.
-    Location loc = op->getLoc();
-    Value tuple = adaptor.getTensor();
-    unsigned i = getFieldIndex(op.getTensor().getType(), -1, /*idxDim=*/*index);
-    rewriter.replaceOp(op, createTupleGet(rewriter, loc, tuple, i));
+    auto fields = tuple.getInputs();
+    assert(*idx < fields.size());
+    rewriter.replaceOp(op, fields[*idx]);
     return success();
   }
 };
 
-/// Sparse codegen rule for value accesses.
-class SparseToValuesConverter : public OpConversionPattern<ToValuesOp> {
+/// Sparse codegen rule for pointer accesses.
+class SparseToPointersConverter
+    : public SparseGetterOpConverter<ToPointersOp, SparseToPointersConverter> {
 public:
-  using OpConversionPattern::OpConversionPattern;
-  LogicalResult
-  matchAndRewrite(ToValuesOp op, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override {
-    // Replace the requested values access with corresponding field.
-    Location loc = op->getLoc();
-    Value tuple = adaptor.getTensor();
-    unsigned i = tuple.getType().cast<TupleType>().size() - 1; // last
-    rewriter.replaceOp(op, createTupleGet(rewriter, loc, tuple, i));
-    return success();
+  using SparseGetterOpConverter::SparseGetterOpConverter;
+  // Callback for SparseGetterOpConverter.
+  static Optional<unsigned> getIndexForOp(UnrealizedConversionCastOp /*tuple*/,
+                                          ToPointersOp op) {
+    Optional<int64_t> dim = getConstantInt(op.getDim());
+    if (!dim)
+      return llvm::None; // variable dim
+    return getFieldIndex(op.getTensor().getType(), /*ptrDim=*/*dim, -1);
   }
 };
 
-/// Sparse codegen rule for tensor rematerialization.
-class SparseTensorLoadConverter : public OpConversionPattern<LoadOp> {
+/// Sparse codegen rule for index accesses.
+class SparseToIndicesConverter
+    : public SparseGetterOpConverter<ToIndicesOp, SparseToIndicesConverter> {
 public:
-  using OpConversionPattern::OpConversionPattern;
-  LogicalResult
-  matchAndRewrite(LoadOp op, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override {
-    if (op.getHasInserts()) {
-      // Finalize any pending insertions.
-      // TODO: implement
-    }
-    rewriter.replaceOp(op, adaptor.getOperands());
-    return success();
+  using SparseGetterOpConverter::SparseGetterOpConverter;
+  // Callback for SparseGetterOpConverter.
+  static Optional<unsigned> getIndexForOp(UnrealizedConversionCastOp /*tuple*/,
+                                          ToIndicesOp op) {
+    Optional<int64_t> dim = getConstantInt(op.getDim());
+    if (!dim)
+      return llvm::None; // variable dim
+    return getFieldIndex(op.getTensor().getType(), -1, /*idxDim=*/*dim);
+  }
+};
+
+/// Sparse codegen rule for value accesses.
+class SparseToValuesConverter
+    : public SparseGetterOpConverter<ToValuesOp, SparseToValuesConverter> {
+public:
+  using SparseGetterOpConverter::SparseGetterOpConverter;
+  // Callback for SparseGetterOpConverter.
+  static Optional<unsigned> getIndexForOp(UnrealizedConversionCastOp tuple,
+                                          ToValuesOp /*op*/) {
+    // The last field holds the value buffer.
+    return tuple.getInputs().size() - 1;
   }
 };
 
@@ -466,9 +546,9 @@ mlir::SparseTensorTypeToBufferConverter::SparseTensorTypeToBufferConverter() {
 /// the sparsification of linear algebra operations.
 void mlir::populateSparseTensorCodegenPatterns(TypeConverter &typeConverter,
                                                RewritePatternSet &patterns) {
-  patterns.add<SparseReturnConverter, SparseDimOpConverter, SparseCastConverter,
-               SparseTensorAllocConverter, SparseTensorDeallocConverter,
-               SparseToPointersConverter, SparseToIndicesConverter,
-               SparseToValuesConverter, SparseTensorLoadConverter>(
-      typeConverter, patterns.getContext());
+  patterns.add<SparseReturnConverter, SparseCallConverter, SparseDimOpConverter,
+               SparseCastConverter, SparseTensorAllocConverter,
+               SparseTensorDeallocConverter, SparseToPointersConverter,
+               SparseToIndicesConverter, SparseToValuesConverter,
+               SparseTensorLoadConverter>(typeConverter, patterns.getContext());
 }

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
index 505ae79e26fac..fee4222cb53d3 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
@@ -24,7 +24,6 @@ namespace mlir {
 #define GEN_PASS_DEF_SPARSIFICATIONPASS
 #define GEN_PASS_DEF_SPARSETENSORCONVERSIONPASS
 #define GEN_PASS_DEF_SPARSETENSORCODEGEN
-#define GEN_PASS_DEF_SPARSETENSORSTORAGEEXPANSION
 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h.inc"
 } // namespace mlir
 
@@ -154,9 +153,8 @@ struct SparseTensorCodegenPass
     RewritePatternSet patterns(ctx);
     SparseTensorTypeToBufferConverter converter;
     ConversionTarget target(*ctx);
-    // Almost everything in the sparse dialect must go!
+    // Everything in the sparse dialect must go!
     target.addIllegalDialect<SparseTensorDialect>();
-    target.addLegalOp<StorageGetOp, StorageSetOp, StorageOp>();
     // All dynamic rules below accept new function, call, return, and various
     // tensor and bufferization operations as legal output of the rewriting
     // provided that all sparse tensor types have been fully rewritten.
@@ -181,53 +179,13 @@ struct SparseTensorCodegenPass
     target.addLegalDialect<arith::ArithmeticDialect,
                            bufferization::BufferizationDialect,
                            memref::MemRefDialect, scf::SCFDialect>();
-    // Populate with rules and apply rewriting rules.
-    populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
-                                                                   converter);
-    populateCallOpTypeConversionPattern(patterns, converter);
-    scf::populateSCFStructuralTypeConversionsAndLegality(converter, patterns,
-                                                         target);
-    populateSparseTensorCodegenPatterns(converter, patterns);
-    if (failed(applyPartialConversion(getOperation(), target,
-                                      std::move(patterns))))
-      signalPassFailure();
-  }
-};
-
-struct SparseTensorStorageExpansionPass
-    : public impl::SparseTensorStorageExpansionBase<
-          SparseTensorStorageExpansionPass> {
-
-  SparseTensorStorageExpansionPass() = default;
-  SparseTensorStorageExpansionPass(
-      const SparseTensorStorageExpansionPass &pass) = default;
-
-  void runOnOperation() override {
-    auto *ctx = &getContext();
-    RewritePatternSet patterns(ctx);
-    SparseTensorStorageTupleExpander converter;
-    ConversionTarget target(*ctx);
-    // Now, everything in the sparse dialect must go!
-    target.addIllegalDialect<SparseTensorDialect>();
-    // All dynamic rules below accept new function, call, return.
-    target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
-      return converter.isSignatureLegal(op.getFunctionType());
-    });
-    target.addDynamicallyLegalOp<func::CallOp>([&](func::CallOp op) {
-      return converter.isSignatureLegal(op.getCalleeType());
-    });
-    target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp op) {
-      return converter.isLegal(op.getOperandTypes());
-    });
-    // We generate UnrealizedConversionCastOp to intermix tuples and a
-    // list of types.
     target.addLegalOp<UnrealizedConversionCastOp>();
     // Populate with rules and apply rewriting rules.
     populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
                                                                    converter);
     scf::populateSCFStructuralTypeConversionsAndLegality(converter, patterns,
                                                          target);
-    populateSparseTensorStorageExpansionPatterns(converter, patterns);
+    populateSparseTensorCodegenPatterns(converter, patterns);
     if (failed(applyPartialConversion(getOperation(), target,
                                       std::move(patterns))))
       signalPassFailure();
@@ -277,7 +235,3 @@ std::unique_ptr<Pass> mlir::createSparseTensorConversionPass(
 std::unique_ptr<Pass> mlir::createSparseTensorCodegenPass() {
   return std::make_unique<SparseTensorCodegenPass>();
 }
-
-std::unique_ptr<Pass> mlir::createSparseTensorStorageExpansionPass() {
-  return std::make_unique<SparseTensorStorageExpansionPass>();
-}

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageExpansion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageExpansion.cpp
deleted file mode 100644
index 1f7afa1d77804..0000000000000
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageExpansion.cpp
+++ /dev/null
@@ -1,218 +0,0 @@
-//===- SparseTensorStorageExpansion.cpp - Sparse tensor storage expansion ===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// The sparse tensor storage expansion pass expands the compound storage for
-// sparse tensors (using tuple) to flattened SSA values.
-//
-//===----------------------------------------------------------------------===//
-
-#include "CodegenUtils.h"
-
-#include "mlir/Dialect/Func/IR/FuncOps.h"
-#include "mlir/Dialect/MemRef/IR/MemRef.h"
-#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
-#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
-#include "mlir/Dialect/Tensor/IR/Tensor.h"
-#include "mlir/Transforms/DialectConversion.h"
-
-using namespace mlir;
-using namespace mlir::sparse_tensor;
-
-namespace {
-
-//===----------------------------------------------------------------------===//
-// Helper methods.
-//===----------------------------------------------------------------------===//
-
-/// Expands sparse tensor storage tuple.
-static Optional<LogicalResult>
-convertSparseTensorStorageTuple(Type t, SmallVectorImpl<Type> &result) {
-  if (auto tuple = t.dyn_cast<TupleType>()) {
-    // Note that it does not handle nest tuples, but it is fine
-    // for sparse compiler as they will not be generated.
-    result.append(tuple.getTypes().begin(), tuple.getTypes().end());
-    return success();
-  }
-  return llvm::None;
-}
-
-/// Flatten a list of operands that may contain tuples.
-static void flattenOperands(ValueRange operands,
-                            SmallVectorImpl<Value> &flattened) {
-  // In case of
-  // tuple<a, b>, c, tuple<d, e>
-  // ==>
-  // a, b, c, d, e
-  for (auto operand : operands) {
-    if (auto cast =
-            dyn_cast<UnrealizedConversionCastOp>(operand.getDefiningOp());
-        cast && cast->getResultTypes()[0].isa<TupleType>())
-      // An unrealized_conversion_cast will be inserted by type converter to
-      // inter-mix the gap between 1:N conversion between tuple and types.
-      // In this case, take the operands in the cast and replace the tuple
-      // output with the flattened type array.
-      flattened.append(cast.getOperands().begin(), cast.getOperands().end());
-    else
-      flattened.push_back(operand);
-  }
-}
-//===----------------------------------------------------------------------===//
-// Conversion rules.
-//===----------------------------------------------------------------------===//
-
-/// Sparse tensor storage conversion rule for sparse_tensor::storage.
-class SparseStorageConversion : public OpConversionPattern<StorageOp> {
-public:
-  using OpConversionPattern::OpConversionPattern;
-  LogicalResult
-  matchAndRewrite(StorageOp op, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override {
-    // Simply convert it to a unrealize_conversion_cast.
-    // We should guarantee that all uses of sparse_tensor.storage op will
-    // be eventually eliminated by accessing the flattened SSA values directly.
-    rewriter.replaceOpWithNewOp<UnrealizedConversionCastOp>(
-        op, TypeRange{op.getType()}, adaptor.getInputs());
-    return success();
-  }
-};
-
-/// Sparse tensor storage conversion rule for sparse_tensor::storage_get.
-class SparseStorageGetConverter : public OpConversionPattern<StorageGetOp> {
-public:
-  using OpConversionPattern::OpConversionPattern;
-  LogicalResult
-  matchAndRewrite(StorageGetOp op, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override {
-    auto castOp =
-        cast<UnrealizedConversionCastOp>(adaptor.getStorage().getDefiningOp());
-    uint64_t idx = op.getIdx().getZExtValue();
-    assert(idx < castOp.getOperands().size());
-
-    rewriter.replaceOp(op, castOp.getOperand(idx));
-    return success();
-  }
-};
-
-/// Sparse tensor storage conversion rule for sparse_tensor::storage_set.
-class SparseStorageSetConverter : public OpConversionPattern<StorageSetOp> {
-public:
-  using OpConversionPattern::OpConversionPattern;
-  LogicalResult
-  matchAndRewrite(StorageSetOp op, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override {
-    auto castOp =
-        cast<UnrealizedConversionCastOp>(adaptor.getStorage().getDefiningOp());
-    uint64_t idx = op.getIdx().getZExtValue();
-
-    SmallVector<Value, 8> values(castOp.getOperands());
-    assert(idx < values.size());
-
-    // Updates the corresponding element.
-    values[idx] = adaptor.getValue();
-    rewriter.replaceOpWithNewOp<UnrealizedConversionCastOp>(
-        op, TypeRange{op.getType()}, values);
-    return success();
-  }
-};
-
-/// Sparse tensor storage conversion rule for returns.
-class SparseStorageReturnConverter
-    : public OpConversionPattern<func::ReturnOp> {
-public:
-  using OpConversionPattern::OpConversionPattern;
-  LogicalResult
-  matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override {
-    SmallVector<Value, 8> flattened;
-    flattenOperands(adaptor.getOperands(), flattened);
-    // Create a return with the flattened value extracted from tuple.
-    rewriter.replaceOpWithNewOp<func::ReturnOp>(op, flattened);
-    return success();
-  }
-};
-
-/// Sparse tensor storage conversion rule for calls.
-class SparseStorageCallConverter : public OpConversionPattern<func::CallOp> {
-public:
-  // The default CallOp converter can not handle 1:N type conversion properly
-  using OpConversionPattern::OpConversionPattern;
-  LogicalResult
-  matchAndRewrite(func::CallOp op, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override {
-    Location loc = op.getLoc();
-    // In case of:
-    //  tuple(a, b), f, tuple(c, d) = call @foo(...)
-    // ==>
-    //  a, b, f, c, d = call @foo(...)
-    //  cast(a, b)->tuple, f, cast(c,d)->tuple
-    SmallVector<Type, 8> finalRetTy;
-    if (failed(typeConverter->convertTypes(op.getResultTypes(), finalRetTy)))
-      return failure();
-
-    // (1) Genereates new call with flattened return value.
-    SmallVector<Value, 8> flattened;
-    flattenOperands(adaptor.getOperands(), flattened);
-    auto newCall = rewriter.create<func::CallOp>(loc, op.getCallee(),
-                                                 finalRetTy, flattened);
-
-    // (2) Create cast operation for tuple returns.
-    SmallVector<Value, 4> castedRet;
-    // Tracks the offset of current return value (of the orignal call)
-    // relative to the new call (after tuple flattening);
-    unsigned retOffset = 0;
-    for (auto ret : op.getResults()) {
-      assert(retOffset < newCall.getNumResults());
-      auto tupleRet = ret.getType().dyn_cast<TupleType>();
-      if (tupleRet) {
-        auto tupleSize = tupleRet.size();
-        // NOTE: The range is computed under the assumption of non-recursive
-        // tuple type.
-        ValueRange tupleElem(iterator_range<ResultRange::iterator>(
-            newCall.result_begin() + retOffset,
-            newCall.result_begin() + retOffset + tupleSize));
-        auto castOp = rewriter.create<UnrealizedConversionCastOp>(
-            loc, TypeRange({tupleRet}), tupleElem);
-        castedRet.push_back(castOp.getResult(0));
-        retOffset += tupleSize;
-      } else {
-        // If this not a tuple, simply add it into returned values.
-        castedRet.push_back(ret);
-        retOffset++;
-      }
-    }
-
-    assert(castedRet.size() == op.getNumResults());
-    rewriter.replaceOp(op, castedRet);
-    return success();
-  }
-};
-
-} // namespace
-
-//===----------------------------------------------------------------------===//
-// Sparse tensor storage expansion
-//===----------------------------------------------------------------------===//
-
-mlir::SparseTensorStorageTupleExpander::SparseTensorStorageTupleExpander() {
-  addConversion([](Type type) { return type; });
-  addConversion(convertSparseTensorStorageTuple);
-}
-
-//===----------------------------------------------------------------------===//
-// Public method for populating conversion rules.
-//===----------------------------------------------------------------------===//
-
-/// Populates the given patterns list with conversion rules required
-/// to expand compounded sparse tensor tuples.
-void mlir::populateSparseTensorStorageExpansionPatterns(
-    TypeConverter &typeConverter, RewritePatternSet &patterns) {
-  patterns.add<SparseStorageConversion, SparseStorageGetConverter,
-               SparseStorageSetConverter, SparseStorageReturnConverter,
-               SparseStorageCallConverter>(typeConverter,
-                                           patterns.getContext());
-}

diff  --git a/mlir/test/Dialect/SparseTensor/codegen.mlir b/mlir/test/Dialect/SparseTensor/codegen.mlir
index 8c7968022e6f6..89fb8a9129fa5 100644
--- a/mlir/test/Dialect/SparseTensor/codegen.mlir
+++ b/mlir/test/Dialect/SparseTensor/codegen.mlir
@@ -1,5 +1,4 @@
-// RUN: mlir-opt %s --sparse-tensor-codegen  --canonicalize --cse | FileCheck %s --check-prefixes=CHECK,CHECK-CODEGEN
-// RUN: mlir-opt %s --sparse-tensor-codegen --sparse-tensor-storage-expansion --canonicalize --cse | FileCheck %s --check-prefixes=CHECK,CHECK-STORAGE
+// RUN: mlir-opt %s --sparse-tensor-codegen  --canonicalize --cse | FileCheck %s
 
 #SparseVector = #sparse_tensor.encoding<{
   dimLevelType = [ "compressed" ],
@@ -41,96 +40,114 @@
   dimOrdering = affine_map<(i, j, k) -> (k, i, j)>
 }>
 
-// CHECK-CODEGEN-LABEL: func @sparse_nop(
-//  CHECK-CODEGEN-SAME: %[[A:.*]]: tuple<memref<1xindex>, memref<?xi32>, memref<?xi64>, memref<?xf64>>)
-//       CHECK-CODEGEN: return %[[A]] : tuple<memref<1xindex>, memref<?xi32>, memref<?xi64>, memref<?xf64>>
-//
-// CHECK-STORAGE-LABEL: func @sparse_nop(
-//  CHECK-STORAGE-SAME: %[[A0:.*0]]: memref<1xindex>,
-//  CHECK-STORAGE-SAME: %[[A1:.*1]]: memref<?xi32>,
-//  CHECK-STORAGE-SAME: %[[A2:.*2]]: memref<?xi64>,
-//  CHECK-STORAGE-SAME: %[[A3:.*3]]: memref<?xf64>)
-//       CHECK-STORAGE: return %[[A0]], %[[A1]], %[[A2]], %[[A3]] : memref<1xindex>, memref<?xi32>, memref<?xi64>, memref<?xf64>
+// CHECK-LABEL: func @sparse_nop(
+//  CHECK-SAME: %[[A0:.*0]]: memref<1xindex>,
+//  CHECK-SAME: %[[A1:.*1]]: memref<?xi32>,
+//  CHECK-SAME: %[[A2:.*2]]: memref<?xi64>,
+//  CHECK-SAME: %[[A3:.*3]]: memref<?xf64>)
+//       CHECK: return %[[A0]], %[[A1]], %[[A2]], %[[A3]] : memref<1xindex>, memref<?xi32>, memref<?xi64>, memref<?xf64>
 func.func @sparse_nop(%arg0: tensor<?xf64, #SparseVector>) -> tensor<?xf64, #SparseVector> {
   return %arg0 : tensor<?xf64, #SparseVector>
 }
 
-// CHECK-CODEGEN-LABEL: func @sparse_nop_cast(
-//  CHECK-CODEGEN-SAME: %[[A:.*]]: tuple<memref<1xindex>, memref<?xi32>, memref<?xi64>, memref<?xf32>>)
-//       CHECK-CODEGEN: return %[[A]] : tuple<memref<1xindex>, memref<?xi32>, memref<?xi64>, memref<?xf32>>
+// CHECK-LABEL: func @sparse_nop_multi_ret(
+//  CHECK-SAME: %[[A0:.*0]]: memref<1xindex>,
+//  CHECK-SAME: %[[A1:.*1]]: memref<?xi32>,
+//  CHECK-SAME: %[[A2:.*2]]: memref<?xi64>,
+//  CHECK-SAME: %[[A3:.*3]]: memref<?xf64>,
+//  CHECK-SAME: %[[A4:.*4]]: memref<1xindex>,
+//  CHECK-SAME: %[[A5:.*5]]: memref<?xi32>,
+//  CHECK-SAME: %[[A6:.*6]]: memref<?xi64>,
+//  CHECK-SAME: %[[A7:.*7]]: memref<?xf64>) ->
+//       CHECK: return %[[A0]], %[[A1]], %[[A2]], %[[A3]], %[[A4]], %[[A5]], %[[A6]], %[[A7]]
+func.func @sparse_nop_multi_ret(%arg0: tensor<?xf64, #SparseVector>,
+                                %arg1: tensor<?xf64, #SparseVector>) ->
+                                (tensor<?xf64, #SparseVector>, tensor<?xf64, #SparseVector>) {
+  return %arg0, %arg1 : tensor<?xf64, #SparseVector>, tensor<?xf64, #SparseVector>
+}
+
+// CHECK-LABEL: func @sparse_nop_call(
+//  CHECK-SAME: %[[A0:.*0]]: memref<1xindex>,
+//  CHECK-SAME: %[[A1:.*1]]: memref<?xi32>,
+//  CHECK-SAME: %[[A2:.*2]]: memref<?xi64>,
+//  CHECK-SAME: %[[A3:.*3]]: memref<?xf64>,
+//  CHECK-SAME: %[[A4:.*4]]: memref<1xindex>,
+//  CHECK-SAME: %[[A5:.*5]]: memref<?xi32>,
+//  CHECK-SAME: %[[A6:.*6]]: memref<?xi64>,
+//  CHECK-SAME: %[[A7:.*7]]: memref<?xf64>) 
+//       CHECK: %[[T0:.*]]:8 = call @sparse_nop_multi_ret(%[[A0]], %[[A1]], %[[A2]], %[[A3]], %[[A4]], %[[A5]], %[[A6]], %[[A7]]) 
+//       CHECK: return %[[T0]]#0, %[[T0]]#1, %[[T0]]#2, %[[T0]]#3, %[[T0]]#4, %[[T0]]#5, %[[T0]]#6, %[[T0]]#7 
+func.func @sparse_nop_call(%arg0: tensor<?xf64, #SparseVector>,
+                           %arg1: tensor<?xf64, #SparseVector>) ->
+                           (tensor<?xf64, #SparseVector>, tensor<?xf64, #SparseVector>) {
+  %1, %2 = call @sparse_nop_multi_ret(%arg0, %arg1) :
+                           (tensor<?xf64, #SparseVector>, tensor<?xf64, #SparseVector>) ->
+                           (tensor<?xf64, #SparseVector>, tensor<?xf64, #SparseVector>)
+  return %1, %2: tensor<?xf64, #SparseVector>, tensor<?xf64, #SparseVector>
+}
+
 //
-// CHECK-STORAGE-LABEL: func @sparse_nop_cast(
-//  CHECK-STORAGE-SAME: %[[A0:.*0]]: memref<1xindex>,
-//  CHECK-STORAGE-SAME: %[[A1:.*1]]: memref<?xi32>,
-//  CHECK-STORAGE-SAME: %[[A2:.*2]]: memref<?xi64>,
-//  CHECK-STORAGE-SAME: %[[A3:.*3]]: memref<?xf32>)
-//       CHECK-STORAGE: return %[[A0]], %[[A1]], %[[A2]], %[[A3]] : memref<1xindex>, memref<?xi32>, memref<?xi64>, memref<?xf32>
+// CHECK-LABEL: func @sparse_nop_cast(
+//  CHECK-SAME: %[[A0:.*0]]: memref<1xindex>,
+//  CHECK-SAME: %[[A1:.*1]]: memref<?xi32>,
+//  CHECK-SAME: %[[A2:.*2]]: memref<?xi64>,
+//  CHECK-SAME: %[[A3:.*3]]: memref<?xf32>)
+//       CHECK: return %[[A0]], %[[A1]], %[[A2]], %[[A3]] : memref<1xindex>, memref<?xi32>, memref<?xi64>, memref<?xf32>
 func.func @sparse_nop_cast(%arg0: tensor<64xf32, #SparseVector>) -> tensor<?xf32, #SparseVector> {
   %0 = tensor.cast %arg0 : tensor<64xf32, #SparseVector> to tensor<?xf32, #SparseVector>
   return %0 : tensor<?xf32, #SparseVector>
 }
 
-// CHECK-CODEGEN-LABEL: func @sparse_nop_cast_3d(
-//  CHECK-CODEGEN-SAME: %[[A:.*]]: tuple<memref<3xindex>, memref<?xf32>>)
-//       CHECK-CODEGEN: return %[[A]] : tuple<memref<3xindex>, memref<?xf32>>
 //
-// CHECK-STORAGE-LABEL: func @sparse_nop_cast_3d(
-//  CHECK-STORAGE-SAME: %[[A0:.*0]]: memref<3xindex>,
-//  CHECK-STORAGE-SAME: %[[A1:.*1]]: memref<?xf32>)
-//       CHECK-STORAGE: return %[[A0]], %[[A1]] : memref<3xindex>, memref<?xf32>
+// CHECK-LABEL: func @sparse_nop_cast_3d(
+//  CHECK-SAME: %[[A0:.*0]]: memref<3xindex>,
+//  CHECK-SAME: %[[A1:.*1]]: memref<?xf32>)
+//       CHECK: return %[[A0]], %[[A1]] : memref<3xindex>, memref<?xf32>
 func.func @sparse_nop_cast_3d(%arg0: tensor<10x20x30xf32, #Dense3D>) -> tensor<?x?x?xf32, #Dense3D> {
   %0 = tensor.cast %arg0 : tensor<10x20x30xf32, #Dense3D> to tensor<?x?x?xf32, #Dense3D>
   return %0 : tensor<?x?x?xf32, #Dense3D>
 }
 
-// CHECK-CODEGEN-LABEL: func @sparse_dense_2d(
-//  CHECK-CODEGEN-SAME: %[[A:.*]]: tuple<memref<2xindex>, memref<?xf64>>)
 //
-// CHECK-STORAGE-LABEL: func @sparse_dense_2d(
-//  CHECK-STORAGE-SAME: %[[A0:.*0]]: memref<2xindex>,
-//  CHECK-STORAGE-SAME: %[[A1:.*1]]: memref<?xf64>) {
-//       CHECK-STORAGE: return
+// CHECK-LABEL: func @sparse_dense_2d(
+//  CHECK-SAME: %[[A0:.*0]]: memref<2xindex>,
+//  CHECK-SAME: %[[A1:.*1]]: memref<?xf64>) {
+//       CHECK: return
 func.func @sparse_dense_2d(%arg0: tensor<?x?xf64, #Dense2D>) {
   return
 }
 
-// CHECK-CODEGEN-LABEL: func @sparse_row(
-//  CHECK-CODEGEN-SAME: %[[A:.*]]: tuple<memref<2xindex>, memref<?xi32>, memref<?xi64>, memref<?xf64>>)
 //
-// CHECK-STORAGE-LABEL: func @sparse_row(
-//  CHECK-STORAGE-SAME: %[[A0:.*0]]: memref<2xindex>,
-//  CHECK-STORAGE-SAME: %[[A1:.*1]]: memref<?xi32>,
-//  CHECK-STORAGE-SAME: %[[A2:.*2]]: memref<?xi64>,
-//  CHECK-STORAGE-SAME: %[[A3:.*3]]: memref<?xf64>) {
-//       CHECK-STORAGE: return
+// CHECK-LABEL: func @sparse_row(
+//  CHECK-SAME: %[[A0:.*0]]: memref<2xindex>,
+//  CHECK-SAME: %[[A1:.*1]]: memref<?xi32>,
+//  CHECK-SAME: %[[A2:.*2]]: memref<?xi64>,
+//  CHECK-SAME: %[[A3:.*3]]: memref<?xf64>) {
+//       CHECK: return
 func.func @sparse_row(%arg0: tensor<?x?xf64, #Row>) {
   return
 }
 
-// CHECK-CODEGEN-LABEL: func @sparse_csr(
-//  CHECK-CODEGEN-SAME: %[[A:.*]]: tuple<memref<2xindex>, memref<?xi32>, memref<?xi64>, memref<?xf64>>)
 //
-// CHECK-STORAGE-LABEL: func @sparse_csr(
-//  CHECK-STORAGE-SAME: %[[A0:.*0]]: memref<2xindex>,
-//  CHECK-STORAGE-SAME: %[[A1:.*1]]: memref<?xi32>,
-//  CHECK-STORAGE-SAME: %[[A2:.*2]]: memref<?xi64>,
-//  CHECK-STORAGE-SAME: %[[A3:.*3]]: memref<?xf64>) {
-//       CHECK-STORAGE: return
+// CHECK-LABEL: func @sparse_csr(
+//  CHECK-SAME: %[[A0:.*0]]: memref<2xindex>,
+//  CHECK-SAME: %[[A1:.*1]]: memref<?xi32>,
+//  CHECK-SAME: %[[A2:.*2]]: memref<?xi64>,
+//  CHECK-SAME: %[[A3:.*3]]: memref<?xf64>) {
+//       CHECK: return
 func.func @sparse_csr(%arg0: tensor<?x?xf64, #CSR>) {
   return
 }
 
-// CHECK-CODEGEN-LABEL: func @sparse_dcsr(
-//  CHECK-CODEGEN-SAME: %[[A:.*]]: tuple<memref<2xindex>, memref<?xi32>, memref<?xi64>, memref<?xi32>, memref<?xi64>, memref<?xf64>>)
 //
-// CHECK-STORAGE-LABEL: func @sparse_dcsr(
-//  CHECK-STORAGE-SAME: %[[A0:.*0]]: memref<2xindex>,
-//  CHECK-STORAGE-SAME: %[[A1:.*1]]: memref<?xi32>,
-//  CHECK-STORAGE-SAME: %[[A2:.*2]]: memref<?xi64>,
-//  CHECK-STORAGE-SAME: %[[A3:.*3]]: memref<?xi32>,
-//  CHECK-STORAGE-SAME: %[[A4:.*4]]: memref<?xi64>,
-//  CHECK-STORAGE-SAME: %[[A5:.*5]]: memref<?xf64>) {
-//       CHECK-STORAGE: return
+// CHECK-LABEL: func @sparse_dcsr(
+//  CHECK-SAME: %[[A0:.*0]]: memref<2xindex>,
+//  CHECK-SAME: %[[A1:.*1]]: memref<?xi32>,
+//  CHECK-SAME: %[[A2:.*2]]: memref<?xi64>,
+//  CHECK-SAME: %[[A3:.*3]]: memref<?xi32>,
+//  CHECK-SAME: %[[A4:.*4]]: memref<?xi64>,
+//  CHECK-SAME: %[[A5:.*5]]: memref<?xf64>) {
+//       CHECK: return
 func.func @sparse_dcsr(%arg0: tensor<?x?xf64, #DCSR>) {
   return
 }
@@ -139,16 +156,12 @@ func.func @sparse_dcsr(%arg0: tensor<?x?xf64, #DCSR>) {
 // Querying for dimension 1 in the tensor type can immediately
 // fold using the original static dimension sizes.
 //
-// CHECK-CODEGEN-LABEL: func @sparse_dense_3d(
-//  CHECK-CODEGEN-SAME: %[[A:.*]]: tuple<memref<3xindex>, memref<?xf64>>)
-//       CHECK-CODEGEN: %[[C:.*]] = arith.constant 20 : index
-//       CHECK-CODEGEN: return %[[C]] : index
 //
-// CHECK-STORAGE-LABEL: func @sparse_dense_3d(
-//  CHECK-STORAGE-SAME: %[[A0:.*0]]: memref<3xindex>,
-//  CHECK-STORAGE-SAME: %[[A1:.*1]]: memref<?xf64>)
-//       CHECK-STORAGE: %[[C:.*]] = arith.constant 20 : index
-//       CHECK-STORAGE: return %[[C]] : index
+// CHECK-LABEL: func @sparse_dense_3d(
+//  CHECK-SAME: %[[A0:.*0]]: memref<3xindex>,
+//  CHECK-SAME: %[[A1:.*1]]: memref<?xf64>)
+//       CHECK: %[[C:.*]] = arith.constant 20 : index
+//       CHECK: return %[[C]] : index
 func.func @sparse_dense_3d(%arg0: tensor<10x20x30xf64, #Dense3D>) -> index {
   %c = arith.constant 1 : index
   %0 = tensor.dim %arg0, %c : tensor<10x20x30xf64, #Dense3D>
@@ -160,103 +173,74 @@ func.func @sparse_dense_3d(%arg0: tensor<10x20x30xf64, #Dense3D>) -> index {
 // into querying for dimension 2 in the stored sparse tensor scheme,
 // since the latter honors the dimOrdering.
 //
-// CHECK-CODEGEN-LABEL: func @sparse_dense_3d_dyn(
-//  CHECK-CODEGEN-SAME: %[[A:.*]]: tuple<memref<3xindex>, memref<?xf64>>)
-//       CHECK-CODEGEN: %[[C:.*]] = arith.constant 2 : index
-//       CHECK-CODEGEN: %[[F:.*]] = sparse_tensor.storage_get %[[A]][0] : tuple<memref<3xindex>, memref<?xf64>> to memref<3xindex>
-//       CHECK-CODEGEN: %[[L:.*]] = memref.load %[[F]][%[[C]]] : memref<3xindex>
-//       CHECK-CODEGEN: return %[[L]] : index
 //
-// CHECK-STORAGE-LABEL: func @sparse_dense_3d_dyn(
-//  CHECK-STORAGE-SAME: %[[A0:.*0]]: memref<3xindex>,
-//  CHECK-STORAGE-SAME: %[[A1:.*1]]: memref<?xf64>)
-//       CHECK-STORAGE: %[[C:.*]] = arith.constant 2 : index
-//       CHECK-STORAGE: %[[L:.*]] = memref.load %[[A0]][%[[C]]] : memref<3xindex>
-//       CHECK-STORAGE: return %[[L]] : index
+// CHECK-LABEL: func @sparse_dense_3d_dyn(
+//  CHECK-SAME: %[[A0:.*0]]: memref<3xindex>,
+//  CHECK-SAME: %[[A1:.*1]]: memref<?xf64>)
+//       CHECK: %[[C:.*]] = arith.constant 2 : index
+//       CHECK: %[[L:.*]] = memref.load %[[A0]][%[[C]]] : memref<3xindex>
+//       CHECK: return %[[L]] : index
 func.func @sparse_dense_3d_dyn(%arg0: tensor<?x?x?xf64, #Dense3D>) -> index {
   %c = arith.constant 1 : index
   %0 = tensor.dim %arg0, %c : tensor<?x?x?xf64, #Dense3D>
   return %0 : index
 }
 
-// CHECK-CODEGEN-LABEL: func @sparse_pointers_dcsr(
-//  CHECK-CODEGEN-SAME: %[[A:.*]]: tuple<memref<2xindex>, memref<?xi32>, memref<?xi64>, memref<?xi32>, memref<?xi64>, memref<?xf64>>)
-//       CHECK-CODEGEN: %[[F:.*]] = sparse_tensor.storage_get %[[A]][3] : tuple<memref<2xindex>, memref<?xi32>, memref<?xi64>, memref<?xi32>, memref<?xi64>, memref<?xf64>> to memref<?xi32>
-//       CHECK-CODEGEN: return %[[F]] : memref<?xi32>
 //
-// CHECK-STORAGE-LABEL: func @sparse_pointers_dcsr(
-//  CHECK-STORAGE-SAME: %[[A0:.*0]]: memref<2xindex>,
-//  CHECK-STORAGE-SAME: %[[A1:.*1]]: memref<?xi32>,
-//  CHECK-STORAGE-SAME: %[[A2:.*2]]: memref<?xi64>,
-//  CHECK-STORAGE-SAME: %[[A3:.*3]]: memref<?xi32>,
-//  CHECK-STORAGE-SAME: %[[A4:.*4]]: memref<?xi64>,
-//  CHECK-STORAGE-SAME: %[[A5:.*5]]: memref<?xf64>)
-//       CHECK-STORAGE: return %[[A3]] : memref<?xi32>
+// CHECK-LABEL: func @sparse_pointers_dcsr(
+//  CHECK-SAME: %[[A0:.*0]]: memref<2xindex>,
+//  CHECK-SAME: %[[A1:.*1]]: memref<?xi32>,
+//  CHECK-SAME: %[[A2:.*2]]: memref<?xi64>,
+//  CHECK-SAME: %[[A3:.*3]]: memref<?xi32>,
+//  CHECK-SAME: %[[A4:.*4]]: memref<?xi64>,
+//  CHECK-SAME: %[[A5:.*5]]: memref<?xf64>)
+//       CHECK: return %[[A3]] : memref<?xi32>
 func.func @sparse_pointers_dcsr(%arg0: tensor<?x?xf64, #DCSR>) -> memref<?xi32> {
   %c = arith.constant 1 : index
   %0 = sparse_tensor.pointers %arg0, %c : tensor<?x?xf64, #DCSR> to memref<?xi32>
   return %0 : memref<?xi32>
 }
 
-// CHECK-CODEGEN-LABEL: func @sparse_indices_dcsr(
-//  CHECK-CODEGEN-SAME: %[[A:.*]]: tuple<memref<2xindex>, memref<?xi32>, memref<?xi64>, memref<?xi32>, memref<?xi64>, memref<?xf64>>)
-//       CHECK-CODEGEN: %[[F:.*]] = sparse_tensor.storage_get %[[A]][4] : tuple<memref<2xindex>, memref<?xi32>, memref<?xi64>, memref<?xi32>, memref<?xi64>, memref<?xf64>> to memref<?xi64>
-//       CHECK-CODEGEN: return %[[F]] : memref<?xi64>
 //
-// CHECK-STORAGE-LABEL: func @sparse_indices_dcsr(
-//  CHECK-STORAGE-SAME: %[[A0:.*0]]: memref<2xindex>,
-//  CHECK-STORAGE-SAME: %[[A1:.*1]]: memref<?xi32>,
-//  CHECK-STORAGE-SAME: %[[A2:.*2]]: memref<?xi64>,
-//  CHECK-STORAGE-SAME: %[[A3:.*3]]: memref<?xi32>,
-//  CHECK-STORAGE-SAME: %[[A4:.*4]]: memref<?xi64>,
-//  CHECK-STORAGE-SAME: %[[A5:.*5]]: memref<?xf64>)
-//       CHECK-STORAGE: return %[[A4]] : memref<?xi64>
+// CHECK-LABEL: func @sparse_indices_dcsr(
+//  CHECK-SAME: %[[A0:.*0]]: memref<2xindex>,
+//  CHECK-SAME: %[[A1:.*1]]: memref<?xi32>,
+//  CHECK-SAME: %[[A2:.*2]]: memref<?xi64>,
+//  CHECK-SAME: %[[A3:.*3]]: memref<?xi32>,
+//  CHECK-SAME: %[[A4:.*4]]: memref<?xi64>,
+//  CHECK-SAME: %[[A5:.*5]]: memref<?xf64>)
+//       CHECK: return %[[A4]] : memref<?xi64>
 func.func @sparse_indices_dcsr(%arg0: tensor<?x?xf64, #DCSR>) -> memref<?xi64> {
   %c = arith.constant 1 : index
   %0 = sparse_tensor.indices %arg0, %c : tensor<?x?xf64, #DCSR> to memref<?xi64>
   return %0 : memref<?xi64>
 }
 
-// CHECK-CODEGEN-LABEL: func @sparse_values_dcsr(
-//  CHECK-CODEGEN-SAME: %[[A:.*]]: tuple<memref<2xindex>, memref<?xi32>, memref<?xi64>, memref<?xi32>, memref<?xi64>, memref<?xf64>>)
-//       CHECK-CODEGEN: %[[F:.*]] = sparse_tensor.storage_get %[[A]][5] : tuple<memref<2xindex>, memref<?xi32>, memref<?xi64>, memref<?xi32>, memref<?xi64>, memref<?xf64>> to memref<?xf64>
-//       CHECK-CODEGEN: return %[[F]] : memref<?xf64>
 //
-// CHECK-STORAGE-LABEL: func @sparse_values_dcsr(
-//  CHECK-STORAGE-SAME: %[[A0:.*0]]: memref<2xindex>,
-//  CHECK-STORAGE-SAME: %[[A1:.*1]]: memref<?xi32>,
-//  CHECK-STORAGE-SAME: %[[A2:.*2]]: memref<?xi64>,
-//  CHECK-STORAGE-SAME: %[[A3:.*3]]: memref<?xi32>,
-//  CHECK-STORAGE-SAME: %[[A4:.*4]]: memref<?xi64>,
-//  CHECK-STORAGE-SAME: %[[A5:.*5]]: memref<?xf64>)
-//       CHECK-STORAGE: return %[[A5]] : memref<?xf64>
+// CHECK-LABEL: func @sparse_values_dcsr(
+//  CHECK-SAME: %[[A0:.*0]]: memref<2xindex>,
+//  CHECK-SAME: %[[A1:.*1]]: memref<?xi32>,
+//  CHECK-SAME: %[[A2:.*2]]: memref<?xi64>,
+//  CHECK-SAME: %[[A3:.*3]]: memref<?xi32>,
+//  CHECK-SAME: %[[A4:.*4]]: memref<?xi64>,
+//  CHECK-SAME: %[[A5:.*5]]: memref<?xf64>)
+//       CHECK: return %[[A5]] : memref<?xf64>
 func.func @sparse_values_dcsr(%arg0: tensor<?x?xf64, #DCSR>) -> memref<?xf64> {
   %0 = sparse_tensor.values %arg0 : tensor<?x?xf64, #DCSR> to memref<?xf64>
   return %0 : memref<?xf64>
 }
 
-// CHECK-CODEGEN-LABEL: func @sparse_dealloc_csr(
-//  CHECK-CODEGEN-SAME: %[[A:.*]]: tuple<memref<2xindex>, memref<?xi32>, memref<?xi64>, memref<?xf64>>)
-//       CHECK-CODEGEN: %[[F0:.*]] = sparse_tensor.storage_get %[[A]][0] : tuple<memref<2xindex>, memref<?xi32>, memref<?xi64>, memref<?xf64>> to memref<2xindex>
-//       CHECK-CODEGEN: memref.dealloc %[[F0]] : memref<2xindex>
-//       CHECK-CODEGEN: %[[F1:.*]] = sparse_tensor.storage_get %[[A]][1] : tuple<memref<2xindex>, memref<?xi32>, memref<?xi64>, memref<?xf64>> to memref<?xi32>
-//       CHECK-CODEGEN: memref.dealloc %[[F1]] : memref<?xi32>
-//       CHECK-CODEGEN: %[[F2:.*]] = sparse_tensor.storage_get %[[A]][2] : tuple<memref<2xindex>, memref<?xi32>, memref<?xi64>, memref<?xf64>> to memref<?xi64>
-//       CHECK-CODEGEN: memref.dealloc %[[F2]] : memref<?xi64>
-//       CHECK-CODEGEN: %[[F3:.*]] = sparse_tensor.storage_get %[[A]][3] : tuple<memref<2xindex>, memref<?xi32>, memref<?xi64>, memref<?xf64>> to memref<?xf64>
-//       CHECK-CODEGEN: memref.dealloc %[[F3]] : memref<?xf64>
-//       CHECK-CODEGEN: return
 //
-// CHECK-STORAGE-LABEL: func @sparse_dealloc_csr(
-//  CHECK-STORAGE-SAME: %[[A0:.*0]]: memref<2xindex>,
-//  CHECK-STORAGE-SAME: %[[A1:.*1]]: memref<?xi32>,
-//  CHECK-STORAGE-SAME: %[[A2:.*2]]: memref<?xi64>,
-//  CHECK-STORAGE-SAME: %[[A3:.*3]]: memref<?xf64>) {
-//       CHECK-STORAGE: memref.dealloc %[[A0]] : memref<2xindex>
-//       CHECK-STORAGE: memref.dealloc %[[A1]] : memref<?xi32>
-//       CHECK-STORAGE: memref.dealloc %[[A2]] : memref<?xi64>
-//       CHECK-STORAGE: memref.dealloc %[[A3]] : memref<?xf64>
-//       CHECK-STORAGE: return
+// CHECK-LABEL: func @sparse_dealloc_csr(
+//  CHECK-SAME: %[[A0:.*0]]: memref<2xindex>,
+//  CHECK-SAME: %[[A1:.*1]]: memref<?xi32>,
+//  CHECK-SAME: %[[A2:.*2]]: memref<?xi64>,
+//  CHECK-SAME: %[[A3:.*3]]: memref<?xf64>) {
+//       CHECK: memref.dealloc %[[A0]] : memref<2xindex>
+//       CHECK: memref.dealloc %[[A1]] : memref<?xi32>
+//       CHECK: memref.dealloc %[[A2]] : memref<?xi64>
+//       CHECK: memref.dealloc %[[A3]] : memref<?xf64>
+//       CHECK: return
 func.func @sparse_dealloc_csr(%arg0: tensor<?x?xf64, #CSR>) {
   bufferization.dealloc_tensor %arg0 : tensor<?x?xf64, #CSR>
   return
@@ -264,8 +248,7 @@ func.func @sparse_dealloc_csr(%arg0: tensor<?x?xf64, #CSR>) {
 
 //        CHECK-LABEL: func @sparse_alloc_csc(
 //         CHECK-SAME: %[[A:.*]]: index) ->
-// CHECK-CODEGEN-SAME: tuple<memref<2xindex>, memref<?xindex>, memref<?xindex>, memref<?xf64>>
-// CHECK-STORAGE-SAME: memref<2xindex>, memref<?xindex>, memref<?xindex>, memref<?xf64>
+//         CHECK-SAME: memref<2xindex>, memref<?xindex>, memref<?xindex>, memref<?xf64>
 //          CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
 //          CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
 //          CHECK-DAG: %[[C10:.*]] = arith.constant 10 : index
@@ -278,9 +261,7 @@ func.func @sparse_dealloc_csr(%arg0: tensor<?x?xf64, #CSR>) {
 //              CHECK: %[[T4:.*]] = memref.cast %[[T3]] : memref<1xindex> to memref<?xindex>
 //              CHECK: %[[T5:.*]] = memref.alloc() : memref<1xf64>
 //              CHECK: %[[T6:.*]] = memref.cast %[[T5]] : memref<1xf64> to memref<?xf64>
-//      CHECK-CODEGEN: %[[T:.*]] = sparse_tensor.storage(%[[T0]], %[[T2]], %[[T4]], %[[T6]])
-//      CHECK-CODEGEN: return %[[T]]
-//      CHECK-STORAGE: return %[[T0]], %[[T2]], %[[T4]], %[[T6]] 
+//              CHECK: return %[[T0]], %[[T2]], %[[T4]], %[[T6]]
 func.func @sparse_alloc_csc(%arg0: index) -> tensor<10x?xf64, #CSC> {
   %0 = bufferization.alloc_tensor(%arg0) : tensor<10x?xf64, #CSC>
   %1 = sparse_tensor.load %0 : tensor<10x?xf64, #CSC>
@@ -288,8 +269,7 @@ func.func @sparse_alloc_csc(%arg0: index) -> tensor<10x?xf64, #CSC> {
 }
 
 //        CHECK-LABEL: func @sparse_alloc_3d() ->
-// CHECK-CODEGEN-SAME: tuple<memref<3xindex>, memref<?xf64>>
-// CHECK-STORAGE-SAME: memref<3xindex>, memref<?xf64>
+//         CHECK-SAME: memref<3xindex>, memref<?xf64>
 //          CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
 //          CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
 //          CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
@@ -302,9 +282,7 @@ func.func @sparse_alloc_csc(%arg0: index) -> tensor<10x?xf64, #CSC> {
 //              CHECK: memref.store %[[C20]], %[[A0]][%[[C2]]] : memref<3xindex>
 //              CHECK: %[[A:.*]] = memref.alloc() : memref<6000xf64>
 //              CHECK: %[[A1:.*]] = memref.cast %[[A]] : memref<6000xf64> to memref<?xf64>
-//      CHECK-CODEGEN: %[[T:.*]] = sparse_tensor.storage(%[[A0]], %[[A1]])
-//      CHECK-CODEGEN: return %[[T]] : tuple<memref<3xindex>, memref<?xf64>>
-//      CHECK-STORAGE: return %[[A0]], %[[A1]] : memref<3xindex>, memref<?xf64>
+//              CHECK: return %[[A0]], %[[A1]] : memref<3xindex>, memref<?xf64>
 func.func @sparse_alloc_3d() -> tensor<10x20x30xf64, #Dense3D> {
   %0 = bufferization.alloc_tensor() : tensor<10x20x30xf64, #Dense3D>
   %1 = sparse_tensor.load %0 : tensor<10x20x30xf64, #Dense3D>

diff  --git a/mlir/test/Dialect/SparseTensor/invalid.mlir b/mlir/test/Dialect/SparseTensor/invalid.mlir
index b9555e8861a25..ce495e0c7f227 100644
--- a/mlir/test/Dialect/SparseTensor/invalid.mlir
+++ b/mlir/test/Dialect/SparseTensor/invalid.mlir
@@ -442,63 +442,3 @@ func.func @invalid_concat_size_mismatch(%arg0: tensor<2x4xf64, #DC>,
          tensor<4x4xf64, #DC> to tensor<9x4xf64, #DC>
   return %0 : tensor<9x4xf64, #DC>
 }
-
-// -----
-
-func.func @sparse_storage_new(%arg0: memref<?xf64>, %arg1: memref<?xf64>, %arg2: f64) ->
-                               tuple<memref<?xf64>, memref<?xf64>> {
-  // expected-error at +1{{The number of inputs is inconsistent with output}}
-  %0 = sparse_tensor.storage(%arg0, %arg1, %arg2)
-       : memref<?xf64>, memref<?xf64>, f64 to tuple<memref<?xf64>, memref<?xf64>>
-  return %0 : tuple<memref<?xf64>, memref<?xf64>>
-}
-
-// -----
-
-func.func @sparse_storage_new(%arg0: memref<?xf64>, %arg1: memref<?xf64>, %arg2: f64) ->
-                               tuple<memref<?xi64>, memref<?xf64>, f64> {
-  // expected-error at +1{{Type mismatch between}}
-  %0 = sparse_tensor.storage(%arg0, %arg1, %arg2)
-       : memref<?xf64>, memref<?xf64>, f64 to tuple<memref<?xi64>, memref<?xf64>, f64>
-  return %0 : tuple<memref<?xi64>, memref<?xf64>, f64>
-}
-
-// -----
-
-func.func @sparse_storage_get(%arg0: tuple<memref<?xf64>, memref<?xf64>, f64>) -> memref<?xf64> {
-  // expected-error at +1{{Out-of-bound access}}
-  %0 = sparse_tensor.storage_get %arg0[3]
-       : tuple<memref<?xf64>, memref<?xf64>, f64> to
-         memref<?xf64>
-  return %0 : memref<?xf64>
-}
-
-// -----
-
-func.func @sparse_storage_get(%arg0: tuple<memref<?xf64>, memref<?xf64>, f64>) -> memref<?xf64> {
-  // expected-error at +1{{Type mismatch}}
-  %0 = sparse_tensor.storage_get %arg0[2]
-       : tuple<memref<?xf64>, memref<?xf64>, f64> to
-         memref<?xf64>
-  return %0 : memref<?xf64>
-}
-
-// -----
-
-func.func @sparse_storage_set(%arg0: tuple<memref<?xf64>, memref<?xf64>, f64>, %arg1: memref<?xf64>) -> tuple<memref<?xf64>, memref<?xf64>, f64> {
-  // expected-error at +1{{Out-of-bound access}}
-  %0 = sparse_tensor.storage_set %arg0[3], %arg1
-       : tuple<memref<?xf64>, memref<?xf64>, f64>, memref<?xf64> to
-         tuple<memref<?xf64>, memref<?xf64>, f64>
-  return %0 : tuple<memref<?xf64>, memref<?xf64>, f64>
-}
-
-// -----
-
-func.func @sparse_storage_set(%arg0: tuple<memref<?xf64>, memref<?xf64>, f64>, %arg1: memref<?xf64>) -> tuple<memref<?xf64>, memref<?xf64>, f64> {
-  // expected-error at +1{{Type mismatch}}
-  %0 = sparse_tensor.storage_set %arg0[2], %arg1
-       : tuple<memref<?xf64>, memref<?xf64>, f64>, memref<?xf64> to
-         tuple<memref<?xf64>, memref<?xf64>, f64>
-  return %0 : tuple<memref<?xf64>, memref<?xf64>, f64>
-}

diff  --git a/mlir/test/Dialect/SparseTensor/roundtrip.mlir b/mlir/test/Dialect/SparseTensor/roundtrip.mlir
index c37b4e7b53ac8..5edc977de7c00 100644
--- a/mlir/test/Dialect/SparseTensor/roundtrip.mlir
+++ b/mlir/test/Dialect/SparseTensor/roundtrip.mlir
@@ -314,50 +314,3 @@ func.func @concat_sparse_sparse(%arg0: tensor<2x4xf64, #SparseMatrix>,
          tensor<4x4xf64, #SparseMatrix> to tensor<9x4xf64, #SparseMatrix>
   return %0 : tensor<9x4xf64, #SparseMatrix>
 }
-
-// -----
-
-
-// CHECK: func @sparse_storage_new(
-//  CHECK-SAME: %[[A0:.*0]]: memref<?xf64>,
-//  CHECK-SAME: %[[A1:.*1]]: memref<?xf64>,
-//  CHECK-SAME: %[[A2:.*]]: f64
-//       CHECK: %[[TMP_0:.*]] = sparse_tensor.storage(%[[A0]], %[[A1]], %[[A2]])
-//       CHECK: return %[[TMP_0]] : tuple<memref<?xf64>, memref<?xf64>, f64>
-func.func @sparse_storage_new(%arg0: memref<?xf64>, %arg1: memref<?xf64>, %arg2: f64) ->
-                               tuple<memref<?xf64>, memref<?xf64>, f64> {
-  %0 = sparse_tensor.storage(%arg0, %arg1, %arg2)
-       : memref<?xf64>, memref<?xf64>, f64 to tuple<memref<?xf64>, memref<?xf64>, f64>
-  return %0 : tuple<memref<?xf64>, memref<?xf64>, f64>
-}
-
-// -----
-
-// CHECK-LABEL: func @sparse_storage_get(
-//  CHECK-SAME:   %[[A0:.*]]: tuple<memref<?xf64>, memref<?xf64>, f64>
-//       CHECK:   %[[TMP0:.*]] = sparse_tensor.storage_get %[[A0]][0] :
-//  CHECK-SAME:     tuple<memref<?xf64>, memref<?xf64>, f64>
-//  CHECK-SAME:     to memref<?xf64>
-//       CHECK:   return %[[TMP0]] : memref<?xf64>
-func.func @sparse_storage_get(%arg0: tuple<memref<?xf64>, memref<?xf64>, f64>) -> memref<?xf64> {
-  %0 = sparse_tensor.storage_get %arg0[0]
-       : tuple<memref<?xf64>, memref<?xf64>, f64> to memref<?xf64>
-  return %0 : memref<?xf64>
-}
-
-// -----
-
-// CHECK-LABEL: func @sparse_storage_set(
-//  CHECK-SAME:   %[[A0:.*]]: tuple<memref<?xf64>, memref<?xf64>, f64>,
-//  CHECK-SAME:   %[[A1:.*]]: memref<?xf64>
-//       CHECK:   %[[TMP0:.*]] = sparse_tensor.storage_set %[[A0]][0], %[[A1]] :
-//  CHECK-SAME:     tuple<memref<?xf64>, memref<?xf64>, f64>,
-//  CHECK-SAME:     memref<?xf64>
-//  CHECK-SAME:     to tuple<memref<?xf64>, memref<?xf64>, f64>
-//       CHECK:   return %0 : tuple<memref<?xf64>, memref<?xf64>, f64>
-func.func @sparse_storage_set(%arg0: tuple<memref<?xf64>, memref<?xf64>, f64>, %arg1: memref<?xf64>) -> tuple<memref<?xf64>, memref<?xf64>, f64> {
-  %0 = sparse_tensor.storage_set %arg0[0], %arg1
-       : tuple<memref<?xf64>, memref<?xf64>, f64>, memref<?xf64> to
-         tuple<memref<?xf64>, memref<?xf64>, f64>
-  return %0 : tuple<memref<?xf64>, memref<?xf64>, f64>
-}

diff  --git a/mlir/test/Dialect/SparseTensor/sparse_tensor_storage.mlir b/mlir/test/Dialect/SparseTensor/sparse_tensor_storage.mlir
deleted file mode 100644
index d2d4769353a3c..0000000000000
--- a/mlir/test/Dialect/SparseTensor/sparse_tensor_storage.mlir
+++ /dev/null
@@ -1,60 +0,0 @@
-// RUN: mlir-opt %s -sparse-tensor-storage-expansion -cse | FileCheck %s
-
-// CHECK-LABEL:  func @sparse_storage_expand(
-// CHECK-SAME:     %[[TMP_arg0:.*0]]: memref<?xf64>,
-// CHECK-SAME:     %[[TMP_arg1:.*1]]: memref<?xf64>,
-// CHECK-SAME:     %[[TMP_arg2:.*]]: f64
-// CHECK           return %[[TMP_arg0]], %[[TMP_arg1]], %[[TMP_arg2]]
-func.func @sparse_storage_expand(%arg0: tuple<memref<?xf64>, memref<?xf64>, f64>)
-                                     -> tuple<memref<?xf64>, memref<?xf64>, f64> {
-  return %arg0 : tuple<memref<?xf64>, memref<?xf64>, f64>
-}
-
-// CHECK-LABEL:  func @call_sparse_storage_expand(
-// CHECK-SAME:     %[[TMP_arg0:.*0]]: memref<?xf64>,
-// CHECK-SAME:     %[[TMP_arg1:.*1]]: memref<?xf64>,
-// CHECK-SAME:     %[[TMP_arg2:.*]]: f64)
-// CHECK:          %[[TMP_0:.*]]:3 = call @sparse_storage_expand(%[[TMP_arg0]], %[[TMP_arg1]], %[[TMP_arg2]])
-// CHECK:          return %[[TMP_0]]#0, %[[TMP_0]]#1, %[[TMP_0]]#2 : memref<?xf64>, memref<?xf64>, f64
-func.func @call_sparse_storage_expand(%arg0: tuple<memref<?xf64>, memref<?xf64>, f64>)
-                                          -> tuple<memref<?xf64>, memref<?xf64>, f64> {
-  %1 = call @sparse_storage_expand(%arg0) : (tuple<memref<?xf64>, memref<?xf64>, f64>) ->
-                                             tuple<memref<?xf64>, memref<?xf64>, f64>
-  return %1 : tuple<memref<?xf64>, memref<?xf64>, f64>
-}
-
-// CHECK-LABEL: func @sparse_storage(
-// CHECK-SAME:    %[[TMP_arg0:.*0]]: memref<?xf64>,
-// CHECK-SAME:    %[[TMP_arg1:.*1]]: memref<?xf64>,
-// CHECK-SAME:    %[[TMP_arg2:.*2]]: memref<?xf64>)
-// CHECK:         return %[[TMP_arg0]], %[[TMP_arg1]], %[[TMP_arg2]]
-func.func @sparse_storage(%arg0: memref<?xf64>, %arg1: memref<?xf64>, %arg2: memref<?xf64>)
-                        -> tuple<memref<?xf64>, memref<?xf64>, memref<?xf64>> {
-  %1 = sparse_tensor.storage(%arg0, %arg1, %arg2) : memref<?xf64>, memref<?xf64>, memref<?xf64> to tuple<memref<?xf64>, memref<?xf64>, memref<?xf64>>
-  return %1 : tuple<memref<?xf64>, memref<?xf64>, memref<?xf64>>
-}
-
-// CHECK-LABEL:  func @sparse_storage_get(
-// CHECK-SAME:     %[[TMP_arg0:.*0]]: memref<?xf64>,
-// CHECK-SAME:     %[[TMP_arg1:.*1]]: memref<?xf64>,
-// CHECK-SAME:     %[[TMP_arg2:.*]]: f64)
-// CHECK:          return %[[TMP_arg0]] : memref<?xf64>
-func.func @sparse_storage_get(%arg0: tuple<memref<?xf64>, memref<?xf64>, f64>) -> memref<?xf64> {
-  %0 = sparse_tensor.storage_get %arg0[0]
-       : tuple<memref<?xf64>, memref<?xf64>, f64> to memref<?xf64>
-  return %0 : memref<?xf64>
-}
-
-// CHECK-LABEL:  func @sparse_storage_set(
-// CHECK-SAME:     %[[TMP_arg0:.*0]]: memref<?xf64>,
-// CHECK-SAME:     %[[TMP_arg1:.*1]]: memref<?xf64>,
-// CHECK-SAME:     %[[TMP_arg2:.*]]: f64,
-// CHECK-SAME:     %[[TMP_arg3:.*]]: memref<?xf64>)
-// CHECK:          return %[[TMP_arg3]], %[[TMP_arg1]], %[[TMP_arg2]] : memref<?xf64>, memref<?xf64>, f64
-func.func @sparse_storage_set(%arg0: tuple<memref<?xf64>, memref<?xf64>, f64>,
-                              %arg1: memref<?xf64>) -> tuple<memref<?xf64>, memref<?xf64>, f64> {
-  %0 = sparse_tensor.storage_set %arg0[0], %arg1
-       : tuple<memref<?xf64>, memref<?xf64>, f64>, memref<?xf64> to
-         tuple<memref<?xf64>, memref<?xf64>, f64>
-  return %0 : tuple<memref<?xf64>, memref<?xf64>, f64>
-}


        


More information about the Mlir-commits mailing list