[Mlir-commits] [mlir] c66303c - [mlir][sparse] Switch to One-Shot Bufferize

Matthias Springer llvmlistbot at llvm.org
Thu Jul 14 00:53:40 PDT 2022


Author: Matthias Springer
Date: 2022-07-14T09:52:48+02:00
New Revision: c66303c2870c9a77a0f2a8aa16fd0ea87b0358e6

URL: https://github.com/llvm/llvm-project/commit/c66303c2870c9a77a0f2a8aa16fd0ea87b0358e6
DIFF: https://github.com/llvm/llvm-project/commit/c66303c2870c9a77a0f2a8aa16fd0ea87b0358e6.diff

LOG: [mlir][sparse] Switch to One-Shot Bufferize

This change removes the partial bufferization passes from the sparse compilation pipeline and replaces them with One-Shot Bufferize. One-Shot Analysis (and TensorCopyInsertion) is used to resolve all out-of-place bufferizations, dense and sparse. Dense ops are then bufferized with BufferizableOpInterface. Sparse ops are still bufferized in the Sparsification pass.

Details:
* Dense allocations are automatically deallocated, unless they are yielded from a block. (In that case the alloc would leak.) All test cases are modified accordingly. E.g., some funcs now have an "out" tensor argument that is returned from the function. (That way, the allocation happens at the call site.)
* Sparse allocations are *not* automatically deallocated. They must be "released" manually. (No change, this will be addressed in a future change.)
* Sparse tensor copies are not supported yet. (Future change)
* Sparsification no longer has to consider inplacability. If necessary, allocations and/or copies are inserted during TensorCopyInsertion. All tensors are inplaceable by the time Sparsification is running. Instead of marking a tensor as "not inplaceable", it can be marked as "not writable", which will trigger an allocation and/or copy during TensorCopyInsertion.

Differential Revision: https://reviews.llvm.org/D129356

Added: 
    mlir/lib/Dialect/SparseTensor/Transforms/DenseBufferizationPass.cpp

Modified: 
    mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
    mlir/include/mlir/Dialect/SparseTensor/Pipelines/Passes.h
    mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
    mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
    mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
    mlir/test/Dialect/SparseTensor/dense.mlir
    mlir/test/Dialect/SparseTensor/one_shot_bufferize_tensor_copy_insertion.mlir
    mlir/test/Dialect/SparseTensor/sparse_1d.mlir
    mlir/test/Dialect/SparseTensor/sparse_2d.mlir
    mlir/test/Dialect/SparseTensor/sparse_3d.mlir
    mlir/test/Dialect/SparseTensor/sparse_affine.mlir
    mlir/test/Dialect/SparseTensor/sparse_fp_ops.mlir
    mlir/test/Dialect/SparseTensor/sparse_int_ops.mlir
    mlir/test/Dialect/SparseTensor/sparse_kernels.mlir
    mlir/test/Dialect/SparseTensor/sparse_lower.mlir
    mlir/test/Dialect/SparseTensor/sparse_lower_col.mlir
    mlir/test/Dialect/SparseTensor/sparse_lower_inplace.mlir
    mlir/test/Dialect/SparseTensor/sparse_nd.mlir
    mlir/test/Dialect/SparseTensor/sparse_out.mlir
    mlir/test/Dialect/SparseTensor/sparse_outbuf.mlir
    mlir/test/Dialect/SparseTensor/sparse_perm.mlir
    mlir/test/Dialect/SparseTensor/sparse_perm_lower.mlir
    mlir/test/Dialect/SparseTensor/sparse_scalars.mlir
    mlir/test/Dialect/SparseTensor/sparse_vector.mlir
    mlir/test/Dialect/SparseTensor/sparse_vector_chain.mlir
    mlir/test/Dialect/SparseTensor/sparse_vector_index.mlir
    mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_binary.mlir
    mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_cast.mlir
    mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_conversion_sparse2dense.mlir
    mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_conversion_sparse2sparse.mlir
    mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_dot.mlir
    mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_filter_conv2d.mlir
    mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_flatten.mlir
    mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_matmul.mlir
    mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_matrix_ops.mlir
    mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_matvec.mlir
    mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_mttkrp.mlir
    mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_out_reduction.mlir
    mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_out_simple.mlir
    mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_quantized_matmul.mlir
    mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_reductions.mlir
    mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_reshape.mlir
    mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_sampled_matmul.mlir
    mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_scale.mlir
    mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_spmm.mlir
    mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_sum.mlir
    mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_sum_bf16.mlir
    mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_sum_c32.mlir
    mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_sum_f16.mlir
    mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_tanh.mlir
    mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_transpose.mlir
    mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_unary.mlir
    mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_vector_ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
index 2cc84c99d2040..c5f6c129feead 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
@@ -471,6 +471,10 @@ allocateTensorForShapedValue(OpBuilder &b, Location loc, Value shapedValue,
                              bool escape, const BufferizationOptions &options,
                              bool copy = true);
 
+/// Return `true` if the allocation of the given op is guaranteed to not escape
+/// the containing block.
+bool allocationDoesNotEscape(OpResult opResult);
+
 /// Lookup the buffer for the given value. If the value was not bufferized
 /// yet, wrap it in a ToMemrefOp. Otherwise, it is the result of a ToTensorOp,
 /// from which the memref operand is returned.

diff  --git a/mlir/include/mlir/Dialect/SparseTensor/Pipelines/Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Pipelines/Passes.h
index a7064a2508312..43abadd5a7627 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Pipelines/Passes.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Pipelines/Passes.h
@@ -44,6 +44,9 @@ struct SparseCompilerOptions
   PassOptions::Option<bool> enableVLAVectorization{
       *this, "enable-vla-vectorization",
       desc("Enable vector length agnostic vectorization"), init(false)};
