[Mlir-commits] [mlir] 6232a8f - [mlir][sparse][NFC] Switch InitOp to bufferization::AllocTensorOp

Matthias Springer llvmlistbot at llvm.org
Wed Jun 1 15:07:08 PDT 2022


Author: Matthias Springer
Date: 2022-06-02T00:03:52+02:00
New Revision: 6232a8f3d61e5856c17e7b314385e9ea8068cdc1

URL: https://github.com/llvm/llvm-project/commit/6232a8f3d61e5856c17e7b314385e9ea8068cdc1
DIFF: https://github.com/llvm/llvm-project/commit/6232a8f3d61e5856c17e7b314385e9ea8068cdc1.diff

LOG: [mlir][sparse][NFC] Switch InitOp to bufferization::AllocTensorOp

Now that we have an AllocTensorOp (previously InitTensorOp) in the bufferization dialect, the InitOp in the sparse dialect is no longer needed.

Differential Revision: https://reviews.llvm.org/D126180

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
    mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
    mlir/lib/Dialect/Linalg/Transforms/SparseTensorRewriting.cpp
    mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
    mlir/test/Dialect/SparseTensor/conversion.mlir
    mlir/test/Dialect/SparseTensor/invalid.mlir
    mlir/test/Dialect/SparseTensor/roundtrip.mlir
    mlir/test/Dialect/SparseTensor/sparse_expand.mlir
    mlir/test/Dialect/SparseTensor/sparse_fp_ops.mlir
    mlir/test/Dialect/SparseTensor/sparse_index.mlir
    mlir/test/Dialect/SparseTensor/sparse_kernels.mlir
    mlir/test/Dialect/SparseTensor/sparse_out.mlir
    mlir/test/Integration/Dialect/SparseTensor/CPU/dense_output.mlir
    mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_binary.mlir
    mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_complex32.mlir
    mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_complex64.mlir
    mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_complex_ops.mlir
    mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_index.mlir
    mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_matmul.mlir
    mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_matrix_ops.mlir
    mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_out_mult_elt.mlir
    mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_out_reduction.mlir
    mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_re_im.mlir
    mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_sampled_mm_fusion.mlir
    mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_sign.mlir
    mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_tensor_ops.mlir
    mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_transpose.mlir
    mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_unary.mlir
    mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_vector_ops.mlir
    mlir/test/Integration/Dialect/SparseTensor/python/test_elementwise_add_sparse_output.py
    mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco.py

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
index 4ae773af33d7c..a86a44e0411f3 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
@@ -29,13 +29,30 @@ def Bufferization_AllocTensorOp : Bufferization_Op<"alloc_tensor",
   let summary = "buffer allocation in tensor land";
 
   let description = [{
-    `bufferization.alloc_tensor` is an operation that bufferizes to a buffer
-    allocation of a given shape. The shape could be dynamic or static.
-    Reading from the result of an `alloc_tensor` op yields an undefined value.
+    `bufferization.alloc_tensor` materializes an uninitialized tensor with a
+    given shape (dynamic or static). It always bufferizes to a new buffer
+    allocation of the given shape. Reading from the result of an `alloc_tensor`
+    op yields an undefined value.
+
+    `alloc_tensor` is a helper op for bufferization. The operation is provided
+    as an anchor that marks the beginning of a new tensor SSA use-def chain. It
+    can be used to control in-place bufferization decisions during One-Shot
+    Bufferize: The bufferized result of a `bufferization.alloc_tensor` does not
+    alias with any other buffer, so it can be used to resolve read-after-write
+    conflicts that would have been introduced by the in-place bufferization of
+    another op.
+
+    Both dense and sparse tensor types are supported. The result of a
+    `bufferization.alloc_tensor` is a tensor value that can be used like any
+    other tensor value. In practice, it is often used as the "out" operand of
+    another op. E.g.:
 
-    `alloc_tensor` is a helper op for bufferization. It marks the beginning of
-    a new tensor SSA use-def chain and is used to control in-place bufferization
-    decisions during One-Shot Bufferize.
+    ```mlir
+    %c = sparse_tensor.init_tensor [%d1, %d2] : tensor<?x?xf32, #SparseMatrix>
+    %0 = linalg.matmul
+      ins(%a, %b: tensor<?x?xf32>, tensor<?x?xf32>)
+      outs(%c: tensor<?x?xf32, #SparseMatrix>) -> tensor<?x?xf32, #SparseMatrix>
+    ```
   }];
 
   let arguments = (ins Variadic<Index>:$dynamicSizes);

diff  --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index c554330fa204f..4f31031b1fe84 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -49,29 +49,6 @@ def SparseTensor_NewOp : SparseTensor_Op<"new", [NoSideEffect]>,
   let hasVerifier = 1;
 }
 
-def SparseTensor_InitOp : SparseTensor_Op<"init", [NoSideEffect]>,
-    Arguments<(ins Variadic<Index>:$sizes)>,
-    Results<(outs AnyTensor:$result)> {
-  string summary = "Materializes an unitialized sparse tensor";
-  string description = [{
-    Materializes an uninitialized sparse tensor with given shape (either static
-    or dynamic). The operation is provided as an anchor that materializes a
-    properly typed but uninitialized sparse tensor into the output clause of
-    a subsequent operation that yields a sparse tensor as the result.
-
-    Example:
-
-    ```mlir
-    %c = sparse_tensor.init_tensor [%d1, %d2] : tensor<?x?xf32, #SparseMatrix>
-    %0 = linalg.matmul
-      ins(%a, %b: tensor<?x?xf32>, tensor<?x?xf32>)
-      outs(%c: tensor<?x?xf32, #SparseMatrix>) -> tensor<?x?xf32, #SparseMatrix>
-    ```
-  }];
-  let assemblyFormat = "`[` $sizes `]` attr-dict `:` type($result)";
-  let hasVerifier = 1;
-}
-
 def SparseTensor_ConvertOp : SparseTensor_Op<"convert",
   [NoSideEffect, SameOperandsAndResultElementType]>,
     Arguments<(ins AnyTensor:$source)>,
@@ -417,7 +394,7 @@ def SparseTensor_BinaryOp : SparseTensor_Op<"binary", [NoSideEffect]>,
 
       Example of isEqual applied to intersecting elements only.
       ```mlir
-      %C = sparse_tensor.init...
+      %C = bufferization.alloc_tensor...
       %0 = linalg.generic #trait
         ins(%A: tensor<?xf64, #SparseVec>, %B: tensor<?xf64, #SparseVec>)
         outs(%C: tensor<?xi8, #SparseVec>) {
@@ -438,7 +415,7 @@ def SparseTensor_BinaryOp : SparseTensor_Op<"binary", [NoSideEffect]>,
       Example of A+B in upper triangle, A-B in lower triangle
       (not working yet, but construct will be available soon).
       ```mlir
-      %C = sparse_tensor.init...
+      %C = bufferization.alloc_tensor...
       %1 = linalg.generic #trait
         ins(%A: tensor<?x?xf64, #CSR>, %B: tensor<?x?xf64, #CSR>
         outs(%C: tensor<?x?xf64, #CSR> {
@@ -470,7 +447,7 @@ def SparseTensor_BinaryOp : SparseTensor_Op<"binary", [NoSideEffect]>,
       is *not* overlapped by B. The element type of B can be 
diff erent than A
       because we never use its values, only its sparse structure.
       ```mlir
-      %C = sparse_tensor.init...
+      %C = bufferization.alloc_tensor...
       %2 = linalg.generic #trait
         ins(%A: tensor<?x?xf64, #CSR>, %B: tensor<?x?xi32, #CSR>
         outs(%C: tensor<?x?xf64, #CSR> {
@@ -517,7 +494,7 @@ def SparseTensor_UnaryOp : SparseTensor_Op<"unary", [NoSideEffect]>,
 
       Example of A+1, restricted to existing elements:
       ```mlir
-      %C = sparse_tensor.init...
+      %C = bufferization.alloc_tensor...
       %0 = linalg.generic #trait
         ins(%A: tensor<?xf64, #SparseVec>)
         outs(%C: tensor<?xf64, #SparseVec>) {

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/Linalg/Transforms/SparseTensorRewriting.cpp
index 7084baae8d2be..8494012315907 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/SparseTensorRewriting.cpp
@@ -16,6 +16,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
+#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
@@ -24,6 +25,7 @@
 #include "mlir/Support/LLVM.h"
 
 using namespace mlir;
+using namespace mlir::bufferization;
 using namespace mlir::linalg;
 using namespace mlir::sparse_tensor;
 
@@ -47,7 +49,8 @@ static bool isSparseTensor(OpOperand *op) {
 static bool isEmptyInit(OpOperand *op) {
   Value val = op->get();
   return matchPattern(val, m_Zero()) || matchPattern(val, m_AnyZeroFloat()) ||
-         val.getDefiningOp<InitTensorOp>() || val.getDefiningOp<InitOp>();
+         val.getDefiningOp<InitTensorOp>() ||
+         val.getDefiningOp<AllocTensorOp>();
 }
 
 // Helper to detect sampling operation.

diff  --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 0804280b231b9..b860f07528956 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -214,28 +214,6 @@ LogicalResult NewOp::verify() {
   return success();
 }
 
-LogicalResult InitOp::verify() {
-  if (!getSparseTensorEncoding(result().getType()))
-    return emitError("expected a sparse tensor result");
-  RankedTensorType ttp = getType().cast<RankedTensorType>();
-  unsigned rank = ttp.getRank();
-  if (rank != sizes().size())
-    return emitError("unexpected mismatch between tensor rank and sizes: ")
-           << rank << " vs. " << sizes().size();
-  auto shape = ttp.getShape();
-  for (unsigned i = 0; i < rank; i++) {
-    if (shape[i] == ShapedType::kDynamicSize)
-      continue;
-    IntegerAttr constantAttr;
-    if (!matchPattern(sizes()[i], m_Constant(&constantAttr)) ||
-        constantAttr.getInt() != shape[i]) {
-      return emitError("unexpected mismatch with static dimension size ")
-             << shape[i];
-    }
-  }
-  return success();
-}
-
 LogicalResult ConvertOp::verify() {
   if (auto tp1 = source().getType().dyn_cast<RankedTensorType>()) {
     if (auto tp2 = dest().getType().dyn_cast<RankedTensorType>()) {

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
index cae91b68be812..0e7a96e03f08c 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
@@ -459,22 +459,33 @@ class SparseTensorNewConverter : public OpConversionPattern<NewOp> {
   }
 };
 
-/// Sparse conversion rule for the init operator.
-class SparseTensorInitConverter : public OpConversionPattern<InitOp> {
+/// Sparse conversion rule for the alloc operator.
+class SparseTensorAllocConverter
+    : public OpConversionPattern<bufferization::AllocTensorOp> {
   using OpConversionPattern::OpConversionPattern;
   LogicalResult
-  matchAndRewrite(InitOp op, OpAdaptor adaptor,
+  matchAndRewrite(bufferization::AllocTensorOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    Type resType = op.getType();
+    RankedTensorType resType = op.getType();
     auto enc = getSparseTensorEncoding(resType);
     if (!enc)
       return failure();
+    // Gather all dimension sizes as SSA values.
+    SmallVector<Value> sizes;
+    unsigned int operandCtr = 0;
+    for (int64_t i = 0; i < resType.getRank(); ++i) {
+      if (resType.isDynamicDim(i)) {
+        sizes.push_back(adaptor.getOperands()[operandCtr++]);
+      } else {
+        sizes.push_back(rewriter.create<arith::ConstantIndexOp>(
+            op.getLoc(), op.getStaticSize(i)));
+      }
+    }
     // Generate the call to construct empty tensor. The sizes are
-    // explicitly defined by the arguments to the init operator.
+    // explicitly defined by the arguments to the alloc operator.
     SmallVector<Value, 8> params;
     ShapedType stp = resType.cast<ShapedType>();
-    newParams(rewriter, params, op, stp, enc, Action::kEmpty,
-              adaptor.getOperands());
+    newParams(rewriter, params, op, stp, enc, Action::kEmpty, sizes);
     rewriter.replaceOp(op, genNewCall(rewriter, op, params));
     return success();
   }
@@ -912,7 +923,7 @@ void mlir::populateSparseTensorConversionPatterns(
     const SparseTensorConversionOptions &options) {
   patterns.add<SparseReturnConverter, SparseTensorToDimSizeConverter,
                SparseCastConverter, SparseTensorNewConverter,
-               SparseTensorInitConverter, SparseTensorReleaseConverter,
+               SparseTensorAllocConverter, SparseTensorReleaseConverter,
                SparseTensorToPointersConverter, SparseTensorToIndicesConverter,
                SparseTensorToValuesConverter, SparseTensorLoadConverter,
                SparseTensorLexInsertConverter, SparseTensorExpandConverter,

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
index 8c1f296a39e98..5bbb4a5e1ddac 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
@@ -119,6 +119,7 @@ struct SparseTensorConversionPass
     target
         .addLegalDialect<bufferization::BufferizationDialect, LLVM::LLVMDialect,
                          memref::MemRefDialect, scf::SCFDialect>();
+    target.addIllegalOp<bufferization::AllocTensorOp>();
     // Translate strategy flags to strategy options.
     SparseTensorConversionOptions options(
         sparseToSparseConversionStrategy(sparseToSparse));

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index 6453cf99e09fe..b6aa7b99e54a4 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -293,7 +293,7 @@ static bool isInPlace(Value val) {
 /// Returns true if tensor materializes uninitialized into the computation.
 static bool isMaterializing(Value val) {
   return val.getDefiningOp<linalg::InitTensorOp>() ||
-         val.getDefiningOp<InitOp>();
+         val.getDefiningOp<bufferization::AllocTensorOp>();
 }
 
 /// Returns true when the tensor expression is admissable for codegen.

diff  --git a/mlir/test/Dialect/SparseTensor/conversion.mlir b/mlir/test/Dialect/SparseTensor/conversion.mlir
index 64d204601686e..5bc04016a15a3 100644
--- a/mlir/test/Dialect/SparseTensor/conversion.mlir
+++ b/mlir/test/Dialect/SparseTensor/conversion.mlir
@@ -135,7 +135,7 @@ func.func @sparse_new3d(%arg0: !llvm.ptr<i8>) -> tensor<?x?x?xf32, #SparseTensor
 //       CHECK: %[[T:.*]] = call @newSparseTensor(%[[X]], %[[Y]], %[[Z]], %{{.*}}, %{{.*}}, %{{.*}}, %[[Empty]], %[[NP]])
 //       CHECK: return %[[T]] : !llvm.ptr<i8>
 func.func @sparse_init(%arg0: index, %arg1: index) -> tensor<?x?xf64, #SparseMatrix> {
-  %0 = sparse_tensor.init [%arg0, %arg1] : tensor<?x?xf64, #SparseMatrix>
+  %0 = bufferization.alloc_tensor(%arg0, %arg1) : tensor<?x?xf64, #SparseMatrix>
   return %0 : tensor<?x?xf64, #SparseMatrix>
 }
 
@@ -512,8 +512,7 @@ func.func @sparse_insert(%arg0: tensor<128xf32, #SparseVector>,
 //   CHECK-DAG: linalg.fill ins(%{{.*}} : i1) outs(%[[B]] : memref<?xi1>)
 //       CHECK: return %[[C]] : memref<?xindex>
 func.func @sparse_expansion() -> memref<?xindex> {
-  %c = arith.constant 8 : index
-  %0 = sparse_tensor.init [%c, %c] : tensor<8x8xf64, #SparseMatrix>
+  %0 = bufferization.alloc_tensor() : tensor<8x8xf64, #SparseMatrix>
   %values, %filled, %added, %count = sparse_tensor.expand %0
     : tensor<8x8xf64, #SparseMatrix> to memref<?xf64>, memref<?xi1>, memref<?xindex>, index
   return %added : memref<?xindex>

diff  --git a/mlir/test/Dialect/SparseTensor/invalid.mlir b/mlir/test/Dialect/SparseTensor/invalid.mlir
index c840975c222d6..68f3cf3b7c5e6 100644
--- a/mlir/test/Dialect/SparseTensor/invalid.mlir
+++ b/mlir/test/Dialect/SparseTensor/invalid.mlir
@@ -16,36 +16,6 @@ func.func @invalid_release_dense(%arg0: tensor<4xi32>) {
 
 // -----
 
-func.func @invalid_init_dense(%arg0: index, %arg1: index) -> tensor<?x?xf32> {
-  // expected-error at +1 {{expected a sparse tensor result}}
-  %0 = sparse_tensor.init [%arg0, %arg1] : tensor<?x?xf32>
-  return %0 : tensor<?x?xf32>
-}
-
-// -----
-
-#SparseVector = #sparse_tensor.encoding<{dimLevelType = ["compressed"]}>
-
-func.func @invalid_init_rank(%arg0: index) -> tensor<?xf32, #SparseVector> {
-  // expected-error at +1 {{unexpected mismatch between tensor rank and sizes: 1 vs. 2}}
-  %0 = sparse_tensor.init [%arg0, %arg0] : tensor<?xf32, #SparseVector>
-  return %0 : tensor<?xf32, #SparseVector>
-}
-
-// -----
-
-#SparseMatrix = #sparse_tensor.encoding<{dimLevelType = ["compressed", "compressed"]}>
-
-func.func @invalid_init_size() -> tensor<?x10xf32, #SparseMatrix> {
-  %c10 = arith.constant 10 : index
-  %c20 = arith.constant 20 : index
-  // expected-error at +1 {{unexpected mismatch with static dimension size 10}}
-  %0 = sparse_tensor.init [%c10, %c20] : tensor<?x10xf32, #SparseMatrix>
-  return %0 : tensor<?x10xf32, #SparseMatrix>
-}
-
-// -----
-
 func.func @invalid_pointers_dense(%arg0: tensor<128xf64>) -> memref<?xindex> {
   %c = arith.constant 0 : index
   // expected-error at +1 {{expected a sparse tensor to get pointers}}

diff  --git a/mlir/test/Dialect/SparseTensor/roundtrip.mlir b/mlir/test/Dialect/SparseTensor/roundtrip.mlir
index 6b7e7619e78bc..f347b030045f9 100644
--- a/mlir/test/Dialect/SparseTensor/roundtrip.mlir
+++ b/mlir/test/Dialect/SparseTensor/roundtrip.mlir
@@ -13,22 +13,6 @@ func.func @sparse_new(%arg0: !llvm.ptr<i8>) -> tensor<128xf64, #SparseVector> {
 
 // -----
 
-#SparseMatrix = #sparse_tensor.encoding<{dimLevelType = ["compressed", "compressed"]}>
-
-// CHECK-LABEL: func @sparse_init()
-//   CHECK-DAG: %[[C16:.*]] = arith.constant 16 : index
-//   CHECK-DAG: %[[C32:.*]] = arith.constant 32 : index
-//       CHECK: %[[T:.*]] = sparse_tensor.init[%[[C16]], %[[C32]]] : tensor<?x32xf64, #{{.*}}>
-//       CHECK: return %[[T]] : tensor<?x32xf64, #{{.*}}>
-func.func @sparse_init() -> tensor<?x32xf64, #SparseMatrix> {
-  %d1 = arith.constant 16 : index
-  %d2 = arith.constant 32 : index
-  %0 = sparse_tensor.init [%d1, %d2] : tensor<?x32xf64, #SparseMatrix>
-  return %0 : tensor<?x32xf64, #SparseMatrix>
-}
-
-// -----
-
 #SparseVector = #sparse_tensor.encoding<{dimLevelType = ["compressed"]}>
 
 // CHECK-LABEL: func @sparse_release(

diff  --git a/mlir/test/Dialect/SparseTensor/sparse_expand.mlir b/mlir/test/Dialect/SparseTensor/sparse_expand.mlir
index c6621efc859e4..6b8a0f8009073 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_expand.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_expand.mlir
@@ -49,7 +49,7 @@
 func.func @kernel(%arga: tensor<?x?xf64, #DCSC>) -> tensor<?xf64, #SV> {
   %c0 = arith.constant 0 : index
   %n = tensor.dim %arga, %c0 : tensor<?x?xf64, #DCSC>
-  %v = sparse_tensor.init [%n] : tensor<?xf64, #SV>
+  %v = bufferization.alloc_tensor(%n) : tensor<?xf64, #SV>
   %0 = linalg.generic #rowsum
     ins(%arga: tensor<?x?xf64, #DCSC>)
     outs(%v: tensor<?xf64, #SV>) {

diff  --git a/mlir/test/Dialect/SparseTensor/sparse_fp_ops.mlir b/mlir/test/Dialect/SparseTensor/sparse_fp_ops.mlir
index 7f49401e64a80..9b7506c0a9064 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_fp_ops.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_fp_ops.mlir
@@ -355,8 +355,7 @@ func.func @divbyc(%arga: tensor<32xf64, #SV>,
 // CHECK-SAME:    %[[VAL_0:.*]]: tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>> {
 // CHECK-DAG:     %[[VAL_1:.*]] = arith.constant 0 : index
 // CHECK-DAG:     %[[VAL_2:.*]] = arith.constant 1 : index
-// CHECK-DAG:     %[[VAL_3:.*]] = arith.constant 32 : index
-// CHECK:         %[[VAL_4:.*]] = sparse_tensor.init{{\[}}%[[VAL_3]]] : tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK:         %[[VAL_4:.*]] = bufferization.alloc_tensor() : tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>>
 // CHECK:         %[[VAL_5:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_1]] : tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
 // CHECK:         %[[VAL_6:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_1]] : tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
 // CHECK:         %[[VAL_7:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xf64>
@@ -382,7 +381,7 @@ func.func @divbyc(%arga: tensor<32xf64, #SV>,
 // CHECK:       }
 func.func @zero_preserving_math(%arga: tensor<32xf64, #SV>) -> tensor<32xf64, #SV> {
   %c32 = arith.constant 32 : index
-  %xinp = sparse_tensor.init [%c32] : tensor<32xf64, #SV>
+  %xinp = bufferization.alloc_tensor() : tensor<32xf64, #SV>
   %0 = linalg.generic #trait1
      ins(%arga: tensor<32xf64, #SV>)
     outs(%xinp: tensor<32xf64, #SV>) {

diff  --git a/mlir/test/Dialect/SparseTensor/sparse_index.mlir b/mlir/test/Dialect/SparseTensor/sparse_index.mlir
index aa6d3c885e3a4..bafeafaefbf0b 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_index.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_index.mlir
@@ -24,7 +24,7 @@
 // CHECK-DAG:       %[[VAL_2:.*]] = arith.constant 1 : index
 // CHECK-DAG:       %[[VAL_3:.*]] = tensor.dim %[[VAL_0]], %[[VAL_1]] : tensor<?x?xi64, #sparse_tensor.encoding
 // CHECK-DAG:       %[[VAL_4:.*]] = tensor.dim %[[VAL_0]], %[[VAL_1]] : tensor<?x?xi64, #sparse_tensor.encoding
-// CHECK-DAG:       %[[VAL_5:.*]] = sparse_tensor.init{{\[}}%[[VAL_3]], %[[VAL_4]]] : tensor<?x?xi64, #sparse_tensor.encoding
+// CHECK-DAG:       %[[VAL_5:.*]] = bufferization.alloc_tensor(%[[VAL_3]], %[[VAL_4]]) : tensor<?x?xi64, #sparse_tensor.encoding
 // CHECK-DAG:       %[[VAL_6:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<?x?xi64, #sparse_tensor.encoding
 // CHECK-DAG:       %[[VAL_7:.*]] = tensor.dim %[[VAL_5]], %[[VAL_1]] : tensor<?x?xi64, #sparse_tensor.encoding
 // CHECK-DAG:       %[[VAL_8:.*]] = tensor.dim %[[VAL_5]], %[[VAL_2]] : tensor<?x?xi64, #sparse_tensor.encoding
@@ -52,7 +52,7 @@ func.func @dense_index(%arga: tensor<?x?xi64, #DenseMatrix>)
   %c1 = arith.constant 0 : index
   %0 = tensor.dim %arga, %c0 : tensor<?x?xi64, #DenseMatrix>
   %1 = tensor.dim %arga, %c1 : tensor<?x?xi64, #DenseMatrix>
-  %init = sparse_tensor.init [%0, %1] : tensor<?x?xi64, #DenseMatrix>
+  %init = bufferization.alloc_tensor(%0, %1) : tensor<?x?xi64, #DenseMatrix>
   %r = linalg.generic #trait
       ins(%arga: tensor<?x?xi64, #DenseMatrix>)
      outs(%init: tensor<?x?xi64, #DenseMatrix>) {
@@ -75,7 +75,7 @@ func.func @dense_index(%arga: tensor<?x?xi64, #DenseMatrix>)
 // CHECK-DAG:       %[[VAL_3:.*]] = arith.constant 2 : index
 // CHECK-DAG:       %[[VAL_4:.*]] = tensor.dim %[[VAL_0]], %[[VAL_1]] : tensor<?x?xi64, #sparse_tensor.encoding
 // CHECK-DAG:       %[[VAL_5:.*]] = tensor.dim %[[VAL_0]], %[[VAL_1]] : tensor<?x?xi64, #sparse_tensor.encoding
-// CHECK-DAG:       %[[VAL_6:.*]] = sparse_tensor.init{{\[}}%[[VAL_4]], %[[VAL_5]]] : tensor<?x?xi64, #sparse_tensor.encoding
+// CHECK-DAG:       %[[VAL_6:.*]] = bufferization.alloc_tensor(%[[VAL_4]], %[[VAL_5]]) : tensor<?x?xi64, #sparse_tensor.encoding
 // CHECK-DAG:       %[[VAL_7:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_1]] : tensor<?x?xi64, #sparse_tensor.encoding
 // CHECK-DAG:       %[[VAL_8:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_1]] : tensor<?x?xi64, #sparse_tensor.encoding
 // CHECK-DAG:       %[[VAL_9:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_2]] : tensor<?x?xi64, #sparse_tensor.encoding
@@ -110,7 +110,7 @@ func.func @sparse_index(%arga: tensor<?x?xi64, #SparseMatrix>)
   %c1 = arith.constant 0 : index
   %0 = tensor.dim %arga, %c0 : tensor<?x?xi64, #SparseMatrix>
   %1 = tensor.dim %arga, %c1 : tensor<?x?xi64, #SparseMatrix>
-  %init = sparse_tensor.init [%0, %1] : tensor<?x?xi64, #SparseMatrix>
+  %init = bufferization.alloc_tensor(%0, %1) : tensor<?x?xi64, #SparseMatrix>
   %r = linalg.generic #trait
       ins(%arga: tensor<?x?xi64, #SparseMatrix>)
      outs(%init: tensor<?x?xi64, #SparseMatrix>) {

diff  --git a/mlir/test/Dialect/SparseTensor/sparse_kernels.mlir b/mlir/test/Dialect/SparseTensor/sparse_kernels.mlir
index 172882e174a27..af99f00a4807b 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_kernels.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_kernels.mlir
@@ -60,13 +60,12 @@ func.func @matmul1(%a: tensor<10x20xf32, #DCSR>,
 // CHECK-LABEL:   func @matmul2(
 // CHECK-SAME:      %[[VAL_0:.*]]: tensor<4x8xf64, #sparse_tensor.encoding<{{{.*}}}>>,
 // CHECK-SAME:      %[[VAL_1:.*]]: tensor<8x4xf64, #sparse_tensor.encoding<{{{.*}}}>> {
-// CHECK-DAG:       %[[VAL_2:.*]] = arith.constant 4 : index
 // CHECK-DAG:       %[[VAL_3:.*]] = arith.constant 0 : index
 // CHECK-DAG:       %[[VAL_4:.*]] = arith.constant 1 : index
 // CHECK-DAG:       %[[VAL_5:.*]] = arith.constant 2 : index
 // CHECK-DAG:       %[[VAL_6:.*]] = arith.constant false
 // CHECK-DAG:       %[[VAL_7:.*]] = arith.constant true
-// CHECK:           %[[VAL_8:.*]] = sparse_tensor.init{{\[}}%[[VAL_2]], %[[VAL_2]]] : tensor<4x4xf64, #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK:           %[[VAL_8:.*]] = bufferization.alloc_tensor() : tensor<4x4xf64, #sparse_tensor.encoding<{{{.*}}}>>
 // CHECK:           %[[VAL_9:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_3]] : tensor<4x8xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
 // CHECK:           %[[VAL_10:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_3]] : tensor<4x8xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
 // CHECK:           %[[VAL_11:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_4]] : tensor<4x8xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
@@ -147,7 +146,7 @@ func.func @matmul1(%a: tensor<10x20xf32, #DCSR>,
 func.func @matmul2(%A: tensor<4x8xf64, #DCSR>,
               %B: tensor<8x4xf64, #DCSR>) -> tensor<4x4xf64, #DCSR> {
   %c4 = arith.constant 4 : index
-  %C = sparse_tensor.init [%c4, %c4] : tensor<4x4xf64, #DCSR>
+  %C = bufferization.alloc_tensor() : tensor<4x4xf64, #DCSR>
   %D = linalg.matmul
     ins(%A, %B: tensor<4x8xf64, #DCSR>, tensor<8x4xf64, #DCSR>)
        outs(%C: tensor<4x4xf64, #DCSR>) -> tensor<4x4xf64, #DCSR>

diff  --git a/mlir/test/Dialect/SparseTensor/sparse_out.mlir b/mlir/test/Dialect/SparseTensor/sparse_out.mlir
index 7bb7f3229f3ed..167b778b17f8e 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_out.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_out.mlir
@@ -103,11 +103,10 @@ func.func @sparse_simply_dynamic2(%argx: tensor<32x16xf32, #DCSR> {linalg.inplac
 // CHECK-SAME:      %[[VAL_0:.*]]: tensor<10x20xf32, #sparse_tensor.encoding<{{.*}}>>
 // CHECK-DAG:       %[[VAL_1:.*]] = arith.constant 2.000000e+00 : f32
 // CHECK-DAG:       %[[VAL_2:.*]] = arith.constant 10 : index
-// CHECK-DAG:       %[[VAL_3:.*]] = arith.constant 20 : index
 // CHECK-DAG:       %[[VAL_4:.*]] = arith.constant 1 : index
 // CHECK-DAG:       %[[VAL_5:.*]] = arith.constant 2 : index
 // CHECK-DAG:       %[[VAL_6:.*]] = arith.constant 0 : index
-// CHECK:           %[[VAL_7:.*]] = sparse_tensor.init{{\[}}%[[VAL_2]], %[[VAL_3]]] : tensor<10x20xf32, #sparse_tensor.encoding<{{.*}}>>
+// CHECK:           %[[VAL_7:.*]] = bufferization.alloc_tensor() : tensor<10x20xf32, #sparse_tensor.encoding<{{.*}}>>
 // CHECK:           %[[VAL_8:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_4]] : tensor<10x20xf32, #sparse_tensor.encoding<{{.*}}>>
 // CHECK:           %[[VAL_9:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_4]] : tensor<10x20xf32, #sparse_tensor.encoding<{{.*}}>>
 // CHECK:           %[[VAL_10:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<10x20xf32, #sparse_tensor.encoding<{{.*}}>>
@@ -130,9 +129,7 @@ func.func @sparse_simply_dynamic2(%argx: tensor<32x16xf32, #DCSR> {linalg.inplac
 // CHECK:         }
 func.func @sparse_truly_dynamic(%arga: tensor<10x20xf32, #CSR>) -> tensor<10x20xf32, #DCSR> {
   %s = arith.constant 2.0 : f32
-  %d10 = arith.constant 10 : index
-  %d20 = arith.constant 20 : index
-  %xm = sparse_tensor.init [%d10, %d20] : tensor<10x20xf32, #DCSR>
+  %xm = bufferization.alloc_tensor() : tensor<10x20xf32, #DCSR>
   %0 = linalg.generic #trait_scale
      ins(%arga: tensor<10x20xf32, #CSR>)
       outs(%xm: tensor<10x20xf32, #DCSR>) {
@@ -162,7 +159,7 @@ func.func @sparse_truly_dynamic(%arga: tensor<10x20xf32, #CSR>) -> tensor<10x20x
 // CHECK-DAG:       %[[VAL_5:.*]] = arith.constant 0 : i32
 // CHECK:           %[[VAL_6:.*]] = tensor.dim %[[VAL_0]], %[[VAL_2]] : tensor<?x?x?xi32, #{{.*}}>>
 // CHECK:           %[[VAL_7:.*]] = tensor.dim %[[VAL_0]], %[[VAL_3]] : tensor<?x?x?xi32, #{{.*}}>>
-// CHECK:           %[[VAL_8:.*]] = sparse_tensor.init{{\[}}%[[VAL_6]], %[[VAL_7]]] : tensor<?x?xi32, #{{.*}}>>
+// CHECK:           %[[VAL_8:.*]] = bufferization.alloc_tensor(%[[VAL_6]], %[[VAL_7]]) : tensor<?x?xi32, #{{.*}}>>
 // CHECK:           %[[VAL_9:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_2]] : tensor<?x?x?xi32, #{{.*}}>> to memref<?xindex>
 // CHECK:           %[[VAL_10:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_2]] : tensor<?x?x?xi32, #{{.*}}>> to memref<?xindex>
 // CHECK:           %[[VAL_11:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_3]] : tensor<?x?x?xi32, #{{.*}}>> to memref<?xindex>
@@ -288,7 +285,7 @@ func.func @sumred(%arga: tensor<?x?x?xi32, #SparseTensor>,
   %c1 = arith.constant 1 : index
   %d0 = tensor.dim %arga, %c0 : tensor<?x?x?xi32, #SparseTensor>
   %d1 = tensor.dim %arga, %c1 : tensor<?x?x?xi32, #SparseTensor>
-  %xinit = sparse_tensor.init [%d0, %d1] : tensor<?x?xi32, #DCSR>
+  %xinit = bufferization.alloc_tensor(%d0, %d1) : tensor<?x?xi32, #DCSR>
   %0 = linalg.generic #trait_sumred
     ins(%arga, %argb: tensor<?x?x?xi32, #SparseTensor>,
                       tensor<?x?x?xi32, #SparseTensor>)
@@ -321,7 +318,7 @@ func.func @sumred(%arga: tensor<?x?x?xi32, #SparseTensor>,
 // CHECK-DAG:       %[[VAL_6:.*]] = arith.constant true
 // CHECK:           %[[VAL_7:.*]] = tensor.dim %[[VAL_0]], %[[VAL_2]] : tensor<?x?xf32, #sparse_tensor.encoding<{{{.*}}}>>
 // CHECK:           %[[VAL_8:.*]] = tensor.dim %[[VAL_1]], %[[VAL_3]] : tensor<?x?xf32, #sparse_tensor.encoding<{{{.*}}}>>
-// CHECK:           %[[VAL_9:.*]] = sparse_tensor.init{{\[}}%[[VAL_7]], %[[VAL_8]]] : tensor<?x?xf32, #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK:           %[[VAL_9:.*]] = bufferization.alloc_tensor(%[[VAL_7]], %[[VAL_8]]) : tensor<?x?xf32, #sparse_tensor.encoding<{{{.*}}}>>
 // CHECK:           %[[VAL_10:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_2]] : tensor<?x?xf32, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
 // CHECK:           %[[VAL_11:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_2]] : tensor<?x?xf32, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
 // CHECK:           %[[VAL_12:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_3]] : tensor<?x?xf32, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
@@ -405,7 +402,7 @@ func.func @matmat(%arga: tensor<?x?xf32, #DCSR>,
   %c1 = arith.constant 1 : index
   %d0 = tensor.dim %arga, %c0 : tensor<?x?xf32, #DCSR>
   %d1 = tensor.dim %argb, %c1 : tensor<?x?xf32, #DCSR>
-  %cinit = sparse_tensor.init [%d0, %d1] : tensor<?x?xf32, #DCSR>
+  %cinit = bufferization.alloc_tensor(%d0, %d1) : tensor<?x?xf32, #DCSR>
   %0 = linalg.generic #trait_matmat
        ins(%arga, %argb: tensor<?x?xf32, #DCSR>,
                          tensor<?x?xf32, #DCSR>)

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/dense_output.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/dense_output.mlir
index c7d2dcf58c9f0..0018c07b342c5 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/dense_output.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/dense_output.mlir
@@ -46,7 +46,7 @@ module {
     %c1 = arith.constant 1 : index
     %d0 = tensor.dim %arga, %c0 : tensor<?x?xf64, #SparseMatrix>
     %d1 = tensor.dim %arga, %c1 : tensor<?x?xf64, #SparseMatrix>
-    %init = sparse_tensor.init [%d0, %d1] : tensor<?x?xf64, #DenseMatrix>
+    %init = bufferization.alloc_tensor(%d0, %d1) : tensor<?x?xf64, #DenseMatrix>
     %0 = linalg.generic #trait_assign
        ins(%arga: tensor<?x?xf64, #SparseMatrix>)
       outs(%init: tensor<?x?xf64, #DenseMatrix>) {

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_binary.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_binary.mlir
index ee61ddca536a5..5d67f3d363a88 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_binary.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_binary.mlir
@@ -42,7 +42,7 @@ module {
                         %argb: tensor<?xf64, #SparseVector>) -> tensor<?xf64, #SparseVector> {
     %c = arith.constant 0 : index
     %d = tensor.dim %arga, %c : tensor<?xf64, #SparseVector>
-    %xv = sparse_tensor.init [%d] : tensor<?xf64, #SparseVector>
+    %xv = bufferization.alloc_tensor(%d) : tensor<?xf64, #SparseVector>
     %0 = linalg.generic #trait_vec_op
        ins(%arga, %argb: tensor<?xf64, #SparseVector>, tensor<?xf64, #SparseVector>)
         outs(%xv: tensor<?xf64, #SparseVector>) {
@@ -67,7 +67,7 @@ module {
                         %argb: tensor<?xf64>) -> tensor<?xf64, #SparseVector> {
     %c = arith.constant 0 : index
     %d = tensor.dim %arga, %c : tensor<?xf64, #SparseVector>
-    %xv = sparse_tensor.init [%d] : tensor<?xf64, #SparseVector>
+    %xv = bufferization.alloc_tensor(%d) : tensor<?xf64, #SparseVector>
     %0 = linalg.generic #trait_vec_op
        ins(%arga, %argb: tensor<?xf64, #SparseVector>, tensor<?xf64>)
         outs(%xv: tensor<?xf64, #SparseVector>) {
@@ -91,7 +91,7 @@ module {
                             %argb: tensor<?xf64, #SparseVector>) -> tensor<?xf64, #SparseVector> {
     %c = arith.constant 0 : index
     %d = tensor.dim %arga, %c : tensor<?xf64, #SparseVector>
-    %xv = sparse_tensor.init [%d] : tensor<?xf64, #SparseVector>
+    %xv = bufferization.alloc_tensor(%d) : tensor<?xf64, #SparseVector>
     %0 = linalg.generic #trait_vec_op
        ins(%arga, %argb: tensor<?xf64, #SparseVector>, tensor<?xf64, #SparseVector>)
         outs(%xv: tensor<?xf64, #SparseVector>) {
@@ -109,7 +109,7 @@ module {
   func.func @vector_index(%arga: tensor<?xf64, #SparseVector>) -> tensor<?xi32, #SparseVector> {
     %c = arith.constant 0 : index
     %d = tensor.dim %arga, %c : tensor<?xf64, #SparseVector>
-    %xv = sparse_tensor.init [%d] : tensor<?xi32, #SparseVector>
+    %xv = bufferization.alloc_tensor(%d) : tensor<?xi32, #SparseVector>
     %0 = linalg.generic #trait_vec_scale
        ins(%arga: tensor<?xf64, #SparseVector>)
         outs(%xv: tensor<?xi32, #SparseVector>) {
@@ -136,7 +136,7 @@ module {
     %c1 = arith.constant 1 : index
     %d0 = tensor.dim %arga, %c0 : tensor<?x?xf64, #DCSR>
     %d1 = tensor.dim %arga, %c1 : tensor<?x?xf64, #DCSR>
-    %xv = sparse_tensor.init [%d0, %d1] : tensor<?x?xf64, #DCSR>
+    %xv = bufferization.alloc_tensor(%d0, %d1) : tensor<?x?xf64, #DCSR>
     %0 = linalg.generic #trait_mat_op
        ins(%arga, %argb: tensor<?x?xf64, #DCSR>, tensor<?x?xf64, #DCSR>)
         outs(%xv: tensor<?x?xf64, #DCSR>) {
@@ -291,4 +291,4 @@ module {
     sparse_tensor.release %5 : tensor<?x?xf64, #DCSR>
     return
   }
-}
\ No newline at end of file
+}

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_complex32.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_complex32.mlir
index 976f1c5f778c5..82a00591ecf2c 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_complex32.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_complex32.mlir
@@ -22,7 +22,7 @@ module {
                       -> tensor<?xcomplex<f32>, #SparseVector> {
     %c = arith.constant 0 : index
     %d = tensor.dim %arga, %c : tensor<?xcomplex<f32>, #SparseVector>
-    %xv = sparse_tensor.init [%d] : tensor<?xcomplex<f32>, #SparseVector>
+    %xv = bufferization.alloc_tensor(%d) : tensor<?xcomplex<f32>, #SparseVector>
     %0 = linalg.generic #trait_op
        ins(%arga, %argb: tensor<?xcomplex<f32>, #SparseVector>,
                          tensor<?xcomplex<f32>, #SparseVector>)
@@ -39,7 +39,7 @@ module {
                       -> tensor<?xcomplex<f32>, #SparseVector> {
     %c = arith.constant 0 : index
     %d = tensor.dim %arga, %c : tensor<?xcomplex<f32>, #SparseVector>
-    %xv = sparse_tensor.init [%d] : tensor<?xcomplex<f32>, #SparseVector>
+    %xv = bufferization.alloc_tensor(%d) : tensor<?xcomplex<f32>, #SparseVector>
     %0 = linalg.generic #trait_op
        ins(%arga, %argb: tensor<?xcomplex<f32>, #SparseVector>,
                          tensor<?xcomplex<f32>, #SparseVector>)

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_complex64.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_complex64.mlir
index f544fae1449b9..700c4336f3730 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_complex64.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_complex64.mlir
@@ -22,7 +22,7 @@ module {
                       -> tensor<?xcomplex<f64>, #SparseVector> {
     %c = arith.constant 0 : index
     %d = tensor.dim %arga, %c : tensor<?xcomplex<f64>, #SparseVector>
-    %xv = sparse_tensor.init [%d] : tensor<?xcomplex<f64>, #SparseVector>
+    %xv = bufferization.alloc_tensor(%d) : tensor<?xcomplex<f64>, #SparseVector>
     %0 = linalg.generic #trait_op
        ins(%arga, %argb: tensor<?xcomplex<f64>, #SparseVector>,
                          tensor<?xcomplex<f64>, #SparseVector>)
@@ -39,7 +39,7 @@ module {
                       -> tensor<?xcomplex<f64>, #SparseVector> {
     %c = arith.constant 0 : index
     %d = tensor.dim %arga, %c : tensor<?xcomplex<f64>, #SparseVector>
-    %xv = sparse_tensor.init [%d] : tensor<?xcomplex<f64>, #SparseVector>
+    %xv = bufferization.alloc_tensor(%d) : tensor<?xcomplex<f64>, #SparseVector>
     %0 = linalg.generic #trait_op
        ins(%arga, %argb: tensor<?xcomplex<f64>, #SparseVector>,
                          tensor<?xcomplex<f64>, #SparseVector>)

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_complex_ops.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_complex_ops.mlir
index fa3d4c62b82cf..ba13e1df10b31 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_complex_ops.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_complex_ops.mlir
@@ -31,7 +31,7 @@ module {
                  -> tensor<?xcomplex<f64>, #SparseVector> {
     %c0 = arith.constant 0 : index
     %d = tensor.dim %arga, %c0 : tensor<?xcomplex<f64>, #SparseVector>
-    %xv = sparse_tensor.init [%d] : tensor<?xcomplex<f64>, #SparseVector>
+    %xv = bufferization.alloc_tensor(%d) : tensor<?xcomplex<f64>, #SparseVector>
     %0 = linalg.generic #trait_op2
        ins(%arga, %argb: tensor<?xcomplex<f64>, #SparseVector>,
                          tensor<?xcomplex<f64>, #SparseVector>)
@@ -48,7 +48,7 @@ module {
                  -> tensor<?xcomplex<f64>, #SparseVector> {
     %c0 = arith.constant 0 : index
     %d = tensor.dim %arga, %c0 : tensor<?xcomplex<f64>, #SparseVector>
-    %xv = sparse_tensor.init [%d] : tensor<?xcomplex<f64>, #SparseVector>
+    %xv = bufferization.alloc_tensor(%d) : tensor<?xcomplex<f64>, #SparseVector>
     %0 = linalg.generic #trait_op1
        ins(%arga: tensor<?xcomplex<f64>, #SparseVector>)
         outs(%xv: tensor<?xcomplex<f64>, #SparseVector>) {
@@ -63,7 +63,7 @@ module {
                  -> tensor<?xcomplex<f64>, #SparseVector> {
     %c0 = arith.constant 0 : index
     %d = tensor.dim %arga, %c0 : tensor<?xcomplex<f64>, #SparseVector>
-    %xv = sparse_tensor.init [%d] : tensor<?xcomplex<f64>, #SparseVector>
+    %xv = bufferization.alloc_tensor(%d) : tensor<?xcomplex<f64>, #SparseVector>
     %0 = linalg.generic #trait_op1
        ins(%arga: tensor<?xcomplex<f64>, #SparseVector>)
         outs(%xv: tensor<?xcomplex<f64>, #SparseVector>) {
@@ -78,7 +78,7 @@ module {
                  -> tensor<?xcomplex<f64>, #SparseVector> {
     %c0 = arith.constant 0 : index
     %d = tensor.dim %arga, %c0 : tensor<?xcomplex<f64>, #SparseVector>
-    %xv = sparse_tensor.init [%d] : tensor<?xcomplex<f64>, #SparseVector>
+    %xv = bufferization.alloc_tensor(%d) : tensor<?xcomplex<f64>, #SparseVector>
     %0 = linalg.generic #trait_op1
        ins(%arga: tensor<?xcomplex<f64>, #SparseVector>)
         outs(%xv: tensor<?xcomplex<f64>, #SparseVector>) {
@@ -93,7 +93,7 @@ module {
                  -> tensor<?xcomplex<f64>, #SparseVector> {
     %c0 = arith.constant 0 : index
     %d = tensor.dim %arga, %c0 : tensor<?xcomplex<f64>, #SparseVector>
-    %xv = sparse_tensor.init [%d] : tensor<?xcomplex<f64>, #SparseVector>
+    %xv = bufferization.alloc_tensor(%d) : tensor<?xcomplex<f64>, #SparseVector>
     %0 = linalg.generic #trait_op1
        ins(%arga: tensor<?xcomplex<f64>, #SparseVector>)
         outs(%xv: tensor<?xcomplex<f64>, #SparseVector>) {
@@ -109,7 +109,7 @@ module {
                  -> tensor<?xcomplex<f64>, #SparseVector> {
     %c0 = arith.constant 0 : index
     %d = tensor.dim %arga, %c0 : tensor<?xcomplex<f64>, #SparseVector>
-    %xv = sparse_tensor.init [%d] : tensor<?xcomplex<f64>, #SparseVector>
+    %xv = bufferization.alloc_tensor(%d) : tensor<?xcomplex<f64>, #SparseVector>
     %c = complex.constant [2.0 : f64, 0.0 : f64] : complex<f64>
     %0 = linalg.generic #trait_op1
        ins(%arga: tensor<?xcomplex<f64>, #SparseVector>)
@@ -125,7 +125,7 @@ module {
                  -> tensor<?xf64, #SparseVector> {
     %c0 = arith.constant 0 : index
     %d = tensor.dim %arga, %c0 : tensor<?xcomplex<f64>, #SparseVector>
-    %xv = sparse_tensor.init [%d] : tensor<?xf64, #SparseVector>
+    %xv = bufferization.alloc_tensor(%d) : tensor<?xf64, #SparseVector>
     %0 = linalg.generic #trait_op1
        ins(%arga: tensor<?xcomplex<f64>, #SparseVector>)
         outs(%xv: tensor<?xf64, #SparseVector>) {

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_index.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_index.mlir
index 2f663e4fa8bf0..3abc138ed6ec7 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_index.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_index.mlir
@@ -40,8 +40,7 @@ module {
   //
   func.func @sparse_index_1d_conj(%arga: tensor<8xi64, #SparseVector>)
                                  -> tensor<8xi64, #SparseVector> {
-    %d0 = arith.constant 8 : index
-    %init = sparse_tensor.init [%d0] : tensor<8xi64, #SparseVector>
+    %init = bufferization.alloc_tensor() : tensor<8xi64, #SparseVector>
     %r = linalg.generic #trait_1d
         ins(%arga: tensor<8xi64, #SparseVector>)
        outs(%init: tensor<8xi64, #SparseVector>) {
@@ -59,8 +58,7 @@ module {
   //
   func.func @sparse_index_1d_disj(%arga: tensor<8xi64, #SparseVector>)
                                  -> tensor<8xi64, #SparseVector> {
-    %d0 = arith.constant 8 : index
-    %init = sparse_tensor.init [%d0] : tensor<8xi64, #SparseVector>
+    %init = bufferization.alloc_tensor() : tensor<8xi64, #SparseVector>
     %r = linalg.generic #trait_1d
         ins(%arga: tensor<8xi64, #SparseVector>)
        outs(%init: tensor<8xi64, #SparseVector>) {
@@ -78,9 +76,7 @@ module {
   //
   func.func @sparse_index_2d_conj(%arga: tensor<3x4xi64, #SparseMatrix>)
                                  -> tensor<3x4xi64, #SparseMatrix> {
-    %d0 = arith.constant 3 : index
-    %d1 = arith.constant 4 : index
-    %init = sparse_tensor.init [%d0, %d1] : tensor<3x4xi64, #SparseMatrix>
+    %init = bufferization.alloc_tensor() : tensor<3x4xi64, #SparseMatrix>
     %r = linalg.generic #trait_2d
         ins(%arga: tensor<3x4xi64, #SparseMatrix>)
        outs(%init: tensor<3x4xi64, #SparseMatrix>) {
@@ -101,9 +97,7 @@ module {
   //
   func.func @sparse_index_2d_disj(%arga: tensor<3x4xi64, #SparseMatrix>)
                                  -> tensor<3x4xi64, #SparseMatrix> {
-    %d0 = arith.constant 3 : index
-    %d1 = arith.constant 4 : index
-    %init = sparse_tensor.init [%d0, %d1] : tensor<3x4xi64, #SparseMatrix>
+    %init = bufferization.alloc_tensor() : tensor<3x4xi64, #SparseMatrix>
     %r = linalg.generic #trait_2d
         ins(%arga: tensor<3x4xi64, #SparseMatrix>)
        outs(%init: tensor<3x4xi64, #SparseMatrix>) {
@@ -121,9 +115,7 @@ module {
 
   func.func @add_outer_2d(%arg0: tensor<2x3xf32, #SparseMatrix>)
                          -> tensor<2x3xf32, #SparseMatrix> {
-    %c2 = arith.constant 2 : index
-    %c3 = arith.constant 3 : index
-    %0 = sparse_tensor.init[%c2, %c3] : tensor<2x3xf32, #SparseMatrix>
+    %0 = bufferization.alloc_tensor() : tensor<2x3xf32, #SparseMatrix>
     %1 = linalg.generic #trait_2d
       ins(%arg0 : tensor<2x3xf32, #SparseMatrix>)
       outs(%0 : tensor<2x3xf32, #SparseMatrix>) {

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_matmul.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_matmul.mlir
index 16bb2a9ef8dae..f9e0cbc918a76 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_matmul.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_matmul.mlir
@@ -31,8 +31,7 @@ module {
   //
   func.func @matmul2(%A: tensor<4x8xf64, #CSR>,
                 %B: tensor<8x4xf64, #CSR>) -> tensor<4x4xf64, #CSR> {
-    %c4 = arith.constant 4 : index
-    %C = sparse_tensor.init [%c4, %c4] : tensor<4x4xf64, #CSR>
+    %C = bufferization.alloc_tensor() : tensor<4x4xf64, #CSR>
     %D = linalg.matmul
       ins(%A, %B: tensor<4x8xf64, #CSR>, tensor<8x4xf64, #CSR>)
          outs(%C: tensor<4x4xf64, #CSR>) -> tensor<4x4xf64, #CSR>
@@ -44,8 +43,7 @@ module {
   //
   func.func @matmul3(%A: tensor<4x8xf64, #DCSR>,
                 %B: tensor<8x4xf64, #DCSR>) -> tensor<4x4xf64, #DCSR> {
-    %c4 = arith.constant 4 : index
-    %C = sparse_tensor.init [%c4, %c4] : tensor<4x4xf64, #DCSR>
+    %C = bufferization.alloc_tensor() : tensor<4x4xf64, #DCSR>
     %D = linalg.matmul
       ins(%A, %B: tensor<4x8xf64, #DCSR>, tensor<8x4xf64, #DCSR>)
          outs(%C: tensor<4x4xf64, #DCSR>) -> tensor<4x4xf64, #DCSR>

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_matrix_ops.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_matrix_ops.mlir
index f3a95ec0aab49..79fd9c6c814a4 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_matrix_ops.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_matrix_ops.mlir
@@ -42,7 +42,7 @@ module {
     %c1 = arith.constant 1 : index
     %d0 = tensor.dim %arga, %c0 : tensor<?x?xf64, #DCSR>
     %d1 = tensor.dim %arga, %c1 : tensor<?x?xf64, #DCSR>
-    %xm = sparse_tensor.init [%d0, %d1] : tensor<?x?xf64, #DCSR>
+    %xm = bufferization.alloc_tensor(%d0, %d1) : tensor<?x?xf64, #DCSR>
     %0 = linalg.generic #trait_scale
        ins(%arga: tensor<?x?xf64, #DCSR>)
         outs(%xm: tensor<?x?xf64, #DCSR>) {
@@ -73,7 +73,7 @@ module {
     %c1 = arith.constant 1 : index
     %d0 = tensor.dim %arga, %c0 : tensor<?x?xf64, #DCSR>
     %d1 = tensor.dim %arga, %c1 : tensor<?x?xf64, #DCSR>
-    %xv = sparse_tensor.init [%d0, %d1] : tensor<?x?xf64, #DCSR>
+    %xv = bufferization.alloc_tensor(%d0, %d1) : tensor<?x?xf64, #DCSR>
     %0 = linalg.generic #trait_op
        ins(%arga, %argb: tensor<?x?xf64, #DCSR>, tensor<?x?xf64, #DCSR>)
         outs(%xv: tensor<?x?xf64, #DCSR>) {
@@ -91,7 +91,7 @@ module {
     %c1 = arith.constant 1 : index
     %d0 = tensor.dim %arga, %c0 : tensor<?x?xf64, #DCSR>
     %d1 = tensor.dim %arga, %c1 : tensor<?x?xf64, #DCSR>
-    %xv = sparse_tensor.init [%d0, %d1] : tensor<?x?xf64, #DCSR>
+    %xv = bufferization.alloc_tensor(%d0, %d1) : tensor<?x?xf64, #DCSR>
     %0 = linalg.generic #trait_op
        ins(%arga, %argb: tensor<?x?xf64, #DCSR>, tensor<?x?xf64, #DCSR>)
         outs(%xv: tensor<?x?xf64, #DCSR>) {

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_out_mult_elt.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_out_mult_elt.mlir
index 91735f3b2c543..c79d7badd632f 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_out_mult_elt.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_out_mult_elt.mlir
@@ -22,9 +22,7 @@ module {
   // Sparse kernel.
   func.func @sparse_mult_elt(
       %arga: tensor<32x16xf32, #DCSR>, %argb: tensor<32x16xf32, #DCSR>) -> tensor<32x16xf32, #DCSR> {
-    %c16 = arith.constant 16 : index
-    %c32 = arith.constant 32 : index
-    %argx = sparse_tensor.init [%c32, %c16] : tensor<32x16xf32, #DCSR>
+    %argx = bufferization.alloc_tensor() : tensor<32x16xf32, #DCSR>
     %0 = linalg.generic #trait_mult_elt
       ins(%arga, %argb: tensor<32x16xf32, #DCSR>, tensor<32x16xf32, #DCSR>)
       outs(%argx: tensor<32x16xf32, #DCSR>) {

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_out_reduction.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_out_reduction.mlir
index e56e2dc2f7e5f..16b26131308f5 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_out_reduction.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_out_reduction.mlir
@@ -30,7 +30,7 @@ module {
     %c1 = arith.constant 1 : index
     %d0 = tensor.dim %arga, %c0 : tensor<?x?x?xi32, #SparseTensor>
     %d1 = tensor.dim %arga, %c1 : tensor<?x?x?xi32, #SparseTensor>
-    %xinit = sparse_tensor.init [%d0, %d1] : tensor<?x?xi32, #SparseMatrix>
+    %xinit = bufferization.alloc_tensor(%d0, %d1): tensor<?x?xi32, #SparseMatrix>
     %0 = linalg.generic #redsum
       ins(%arga, %argb: tensor<?x?x?xi32, #SparseTensor>,
                         tensor<?x?x?xi32, #SparseTensor>)

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_re_im.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_re_im.mlir
index 2656eb5551652..76cad83bb8ae3 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_re_im.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_re_im.mlir
@@ -20,7 +20,7 @@ module {
                 -> tensor<?xf32, #SparseVector> {
     %c = arith.constant 0 : index
     %d = tensor.dim %arga, %c : tensor<?xcomplex<f32>, #SparseVector>
-    %xv = sparse_tensor.init [%d] : tensor<?xf32, #SparseVector>
+    %xv = bufferization.alloc_tensor(%d) : tensor<?xf32, #SparseVector>
     %0 = linalg.generic #trait_op
        ins(%arga: tensor<?xcomplex<f32>, #SparseVector>)
         outs(%xv: tensor<?xf32, #SparseVector>) {
@@ -35,7 +35,7 @@ module {
                 -> tensor<?xf32, #SparseVector> {
     %c = arith.constant 0 : index
     %d = tensor.dim %arga, %c : tensor<?xcomplex<f32>, #SparseVector>
-    %xv = sparse_tensor.init [%d] : tensor<?xf32, #SparseVector>
+    %xv = bufferization.alloc_tensor(%d) : tensor<?xf32, #SparseVector>
     %0 = linalg.generic #trait_op
        ins(%arga: tensor<?xcomplex<f32>, #SparseVector>)
         outs(%xv: tensor<?xf32, #SparseVector>) {

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_sampled_mm_fusion.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_sampled_mm_fusion.mlir
index aadb8b3e88e11..ba76b65b49ac6 100755
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_sampled_mm_fusion.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_sampled_mm_fusion.mlir
@@ -101,8 +101,7 @@ module {
   func.func @sparse_sampled_dd(%args: tensor<8x8xf64, #SM>,
                           %arga: tensor<8x8xf64>,
                           %argb: tensor<8x8xf64>) -> tensor<8x8xf64, #SM> {
-    %c8 = arith.constant 8 : index
-    %1 = sparse_tensor.init [%c8, %c8] : tensor<8x8xf64, #SM>
+    %1 = bufferization.alloc_tensor() : tensor<8x8xf64, #SM>
     %2 = linalg.generic #trait_sampled_dense_dense
       ins(%args, %arga, %argb: tensor<8x8xf64, #SM>,
                                tensor<8x8xf64>, tensor<8x8xf64>)
@@ -135,8 +134,7 @@ module {
           linalg.yield %q : f64
     } -> tensor<8x8xf64>
     // Sample the result with elements-wise multiplication with sparse matrix.
-    %c8 = arith.constant 8 : index
-    %3 = sparse_tensor.init [%c8, %c8] : tensor<8x8xf64, #SM>
+    %3 = bufferization.alloc_tensor() : tensor<8x8xf64, #SM>
     %4 = linalg.generic #trait_scale
       ins(%2, %args : tensor<8x8xf64>, tensor<8x8xf64, #SM>)
       outs(%3 : tensor<8x8xf64, #SM>) {

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_sign.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_sign.mlir
index acce488155670..8a8a71e576402 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_sign.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_sign.mlir
@@ -30,7 +30,7 @@ module {
                              -> tensor<?xf64, #SparseVector> {
     %c0 = arith.constant 0 : index
     %d = tensor.dim %arg0, %c0 : tensor<?xf64, #SparseVector>
-    %xin = sparse_tensor.init [%d] : tensor<?xf64, #SparseVector>
+    %xin = bufferization.alloc_tensor(%d) : tensor<?xf64, #SparseVector>
     %0 = linalg.generic #trait_op
       ins(%arg0: tensor<?xf64, #SparseVector>)
       outs(%xin: tensor<?xf64, #SparseVector>) {

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_tensor_ops.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_tensor_ops.mlir
index 67e3d183b1e69..bd528bd56f34b 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_tensor_ops.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_tensor_ops.mlir
@@ -29,7 +29,7 @@ module {
     %d0 = tensor.dim %arga, %c0 : tensor<?x?x?xf64, #ST1>
     %d1 = tensor.dim %arga, %c1 : tensor<?x?x?xf64, #ST1>
     %d2 = tensor.dim %arga, %c2 : tensor<?x?x?xf64, #ST1>
-    %xm = sparse_tensor.init [%d0, %d1, %d2] : tensor<?x?x?xf64, #ST2>
+    %xm = bufferization.alloc_tensor(%d0, %d1, %d2) : tensor<?x?x?xf64, #ST2>
     %0 = linalg.generic #trait_scale
        ins(%arga: tensor<?x?x?xf64, #ST1>)
         outs(%xm: tensor<?x?x?xf64, #ST2>) {

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_transpose.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_transpose.mlir
index 905bf6adc264c..a73229ffe5666 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_transpose.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_transpose.mlir
@@ -32,9 +32,7 @@ module {
   func.func @sparse_transpose(%arga: tensor<3x4xf64, #DCSR>) -> tensor<4x3xf64, #DCSR> {
     %t = sparse_tensor.convert %arga : tensor<3x4xf64, #DCSR> to tensor<3x4xf64, #DCSC>
 
-    %c3 = arith.constant 3 : index
-    %c4 = arith.constant 4 : index
-    %i = sparse_tensor.init [%c4, %c3] : tensor<4x3xf64, #DCSR>
+    %i = bufferization.alloc_tensor() : tensor<4x3xf64, #DCSR>
 
     %0 = linalg.generic #transpose_trait
        ins(%t: tensor<3x4xf64, #DCSC>)

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_unary.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_unary.mlir
index c041b2b8f1934..1a8e8d997a3de 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_unary.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_unary.mlir
@@ -32,7 +32,7 @@ module {
     %c = arith.constant 0 : index
     %ci1 = arith.constant 1 : i32
     %d = tensor.dim %arga, %c : tensor<?xf64, #SparseVector>
-    %xv = sparse_tensor.init [%d] : tensor<?xi32, #SparseVector>
+    %xv = bufferization.alloc_tensor(%d) : tensor<?xi32, #SparseVector>
     %0 = linalg.generic #trait_vec_scale
        ins(%arga: tensor<?xf64, #SparseVector>)
         outs(%xv: tensor<?xi32, #SparseVector>) {
@@ -52,7 +52,7 @@ module {
     %c = arith.constant 0 : index
     %cf1 = arith.constant 1.0 : f64
     %d = tensor.dim %arga, %c : tensor<?xf64, #SparseVector>
-    %xv = sparse_tensor.init [%d] : tensor<?xf64, #SparseVector>
+    %xv = bufferization.alloc_tensor(%d) : tensor<?xf64, #SparseVector>
     %0 = linalg.generic #trait_vec_scale
        ins(%arga: tensor<?xf64, #SparseVector>)
         outs(%xv: tensor<?xf64, #SparseVector>) {
@@ -79,7 +79,7 @@ module {
     %cfmax = arith.constant 7.0 : f64
     %d0 = tensor.dim %argx, %c0 : tensor<?x?xf64, #DCSR>
     %d1 = tensor.dim %argx, %c1 : tensor<?x?xf64, #DCSR>
-    %xv = sparse_tensor.init [%d0, %d1] : tensor<?x?xf64, #DCSR>
+    %xv = bufferization.alloc_tensor(%d0, %d1) : tensor<?x?xf64, #DCSR>
     %0 = linalg.generic #trait_mat_scale
        ins(%argx: tensor<?x?xf64, #DCSR>)
         outs(%xv: tensor<?x?xf64, #DCSR>) {
@@ -202,4 +202,4 @@ module {
     sparse_tensor.release %2 : tensor<?x?xf64, #DCSR>
     return
   }
-}
\ No newline at end of file
+}

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_vector_ops.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_vector_ops.mlir
index b0d6bcc57b0f3..e9b4e33f8325a 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_vector_ops.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_vector_ops.mlir
@@ -50,7 +50,7 @@ module {
     %s = arith.constant 2.0 : f64
     %c = arith.constant 0 : index
     %d = tensor.dim %arga, %c : tensor<?xf64, #SparseVector>
-    %xv = sparse_tensor.init [%d] : tensor<?xf64, #SparseVector>
+    %xv = bufferization.alloc_tensor(%d) : tensor<?xf64, #SparseVector>
     %0 = linalg.generic #trait_scale
        ins(%arga: tensor<?xf64, #SparseVector>)
         outs(%xv: tensor<?xf64, #SparseVector>) {
@@ -79,7 +79,7 @@ module {
                    %argb: tensor<?xf64, #SparseVector>) -> tensor<?xf64, #SparseVector> {
     %c = arith.constant 0 : index
     %d = tensor.dim %arga, %c : tensor<?xf64, #SparseVector>
-    %xv = sparse_tensor.init [%d] : tensor<?xf64, #SparseVector>
+    %xv = bufferization.alloc_tensor(%d) : tensor<?xf64, #SparseVector>
     %0 = linalg.generic #trait_op
        ins(%arga, %argb: tensor<?xf64, #SparseVector>, tensor<?xf64, #SparseVector>)
         outs(%xv: tensor<?xf64, #SparseVector>) {
@@ -95,7 +95,7 @@ module {
                    %argb: tensor<?xf64, #SparseVector>) -> tensor<?xf64, #SparseVector> {
     %c = arith.constant 0 : index
     %d = tensor.dim %arga, %c : tensor<?xf64, #SparseVector>
-    %xv = sparse_tensor.init [%d] : tensor<?xf64, #SparseVector>
+    %xv = bufferization.alloc_tensor(%d) : tensor<?xf64, #SparseVector>
     %0 = linalg.generic #trait_op
        ins(%arga, %argb: tensor<?xf64, #SparseVector>, tensor<?xf64, #SparseVector>)
         outs(%xv: tensor<?xf64, #SparseVector>) {
@@ -111,7 +111,7 @@ module {
                      %argb: tensor<?xf64, #SparseVector>) -> tensor<?xf64, #DenseVector> {
     %c = arith.constant 0 : index
     %d = tensor.dim %arga, %c : tensor<?xf64, #SparseVector>
-    %xv = sparse_tensor.init [%d] : tensor<?xf64, #DenseVector>
+    %xv = bufferization.alloc_tensor(%d) : tensor<?xf64, #DenseVector>
     %0 = linalg.generic #trait_op
        ins(%arga, %argb: tensor<?xf64, #SparseVector>, tensor<?xf64, #SparseVector>)
         outs(%xv: tensor<?xf64, #DenseVector>) {

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/python/test_elementwise_add_sparse_output.py b/mlir/test/Integration/Dialect/SparseTensor/python/test_elementwise_add_sparse_output.py
index f647066a79eda..aaad99561bc33 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/python/test_elementwise_add_sparse_output.py
+++ b/mlir/test/Integration/Dialect/SparseTensor/python/test_elementwise_add_sparse_output.py
@@ -35,9 +35,7 @@
 
 func.func @sparse_add_elt(
     %arga: tensor<3x4xf64, #DCSR>, %argb: tensor<3x4xf64, #DCSR>) -> tensor<3x4xf64, #DCSR> {
-  %c3 = arith.constant 3 : index
-  %c4 = arith.constant 4 : index
-  %argx = sparse_tensor.init [%c3, %c4] : tensor<3x4xf64, #DCSR>
+  %argx = bufferization.alloc_tensor() : tensor<3x4xf64, #DCSR>
   %0 = linalg.generic #trait_add_elt
     ins(%arga, %argb: tensor<3x4xf64, #DCSR>, tensor<3x4xf64, #DCSR>)
     outs(%argx: tensor<3x4xf64, #DCSR>) {

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco.py b/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco.py
index 220a9ce5f9ec4..4b8cfeb1bd36d 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco.py
+++ b/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco.py
@@ -34,6 +34,7 @@
 from mlir import ir
 from mlir import runtime
 from mlir.dialects import arith
+from mlir.dialects import bufferization
 from mlir.dialects import builtin
 from mlir.dialects import func
 from mlir.dialects import linalg
@@ -889,8 +890,7 @@ def emit_tensor_init(self) -> ir.RankedTensorType:
     mlir_type = _mlir_tensor_type(self.dst_dtype, self.dst_dims,
                                   self.dst_format.mlir_tensor_attr())
     index_type = ir.IndexType.get()
-    dims = [arith.ConstantOp(index_type, d).result for d in mlir_type.shape]
-    return sparse_tensor.InitOp(mlir_type, dims)
+    return bufferization.AllocTensorOp(mlir_type, [])
 
 
 class _Stats:


        


More information about the Mlir-commits mailing list