+  PassOptions::Option<bool> testBufferizationAnalysisOnly{
+      *this, "test-bufferization-analysis-only",
+      desc("Run only the inplacability analysis"), init(false)};
 
   /// Projects out the options for `createSparsificationPass`.
   SparsificationOptions sparsificationOptions() const {

diff  --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
index 20a322e97dfc0..b7300dda22fae 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
@@ -23,6 +23,9 @@
 #include "mlir/Pass/Pass.h"
 
 namespace mlir {
+namespace bufferization {
+struct OneShotBufferizationOptions;
+} // namespace bufferization
 
 // Forward.
 class TypeConverter;
@@ -131,6 +134,8 @@ void populateSparseTensorConversionPatterns(
     const SparseTensorConversionOptions &options =
         SparseTensorConversionOptions());
 
+std::unique_ptr<Pass> createDenseBufferizationPass(
+    const bufferization::OneShotBufferizationOptions &options);
 std::unique_ptr<Pass> createSparseTensorConversionPass();
 std::unique_ptr<Pass>
 createSparseTensorConversionPass(const SparseTensorConversionOptions &options);

diff  --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index 5263695d02145..83d6fcc2e1f74 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -50,6 +50,22 @@ static Operation *getOwnerOfValue(Value value) {
   return value.cast<BlockArgument>().getOwner()->getParentOp();
 }
 
+bool bufferization::allocationDoesNotEscape(OpResult opResult) {
+#ifndef NDEBUG
+  auto bufferizableOp = opResult.getDefiningOp<BufferizableOpInterface>();
+  assert(bufferizableOp && bufferizableOp.bufferizesToAllocation(opResult) &&
+         "expected op that bufferizes to an allocation");
+#endif // NDEBUG
+
+  Operation *op = opResult.getDefiningOp();
+  // If there is no 'escape' attribute, we cannot say for sure.
+  if (!op->hasAttr(BufferizationDialect::kEscapeAttrName))
+    return false;
+  auto attr =
+      op->getAttrOfType<ArrayAttr>(BufferizationDialect::kEscapeAttrName);
+  return !attr[opResult.getResultNumber()].cast<BoolAttr>().getValue();
+}
+
 /// Create an AllocTensorOp for the given shaped value. If `copy` is set, the
 /// shaped value is copied. Otherwise, a tensor with undefined contents is
 /// allocated.

diff  --git a/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp b/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp
index 04898102433b8..59e62209c2bd1 100644
--- a/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp
@@ -9,20 +9,41 @@
 #include "mlir/Dialect/SparseTensor/Pipelines/Passes.h"
 
 #include "mlir/Conversion/Passes.h"
-#include "mlir/Dialect/Arithmetic/Transforms/Passes.h"
+#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
+#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
 #include "mlir/Dialect/Bufferization/Transforms/Passes.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
-#include "mlir/Dialect/Func/Transforms/Passes.h"
 #include "mlir/Dialect/Linalg/Passes.h"
 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
-#include "mlir/Dialect/Tensor/Transforms/Passes.h"
-#include "mlir/Dialect/Vector/Transforms/Passes.h"
 #include "mlir/Pass/PassManager.h"
 
 using namespace mlir;
 using namespace mlir::sparse_tensor;
 
+/// Return configuration options for One-Shot Bufferize.
+static bufferization::OneShotBufferizationOptions
+getBufferizationOptions(bool analysisOnly) {
+  using namespace bufferization;
+  OneShotBufferizationOptions options;
+  options.bufferizeFunctionBoundaries = true;
+  // TODO(springerm): To spot memory leaks more easily, returning dense allocs
+  // should be disallowed.
+  options.allowReturnAllocs = true;
+  options.functionBoundaryTypeConversion =
+      BufferizationOptions::LayoutMapOption::IdentityLayoutMap;
+  options.unknownTypeConverterFn = [](Value value, unsigned memorySpace,
+                                      const BufferizationOptions &options) {
+    return getMemRefTypeWithStaticIdentityLayout(
+        value.getType().cast<TensorType>(), memorySpace);
+  };
+  if (analysisOnly) {
+    options.testAnalysisOnly = true;
+    options.printConflicts = true;
+  }
+  return options;
+}
+
 //===----------------------------------------------------------------------===//
 // Pipeline implementation.
 //===----------------------------------------------------------------------===//
@@ -31,20 +52,28 @@ void mlir::sparse_tensor::buildSparseCompiler(
     OpPassManager &pm, const SparseCompilerOptions &options) {
   // TODO(wrengr): ensure the original `pm` is for ModuleOp
   pm.addNestedPass<func::FuncOp>(createLinalgGeneralizationPass());
-  pm.addPass(createLinalgElementwiseOpFusionPass());
+  // TODO(springerm): Reactivate element-wise op fusion pass. This pass does not
+  // fit well with bufferization because it replaces unused "out" operands of
+  // LinalgOps with InitTensorOps. This would result in additional buffer
+  // allocations during bufferization.
+  // pm.addPass(createLinalgElementwiseOpFusionPass());
+  pm.addPass(
+      bufferization::createTensorCopyInsertionPass(getBufferizationOptions(
+          /*analysisOnly=*/options.testBufferizationAnalysisOnly)));
+  if (options.testBufferizationAnalysisOnly)
+    return;
   pm.addPass(createSparsificationPass(options.sparsificationOptions()));
   pm.addPass(createSparseTensorConversionPass(
       options.sparseTensorConversionOptions()));
-  pm.addNestedPass<func::FuncOp>(createLinalgBufferizePass());
-  pm.addNestedPass<func::FuncOp>(vector::createVectorBufferizePass());
+  pm.addPass(createDenseBufferizationPass(
+      getBufferizationOptions(/*analysisOnly=*/false)));
+  pm.addNestedPass<func::FuncOp>(
+      mlir::bufferization::createFinalizingBufferizePass());
+  // TODO(springerm): Add sparse support to the BufferDeallocation pass and add
+  // it to this pipeline.
   pm.addNestedPass<func::FuncOp>(createConvertLinalgToLoopsPass());
   pm.addNestedPass<func::FuncOp>(createConvertVectorToSCFPass());
   pm.addNestedPass<func::FuncOp>(createConvertSCFToCFPass());
-  pm.addPass(func::createFuncBufferizePass());
-  pm.addPass(arith::createConstantBufferizePass());
-  pm.addNestedPass<func::FuncOp>(createTensorBufferizePass());
-  pm.addNestedPass<func::FuncOp>(
-      mlir::bufferization::createFinalizingBufferizePass());
   pm.addPass(createLowerAffinePass());
   pm.addPass(createConvertVectorToLLVMPass(options.lowerVectorToLLVMOptions()));
   pm.addPass(createMemRefToLLVMPass());

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
index 50c16594213a9..76bd31691dfd9 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
@@ -1,6 +1,7 @@
 add_mlir_dialect_library(MLIRSparseTensorTransforms
   BufferizableOpInterfaceImpl.cpp
   CodegenUtils.cpp
+  DenseBufferizationPass.cpp
   Sparsification.cpp
   SparseTensorConversion.cpp
   SparseTensorPasses.cpp

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/DenseBufferizationPass.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/DenseBufferizationPass.cpp
new file mode 100644
index 0000000000000..51e18be6ff21a
--- /dev/null
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/DenseBufferizationPass.cpp
@@ -0,0 +1,74 @@
+//===- DenseBufferizationPass.cpp - Dense bufferization pass --------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
+
+#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
+#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
+#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
+
+using namespace mlir;
+using namespace mlir::func;
+
+namespace mlir {
+namespace sparse_tensor {
+
+/// Return `true` if one of the given types is a sparse tensor type.
+static bool containsSparseTensor(TypeRange types) {
+  for (Type t : types)
+    if (getSparseTensorEncoding(t))
+      return true;
+  return false;
+}
+
+/// A pass that bufferizes only dense tensor ops and ignores all sparse tensor
+/// ops. No buffer copies are inserted. All tensor OpOperands must be
+/// inplacable.
+class BufferizeDenseOpsPass
+    : public PassWrapper<BufferizeDenseOpsPass, OperationPass<ModuleOp>> {
+public:
+  BufferizeDenseOpsPass(
+      const bufferization::OneShotBufferizationOptions &options)
+      : PassWrapper<BufferizeDenseOpsPass, OperationPass<ModuleOp>>(),
+        options(options) {}
+
+  void runOnOperation() override {
+    // Disallow all sparse tensor ops, so that only dense tensor ops are
+    // bufferized.
+    bufferization::OpFilter opFilter;
+    opFilter.allowOperation([&](Operation *op) {
+      if (containsSparseTensor(TypeRange(op->getResults())) ||
+          containsSparseTensor(TypeRange(op->getOperands())))
+        return false;
+      if (auto funcOp = dyn_cast<func::FuncOp>(op)) {
+        FunctionType funcType = funcOp.getFunctionType();
+        if (containsSparseTensor(funcType.getInputs()) ||
+            containsSparseTensor(funcType.getResults()))
+          return false;
+      }
+      return true;
+    });
+
+    if (failed(bufferization::bufferizeOp(getOperation(), options,
+                                          /*copyBeforeWrite=*/false,
+                                          &opFilter)))
+      signalPassFailure();
+  }
+
+private:
+  bufferization::OneShotBufferizationOptions options;
+};
+} // namespace sparse_tensor
+} // namespace mlir
+
+std::unique_ptr<Pass> mlir::createDenseBufferizationPass(
+    const bufferization::OneShotBufferizationOptions &options) {
+  return std::make_unique<mlir::sparse_tensor::BufferizeDenseOpsPass>(options);
+}

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
index 282af7aed2df5..1fad15dfb00bb 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
@@ -16,6 +16,7 @@
 
 #include "CodegenUtils.h"
 
+#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
@@ -30,6 +31,8 @@
 
 using namespace mlir;
 using namespace mlir::sparse_tensor;
+using mlir::bufferization::BufferizableOpInterface;
+using mlir::bufferization::BufferizationDialect;
 
 namespace {
 
@@ -320,8 +323,8 @@ static Value genIndexAndValueForSparse(OpBuilder &builder, Location loc,
   return builder.create<tensor::ExtractOp>(loc, values, ivs[0]);
 }
 
-/// Generates code to allocate a tensor of the given type, and zero
-/// initialize it.  If the tensor type has any dynamic sizes, then the
+/// Generates code to allocate a buffer of the given type, and zero
+/// initialize it.  If the buffer type has any dynamic sizes, then the
 /// `sizes` parameter should be as filled by sizesFromPtr(); that way
 /// we can reuse the genDimSizeCall() results generated by sizesFromPtr().
 static Value allocDenseTensor(OpBuilder &builder, Location loc,
@@ -340,6 +343,11 @@ static Value allocDenseTensor(OpBuilder &builder, Location loc,
   return mem;
 }
 
+/// Generates code to deallocate a dense buffer.
+static void deallocDenseTensor(OpBuilder &builder, Location loc, Value buffer) {
+  builder.create<memref::DeallocOp>(loc, buffer);
+}
+
 /// Inserts the element returned by genGetNextCall(_, ind, elemPtr) into
 /// the tensor created by allocDenseTensor().  The `rank` is the rank
 /// of the `tensor` and the length of `ind`.
@@ -618,6 +626,9 @@ class SparseTensorAllocConverter
   LogicalResult
   matchAndRewrite(bufferization::AllocTensorOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
+    if (op.getCopy())
+      return rewriter.notifyMatchFailure(op,
+                                         "sparse tensor copy not implemented");
     RankedTensorType resType = op.getType();
     auto enc = getSparseTensorEncoding(resType);
     if (!enc)
@@ -743,6 +754,9 @@ class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> {
       Value iter = genNewCall(rewriter, op, params);
       Value ind = genAlloca(rewriter, loc, rank, rewriter.getIndexType());
       Value elemPtr = genAllocaScalar(rewriter, loc, elemTp);
+      Block *insertionBlock = rewriter.getInsertionBlock();
+      // TODO: Dense buffers should be allocated/deallocated via the callback
+      // in BufferizationOptions.
       Value dst = allocDenseTensor(rewriter, loc, dstTensorTp, sizes);
       SmallVector<Value> noArgs;
       SmallVector<Type> noTypes;
@@ -758,6 +772,11 @@ class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> {
       rewriter.setInsertionPointAfter(whileOp);
       genDelCOOCall(rewriter, op, elemTp, iter);
       rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(op, resType, dst);
+      // Deallocate the buffer.
+      if (bufferization::allocationDoesNotEscape(op->getOpResult(0))) {
+        rewriter.setInsertionPoint(insertionBlock->getTerminator());
+        deallocDenseTensor(rewriter, loc, dst);
+      }
       return success();
     }
     if (!encDst && !encSrc) {

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
index 1f157eab3c57d..9d94e5b72e933 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
@@ -127,11 +127,17 @@ struct SparseTensorConversionPass
         });
     // The following operations and dialects may be introduced by the
     // rewriting rules, and are therefore marked as legal.
-    target.addLegalOp<complex::ConstantOp, complex::NotEqualOp, linalg::FillOp,
+    target.addLegalOp<bufferization::ToMemrefOp, bufferization::ToTensorOp,
+                      complex::ConstantOp, complex::NotEqualOp, linalg::FillOp,
                       linalg::YieldOp, tensor::ExtractOp>();
     target.addLegalDialect<
         arith::ArithmeticDialect, bufferization::BufferizationDialect,
         LLVM::LLVMDialect, memref::MemRefDialect, scf::SCFDialect>();
+    target.addDynamicallyLegalOp<bufferization::AllocTensorOp>(
+        [&](bufferization::AllocTensorOp op) {
+          // Dense tensors are legal, sparse tensors are not.
+          return !static_cast<bool>(op.getType().getEncoding());
+        });
     // Translate strategy flags to strategy options.
     SparseTensorConversionOptions options(
         sparseToSparseConversionStrategy(sparseToSparse));

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index 8170e49043a1b..53182243ab84c 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -308,17 +308,6 @@ static bool computeIterationGraph(Merger &merger, linalg::GenericOp op,
   return true;
 }
 
-/// Returns true if tensor has an in-place annotation.
-static bool isInPlace(Value val) {
-  if (auto arg = val.dyn_cast<BlockArgument>())
-    if (auto funcOp = dyn_cast<func::FuncOp>(arg.getOwner()->getParentOp()))
-      if (auto attr = funcOp.getArgAttrOfType<BoolAttr>(
-              arg.getArgNumber(),
-              bufferization::BufferizableOpInterface::kInplaceableAttrName))
-        return attr.getValue();
-  return false;
-}
-
 /// Returns true if tensor materializes uninitialized into the computation.
 static bool isMaterializing(Value val) {
   return val.getDefiningOp<linalg::InitTensorOp>() ||
@@ -355,9 +344,8 @@ static bool isAdmissableTensorExp(Merger &merger, linalg::GenericOp op,
     return true;
   // A tensor expression with a sparse output tensor that changes its values
   // but not its nonzero structure, an operation called "simply dynamic" in
-  // [Bik96,Ch9], is also admissable without special codegen, provided
-  // the tensor's underlying sparse storage scheme can be modified in-place.
-  if (merger.isSingleCondition(tensor, exp) && isInPlace(lhs->get()))
+  // [Bik96,Ch9], is also admissable without special codegen.
+  if (merger.isSingleCondition(tensor, exp))
     return true;
   // Accept "truly dynamic" if the output tensor materializes uninitialized
   // into the computation and insertions occur in lexicographic index order.
@@ -486,37 +474,19 @@ static Value genOutputBuffer(CodeGen &codegen, OpBuilder &builder,
   OpOperand *lhs = op.getOutputOperand(0);
   Value tensor = lhs->get();
   bool isInit = op.isInitTensor(lhs);
-  // An output tensor that is in-place can simply materialize from the buffer
-  // of the tensor that appears in the outs() clause. For updates, this has
-  // the advantage that only the nonzero value are involved in the computation,
-  // keeping the operation O(nnz). In all other cases, we are forced to zero
-  // out the buffer to enforce the assumption above, which may negatively
-  // impact running complexity (viz. O(n^2 + nnz) vs. O(nnz) for matrices).
+  // An output tensor can simply materialize from the buffer of the tensor that
+  // appears in the outs() clause. For updates, this has the advantage that only
+  // the nonzero value are involved in the computation, keeping the operation
+  // O(nnz). In all other cases, we are forced to zero out the buffer to enforce
+  // the assumption above, which may negatively impact running complexity
+  // (viz. O(n^2 + nnz) vs. O(nnz) for matrices).
   // TODO: use better analysis to avoid zeroing out the buffer?
-  if (isInPlace(tensor)) {
-    Value init =
-        builder.create<bufferization::ToMemrefOp>(loc, denseTp, tensor);
-    if (!isInit) {
-      Value zero = constantZero(builder, loc, denseTp.getElementType());
-      builder.create<linalg::FillOp>(loc, ValueRange{zero}, ValueRange{init});
-    }
-    return init;
-  }
-  // By default, a new buffer is allocated which is either set to zero (when
-  // no updates occur or the tensor materializes into this computation) or
-  // initialized to the value of the tensor defined in the outs() clause.
-  // This is always correct (since it enforces all assumptions above) but
-  // may negatively impact running complexity as explained above.
-  Value alloc = builder.create<memref::AllocOp>(loc, denseTp, args);
-  if (!isInit || isMaterializing(tensor)) {
+  Value init = builder.create<bufferization::ToMemrefOp>(loc, denseTp, tensor);
+  if (!isInit) {
     Value zero = constantZero(builder, loc, denseTp.getElementType());
-    builder.create<linalg::FillOp>(loc, ValueRange{zero}, ValueRange{alloc});
-  } else {
-    Value init =
-        builder.create<bufferization::ToMemrefOp>(loc, denseTp, tensor);
-    builder.create<memref::CopyOp>(loc, init, alloc);
+    builder.create<linalg::FillOp>(loc, ValueRange{zero}, ValueRange{init});
   }
-  return alloc;
+  return init;
 }
 
 /// Local bufferization of all dense and sparse data structures.

diff  --git a/mlir/test/Dialect/SparseTensor/dense.mlir b/mlir/test/Dialect/SparseTensor/dense.mlir
index dff9b20f16fbc..33441c9dddf36 100644
--- a/mlir/test/Dialect/SparseTensor/dense.mlir
+++ b/mlir/test/Dialect/SparseTensor/dense.mlir
@@ -29,55 +29,11 @@
 
 //
 // Test with an all-dense-annotated "sparse" matrix as input and
-// a non-annotated dense matrix as output that is not inplacable.
-// This results in an explicit allocation to facilitate output.
+// a non-annotated dense matrix as output.
 //
 // CHECK-LABEL:   func @dense1(
 // CHECK-SAME:                 %[[VAL_0:.*]]: tensor<32x16xf32, #sparse_tensor.encoding<{{.*}}>>,
-// CHECK-SAME:                 %[[VAL_1:.*]]: tensor<32x16xf32> {linalg.inplaceable = false}) -> tensor<32x16xf32> {
-// CHECK-DAG:       %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32
-// CHECK-DAG:       %[[VAL_2:.*]] = arith.constant 1.000000e+00 : f32
-// CHECK-DAG:       %[[VAL_3:.*]] = arith.constant 32 : index
-// CHECK-DAG:       %[[VAL_4:.*]] = arith.constant 16 : index
-// CHECK-DAG:       %[[VAL_5:.*]] = arith.constant 0 : index
-// CHECK-DAG:       %[[VAL_6:.*]] = arith.constant 1 : index
-// CHECK-DAG:       %[[VAL_7:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x16xf32, #sparse_tensor.encoding<{{.*}}>> to memref<?xf32>
-// CHECK-DAG:       %[[VAL_9:.*]] = memref.alloc() : memref<32x16xf32>
-// CHECK:           linalg.fill ins(%[[ZERO]] : f32) outs(%[[VAL_9]] : memref<32x16xf32>)
-// CHECK:           scf.for %[[VAL_10:.*]] = %[[VAL_5]] to %[[VAL_3]] step %[[VAL_6]] {
-// CHECK:             scf.for %[[VAL_11:.*]] = %[[VAL_5]] to %[[VAL_4]] step %[[VAL_6]] {
-// CHECK:               %[[VAL_12:.*]] = arith.muli %[[VAL_10]], %[[VAL_4]] : index
-// CHECK:               %[[VAL_13:.*]] = arith.addi %[[VAL_12]], %[[VAL_11]] : index
-// CHECK:               %[[VAL_14:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_13]]] : memref<?xf32>
-// CHECK:               %[[VAL_15:.*]] = arith.addf %[[VAL_14]], %[[VAL_2]] : f32
-// CHECK:               memref.store %[[VAL_15]], %[[VAL_9]]{{\[}}%[[VAL_10]], %[[VAL_11]]] : memref<32x16xf32>
-// CHECK:             }
-// CHECK:           }
-// CHECK:           %[[VAL_16:.*]] = bufferization.to_tensor %[[VAL_9]] : memref<32x16xf32>
-// CHECK:           return %[[VAL_16]] : tensor<32x16xf32>
-// CHECK:         }
-func.func @dense1(%arga: tensor<32x16xf32, #DenseMatrix>,
-             %argx: tensor<32x16xf32> {linalg.inplaceable = false})
-	     -> tensor<32x16xf32> {
-  %c = arith.constant 1.0 : f32
-  %0 = linalg.generic #trait_2d
-     ins(%arga: tensor<32x16xf32, #DenseMatrix>)
-    outs(%argx: tensor<32x16xf32>) {
-      ^bb(%a: f32, %x: f32):
-        %1 = arith.addf %a, %c : f32
-        linalg.yield %1 : f32
-  } -> tensor<32x16xf32>
-  return %0 : tensor<32x16xf32>
-}
-
-//
-// Test with an all-dense-annotated "sparse" matrix as input and
-// a non-annotated dense matrix as output that is inplacable.
-// This allows updating the dense output in place.
-//
-// CHECK-LABEL:   func @dense2(
-// CHECK-SAME:                 %[[VAL_0:.*]]: tensor<32x16xf32, #sparse_tensor.encoding<{{.*}}>>,
-// CHECK-SAME:                 %[[VAL_1:.*]]: tensor<32x16xf32> {linalg.inplaceable = true}) -> tensor<32x16xf32> {
+// CHECK-SAME:                 %[[VAL_1:.*]]: tensor<32x16xf32>) -> tensor<32x16xf32> {
 // CHECK-DAG:       %[[VAL_2:.*]] = arith.constant 1.000000e+00 : f32
 // CHECK-DAG:       %[[VAL_3:.*]] = arith.constant 32 : index
 // CHECK-DAG:       %[[VAL_4:.*]] = arith.constant 16 : index
@@ -97,8 +53,8 @@ func.func @dense1(%arga: tensor<32x16xf32, #DenseMatrix>,
 // CHECK:           %[[VAL_15:.*]] = bufferization.to_tensor %[[VAL_8]] : memref<32x16xf32>
 // CHECK:           return %[[VAL_15]] : tensor<32x16xf32>
 // CHECK:         }
-func.func @dense2(%arga: tensor<32x16xf32, #DenseMatrix>,
-             %argx: tensor<32x16xf32> {linalg.inplaceable = true})
+func.func @dense1(%arga: tensor<32x16xf32, #DenseMatrix>,
+                  %argx: tensor<32x16xf32>)
 	     -> tensor<32x16xf32> {
   %c = arith.constant 1.0 : f32
   %0 = linalg.generic #trait_2d
@@ -114,11 +70,10 @@ func.func @dense2(%arga: tensor<32x16xf32, #DenseMatrix>,
 //
 // Test with a non-annotated dense matrix as input and
 // an all-dense annotated "sparse" matrix as output.
-// The rewriting would fail if argx was not in-placeable.
 //
-// CHECK-LABEL:   func @dense3(
+// CHECK-LABEL:   func @dense2(
 // CHECK-SAME:      %[[VAL_0:.*]]: tensor<32x16xf32>,
-// CHECK-SAME:      %[[VAL_1:.*]]: tensor<32x16xf32, #sparse_tensor.encoding<{{.*}}>> {linalg.inplaceable = true}) -> tensor<32x16xf32, #sparse_tensor.encoding<{{.*}}>> {
+// CHECK-SAME:      %[[VAL_1:.*]]: tensor<32x16xf32, #sparse_tensor.encoding<{{.*}}>>) -> tensor<32x16xf32, #sparse_tensor.encoding<{{.*}}>> {
 // CHECK-DAG:       %[[VAL_2:.*]] = arith.constant 1.000000e+00 : f32
 // CHECK-DAG:       %[[VAL_3:.*]] = arith.constant 32 : index
 // CHECK-DAG:       %[[VAL_4:.*]] = arith.constant 16 : index
@@ -138,8 +93,8 @@ func.func @dense2(%arga: tensor<32x16xf32, #DenseMatrix>,
 // CHECK:           %[[VAL_15:.*]] = sparse_tensor.load %[[VAL_1]] : tensor<32x16xf32, #sparse_tensor.encoding<{{.*}}>>
 // CHECK:           return %[[VAL_15]] : tensor<32x16xf32, #sparse_tensor.encoding<{{.*}}>>
 // CHECK:         }
-func.func @dense3(%arga: tensor<32x16xf32>,
-             %argx: tensor<32x16xf32, #DenseMatrix> {linalg.inplaceable = true})
+func.func @dense2(%arga: tensor<32x16xf32>,
+                  %argx: tensor<32x16xf32, #DenseMatrix>)
 	     -> tensor<32x16xf32, #DenseMatrix> {
   %c = arith.constant 1.0 : f32
   %0 = linalg.generic #trait_2d
@@ -156,13 +111,12 @@ func.func @dense3(%arga: tensor<32x16xf32>,
 //
 // Test with a non-annotated dense matrix as input and
 // an all-dense annotated "sparse" matrix as output.
-// The rewriting would fail if argx was not in-placeable.
 // The missing innermost "k" index (due to a reduction) is accounted
 // for by scalarizing the reduction operation for the output tensor.
 //
-// CHECK-LABEL:   func @dense4(
+// CHECK-LABEL:   func @dense3(
 // CHECK-SAME:      %[[VAL_0:.*]]: tensor<32x16x8xf32>,
-// CHECK-SAME:      %[[VAL_1:.*]]: tensor<32x16xf32, #sparse_tensor.encoding<{{.*}}>> {linalg.inplaceable = true}) -> tensor<32x16xf32, #sparse_tensor.encoding<{{.*}}>> {
+// CHECK-SAME:      %[[VAL_1:.*]]: tensor<32x16xf32, #sparse_tensor.encoding<{{.*}}>>) -> tensor<32x16xf32, #sparse_tensor.encoding<{{.*}}>> {
 // CHECK-DAG:       %[[VAL_2:.*]] = arith.constant 8 : index
 // CHECK-DAG:       %[[VAL_3:.*]] = arith.constant 32 : index
 // CHECK-DAG:       %[[VAL_4:.*]] = arith.constant 16 : index
@@ -186,8 +140,8 @@ func.func @dense3(%arga: tensor<32x16xf32>,
 // CHECK:           %[[VAL_20:.*]] = sparse_tensor.load %[[VAL_1]] : tensor<32x16xf32, #sparse_tensor.encoding<{{.*}}>>
 // CHECK:           return %[[VAL_20]] : tensor<32x16xf32, #sparse_tensor.encoding<{{.*}}>>
 // CHECK:         }
-func.func @dense4(%arga: tensor<32x16x8xf32>,
-             %argx: tensor<32x16xf32, #DenseMatrix> {linalg.inplaceable = true})
+func.func @dense3(%arga: tensor<32x16x8xf32>,
+                  %argx: tensor<32x16xf32, #DenseMatrix>)
 	     -> tensor<32x16xf32, #DenseMatrix> {
   %0 = linalg.generic #trait_3d
      ins(%arga: tensor<32x16x8xf32>)

diff  --git a/mlir/test/Dialect/SparseTensor/one_shot_bufferize_tensor_copy_insertion.mlir b/mlir/test/Dialect/SparseTensor/one_shot_bufferize_tensor_copy_insertion.mlir
index 76ddfaea2978f..85c4688171576 100644
--- a/mlir/test/Dialect/SparseTensor/one_shot_bufferize_tensor_copy_insertion.mlir
+++ b/mlir/test/Dialect/SparseTensor/one_shot_bufferize_tensor_copy_insertion.mlir
@@ -40,3 +40,34 @@ func.func @sparse_tensor_convert() -> tensor<20x40xf32> {
   %2 = sparse_tensor.convert %1 : tensor<20x40xf32, #DCSR> to tensor<20x40xf32>
   return %2 : tensor<20x40xf32>
 }
+
+#SV = #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ] }>
+
+#trait = {
+  indexing_maps = [
+    affine_map<(i) -> (i)>,  // A (in)
+    affine_map<(i) -> (i)>   // X (out)
+  ],
+  iterator_types = ["parallel"]
+}
+
+// CHECK-LABEL: func @update_notinplace(
+//  CHECK-SAME:    %[[argb:.*]]: tensor<10xf32> 
+// CHECK-FUNC-LABEL: func @update_notinplace(
+//  CHECK-FUNC-SAME:    %[[argb:.*]]: tensor<10xf32> 
+func.func @update_notinplace(%argb: tensor<10xf32>, %arga: tensor<10xf32, #SV>)
+  -> (tensor<10xf32>, tensor<10xf32>)
+{
+  // CHECK: %[[alloc:.*]] = bufferization.alloc_tensor() copy(%[[argb]]) {bufferization.escape = [false]} : tensor<10xf32>
+  // CHECK: linalg.generic {{.*}} outs(%[[alloc]]
+  // CHECK-FUNC: %[[alloc:.*]] = bufferization.alloc_tensor() copy(%[[argb]]) {bufferization.escape = [true]} : tensor<10xf32>
+  // CHECK-FUNC: linalg.generic {{.*}} outs(%[[alloc]]
+  %0 = linalg.generic #trait
+  ins(%arga: tensor<10xf32, #SV>)
+  outs(%argb: tensor<10xf32>) {
+    ^bb(%a: f32, %x : f32):
+      %up = arith.addf %a, %x : f32
+      linalg.yield %up : f32
+  } -> tensor<10xf32>
+  return %0, %argb : tensor<10xf32>, tensor<10xf32>
+}

diff  --git a/mlir/test/Dialect/SparseTensor/sparse_1d.mlir b/mlir/test/Dialect/SparseTensor/sparse_1d.mlir
index 6a9d26c65b957..f82af4b9205c7 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_1d.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_1d.mlir
@@ -21,7 +21,7 @@
 // CHECK-DAG:       %[[VAL_4:.*]] = arith.constant 0 : index
 // CHECK-DAG:       %[[VAL_5:.*]] = arith.constant 1 : index
 // CHECK-DAG:       %[[VAL_6:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf32>
-// CHECK-DAG:       %[[VAL_8:.*]] = memref.alloc() : memref<32xf32>
+// CHECK-DAG:       %[[VAL_8:.*]] = bufferization.to_memref %[[VAL_2]]
 // CHECK:           linalg.fill ins(%{{.*}} : f32) outs(%[[VAL_8]] : memref<32xf32>)
 // CHECK:           scf.for %[[VAL_9:.*]] = %[[VAL_4]] to %[[VAL_3]] step %[[VAL_5]] {
 // CHECK:             %[[VAL_10:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_9]]] : memref<?xf32>
@@ -49,8 +49,9 @@ func.func @add_d(%arga: tensor<32xf32, #DV>, %argb: f32, %argx: tensor<32xf32>)
 // CHECK:           %[[VAL_3:.*]] = arith.constant 0.000000e+00 : f32
 // CHECK:           %[[VAL_4:.*]] = arith.constant 0 : index
 // CHECK:           %[[VAL_5:.*]] = arith.constant 1 : index
+// CHECK:           %[[VAL_INITTENSOR:.*]] = linalg.init_tensor [32] : tensor<32xf32>
 // CHECK:           %[[VAL_6:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf32>
-// CHECK:           %[[VAL_7:.*]] = memref.alloc() : memref<32xf32>
+// CHECK:           %[[VAL_7:.*]] = bufferization.to_memref %[[VAL_INITTENSOR]] : memref<32xf32>
 // CHECK:           linalg.fill ins(%[[VAL_3]] : f32) outs(%[[VAL_7]] : memref<32xf32>)
 // CHECK:           scf.for %[[VAL_8:.*]] = %[[VAL_4]] to %[[VAL_2]] step %[[VAL_5]] {
 // CHECK:             %[[VAL_9:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_8]]] : memref<?xf32>
@@ -80,7 +81,7 @@ func.func @add_d_init(%arga: tensor<32xf32, #DV>, %argb: f32) -> tensor<32xf32>
 // CHECK-DAG:       %[[VAL_4:.*]] = arith.constant 0 : index
 // CHECK-DAG:       %[[VAL_5:.*]] = arith.constant 1 : index
 // CHECK-DAG:       %[[VAL_6:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf32>
-// CHECK-DAG:       %[[VAL_8:.*]] = memref.alloc() : memref<32xf32>
+// CHECK-DAG:       %[[VAL_8:.*]] = bufferization.to_memref %[[VAL_2]]
 // CHECK:           linalg.fill ins(%{{.*}} : f32) outs(%[[VAL_8]] : memref<32xf32>)
 // CHECK:           scf.for %[[VAL_9:.*]] = %[[VAL_4]] to %[[VAL_3]] step %[[VAL_5]] {
 // CHECK:             %[[VAL_10:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_9]]] : memref<?xf32>
@@ -112,7 +113,7 @@ func.func @mul_d(%arga: tensor<32xf32, #DV>, %argb: f32, %argx: tensor<32xf32>)
 // CHECK-DAG:       %[[VAL_7:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_4]] : tensor<32xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
 // CHECK-DAG:       %[[VAL_8:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_4]] : tensor<32xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
 // CHECK-DAG:       %[[VAL_9:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf32>
-// CHECK-DAG:       %[[VAL_11:.*]] = memref.alloc() : memref<32xf32>
+// CHECK-DAG:       %[[VAL_11:.*]] = bufferization.to_memref %[[VAL_2]]
 // CHECK:           linalg.fill ins(%{{.*}} : f32) outs(%[[VAL_11]] : memref<32xf32>)
 // CHECK:           %[[VAL_12:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_4]]] : memref<?xindex>
 // CHECK:           %[[VAL_13:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_6]]] : memref<?xindex>
@@ -164,7 +165,7 @@ func.func @add_s(%arga: tensor<32xf32, #SV>, %argb: f32, %argx: tensor<32xf32>)
 // CHECK-DAG:       %[[VAL_4:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_2]] : tensor<32xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
 // CHECK-DAG:       %[[VAL_5:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_2]] : tensor<32xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
 // CHECK-DAG:       %[[VAL_6:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf32>
-// CHECK-DAG:       %[[VAL_8:.*]] = memref.alloc() : memref<32xf32>
+// CHECK-DAG:       %[[VAL_8:.*]] = bufferization.to_memref %[[VAL_1]]
 // CHECK:           linalg.fill ins(%{{.*}} : f32) outs(%[[VAL_8]] : memref<32xf32>)
 // CHECK:           %[[VAL_9:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_2]]] : memref<?xindex>
 // CHECK:           %[[VAL_10:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_3]]] : memref<?xindex>
@@ -204,7 +205,7 @@ func.func @repeated_add_s(%arga: tensor<32xf32, #SV>, %argx: tensor<32xf32>) ->
 // CHECK-DAG:       %[[VAL_5:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_3]] : tensor<32xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
 // CHECK-DAG:       %[[VAL_6:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_3]] : tensor<32xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
 // CHECK-DAG:       %[[VAL_7:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf32>
-// CHECK-DAG:       %[[VAL_9:.*]] = memref.alloc() : memref<32xf32>
+// CHECK-DAG:       %[[VAL_9:.*]] = bufferization.to_memref %[[VAL_2]]
 // CHECK:           linalg.fill ins(%{{.*}} : f32) outs(%[[VAL_9]] : memref<32xf32>)
 // CHECK:           %[[VAL_10:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_3]]] : memref<?xindex>
 // CHECK:           %[[VAL_11:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_4]]] : memref<?xindex>
@@ -247,7 +248,7 @@ func.func @mul_s(%arga: tensor<32xf32, #SV>, %argb: f32, %argx: tensor<32xf32>)
 // CHECK-DAG:       %[[VAL_5:.*]] = arith.constant 1 : index
 // CHECK-DAG:       %[[VAL_6:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf32>
 // CHECK-DAG:       %[[VAL_7:.*]] = bufferization.to_memref %[[VAL_1]] : memref<32xf32>
-// CHECK-DAG:       %[[VAL_9:.*]] = memref.alloc() : memref<32xf32>
+// CHECK-DAG:       %[[VAL_9:.*]] = bufferization.to_memref %[[VAL_2]]
 // CHECK:           linalg.fill ins(%{{.*}} : f32) outs(%[[VAL_9]] : memref<32xf32>)
 // CHECK:           scf.for %[[VAL_10:.*]] = %[[VAL_4]] to %[[VAL_3]] step %[[VAL_5]] {
 // CHECK:             %[[VAL_11:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_10]]] : memref<?xf32>
@@ -278,7 +279,7 @@ func.func @add_dd(%arga: tensor<32xf32, #DV>, %argb: tensor<32xf32>, %argx: tens
 // CHECK-DAG:       %[[VAL_5:.*]] = arith.constant 1 : index
 // CHECK-DAG:       %[[VAL_6:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf32>
 // CHECK-DAG:       %[[VAL_7:.*]] = bufferization.to_memref %[[VAL_1]] : memref<32xf32>
-// CHECK-DAG:       %[[VAL_9:.*]] = memref.alloc() : memref<32xf32>
+// CHECK-DAG:       %[[VAL_9:.*]] = bufferization.to_memref %[[VAL_2]]
 // CHECK:           linalg.fill ins(%{{.*}} : f32) outs(%[[VAL_9]] : memref<32xf32>)
 // CHECK:           scf.for %[[VAL_10:.*]] = %[[VAL_4]] to %[[VAL_3]] step %[[VAL_5]] {
 // CHECK:             %[[VAL_11:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_10]]] : memref<?xf32>
@@ -312,7 +313,7 @@ func.func @mul_dd(%arga: tensor<32xf32, #DV>, %argb: tensor<32xf32>, %argx: tens
 // CHECK-DAG:       %[[VAL_8:.*]] = sparse_tensor.pointers %[[VAL_1]], %[[VAL_4]] : tensor<32xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
 // CHECK-DAG:       %[[VAL_9:.*]] = sparse_tensor.indices %[[VAL_1]], %[[VAL_4]] : tensor<32xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
 // CHECK-DAG:       %[[VAL_10:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<32xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf32>
-// CHECK-DAG:       %[[VAL_12:.*]] = memref.alloc() : memref<32xf32>
+// CHECK-DAG:       %[[VAL_12:.*]] = bufferization.to_memref %[[VAL_2]]
 // CHECK:           linalg.fill ins(%{{.*}} : f32) outs(%[[VAL_12]] : memref<32xf32>)
 // CHECK:           %[[VAL_13:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_4]]] : memref<?xindex>
 // CHECK:           %[[VAL_14:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_6]]] : memref<?xindex>
@@ -369,7 +370,7 @@ func.func @add_ds(%arga: tensor<32xf32>, %argb: tensor<32xf32, #SV>, %argx: tens
 // CHECK-DAG:       %[[VAL_6:.*]] = sparse_tensor.pointers %[[VAL_1]], %[[VAL_3]] : tensor<32xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
 // CHECK-DAG:       %[[VAL_7:.*]] = sparse_tensor.indices %[[VAL_1]], %[[VAL_3]] : tensor<32xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
 // CHECK-DAG:       %[[VAL_8:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<32xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf32>
-// CHECK-DAG:       %[[VAL_10:.*]] = memref.alloc() : memref<32xf32>
+// CHECK-DAG:       %[[VAL_10:.*]] = bufferization.to_memref %[[VAL_2]]
 // CHECK:           linalg.fill ins(%{{.*}} : f32) outs(%[[VAL_10]] : memref<32xf32>)
 // CHECK:           %[[VAL_11:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_3]]] : memref<?xindex>
 // CHECK:           %[[VAL_12:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_4]]] : memref<?xindex>
@@ -406,7 +407,7 @@ func.func @mul_ds(%arga: tensor<32xf32>, %argb: tensor<32xf32, #SV>, %argx: tens
 // CHECK-DAG:       %[[VAL_8:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_4]] : tensor<32xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
 // CHECK-DAG:       %[[VAL_9:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf32>
 // CHECK-DAG:       %[[VAL_10:.*]] = bufferization.to_memref %[[VAL_1]] : memref<32xf32>
-// CHECK-DAG:       %[[VAL_12:.*]] = memref.alloc() : memref<32xf32>
+// CHECK-DAG:       %[[VAL_12:.*]] = bufferization.to_memref %[[VAL_2]]
 // CHECK:           linalg.fill ins(%{{.*}} : f32) outs(%[[VAL_12]] : memref<32xf32>)
 // CHECK:           %[[VAL_13:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_4]]] : memref<?xindex>
 // CHECK:           %[[VAL_14:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_6]]] : memref<?xindex>
@@ -463,7 +464,7 @@ func.func @add_sd(%arga: tensor<32xf32, #SV>, %argb: tensor<32xf32>, %argx: tens
 // CHECK-DAG:       %[[VAL_6:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_3]] : tensor<32xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
 // CHECK-DAG:       %[[VAL_7:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf32>
 // CHECK-DAG:       %[[VAL_8:.*]] = bufferization.to_memref %[[VAL_1]] : memref<32xf32>
-// CHECK-DAG:       %[[VAL_10:.*]] = memref.alloc() : memref<32xf32>
+// CHECK-DAG:       %[[VAL_10:.*]] = bufferization.to_memref %[[VAL_2]]
 // CHECK:           linalg.fill ins(%{{.*}} : f32) outs(%[[VAL_10]] : memref<32xf32>)
 // CHECK:           %[[VAL_11:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_3]]] : memref<?xindex>
 // CHECK:           %[[VAL_12:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_4]]] : memref<?xindex>
@@ -500,7 +501,7 @@ func.func @mul_sd(%arga: tensor<32xf32, #SV>, %argb: tensor<32xf32>, %argx: tens
 // CHECK-DAG:       %[[VAL_8:.*]] = sparse_tensor.pointers %[[VAL_1]], %[[VAL_3]] : tensor<32xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
 // CHECK-DAG:       %[[VAL_9:.*]] = sparse_tensor.indices %[[VAL_1]], %[[VAL_3]] : tensor<32xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
 // CHECK-DAG:       %[[VAL_10:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<32xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf32>
-// CHECK-DAG:       %[[VAL_12:.*]] = memref.alloc() : memref<32xf32>
+// CHECK-DAG:       %[[VAL_12:.*]] = bufferization.to_memref %[[VAL_2]]
 // CHECK:           linalg.fill ins(%{{.*}} : f32) outs(%[[VAL_12]] : memref<32xf32>)
 // CHECK:           %[[VAL_13:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_3]]] : memref<?xindex>
 // CHECK:           %[[VAL_14:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_4]]] : memref<?xindex>
@@ -583,7 +584,7 @@ func.func @add_ss(%arga: tensor<32xf32, #SV>, %argb: tensor<32xf32, #SV>, %argx:
 // CHECK-DAG:       %[[VAL_8:.*]] = sparse_tensor.pointers %[[VAL_1]], %[[VAL_3]] : tensor<32xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
 // CHECK-DAG:       %[[VAL_9:.*]] = sparse_tensor.indices %[[VAL_1]], %[[VAL_3]] : tensor<32xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
 // CHECK-DAG:       %[[VAL_10:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<32xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf32>
-// CHECK-DAG:       %[[VAL_12:.*]] = memref.alloc() : memref<32xf32>
+// CHECK-DAG:       %[[VAL_12:.*]] = bufferization.to_memref %[[VAL_2]]
 // CHECK:           linalg.fill ins(%{{.*}} : f32) outs(%[[VAL_12]] : memref<32xf32>)
 // CHECK:           %[[VAL_13:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_3]]] : memref<?xindex>
 // CHECK:           %[[VAL_14:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_4]]] : memref<?xindex>
@@ -645,7 +646,7 @@ func.func @mul_ss(%arga: tensor<32xf32, #SV>, %argb: tensor<32xf32, #SV>, %argx:
 // CHECK-DAG:       %[[VAL_9:.*]] = sparse_tensor.pointers %[[VAL_1]], %[[VAL_4]] : tensor<16xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
 // CHECK-DAG:       %[[VAL_10:.*]] = sparse_tensor.indices %[[VAL_1]], %[[VAL_4]] : tensor<16xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
 // CHECK-DAG:       %[[VAL_11:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<16xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf32>
-// CHECK-DAG:       %[[VAL_13:.*]] = memref.alloc() : memref<16xf32>
+// CHECK-DAG:       %[[VAL_13:.*]] = bufferization.to_memref %[[VAL_3]]
 // CHECK:           linalg.fill ins(%{{.*}} : f32) outs(%[[VAL_13]] : memref<16xf32>)
 // CHECK:           %[[VAL_14:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_4]]] : memref<?xindex>
 // CHECK:           %[[VAL_15:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_5]]] : memref<?xindex>
@@ -738,7 +739,7 @@ func.func @two_way_inv(%arga: tensor<16xf32, #SV>, %argb: tensor<16xf32, #SV>, %
 // CHECK-DAG:       %[[VAL_9:.*]] = sparse_tensor.pointers %[[VAL_1]], %[[VAL_4]] : tensor<16xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
 // CHECK-DAG:       %[[VAL_10:.*]] = sparse_tensor.indices %[[VAL_1]], %[[VAL_4]] : tensor<16xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
 // CHECK-DAG:       %[[VAL_11:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<16xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf32>
-// CHECK-DAG:       %[[VAL_13:.*]] = memref.alloc() : memref<16xf32>
+// CHECK-DAG:       %[[VAL_13:.*]] = bufferization.to_memref %[[VAL_3]]
 // CHECK:           linalg.fill ins(%{{.*}} : f32) outs(%[[VAL_13]] : memref<16xf32>)
 // CHECK:           %[[VAL_14:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_4]]] : memref<?xindex>
 // CHECK:           %[[VAL_15:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_5]]] : memref<?xindex>
@@ -834,18 +835,16 @@ func.func @two_way_inv_alt(%arga: tensor<16xf32, #SV>,
 // CHECK-DAG:       %[[VAL_4:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_2]] : tensor<?xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
 // CHECK-DAG:       %[[VAL_5:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<?xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf32>
 // CHECK-DAG:       %[[VAL_6:.*]] = bufferization.to_memref %[[VAL_1]] : memref<f32>
-// CHECK-DAG:       %[[VAL_7:.*]] = memref.alloc() : memref<f32>
-// CHECK:           memref.copy %[[VAL_6]], %[[VAL_7]] : memref<f32> to memref<f32>
 // CHECK-DAG:       %[[VAL_8:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_2]]] : memref<?xindex>
 // CHECK-DAG:       %[[VAL_9:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_3]]] : memref<?xindex>
-// CHECK-DAG:       %[[VAL_10:.*]] = memref.load %[[VAL_7]][] : memref<f32>
+// CHECK-DAG:       %[[VAL_10:.*]] = memref.load %[[VAL_6]][] : memref<f32>
 // CHECK:           %[[VAL_11:.*]] = scf.for %[[VAL_12:.*]] = %[[VAL_8]] to %[[VAL_9]] step %[[VAL_3]] iter_args(%[[VAL_13:.*]] = %[[VAL_10]]) -> (f32) {
 // CHECK:             %[[VAL_14:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_12]]] : memref<?xf32>
 // CHECK:             %[[VAL_15:.*]] = arith.addf %[[VAL_13]], %[[VAL_14]] : f32
 // CHECK:             scf.yield %[[VAL_15]] : f32
 // CHECK:           }
-// CHECK:           memref.store %[[VAL_11]], %[[VAL_7]][] : memref<f32>
-// CHECK:           %[[VAL_17:.*]] = bufferization.to_tensor %[[VAL_7]] : memref<f32>
+// CHECK:           memref.store %[[VAL_11]], %[[VAL_6]][] : memref<f32>
+// CHECK:           %[[VAL_17:.*]] = bufferization.to_tensor %[[VAL_6]] : memref<f32>
 // CHECK:           return %[[VAL_17]] : tensor<f32>
 // CHECK:         }
 func.func @sum_reduction(%arga: tensor<?xf32, #SV>, %argx: tensor<f32>) -> tensor<f32> {
@@ -882,9 +881,7 @@ func.func @sum_reduction(%arga: tensor<?xf32, #SV>, %argx: tensor<f32>) -> tenso
 // CHECK-DAG:       %[[VAL_9:.*]] = sparse_tensor.indices %[[VAL_1]], %[[VAL_3]] : tensor<16xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
 // CHECK-DAG:       %[[VAL_10:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<16xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf32>
 // CHECK-DAG:       %[[VAL_11:.*]] = bufferization.to_memref %[[VAL_2]] : memref<f32>
-// CHECK-DAG:       %[[VAL_12:.*]] = memref.alloc() : memref<f32>
-// CHECK:           memref.copy %[[VAL_11]], %[[VAL_12]] : memref<f32> to memref<f32>
-// CHECK:           %[[VAL_13:.*]] = memref.load %[[VAL_12]][] : memref<f32>
+// CHECK:           %[[VAL_13:.*]] = memref.load %[[VAL_11]][] : memref<f32>
 // CHECK:           %[[VAL_14:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_3]]] : memref<?xindex>
 // CHECK:           %[[VAL_15:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_4]]] : memref<?xindex>
 // CHECK:           %[[VAL_16:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_3]]] : memref<?xindex>
@@ -946,8 +943,8 @@ func.func @sum_reduction(%arga: tensor<?xf32, #SV>, %argx: tensor<f32>) -> tenso
 // CHECK:             %[[VAL_69:.*]] = arith.addf %[[VAL_66]], %[[VAL_68]] : f32
 // CHECK:             scf.yield %[[VAL_69]] : f32
 // CHECK:           }
-// CHECK:           memref.store %[[VAL_70:.*]], %[[VAL_12]][] : memref<f32>
-// CHECK:           %[[VAL_71:.*]] = bufferization.to_tensor %[[VAL_12]] : memref<f32>
+// CHECK:           memref.store %[[VAL_70:.*]], %[[VAL_11]][] : memref<f32>
+// CHECK:           %[[VAL_71:.*]] = bufferization.to_tensor %[[VAL_11]] : memref<f32>
 // CHECK:           return %[[VAL_71]] : tensor<f32>
 // CHECK:         }
 func.func @sum_reduction_ss(%arga: tensor<16xf32, #SV>,
@@ -992,9 +989,7 @@ func.func @sum_reduction_ss(%arga: tensor<16xf32, #SV>,
 // CHECK-DAG:       %[[VAL_11:.*]] = sparse_tensor.indices %[[VAL_2]], %[[VAL_4]] : tensor<16xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
 // CHECK-DAG:       %[[VAL_12:.*]] = sparse_tensor.values %[[VAL_2]] : tensor<16xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf32>
 // CHECK-DAG:       %[[VAL_13:.*]] = bufferization.to_memref %[[VAL_3]] : memref<f32>
-// CHECK-DAG:       %[[VAL_14:.*]] = memref.alloc() : memref<f32>
-// CHECK:           memref.copy %[[VAL_13]], %[[VAL_14]] : memref<f32> to memref<f32>
-// CHECK:           %[[VAL_15:.*]] = memref.load %[[VAL_14]][] : memref<f32>
+// CHECK:           %[[VAL_15:.*]] = memref.load %[[VAL_13]][] : memref<f32>
 // CHECK:           %[[VAL_16:.*]] = memref.load %[[VAL_9]][] : memref<f32>
 // CHECK:           %[[VAL_17:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_4]]] : memref<?xindex>
 // CHECK:           %[[VAL_18:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_5]]] : memref<?xindex>
@@ -1060,8 +1055,8 @@ func.func @sum_reduction_ss(%arga: tensor<16xf32, #SV>,
 // CHECK:             %[[VAL_75:.*]] = arith.addf %[[VAL_72]], %[[VAL_74]] : f32
 // CHECK:             scf.yield %[[VAL_75]] : f32
 // CHECK:           }
-// CHECK:           memref.store %[[VAL_76:.*]], %[[VAL_14]][] : memref<f32>
-// CHECK:           %[[VAL_77:.*]] = bufferization.to_tensor %[[VAL_14]] : memref<f32>
+// CHECK:           memref.store %[[VAL_76:.*]], %[[VAL_13]][] : memref<f32>
+// CHECK:           %[[VAL_77:.*]] = bufferization.to_tensor %[[VAL_13]] : memref<f32>
 // CHECK:           return %[[VAL_77]] : tensor<f32>
 // CHECK:         }
 func.func @sum_reduction_inv(%arga: tensor<16xf32, #SV>,
@@ -1112,7 +1107,7 @@ func.func @sum_reduction_inv(%arga: tensor<16xf32, #SV>,
 // CHECK-DAG:       %[[VAL_14:.*]] = sparse_tensor.indices %[[VAL_3]], %[[VAL_5]] : tensor<?xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
 // CHECK-DAG:       %[[VAL_15:.*]] = sparse_tensor.values %[[VAL_3]] : tensor<?xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf64>
 // CHECK-DAG:       %[[VAL_16:.*]] = tensor.dim %[[VAL_4]], %[[VAL_5]] : tensor<?xf64>
-// CHECK-DAG:       %[[VAL_18:.*]] = memref.alloc(%[[VAL_16]]) : memref<?xf64>
+// CHECK-DAG:       %[[VAL_18:.*]] = bufferization.to_memref %[[VAL_4]]
 // CHECK:           linalg.fill ins(%{{.*}} : f64) outs(%[[VAL_18]] : memref<?xf64>)
 // CHECK:           %[[VAL_19:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_5]]] : memref<?xindex>
 // CHECK:           %[[VAL_20:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_7]]] : memref<?xindex>
@@ -1289,9 +1284,7 @@ func.func @four_tensors_op(%arga: tensor<?xf64>,
 // CHECK-DAG:       %[[VAL_13:.*]] = sparse_tensor.indices %[[VAL_2]], %[[VAL_4]] : tensor<?xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
 // CHECK-DAG:       %[[VAL_14:.*]] = sparse_tensor.values %[[VAL_2]] : tensor<?xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf64>
 // CHECK-DAG:       %[[VAL_15:.*]] = bufferization.to_memref %[[VAL_3]] : memref<f64>
-// CHECK-DAG:       %[[VAL_16:.*]] = memref.alloc() : memref<f64>
-// CHECK:           memref.copy %[[VAL_15]], %[[VAL_16]] : memref<f64> to memref<f64>
-// CHECK:           %[[VAL_17:.*]] = memref.load %[[VAL_16]][] : memref<f64>
+// CHECK:           %[[VAL_17:.*]] = memref.load %[[VAL_15]][] : memref<f64>
 // CHECK:           %[[VAL_18:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_4]]] : memref<?xindex>
 // CHECK:           %[[VAL_19:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_5]]] : memref<?xindex>
 // CHECK:           %[[VAL_20:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_4]]] : memref<?xindex>
@@ -1557,8 +1550,8 @@ func.func @four_tensors_op(%arga: tensor<?xf64>,
 // CHECK:             %[[VAL_250:.*]] = arith.addf %[[VAL_247]], %[[VAL_249]] : f64
 // CHECK:             scf.yield %[[VAL_250]] : f64
 // CHECK:           }
-// CHECK:           memref.store %[[VAL_251:.*]], %[[VAL_16]][] : memref<f64>
-// CHECK:           %[[VAL_252:.*]] = bufferization.to_tensor %[[VAL_16]] : memref<f64>
+// CHECK:           memref.store %[[VAL_251:.*]], %[[VAL_15]][] : memref<f64>
+// CHECK:           %[[VAL_252:.*]] = bufferization.to_tensor %[[VAL_15]] : memref<f64>
 // CHECK:           return %[[VAL_252]] : tensor<f64>
 // CHECK:         }
 func.func @red3s(%arga: tensor<?xf64, #SV>,

diff  --git a/mlir/test/Dialect/SparseTensor/sparse_2d.mlir b/mlir/test/Dialect/SparseTensor/sparse_2d.mlir
index d6ed917517a32..896b11080a083 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_2d.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_2d.mlir
@@ -26,7 +26,7 @@
 // CHECK-DAG:       %[[VAL_6:.*]] = arith.constant 1 : index
 // CHECK-DAG:       %[[VAL_7:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x16xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "dense" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf32>
 // CHECK-DAG:       %[[VAL_8:.*]] = bufferization.to_memref %[[VAL_1]] : memref<32x16xf32>
-// CHECK-DAG:       %[[VAL_10:.*]] = memref.alloc() : memref<32x16xf32>
+// CHECK-DAG:       %[[VAL_10:.*]] = bufferization.to_memref %[[VAL_2]] : memref<32x16xf32>
 // CHECK:           linalg.fill ins(%{{.*}} : f32) outs(%[[VAL_10]] : memref<32x16xf32>)
 // CHECK:           scf.for %[[VAL_11:.*]] = %[[VAL_5]] to %[[VAL_3]] step %[[VAL_6]] {
 // CHECK:             scf.for %[[VAL_12:.*]] = %[[VAL_5]] to %[[VAL_4]] step %[[VAL_6]] {
@@ -62,7 +62,7 @@ func.func @add_dd(%arga: tensor<32x16xf32, #Tdd>, %argb: tensor<32x16xf32>, %arg
 // CHECK-DAG:       %[[VAL_6:.*]] = arith.constant 1 : index
 // CHECK-DAG:       %[[VAL_7:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x16xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "dense" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf32>
 // CHECK-DAG:       %[[VAL_8:.*]] = bufferization.to_memref %[[VAL_1]] : memref<32x16xf32>
-// CHECK-DAG:       %[[VAL_10:.*]] = memref.alloc() : memref<32x16xf32>
+// CHECK-DAG:       %[[VAL_10:.*]] = bufferization.to_memref %[[VAL_2]] : memref<32x16xf32>
 // CHECK:           linalg.fill ins(%{{.*}} : f32) outs(%[[VAL_10]] : memref<32x16xf32>)
 // CHECK:           scf.for %[[VAL_11:.*]] = %[[VAL_5]] to %[[VAL_3]] step %[[VAL_6]] {
 // CHECK:             scf.for %[[VAL_12:.*]] = %[[VAL_5]] to %[[VAL_4]] step %[[VAL_6]] {
@@ -101,7 +101,7 @@ func.func @mul_dd(%arga: tensor<32x16xf32, #Tdd>, %argb: tensor<32x16xf32>, %arg
 // CHECK-DAG:       %[[VAL_9:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_7]] : tensor<32x16xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
 // CHECK-DAG:       %[[VAL_10:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x16xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf32>
 // CHECK-DAG:       %[[VAL_11:.*]] = bufferization.to_memref %[[VAL_1]] : memref<32x16xf32>
-// CHECK-DAG:       %[[VAL_13:.*]] = memref.alloc() : memref<32x16xf32>
+// CHECK-DAG:       %[[VAL_13:.*]] = bufferization.to_memref %[[VAL_2]] : memref<32x16xf32>
 // CHECK:           linalg.fill ins(%{{.*}} : f32) outs(%[[VAL_13]] : memref<32x16xf32>)
 // CHECK:           scf.for %[[VAL_14:.*]] = %[[VAL_5]] to %[[VAL_3]] step %[[VAL_7]] {
 // CHECK:             %[[VAL_15:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_14]]] : memref<?xindex>
@@ -162,7 +162,7 @@ func.func @add_ds(%arga: tensor<32x16xf32, #Tds>, %argb: tensor<32x16xf32>, %arg
 // CHECK-DAG:       %[[VAL_7:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_5]] : tensor<32x16xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
 // CHECK-DAG:       %[[VAL_8:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x16xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf32>
 // CHECK-DAG:       %[[VAL_9:.*]] = bufferization.to_memref %[[VAL_1]] : memref<32x16xf32>
-// CHECK-DAG:       %[[VAL_11:.*]] = memref.alloc() : memref<32x16xf32>
+// CHECK-DAG:       %[[VAL_11:.*]] = bufferization.to_memref %[[VAL_2]] : memref<32x16xf32>
 // CHECK:           linalg.fill ins(%{{.*}} : f32) outs(%[[VAL_11]] : memref<32x16xf32>)
 // CHECK:           scf.for %[[VAL_12:.*]] = %[[VAL_4]] to %[[VAL_3]] step %[[VAL_5]] {
 // CHECK:             %[[VAL_13:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_12]]] : memref<?xindex>
@@ -203,7 +203,7 @@ func.func @mul_ds(%arga: tensor<32x16xf32, #Tds>, %argb: tensor<32x16xf32>, %arg
 // CHECK-DAG:       %[[VAL_9:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_6]] : tensor<32x16xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "dense" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
 // CHECK-DAG:       %[[VAL_10:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x16xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "dense" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf32>
 // CHECK-DAG:       %[[VAL_11:.*]] = bufferization.to_memref %[[VAL_1]] : memref<32x16xf32>
-// CHECK-DAG:       %[[VAL_13:.*]] = memref.alloc() : memref<32x16xf32>
+// CHECK-DAG:       %[[VAL_13:.*]] = bufferization.to_memref %[[VAL_2]] : memref<32x16xf32>
 // CHECK:           linalg.fill ins(%{{.*}} : f32) outs(%[[VAL_13]] : memref<32x16xf32>)
 // CHECK:           %[[VAL_14:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_6]]] : memref<?xindex>
 // CHECK:           %[[VAL_15:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_7]]] : memref<?xindex>
@@ -269,7 +269,7 @@ func.func @add_sd(%arga: tensor<32x16xf32, #Tsd>, %argb: tensor<32x16xf32>, %arg
 // CHECK-DAG:       %[[VAL_7:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_4]] : tensor<32x16xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "dense" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
 // CHECK-DAG:       %[[VAL_8:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x16xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "dense" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf32>
 // CHECK-DAG:       %[[VAL_9:.*]] = bufferization.to_memref %[[VAL_1]] : memref<32x16xf32>
-// CHECK-DAG:       %[[VAL_11:.*]] = memref.alloc() : memref<32x16xf32>
+// CHECK-DAG:       %[[VAL_11:.*]] = bufferization.to_memref %[[VAL_2]] : memref<32x16xf32>
 // CHECK:           linalg.fill ins(%{{.*}} : f32) outs(%[[VAL_11]] : memref<32x16xf32>)
 // CHECK:           %[[VAL_12:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_4]]] : memref<?xindex>
 // CHECK:           %[[VAL_13:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_5]]] : memref<?xindex>
@@ -313,7 +313,7 @@ func.func @mul_sd(%arga: tensor<32x16xf32, #Tsd>, %argb: tensor<32x16xf32>, %arg
 // CHECK-DAG:       %[[VAL_11:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_7]] : tensor<32x16xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
 // CHECK-DAG:       %[[VAL_12:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x16xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf32>
 // CHECK-DAG:       %[[VAL_13:.*]] = bufferization.to_memref %[[VAL_1]] : memref<32x16xf32>
-// CHECK-DAG:       %[[VAL_15:.*]] = memref.alloc() : memref<32x16xf32>
+// CHECK-DAG:       %[[VAL_15:.*]] = bufferization.to_memref %[[VAL_2]] : memref<32x16xf32>
 // CHECK:           linalg.fill ins(%{{.*}} : f32) outs(%[[VAL_15]] : memref<32x16xf32>)
 // CHECK:           %[[VAL_16:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_6]]] : memref<?xindex>
 // CHECK:           %[[VAL_17:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_7]]] : memref<?xindex>
@@ -404,7 +404,7 @@ func.func @add_ss(%arga: tensor<32x16xf32, #Tss>, %argb: tensor<32x16xf32>, %arg
 // CHECK-DAG:       %[[VAL_8:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_4]] : tensor<32x16xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
 // CHECK-DAG:       %[[VAL_9:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x16xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf32>
 // CHECK-DAG:       %[[VAL_10:.*]] = bufferization.to_memref %[[VAL_1]] : memref<32x16xf32>
-// CHECK-DAG:       %[[VAL_12:.*]] = memref.alloc() : memref<32x16xf32>
+// CHECK-DAG:       %[[VAL_12:.*]] = bufferization.to_memref %[[VAL_2]] : memref<32x16xf32>
 // CHECK:           linalg.fill ins(%{{.*}} : f32) outs(%[[VAL_12]] : memref<32x16xf32>)
 // CHECK:           %[[VAL_13:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_3]]] : memref<?xindex>
 // CHECK:           %[[VAL_14:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_4]]] : memref<?xindex>
@@ -451,7 +451,7 @@ func.func @mul_ss(%arga: tensor<32x16xf32, #Tss>, %argb: tensor<32x16xf32>, %arg
 // CHECK-DAG:       %[[VAL_12:.*]] = sparse_tensor.pointers %[[VAL_1]], %[[VAL_4]] : tensor<32x16xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
 // CHECK-DAG:       %[[VAL_13:.*]] = sparse_tensor.indices %[[VAL_1]], %[[VAL_4]] : tensor<32x16xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
 // CHECK-DAG:       %[[VAL_14:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<32x16xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf32>
-// CHECK-DAG:       %[[VAL_16:.*]] = memref.alloc() : memref<32x16xf32>
+// CHECK-DAG:       %[[VAL_16:.*]] = bufferization.to_memref %[[VAL_2]] : memref<32x16xf32>
 // CHECK:           linalg.fill ins(%{{.*}} : f32) outs(%[[VAL_16]] : memref<32x16xf32>)
 // CHECK:           %[[VAL_17:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_3]]] : memref<?xindex>
 // CHECK:           %[[VAL_18:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_4]]] : memref<?xindex>
@@ -615,7 +615,7 @@ func.func @add_ss_ss(%arga: tensor<32x16xf32, #Tss>, %argb: tensor<32x16xf32, #T
 // CHECK-DAG:       %[[VAL_12:.*]] = sparse_tensor.pointers %[[VAL_1]], %[[VAL_4]] : tensor<32x16xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
 // CHECK-DAG:       %[[VAL_13:.*]] = sparse_tensor.indices %[[VAL_1]], %[[VAL_4]] : tensor<32x16xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
 // CHECK-DAG:       %[[VAL_14:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<32x16xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf32>
-// CHECK-DAG:       %[[VAL_16:.*]] = memref.alloc() : memref<32x16xf32>
+// CHECK-DAG:       %[[VAL_16:.*]] = bufferization.to_memref %[[VAL_2]] : memref<32x16xf32>
 // CHECK:           linalg.fill ins(%{{.*}} : f32) outs(%[[VAL_16]] : memref<32x16xf32>)
 // CHECK:           %[[VAL_17:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_3]]] : memref<?xindex>
 // CHECK:           %[[VAL_18:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_4]]] : memref<?xindex>
@@ -710,7 +710,7 @@ func.func @mul_ss_ss(%arga: tensor<32x16xf32, #Tss>, %argb: tensor<32x16xf32, #T
 // CHECK-DAG:       %[[VAL_11:.*]] = sparse_tensor.pointers %[[VAL_1]], %[[VAL_7]] : tensor<32x16xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
 // CHECK-DAG:       %[[VAL_12:.*]] = sparse_tensor.indices %[[VAL_1]], %[[VAL_7]] : tensor<32x16xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
 // CHECK-DAG:       %[[VAL_13:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<32x16xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf32>
-// CHECK-DAG:       %[[VAL_15:.*]] = memref.alloc() : memref<32x16xf32>
+// CHECK-DAG:       %[[VAL_15:.*]] = bufferization.to_memref %[[VAL_2]] : memref<32x16xf32>
 // CHECK:           linalg.fill ins(%{{.*}} : f32) outs(%[[VAL_15]] : memref<32x16xf32>)
 // CHECK:           %[[VAL_16:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_5]]] : memref<?xindex>
 // CHECK:           %[[VAL_17:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_7]]] : memref<?xindex>
@@ -814,7 +814,7 @@ func.func @add_sd_ds(%arga: tensor<32x16xf32, #Tsd>, %argb: tensor<32x16xf32, #T
 // CHECK-DAG:       %[[VAL_9:.*]] = sparse_tensor.pointers %[[VAL_1]], %[[VAL_5]] : tensor<32x16xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
 // CHECK-DAG:       %[[VAL_10:.*]] = sparse_tensor.indices %[[VAL_1]], %[[VAL_5]] : tensor<32x16xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
 // CHECK-DAG:       %[[VAL_11:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<32x16xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf32>
-// CHECK-DAG:       %[[VAL_13:.*]] = memref.alloc() : memref<32x16xf32>
+// CHECK-DAG:       %[[VAL_13:.*]] = bufferization.to_memref %[[VAL_2]] : memref<32x16xf32>
 // CHECK:           linalg.fill ins(%{{.*}} : f32) outs(%[[VAL_13]] : memref<32x16xf32>)
 // CHECK:           %[[VAL_14:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_4]]] : memref<?xindex>
 // CHECK:           %[[VAL_15:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_5]]] : memref<?xindex>
@@ -868,9 +868,7 @@ func.func @mul_sd_ds(%arga: tensor<32x16xf32, #Tsd>, %argb: tensor<32x16xf32, #T
 // CHECK-DAG:       %[[VAL_7:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_5]] : tensor<16x32xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
 // CHECK-DAG:       %[[VAL_8:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<16x32xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf32>
 // CHECK-DAG:       %[[VAL_9:.*]] = bufferization.to_memref %[[VAL_1]] : memref<32xf32>
-// CHECK-DAG:       %[[VAL_10:.*]] = bufferization.to_memref %[[VAL_2]] : memref<16xf32>
-// CHECK-DAG:       %[[VAL_11:.*]] = memref.alloc() : memref<16xf32>
-// CHECK:           memref.copy %[[VAL_10]], %[[VAL_11]] : memref<16xf32> to memref<16xf32>
+// CHECK-DAG:       %[[VAL_11:.*]] = bufferization.to_memref %[[VAL_2]] : memref<16xf32>
 // CHECK:           scf.for %[[VAL_12:.*]] = %[[VAL_4]] to %[[VAL_3]] step %[[VAL_5]] {
 // CHECK-DAG:         %[[VAL_13:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_12]]] : memref<?xindex>
 // CHECK-DAG:         %[[VAL_14:.*]] = arith.addi %[[VAL_12]], %[[VAL_5]] : index
@@ -918,9 +916,7 @@ func.func @matvec(%argA: tensor<16x32xf32, #Tds>, %argb: tensor<32xf32>, %argx:
 // CHECK-DAG:       %[[VAL_4:.*]] = arith.constant 0 : index
 // CHECK-DAG:       %[[VAL_5:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_3]] : tensor<10x20xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
 // CHECK-DAG:       %[[VAL_6:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<10x20xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf32>
-// CHECK-DAG:       %[[VAL_7:.*]] = bufferization.to_memref %[[VAL_1]] : memref<f32>
-// CHECK-DAG:       %[[VAL_8:.*]] = memref.alloc() : memref<f32>
-// CHECK:           memref.copy %[[VAL_7]], %[[VAL_8]] : memref<f32> to memref<f32>
+// CHECK-DAG:       %[[VAL_8:.*]] = bufferization.to_memref %[[VAL_1]] : memref<f32>
 // CHECK:           %[[VAL_9:.*]] = memref.load %[[VAL_8]][] : memref<f32>
 // CHECK:           %[[VAL_10:.*]] = scf.for %[[VAL_11:.*]] = %[[VAL_4]] to %[[VAL_2]] step %[[VAL_3]] iter_args(%[[VAL_12:.*]] = %[[VAL_9]]) -> (f32) {
 // CHECK:             %[[VAL_13:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_11]]] : memref<?xindex>
@@ -967,8 +963,7 @@ func.func @sum_reduction(%arga: tensor<10x20xf32, #Tds>, %argx: tensor<f32>) ->
 // CHECK-DAG:       %[[VAL_6:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_4]] : tensor<?x?xf64, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
 // CHECK-DAG:       %[[VAL_7:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<?x?xf64, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf64>
 // CHECK-DAG:       %[[VAL_8:.*]] = tensor.dim %[[VAL_1]], %[[VAL_3]] : tensor<?x?xf64>
-// CHECK-DAG:       %[[VAL_9:.*]] = tensor.dim %[[VAL_1]], %[[VAL_4]] : tensor<?x?xf64>
-// CHECK-DAG:       %[[VAL_11:.*]] = memref.alloc(%[[VAL_8]], %[[VAL_9]]) : memref<?x?xf64>
+// CHECK-DAG:       %[[VAL_11:.*]] = bufferization.to_memref %[[VAL_1]] : memref<?x?xf64>
 // CHECK:           linalg.fill ins(%{{.*}} : f64) outs(%[[VAL_11]] : memref<?x?xf64>)
 // CHECK:           scf.for %[[VAL_12:.*]] = %[[VAL_3]] to %[[VAL_8]] step %[[VAL_4]] {
 // CHECK:             %[[VAL_13:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_12]]] : memref<?xindex>
@@ -1022,11 +1017,7 @@ func.func @scale(%arga: tensor<?x?xf64, #Tds>, %argx: tensor<?x?xf64>) -> tensor
 // CHECK-DAG:       %[[VAL_11:.*]] = bufferization.to_memref %[[VAL_1]] : memref<?x?xf32>
 // CHECK-DAG:       %[[VAL_12:.*]] = tensor.dim %[[VAL_2]], %[[VAL_4]] : tensor<?x?xf32>
 // CHECK-DAG:       %[[VAL_13:.*]] = bufferization.to_memref %[[VAL_2]] : memref<?x?xf32>
-// CHECK-DAG:       %[[VAL_14:.*]] = tensor.dim %[[VAL_3]], %[[VAL_4]] : tensor<?x?xf32>
-// CHECK-DAG:       %[[VAL_15:.*]] = tensor.dim %[[VAL_3]], %[[VAL_5]] : tensor<?x?xf32>
-// CHECK-DAG:       %[[VAL_16:.*]] = bufferization.to_memref %[[VAL_3]] : memref<?x?xf32>
-// CHECK-DAG:       %[[VAL_17:.*]] = memref.alloc(%[[VAL_14]], %[[VAL_15]]) : memref<?x?xf32>
-// CHECK:           memref.copy %[[VAL_16]], %[[VAL_17]] : memref<?x?xf32> to memref<?x?xf32>
+// CHECK-DAG:       %[[VAL_17:.*]] = bufferization.to_memref %[[VAL_3]] : memref<?x?xf32>
 // CHECK:           %[[VAL_18:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_4]]] : memref<?xindex>
 // CHECK:           %[[VAL_19:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_5]]] : memref<?xindex>
 // CHECK:           scf.for %[[VAL_20:.*]] = %[[VAL_18]] to %[[VAL_19]] step %[[VAL_5]] {
@@ -1105,9 +1096,7 @@ func.func @sampled_dense_dense(%args: tensor<?x?xf32, #Tss>,
 // CHECK-DAG:       %[[VAL_20:.*]] = bufferization.to_memref %[[VAL_3]] : memref<?xf32>
 // CHECK-DAG:       %[[VAL_21:.*]] = bufferization.to_memref %[[VAL_4]] : memref<f32>
 // CHECK-DAG:       %[[VAL_22:.*]] = tensor.dim %[[VAL_5]], %[[VAL_6]] : tensor<?xf32>
-// CHECK-DAG:       %[[VAL_23:.*]] = bufferization.to_memref %[[VAL_5]] : memref<?xf32>
-// CHECK-DAG:       %[[VAL_24:.*]] = memref.alloc(%[[VAL_22]]) : memref<?xf32>
-// CHECK:           memref.copy %[[VAL_23]], %[[VAL_24]] : memref<?xf32> to memref<?xf32>
+// CHECK-DAG:       %[[VAL_24:.*]] = bufferization.to_memref %[[VAL_5]] : memref<?xf32>
 // CHECK:           %[[VAL_25:.*]] = memref.load %[[VAL_21]][] : memref<f32>
 // CHECK:           %[[VAL_26:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_6]]] : memref<?xindex>
 // CHECK:           %[[VAL_27:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_7]]] : memref<?xindex>

diff  --git a/mlir/test/Dialect/SparseTensor/sparse_3d.mlir b/mlir/test/Dialect/SparseTensor/sparse_3d.mlir
index bdc74e7d40be7..d3f36ae90c98c 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_3d.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_3d.mlir
@@ -34,7 +34,7 @@
 // CHECK-DAG:       %[[VAL_7:.*]] = arith.constant 1 : index
 // CHECK-DAG:       %[[VAL_8:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "dense", "dense" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf32>
 // CHECK-DAG:       %[[VAL_9:.*]] = bufferization.to_memref %[[VAL_1]] : memref<32x16x8xf32>
-// CHECK-DAG:       %[[VAL_11:.*]] = memref.alloc() : memref<32x16x8xf32>
+// CHECK-DAG:       %[[VAL_11:.*]] = bufferization.to_memref %[[VAL_2]] : memref<32x16x8xf32>
 // CHECK:           linalg.fill ins(%[[ZERO]] : f32) outs(%[[VAL_11]] : memref<32x16x8xf32>)
 // CHECK:           scf.for %[[VAL_12:.*]] = %[[VAL_6]] to %[[VAL_3]] step %[[VAL_7]] {
 // CHECK:             scf.for %[[VAL_13:.*]] = %[[VAL_6]] to %[[VAL_4]] step %[[VAL_7]] {
@@ -76,7 +76,7 @@ func.func @add_ddd(%arga: tensor<32x16x8xf32, #Tddd>, %argb: tensor<32x16x8xf32>
 // CHECK-DAG:       %[[VAL_7:.*]] = arith.constant 1 : index
 // CHECK-DAG:       %[[VAL_8:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "dense", "dense" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf32>
 // CHECK-DAG:       %[[VAL_9:.*]] = bufferization.to_memref %[[VAL_1]] : memref<32x16x8xf32>
-// CHECK-DAG:       %[[VAL_11:.*]] = memref.alloc() : memref<32x16x8xf32>
+// CHECK-DAG:       %[[VAL_11:.*]] = bufferization.to_memref %[[VAL_2]] : memref<32x16x8xf32>
 // CHECK:           linalg.fill ins(%[[ZERO]] : f32) outs(%[[VAL_11]] : memref<32x16x8xf32>)
 // CHECK:           scf.for %[[VAL_12:.*]] = %[[VAL_6]] to %[[VAL_3]] step %[[VAL_7]] {
 // CHECK:             scf.for %[[VAL_13:.*]] = %[[VAL_6]] to %[[VAL_4]] step %[[VAL_7]] {
@@ -122,7 +122,7 @@ func.func @mul_ddd(%arga: tensor<32x16x8xf32, #Tddd>, %argb: tensor<32x16x8xf32>
 // CHECK-DAG:       %[[VAL_11:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_3]] : tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "dense", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
 // CHECK-DAG:       %[[VAL_12:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "dense", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf32>
 // CHECK-DAG:       %[[VAL_13:.*]] = bufferization.to_memref %[[VAL_1]] : memref<32x16x8xf32>
-// CHECK-DAG:       %[[VAL_15:.*]] = memref.alloc() : memref<32x16x8xf32>
+// CHECK-DAG:       %[[VAL_15:.*]] = bufferization.to_memref %[[VAL_2]] : memref<32x16x8xf32>
 // CHECK:           linalg.fill ins(%[[ZERO]] : f32) outs(%[[VAL_15]] : memref<32x16x8xf32>)
 // CHECK:           scf.for %[[VAL_16:.*]] = %[[VAL_7]] to %[[VAL_4]] step %[[VAL_9]] {
 // CHECK:             scf.for %[[VAL_17:.*]] = %[[VAL_7]] to %[[VAL_5]] step %[[VAL_9]] {
@@ -190,7 +190,7 @@ func.func @add_dds(%arga: tensor<32x16x8xf32, #Tdds>, %argb: tensor<32x16x8xf32>
 // CHECK-DAG:       %[[VAL_9:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_3]] : tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "dense", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
 // CHECK-DAG:       %[[VAL_10:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "dense", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf32>
 // CHECK-DAG:       %[[VAL_11:.*]] = bufferization.to_memref %[[VAL_1]] : memref<32x16x8xf32>
-// CHECK-DAG:       %[[VAL_13:.*]] = memref.alloc() : memref<32x16x8xf32>
+// CHECK-DAG:       %[[VAL_13:.*]] = bufferization.to_memref %[[VAL_2]] : memref<32x16x8xf32>
 // CHECK:           linalg.fill ins(%[[ZERO]] : f32) outs(%[[VAL_13]] : memref<32x16x8xf32>)
 // CHECK:           scf.for %[[VAL_14:.*]] = %[[VAL_6]] to %[[VAL_4]] step %[[VAL_7]] {
 // CHECK:             scf.for %[[VAL_15:.*]] = %[[VAL_6]] to %[[VAL_5]] step %[[VAL_7]] {
@@ -237,7 +237,7 @@ func.func @mul_dds(%arga: tensor<32x16x8xf32, #Tdds>, %argb: tensor<32x16x8xf32>
 // CHECK-DAG:       %[[VAL_10:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_8]] : tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed", "dense" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
 // CHECK-DAG:       %[[VAL_11:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed", "dense" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf32>
 // CHECK-DAG:       %[[VAL_12:.*]] = bufferization.to_memref %[[VAL_1]] : memref<32x16x8xf32>
-// CHECK-DAG:       %[[VAL_14:.*]] = memref.alloc() : memref<32x16x8xf32>
+// CHECK-DAG:       %[[VAL_14:.*]] = bufferization.to_memref %[[VAL_2]] : memref<32x16x8xf32>
 // CHECK:           linalg.fill ins(%[[ZERO]] : f32) outs(%[[VAL_14]] : memref<32x16x8xf32>)
 // CHECK:           scf.for %[[VAL_15:.*]] = %[[VAL_7]] to %[[VAL_3]] step %[[VAL_8]] {
 // CHECK:             %[[VAL_16:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_15]]] : memref<?xindex>
@@ -308,7 +308,7 @@ func.func @add_dsd(%arga: tensor<32x16x8xf32, #Tdsd>, %argb: tensor<32x16x8xf32>
 // CHECK-DAG:       %[[VAL_8:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_6]] : tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed", "dense" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
 // CHECK-DAG:       %[[VAL_9:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed", "dense" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf32>
 // CHECK-DAG:       %[[VAL_10:.*]] = bufferization.to_memref %[[VAL_1]] : memref<32x16x8xf32>
-// CHECK-DAG:       %[[VAL_12:.*]] = memref.alloc() : memref<32x16x8xf32>
+// CHECK-DAG:       %[[VAL_12:.*]] = bufferization.to_memref %[[VAL_2]] : memref<32x16x8xf32>
 // CHECK:           linalg.fill ins(%[[ZERO]] : f32) outs(%[[VAL_12]] : memref<32x16x8xf32>)
 // CHECK:           scf.for %[[VAL_13:.*]] = %[[VAL_5]] to %[[VAL_3]] step %[[VAL_6]] {
 // CHECK:             %[[VAL_14:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_13]]] : memref<?xindex>
@@ -358,7 +358,7 @@ func.func @mul_dsd(%arga: tensor<32x16x8xf32, #Tdsd>, %argb: tensor<32x16x8xf32>
 // CHECK-DAG:       %[[VAL_13:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_3]] : tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
 // CHECK-DAG:       %[[VAL_14:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf32>
 // CHECK-DAG:       %[[VAL_15:.*]] = bufferization.to_memref %[[VAL_1]] : memref<32x16x8xf32>
-// CHECK-DAG:       %[[VAL_17:.*]] = memref.alloc() : memref<32x16x8xf32>
+// CHECK-DAG:       %[[VAL_17:.*]] = bufferization.to_memref %[[VAL_2]] : memref<32x16x8xf32>
 // CHECK:           linalg.fill ins(%[[ZERO]] : f32) outs(%[[VAL_17]] : memref<32x16x8xf32>)
 // CHECK:           scf.for %[[VAL_18:.*]] = %[[VAL_8]] to %[[VAL_4]] step %[[VAL_9]] {
 // CHECK:             %[[VAL_19:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_18]]] : memref<?xindex>
@@ -455,7 +455,7 @@ func.func @add_dss(%arga: tensor<32x16x8xf32, #Tdss>, %argb: tensor<32x16x8xf32>
 // CHECK-DAG:       %[[VAL_10:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_3]] : tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
 // CHECK-DAG:       %[[VAL_11:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf32>
 // CHECK-DAG:       %[[VAL_12:.*]] = bufferization.to_memref %[[VAL_1]] : memref<32x16x8xf32>
-// CHECK-DAG:       %[[VAL_14:.*]] = memref.alloc() : memref<32x16x8xf32>
+// CHECK-DAG:       %[[VAL_14:.*]] = bufferization.to_memref %[[VAL_2]] : memref<32x16x8xf32>
 // CHECK:           linalg.fill ins(%[[ZERO]] : f32) outs(%[[VAL_14]] : memref<32x16x8xf32>)
 // CHECK:           scf.for %[[VAL_15:.*]] = %[[VAL_5]] to %[[VAL_4]] step %[[VAL_6]] {
 // CHECK:             %[[VAL_16:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_15]]] : memref<?xindex>
@@ -504,7 +504,7 @@ func.func @mul_dss(%arga: tensor<32x16x8xf32, #Tdss>, %argb: tensor<32x16x8xf32>
 // CHECK-DAG:       %[[VAL_10:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_7]] : tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "dense", "dense" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
 // CHECK-DAG:       %[[VAL_11:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "dense", "dense" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf32>
 // CHECK-DAG:       %[[VAL_12:.*]] = bufferization.to_memref %[[VAL_1]] : memref<32x16x8xf32>
-// CHECK-DAG:       %[[VAL_14:.*]] = memref.alloc() : memref<32x16x8xf32>
+// CHECK-DAG:       %[[VAL_14:.*]] = bufferization.to_memref %[[VAL_2]] : memref<32x16x8xf32>
 // CHECK:           linalg.fill ins(%[[ZERO]] : f32) outs(%[[VAL_14]] : memref<32x16x8xf32>)
 // CHECK:           %[[VAL_15:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_7]]] : memref<?xindex>
 // CHECK:           %[[VAL_16:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_8]]] : memref<?xindex>
@@ -580,7 +580,7 @@ func.func @add_sdd(%arga: tensor<32x16x8xf32, #Tsdd>, %argb: tensor<32x16x8xf32>
 // CHECK-DAG:       %[[VAL_8:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_5]] : tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "dense", "dense" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
 // CHECK-DAG:       %[[VAL_9:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "dense", "dense" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf32>
 // CHECK-DAG:       %[[VAL_10:.*]] = bufferization.to_memref %[[VAL_1]] : memref<32x16x8xf32>
-// CHECK-DAG:       %[[VAL_12:.*]] = memref.alloc() : memref<32x16x8xf32>
+// CHECK-DAG:       %[[VAL_12:.*]] = bufferization.to_memref %[[VAL_2]] : memref<32x16x8xf32>
 // CHECK:           linalg.fill ins(%[[ZERO]] : f32) outs(%[[VAL_12]] : memref<32x16x8xf32>)
 // CHECK:           %[[VAL_13:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_5]]] : memref<?xindex>
 // CHECK:           %[[VAL_14:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_6]]] : memref<?xindex>
@@ -631,7 +631,7 @@ func.func @mul_sdd(%arga: tensor<32x16x8xf32, #Tsdd>, %argb: tensor<32x16x8xf32>
 // CHECK-DAG:       %[[VAL_13:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_3]] : tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "dense", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
 // CHECK-DAG:       %[[VAL_14:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "dense", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf32>
 // CHECK-DAG:       %[[VAL_15:.*]] = bufferization.to_memref %[[VAL_1]] : memref<32x16x8xf32>
-// CHECK-DAG:       %[[VAL_17:.*]] = memref.alloc() : memref<32x16x8xf32>
+// CHECK-DAG:       %[[VAL_17:.*]] = bufferization.to_memref %[[VAL_2]] : memref<32x16x8xf32>
 // CHECK:           linalg.fill ins(%[[ZERO]] : f32) outs(%[[VAL_17]] : memref<32x16x8xf32>)
 // CHECK:           %[[VAL_18:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_8]]] : memref<?xindex>
 // CHECK:           %[[VAL_19:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_9]]] : memref<?xindex>
@@ -733,7 +733,7 @@ func.func @add_sds(%arga: tensor<32x16x8xf32, #Tsds>, %argb: tensor<32x16x8xf32>
 // CHECK-DAG:       %[[VAL_10:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_3]] : tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "dense", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
 // CHECK-DAG:       %[[VAL_11:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "dense", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf32>
 // CHECK-DAG:       %[[VAL_12:.*]] = bufferization.to_memref %[[VAL_1]] : memref<32x16x8xf32>
-// CHECK-DAG:       %[[VAL_14:.*]] = memref.alloc() : memref<32x16x8xf32>
+// CHECK-DAG:       %[[VAL_14:.*]] = bufferization.to_memref %[[VAL_2]] : memref<32x16x8xf32>
 // CHECK:           linalg.fill ins(%[[ZERO]] : f32) outs(%[[VAL_14]] : memref<32x16x8xf32>)
 // CHECK:           %[[VAL_15:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_5]]] : memref<?xindex>
 // CHECK:           %[[VAL_16:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_6]]] : memref<?xindex>
@@ -785,7 +785,7 @@ func.func @mul_sds(%arga: tensor<32x16x8xf32, #Tsds>, %argb: tensor<32x16x8xf32>
 // CHECK-DAG:       %[[VAL_12:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_8]] : tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed", "dense" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
 // CHECK-DAG:       %[[VAL_13:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed", "dense" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf32>
 // CHECK-DAG:       %[[VAL_14:.*]] = bufferization.to_memref %[[VAL_1]] : memref<32x16x8xf32>
-// CHECK-DAG:       %[[VAL_16:.*]] = memref.alloc() : memref<32x16x8xf32>
+// CHECK-DAG:       %[[VAL_16:.*]] = bufferization.to_memref %[[VAL_2]] : memref<32x16x8xf32>
 // CHECK:           linalg.fill ins(%[[ZERO]] : f32) outs(%[[VAL_16]] : memref<32x16x8xf32>)
 // CHECK:           %[[VAL_17:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_7]]] : memref<?xindex>
 // CHECK:           %[[VAL_18:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_8]]] : memref<?xindex>
@@ -890,7 +890,7 @@ func.func @add_ssd(%arga: tensor<32x16x8xf32, #Tssd>, %argb: tensor<32x16x8xf32>
 // CHECK-DAG:       %[[VAL_9:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_5]] : tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed", "dense" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
 // CHECK-DAG:       %[[VAL_10:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed", "dense" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf32>
 // CHECK-DAG:       %[[VAL_11:.*]] = bufferization.to_memref %[[VAL_1]] : memref<32x16x8xf32>
-// CHECK-DAG:       %[[VAL_13:.*]] = memref.alloc() : memref<32x16x8xf32>
+// CHECK-DAG:       %[[VAL_13:.*]] = bufferization.to_memref %[[VAL_2]] : memref<32x16x8xf32>
 // CHECK:           linalg.fill ins(%[[ZERO]] : f32) outs(%[[VAL_13]] : memref<32x16x8xf32>)
 // CHECK:           %[[VAL_14:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_4]]] : memref<?xindex>
 // CHECK:           %[[VAL_15:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_5]]] : memref<?xindex>
@@ -945,7 +945,7 @@ func.func @mul_ssd(%arga: tensor<32x16x8xf32, #Tssd>, %argb: tensor<32x16x8xf32>
 // CHECK-DAG:       %[[VAL_15:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_3]] : tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
 // CHECK-DAG:       %[[VAL_16:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf32>
 // CHECK-DAG:       %[[VAL_17:.*]] = bufferization.to_memref %[[VAL_1]] : memref<32x16x8xf32>
-// CHECK-DAG:       %[[VAL_19:.*]] = memref.alloc() : memref<32x16x8xf32>
+// CHECK-DAG:       %[[VAL_19:.*]] = bufferization.to_memref %[[VAL_2]] : memref<32x16x8xf32>
 // CHECK:           linalg.fill ins(%[[ZERO]] : f32) outs(%[[VAL_19]] : memref<32x16x8xf32>)
 // CHECK:           %[[VAL_20:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_8]]] : memref<?xindex>
 // CHECK:           %[[VAL_21:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_9]]] : memref<?xindex>
@@ -1076,7 +1076,7 @@ func.func @add_sss(%arga: tensor<32x16x8xf32, #Tsss>, %argb: tensor<32x16x8xf32>
 // CHECK-DAG:       %[[VAL_11:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_3]] : tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
 // CHECK-DAG:       %[[VAL_12:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf32>
 // CHECK-DAG:       %[[VAL_13:.*]] = bufferization.to_memref %[[VAL_1]] : memref<32x16x8xf32>
-// CHECK-DAG:       %[[VAL_15:.*]] = memref.alloc() : memref<32x16x8xf32>
+// CHECK-DAG:       %[[VAL_15:.*]] = bufferization.to_memref %[[VAL_2]] : memref<32x16x8xf32>
 // CHECK:           linalg.fill ins(%[[ZERO]] : f32) outs(%[[VAL_15]] : memref<32x16x8xf32>)
 // CHECK:           %[[VAL_16:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_4]]] : memref<?xindex>
 // CHECK:           %[[VAL_17:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_5]]] : memref<?xindex>
@@ -1140,9 +1140,7 @@ func.func @mul_sss(%arga: tensor<32x16x8xf32, #Tsss>, %argb: tensor<32x16x8xf32>
 // CHECK-DAG:       %[[VAL_12:.*]] = bufferization.to_memref %[[VAL_3]] : memref<?x?xf32>
 // CHECK-DAG:       %[[VAL_13:.*]] = tensor.dim %[[VAL_0]], %[[VAL_5]] : tensor<?x?xf32>
 // CHECK-DAG:       %[[VAL_14:.*]] = tensor.dim %[[VAL_0]], %[[VAL_6]] : tensor<?x?xf32>
-// CHECK-DAG:       %[[VAL_15:.*]] = bufferization.to_memref %[[VAL_0]] : memref<?x?xf32>
-// CHECK-DAG:       %[[VAL_16:.*]] = memref.alloc(%[[VAL_13]], %[[VAL_14]]) : memref<?x?xf32>
-// CHECK:           memref.copy %[[VAL_15]], %[[VAL_16]] : memref<?x?xf32> to memref<?x?xf32>
+// CHECK-DAG:       %[[VAL_16:.*]] = bufferization.to_memref %[[VAL_0]] : memref<?x?xf32>
 // CHECK:           scf.for %[[VAL_17:.*]] = %[[VAL_5]] to %[[VAL_13]] step %[[VAL_6]] {
 // CHECK:             scf.for %[[VAL_18:.*]] = %[[VAL_5]] to %[[VAL_10]] step %[[VAL_6]] {
 // CHECK:               %[[VAL_19:.*]] = arith.muli %[[VAL_10]], %[[VAL_17]] : index
@@ -1203,9 +1201,7 @@ func.func @kernel_3d(%arga: tensor<?x?xf32>,
 // CHECK-DAG:       %[[VAL_6:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_3]] : tensor<10x20x30xf32, #sparse_tensor.encoding<{{{.*}}>>
 // CHECK-DAG:       %[[VAL_7:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_4]] : tensor<10x20x30xf32, #sparse_tensor.encoding<{{{.*}}}>>
 // CHECK-DAG:       %[[VAL_8:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<10x20x30xf32, #sparse_tensor.encoding<{{{.*}}}>>
-// CHECK-DAG:       %[[VAL_9:.*]] = bufferization.to_memref %[[VAL_1]] : memref<f32>
-// CHECK-DAG:       %[[VAL_10:.*]] = memref.alloc() : memref<f32>
-// CHECK:           memref.copy %[[VAL_9]], %[[VAL_10]] : memref<f32> to memref<f32>
+// CHECK-DAG:       %[[VAL_10:.*]] = bufferization.to_memref %[[VAL_1]] : memref<f32>
 // CHECK:           %[[VAL_11:.*]] = memref.load %[[VAL_10]][] : memref<f32>
 // CHECK:           %[[VAL_12:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_2]]] : memref<?xindex>
 // CHECK:           %[[VAL_13:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_3]]] : memref<?xindex>
@@ -1263,9 +1259,7 @@ func.func @sum_reduction(%arga: tensor<10x20x30xf32, #Tsss>, %argx: tensor<f32>)
 // CHECK-DAG:       %[[VAL_8:.*]] = bufferization.to_memref %[[VAL_0]] : memref<?x?x?xf32>
 // CHECK-DAG:       %[[VAL_9:.*]] = tensor.dim %[[VAL_1]], %[[VAL_5]] : tensor<?xf32, #sparse_tensor.encoding<{{{.*}}}>>
 // CHECK-DAG:       %[[VAL_10:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<?xf32, #sparse_tensor.encoding<{{{.*}}}>>
-// CHECK-DAG:       %[[VAL_11:.*]] = bufferization.to_memref %[[VAL_2]] : memref<f32>
-// CHECK-DAG:       %[[VAL_12:.*]] = memref.alloc() : memref<f32>
-// CHECK:           memref.copy %[[VAL_11]], %[[VAL_12]] : memref<f32> to memref<f32>
+// CHECK-DAG:       %[[VAL_12:.*]] = bufferization.to_memref %[[VAL_2]] : memref<f32>
 // CHECK:           %[[VAL_13:.*]] = memref.load %[[VAL_12]][] : memref<f32>
 // CHECK:           %[[VAL_14:.*]] = scf.for %[[VAL_15:.*]] = %[[VAL_5]] to %[[VAL_9]] step %[[VAL_3]] iter_args(%[[VAL_16:.*]] = %[[VAL_13]]) -> (f32) {
 // CHECK:             %[[VAL_17:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_15]]] : memref<?xf32>
@@ -1323,7 +1317,7 @@ func.func @sum_reduction_inv(%arga: tensor<?x?x?xf32>,
 // CHECK-DAG:       %[[VAL_9:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<10xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf32>
 // CHECK-DAG:       %[[VAL_10:.*]] = bufferization.to_memref %[[VAL_1]] : memref<20xf32>
 // CHECK-DAG:       %[[VAL_11:.*]] = bufferization.to_memref %[[VAL_2]] : memref<30xf32>
-// CHECK-DAG:       %[[VAL_13:.*]] = memref.alloc() : memref<10x20x30xf32>
+// CHECK-DAG:       %[[VAL_13:.*]] = bufferization.to_memref %[[VAL_3]] : memref<10x20x30xf32>
 // CHECK:           linalg.fill ins(%[[ZERO]] : f32) outs(%[[VAL_13]] : memref<10x20x30xf32>)
 // CHECK:           scf.for %[[VAL_14:.*]] = %[[VAL_7]] to %[[VAL_4]] step %[[VAL_8]] {
 // CHECK:             %[[VAL_15:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_14]]] : memref<?xf32>

diff  --git a/mlir/test/Dialect/SparseTensor/sparse_affine.mlir b/mlir/test/Dialect/SparseTensor/sparse_affine.mlir
index 1520a0e0275db..948095f75daa1 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_affine.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_affine.mlir
@@ -25,9 +25,7 @@
 // CHECK-DAG:       %[[VAL_7:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_3]] : tensor<32xf32, #sparse_tensor.encoding<{{{.*}}}>>
 // CHECK-DAG:       %[[VAL_8:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32xf32, #sparse_tensor.encoding<{{{.*}}}>>
 // CHECK-DAG:       %[[VAL_9:.*]] = bufferization.to_memref %[[VAL_1]] : memref<4xf32>
-// CHECK-DAG:       %[[VAL_10:.*]] = bufferization.to_memref %[[VAL_2]] : memref<32xf32>
-// CHECK-DAG:       %[[VAL_11:.*]] = memref.alloc() : memref<32xf32>
-// CHECK:           memref.copy %[[VAL_10]], %[[VAL_11]] : memref<32xf32> to memref<32xf32>
+// CHECK-DAG:       %[[VAL_11:.*]] = bufferization.to_memref %[[VAL_2]] : memref<32xf32>
 // CHECK:           %[[VAL_12:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_4]]] : memref<4xf32>
 // CHECK:           %[[VAL_13:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_3]]] : memref<?xindex>
 // CHECK:           %[[VAL_14:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_5]]] : memref<?xindex>
@@ -78,7 +76,7 @@ func.func @mul_inv_dense1d(%arga: tensor<32xf32, #SpVec>,
 // CHECK-DAG:       %[[VAL_7:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_3]] : tensor<32xi32, #sparse_tensor.encoding<{{{.*}}}>>
 // CHECK-DAG:       %[[VAL_8:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32xi32, #sparse_tensor.encoding<{{{.*}}}>>
 // CHECK-DAG:       %[[VAL_9:.*]] = bufferization.to_memref %[[VAL_1]] : memref<34xi32>
-// CHECK-DAG:       %[[VAL_11:.*]] = memref.alloc() : memref<32xi32>
+// CHECK-DAG:       %[[VAL_11:.*]] = bufferization.to_memref %[[VAL_2]] : memref<32xi32>
 // CHECK:           linalg.fill ins(%[[ZERO]] : i32) outs(%[[VAL_11]] : memref<32xi32>)
 // CHECK:           %[[VAL_12:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_3]]] : memref<?xindex>
 // CHECK:           %[[VAL_13:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_4]]] : memref<?xindex>
@@ -129,9 +127,7 @@ func.func @and_affine_dense1d(%arga: tensor<32xi32, #SpVec>,
 // CHECK-DAG:       %[[VAL_9:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_3]] : tensor<32x16xf64, #sparse_tensor.encoding<{{{.*}}}>>
 // CHECK-DAG:       %[[VAL_10:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x16xf64, #sparse_tensor.encoding<{{{.*}}}>>
 // CHECK-DAG:       %[[VAL_11:.*]] = bufferization.to_memref %[[VAL_1]] : memref<34x19xf64>
-// CHECK-DAG:       %[[VAL_12:.*]] = bufferization.to_memref %[[VAL_2]] : memref<32x16xf64>
-// CHECK-DAG:       %[[VAL_13:.*]] = memref.alloc() : memref<32x16xf64>
-// CHECK:           memref.copy %[[VAL_12]], %[[VAL_13]] : memref<32x16xf64> to memref<32x16xf64>
+// CHECK-DAG:       %[[VAL_13:.*]] = bufferization.to_memref %[[VAL_2]] : memref<32x16xf64>
 // CHECK:           scf.for %[[VAL_14:.*]] = %[[VAL_5]] to %[[VAL_4]] step %[[VAL_3]] {
 // CHECK:             %[[VAL_15:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_14]]] : memref<?xindex>
 // CHECK:             %[[VAL_16:.*]] = arith.addi %[[VAL_14]], %[[VAL_3]] : index

diff  --git a/mlir/test/Dialect/SparseTensor/sparse_fp_ops.mlir b/mlir/test/Dialect/SparseTensor/sparse_fp_ops.mlir
index 1fe0905aad613..edcaeb5093101 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_fp_ops.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_fp_ops.mlir
@@ -33,7 +33,7 @@
 
 // CHECK-LABEL: func @abs(
 // CHECK-SAME:    %[[VAL_0:.*]]: tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>>,
-// CHECK-SAME:    %[[VAL_1:.*]]: tensor<32xf64> {linalg.inplaceable = true}) -> tensor<32xf64> {
+// CHECK-SAME:    %[[VAL_1:.*]]: tensor<32xf64>) -> tensor<32xf64> {
 // CHECK-DAG:     %[[VAL_2:.*]] = arith.constant 0 : index
 // CHECK-DAG:     %[[VAL_3:.*]] = arith.constant 1 : index
 // CHECK:         %[[VAL_4:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_2]] : tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
@@ -52,7 +52,7 @@
 // CHECK:         return %[[VAL_14]] : tensor<32xf64>
 // CHECK:       }
 func.func @abs(%arga: tensor<32xf64, #SV>,
-          %argx: tensor<32xf64> {linalg.inplaceable = true}) -> tensor<32xf64> {
+               %argx: tensor<32xf64>) -> tensor<32xf64> {
   %0 = linalg.generic #trait1
      ins(%arga: tensor<32xf64, #SV>)
     outs(%argx: tensor<32xf64>) {
@@ -65,7 +65,7 @@ func.func @abs(%arga: tensor<32xf64, #SV>,
 
 // CHECK-LABEL: func @ceil(
 // CHECK-SAME:    %[[VAL_0:.*]]: tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>>,
-// CHECK-SAME:    %[[VAL_1:.*]]: tensor<32xf64> {linalg.inplaceable = true}) -> tensor<32xf64> {
+// CHECK-SAME:    %[[VAL_1:.*]]: tensor<32xf64>) -> tensor<32xf64> {
 // CHECK-DAG:     %[[VAL_2:.*]] = arith.constant 0 : index
 // CHECK-DAG:     %[[VAL_3:.*]] = arith.constant 1 : index
 // CHECK:         %[[VAL_4:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_2]] : tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
@@ -84,7 +84,7 @@ func.func @abs(%arga: tensor<32xf64, #SV>,
 // CHECK:         return %[[VAL_14]] : tensor<32xf64>
 // CHECK:       }
 func.func @ceil(%arga: tensor<32xf64, #SV>,
-           %argx: tensor<32xf64> {linalg.inplaceable = true}) -> tensor<32xf64> {
+                %argx: tensor<32xf64>) -> tensor<32xf64> {
   %0 = linalg.generic #trait1
      ins(%arga: tensor<32xf64, #SV>)
     outs(%argx: tensor<32xf64>) {
@@ -97,7 +97,7 @@ func.func @ceil(%arga: tensor<32xf64, #SV>,
 
 // CHECK-LABEL: func @floor(
 // CHECK-SAME:    %[[VAL_0:.*]]: tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>>,
-// CHECK-SAME:    %[[VAL_1:.*]]: tensor<32xf64> {linalg.inplaceable = true}) -> tensor<32xf64> {
+// CHECK-SAME:    %[[VAL_1:.*]]: tensor<32xf64>) -> tensor<32xf64> {
 // CHECK-DAG:     %[[VAL_2:.*]] = arith.constant 0 : index
 // CHECK-DAG:     %[[VAL_3:.*]] = arith.constant 1 : index
 // CHECK:         %[[VAL_4:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_2]] : tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
@@ -116,7 +116,7 @@ func.func @ceil(%arga: tensor<32xf64, #SV>,
 // CHECK:         return %[[VAL_14]] : tensor<32xf64>
 // CHECK:       }
 func.func @floor(%arga: tensor<32xf64, #SV>,
-            %argx: tensor<32xf64> {linalg.inplaceable = true}) -> tensor<32xf64> {
+                 %argx: tensor<32xf64>) -> tensor<32xf64> {
   %0 = linalg.generic #trait1
      ins(%arga: tensor<32xf64, #SV>)
     outs(%argx: tensor<32xf64>) {
@@ -129,7 +129,7 @@ func.func @floor(%arga: tensor<32xf64, #SV>,
 
 // CHECK-LABEL: func @neg(
 // CHECK-SAME:    %[[VAL_0:.*]]: tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>>,
-// CHECK-SAME:    %[[VAL_1:.*]]: tensor<32xf64> {linalg.inplaceable = true}) -> tensor<32xf64> {
+// CHECK-SAME:    %[[VAL_1:.*]]: tensor<32xf64>) -> tensor<32xf64> {
 // CHECK-DAG:     %[[VAL_2:.*]] = arith.constant 0 : index
 // CHECK-DAG:     %[[VAL_3:.*]] = arith.constant 1 : index
 // CHECK:         %[[VAL_4:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_2]] : tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
@@ -148,7 +148,7 @@ func.func @floor(%arga: tensor<32xf64, #SV>,
 // CHECK:         return %[[VAL_14]] : tensor<32xf64>
 // CHECK:       }
 func.func @neg(%arga: tensor<32xf64, #SV>,
-          %argx: tensor<32xf64> {linalg.inplaceable = true}) -> tensor<32xf64> {
+               %argx: tensor<32xf64>) -> tensor<32xf64> {
   %0 = linalg.generic #trait1
      ins(%arga: tensor<32xf64, #SV>)
     outs(%argx: tensor<32xf64>) {
@@ -162,7 +162,7 @@ func.func @neg(%arga: tensor<32xf64, #SV>,
 // CHECK-LABEL: func @add(
 // CHECK-SAME:    %[[VAL_0:.*]]: tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>>,
 // CHECK-SAME:    %[[VAL_1:.*]]: tensor<32xf64>,
-// CHECK-SAME:    %[[VAL_2:.*]]: tensor<32xf64> {linalg.inplaceable = true}) -> tensor<32xf64> {
+// CHECK-SAME:    %[[VAL_2:.*]]: tensor<32xf64>) -> tensor<32xf64> {
 // CHECK-DAG:     %[[VAL_3:.*]] = arith.constant 32 : index
 // CHECK-DAG:     %[[VAL_4:.*]] = arith.constant 0 : index
 // CHECK-DAG:     %[[VAL_5:.*]] = arith.constant true
@@ -207,8 +207,8 @@ func.func @neg(%arga: tensor<32xf64, #SV>,
 // CHECK:         return %[[VAL_33]] : tensor<32xf64>
 // CHECK:       }
 func.func @add(%arga: tensor<32xf64, #SV>,
-          %argb: tensor<32xf64>,
-          %argx: tensor<32xf64> {linalg.inplaceable = true}) -> tensor<32xf64> {
+               %argb: tensor<32xf64>,
+               %argx: tensor<32xf64>) -> tensor<32xf64> {
   %0 = linalg.generic #trait2
      ins(%arga, %argb: tensor<32xf64, #SV>, tensor<32xf64>)
     outs(%argx: tensor<32xf64>) {
@@ -222,7 +222,7 @@ func.func @add(%arga: tensor<32xf64, #SV>,
 // CHECK-LABEL: func @sub(
 // CHECK-SAME:    %[[VAL_0:.*]]: tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>>,
 // CHECK-SAME:    %[[VAL_1:.*]]: tensor<32xf64>,
-// CHECK-SAME:    %[[VAL_2:.*]]: tensor<32xf64> {linalg.inplaceable = true}) -> tensor<32xf64> {
+// CHECK-SAME:    %[[VAL_2:.*]]: tensor<32xf64>) -> tensor<32xf64> {
 // CHECK-DAG:     %[[VAL_3:.*]] = arith.constant 32 : index
 // CHECK-DAG:     %[[VAL_4:.*]] = arith.constant 0 : index
 // CHECK-DAG:     %[[VAL_5:.*]] = arith.constant true
@@ -269,8 +269,8 @@ func.func @add(%arga: tensor<32xf64, #SV>,
 // CHECK:         return %[[VAL_35]] : tensor<32xf64>
 // CHECK:       }
 func.func @sub(%arga: tensor<32xf64, #SV>,
-          %argb: tensor<32xf64>,
-          %argx: tensor<32xf64> {linalg.inplaceable = true}) -> tensor<32xf64> {
+               %argb: tensor<32xf64>,
+               %argx: tensor<32xf64>) -> tensor<32xf64> {
   %0 = linalg.generic #trait2
      ins(%arga, %argb: tensor<32xf64, #SV>, tensor<32xf64>)
     outs(%argx: tensor<32xf64>) {
@@ -284,7 +284,7 @@ func.func @sub(%arga: tensor<32xf64, #SV>,
 // CHECK-LABEL: func @mul(
 // CHECK-SAME:    %[[VAL_0:.*]]: tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>>,
 // CHECK-SAME:    %[[VAL_1:.*]]: tensor<32xf64>,
-// CHECK-SAME:    %[[VAL_2:.*]]: tensor<32xf64> {linalg.inplaceable = true}) -> tensor<32xf64> {
+// CHECK-SAME:    %[[VAL_2:.*]]: tensor<32xf64>) -> tensor<32xf64> {
 // CHECK-DAG:     %[[VAL_3:.*]] = arith.constant 0 : index
 // CHECK-DAG:     %[[VAL_4:.*]] = arith.constant 1 : index
 // CHECK:         %[[VAL_5:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_3]] : tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>>
@@ -305,8 +305,8 @@ func.func @sub(%arga: tensor<32xf64, #SV>,
 // CHECK:         return %[[VAL_17]] : tensor<32xf64>
 // CHECK:       }
 func.func @mul(%arga: tensor<32xf64, #SV>,
-          %argb: tensor<32xf64>,
-          %argx: tensor<32xf64> {linalg.inplaceable = true}) -> tensor<32xf64> {
+               %argb: tensor<32xf64>,
+               %argx: tensor<32xf64>) -> tensor<32xf64> {
   %0 = linalg.generic #trait2
      ins(%arga, %argb: tensor<32xf64, #SV>, tensor<32xf64>)
     outs(%argx: tensor<32xf64>) {
@@ -319,7 +319,7 @@ func.func @mul(%arga: tensor<32xf64, #SV>,
 
 // CHECK-LABEL: func @divbyc(
 // CHECK-SAME:    %[[VAL_0:.*]]: tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>>,
-// CHECK-SAME:    %[[VAL_1:.*]]: tensor<32xf64> {linalg.inplaceable = true}) -> tensor<32xf64> {
+// CHECK-SAME:    %[[VAL_1:.*]]: tensor<32xf64>) -> tensor<32xf64> {
 // CHECK-DAG:     %[[VAL_2:.*]] = arith.constant 2.000000e+00 : f64
 // CHECK-DAG:     %[[VAL_3:.*]] = arith.constant 0 : index
 // CHECK-DAG:     %[[VAL_4:.*]] = arith.constant 1 : index
@@ -339,7 +339,7 @@ func.func @mul(%arga: tensor<32xf64, #SV>,
 // CHECK:         return %[[VAL_15]] : tensor<32xf64>
 // CHECK:       }
 func.func @divbyc(%arga: tensor<32xf64, #SV>,
-           %argx: tensor<32xf64> {linalg.inplaceable = true}) -> tensor<32xf64> {
+                  %argx: tensor<32xf64>) -> tensor<32xf64> {
   %c = arith.constant 2.0 : f64
   %0 = linalg.generic #traitc
      ins(%arga: tensor<32xf64, #SV>)

diff  --git a/mlir/test/Dialect/SparseTensor/sparse_int_ops.mlir b/mlir/test/Dialect/SparseTensor/sparse_int_ops.mlir
index dd615706016fa..024311f76b676 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_int_ops.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_int_ops.mlir
@@ -25,7 +25,7 @@
 // CHECK-LABEL:   func @add(
 // CHECK-SAME:              %[[VAL_0:.*]]: tensor<32xi64, #sparse_tensor.encoding<{{{.*}}}>>,
 // CHECK-SAME:              %[[VAL_1:.*]]: tensor<32xi64>,
-// CHECK-SAME:              %[[VAL_2:.*]]: tensor<32xi64> {linalg.inplaceable = true}) -> tensor<32xi64> {
+// CHECK-SAME:              %[[VAL_2:.*]]: tensor<32xi64>) -> tensor<32xi64> {
 // CHECK-DAG:           %[[VAL_3:.*]] = arith.constant 32 : index
 // CHECK-DAG:           %[[VAL_4:.*]] = arith.constant 0 : index
 // CHECK-DAG:           %[[VAL_5:.*]] = arith.constant true
@@ -70,8 +70,8 @@
 // CHECK:           return %[[VAL_33]] : tensor<32xi64>
 // CHECK:         }
 func.func @add(%arga: tensor<32xi64, #SV>,
-          %argb: tensor<32xi64>,
-          %argx: tensor<32xi64> {linalg.inplaceable = true}) -> tensor<32xi64> {
+               %argb: tensor<32xi64>,
+               %argx: tensor<32xi64>) -> tensor<32xi64> {
   %0 = linalg.generic #trait2
      ins(%arga, %argb: tensor<32xi64, #SV>, tensor<32xi64>)
     outs(%argx: tensor<32xi64>) {
@@ -85,7 +85,7 @@ func.func @add(%arga: tensor<32xi64, #SV>,
 // CHECK-LABEL:   func @sub(
 // CHECK-SAME:              %[[VAL_0:.*]]: tensor<32xi64, #sparse_tensor.encoding<{{{.*}}}>>,
 // CHECK-SAME:              %[[VAL_1:.*]]: tensor<32xi64>,
-// CHECK-SAME:              %[[VAL_2:.*]]: tensor<32xi64> {linalg.inplaceable = true}) -> tensor<32xi64> {
+// CHECK-SAME:              %[[VAL_2:.*]]: tensor<32xi64>) -> tensor<32xi64> {
 // CHECK-DAG:           %[[VAL_3:.*]] = arith.constant 32 : index
 // CHECK-DAG:           %[[VAL_4:.*]] = arith.constant 0 : index
 // CHECK-DAG:           %[[VAL_5:.*]] = arith.constant true
@@ -133,8 +133,8 @@ func.func @add(%arga: tensor<32xi64, #SV>,
 // CHECK:           return %[[VAL_36]] : tensor<32xi64>
 // CHECK:         }
 func.func @sub(%arga: tensor<32xi64, #SV>,
-          %argb: tensor<32xi64>,
-          %argx: tensor<32xi64> {linalg.inplaceable = true}) -> tensor<32xi64> {
+               %argb: tensor<32xi64>,
+               %argx: tensor<32xi64>) -> tensor<32xi64> {
   %0 = linalg.generic #trait2
      ins(%arga, %argb: tensor<32xi64, #SV>, tensor<32xi64>)
     outs(%argx: tensor<32xi64>) {
@@ -148,7 +148,7 @@ func.func @sub(%arga: tensor<32xi64, #SV>,
 // CHECK-LABEL:   func @mul(
 // CHECK-SAME:              %[[VAL_0:.*]]: tensor<32xi64, #sparse_tensor.encoding<{{{.*}}}>>,
 // CHECK-SAME:              %[[VAL_1:.*]]: tensor<32xi64>,
-// CHECK-SAME:              %[[VAL_2:.*]]: tensor<32xi64> {linalg.inplaceable = true}) -> tensor<32xi64> {
+// CHECK-SAME:              %[[VAL_2:.*]]: tensor<32xi64>) -> tensor<32xi64> {
 // CHECK-DAG:           %[[VAL_3:.*]] = arith.constant 0 : index
 // CHECK-DAG:           %[[VAL_4:.*]] = arith.constant 1 : index
 // CHECK:           %[[VAL_5:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_3]] : tensor<32xi64, #sparse_tensor.encoding<{{{.*}}}>>
@@ -169,8 +169,8 @@ func.func @sub(%arga: tensor<32xi64, #SV>,
 // CHECK:           return %[[VAL_17]] : tensor<32xi64>
 // CHECK:         }
 func.func @mul(%arga: tensor<32xi64, #SV>,
-          %argb: tensor<32xi64>,
-          %argx: tensor<32xi64> {linalg.inplaceable = true}) -> tensor<32xi64> {
+               %argb: tensor<32xi64>,
+               %argx: tensor<32xi64>) -> tensor<32xi64> {
   %0 = linalg.generic #trait2
      ins(%arga, %argb: tensor<32xi64, #SV>, tensor<32xi64>)
     outs(%argx: tensor<32xi64>) {
@@ -183,7 +183,7 @@ func.func @mul(%arga: tensor<32xi64, #SV>,
 
 // CHECK-LABEL:   func @divsbyc(
 // CHECK-SAME:                  %[[VAL_0:.*]]: tensor<32xi64, #sparse_tensor.encoding<{{{.*}}}>>,
-// CHECK-SAME:                  %[[VAL_1:.*]]: tensor<32xi64> {linalg.inplaceable = true}) -> tensor<32xi64> {
+// CHECK-SAME:                  %[[VAL_1:.*]]: tensor<32xi64>) -> tensor<32xi64> {
 // CHECK-DAG:           %[[VAL_2:.*]] = arith.constant 2 : i64
 // CHECK-DAG:           %[[VAL_3:.*]] = arith.constant 0 : index
 // CHECK-DAG:           %[[VAL_4:.*]] = arith.constant 1 : index
@@ -203,7 +203,7 @@ func.func @mul(%arga: tensor<32xi64, #SV>,
 // CHECK:           return %[[VAL_15]] : tensor<32xi64>
 // CHECK:         }
 func.func @divsbyc(%arga: tensor<32xi64, #SV>,
-              %argx: tensor<32xi64> {linalg.inplaceable = true}) -> tensor<32xi64> {
+                   %argx: tensor<32xi64>) -> tensor<32xi64> {
   %c = arith.constant 2 : i64
   %0 = linalg.generic #traitc
      ins(%arga: tensor<32xi64, #SV>)
@@ -217,7 +217,7 @@ func.func @divsbyc(%arga: tensor<32xi64, #SV>,
 
 // CHECK-LABEL:   func @divubyc(
 // CHECK-SAME:                  %[[VAL_0:.*]]: tensor<32xi64, #sparse_tensor.encoding<{{{.*}}}>>,
-// CHECK-SAME:                  %[[VAL_1:.*]]: tensor<32xi64> {linalg.inplaceable = true}) -> tensor<32xi64> {
+// CHECK-SAME:                  %[[VAL_1:.*]]: tensor<32xi64>) -> tensor<32xi64> {
 // CHECK-DAG:           %[[VAL_2:.*]] = arith.constant 2 : i64
 // CHECK-DAG:           %[[VAL_3:.*]] = arith.constant 0 : index
 // CHECK-DAG:           %[[VAL_4:.*]] = arith.constant 1 : index
@@ -237,7 +237,7 @@ func.func @divsbyc(%arga: tensor<32xi64, #SV>,
 // CHECK:           return %[[VAL_15]] : tensor<32xi64>
 // CHECK:         }
 func.func @divubyc(%arga: tensor<32xi64, #SV>,
-              %argx: tensor<32xi64> {linalg.inplaceable = true}) -> tensor<32xi64> {
+                   %argx: tensor<32xi64>) -> tensor<32xi64> {
   %c = arith.constant 2 : i64
   %0 = linalg.generic #traitc
      ins(%arga: tensor<32xi64, #SV>)
@@ -252,7 +252,7 @@ func.func @divubyc(%arga: tensor<32xi64, #SV>,
 // CHECK-LABEL:   func @and(
 // CHECK-SAME:              %[[VAL_0:.*]]: tensor<32xi64, #sparse_tensor.encoding<{{{.*}}}>>,
 // CHECK-SAME:              %[[VAL_1:.*]]: tensor<32xi64>,
-// CHECK-SAME:              %[[VAL_2:.*]]: tensor<32xi64> {linalg.inplaceable = true}) -> tensor<32xi64> {
+// CHECK-SAME:              %[[VAL_2:.*]]: tensor<32xi64>) -> tensor<32xi64> {
 // CHECK-DAG:           %[[VAL_3:.*]] = arith.constant 0 : index
 // CHECK-DAG:           %[[VAL_4:.*]] = arith.constant 1 : index
 // CHECK:           %[[VAL_5:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_3]] : tensor<32xi64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
@@ -273,8 +273,8 @@ func.func @divubyc(%arga: tensor<32xi64, #SV>,
 // CHECK:           return %[[VAL_17]] : tensor<32xi64>
 // CHECK:         }
 func.func @and(%arga: tensor<32xi64, #SV>,
-          %argb: tensor<32xi64>,
-          %argx: tensor<32xi64> {linalg.inplaceable = true}) -> tensor<32xi64> {
+               %argb: tensor<32xi64>,
+               %argx: tensor<32xi64>) -> tensor<32xi64> {
   %0 = linalg.generic #trait2
      ins(%arga, %argb: tensor<32xi64, #SV>, tensor<32xi64>)
     outs(%argx: tensor<32xi64>) {
@@ -288,7 +288,7 @@ func.func @and(%arga: tensor<32xi64, #SV>,
 // CHECK-LABEL:   func @or(
 // CHECK-SAME:             %[[VAL_0:.*]]: tensor<32xi64, #sparse_tensor.encoding<{{{.*}}}>>,
 // CHECK-SAME:             %[[VAL_1:.*]]: tensor<32xi64>,
-// CHECK-SAME:             %[[VAL_2:.*]]: tensor<32xi64> {linalg.inplaceable = true}) -> tensor<32xi64> {
+// CHECK-SAME:             %[[VAL_2:.*]]: tensor<32xi64>) -> tensor<32xi64> {
 // CHECK-DAG:           %[[VAL_3:.*]] = arith.constant 32 : index
 // CHECK-DAG:           %[[VAL_4:.*]] = arith.constant 0 : index
 // CHECK-DAG:           %[[VAL_5:.*]] = arith.constant true
@@ -333,8 +333,8 @@ func.func @and(%arga: tensor<32xi64, #SV>,
 // CHECK:           return %[[VAL_33]] : tensor<32xi64>
 // CHECK:         }
 func.func @or(%arga: tensor<32xi64, #SV>,
-         %argb: tensor<32xi64>,
-         %argx: tensor<32xi64> {linalg.inplaceable = true}) -> tensor<32xi64> {
+              %argb: tensor<32xi64>,
+              %argx: tensor<32xi64>) -> tensor<32xi64> {
   %0 = linalg.generic #trait2
      ins(%arga, %argb: tensor<32xi64, #SV>, tensor<32xi64>)
     outs(%argx: tensor<32xi64>) {
@@ -348,7 +348,7 @@ func.func @or(%arga: tensor<32xi64, #SV>,
 // CHECK-LABEL:   func @xor(
 // CHECK-SAME:             %[[VAL_0:.*]]: tensor<32xi64, #sparse_tensor.encoding<{{{.*}}}>>,
 // CHECK-SAME:             %[[VAL_1:.*]]: tensor<32xi64>,
-// CHECK-SAME:             %[[VAL_2:.*]]: tensor<32xi64> {linalg.inplaceable = true}) -> tensor<32xi64> {
+// CHECK-SAME:             %[[VAL_2:.*]]: tensor<32xi64>) -> tensor<32xi64> {
 // CHECK-DAG:           %[[VAL_3:.*]] = arith.constant 32 : index
 // CHECK-DAG:           %[[VAL_4:.*]] = arith.constant 0 : index
 // CHECK-DAG:           %[[VAL_5:.*]] = arith.constant true
@@ -393,8 +393,8 @@ func.func @or(%arga: tensor<32xi64, #SV>,
 // CHECK:           return %[[VAL_33]] : tensor<32xi64>
 // CHECK:         }
 func.func @xor(%arga: tensor<32xi64, #SV>,
-          %argb: tensor<32xi64>,
-          %argx: tensor<32xi64> {linalg.inplaceable = true}) -> tensor<32xi64> {
+               %argb: tensor<32xi64>,
+               %argx: tensor<32xi64>) -> tensor<32xi64> {
   %0 = linalg.generic #trait2
      ins(%arga, %argb: tensor<32xi64, #SV>, tensor<32xi64>)
     outs(%argx: tensor<32xi64>) {
@@ -407,7 +407,7 @@ func.func @xor(%arga: tensor<32xi64, #SV>,
 
 // CHECK-LABEL:   func @ashrbyc(
 // CHECK-SAME:                  %[[VAL_0:.*]]: tensor<32xi64, #sparse_tensor.encoding<{{{.*}}}>>,
-// CHECK-SAME:                  %[[VAL_1:.*]]: tensor<32xi64> {linalg.inplaceable = true}) -> tensor<32xi64> {
+// CHECK-SAME:                  %[[VAL_1:.*]]: tensor<32xi64>) -> tensor<32xi64> {
 // CHECK-DAG:           %[[VAL_2:.*]] = arith.constant 2 : i64
 // CHECK-DAG:           %[[VAL_3:.*]] = arith.constant 0 : index
 // CHECK-DAG:           %[[VAL_4:.*]] = arith.constant 1 : index
@@ -427,7 +427,7 @@ func.func @xor(%arga: tensor<32xi64, #SV>,
 // CHECK:           return %[[VAL_15]] : tensor<32xi64>
 // CHECK:         }
 func.func @ashrbyc(%arga: tensor<32xi64, #SV>,
-              %argx: tensor<32xi64> {linalg.inplaceable = true}) -> tensor<32xi64> {
+                   %argx: tensor<32xi64>) -> tensor<32xi64> {
   %c = arith.constant 2 : i64
   %0 = linalg.generic #traitc
      ins(%arga: tensor<32xi64, #SV>)
@@ -441,7 +441,7 @@ func.func @ashrbyc(%arga: tensor<32xi64, #SV>,
 
 // CHECK-LABEL:   func @lsrbyc(
 // CHECK-SAME:                 %[[VAL_0:.*]]: tensor<32xi64, #sparse_tensor.encoding<{{{.*}}}>>,
-// CHECK-SAME:                 %[[VAL_1:.*]]: tensor<32xi64> {linalg.inplaceable = true}) -> tensor<32xi64> {
+// CHECK-SAME:                 %[[VAL_1:.*]]: tensor<32xi64>) -> tensor<32xi64> {
 // CHECK-DAG:           %[[VAL_2:.*]] = arith.constant 2 : i64
 // CHECK-DAG:           %[[VAL_3:.*]] = arith.constant 0 : index
 // CHECK-DAG:           %[[VAL_4:.*]] = arith.constant 1 : index
@@ -461,7 +461,7 @@ func.func @ashrbyc(%arga: tensor<32xi64, #SV>,
 // CHECK:           return %[[VAL_15]] : tensor<32xi64>
 // CHECK:         }
 func.func @lsrbyc(%arga: tensor<32xi64, #SV>,
-             %argx: tensor<32xi64> {linalg.inplaceable = true}) -> tensor<32xi64> {
+                  %argx: tensor<32xi64>) -> tensor<32xi64> {
   %c = arith.constant 2 : i64
   %0 = linalg.generic #traitc
      ins(%arga: tensor<32xi64, #SV>)
@@ -475,7 +475,7 @@ func.func @lsrbyc(%arga: tensor<32xi64, #SV>,
 
 // CHECK-LABEL:   func @lslbyc(
 // CHECK-SAME:                 %[[VAL_0:.*]]: tensor<32xi64, #sparse_tensor.encoding<{{{.*}}}>>,
-// CHECK-SAME:                 %[[VAL_1:.*]]: tensor<32xi64> {linalg.inplaceable = true}) -> tensor<32xi64> {
+// CHECK-SAME:                 %[[VAL_1:.*]]: tensor<32xi64>) -> tensor<32xi64> {
 // CHECK-DAG:           %[[VAL_2:.*]] = arith.constant 2 : i64
 // CHECK-DAG:           %[[VAL_3:.*]] = arith.constant 0 : index
 // CHECK-DAG:           %[[VAL_4:.*]] = arith.constant 1 : index
@@ -495,7 +495,7 @@ func.func @lsrbyc(%arga: tensor<32xi64, #SV>,
 // CHECK:           return %[[VAL_15]] : tensor<32xi64>
 // CHECK:         }
 func.func @lslbyc(%arga: tensor<32xi64, #SV>,
-             %argx: tensor<32xi64> {linalg.inplaceable = true}) -> tensor<32xi64> {
+                  %argx: tensor<32xi64>) -> tensor<32xi64> {
   %c = arith.constant 2 : i64
   %0 = linalg.generic #traitc
      ins(%arga: tensor<32xi64, #SV>)

diff  --git a/mlir/test/Dialect/SparseTensor/sparse_kernels.mlir b/mlir/test/Dialect/SparseTensor/sparse_kernels.mlir
index af99f00a4807b..289bdcb87bf9d 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_kernels.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_kernels.mlir
@@ -20,9 +20,7 @@
 // CHECK-DAG:       %[[VAL_9:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_4]] : tensor<10x20xf32, #sparse_tensor.encoding<{{{.*}}}>>
 // CHECK-DAG:       %[[VAL_10:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<10x20xf32, #sparse_tensor.encoding<{{{.*}}}>>
 // CHECK-DAG:       %[[VAL_11:.*]] = bufferization.to_memref %[[VAL_1]] : memref<20x30xf32>
-// CHECK-DAG:       %[[VAL_12:.*]] = bufferization.to_memref %[[VAL_2]] : memref<10x30xf32>
-// CHECK-DAG:       %[[VAL_13:.*]] = memref.alloc() : memref<10x30xf32>
-// CHECK:           memref.copy %[[VAL_12]], %[[VAL_13]] : memref<10x30xf32> to memref<10x30xf32>
+// CHECK-DAG:       %[[VAL_13:.*]] = bufferization.to_memref %[[VAL_2]] : memref<10x30xf32>
 // CHECK:           %[[VAL_14:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_3]]] : memref<?xindex>
 // CHECK:           %[[VAL_15:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_4]]] : memref<?xindex>
 // CHECK:           scf.for %[[VAL_16:.*]] = %[[VAL_14]] to %[[VAL_15]] step %[[VAL_4]] {
@@ -166,9 +164,7 @@ func.func @matmul2(%A: tensor<4x8xf64, #DCSR>,
 // CHECK-DAG:       %[[VAL_9:.*]] = sparse_tensor.pointers %[[VAL_1]], %[[VAL_4]] : tensor<3x3xi32, #sparse_tensor.encoding<{{{.*}}}>>
 // CHECK-DAG:       %[[VAL_10:.*]] = sparse_tensor.indices %[[VAL_1]], %[[VAL_4]] : tensor<3x3xi32, #sparse_tensor.encoding<{{{.*}}}>>
 // CHECK-DAG:       %[[VAL_11:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<3x3xi32, #sparse_tensor.encoding<{{{.*}}}>>
-// CHECK-DAG:       %[[VAL_12:.*]] = bufferization.to_memref %[[VAL_2]] : memref<6x6xi32>
-// CHECK-DAG:       %[[VAL_13:.*]] = memref.alloc() : memref<6x6xi32>
-// CHECK:           memref.copy %[[VAL_12]], %[[VAL_13]] : memref<6x6xi32> to memref<6x6xi32>
+// CHECK-DAG:       %[[VAL_13:.*]] = bufferization.to_memref %[[VAL_2]] : memref<6x6xi32>
 // CHECK:           %[[VAL_14:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_3]]] : memref<?xindex>
 // CHECK:           %[[VAL_15:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_4]]] : memref<?xindex>
 // CHECK:           scf.for %[[VAL_16:.*]] = %[[VAL_14]] to %[[VAL_15]] step %[[VAL_4]] {
@@ -218,9 +214,7 @@ func.func @conv2d(%input:  tensor<8x8xi32>,
 // CHECK-DAG:       %[[VAL_10:.*]] = sparse_tensor.pointers %[[VAL_1]], %[[VAL_5]] : tensor<3x6xi8, #sparse_tensor.encoding<{{{.*}}}>>
 // CHECK-DAG:       %[[VAL_11:.*]] = sparse_tensor.indices %[[VAL_1]], %[[VAL_5]] : tensor<3x6xi8, #sparse_tensor.encoding<{{{.*}}}>>
 // CHECK-DAG:       %[[VAL_12:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<3x6xi8, #sparse_tensor.encoding<{{{.*}}}>>
-// CHECK-DAG:       %[[VAL_13:.*]] = bufferization.to_memref %[[VAL_2]] : memref<5x6xi64>
-// CHECK-DAG:       %[[VAL_14:.*]] = memref.alloc() : memref<5x6xi64>
-// CHECK:           memref.copy %[[VAL_13]], %[[VAL_14]] : memref<5x6xi64> to memref<5x6xi64>
+// CHECK-DAG:       %[[VAL_14:.*]] = bufferization.to_memref %[[VAL_2]] : memref<5x6xi64>
 // CHECK:           %[[VAL_15:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_4]]] : memref<?xindex>
 // CHECK:           %[[VAL_16:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_5]]] : memref<?xindex>
 // CHECK:           scf.for %[[VAL_17:.*]] = %[[VAL_15]] to %[[VAL_16]] step %[[VAL_5]] {
@@ -266,9 +260,7 @@ func.func @quantized_matmul(%input1: tensor<5x3xi8>,
 // CHECK-DAG:       %[[VAL_8:.*]] = sparse_tensor.pointers %[[VAL_1:.*]], %[[VAL_3]] : tensor<1024xf32, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
 // CHECK-DAG:       %[[VAL_9:.*]] = sparse_tensor.indices %[[VAL_1]], %[[VAL_3]] : tensor<1024xf32, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
 // CHECK-DAG:       %[[VAL_10:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<1024xf32, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xf32>
-// CHECK-DAG:       %[[VAL_11:.*]] = memref.alloc() : memref<f32>
-// CHECK-DAG:       %[[VAL_12:.*]] = bufferization.to_memref %[[VAL_2:.*]] : memref<f32>
-// CHECK-DAG:       memref.copy %[[VAL_12]], %[[VAL_11]] : memref<f32> to memref<f32>
+// CHECK-DAG:       %[[VAL_11:.*]] = bufferization.to_memref %[[VAL_2:.*]] : memref<f32>
 // CHECK-DAG:       %[[VAL_13:.*]] = memref.load %[[VAL_11]][] : memref<f32>
 // CHECK-DAG:       %[[VAL_14:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_3]]] : memref<?xindex>
 // CHECK-DAG:       %[[VAL_15:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_4]]] : memref<?xindex>

diff  --git a/mlir/test/Dialect/SparseTensor/sparse_lower.mlir b/mlir/test/Dialect/SparseTensor/sparse_lower.mlir
index f20fcbe3c19db..502b89c610898 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_lower.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_lower.mlir
@@ -31,9 +31,7 @@
 // CHECK-HIR-DAG:       %[[VAL_7:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_5]] : tensor<32x64xf64, #sparse_tensor.encoding<{{{.*}}}>>
 // CHECK-HIR-DAG:       %[[VAL_8:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x64xf64, #sparse_tensor.encoding<{{{.*}}}>>
 // CHECK-HIR-DAG:       %[[VAL_9:.*]] = bufferization.to_memref %[[VAL_1]] : memref<64xf64>
-// CHECK-HIR-DAG:       %[[VAL_10:.*]] = bufferization.to_memref %[[VAL_2]] : memref<32xf64>
-// CHECK-HIR-DAG:       %[[VAL_11:.*]] = memref.alloc() : memref<32xf64>
-// CHECK-HIR:           memref.copy %[[VAL_10]], %[[VAL_11]] : memref<32xf64> to memref<32xf64>
+// CHECK-HIR-DAG:       %[[VAL_11:.*]] = bufferization.to_memref %[[VAL_2]] : memref<32xf64>
 // CHECK-HIR:           scf.for %[[VAL_12:.*]] = %[[VAL_4]] to %[[VAL_3]] step %[[VAL_5]] {
 // CHECK-HIR-DAG:         %[[VAL_13:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_12]]] : memref<?xindex>
 // CHECK-HIR-DAG:         %[[VAL_14:.*]] = arith.addi %[[VAL_12]], %[[VAL_5]] : index
@@ -64,9 +62,7 @@
 // CHECK-MIR-DAG:       %[[VAL_7:.*]] = call @sparseIndices0(%[[VAL_0]], %[[VAL_5]]) : (!llvm.ptr<i8>, index) -> memref<?xindex>
 // CHECK-MIR-DAG:       %[[VAL_8:.*]] = call @sparseValuesF64(%[[VAL_0]]) : (!llvm.ptr<i8>) -> memref<?xf64>
 // CHECK-MIR-DAG:       %[[VAL_9:.*]] = bufferization.to_memref %[[VAL_1]] : memref<64xf64>
-// CHECK-MIR-DAG:       %[[VAL_10:.*]] = bufferization.to_memref %[[VAL_2]] : memref<32xf64>
-// CHECK-MIR-DAG:       %[[VAL_11:.*]] = memref.alloc() : memref<32xf64>
-// CHECK-MIR:           memref.copy %[[VAL_10]], %[[VAL_11]] : memref<32xf64> to memref<32xf64>
+// CHECK-MIR-DAG:       %[[VAL_11:.*]] = bufferization.to_memref %[[VAL_2]] : memref<32xf64>
 // CHECK-MIR:           scf.for %[[VAL_14:.*]] = %[[VAL_4]] to %[[VAL_3]] step %[[VAL_5]] {
 // CHECK-MIR-DAG:         %[[VAL_15:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_14]]] : memref<?xindex>
 // CHECK-MIR-DAG:         %[[VAL_16:.*]] = arith.addi %[[VAL_14]], %[[VAL_5]] : index
@@ -96,13 +92,11 @@
 // CHECK-LIR-DAG:       %[[VAL_6:.*]] = call @sparsePointers0(%[[VAL_0]], %[[VAL_5]]) : (!llvm.ptr<i8>, index) -> memref<?xindex>
 // CHECK-LIR-DAG:       %[[VAL_7:.*]] = call @sparseIndices0(%[[VAL_0]], %[[VAL_5]]) : (!llvm.ptr<i8>, index) -> memref<?xindex>
 // CHECK-LIR-DAG:       %[[VAL_8:.*]] = call @sparseValuesF64(%[[VAL_0]]) : (!llvm.ptr<i8>) -> memref<?xf64>
-// CHECK-LIR-DAG:       %[[VAL_9:.*]] = memref.alloc() : memref<32xf64>
-// CHECK-LIR:           memref.copy %[[VAL_2]], %[[VAL_9]] : memref<32xf64> to memref<32xf64>
 // CHECK-LIR:           scf.for %[[VAL_12:.*]] = %[[VAL_4]] to %[[VAL_3]] step %[[VAL_5]] {
 // CHECK-LIR-DAG:         %[[VAL_13:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_12]]] : memref<?xindex>
 // CHECK-LIR-DAG:         %[[VAL_14:.*]] = arith.addi %[[VAL_12]], %[[VAL_5]] : index
 // CHECK-LIR-DAG:         %[[VAL_15:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_14]]] : memref<?xindex>
-// CHECK-LIR-DAG:         %[[VAL_16:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_12]]] : memref<32xf64>
+// CHECK-LIR-DAG:         %[[VAL_16:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_12]]] : memref<32xf64>
 // CHECK-LIR:             %[[VAL_17:.*]] = scf.for %[[VAL_18:.*]] = %[[VAL_13]] to %[[VAL_15]] step %[[VAL_5]] iter_args(%[[VAL_19:.*]] = %[[VAL_16]]) -> (f64) {
 // CHECK-LIR:               %[[VAL_20:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_18]]] : memref<?xindex>
 // CHECK-LIR:               %[[VAL_21:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_18]]] : memref<?xf64>
@@ -111,9 +105,9 @@
 // CHECK-LIR:               %[[VAL_24:.*]] = arith.addf %[[VAL_19]], %[[VAL_23]] : f64
 // CHECK-LIR:               scf.yield %[[VAL_24]] : f64
 // CHECK-LIR:             }
-// CHECK-LIR:             memref.store %[[VAL_17]], %[[VAL_9]]{{\[}}%[[VAL_12]]] : memref<32xf64>
+// CHECK-LIR:             memref.store %[[VAL_17]], %[[VAL_2]]{{\[}}%[[VAL_12]]] : memref<32xf64>
 // CHECK-LIR:           }
-// CHECK-LIR:           return %[[VAL_9]] : memref<32xf64>
+// CHECK-LIR:           return %[[VAL_2]] : memref<32xf64>
 // CHECK-LIR:         }
 
 func.func @matvec(%arga: tensor<32x64xf64, #CSR>,

diff  --git a/mlir/test/Dialect/SparseTensor/sparse_lower_col.mlir b/mlir/test/Dialect/SparseTensor/sparse_lower_col.mlir
index ed737f7f28e51..3c9d7f6601497 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_lower_col.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_lower_col.mlir
@@ -34,9 +34,7 @@
 // CHECK-HIR-DAG:       %[[VAL_7:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_5]] : tensor<32x64xf64, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], dimOrdering = affine_map<(d0, d1) -> (d1, d0)>, pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
 // CHECK-HIR-DAG:       %[[VAL_8:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x64xf64, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], dimOrdering = affine_map<(d0, d1) -> (d1, d0)>, pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf64>
 // CHECK-HIR-DAG:       %[[VAL_9:.*]] = bufferization.to_memref %[[VAL_1]] : memref<64xf64>
-// CHECK-HIR-DAG:       %[[VAL_10:.*]] = bufferization.to_memref %[[VAL_2]] : memref<32xf64>
-// CHECK-HIR-DAG:       %[[VAL_11:.*]] = memref.alloc() : memref<32xf64>
-// CHECK-HIR:           memref.copy %[[VAL_10]], %[[VAL_11]] : memref<32xf64> to memref<32xf64>
+// CHECK-HIR-DAG:       %[[VAL_11:.*]] = bufferization.to_memref %[[VAL_2]] : memref<32xf64>
 // CHECK-HIR:           scf.for %[[VAL_12:.*]] = %[[VAL_4]] to %[[VAL_3]] step %[[VAL_5]] {
 // CHECK-HIR:             %[[VAL_13:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_12]]] : memref<64xf64>
 // CHECK-HIR:             %[[VAL_14:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_12]]] : memref<?xindex>
@@ -66,9 +64,7 @@
 // CHECK-MIR-DAG:       %[[VAL_8:.*]] = call @sparseIndices0(%[[VAL_0]], %[[VAL_6]]) : (!llvm.ptr<i8>, index) -> memref<?xindex>
 // CHECK-MIR-DAG:       %[[VAL_9:.*]] = call @sparseValuesF64(%[[VAL_0]]) : (!llvm.ptr<i8>) -> memref<?xf64>
 // CHECK-MIR-DAG:       %[[VAL_10:.*]] = bufferization.to_memref %[[VAL_1]] : memref<64xf64>
-// CHECK-MIR-DAG:       %[[VAL_11:.*]] = bufferization.to_memref %[[VAL_2]] : memref<32xf64>
-// CHECK-MIR-DAG:       %[[VAL_12:.*]] = memref.alloc() : memref<32xf64>
-// CHECK-MIR:           memref.copy %[[VAL_11]], %[[VAL_12]] : memref<32xf64> to memref<32xf64>
+// CHECK-MIR-DAG:       %[[VAL_12:.*]] = bufferization.to_memref %[[VAL_2]] : memref<32xf64>
 // CHECK-MIR:           scf.for %[[VAL_15:.*]] = %[[VAL_5]] to %[[VAL_3]] step %[[VAL_6]] {
 // CHECK-MIR:             %[[VAL_16:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_15]]] : memref<64xf64>
 // CHECK-MIR:             %[[VAL_17:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_15]]] : memref<?xindex>
@@ -97,8 +93,6 @@
 // CHECK-LIR-DAG:       %[[VAL_7:.*]] = call @sparsePointers0(%[[VAL_0]], %[[VAL_6]]) : (!llvm.ptr<i8>, index) -> memref<?xindex>
 // CHECK-LIR-DAG:       %[[VAL_8:.*]] = call @sparseIndices0(%[[VAL_0]], %[[VAL_6]]) : (!llvm.ptr<i8>, index) -> memref<?xindex>
 // CHECK-LIR-DAG:       %[[VAL_9:.*]] = call @sparseValuesF64(%[[VAL_0]]) : (!llvm.ptr<i8>) -> memref<?xf64>
-// CHECK-LIR-DAG:       %[[VAL_10:.*]] = memref.alloc() : memref<32xf64>
-// CHECK-LIR:           memref.copy %[[VAL_2]], %[[VAL_10]] : memref<32xf64> to memref<32xf64>
 // CHECK-LIR:           scf.for %[[VAL_13:.*]] = %[[VAL_5]] to %[[VAL_3]] step %[[VAL_6]] {
 // CHECK-LIR:             %[[VAL_14:.*]] = memref.load %[[VAL_1]]{{\[}}%[[VAL_13]]] : memref<64xf64>
 // CHECK-LIR:             %[[VAL_15:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_13]]] : memref<?xindex>
@@ -106,14 +100,14 @@
 // CHECK-LIR:             %[[VAL_17:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_16]]] : memref<?xindex>
 // CHECK-LIR:             scf.for %[[VAL_18:.*]] = %[[VAL_15]] to %[[VAL_17]] step %[[VAL_6]] {
 // CHECK-LIR:               %[[VAL_19:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_18]]] : memref<?xindex>
-// CHECK-LIR:               %[[VAL_20:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_19]]] : memref<32xf64>
+// CHECK-LIR:               %[[VAL_20:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_19]]] : memref<32xf64>
 // CHECK-LIR:               %[[VAL_21:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_18]]] : memref<?xf64>
 // CHECK-LIR:               %[[VAL_22:.*]] = arith.mulf %[[VAL_21]], %[[VAL_14]] : f64
 // CHECK-LIR:               %[[VAL_23:.*]] = arith.addf %[[VAL_20]], %[[VAL_22]] : f64
-// CHECK-LIR:               memref.store %[[VAL_23]], %[[VAL_10]]{{\[}}%[[VAL_19]]] : memref<32xf64>
+// CHECK-LIR:               memref.store %[[VAL_23]], %[[VAL_2]]{{\[}}%[[VAL_19]]] : memref<32xf64>
 // CHECK-LIR:             }
 // CHECK-LIR:           }
-// CHECK-LIR:           return %[[VAL_10]] : memref<32xf64>
+// CHECK-LIR:           return %[[VAL_2]] : memref<32xf64>
 // CHECK-LIR:         }
 
 func.func @matvec(%arga: tensor<32x64xf64, #CSC>,

diff  --git a/mlir/test/Dialect/SparseTensor/sparse_lower_inplace.mlir b/mlir/test/Dialect/SparseTensor/sparse_lower_inplace.mlir
index b3e0d25a7548d..92237a27e18a7 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_lower_inplace.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_lower_inplace.mlir
@@ -23,7 +23,7 @@
 // CHECK-HIR-LABEL:   func @matvec(
 // CHECK-HIR-SAME:      %[[VAL_0:.*]]: tensor<32x64xf64, #sparse_tensor.encoding<{{{.*}}}>>,
 // CHECK-HIR-SAME:      %[[VAL_1:.*]]: tensor<64xf64>,
-// CHECK-HIR-SAME:      %[[VAL_2:.*]]: tensor<32xf64> {linalg.inplaceable = true}) -> tensor<32xf64> {
+// CHECK-HIR-SAME:      %[[VAL_2:.*]]: tensor<32xf64>) -> tensor<32xf64> {
 // CHECK-HIR-DAG:       %[[VAL_3:.*]] = arith.constant 32 : index
 // CHECK-HIR-DAG:       %[[VAL_4:.*]] = arith.constant 0 : index
 // CHECK-HIR-DAG:       %[[VAL_5:.*]] = arith.constant 1 : index
@@ -54,7 +54,7 @@
 // CHECK-MIR-LABEL:   func @matvec(
 // CHECK-MIR-SAME:      %[[VAL_0:.*]]: !llvm.ptr<i8>,
 // CHECK-MIR-SAME:      %[[VAL_1:.*]]: tensor<64xf64>,
-// CHECK-MIR-SAME:      %[[VAL_2:.*]]: tensor<32xf64> {linalg.inplaceable = true}) -> tensor<32xf64> {
+// CHECK-MIR-SAME:      %[[VAL_2:.*]]: tensor<32xf64>) -> tensor<32xf64> {
 // CHECK-MIR-DAG:       %[[VAL_3:.*]] = arith.constant 32 : index
 // CHECK-MIR-DAG:       %[[VAL_4:.*]] = arith.constant 0 : index
 // CHECK-MIR-DAG:       %[[VAL_5:.*]] = arith.constant 1 : index
@@ -85,7 +85,7 @@
 // CHECK-LIR-LABEL:   func @matvec(
 // CHECK-LIR-SAME:      %[[VAL_0:.*]]: !llvm.ptr<i8>,
 // CHECK-LIR-SAME:      %[[VAL_1:.*]]: memref<64xf64>,
-// CHECK-LIR-SAME:      %[[VAL_2:.*]]: memref<32xf64> {linalg.inplaceable = true}) -> memref<32xf64> {
+// CHECK-LIR-SAME:      %[[VAL_2:.*]]: memref<32xf64>) -> memref<32xf64> {
 // CHECK-LIR-DAG:       %[[VAL_3:.*]] = arith.constant 32 : index
 // CHECK-LIR-DAG:       %[[VAL_4:.*]] = arith.constant 0 : index
 // CHECK-LIR-DAG:       %[[VAL_5:.*]] = arith.constant 1 : index
@@ -112,7 +112,7 @@
 
 func.func @matvec(%arga: tensor<32x64xf64, #CSR>,
              %argb: tensor<64xf64>,
-	     %argx: tensor<32xf64> {linalg.inplaceable = true}) -> tensor<32xf64> {
+	           %argx: tensor<32xf64>) -> tensor<32xf64> {
   %0 = linalg.generic #trait_matvec
       ins(%arga, %argb : tensor<32x64xf64, #CSR>, tensor<64xf64>)
       outs(%argx: tensor<32xf64>) {

diff  --git a/mlir/test/Dialect/SparseTensor/sparse_nd.mlir b/mlir/test/Dialect/SparseTensor/sparse_nd.mlir
index 1e1e7f62737e6..ce2766f38b3a6 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_nd.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_nd.mlir
@@ -41,7 +41,7 @@
 // CHECK-DAG:       %[[VAL_16:.*]] = sparse_tensor.pointers %[[VAL_1]], %[[VAL_4]] : tensor<80x70x60x50x40x30x20x10xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "dense", "dense", "compressed", "compressed", "dense", "dense", "dense" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
 // CHECK-DAG:       %[[VAL_17:.*]] = sparse_tensor.indices %[[VAL_1]], %[[VAL_4]] : tensor<80x70x60x50x40x30x20x10xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "dense", "dense", "compressed", "compressed", "dense", "dense", "dense" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
 // CHECK-DAG:       %[[VAL_18:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<80x70x60x50x40x30x20x10xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "dense", "dense", "compressed", "compressed", "dense", "dense", "dense" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf32>
-// CHECK-DAG:       %[[VAL_20:.*]] = memref.alloc() : memref<10x20x30x40x50x60x70x80xf32>
+// CHECK-DAG:       %[[VAL_20:.*]] = bufferization.to_memref %[[VAL_2]] : memref<10x20x30x40x50x60x70x80xf32>
 // CHECK:           linalg.fill ins(%[[ZERO]] : f32) outs(%[[VAL_20]] : memref<10x20x30x40x50x60x70x80xf32>
 // CHECK:           scf.for %[[VAL_21:.*]] = %[[VAL_11]] to %[[VAL_10]] step %[[VAL_12]] {
 // CHECK:             scf.for %[[VAL_22:.*]] = %[[VAL_11]] to %[[VAL_9]] step %[[VAL_12]] {

diff  --git a/mlir/test/Dialect/SparseTensor/sparse_out.mlir b/mlir/test/Dialect/SparseTensor/sparse_out.mlir
index 96409e1271a85..f1acd0ae8acf5 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_out.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_out.mlir
@@ -46,7 +46,7 @@
 // CHECK:           %[[VAL_18:.*]] = sparse_tensor.load %[[VAL_0]] : tensor<32x16xf32, #sparse_tensor.encoding<{{.*}}>>
 // CHECK:           return %[[VAL_18]] : tensor<32x16xf32, #sparse_tensor.encoding<{{.*}}>>
 // CHECK:         }
-func.func @sparse_simply_dynamic1(%argx: tensor<32x16xf32, #DCSR> {linalg.inplaceable = true}) -> tensor<32x16xf32, #DCSR> {
+func.func @sparse_simply_dynamic1(%argx: tensor<32x16xf32, #DCSR>) -> tensor<32x16xf32, #DCSR> {
   %c = arith.constant 2.0 : f32
   %0 = linalg.generic #trait_scale_inpl
     outs(%argx: tensor<32x16xf32, #DCSR>) {
@@ -80,7 +80,7 @@ func.func @sparse_simply_dynamic1(%argx: tensor<32x16xf32, #DCSR> {linalg.inplac
 // CHECK:           %[[VAL_16:.*]] = sparse_tensor.load %[[VAL_0]] : tensor<32x16xf32, #sparse_tensor.encoding<{{.*}}>>
 // CHECK:           return %[[VAL_16]] : tensor<32x16xf32, #sparse_tensor.encoding<{{.*}}>>
 // CHECK:         }
-func.func @sparse_simply_dynamic2(%argx: tensor<32x16xf32, #DCSR> {linalg.inplaceable = true}) -> tensor<32x16xf32, #DCSR> {
+func.func @sparse_simply_dynamic2(%argx: tensor<32x16xf32, #DCSR>) -> tensor<32x16xf32, #DCSR> {
   %0 = linalg.generic #trait_scale_inpl
     outs(%argx: tensor<32x16xf32, #DCSR>) {
       ^bb(%x: f32):

diff  --git a/mlir/test/Dialect/SparseTensor/sparse_outbuf.mlir b/mlir/test/Dialect/SparseTensor/sparse_outbuf.mlir
index 8f9bd40355a97..6f44689e496e0 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_outbuf.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_outbuf.mlir
@@ -12,7 +12,7 @@
 
 // CHECK-LABEL:   func.func @allout_inplace(
 // CHECK-SAME:      %[[VAL_0:.*]]: tensor<10xi32, #{{.*}}>,
-// CHECK-SAME:      %[[VAL_1:.*]]: tensor<10xf32> {linalg.inplaceable = true}) -> tensor<10xf32> {
+// CHECK-SAME:      %[[VAL_1:.*]]: tensor<10xf32>) -> tensor<10xf32> {
 // CHECK-DAG:       %[[VAL_2:.*]] = arith.constant 0 : index
 // CHECK-DAG:       %[[VAL_3:.*]] = arith.constant 0.000000e+00 : f32
 // CHECK-DAG:       %[[VAL_4:.*]] = arith.constant 1 : index
@@ -33,7 +33,7 @@
 // CHECK:           return %[[VAL_15]] : tensor<10xf32>
 // CHECK:         }
 func.func @allout_inplace(%arga: tensor<10xi32, #SV>,
-                          %argb: tensor<10xf32> {linalg.inplaceable = true}) -> tensor<10xf32> {
+                          %argb: tensor<10xf32>) -> tensor<10xf32> {
   %0 = linalg.generic #trait
   ins(%arga: tensor<10xi32, #SV>)
   outs(%argb: tensor<10xf32>) {
@@ -53,7 +53,7 @@ func.func @allout_inplace(%arga: tensor<10xi32, #SV>,
 // CHECK:           %[[VAL_5:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_1]] : tensor<10xi32, #{{.*}}> to memref<?xindex>
 // CHECK:           %[[VAL_6:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_1]] : tensor<10xi32, #{{.*}}> to memref<?xindex>
 // CHECK:           %[[VAL_7:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<10xi32, #{{.*}}> to memref<?xi32>
-// CHECK:           %[[VAL_8:.*]] = memref.alloc() : memref<10xf32>
+// CHECK:           %[[VAL_8:.*]] = bufferization.to_memref %[[VAL_4]] : memref<10xf32>
 // CHECK:           linalg.fill ins(%[[VAL_2]] : f32) outs(%[[VAL_8]] : memref<10xf32>)
 // CHECK:           %[[VAL_9:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_1]]] : memref<?xindex>
 // CHECK:           %[[VAL_10:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_3]]] : memref<?xindex>
@@ -78,44 +78,9 @@ func.func @allout_materialize(%arga: tensor<10xi32, #SV>) -> tensor<10xf32> {
   return %0 : tensor<10xf32>
 }
 
-// CHECK-LABEL:   func.func @update_notinplace(
-// CHECK-SAME:      %[[VAL_0:.*]]: tensor<10xf32, #{{.*}}>,
-// CHECK-SAME:      %[[VAL_1:.*]]: tensor<10xf32> {linalg.inplaceable = false}) -> tensor<10xf32> {
-// CHECK-DAG:       %[[VAL_2:.*]] = arith.constant 0 : index
-// CHECK-DAG:       %[[VAL_3:.*]] = arith.constant 1 : index
-// CHECK:           %[[VAL_4:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_2]] : tensor<10xf32, #{{.*}}> to memref<?xindex>
-// CHECK:           %[[VAL_5:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_2]] : tensor<10xf32, #{{.*}}> to memref<?xindex>
-// CHECK:           %[[VAL_6:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<10xf32, #{{.*}}> to memref<?xf32>
-// CHECK:           %[[VAL_7:.*]] = memref.alloc() : memref<10xf32>
-// CHECK:           %[[VAL_8:.*]] = bufferization.to_memref %[[VAL_1]] : memref<10xf32>
-// CHECK:           memref.copy %[[VAL_8]], %[[VAL_7]] : memref<10xf32> to memref<10xf32>
-// CHECK:           %[[VAL_9:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_2]]] : memref<?xindex>
-// CHECK:           %[[VAL_10:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_3]]] : memref<?xindex>
-// CHECK:           scf.for %[[VAL_11:.*]] = %[[VAL_9]] to %[[VAL_10]] step %[[VAL_3]] {
-// CHECK:             %[[VAL_12:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_11]]] : memref<?xindex>
-// CHECK:             %[[VAL_13:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_11]]] : memref<?xf32>
-// CHECK:             %[[VAL_14:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_12]]] : memref<10xf32>
-// CHECK:             %[[VAL_15:.*]] = arith.addf %[[VAL_13]], %[[VAL_14]] : f32
-// CHECK:             memref.store %[[VAL_15]], %[[VAL_7]]{{\[}}%[[VAL_12]]] : memref<10xf32>
-// CHECK:           }
-// CHECK:           %[[VAL_16:.*]] = bufferization.to_tensor %[[VAL_7]] : memref<10xf32>
-// CHECK:           return %[[VAL_16]] : tensor<10xf32>
-// CHECK:         }
-func.func @update_notinplace(%arga: tensor<10xf32, #SV>,
-                             %argb: tensor<10xf32> {linalg.inplaceable = false}) -> tensor<10xf32> {
-  %0 = linalg.generic #trait
-  ins(%arga: tensor<10xf32, #SV>)
-  outs(%argb: tensor<10xf32>) {
-    ^bb(%a: f32, %x : f32):
-      %up = arith.addf %a, %x : f32
-      linalg.yield %up : f32
-  } -> tensor<10xf32>
-  return %0 : tensor<10xf32>
-}
-
 // CHECK-LABEL:   func.func @update_inplace(
 // CHECK-SAME:      %[[VAL_0:.*]]: tensor<10xf32, #{{.*}}>,
-// CHECK-SAME:      %[[VAL_1:.*]]: tensor<10xf32> {linalg.inplaceable = true}) -> tensor<10xf32> {
+// CHECK-SAME:      %[[VAL_1:.*]]: tensor<10xf32>) -> tensor<10xf32> {
 // CHECK-DAG:       %[[VAL_2:.*]] = arith.constant 0 : index
 // CHECK-DAG:       %[[VAL_3:.*]] = arith.constant 1 : index
 // CHECK:           %[[VAL_4:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_2]] : tensor<10xf32, #{{.*}}> to memref<?xindex>
@@ -135,7 +100,7 @@ func.func @update_notinplace(%arga: tensor<10xf32, #SV>,
 // CHECK:           return %[[VAL_15]] : tensor<10xf32>
 // CHECK:         }
 func.func @update_inplace(%arga: tensor<10xf32, #SV>,
-                          %argb: tensor<10xf32> {linalg.inplaceable = true}) -> tensor<10xf32> {
+                          %argb: tensor<10xf32>) -> tensor<10xf32> {
   %0 = linalg.generic #trait
   ins(%arga: tensor<10xf32, #SV>)
   outs(%argb: tensor<10xf32>) {

diff  --git a/mlir/test/Dialect/SparseTensor/sparse_perm.mlir b/mlir/test/Dialect/SparseTensor/sparse_perm.mlir
index 1c7bccebaedf5..849f147246790 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_perm.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_perm.mlir
@@ -24,7 +24,7 @@
 // CHECK-DAG:       %[[VAL_5:.*]] = arith.constant 0 : index
 // CHECK-DAG:       %[[VAL_6:.*]] = arith.constant 1 : index
 // CHECK-DAG:       %[[VAL_7:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<10x20x30xf32, #sparse_tensor.encoding<{{{.*}}}>>
-// CHECK-DAG:       %[[VAL_9:.*]] = memref.alloc() : memref<20x30x10xf32>
+// CHECK-DAG:       %[[VAL_9:.*]] = bufferization.to_memref %[[VAL_1]] : memref<20x30x10xf32>
 // CHECK:           linalg.fill ins(%[[ZERO]] : f32) outs(%[[VAL_9]] : memref<20x30x10xf32>)
 // CHECK:           scf.for %[[VAL_10:.*]] = %[[VAL_5]] to %[[VAL_3]] step %[[VAL_6]] {
 // CHECK:             scf.for %[[VAL_11:.*]] = %[[VAL_5]] to %[[VAL_4]] step %[[VAL_6]] {
@@ -63,7 +63,7 @@ func.func @sparse_static_dims(%arga: tensor<10x20x30xf32, #X>,
 // CHECK-DAG:       %[[VAL_6:.*]] = tensor.dim %[[VAL_1]], %[[VAL_3]] : tensor<?x?x?xf32>
 // CHECK-DAG:       %[[VAL_7:.*]] = tensor.dim %[[VAL_1]], %[[VAL_4]] : tensor<?x?x?xf32>
 // CHECK-DAG:       %[[VAL_8:.*]] = tensor.dim %[[VAL_1]], %[[VAL_2]] : tensor<?x?x?xf32>
-// CHECK-DAG:       %[[VAL_10:.*]] = memref.alloc(%[[VAL_6]], %[[VAL_7]], %[[VAL_8]]) : memref<?x?x?xf32>
+// CHECK-DAG:       %[[VAL_10:.*]] = bufferization.to_memref %[[VAL_1]] : memref<?x?x?xf32>
 // CHECK:           linalg.fill ins(%[[ZERO]] : f32) outs(%[[VAL_10]] : memref<?x?x?xf32>)
 // CHECK:           scf.for %[[VAL_11:.*]] = %[[VAL_3]] to %[[VAL_7]] step %[[VAL_4]] {
 // CHECK:             scf.for %[[VAL_12:.*]] = %[[VAL_3]] to %[[VAL_8]] step %[[VAL_4]] {
@@ -81,7 +81,7 @@ func.func @sparse_static_dims(%arga: tensor<10x20x30xf32, #X>,
 // CHECK:           return %[[VAL_19]] : tensor<?x?x?xf32>
 // CHECK:         }
 func.func @sparse_dynamic_dims(%arga: tensor<?x?x?xf32, #X>,
-                          %argx: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
+                               %argx: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
   %0 = linalg.generic #trait
     ins(%arga: tensor<?x?x?xf32, #X>)
     outs(%argx: tensor<?x?x?xf32>) {

diff  --git a/mlir/test/Dialect/SparseTensor/sparse_perm_lower.mlir b/mlir/test/Dialect/SparseTensor/sparse_perm_lower.mlir
index ce82c7d951f83..6230cf7492cc1 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_perm_lower.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_perm_lower.mlir
@@ -26,10 +26,8 @@
 // CHECK-HIR-DAG:       %[[VAL_6:.*]] = tensor.dim %[[VAL_0]], %[[VAL_3]] : tensor<?x?x?xf32, #sparse_tensor.encoding<{{{.*}}}>>
 // CHECK-HIR-DAG:       %[[VAL_7:.*]] = tensor.dim %[[VAL_0]], %[[VAL_2]] : tensor<?x?x?xf32, #sparse_tensor.encoding<{{{.*}}}>>
 // CHECK-HIR-DAG:       %[[VAL_8:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<?x?x?xf32, #sparse_tensor.encoding<{{{.*}}}>>
-// CHECK-HIR-DAG:       %[[VAL_9:.*]] = bufferization.to_memref %[[VAL_1]] : memref<f32>
-// CHECK-HIR-DAG:       %[[VAL_10:.*]] = memref.alloc() : memref<f32>
-// CHECK-HIR:           memref.copy %[[VAL_9]], %[[VAL_10]] : memref<f32> to memref<f32>
-// CHECK-HIR:           %[[VAL_11:.*]] = memref.load %[[VAL_10]][] : memref<f32>
+// CHECK-HIR-DAG:       %[[VAL_10:.*]] = bufferization.to_memref %[[VAL_1]] : memref<f32>
+// CHECK-HIR:           %[[VAL_11:.*]] = tensor.extract %[[VAL_1]][] : tensor<f32>
 // CHECK-HIR:           %[[VAL_12:.*]] = scf.for %[[VAL_13:.*]] = %[[VAL_3]] to %[[VAL_5]] step %[[VAL_2]] iter_args(%[[VAL_14:.*]] = %[[VAL_11]]) -> (f32) {
 // CHECK-HIR:             %[[VAL_15:.*]] = scf.for %[[VAL_16:.*]] = %[[VAL_3]] to %[[VAL_6]] step %[[VAL_2]] iter_args(%[[VAL_17:.*]] = %[[VAL_14]]) -> (f32) {
 // CHECK-HIR:               %[[VAL_18:.*]] = arith.muli %[[VAL_6]], %[[VAL_13]] : index
@@ -60,10 +58,8 @@
 // CHECK-MIR-DAG:       %[[VAL_6:.*]] = call @sparseDimSize(%[[VAL_0]], %[[VAL_3]]) : (!llvm.ptr<i8>, index) -> index
 // CHECK-MIR-DAG:       %[[VAL_7:.*]] = call @sparseDimSize(%[[VAL_0]], %[[VAL_2]]) : (!llvm.ptr<i8>, index) -> index
 // CHECK-MIR-DAG:       %[[VAL_8:.*]] = call @sparseValuesF32(%[[VAL_0]]) : (!llvm.ptr<i8>) -> memref<?xf32>
-// CHECK-MIR-DAG:       %[[VAL_9:.*]] = bufferization.to_memref %[[VAL_1]] : memref<f32>
-// CHECK-MIR-DAG:       %[[VAL_10:.*]] = memref.alloc() : memref<f32>
-// CHECK-MIR:           memref.copy %[[VAL_9]], %[[VAL_10]] : memref<f32> to memref<f32>
-// CHECK-MIR:           %[[VAL_11:.*]] = memref.load %[[VAL_10]][] : memref<f32>
+// CHECK-MIR-DAG:       %[[VAL_10:.*]] = bufferization.to_memref %[[VAL_1]] : memref<f32>
+// CHECK-MIR:           %[[VAL_11:.*]] = tensor.extract %[[VAL_1]][] : tensor<f32>
 // CHECK-MIR:           %[[VAL_12:.*]] = scf.for %[[VAL_13:.*]] = %[[VAL_4]] to %[[VAL_5]] step %[[VAL_3]] iter_args(%[[VAL_14:.*]] = %[[VAL_11]]) -> (f32) {
 // CHECK-MIR:             %[[VAL_15:.*]] = scf.for %[[VAL_16:.*]] = %[[VAL_4]] to %[[VAL_6]] step %[[VAL_3]] iter_args(%[[VAL_17:.*]] = %[[VAL_14]]) -> (f32) {
 // CHECK-MIR:               %[[VAL_18:.*]] = arith.muli %[[VAL_6]], %[[VAL_13]] : index

diff  --git a/mlir/test/Dialect/SparseTensor/sparse_scalars.mlir b/mlir/test/Dialect/SparseTensor/sparse_scalars.mlir
index d5ee64701b8e2..e5928f63282df 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_scalars.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_scalars.mlir
@@ -23,7 +23,7 @@
 // CHECK-SAME:              %[[VAL_1:.*1]]: tensor<f32>,
 // CHECK-SAME:              %[[VAL_2:.*2]]: f32,
 // CHECK-SAME:              %[[VAL_3:.*3]]: f32,
-// CHECK-SAME:              %[[VAL_4:.*4]]: tensor<32x16xf32> {linalg.inplaceable = true}) -> tensor<32x16xf32> {
+// CHECK-SAME:              %[[VAL_4:.*4]]: tensor<32x16xf32>) -> tensor<32x16xf32> {
 // CHECK-DAG:       %[[VAL_5:.*]] = arith.constant 2.200000e+00 : f32
 // CHECK-DAG:       %[[VAL_6:.*]] = arith.constant 0 : index
 // CHECK-DAG:       %[[VAL_7:.*]] = arith.constant 1 : index
@@ -60,10 +60,10 @@
 // CHECK:           return %[[VAL_34]] : tensor<32x16xf32>
 // CHECK:         }
 func.func @mul(%arga: tensor<32x16xf32, #SparseMatrix>,
-          %argp: tensor<f32>,
-          %argq: f32,
-          %argr: f32,
-          %argx: tensor<32x16xf32> {linalg.inplaceable = true}) -> tensor<32x16xf32> {
+               %argp: tensor<f32>,
+               %argq: f32,
+               %argr: f32,
+               %argx: tensor<32x16xf32>) -> tensor<32x16xf32> {
   %s = arith.addf %argq, %argr : f32
   %c = arith.constant 2.2 : f32
   %0 = linalg.generic #trait

diff  --git a/mlir/test/Dialect/SparseTensor/sparse_vector.mlir b/mlir/test/Dialect/SparseTensor/sparse_vector.mlir
index 1425a7b896213..816a5c75c7c80 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_vector.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_vector.mlir
@@ -585,7 +585,7 @@ func.func @mul_ds(%arga: tensor<512x1024xf32, #SparseMatrix>, %argb: tensor<512x
 // CHECK-VEC4:       return
 //
 func.func @add_dense(%arga: tensor<32x64xf64, #SparseMatrix>,
-                %argx: tensor<33x64xf64> {linalg.inplaceable = true}) -> tensor<33x64xf64> {
+                %argx: tensor<33x64xf64>) -> tensor<33x64xf64> {
   %0 = linalg.generic #trait_affine
      ins(%arga: tensor<32x64xf64, #SparseMatrix>)
     outs(%argx: tensor<33x64xf64>) {

diff  --git a/mlir/test/Dialect/SparseTensor/sparse_vector_chain.mlir b/mlir/test/Dialect/SparseTensor/sparse_vector_chain.mlir
index df55b8373e0ee..bd557c3d9e1fe 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_vector_chain.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_vector_chain.mlir
@@ -18,7 +18,7 @@
 // while-loop are chained before horizontally reducing these back to scalar.
 //
 // CHECK-LABEL:   func @sparse_matrix_sum(
-// CHECK-SAME:      %[[VAL_0:.*]]: tensor<f64> {linalg.inplaceable = true},
+// CHECK-SAME:      %[[VAL_0:.*]]: tensor<f64>,
 // CHECK-SAME:      %[[VAL_1:.*]]: tensor<64x32xf64, #sparse_tensor.encoding<{{{.*}}}>>,
 // CHECK-SAME:      %[[VAL_2:.*]]: tensor<64x32xf64, #sparse_tensor.encoding<{{{.*}}}>>) -> tensor<f64> {
 // CHECK-DAG:       %[[VAL_3:.*]] = arith.constant dense<0.000000e+00> : vector<8xf64>
@@ -112,9 +112,9 @@
 // CHECK:           %[[VAL_87:.*]] = bufferization.to_tensor %[[VAL_15]] : memref<f64>
 // CHECK:           return %[[VAL_87]] : tensor<f64>
 // CHECK:         }
-func.func @sparse_matrix_sum(%argx: tensor<f64> {linalg.inplaceable = true},
-                         %arga: tensor<64x32xf64, #SparseMatrix>,
-                         %argb: tensor<64x32xf64, #SparseMatrix>) -> tensor<f64> {
+func.func @sparse_matrix_sum(%argx: tensor<f64>,
+                             %arga: tensor<64x32xf64, #SparseMatrix>,
+                             %argb: tensor<64x32xf64, #SparseMatrix>) -> tensor<f64> {
   %0 = linalg.generic #trait
      ins(%arga, %argb: tensor<64x32xf64, #SparseMatrix>,
                        tensor<64x32xf64, #SparseMatrix>)

diff  --git a/mlir/test/Dialect/SparseTensor/sparse_vector_index.mlir b/mlir/test/Dialect/SparseTensor/sparse_vector_index.mlir
index 792e741931499..dfea43e73f0e9 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_vector_index.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_vector_index.mlir
@@ -32,7 +32,8 @@
 // CHECK-DAG:       %[[VAL_7:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_6]] : tensor<8xi64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
 // CHECK-DAG:       %[[VAL_8:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_6]] : tensor<8xi64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
 // CHECK-DAG:       %[[VAL_9:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<8xi64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xi64>
-// CHECK-DAG:       %[[VAL_10:.*]] = memref.alloc() : memref<8xi64>
+// CHECK-DAG:       %[[VAL_10a:.*]] = linalg.init_tensor [8] : tensor<8xi64>
+// CHECK-DAG:       %[[VAL_10:.*]] = bufferization.to_memref %[[VAL_10a]] : memref<8xi64>
 // CHECK-DAG:       linalg.fill ins(%[[VAL_5]] : i64) outs(%[[VAL_10]] : memref<8xi64>)
 // CHECK-DAG:       %[[VAL_11:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_6]]] : memref<?xindex>
 // CHECK-DAG:       %[[VAL_12:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_4]]] : memref<?xindex>
@@ -72,7 +73,8 @@ func.func @sparse_index_1d_conj(%arga: tensor<8xi64, #SparseVector>) -> tensor<8
 // CHECK-DAG:       %[[VAL_6:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_5]] : tensor<8xi64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
 // CHECK-DAG:       %[[VAL_7:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_5]] : tensor<8xi64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
 // CHECK-DAG:       %[[VAL_8:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<8xi64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xi64>
-// CHECK-DAG:       %[[VAL_9:.*]] = memref.alloc() : memref<8xi64>
+// CHECK-DAG:       %[[VAL_9a:.*]] = linalg.init_tensor [8] : tensor<8xi64>
+// CHECK-DAG:       %[[VAL_9:.*]] = bufferization.to_memref %[[VAL_9a]] : memref<8xi64>
 // CHECK-DAG:       linalg.fill ins(%[[VAL_3]] : i64) outs(%[[VAL_9]] : memref<8xi64>)
 // CHECK-DAG:       %[[VAL_10:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_5]]] : memref<?xindex>
 // CHECK-DAG:       %[[VAL_11:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_2]]] : memref<?xindex>

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_binary.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_binary.mlir
index cb6b47a15fd37..de412bc0c0627 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_binary.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_binary.mlir
@@ -353,10 +353,8 @@ module {
     vector.print %1 : vector<16xf64>
     // Dump the dense vector to verify structure is correct.
     %dv = sparse_tensor.convert %arg0 : tensor<?xf64, #SparseVector> to tensor<?xf64>
-    %2 = bufferization.to_memref %dv : memref<?xf64>
-    %3 = vector.transfer_read %2[%c0], %d0: memref<?xf64>, vector<32xf64>
+    %3 = vector.transfer_read %dv[%c0], %d0: tensor<?xf64>, vector<32xf64>
     vector.print %3 : vector<32xf64>
-    memref.dealloc %2 : memref<?xf64>
     return
   }
 
@@ -369,10 +367,8 @@ module {
     vector.print %1 : vector<24xi32>
     // Dump the dense vector to verify structure is correct.
     %dv = sparse_tensor.convert %arg0 : tensor<?xi32, #SparseVector> to tensor<?xi32>
-    %2 = bufferization.to_memref %dv : memref<?xi32>
-    %3 = vector.transfer_read %2[%c0], %d0: memref<?xi32>, vector<32xi32>
+    %3 = vector.transfer_read %dv[%c0], %d0: tensor<?xi32>, vector<32xi32>
     vector.print %3 : vector<32xi32>
-    memref.dealloc %2 : memref<?xi32>
     return
   }
 
@@ -380,10 +376,8 @@ module {
     %d0 = arith.constant 0.0 : f64
     %c0 = arith.constant 0 : index
     %dm = sparse_tensor.convert %arg0 : tensor<?x?xf64, #DCSR> to tensor<?x?xf64>
-    %0 = bufferization.to_memref %dm : memref<?x?xf64>
-    %1 = vector.transfer_read %0[%c0, %c0], %d0: memref<?x?xf64>, vector<4x8xf64>
+    %1 = vector.transfer_read %dm[%c0, %c0], %d0: tensor<?x?xf64>, vector<4x8xf64>
     vector.print %1 : vector<4x8xf64>
-    memref.dealloc %0 : memref<?x?xf64>
     return
   }
 
@@ -392,16 +386,13 @@ module {
     %du = arith.constant -1.0 : f64
 
     %c = sparse_tensor.convert %A : tensor<4x4xf64, #DCSR> to tensor<4x4xf64>
-    %m = bufferization.to_memref %c : memref<4x4xf64>
-    %v = vector.transfer_read %m[%c0, %c0], %du: memref<4x4xf64>, vector<4x4xf64>
+    %v = vector.transfer_read %c[%c0, %c0], %du: tensor<4x4xf64>, vector<4x4xf64>
     vector.print %v : vector<4x4xf64>
 
     %1 = sparse_tensor.values %A : tensor<4x4xf64, #DCSR> to memref<?xf64>
     %2 = vector.transfer_read %1[%c0], %du: memref<?xf64>, vector<16xf64>
     vector.print %2 : vector<16xf64>
 
-    // Release the resources.
-    memref.dealloc %m : memref<4x4xf64>
     return
   }
 
@@ -410,16 +401,13 @@ module {
     %du = arith.constant -1 : i8
 
     %c = sparse_tensor.convert %A : tensor<4x4xi8, #DCSR> to tensor<4x4xi8>
-    %m = bufferization.to_memref %c : memref<4x4xi8>
-    %v = vector.transfer_read %m[%c0, %c0], %du: memref<4x4xi8>, vector<4x4xi8>
+    %v = vector.transfer_read %c[%c0, %c0], %du: tensor<4x4xi8>, vector<4x4xi8>
     vector.print %v : vector<4x4xi8>
 
     %1 = sparse_tensor.values %A : tensor<4x4xi8, #DCSR> to memref<?xi8>
     %2 = vector.transfer_read %1[%c0], %du: memref<?xi8>, vector<16xi8>
     vector.print %2 : vector<16xi8>
-
-    // Release the resources.
-    memref.dealloc %m : memref<4x4xi8>
+  
     return
   }
 

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_cast.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_cast.mlir
index 577bca6a23f1f..b5c56e9a39a7b 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_cast.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_cast.mlir
@@ -45,110 +45,110 @@ module {
   // Since all casts are "zero preserving" unary operations, lattice computation
   // and conversion to sparse code is straightforward.
   //
-  func.func @sparse_cast_s32_to_f32(%arga: tensor<10xi32, #SV>) -> tensor<10xf32> {
-    %argx = arith.constant dense<0.0> : tensor<10xf32>
+  func.func @sparse_cast_s32_to_f32(%arga: tensor<10xi32, #SV>,
+                                    %argb: tensor<10xf32>) -> tensor<10xf32> {
     %0 = linalg.generic #trait_cast
       ins(%arga: tensor<10xi32, #SV>)
-      outs(%argx: tensor<10xf32>) {
+      outs(%argb: tensor<10xf32>) {
         ^bb(%a: i32, %x : f32):
           %cst = arith.sitofp %a : i32 to f32
           linalg.yield %cst : f32
     } -> tensor<10xf32>
     return %0 : tensor<10xf32>
   }
-  func.func @sparse_cast_u32_to_f32(%arga: tensor<10xi32, #SV>) -> tensor<10xf32> {
-    %argx = arith.constant dense<0.0> : tensor<10xf32>
+  func.func @sparse_cast_u32_to_f32(%arga: tensor<10xi32, #SV>,
+                                    %argb: tensor<10xf32>) -> tensor<10xf32> {
     %0 = linalg.generic #trait_cast
       ins(%arga: tensor<10xi32, #SV>)
-      outs(%argx: tensor<10xf32>) {
+      outs(%argb: tensor<10xf32>) {
         ^bb(%a: i32, %x : f32):
           %cst = arith.uitofp %a : i32 to f32
           linalg.yield %cst : f32
     } -> tensor<10xf32>
     return %0 : tensor<10xf32>
   }
-  func.func @sparse_cast_f32_to_s32(%arga: tensor<10xf32, #SV>) -> tensor<10xi32> {
-    %argx = arith.constant dense<0> : tensor<10xi32>
+  func.func @sparse_cast_f32_to_s32(%arga: tensor<10xf32, #SV>,
+                                    %argb: tensor<10xi32>) -> tensor<10xi32> {
     %0 = linalg.generic #trait_cast
       ins(%arga: tensor<10xf32, #SV>)
-      outs(%argx: tensor<10xi32>) {
+      outs(%argb: tensor<10xi32>) {
         ^bb(%a: f32, %x : i32):
           %cst = arith.fptosi %a : f32 to i32
           linalg.yield %cst : i32
     } -> tensor<10xi32>
     return %0 : tensor<10xi32>
   }
-  func.func @sparse_cast_f64_to_u32(%arga: tensor<10xf64, #SV>) -> tensor<10xi32> {
-    %argx = arith.constant dense<0> : tensor<10xi32>
+  func.func @sparse_cast_f64_to_u32(%arga: tensor<10xf64, #SV>,
+                                    %argb: tensor<10xi32>) -> tensor<10xi32> {
     %0 = linalg.generic #trait_cast
       ins(%arga: tensor<10xf64, #SV>)
-      outs(%argx: tensor<10xi32>) {
+      outs(%argb: tensor<10xi32>) {
         ^bb(%a: f64, %x : i32):
           %cst = arith.fptoui %a : f64 to i32
           linalg.yield %cst : i32
     } -> tensor<10xi32>
     return %0 : tensor<10xi32>
   }
-  func.func @sparse_cast_f32_to_f64(%arga: tensor<10xf32, #SV>) -> tensor<10xf64> {
-    %argx = arith.constant dense<0.0> : tensor<10xf64>
+  func.func @sparse_cast_f32_to_f64(%arga: tensor<10xf32, #SV>,
+                                    %argb: tensor<10xf64>) -> tensor<10xf64> {
     %0 = linalg.generic #trait_cast
       ins(%arga: tensor<10xf32, #SV>)
-      outs(%argx: tensor<10xf64>) {
+      outs(%argb: tensor<10xf64>) {
         ^bb(%a: f32, %x : f64):
           %cst = arith.extf %a : f32 to f64
           linalg.yield %cst : f64
     } -> tensor<10xf64>
     return %0 : tensor<10xf64>
   }
-  func.func @sparse_cast_f64_to_f32(%arga: tensor<10xf64, #SV>) -> tensor<10xf32> {
-    %argx = arith.constant dense<0.0> : tensor<10xf32>
+  func.func @sparse_cast_f64_to_f32(%arga: tensor<10xf64, #SV>,
+                                    %argb: tensor<10xf32>) -> tensor<10xf32> {
     %0 = linalg.generic #trait_cast
       ins(%arga: tensor<10xf64, #SV>)
-      outs(%argx: tensor<10xf32>) {
+      outs(%argb: tensor<10xf32>) {
         ^bb(%a: f64, %x : f32):
           %cst = arith.truncf %a : f64 to f32
           linalg.yield %cst : f32
     } -> tensor<10xf32>
     return %0 : tensor<10xf32>
   }
-  func.func @sparse_cast_s32_to_u64(%arga: tensor<10xi32, #SV>) -> tensor<10xi64> {
-    %argx = arith.constant dense<0> : tensor<10xi64>
+  func.func @sparse_cast_s32_to_u64(%arga: tensor<10xi32, #SV>,
+                                    %argb: tensor<10xi64>) -> tensor<10xi64> {
     %0 = linalg.generic #trait_cast
       ins(%arga: tensor<10xi32, #SV>)
-      outs(%argx: tensor<10xi64>) {
+      outs(%argb: tensor<10xi64>) {
         ^bb(%a: i32, %x : i64):
           %cst = arith.extsi %a : i32 to i64
           linalg.yield %cst : i64
     } -> tensor<10xi64>
     return %0 : tensor<10xi64>
   }
-  func.func @sparse_cast_u32_to_s64(%arga: tensor<10xi32, #SV>) -> tensor<10xi64> {
-    %argx = arith.constant dense<0> : tensor<10xi64>
+  func.func @sparse_cast_u32_to_s64(%arga: tensor<10xi32, #SV>,
+                                    %argb: tensor<10xi64>) -> tensor<10xi64> {
     %0 = linalg.generic #trait_cast
       ins(%arga: tensor<10xi32, #SV>)
-      outs(%argx: tensor<10xi64>) {
+      outs(%argb: tensor<10xi64>) {
         ^bb(%a: i32, %x : i64):
           %cst = arith.extui %a : i32 to i64
           linalg.yield %cst : i64
     } -> tensor<10xi64>
     return %0 : tensor<10xi64>
   }
-  func.func @sparse_cast_i32_to_i8(%arga: tensor<10xi32, #SV>) -> tensor<10xi8> {
-    %argx = arith.constant dense<0> : tensor<10xi8>
+  func.func @sparse_cast_i32_to_i8(%arga: tensor<10xi32, #SV>,
+                                   %argb: tensor<10xi8>) -> tensor<10xi8> {
     %0 = linalg.generic #trait_cast
       ins(%arga: tensor<10xi32, #SV>)
-      outs(%argx: tensor<10xi8>) {
+      outs(%argb: tensor<10xi8>) {
         ^bb(%a: i32, %x : i8):
           %cst = arith.trunci %a : i32 to i8
           linalg.yield %cst : i8
     } -> tensor<10xi8>
     return %0 : tensor<10xi8>
   }
-  func.func @sparse_cast_f32_as_s32(%arga: tensor<10xf32, #SV>) -> tensor<10xi32> {
-    %argx = arith.constant dense<0> : tensor<10xi32>
+  func.func @sparse_cast_f32_as_s32(%arga: tensor<10xf32, #SV>,
+                                    %argb: tensor<10xi32>) -> tensor<10xi32> {
     %0 = linalg.generic #trait_cast
       ins(%arga: tensor<10xf32, #SV>)
-      outs(%argx: tensor<10xi32>) {
+      outs(%argb: tensor<10xi32>) {
         ^bb(%a: f32, %x : i32):
           %cst = arith.bitcast %a : f32 to i32
           linalg.yield %cst : i32
@@ -168,6 +168,12 @@ module {
     %f = arith.constant 0.0 : f32
     %d = arith.constant 0.0 : f64
 
+    %zero_b = arith.constant dense<0> : tensor<10xi8>
+    %zero_d = arith.constant dense<0.0> : tensor<10xf64>
+    %zero_f = arith.constant dense<0.0> : tensor<10xf32>
+    %zero_i = arith.constant dense<0> : tensor<10xi32>
+    %zero_l = arith.constant dense<0> : tensor<10xi64>
+
     // Initialize dense tensors, convert to a sparse vectors.
     %0 = arith.constant dense<[ -4, -3, -2, -1, 0, 1, 2, 3, 4, 305 ]> : tensor<10xi32>
     %1 = sparse_tensor.convert %0 : tensor<10xi32> to tensor<10xi32, #SV>
@@ -182,82 +188,72 @@ module {
     //
     // CHECK: ( -4, -3, -2, -1, 0, 1, 2, 3, 4, 305 )
     //
-    %c0 = call @sparse_cast_s32_to_f32(%1) : (tensor<10xi32, #SV>) -> tensor<10xf32>
-    %m0 = bufferization.to_memref %c0 : memref<10xf32>
-    %v0 = vector.transfer_read %m0[%z], %f: memref<10xf32>, vector<10xf32>
+    %c0 = call @sparse_cast_s32_to_f32(%1, %zero_f) : (tensor<10xi32, #SV>, tensor<10xf32>) -> tensor<10xf32>
+    %v0 = vector.transfer_read %c0[%z], %f: tensor<10xf32>, vector<10xf32>
     vector.print %v0 : vector<10xf32>
 
     //
     // CHECK: ( 4.29497e+09, 4.29497e+09, 4.29497e+09, 4.29497e+09, 0, 1, 2, 3, 4, 305 )
     //
-    %c1 = call @sparse_cast_u32_to_f32(%1) : (tensor<10xi32, #SV>) -> tensor<10xf32>
-    %m1 = bufferization.to_memref %c1 : memref<10xf32>
-    %v1 = vector.transfer_read %m1[%z], %f: memref<10xf32>, vector<10xf32>
+    %c1 = call @sparse_cast_u32_to_f32(%1, %zero_f) : (tensor<10xi32, #SV>, tensor<10xf32>) -> tensor<10xf32>
+    %v1 = vector.transfer_read %c1[%z], %f: tensor<10xf32>, vector<10xf32>
     vector.print %v1 : vector<10xf32>
 
     //
     // CHECK: ( -4, -3, -2, -1, 0, 1, 2, 3, 4, 305 )
     //
-    %c2 = call @sparse_cast_f32_to_s32(%3) : (tensor<10xf32, #SV>) -> tensor<10xi32>
-    %m2 = bufferization.to_memref %c2 : memref<10xi32>
-    %v2 = vector.transfer_read %m2[%z], %i: memref<10xi32>, vector<10xi32>
+    %c2 = call @sparse_cast_f32_to_s32(%3, %zero_i) : (tensor<10xf32, #SV>, tensor<10xi32>) -> tensor<10xi32>
+    %v2 = vector.transfer_read %c2[%z], %i: tensor<10xi32>, vector<10xi32>
     vector.print %v2 : vector<10xi32>
 
     //
     // CHECK: ( 4294967295, 4294967294, 4294967293, 4294967292, 0, 1, 2, 3, 4, 305 )
     //
-    %c3 = call @sparse_cast_f64_to_u32(%7) : (tensor<10xf64, #SV>) -> tensor<10xi32>
-    %m3 = bufferization.to_memref %c3 : memref<10xi32>
-    %v3 = vector.transfer_read %m3[%z], %i: memref<10xi32>, vector<10xi32>
+    %c3 = call @sparse_cast_f64_to_u32(%7, %zero_i) : (tensor<10xf64, #SV>, tensor<10xi32>) -> tensor<10xi32>
+    %v3 = vector.transfer_read %c3[%z], %i: tensor<10xi32>, vector<10xi32>
     %vu = vector.bitcast %v3 : vector<10xi32> to vector<10xui32>
     vector.print %vu : vector<10xui32>
 
     //
     // CHECK: ( -4.4, -3.3, -2.2, -1.1, 0, 1.1, 2.2, 3.3, 4.4, 305.5 )
     //
-    %c4 = call @sparse_cast_f32_to_f64(%3) : (tensor<10xf32, #SV>) -> tensor<10xf64>
-    %m4 = bufferization.to_memref %c4 : memref<10xf64>
-    %v4 = vector.transfer_read %m4[%z], %d: memref<10xf64>, vector<10xf64>
+    %c4 = call @sparse_cast_f32_to_f64(%3, %zero_d) : (tensor<10xf32, #SV>, tensor<10xf64>) -> tensor<10xf64>
+    %v4 = vector.transfer_read %c4[%z], %d: tensor<10xf64>, vector<10xf64>
     vector.print %v4 : vector<10xf64>
 
     //
     // CHECK: ( -4.4, -3.3, -2.2, -1.1, 0, 1.1, 2.2, 3.3, 4.4, 305.5 )
     //
-    %c5 = call @sparse_cast_f64_to_f32(%5) : (tensor<10xf64, #SV>) -> tensor<10xf32>
-    %m5 = bufferization.to_memref %c5 : memref<10xf32>
-    %v5 = vector.transfer_read %m5[%z], %f: memref<10xf32>, vector<10xf32>
+    %c5 = call @sparse_cast_f64_to_f32(%5, %zero_f) : (tensor<10xf64, #SV>, tensor<10xf32>) -> tensor<10xf32>
+    %v5 = vector.transfer_read %c5[%z], %f: tensor<10xf32>, vector<10xf32>
     vector.print %v5 : vector<10xf32>
 
     //
     // CHECK: ( -4, -3, -2, -1, 0, 1, 2, 3, 4, 305 )
     //
-    %c6 = call @sparse_cast_s32_to_u64(%1) : (tensor<10xi32, #SV>) -> tensor<10xi64>
-    %m6 = bufferization.to_memref %c6 : memref<10xi64>
-    %v6 = vector.transfer_read %m6[%z], %l: memref<10xi64>, vector<10xi64>
+    %c6 = call @sparse_cast_s32_to_u64(%1, %zero_l) : (tensor<10xi32, #SV>, tensor<10xi64>) -> tensor<10xi64>
+    %v6 = vector.transfer_read %c6[%z], %l: tensor<10xi64>, vector<10xi64>
     vector.print %v6 : vector<10xi64>
 
     //
     // CHECK: ( 4294967292, 4294967293, 4294967294, 4294967295, 0, 1, 2, 3, 4, 305 )
     //
-    %c7 = call @sparse_cast_u32_to_s64(%1) : (tensor<10xi32, #SV>) -> tensor<10xi64>
-    %m7 = bufferization.to_memref %c7 : memref<10xi64>
-    %v7 = vector.transfer_read %m7[%z], %l: memref<10xi64>, vector<10xi64>
+    %c7 = call @sparse_cast_u32_to_s64(%1, %zero_l) : (tensor<10xi32, #SV>, tensor<10xi64>) -> tensor<10xi64>
+    %v7 = vector.transfer_read %c7[%z], %l: tensor<10xi64>, vector<10xi64>
     vector.print %v7 : vector<10xi64>
 
     //
     // CHECK: ( -4, -3, -2, -1, 0, 1, 2, 3, 4, 49 )
     //
-    %c8 = call @sparse_cast_i32_to_i8(%1) : (tensor<10xi32, #SV>) -> tensor<10xi8>
-    %m8 = bufferization.to_memref %c8 : memref<10xi8>
-    %v8 = vector.transfer_read %m8[%z], %b: memref<10xi8>, vector<10xi8>
+    %c8 = call @sparse_cast_i32_to_i8(%1, %zero_b) : (tensor<10xi32, #SV>, tensor<10xi8>) -> tensor<10xi8>
+    %v8 = vector.transfer_read %c8[%z], %b: tensor<10xi8>, vector<10xi8>
     vector.print %v8 : vector<10xi8>
 
     //
     // CHECK: ( -1064514355, -1068289229, -1072902963, -1081291571, 0, 1066192077, 1074580685, 1079194419, 1082969293, 1134084096 )
     //
-    %c9 = call @sparse_cast_f32_as_s32(%3) : (tensor<10xf32, #SV>) -> tensor<10xi32>
-    %m9 = bufferization.to_memref %c9 : memref<10xi32>
-    %v9 = vector.transfer_read %m9[%z], %i: memref<10xi32>, vector<10xi32>
+    %c9 = call @sparse_cast_f32_as_s32(%3, %zero_i) : (tensor<10xf32, #SV>, tensor<10xi32>) -> tensor<10xi32>
+    %v9 = vector.transfer_read %c9[%z], %i: tensor<10xi32>, vector<10xi32>
     vector.print %v9 : vector<10xi32>
 
     // Release the resources.
@@ -265,16 +261,6 @@ module {
     sparse_tensor.release %3 : tensor<10xf32, #SV>
     sparse_tensor.release %5 : tensor<10xf64, #SV>
     sparse_tensor.release %7 : tensor<10xf64, #SV>
-    memref.dealloc %m0 : memref<10xf32>
-    memref.dealloc %m1 : memref<10xf32>
-    memref.dealloc %m2 : memref<10xi32>
-    memref.dealloc %m3 : memref<10xi32>
-    memref.dealloc %m4 : memref<10xf64>
-    memref.dealloc %m5 : memref<10xf32>
-    memref.dealloc %m6 : memref<10xi64>
-    memref.dealloc %m7 : memref<10xi64>
-    memref.dealloc %m8 : memref<10xi8>
-    memref.dealloc %m9 : memref<10xi32>
 
     return
   }

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_conversion_sparse2dense.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_conversion_sparse2dense.mlir
index 2515ac9558f07..6dd9ce96574bc 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_conversion_sparse2dense.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_conversion_sparse2dense.mlir
@@ -51,57 +51,41 @@ module {
   }
   func.func @dumpAndRelease_234(%arg0: tensor<2x3x4xf64>) {
     call @dump(%arg0) : (tensor<2x3x4xf64>) -> ()
-    %1 = bufferization.to_memref %arg0 : memref<2x3x4xf64>
-    memref.dealloc %1 : memref<2x3x4xf64>
     return
   }
   func.func @dumpAndRelease_p34(%arg0: tensor<?x3x4xf64>) {
     %0 = tensor.cast %arg0 : tensor<?x3x4xf64> to tensor<2x3x4xf64>
     call @dump(%0) : (tensor<2x3x4xf64>) -> ()
-    %1 = bufferization.to_memref %arg0 : memref<?x3x4xf64>
-    memref.dealloc %1 : memref<?x3x4xf64>
     return
   }
   func.func @dumpAndRelease_2p4(%arg0: tensor<2x?x4xf64>) {
     %0 = tensor.cast %arg0 : tensor<2x?x4xf64> to tensor<2x3x4xf64>
     call @dump(%0) : (tensor<2x3x4xf64>) -> ()
-    %1 = bufferization.to_memref %arg0 : memref<2x?x4xf64>
-    memref.dealloc %1 : memref<2x?x4xf64>
     return
   }
   func.func @dumpAndRelease_23p(%arg0: tensor<2x3x?xf64>) {
     %0 = tensor.cast %arg0 : tensor<2x3x?xf64> to tensor<2x3x4xf64>
     call @dump(%0) : (tensor<2x3x4xf64>) -> ()
-    %1 = bufferization.to_memref %arg0 : memref<2x3x?xf64>
-    memref.dealloc %1 : memref<2x3x?xf64>
     return
   }
   func.func @dumpAndRelease_2pp(%arg0: tensor<2x?x?xf64>) {
     %0 = tensor.cast %arg0 : tensor<2x?x?xf64> to tensor<2x3x4xf64>
     call @dump(%0) : (tensor<2x3x4xf64>) -> ()
-    %1 = bufferization.to_memref %arg0 : memref<2x?x?xf64>
-    memref.dealloc %1 : memref<2x?x?xf64>
     return
   }
   func.func @dumpAndRelease_p3p(%arg0: tensor<?x3x?xf64>) {
     %0 = tensor.cast %arg0 : tensor<?x3x?xf64> to tensor<2x3x4xf64>
     call @dump(%0) : (tensor<2x3x4xf64>) -> ()
-    %1 = bufferization.to_memref %arg0 : memref<?x3x?xf64>
-    memref.dealloc %1 : memref<?x3x?xf64>
     return
   }
   func.func @dumpAndRelease_pp4(%arg0: tensor<?x?x4xf64>) {
     %0 = tensor.cast %arg0 : tensor<?x?x4xf64> to tensor<2x3x4xf64>
     call @dump(%0) : (tensor<2x3x4xf64>) -> ()
-    %1 = bufferization.to_memref %arg0 : memref<?x?x4xf64>
-    memref.dealloc %1 : memref<?x?x4xf64>
     return
   }
   func.func @dumpAndRelease_ppp(%arg0: tensor<?x?x?xf64>) {
     %0 = tensor.cast %arg0 : tensor<?x?x?xf64> to tensor<2x3x4xf64>
     call @dump(%0) : (tensor<2x3x4xf64>) -> ()
-    %1 = bufferization.to_memref %arg0 : memref<?x?x?xf64>
-    memref.dealloc %1 : memref<?x?x?xf64>
     return
   }
 

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_conversion_sparse2sparse.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_conversion_sparse2sparse.mlir
index a66a0dcb29aa0..50a798e297f62 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_conversion_sparse2sparse.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_conversion_sparse2sparse.mlir
@@ -32,8 +32,6 @@ module {
   }
   func.func @dumpAndRelease_234(%arg0: tensor<2x3x4xf64>) {
     call @dump(%arg0) : (tensor<2x3x4xf64>) -> ()
-    %1 = bufferization.to_memref %arg0 : memref<2x3x4xf64>
-    memref.dealloc %1 : memref<2x3x4xf64>
     return
   }
 

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_dot.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_dot.mlir
index 7c1b37b6ddf7a..49610a5fe9eb7 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_dot.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_dot.mlir
@@ -11,8 +11,8 @@ module {
   // Sparse kernel.
   //
   func.func @sparse_dot(%a: tensor<1024xf32, #SparseVector>,
-                   %b: tensor<1024xf32, #SparseVector>) -> tensor<f32> {
-    %x = linalg.init_tensor [] : tensor<f32>
+                        %b: tensor<1024xf32, #SparseVector>,
+                        %x: tensor<f32>) -> tensor<f32> {
     %dot = linalg.dot ins(%a, %b: tensor<1024xf32, #SparseVector>,
                                   tensor<1024xf32, #SparseVector>)
          outs(%x: tensor<f32>) -> tensor<f32>
@@ -37,16 +37,16 @@ module {
     //
     // CHECK: 53
     //
-    %0 = call @sparse_dot(%s1, %s2) : (tensor<1024xf32, #SparseVector>,
-                                       tensor<1024xf32, #SparseVector>) -> tensor<f32>
+    %x = bufferization.alloc_tensor() : tensor<f32>
+    %0 = call @sparse_dot(%s1, %s2, %x) : (tensor<1024xf32, #SparseVector>,
+                                           tensor<1024xf32, #SparseVector>,
+                                           tensor<f32>) -> tensor<f32>
     %1 = tensor.extract %0[] : tensor<f32>
     vector.print %1 : f32
 
     // Release the resources.
     sparse_tensor.release %s1 : tensor<1024xf32, #SparseVector>
     sparse_tensor.release %s2 : tensor<1024xf32, #SparseVector>
-    %m = bufferization.to_memref %0 : memref<f32>
-    memref.dealloc %m : memref<f32>
 
     return
   }

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_filter_conv2d.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_filter_conv2d.mlir
index 6f86ac56a0596..336173ad6c329 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_filter_conv2d.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_filter_conv2d.mlir
@@ -63,14 +63,12 @@ module {
     // CHECK-SAME: ( 0, 0, 3, 6, -3, -6 ),
     // CHECK-SAME: ( 2, -1, 3, 0, -3, 0 ) )
     //
-    %m = bufferization.to_memref %0 : memref<6x6xi32>
-    %v = vector.transfer_read %m[%c0, %c0], %i0
-      : memref<6x6xi32>, vector<6x6xi32>
+    %v = vector.transfer_read %0[%c0, %c0], %i0
+      : tensor<6x6xi32>, vector<6x6xi32>
     vector.print %v : vector<6x6xi32>
 
     // Release the resources.
     sparse_tensor.release %sparse_filter : tensor<3x3xi32, #DCSR>
-    memref.dealloc %m : memref<6x6xi32>
 
     return
   }

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_flatten.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_flatten.mlir
index 855704d6f60ee..9b2095f67f48c 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_flatten.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_flatten.mlir
@@ -45,7 +45,7 @@ module {
   // A kernel that flattens a rank 8 tensor into a dense matrix.
   //
   func.func @kernel_flatten(%arga: tensor<7x3x3x3x3x3x5x3xf64, #SparseTensor>,
-                       %argx: tensor<7x3xf64> {linalg.inplaceable = true})
+                            %argx: tensor<7x3xf64>)
 		       -> tensor<7x3xf64> {
     %0 = linalg.generic #trait_flatten
       ins(%arga: tensor<7x3x3x3x3x3x5x3xf64, #SparseTensor>)

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_matmul.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_matmul.mlir
index f9e0cbc918a76..25ff5fd99d1ab 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_matmul.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_matmul.mlir
@@ -138,8 +138,7 @@ module {
     // CHECK-SAME: ( 405.48, 443.88, 482.28, 520.68 ),
     // CHECK-SAME: ( 413.84, 453.04, 492.24, 531.44 ) )
     //
-    %m0 = bufferization.to_memref %0 : memref<4x4xf64>
-    %v0 = vector.transfer_read %m0[%c0, %c0], %d1 : memref<4x4xf64>, vector<4x4xf64>
+    %v0 = vector.transfer_read %0[%c0, %c0], %d1 : tensor<4x4xf64>, vector<4x4xf64>
     vector.print %v0 : vector<4x4xf64>
 
     //
@@ -149,8 +148,7 @@ module {
     // CHECK-SAME: ( 413.84, 453.04, 492.24, 531.44 ) )
     //
     %c1 = sparse_tensor.convert %1 : tensor<4x4xf64, #CSR> to tensor<4x4xf64>
-    %m1 = bufferization.to_memref %c1 : memref<4x4xf64>
-    %v1 = vector.transfer_read %m1[%c0, %c0], %d1 : memref<4x4xf64>, vector<4x4xf64>
+    %v1 = vector.transfer_read %c1[%c0, %c0], %d1 : tensor<4x4xf64>, vector<4x4xf64>
     vector.print %v1 : vector<4x4xf64>
 
     //
@@ -160,8 +158,7 @@ module {
     // CHECK-SAME: ( 413.84, 453.04, 492.24, 531.44 ) )
     //
     %c2 = sparse_tensor.convert %2 : tensor<4x4xf64, #DCSR> to tensor<4x4xf64>
-    %m2 = bufferization.to_memref %c2 : memref<4x4xf64>
-    %v2 = vector.transfer_read %m2[%c0, %c0], %d1 : memref<4x4xf64>, vector<4x4xf64>
+    %v2 = vector.transfer_read %c2[%c0, %c0], %d1 : tensor<4x4xf64>, vector<4x4xf64>
     vector.print %v2 : vector<4x4xf64>
 
     //
@@ -170,8 +167,7 @@ module {
     // CHECK-SAME: ( 23.46, 25.76, 28.06, 30.36 ),
     // CHECK-SAME: ( 10.8, 11.8, 12.8, 13.8 ) )
     //
-    %m3 = bufferization.to_memref %3 : memref<4x4xf64>
-    %v3 = vector.transfer_read %m3[%c0, %c0], %d1 : memref<4x4xf64>, vector<4x4xf64>
+    %v3 = vector.transfer_read %3[%c0, %c0], %d1 : tensor<4x4xf64>, vector<4x4xf64>
     vector.print %v3 : vector<4x4xf64>
 
     //
@@ -181,8 +177,7 @@ module {
     // CHECK-SAME: ( 10.8, 11.8, 12.8, 13.8 ) )
     //
     %c4 = sparse_tensor.convert %4 : tensor<4x4xf64, #CSR> to tensor<4x4xf64>
-    %m4 = bufferization.to_memref %c4 : memref<4x4xf64>
-    %v4 = vector.transfer_read %m4[%c0, %c0], %d1 : memref<4x4xf64>, vector<4x4xf64>
+    %v4 = vector.transfer_read %c4[%c0, %c0], %d1 : tensor<4x4xf64>, vector<4x4xf64>
     vector.print %v4 : vector<4x4xf64>
 
     //
@@ -192,31 +187,27 @@ module {
     // CHECK-SAME: ( 10.8, 11.8, 12.8, 13.8 ) )
     //
     %c5 = sparse_tensor.convert %5 : tensor<4x4xf64, #DCSR> to tensor<4x4xf64>
-    %m5 = bufferization.to_memref %c5 : memref<4x4xf64>
-    %v5 = vector.transfer_read %m5[%c0, %c0], %d1 : memref<4x4xf64>, vector<4x4xf64>
+    %v5 = vector.transfer_read %c5[%c0, %c0], %d1 : tensor<4x4xf64>, vector<4x4xf64>
     vector.print %v5 : vector<4x4xf64>
 
     //
     // CHECK: ( ( 0, 30.5, 4.2, 0 ), ( 0, 0, 0, 0 ), ( 0, 0, 4.6, 0 ), ( 0, 0, 7, 8 ) )
     //
-    %m6 = bufferization.to_memref %6 : memref<4x4xf64>
-    %v6 = vector.transfer_read %m6[%c0, %c0], %d1 : memref<4x4xf64>, vector<4x4xf64>
+    %v6 = vector.transfer_read %6[%c0, %c0], %d1 : tensor<4x4xf64>, vector<4x4xf64>
     vector.print %v6 : vector<4x4xf64>
 
     //
     // CHECK: ( ( 0, 30.5, 4.2, 0 ), ( 0, 0, 0, 0 ), ( 0, 0, 4.6, 0 ), ( 0, 0, 7, 8 ) )
     //
     %c7 = sparse_tensor.convert %7 : tensor<4x4xf64, #CSR> to tensor<4x4xf64>
-    %m7 = bufferization.to_memref %c7 : memref<4x4xf64>
-    %v7 = vector.transfer_read %m7[%c0, %c0], %d1 : memref<4x4xf64>, vector<4x4xf64>
+    %v7 = vector.transfer_read %c7[%c0, %c0], %d1 : tensor<4x4xf64>, vector<4x4xf64>
     vector.print %v7 : vector<4x4xf64>
 
     //
     // CHECK: ( ( 0, 30.5, 4.2, 0 ), ( 0, 0, 0, 0 ), ( 0, 0, 4.6, 0 ), ( 0, 0, 7, 8 ) )
     //
     %c8 = sparse_tensor.convert %8 : tensor<4x4xf64, #DCSR> to tensor<4x4xf64>
-    %m8 = bufferization.to_memref %c8 : memref<4x4xf64>
-    %v8 = vector.transfer_read %m8[%c0, %c0], %d1 : memref<4x4xf64>, vector<4x4xf64>
+    %v8 = vector.transfer_read %c8[%c0, %c0], %d1 : tensor<4x4xf64>, vector<4x4xf64>
     vector.print %v8 : vector<4x4xf64>
 
     //
@@ -247,15 +238,6 @@ module {
     sparse_tensor.release %5 : tensor<4x4xf64, #DCSR>
     sparse_tensor.release %7 : tensor<4x4xf64, #CSR>
     sparse_tensor.release %8 : tensor<4x4xf64, #DCSR>
-    memref.dealloc %m0 : memref<4x4xf64>
-    memref.dealloc %m1 : memref<4x4xf64>
-    memref.dealloc %m2 : memref<4x4xf64>
-    memref.dealloc %m3 : memref<4x4xf64>
-    memref.dealloc %m4 : memref<4x4xf64>
-    memref.dealloc %m5 : memref<4x4xf64>
-    memref.dealloc %m6 : memref<4x4xf64>
-    memref.dealloc %m7 : memref<4x4xf64>
-    memref.dealloc %m8 : memref<4x4xf64>
 
     return
   }

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_matrix_ops.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_matrix_ops.mlir
index 79fd9c6c814a4..dabecaec9f7fa 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_matrix_ops.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_matrix_ops.mlir
@@ -54,8 +54,7 @@ module {
   }
 
   // Scales a sparse matrix in place.
-  func.func @matrix_scale_inplace(%argx: tensor<?x?xf64, #DCSR>
-                             {linalg.inplaceable = true}) -> tensor<?x?xf64, #DCSR> {
+  func.func @matrix_scale_inplace(%argx: tensor<?x?xf64, #DCSR>) -> tensor<?x?xf64, #DCSR> {
     %s = arith.constant 2.0 : f64
     %0 = linalg.generic #trait_scale_inpl
       outs(%argx: tensor<?x?xf64, #DCSR>) {
@@ -68,7 +67,7 @@ module {
 
   // Adds two sparse matrices element-wise into a new sparse matrix.
   func.func @matrix_add(%arga: tensor<?x?xf64, #DCSR>,
-                   %argb: tensor<?x?xf64, #DCSR>) -> tensor<?x?xf64, #DCSR> {
+                        %argb: tensor<?x?xf64, #DCSR>) -> tensor<?x?xf64, #DCSR> {
     %c0 = arith.constant 0 : index
     %c1 = arith.constant 1 : index
     %d0 = tensor.dim %arga, %c0 : tensor<?x?xf64, #DCSR>
@@ -86,7 +85,7 @@ module {
 
   // Multiplies two sparse matrices element-wise into a new sparse matrix.
   func.func @matrix_mul(%arga: tensor<?x?xf64, #DCSR>,
-                   %argb: tensor<?x?xf64, #DCSR>) -> tensor<?x?xf64, #DCSR> {
+                        %argb: tensor<?x?xf64, #DCSR>) -> tensor<?x?xf64, #DCSR> {
     %c0 = arith.constant 0 : index
     %c1 = arith.constant 1 : index
     %d0 = tensor.dim %arga, %c0 : tensor<?x?xf64, #DCSR>
@@ -107,10 +106,8 @@ module {
     %d0 = arith.constant 0.0 : f64
     %c0 = arith.constant 0 : index
     %dm = sparse_tensor.convert %arg0 : tensor<?x?xf64, #DCSR> to tensor<?x?xf64>
-    %0 = bufferization.to_memref %dm : memref<?x?xf64>
-    %1 = vector.transfer_read %0[%c0, %c0], %d0: memref<?x?xf64>, vector<4x8xf64>
+    %1 = vector.transfer_read %dm[%c0, %c0], %d0: tensor<?x?xf64>, vector<4x8xf64>
     vector.print %1 : vector<4x8xf64>
-    memref.dealloc %0 : memref<?x?xf64>
     return
   }
 
@@ -129,22 +126,24 @@ module {
          [6.0, 5.0, 4.0, 3.0, 2.0, 1.0 ]
     > : tensor<4x8xf64>
     %sm1 = sparse_tensor.convert %m1 : tensor<4x8xf64> to tensor<?x?xf64, #DCSR>
+    // TODO: Use %sm1 when we support sparse tensor copies.
+    %sm1_dup = sparse_tensor.convert %m1 : tensor<4x8xf64> to tensor<?x?xf64, #DCSR>
     %sm2 = sparse_tensor.convert %m2 : tensor<4x8xf64> to tensor<?x?xf64, #DCSR>
 
     // Call sparse matrix kernels.
     %0 = call @matrix_scale(%sm1)
       : (tensor<?x?xf64, #DCSR>) -> tensor<?x?xf64, #DCSR>
-    %1 = call @matrix_scale_inplace(%sm1)
+    %1 = call @matrix_scale_inplace(%sm1_dup)
       : (tensor<?x?xf64, #DCSR>) -> tensor<?x?xf64, #DCSR>
-    %2 = call @matrix_add(%sm1, %sm2)
+    %2 = call @matrix_add(%1, %sm2)
       : (tensor<?x?xf64, #DCSR>, tensor<?x?xf64, #DCSR>) -> tensor<?x?xf64, #DCSR>
-    %3 = call @matrix_mul(%sm1, %sm2)
+    %3 = call @matrix_mul(%1, %sm2)
       : (tensor<?x?xf64, #DCSR>, tensor<?x?xf64, #DCSR>) -> tensor<?x?xf64, #DCSR>
 
     //
     // Verify the results.
     //
-    // CHECK:      ( ( 2, 4, 0, 0, 0, 0, 0, 0 ), ( 0, 0, 0, 0, 0, 0, 0, 6 ), ( 0, 0, 8, 0, 10, 0, 0, 12 ), ( 14, 0, 16, 18, 0, 0, 0, 0 ) )
+    // CHECK:      ( ( 1, 2, 0, 0, 0, 0, 0, 0 ), ( 0, 0, 0, 0, 0, 0, 0, 3 ), ( 0, 0, 4, 0, 5, 0, 0, 6 ), ( 7, 0, 8, 9, 0, 0, 0, 0 ) )
     // CHECK-NEXT: ( ( 6, 0, 0, 0, 0, 0, 0, 5 ), ( 4, 0, 0, 0, 0, 0, 3, 0 ), ( 0, 2, 0, 0, 0, 0, 0, 1 ), ( 0, 0, 0, 0, 0, 0, 0, 0 ) )
     // CHECK-NEXT: ( ( 2, 4, 0, 0, 0, 0, 0, 0 ), ( 0, 0, 0, 0, 0, 0, 0, 6 ), ( 0, 0, 8, 0, 10, 0, 0, 12 ), ( 14, 0, 16, 18, 0, 0, 0, 0 ) )
     // CHECK-NEXT: ( ( 2, 4, 0, 0, 0, 0, 0, 0 ), ( 0, 0, 0, 0, 0, 0, 0, 6 ), ( 0, 0, 8, 0, 10, 0, 0, 12 ), ( 14, 0, 16, 18, 0, 0, 0, 0 ) )
@@ -160,6 +159,7 @@ module {
 
     // Release the resources.
     sparse_tensor.release %sm1 : tensor<?x?xf64, #DCSR>
+    sparse_tensor.release %sm1_dup : tensor<?x?xf64, #DCSR>
     sparse_tensor.release %sm2 : tensor<?x?xf64, #DCSR>
     sparse_tensor.release %0 : tensor<?x?xf64, #DCSR>
     sparse_tensor.release %2 : tensor<?x?xf64, #DCSR>

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_matvec.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_matvec.mlir
index 0a3a7cd8f5e92..2bfc58f7cf28f 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_matvec.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_matvec.mlir
@@ -44,8 +44,8 @@ module {
   // into a dense vector x.
   //
   func.func @kernel_matvec(%arga: tensor<?x?xi32, #SparseMatrix>,
-                      %argb: tensor<?xi32>,
-                      %argx: tensor<?xi32> {linalg.inplaceable = true})
+                           %argb: tensor<?xi32>,
+                           %argx: tensor<?xi32>)
 		      -> tensor<?xi32> {
     %0 = linalg.generic #matvec
       ins(%arga, %argb: tensor<?x?xi32, #SparseMatrix>, tensor<?xi32>)

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_mttkrp.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_mttkrp.mlir
index c74864c3e71c8..db49380828a82 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_mttkrp.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_mttkrp.mlir
@@ -42,9 +42,9 @@ module {
   // http://tensor-compiler.org/docs/data_analytics/index.html.
   //
   func.func @kernel_mttkrp(%argb: tensor<?x?x?xf64, #SparseTensor>,
-                      %argc: tensor<?x?xf64>,
-                      %argd: tensor<?x?xf64>,
-                      %arga: tensor<?x?xf64> {linalg.inplaceable = true})
+                           %argc: tensor<?x?xf64>,
+                           %argd: tensor<?x?xf64>,
+                           %arga: tensor<?x?xf64>)
 		      -> tensor<?x?xf64> {
     %0 = linalg.generic #mttkrp
       ins(%argb, %argc, %argd:

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_out_reduction.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_out_reduction.mlir
index 16b26131308f5..9d1960329b0e7 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_out_reduction.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_out_reduction.mlir
@@ -77,15 +77,13 @@ module {
     vector.print %vv : vector<4xi32>
     %dm = sparse_tensor.convert %0
       : tensor<?x?xi32, #SparseMatrix> to tensor<?x?xi32>
-    %db = bufferization.to_memref %dm : memref<?x?xi32>
-    %vm = vector.transfer_read %db[%c0, %c0], %i0: memref<?x?xi32>, vector<3x3xi32>
+    %vm = vector.transfer_read %dm[%c0, %c0], %i0: tensor<?x?xi32>, vector<3x3xi32>
     vector.print %vm : vector<3x3xi32>
 
     // Release the resources.
     sparse_tensor.release %st1 : tensor<?x?x?xi32, #SparseTensor>
     sparse_tensor.release %st2 : tensor<?x?x?xi32, #SparseTensor>
     sparse_tensor.release %0 : tensor<?x?xi32, #SparseMatrix>
-    memref.dealloc %db : memref<?x?xi32>
     return
   }
 }

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_out_simple.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_out_simple.mlir
index d16ba1dcc604d..d279f134e1cb4 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_out_simple.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_out_simple.mlir
@@ -41,7 +41,7 @@ module {
   // a sparse tensor as output, but although the values of the
   // sparse tensor change, its nonzero structure remains the same.
   //
-  func.func @kernel_eltwise_mult(%argx: tensor<?x?xf64, #DCSR> {linalg.inplaceable = true})
+  func.func @kernel_eltwise_mult(%argx: tensor<?x?xf64, #DCSR>)
     -> tensor<?x?xf64, #DCSR> {
     %0 = linalg.generic #eltwise_mult
       outs(%argx: tensor<?x?xf64, #DCSR>) {

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_quantized_matmul.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_quantized_matmul.mlir
index 33daf749247b0..fba1c0cb07643 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_quantized_matmul.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_quantized_matmul.mlir
@@ -65,14 +65,12 @@ module {
     // CHECK-SAME: ( -254, 0, 256, -300, -30, -6 ),
     // CHECK-SAME: ( 1397, 0, -1408, 100, 10, 33 ) )
     //
-    %m = bufferization.to_memref %0 : memref<5x6xi32>
-    %v = vector.transfer_read %m[%c0, %c0], %i0
-      : memref<5x6xi32>, vector<5x6xi32>
+    %v = vector.transfer_read %0[%c0, %c0], %i0
+      : tensor<5x6xi32>, vector<5x6xi32>
     vector.print %v : vector<5x6xi32>
 
     // Release the resources.
     sparse_tensor.release %sparse_input2 : tensor<3x6xi8, #DCSR>
-    memref.dealloc %m : memref<5x6xi32>
 
     return
   }

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_reductions.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_reductions.mlir
index 83406e6952977..0ae1f17be0bef 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_reductions.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_reductions.mlir
@@ -109,14 +109,14 @@ module {
     return %0 : tensor<i32>
   }
 
-  func.func @dump_i32(%arg0 : memref<i32>) {
-    %v = memref.load %arg0[] : memref<i32>
+  func.func @dump_i32(%arg0 : tensor<i32>) {
+    %v = tensor.extract %arg0[] : tensor<i32>
     vector.print %v : i32
     return
   }
 
-  func.func @dump_f32(%arg0 : memref<f32>) {
-    %v = memref.load %arg0[] : memref<f32>
+  func.func @dump_f32(%arg0 : tensor<f32>) {
+    %v = tensor.extract %arg0[] : tensor<f32>
     vector.print %v : f32
     return
   }
@@ -185,33 +185,19 @@ module {
     // CHECK: 15
     // CHECK: 10
     //
-    %m0 = bufferization.to_memref %0 : memref<i32>
-    call @dump_i32(%m0) : (memref<i32>) -> ()
-    %m1 = bufferization.to_memref %1 : memref<f32>
-    call @dump_f32(%m1) : (memref<f32>) -> ()
-    %m2 = bufferization.to_memref %2 : memref<i32>
-    call @dump_i32(%m2) : (memref<i32>) -> ()
-    %m3 = bufferization.to_memref %3 : memref<f32>
-    call @dump_f32(%m3) : (memref<f32>) -> ()
-    %m4 = bufferization.to_memref %4 : memref<i32>
-    call @dump_i32(%m4) : (memref<i32>) -> ()
-    %m5 = bufferization.to_memref %5 : memref<i32>
-    call @dump_i32(%m5) : (memref<i32>) -> ()
-    %m6 = bufferization.to_memref %6 : memref<i32>
-    call @dump_i32(%m6) : (memref<i32>) -> ()
+    call @dump_i32(%0) : (tensor<i32>) -> ()
+    call @dump_f32(%1) : (tensor<f32>) -> ()
+    call @dump_i32(%2) : (tensor<i32>) -> ()
+    call @dump_f32(%3) : (tensor<f32>) -> ()
+    call @dump_i32(%4) : (tensor<i32>) -> ()
+    call @dump_i32(%5) : (tensor<i32>) -> ()
+    call @dump_i32(%6) : (tensor<i32>) -> ()
 
     // Release the resources.
     sparse_tensor.release %sparse_input_i32 : tensor<32xi32, #SV>
     sparse_tensor.release %sparse_input_f32 : tensor<32xf32, #SV>
     sparse_tensor.release %dense_input_i32  : tensor<32xi32, #DV>
     sparse_tensor.release %dense_input_f32  : tensor<32xf32, #DV>
-    memref.dealloc %m0 : memref<i32>
-    memref.dealloc %m1 : memref<f32>
-    memref.dealloc %m2 : memref<i32>
-    memref.dealloc %m3 : memref<f32>
-    memref.dealloc %m4 : memref<i32>
-    memref.dealloc %m5 : memref<i32>
-    memref.dealloc %m6 : memref<i32>
 
     return
   }

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_reshape.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_reshape.mlir
index 57d2a931fad5d..be539c3cd7c1b 100755
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_reshape.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_reshape.mlir
@@ -129,6 +129,8 @@ module {
     sparse_tensor.release %collapse3 : tensor<12xf64, #SparseVector>
 
     // Release dense resources.
+    // TODO(springerm): Replace these with a bufferization.release op (operating
+    // on tensors).
     %meme1 = bufferization.to_memref %expand1 : memref<3x4xf64>
     memref.dealloc %meme1 : memref<3x4xf64>
     %memc1 = bufferization.to_memref %collapse1 : memref<12xf64>

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_sampled_matmul.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_sampled_matmul.mlir
index 25b8019c46995..f43359482f9f7 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_sampled_matmul.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_sampled_matmul.mlir
@@ -45,9 +45,9 @@ module {
   // A kernel that computes a sampled matrix matrix multiplication.
   //
   func.func @sampled_dense_dense(%args: tensor<?x?xf32, #SparseMatrix>,
-                            %arga: tensor<?x?xf32>,
-                            %argb: tensor<?x?xf32>,
-                            %argx: tensor<?x?xf32> {linalg.inplaceable = true}) -> tensor<?x?xf32> {
+                                 %arga: tensor<?x?xf32>,
+                                 %argb: tensor<?x?xf32>,
+                                 %argx: tensor<?x?xf32>) -> tensor<?x?xf32> {
     %0 = linalg.generic #trait_sampled_dense_dense
       ins(%args, %arga, %argb: tensor<?x?xf32, #SparseMatrix>, tensor<?x?xf32>, tensor<?x?xf32>)
       outs(%argx: tensor<?x?xf32>) {

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_scale.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_scale.mlir
index 09a17e10ba5a4..79f7f8417ecfd 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_scale.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_scale.mlir
@@ -31,8 +31,7 @@ module {
   //
   // A kernel that scales a sparse matrix A by a factor of 2.0.
   //
-  func.func @sparse_scale(%argx: tensor<8x8xf32, #CSR>
-                     {linalg.inplaceable = true}) -> tensor<8x8xf32, #CSR> {
+  func.func @sparse_scale(%argx: tensor<8x8xf32, #CSR>) -> tensor<8x8xf32, #CSR> {
     %c = arith.constant 2.0 : f32
     %0 = linalg.generic #trait_scale
       outs(%argx: tensor<8x8xf32, #CSR>) {

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_spmm.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_spmm.mlir
index f0b503f7e32b1..f3a7cc9b77738 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_spmm.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_spmm.mlir
@@ -41,8 +41,8 @@ module {
   // into a dense matrix X.
   //
   func.func @kernel_spmm(%arga: tensor<?x?xf64, #SparseMatrix>,
-                    %argb: tensor<?x?xf64>,
-                    %argx: tensor<?x?xf64> {linalg.inplaceable = true}) -> tensor<?x?xf64> {
+                         %argb: tensor<?x?xf64>,
+                         %argx: tensor<?x?xf64>) -> tensor<?x?xf64> {
     %0 = linalg.generic #spmm
       ins(%arga, %argb: tensor<?x?xf64, #SparseMatrix>, tensor<?x?xf64>)
       outs(%argx: tensor<?x?xf64>) {

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_sum.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_sum.mlir
index 9396543ea851c..0e27da8cdeee4 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_sum.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_sum.mlir
@@ -39,7 +39,7 @@ module {
   // A kernel that sum-reduces a matrix to a single scalar.
   //
   func.func @kernel_sum_reduce(%arga: tensor<?x?xf64, #SparseMatrix>,
-                          %argx: tensor<f64> {linalg.inplaceable = true}) -> tensor<f64> {
+                               %argx: tensor<f64>) -> tensor<f64> {
     %0 = linalg.generic #trait_sum_reduce
       ins(%arga: tensor<?x?xf64, #SparseMatrix>)
       outs(%argx: tensor<f64>) {
@@ -61,9 +61,7 @@ module {
 
     // Setup memory for a single reduction scalar,
     // initialized to zero.
-    %xdata = memref.alloc() : memref<f64>
-    memref.store %d0, %xdata[] : memref<f64>
-    %x = bufferization.to_tensor %xdata : memref<f64>
+    %x = tensor.from_elements %d0 : tensor<f64>
 
     // Read the sparse matrix from file, construct sparse storage.
     %fileName = call @getTensorFilename(%c0) : (index) -> (!Filename)
@@ -77,12 +75,10 @@ module {
     //
     // CHECK: 30.2
     //
-    %m = bufferization.to_memref %0 : memref<f64>
-    %v = memref.load %m[] : memref<f64>
+    %v = tensor.extract %0[] : tensor<f64>
     vector.print %v : f64
 
     // Release the resources.
-    memref.dealloc %xdata : memref<f64>
     sparse_tensor.release %a : tensor<?x?xf64, #SparseMatrix>
 
     return

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_sum_bf16.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_sum_bf16.mlir
index 6e9f8b17eff84..5863b43ffdd9a 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_sum_bf16.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_sum_bf16.mlir
@@ -24,7 +24,7 @@ module {
   // A kernel that sum-reduces a matrix to a single scalar.
   //
   func.func @kernel_sum_reduce(%arga: tensor<?x?xbf16, #SparseMatrix>,
-                          %argx: tensor<bf16> {linalg.inplaceable = true}) -> tensor<bf16> {
+                               %argx: tensor<bf16>) -> tensor<bf16> {
     %0 = linalg.generic #trait_sum_reduce
       ins(%arga: tensor<?x?xbf16, #SparseMatrix>)
       outs(%argx: tensor<bf16>) {

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_sum_c32.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_sum_c32.mlir
index fe460f5b535b1..793ddefd02751 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_sum_c32.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_sum_c32.mlir
@@ -30,7 +30,7 @@ module {
   // A kernel that sum-reduces a matrix to a single scalar.
   //
   func.func @kernel_sum_reduce(%arga: tensor<?x?xcomplex<f64>, #SparseMatrix>,
-                          %argx: tensor<complex<f64>> {linalg.inplaceable = true}) -> tensor<complex<f64>> {
+                               %argx: tensor<complex<f64>>) -> tensor<complex<f64>> {
     %0 = linalg.generic #trait_sum_reduce
       ins(%arga: tensor<?x?xcomplex<f64>, #SparseMatrix>)
       outs(%argx: tensor<complex<f64>>) {
@@ -53,9 +53,9 @@ module {
 
     // Setup memory for a single reduction scalar,
     // initialized to zero.
-    %xdata = memref.alloc() : memref<complex<f64>>
-    memref.store %d0, %xdata[] : memref<complex<f64>>
-    %x = bufferization.to_tensor %xdata : memref<complex<f64>>
+    // TODO: tensor.from_elements does not support complex.
+    %alloc = bufferization.alloc_tensor() : tensor<complex<f64>>
+    %x = tensor.insert %d0 into %alloc[] : tensor<complex<f64>>
 
     // Read the sparse matrix from file, construct sparse storage.
     %fileName = call @getTensorFilename(%c0) : (index) -> (!Filename)
@@ -70,15 +70,13 @@ module {
     // CHECK: 30.2
     // CHECK-NEXT: 22.2
     //
-    %m = bufferization.to_memref %0 : memref<complex<f64>>
-    %v = memref.load %m[] : memref<complex<f64>>
+    %v = tensor.extract %0[] : tensor<complex<f64>>
     %real = complex.re %v : complex<f64>
     %imag = complex.im %v : complex<f64>
     vector.print %real : f64
     vector.print %imag : f64
 
     // Release the resources.
-    memref.dealloc %xdata : memref<complex<f64>>
     sparse_tensor.release %a : tensor<?x?xcomplex<f64>, #SparseMatrix>
 
     return

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_sum_f16.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_sum_f16.mlir
index 085f28769072f..a50f2e0b66617 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_sum_f16.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_sum_f16.mlir
@@ -24,7 +24,7 @@ module {
   // A kernel that sum-reduces a matrix to a single scalar.
   //
   func.func @kernel_sum_reduce(%arga: tensor<?x?xf16, #SparseMatrix>,
-                          %argx: tensor<f16> {linalg.inplaceable = true}) -> tensor<f16> {
+                               %argx: tensor<f16>) -> tensor<f16> {
     %0 = linalg.generic #trait_sum_reduce
       ins(%arga: tensor<?x?xf16, #SparseMatrix>)
       outs(%argx: tensor<f16>) {

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_tanh.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_tanh.mlir
index 83a3447797031..dab12bed7a75b 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_tanh.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_tanh.mlir
@@ -16,8 +16,7 @@
 
 module {
   // Performs zero-preserving math to sparse vector.
-  func.func @sparse_tanh(%vec: tensor<?xf64, #SparseVector>
-                          {linalg.inplaceable = true})
+  func.func @sparse_tanh(%vec: tensor<?xf64, #SparseVector>)
                        -> tensor<?xf64, #SparseVector> {
     %0 = linalg.generic #trait_op
       outs(%vec: tensor<?xf64, #SparseVector>) {
@@ -40,10 +39,8 @@ module {
     // Dump the dense vector to verify structure is correct.
     %dv = sparse_tensor.convert %arg0
         : tensor<?xf64, #SparseVector> to tensor<?xf64>
-    %2 = bufferization.to_memref %dv : memref<?xf64>
-    %3 = vector.transfer_read %2[%c0], %d0: memref<?xf64>, vector<32xf64>
+    %3 = vector.transfer_read %dv[%c0], %d0: tensor<?xf64>, vector<32xf64>
     vector.print %3 : vector<32xf64>
-    memref.dealloc %2 : memref<?xf64>
     return
   }
 
@@ -67,7 +64,7 @@ module {
     // CHECK: {{( -0.761[0-9]*, 0.761[0-9]*, 0.96[0-9]*, 0.99[0-9]*, 0.99[0-9]*, 0.99[0-9]*, 0.99[0-9]*, 0.99[0-9]*, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 )}}
     // CHECK-NEXT {{( -0.761[0-9]*, 0, 0, 0.761[0-9]*, 0, 0, 0, 0, 0, 0, 0, 0.96[0-9]*, 0, 0, 0, 0, 0, 0.99[0-9]*, 0, 0, 0.99[0-9]*, 0.99[0-9]*, 0, 0, 0, 0, 0, 0, 0.99[0-9]*, 0.99[0-9]*, 0, 1 )}}
     //
-    call @dump_vec_f64(%sv1) : (tensor<?xf64, #SparseVector>) -> ()
+    call @dump_vec_f64(%0) : (tensor<?xf64, #SparseVector>) -> ()
 
     // Release the resources.
     sparse_tensor.release %sv1 : tensor<?xf64, #SparseVector>

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_transpose.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_transpose.mlir
index 496da608483fc..13b4737f18811 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_transpose.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_transpose.mlir
@@ -100,24 +100,19 @@ module {
     // CHECK-NEXT: ( 1.4, 0, 3.4 )
     //
     %x = sparse_tensor.convert %0 : tensor<4x3xf64, #DCSR> to tensor<4x3xf64>
-    %m = bufferization.to_memref %x : memref<4x3xf64>
     scf.for %i = %c0 to %c4 step %c1 {
-      %v1 = vector.transfer_read %m[%i, %c0], %du: memref<4x3xf64>, vector<3xf64>
+      %v1 = vector.transfer_read %x[%i, %c0], %du: tensor<4x3xf64>, vector<3xf64>
       vector.print %v1 : vector<3xf64>
     }
     %y = sparse_tensor.convert %1 : tensor<4x3xf64, #DCSR> to tensor<4x3xf64>
-    %n = bufferization.to_memref %y : memref<4x3xf64>
     scf.for %i = %c0 to %c4 step %c1 {
-      %v2 = vector.transfer_read %n[%i, %c0], %du: memref<4x3xf64>, vector<3xf64>
+      %v2 = vector.transfer_read %y[%i, %c0], %du: tensor<4x3xf64>, vector<3xf64>
       vector.print %v2 : vector<3xf64>
     }
 
     // Release resources.
     sparse_tensor.release %a : tensor<3x4xf64, #DCSR>
     sparse_tensor.release %0 : tensor<4x3xf64, #DCSR>
-    sparse_tensor.release %1 : tensor<4x3xf64, #DCSR>
-    memref.dealloc %m : memref<4x3xf64>
-    memref.dealloc %n : memref<4x3xf64>
 
     return
   }

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_unary.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_unary.mlir
index af794da47b828..607cb238f6636 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_unary.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_unary.mlir
@@ -166,10 +166,8 @@ module {
     vector.print %1 : vector<32xf64>
     // Dump the dense vector to verify structure is correct.
     %dv = sparse_tensor.convert %arg0 : tensor<?xf64, #SparseVector> to tensor<?xf64>
-    %2 = bufferization.to_memref %dv : memref<?xf64>
-    %3 = vector.transfer_read %2[%c0], %d0: memref<?xf64>, vector<32xf64>
+    %3 = vector.transfer_read %dv[%c0], %d0: tensor<?xf64>, vector<32xf64>
     vector.print %3 : vector<32xf64>
-    memref.dealloc %2 : memref<?xf64>
     return
   }
 
@@ -183,10 +181,8 @@ module {
     vector.print %1 : vector<24xi32>
     // Dump the dense vector to verify structure is correct.
     %dv = sparse_tensor.convert %arg0 : tensor<?xi32, #SparseVector> to tensor<?xi32>
-    %2 = bufferization.to_memref %dv : memref<?xi32>
-    %3 = vector.transfer_read %2[%c0], %d0: memref<?xi32>, vector<32xi32>
+    %3 = vector.transfer_read %dv[%c0], %d0: tensor<?xi32>, vector<32xi32>
     vector.print %3 : vector<32xi32>
-    memref.dealloc %2 : memref<?xi32>
     return
   }
 
@@ -198,10 +194,8 @@ module {
     %1 = vector.transfer_read %0[%c0], %d0: memref<?xf64>, vector<16xf64>
     vector.print %1 : vector<16xf64>
     %dm = sparse_tensor.convert %arg0 : tensor<?x?xf64, #DCSR> to tensor<?x?xf64>
-    %2 = bufferization.to_memref %dm : memref<?x?xf64>
-    %3 = vector.transfer_read %2[%c0, %c0], %d0: memref<?x?xf64>, vector<4x8xf64>
+    %3 = vector.transfer_read %dm[%c0, %c0], %d0: tensor<?x?xf64>, vector<4x8xf64>
     vector.print %3 : vector<4x8xf64>
-    memref.dealloc %2 : memref<?x?xf64>
     return
   }
 

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_vector_ops.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_vector_ops.mlir
index e9b4e33f8325a..eef2a0d146c54 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_vector_ops.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_vector_ops.mlir
@@ -62,8 +62,7 @@ module {
   }
 
   // Scales a sparse vector in place.
-  func.func @vector_scale_inplace(%argx: tensor<?xf64, #SparseVector>
-                             {linalg.inplaceable = true}) -> tensor<?xf64, #SparseVector> {
+  func.func @vector_scale_inplace(%argx: tensor<?xf64, #SparseVector>) -> tensor<?xf64, #SparseVector> {
     %s = arith.constant 2.0 : f64
     %0 = linalg.generic #trait_scale_inpl
       outs(%argx: tensor<?xf64, #SparseVector>) {
@@ -125,7 +124,7 @@ module {
   // Sum reduces dot product of two sparse vectors.
   func.func @vector_dotprod(%arga: tensor<?xf64, #SparseVector>,
                        %argb: tensor<?xf64, #SparseVector>,
-                       %argx: tensor<f64> {linalg.inplaceable = true}) -> tensor<f64> {
+                       %argx: tensor<f64>) -> tensor<f64> {
     %0 = linalg.generic #trait_dot
        ins(%arga, %argb: tensor<?xf64, #SparseVector>, tensor<?xf64, #SparseVector>)
         outs(%argx: tensor<f64>) {
@@ -147,10 +146,8 @@ module {
     vector.print %1 : vector<16xf64>
     // Dump the dense vector to verify structure is correct.
     %dv = sparse_tensor.convert %arg0 : tensor<?xf64, #SparseVector> to tensor<?xf64>
-    %2 = bufferization.to_memref %dv : memref<?xf64>
-    %3 = vector.transfer_read %2[%c0], %d0: memref<?xf64>, vector<32xf64>
-    vector.print %3 : vector<32xf64>
-    memref.dealloc %2 : memref<?xf64>
+    %2 = vector.transfer_read %dv[%c0], %d0: tensor<?xf64>, vector<32xf64>
+    vector.print %2 : vector<32xf64>
     return
   }
 
@@ -169,36 +166,36 @@ module {
          [11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0 ]
     > : tensor<32xf64>
     %sv1 = sparse_tensor.convert %v1 : tensor<32xf64> to tensor<?xf64, #SparseVector>
+    // TODO: Use %sv1 when copying sparse tensors is supported.
+    %sv1_dup = sparse_tensor.convert %v1 : tensor<32xf64> to tensor<?xf64, #SparseVector>
     %sv2 = sparse_tensor.convert %v2 : tensor<32xf64> to tensor<?xf64, #SparseVector>
 
     // Setup memory for a single reduction scalar.
-    %xdata = memref.alloc() : memref<f64>
-    memref.store %d1, %xdata[] : memref<f64>
-    %x = bufferization.to_tensor %xdata : memref<f64>
+    %x = tensor.from_elements %d1 : tensor<f64>
 
     // Call sparse vector kernels.
     %0 = call @vector_scale(%sv1)
        : (tensor<?xf64, #SparseVector>) -> tensor<?xf64, #SparseVector>
-    %1 = call @vector_scale_inplace(%sv1)
+    %1 = call @vector_scale_inplace(%sv1_dup)
        : (tensor<?xf64, #SparseVector>) -> tensor<?xf64, #SparseVector>
-    %2 = call @vector_add(%sv1, %sv2)
+    %2 = call @vector_add(%1, %sv2)
        : (tensor<?xf64, #SparseVector>,
           tensor<?xf64, #SparseVector>) -> tensor<?xf64, #SparseVector>
-    %3 = call @vector_mul(%sv1, %sv2)
+    %3 = call @vector_mul(%1, %sv2)
        : (tensor<?xf64, #SparseVector>,
           tensor<?xf64, #SparseVector>) -> tensor<?xf64, #SparseVector>
-    %4 = call @vector_mul_d(%sv1, %sv2)
+    %4 = call @vector_mul_d(%1, %sv2)
        : (tensor<?xf64, #SparseVector>,
           tensor<?xf64, #SparseVector>) -> tensor<?xf64, #DenseVector>
-    %5 = call @vector_dotprod(%sv1, %sv2, %x)
+    %5 = call @vector_dotprod(%1, %sv2, %x)
        : (tensor<?xf64, #SparseVector>,
           tensor<?xf64, #SparseVector>, tensor<f64>) -> tensor<f64>
 
     //
     // Verify the results.
     //
-    // CHECK:      ( 2, 4, 6, 8, 10, 12, 14, 16, 18, -1, -1, -1, -1, -1, -1, -1 )
-    // CHECK-NEXT: ( 2, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 6, 0, 0, 0, 0, 0, 8, 0, 0, 10, 12, 0, 0, 0, 0, 0, 0, 14, 16, 0, 18 )
+    // CHECK:      ( 1, 2, 3, 4, 5, 6, 7, 8, 9, -1, -1, -1, -1, -1, -1, -1 )
+    // CHECK-NEXT: ( 1, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0, 4, 0, 0, 5, 6, 0, 0, 0, 0, 0, 0, 7, 8, 0, 9 )
     // CHECK-NEXT: ( 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, -1, -1, -1, -1, -1, -1 )
     // CHECK-NEXT: ( 0, 11, 0, 12, 13, 0, 0, 0, 0, 0, 14, 0, 0, 0, 0, 0, 15, 0, 16, 0, 0, 17, 0, 0, 0, 0, 0, 0, 18, 19, 0, 20 )
     // CHECK-NEXT: ( 2, 4, 6, 8, 10, 12, 14, 16, 18, -1, -1, -1, -1, -1, -1, -1 )
@@ -212,6 +209,7 @@ module {
     // CHECK-NEXT: ( 0, 0, 0, 48, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 204, 0, 0, 0, 0, 0, 0, 252, 304, 0, 360 )
     // CHECK-NEXT: 1169.1
     //
+
     call @dump(%sv1) : (tensor<?xf64, #SparseVector>) -> ()
     call @dump(%sv2) : (tensor<?xf64, #SparseVector>) -> ()
     call @dump(%0) : (tensor<?xf64, #SparseVector>) -> ()
@@ -227,12 +225,12 @@ module {
 
     // Release the resources.
     sparse_tensor.release %sv1 : tensor<?xf64, #SparseVector>
+    sparse_tensor.release %sv1_dup : tensor<?xf64, #SparseVector>
     sparse_tensor.release %sv2 : tensor<?xf64, #SparseVector>
     sparse_tensor.release %0 : tensor<?xf64, #SparseVector>
     sparse_tensor.release %2 : tensor<?xf64, #SparseVector>
     sparse_tensor.release %3 : tensor<?xf64, #SparseVector>
     sparse_tensor.release %4 : tensor<?xf64, #DenseVector>
-    memref.dealloc %xdata : memref<f64>
     return
   }
 }


        


More information about the Mlir-commits mailing list