[Mlir-commits] [mlir] [mlir][sparse] More allocate -> empty tensor migration (PR #66720)

Aart Bik llvmlistbot at llvm.org
Mon Sep 18 16:26:34 PDT 2023


https://github.com/aartbik created https://github.com/llvm/llvm-project/pull/66720

This also allows tensor.empty in the "conversion" path of the sparse compiler, further paving the way to
deprecate the bufferization.allocated_tensor() op.

>From cb2f8022fbab464f96a7f6fe96d441683d458200 Mon Sep 17 00:00:00 2001
From: Aart Bik <ajcbik at google.com>
Date: Mon, 18 Sep 2023 16:22:29 -0700
Subject: [PATCH] [mlir][sparse] More allocate -> empty tensor migration

This also allows tensor.empty in the "conversion" path
of the sparse compiler, further paving the way to
deprecate the bufferization.allocated_tensor() op.
---
 .../Transforms/SparseTensorConversion.cpp     | 58 +++++++++++++----
 .../SparseTensor/constant_index_map.mlir      |  4 +-
 .../Dialect/SparseTensor/sparse_affine.mlir   | 63 +++++++++----------
 .../SparseTensor/sparse_broadcast.mlir        |  4 +-
 .../Dialect/SparseTensor/sparse_expand.mlir   |  6 +-
 .../Dialect/SparseTensor/sparse_fp_ops.mlir   | 36 +++++------
 .../Dialect/SparseTensor/sparse_kernels.mlir  |  4 +-
 .../test/Dialect/SparseTensor/sparse_out.mlir | 12 ++--
 .../SparseTensor/sparse_vector_ops.mlir       |  2 +-
 9 files changed, 110 insertions(+), 79 deletions(-)

diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
index 871686a4ada0f70..d75601e369a0d25 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
@@ -830,6 +830,7 @@ class SparseTensorNewConverter : public OpConversionPattern<NewOp> {
 };
 
 /// Sparse conversion rule for the alloc operator.
+/// TODO(springerm): remove when bufferization.alloc_tensor is gone
 class SparseTensorAllocConverter
     : public OpConversionPattern<bufferization::AllocTensorOp> {
 public:
@@ -864,6 +865,37 @@ class SparseTensorAllocConverter
   }
 };
 
+/// Sparse conversion rule for the empty tensor.
+class SparseTensorEmptyConverter : public OpConversionPattern<tensor::EmptyOp> {
+public:
+  using OpConversionPattern::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(tensor::EmptyOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    const auto stt = getSparseTensorType(op);
+    if (!stt.hasEncoding())
+      return failure();
+    // Gather all dimension sizes as SSA values.
+    const Dimension dimRank = stt.getDimRank();
+    SmallVector<Value> dimSizes;
+    dimSizes.reserve(dimRank);
+    auto shape = op.getType().getShape();
+    unsigned operandCtr = 0;
+    for (Dimension d = 0; d < dimRank; ++d) {
+      dimSizes.push_back(stt.isDynamicDim(d)
+                             ? adaptor.getOperands()[operandCtr++]
+                             : constantIndex(rewriter, loc, shape[d]));
+    }
+    // Generate the call to construct empty tensor. The sizes are
+    // explicitly defined by the arguments to the alloc operator.
+    rewriter.replaceOp(op, NewCallParams(rewriter, loc)
+                               .genBuffers(stt, dimSizes)
+                               .genNewCall(Action::kEmpty));
+    return success();
+  }
+};
+
 /// Sparse conversion rule for the convert operator.
 class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> {
 public:
@@ -1503,19 +1535,19 @@ mlir::SparseTensorTypeToPtrConverter::SparseTensorTypeToPtrConverter() {
 void mlir::populateSparseTensorConversionPatterns(
     TypeConverter &typeConverter, RewritePatternSet &patterns,
     const SparseTensorConversionOptions &options) {
-  patterns.add<SparseReturnConverter, SparseTensorToDimSizeConverter,
-               SparseCastConverter, SparseTensorNewConverter,
-               SparseReshapeConverter<tensor::ExpandShapeOp>,
-               SparseReshapeConverter<tensor::CollapseShapeOp>,
-               SparseTensorConcatConverter, SparseTensorAllocConverter,
-               SparseTensorDeallocConverter, SparseTensorToPositionsConverter,
-               SparseTensorToCoordinatesConverter,
-               SparseTensorToValuesConverter, SparseNumberOfEntriesConverter,
-               SparseTensorLoadConverter, SparseTensorInsertConverter,
-               SparseTensorExpandConverter, SparseTensorCompressConverter,
-               SparseTensorOutConverter, SparseTensorPackConverter>(
-      typeConverter, patterns.getContext());
-
+  patterns
+      .add<SparseReturnConverter, SparseTensorToDimSizeConverter,
+           SparseCastConverter, SparseTensorNewConverter,
+           SparseReshapeConverter<tensor::ExpandShapeOp>,
+           SparseReshapeConverter<tensor::CollapseShapeOp>,
+           SparseTensorConcatConverter, SparseTensorAllocConverter,
+           SparseTensorEmptyConverter, SparseTensorDeallocConverter,
+           SparseTensorToPositionsConverter, SparseTensorToCoordinatesConverter,
+           SparseTensorToValuesConverter, SparseNumberOfEntriesConverter,
+           SparseTensorLoadConverter, SparseTensorInsertConverter,
+           SparseTensorExpandConverter, SparseTensorCompressConverter,
+           SparseTensorOutConverter, SparseTensorPackConverter>(
+          typeConverter, patterns.getContext());
   patterns.add<SparseTensorConvertConverter>(typeConverter,
                                              patterns.getContext(), options);
 }
diff --git a/mlir/test/Dialect/SparseTensor/constant_index_map.mlir b/mlir/test/Dialect/SparseTensor/constant_index_map.mlir
index bfb4503edbc4e40..9eb535385790146 100644
--- a/mlir/test/Dialect/SparseTensor/constant_index_map.mlir
+++ b/mlir/test/Dialect/SparseTensor/constant_index_map.mlir
@@ -13,7 +13,7 @@
 // CHECK-DAG:       %[[VAL_2:.*]] = arith.constant 77 : index
 // CHECK-DAG:       %[[VAL_3:.*]] = arith.constant 0 : index
 // CHECK-DAG:       %[[VAL_4:.*]] = arith.constant 1 : index
-// CHECK-DAG:       %[[VAL_5:.*]] = bufferization.alloc_tensor() : tensor<77xi1, #{{.*}}>
+// CHECK-DAG:       %[[VAL_5:.*]] = tensor.empty() : tensor<77xi1, #{{.*}}>
 // CHECK-DAG:       %[[VAL_6:.*]] = bufferization.to_memref %[[VAL_0]] : memref<1x77xi1>
 // CHECK-DAG:       %[[VAL_7:.*]] = bufferization.to_memref %[[VAL_1]] : memref<1x77xi1>
 // CHECK:           %[[VAL_8:.*]] = scf.for %[[VAL_9:.*]] = %[[VAL_3]] to %[[VAL_2]] step %[[VAL_4]] iter_args(%[[VAL_10:.*]] = %[[VAL_5]]) -> (tensor<77xi1, #{{.*}}>) {
@@ -27,7 +27,7 @@
 // CHECK:           return %[[VAL_15]] : tensor<77xi1, #{{.*}}>
 // CHECK:         }
 func.func @main(%arg0: tensor<1x77xi1>, %arg1: tensor<1x77xi1>) -> tensor<77xi1, #SpVec> {
-  %0 = bufferization.alloc_tensor() : tensor<77xi1, #SpVec>
+  %0 = tensor.empty() : tensor<77xi1, #SpVec>
   %1 = linalg.generic {
     indexing_maps = [#map1, #map1, #map2],
     iterator_types = ["parallel"]}
diff --git a/mlir/test/Dialect/SparseTensor/sparse_affine.mlir b/mlir/test/Dialect/SparseTensor/sparse_affine.mlir
index b3f6ae9f12ee4d6..fc97685b8378bf5 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_affine.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_affine.mlir
@@ -1,4 +1,3 @@
-// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py
 // RUN: mlir-opt %s -sparsification | FileCheck %s
 
 #SpVec = #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed) }>
@@ -17,9 +16,9 @@
 }
 
 // CHECK-LABEL:   func @mul_inv_dense1d(
-// CHECK-SAME:                          %[[VAL_0:.*]]: tensor<32xf32, #sparse_tensor.encoding<{{{.*}}}>>,
-// CHECK-SAME:                          %[[VAL_1:.*]]: tensor<4xf32>,
-// CHECK-SAME:                          %[[VAL_2:.*]]: tensor<32xf32>) -> tensor<32xf32> {
+// CHECK-SAME:      %[[VAL_0:.*]]: tensor<32xf32, #sparse_tensor.encoding<{{{.*}}}>>,
+// CHECK-SAME:      %[[VAL_1:.*]]: tensor<4xf32>,
+// CHECK-SAME:      %[[VAL_2:.*]]: tensor<32xf32>) -> tensor<32xf32> {
 // CHECK-DAG:       %[[VAL_3:.*]] = arith.constant 0 : index
 // CHECK-DAG:       %[[VAL_4:.*]] = arith.constant 3 : index
 // CHECK-DAG:       %[[VAL_5:.*]] = arith.constant 1 : index
@@ -57,13 +56,13 @@ func.func @mul_inv_dense1d(%arga: tensor<32xf32, #SpVec>,
 }
 
 // CHECK-LABEL:   func.func @mul_inv_sparse1d(
-// CHECK-SAME:                                %[[VAL_0:.*]]: tensor<32xf32, #sparse_tensor.encoding<{{{.*}}}>>,
-// CHECK-SAME:                                %[[VAL_1:.*]]: tensor<4xf32, #sparse_tensor.encoding<{{{.*}}}>>)
+// CHECK-SAME:      %[[VAL_0:.*]]: tensor<32xf32, #sparse_tensor.encoding<{{{.*}}}>>,
+// CHECK-SAME:      %[[VAL_1:.*]]: tensor<4xf32, #sparse_tensor.encoding<{{{.*}}}>>)
 // CHECK:           %[[VAL_2:.*]] = arith.constant 0 : index
 // CHECK:           %[[VAL_3:.*]] = arith.constant 1 : index
 // CHECK:           %[[VAL_4:.*]] = arith.constant 3 : index
 // CHECK:           %[[VAL_5:.*]] = arith.constant 0.000000e+00 : f32
-// CHECK:           %[[VAL_6:.*]] = bufferization.alloc_tensor() : tensor<32xf32, #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK:           %[[VAL_6:.*]] = tensor.empty() : tensor<32xf32, #sparse_tensor.encoding<{{{.*}}}>>
 // CHECK:           %[[VAL_7:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<32xf32, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
 // CHECK:           %[[VAL_8:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32xf32, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xf32>
 // CHECK:           %[[VAL_9:.*]] = sparse_tensor.positions %[[VAL_1]] {level = 0 : index} : tensor<4xf32, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
@@ -95,7 +94,7 @@ func.func @mul_inv_dense1d(%arga: tensor<32xf32, #SpVec>,
 // CHECK:           return %[[VAL_32]] : tensor<32xf32, #sparse_tensor.encoding<{{{.*}}}>>
 func.func @mul_inv_sparse1d(%arga: tensor<32xf32, #SpVec>,
                             %argb: tensor<4xf32, #SpVec>) -> tensor<32xf32, #SpVec> {
-  %argx = bufferization.alloc_tensor() : tensor<32xf32, #SpVec>
+  %argx = tensor.empty() : tensor<32xf32, #SpVec>
   %0 = linalg.generic #trait1
      ins(%arga, %argb: tensor<32xf32, #SpVec>, tensor<4xf32, #SpVec>)
     outs(%argx: tensor<32xf32, #SpVec>) {
@@ -109,13 +108,13 @@ func.func @mul_inv_sparse1d(%arga: tensor<32xf32, #SpVec>,
 
 
 // CHECK-LABEL:   func.func @mul_inv_enc_dense1d(
-// CHECK-SAME:                                   %[[VAL_0:.*]]: tensor<32xf32, #sparse_tensor.encoding<{{{.*}}}>>,
-// CHECK-SAME:                                   %[[VAL_1:.*]]: tensor<4xf32, #sparse_tensor.encoding<{{{.*}}}>>) -> tensor<32xf32, #sparse_tensor.encoding<{{{.*}}}>> {
+// CHECK-SAME:      %[[VAL_0:.*]]: tensor<32xf32, #sparse_tensor.encoding<{{{.*}}}>>,
+// CHECK-SAME:      %[[VAL_1:.*]]: tensor<4xf32, #sparse_tensor.encoding<{{{.*}}}>>) -> tensor<32xf32, #sparse_tensor.encoding<{{{.*}}}>> {
 // CHECK:           %[[VAL_2:.*]] = arith.constant 32 : index
 // CHECK:           %[[VAL_3:.*]] = arith.constant 3 : index
 // CHECK:           %[[VAL_4:.*]] = arith.constant 0 : index
 // CHECK:           %[[VAL_5:.*]] = arith.constant 1 : index
-// CHECK:           %[[VAL_6:.*]] = bufferization.alloc_tensor() : tensor<32xf32, #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK:           %[[VAL_6:.*]] = tensor.empty() : tensor<32xf32, #sparse_tensor.encoding<{{{.*}}}>>
 // CHECK:           %[[VAL_7:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32xf32, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xf32>
 // CHECK:           %[[VAL_8:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<4xf32, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xf32>
 // CHECK:           %[[VAL_9:.*]] = sparse_tensor.values %[[VAL_6]] : tensor<32xf32, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xf32>
@@ -132,7 +131,7 @@ func.func @mul_inv_sparse1d(%arga: tensor<32xf32, #SpVec>,
 // CHECK:         }
 func.func @mul_inv_enc_dense1d(%arga: tensor<32xf32, #EncDenseVec>,
                             %argb: tensor<4xf32, #EncDenseVec>) -> tensor<32xf32, #EncDenseVec> {
-  %argx = bufferization.alloc_tensor() : tensor<32xf32, #EncDenseVec>
+  %argx = tensor.empty() : tensor<32xf32, #EncDenseVec>
   %0 = linalg.generic #trait1
      ins(%arga, %argb: tensor<32xf32, #EncDenseVec>, tensor<4xf32, #EncDenseVec>)
     outs(%argx: tensor<32xf32, #EncDenseVec>) {
@@ -155,9 +154,9 @@ func.func @mul_inv_enc_dense1d(%arga: tensor<32xf32, #EncDenseVec>,
 }
 
 // CHECK-LABEL:   func @and_affine_dense1d(
-// CHECK-SAME:                             %[[VAL_0:.*]]: tensor<32xi32, #sparse_tensor.encoding<{{{.*}}}>>,
-// CHECK-SAME:                             %[[VAL_1:.*]]: tensor<34xi32>,
-// CHECK-SAME:                             %[[VAL_2:.*]]: tensor<32xi32>) -> tensor<32xi32> {
+// CHECK-SAME:      %[[VAL_0:.*]]: tensor<32xi32, #sparse_tensor.encoding<{{{.*}}}>>,
+// CHECK-SAME:      %[[VAL_1:.*]]: tensor<34xi32>,
+// CHECK-SAME:      %[[VAL_2:.*]]: tensor<32xi32>) -> tensor<32xi32> {
 // CHECK-DAG:       %[[ZERO:.*]] = arith.constant 0 : i32
 // CHECK-DAG:       %[[VAL_3:.*]] = arith.constant 0 : index
 // CHECK-DAG:       %[[VAL_4:.*]] = arith.constant 1 : index
@@ -195,12 +194,12 @@ func.func @and_affine_dense1d(%arga: tensor<32xi32, #SpVec>,
 }
 
 // CHECK-LABEL:   func.func @and_affine_sparse1d(
-// CHECK-SAME:                                   %[[VAL_0:.*]]: tensor<32xi32, #sparse_tensor.encoding<{{{.*}}}>>,
-// CHECK-SAME:                                   %[[VAL_1:.*]]: tensor<34xi32, #sparse_tensor.encoding<{{{.*}}}>>)
+// CHECK-SAME:      %[[VAL_0:.*]]: tensor<32xi32, #sparse_tensor.encoding<{{{.*}}}>>,
+// CHECK-SAME:      %[[VAL_1:.*]]: tensor<34xi32, #sparse_tensor.encoding<{{{.*}}}>>)
 // CHECK:           %[[VAL_2:.*]] = arith.constant 0 : index
 // CHECK:           %[[VAL_3:.*]] = arith.constant 1 : index
 // CHECK:           %[[VAL_4:.*]] = arith.constant 2 : index
-// CHECK:           %[[VAL_5:.*]] = bufferization.alloc_tensor() : tensor<32xi32, #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK:           %[[VAL_5:.*]] = tensor.empty() : tensor<32xi32, #sparse_tensor.encoding<{{{.*}}}>>
 // CHECK:           %[[VAL_6:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<32xi32, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
 // CHECK:           %[[VAL_7:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<32xi32, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
 // CHECK:           %[[VAL_8:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32xi32, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xi32>
@@ -234,7 +233,7 @@ func.func @and_affine_dense1d(%arga: tensor<32xi32, #SpVec>,
 // CHECK:           return %[[VAL_33]] : tensor<32xi32, #sparse_tensor.encoding<{{{.*}}}>>
 func.func @and_affine_sparse1d(%arga: tensor<32xi32, #SpVec>,
                                %argb: tensor<34xi32, #SpVec>) -> tensor<32xi32, #SpVec> {
-  %argx = bufferization.alloc_tensor() : tensor<32xi32, #SpVec>
+  %argx = tensor.empty() : tensor<32xi32, #SpVec>
   %0 = linalg.generic #trait2
      ins(%arga, %argb: tensor<32xi32, #SpVec>, tensor<34xi32, #SpVec>)
     outs(%argx: tensor<32xi32, #SpVec>) {
@@ -256,9 +255,9 @@ func.func @and_affine_sparse1d(%arga: tensor<32xi32, #SpVec>,
 }
 
 // CHECK-LABEL:   func @mul_affine_dense2d(
-// CHECK-SAME:                             %[[VAL_0:.*]]: tensor<32x16xf64, #sparse_tensor.encoding<{{{.*}}}>>,
-// CHECK-SAME:                             %[[VAL_1:.*]]: tensor<34x19xf64>,
-// CHECK-SAME:                             %[[VAL_2:.*]]: tensor<32x16xf64>) -> tensor<32x16xf64> {
+// CHECK-SAME:      %[[VAL_0:.*]]: tensor<32x16xf64, #sparse_tensor.encoding<{{{.*}}}>>,
+// CHECK-SAME:      %[[VAL_1:.*]]: tensor<34x19xf64>,
+// CHECK-SAME:      %[[VAL_2:.*]]: tensor<32x16xf64>) -> tensor<32x16xf64> {
 // CHECK-DAG:       %[[VAL_3:.*]] = arith.constant 1 : index
 // CHECK-DAG:       %[[VAL_4:.*]] = arith.constant 32 : index
 // CHECK-DAG:       %[[VAL_5:.*]] = arith.constant 0 : index
@@ -304,8 +303,8 @@ func.func @mul_affine_dense2d(%arga: tensor<32x16xf64, #CSR>,
 
 
 // CHECK-LABEL:   func.func @mul_affine_sparse2d(
-// CHECK-SAME:                                   %[[VAL_0:.*]]: tensor<32x16xf64, #sparse_tensor.encoding<{{{.*}}}>>,
-// CHECK-SAME:                                   %[[VAL_1:.*]]: tensor<34x19xf64, #sparse_tensor.encoding<{{{.*}}}>>) -> tensor<32x16xf64, #sparse_tensor.encoding<{{{.*}}}>> {
+// CHECK-SAME:      %[[VAL_0:.*]]: tensor<32x16xf64, #sparse_tensor.encoding<{{{.*}}}>>,
+// CHECK-SAME:      %[[VAL_1:.*]]: tensor<34x19xf64, #sparse_tensor.encoding<{{{.*}}}>>) -> tensor<32x16xf64, #sparse_tensor.encoding<{{{.*}}}>> {
 // CHECK-DAG:       %[[VAL_2:.*]] = arith.constant 32 : index
 // CHECK-DAG:       %[[VAL_3:.*]] = arith.constant 0 : index
 // CHECK-DAG:       %[[VAL_4:.*]] = arith.constant 1 : index
@@ -314,7 +313,7 @@ func.func @mul_affine_dense2d(%arga: tensor<32x16xf64, #CSR>,
 // CHECK-DAG:       %[[VAL_7:.*]] = arith.constant 3 : index
 // CHECK-DAG:       %[[VAL_TRUE:.*]] = arith.constant true
 // CHECK-DAG:       %[[VAL_FALSE:.*]] = arith.constant false
-// CHECK:           %[[VAL_8:.*]] = bufferization.alloc_tensor() : tensor<32x16xf64, #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK:           %[[VAL_8:.*]] = tensor.empty() : tensor<32x16xf64, #sparse_tensor.encoding<{{{.*}}}>>
 // CHECK:           %[[VAL_9:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 1 : index} : tensor<32x16xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
 // CHECK:           %[[VAL_10:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 1 : index} : tensor<32x16xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
 // CHECK:           %[[VAL_11:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x16xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xf64>
@@ -360,7 +359,7 @@ func.func @mul_affine_dense2d(%arga: tensor<32x16xf64, #CSR>,
 // CHECK:           return %[[VAL_45]] : tensor<32x16xf64, #sparse_tensor.encoding<{{{.*}}}>>
 func.func @mul_affine_sparse2d(%arga: tensor<32x16xf64, #CSR>,
                               %argb: tensor<34x19xf64, #CSR>) -> tensor<32x16xf64, #CSR> {
-  %argx = bufferization.alloc_tensor() : tensor<32x16xf64, #CSR>
+  %argx = tensor.empty() : tensor<32x16xf64, #CSR>
   %0 = linalg.generic #trait3
      ins(%arga, %argb: tensor<32x16xf64, #CSR>, tensor<34x19xf64, #CSR>)
     outs(%argx: tensor<32x16xf64, #CSR>) {
@@ -383,9 +382,9 @@ func.func @mul_affine_sparse2d(%arga: tensor<32x16xf64, #CSR>,
 }
 
 // CHECK-LABEL:   func.func @mul_affine_dense_dim_2d(
-// CHECK-SAME:                                       %[[VAL_0:.*]]: tensor<34x16xf64, #sparse_tensor.encoding
-// CHECK-SAME:                                       %[[VAL_1:.*]]: tensor<32x19xf64, #sparse_tensor.encoding<{{{.*}}}>>,
-// CHECK-SAME:                                       %[[VAL_2:.*]]: tensor<32x16xf64>) -> tensor<32x16xf64> {
+// CHECK-SAME:      %[[VAL_0:.*]]: tensor<34x16xf64, #sparse_tensor.encoding
+// CHECK-SAME:      %[[VAL_1:.*]]: tensor<32x19xf64, #sparse_tensor.encoding<{{{.*}}}>>,
+// CHECK-SAME:      %[[VAL_2:.*]]: tensor<32x16xf64>) -> tensor<32x16xf64> {
 // CHECK-DAG:       %[[VAL_3:.*]] = arith.constant 19 : index
 // CHECK-DAG:       %[[VAL_4:.*]] = arith.constant 0 : index
 // CHECK-DAG:       %[[VAL_5:.*]] = arith.constant 1 : index
@@ -447,9 +446,9 @@ func.func @mul_affine_dense_dim_2d(%arga: tensor<34x16xf64, #CSR>,
 }
 
 // CHECK-LABEL:   func.func @mul_const_affine_dense_dim_2d(
-// CHECK-SAME:                                             %[[VAL_0:.*]]: tensor<34x16xf64,
-// CHECK-SAME:                                             %[[VAL_1:.*]]: tensor<32x19xf64, #sparse_tensor.encoding<{{{.*}}}>>,
-// CHECK-SAME:                                             %[[VAL_2:.*]]: tensor<32x16xf64>) -> tensor<32x16xf64> {
+// CHECK-SAME:      %[[VAL_0:.*]]: tensor<34x16xf64,
+// CHECK-SAME:      %[[VAL_1:.*]]: tensor<32x19xf64, #sparse_tensor.encoding<{{{.*}}}>>,
+// CHECK-SAME:      %[[VAL_2:.*]]: tensor<32x16xf64>) -> tensor<32x16xf64> {
 // CHECK-DAG:       %[[VAL_3:.*]] = arith.constant 19 : index
 // CHECK-DAG:       %[[VAL_4:.*]] = arith.constant 2 : index
 // CHECK-DAG:       %[[VAL_5:.*]] = arith.constant 0 : index
diff --git a/mlir/test/Dialect/SparseTensor/sparse_broadcast.mlir b/mlir/test/Dialect/SparseTensor/sparse_broadcast.mlir
index ae5b941259f6515..1a5f79d23cba29a 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_broadcast.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_broadcast.mlir
@@ -16,7 +16,7 @@
 //   CHECK-DAG:  %[[TMP_c3:.*]] = arith.constant 3 : index
 //   CHECK-DAG:  %[[TMP_c0:.*]] = arith.constant 0 : index
 //   CHECK-DAG:  %[[TMP_c1:.*]] = arith.constant 1 : index
-//       CHECK:  %[[TMP_0:.*]] = bufferization.alloc_tensor()
+//       CHECK:  %[[TMP_0:.*]] = tensor.empty()
 //       CHECK:  %[[TMP_1:.*]] = sparse_tensor.positions %[[TMP_arg0]] {level = 0 : index}
 //       CHECK:  %[[TMP_2:.*]] = sparse_tensor.coordinates %[[TMP_arg0]] {level = 0 : index}
 //       CHECK:  %[[TMP_3:.*]] = sparse_tensor.positions %[[TMP_arg0]] {level = 1 : index}
@@ -44,7 +44,7 @@
 //       CHECK:  return %[[TMP_8]]
 module @func_sparse {
   func.func public @main(%arg0: tensor<4x5xi32, #DCSR>) -> tensor<4x3x5xi32, #SparseTensor> {
-    %0 = bufferization.alloc_tensor() : tensor<4x3x5xi32, #SparseTensor>
+    %0 = tensor.empty() : tensor<4x3x5xi32, #SparseTensor>
     %1 = linalg.generic #trait
     ins(%arg0 : tensor<4x5xi32, #DCSR>) outs(%0 : tensor<4x3x5xi32, #SparseTensor>) {
     ^bb0(%in: i32, %out: i32):
diff --git a/mlir/test/Dialect/SparseTensor/sparse_expand.mlir b/mlir/test/Dialect/SparseTensor/sparse_expand.mlir
index ee3613a268def5e..d19d7fe2871d674 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_expand.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_expand.mlir
@@ -67,7 +67,7 @@
 func.func @kernel(%arga: tensor<?x?xf64, #DCSC>) -> tensor<?xf64, #SV> {
   %c0 = arith.constant 0 : index
   %n = tensor.dim %arga, %c0 : tensor<?x?xf64, #DCSC>
-  %v = bufferization.alloc_tensor(%n) : tensor<?xf64, #SV>
+  %v = tensor.empty(%n) : tensor<?xf64, #SV>
   %0 = linalg.generic #rowsum
     ins(%arga: tensor<?x?xf64, #DCSC>)
     outs(%v: tensor<?xf64, #SV>) {
@@ -119,7 +119,7 @@ func.func @kernel(%arga: tensor<?x?xf64, #DCSC>) -> tensor<?xf64, #SV> {
 //
 func.func @matmul1(%A: tensor<8x2xf64, #CSR>,
                    %B: tensor<2x4xf64, #CSR>) -> tensor<8x4xf64, #CSR> {
-  %C = bufferization.alloc_tensor() : tensor<8x4xf64, #CSR>
+  %C = tensor.empty() : tensor<8x4xf64, #CSR>
   %D = linalg.matmul
     ins(%A, %B: tensor<8x2xf64, #CSR>, tensor<2x4xf64, #CSR>)
        outs(%C: tensor<8x4xf64, #CSR>) -> tensor<8x4xf64, #CSR>
@@ -167,7 +167,7 @@ func.func @matmul1(%A: tensor<8x2xf64, #CSR>,
 //
 func.func @matmul2(%A: tensor<8x2xf64, #CSC>,
                    %B: tensor<2x4xf64, #CSC>) -> tensor<8x4xf64, #CSC> {
-  %C = bufferization.alloc_tensor() : tensor<8x4xf64, #CSC>
+  %C = tensor.empty() : tensor<8x4xf64, #CSC>
   %D = linalg.matmul
     ins(%A, %B: tensor<8x2xf64, #CSC>, tensor<2x4xf64, #CSC>)
        outs(%C: tensor<8x4xf64, #CSC>) -> tensor<8x4xf64, #CSC>
diff --git a/mlir/test/Dialect/SparseTensor/sparse_fp_ops.mlir b/mlir/test/Dialect/SparseTensor/sparse_fp_ops.mlir
index 3e18b0c1b6c1bc6..dac34da30b49fda 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_fp_ops.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_fp_ops.mlir
@@ -351,13 +351,13 @@ func.func @divbyc(%arga: tensor<32xf64, #SV>,
 }
 
 // CHECK-LABEL:   func.func @zero_preserving_math(
-// CHECK-SAME:      %[[VAL_0:.*]]: tensor<32xf64, #sparse_tensor.encoding<{ lvlTypes = [ "compressed" ] }>>) -> tensor<32xf64, #sparse_tensor.encoding<{ lvlTypes = [ "compressed" ] }>> {
+// CHECK-SAME:      %[[VAL_0:.*]]: tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>> {
 // CHECK-DAG:       %[[VAL_1:.*]] = arith.constant 0 : index
 // CHECK-DAG:       %[[VAL_2:.*]] = arith.constant 1 : index
-// CHECK:           %[[VAL_3:.*]] = bufferization.alloc_tensor() : tensor<32xf64, #sparse_tensor.encoding<{ lvlTypes = [ "compressed" ] }>>
-// CHECK:           %[[VAL_4:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<32xf64, #sparse_tensor.encoding<{ lvlTypes = [ "compressed" ] }>> to memref<?xindex>
-// CHECK:           %[[VAL_5:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<32xf64, #sparse_tensor.encoding<{ lvlTypes = [ "compressed" ] }>> to memref<?xindex>
-// CHECK:           %[[VAL_6:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32xf64, #sparse_tensor.encoding<{ lvlTypes = [ "compressed" ] }>> to memref<?xf64>
+// CHECK:           %[[VAL_3:.*]] = tensor.empty() : tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK:           %[[VAL_4:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
+// CHECK:           %[[VAL_5:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
+// CHECK:           %[[VAL_6:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xf64>
 // CHECK:           %[[VAL_7:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_1]]] : memref<?xindex>
 // CHECK:           %[[VAL_8:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_2]]] : memref<?xindex>
 // CHECK:           %[[T:.*]] = scf.for %[[VAL_9:.*]] = %[[VAL_7]] to %[[VAL_8]] step %[[VAL_2]] {{.*}} {
@@ -371,15 +371,15 @@ func.func @divbyc(%arga: tensor<32xf64, #SV>,
 // CHECK:             %[[VAL_17:.*]] = math.log1p %[[VAL_16]] : f64
 // CHECK:             %[[VAL_18:.*]] = math.sin %[[VAL_17]] : f64
 // CHECK:             %[[VAL_19:.*]] = math.tanh %[[VAL_18]] : f64
-// CHECK:             %[[Y:.*]] = sparse_tensor.insert %[[VAL_19]] into %{{.*}}[%[[VAL_10]]] : tensor<32xf64, #sparse_tensor.encoding<{ lvlTypes = [ "compressed" ] }>>
+// CHECK:             %[[Y:.*]] = sparse_tensor.insert %[[VAL_19]] into %{{.*}}[%[[VAL_10]]] : tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>>
 // CHECK:             scf.yield %[[Y]]
 // CHECK:           }
-// CHECK:           %[[VAL_20:.*]] = sparse_tensor.load %[[T]] hasInserts : tensor<32xf64, #sparse_tensor.encoding<{ lvlTypes = [ "compressed" ] }>>
-// CHECK:           return %[[VAL_20]] : tensor<32xf64, #sparse_tensor.encoding<{ lvlTypes = [ "compressed" ] }>>
+// CHECK:           %[[VAL_20:.*]] = sparse_tensor.load %[[T]] hasInserts : tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK:           return %[[VAL_20]] : tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>>
 // CHECK:         }
 func.func @zero_preserving_math(%arga: tensor<32xf64, #SV>) -> tensor<32xf64, #SV> {
   %c32 = arith.constant 32 : index
-  %xinp = bufferization.alloc_tensor() : tensor<32xf64, #SV>
+  %xinp = tensor.empty() : tensor<32xf64, #SV>
   %0 = linalg.generic #trait1
      ins(%arga: tensor<32xf64, #SV>)
     outs(%xinp: tensor<32xf64, #SV>) {
@@ -398,29 +398,29 @@ func.func @zero_preserving_math(%arga: tensor<32xf64, #SV>) -> tensor<32xf64, #S
 }
 
 // CHECK-LABEL:   func.func @complex_divbyc(
-// CHECK-SAME:      %[[VAL_0:.*]]: tensor<32xcomplex<f64>, #sparse_tensor.encoding<{ lvlTypes = [ "compressed" ] }>>) -> tensor<32xcomplex<f64>, #sparse_tensor.encoding<{ lvlTypes = [ "compressed" ] }>> {
+// CHECK-SAME:      %[[VAL_0:.*]]: tensor<32xcomplex<f64>, #sparse_tensor.encoding<{{{.*}}}>> {
 // CHECK-DAG:       %[[VAL_1:.*]] = arith.constant 0 : index
 // CHECK-DAG:       %[[VAL_2:.*]] = arith.constant 1 : index
 // CHECK:           %[[VAL_3:.*]] = complex.constant [0.000000e+00, 1.000000e+00] : complex<f64>
-// CHECK:           %[[VAL_4:.*]] = bufferization.alloc_tensor() : tensor<32xcomplex<f64>, #sparse_tensor.encoding<{ lvlTypes = [ "compressed" ] }>>
-// CHECK:           %[[VAL_5:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<32xcomplex<f64>, #sparse_tensor.encoding<{ lvlTypes = [ "compressed" ] }>> to memref<?xindex>
-// CHECK:           %[[VAL_6:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<32xcomplex<f64>, #sparse_tensor.encoding<{ lvlTypes = [ "compressed" ] }>> to memref<?xindex>
-// CHECK:           %[[VAL_7:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32xcomplex<f64>, #sparse_tensor.encoding<{ lvlTypes = [ "compressed" ] }>> to memref<?xcomplex<f64>>
+// CHECK:           %[[VAL_4:.*]] = tensor.empty() : tensor<32xcomplex<f64>, #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK:           %[[VAL_5:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<32xcomplex<f64>, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
+// CHECK:           %[[VAL_6:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<32xcomplex<f64>, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
+// CHECK:           %[[VAL_7:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32xcomplex<f64>, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xcomplex<f64>>
 // CHECK:           %[[VAL_8:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_1]]] : memref<?xindex>
 // CHECK:           %[[VAL_9:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_2]]] : memref<?xindex>
 // CHECK:           %[[T:.*]] = scf.for %[[VAL_10:.*]] = %[[VAL_8]] to %[[VAL_9]] step %[[VAL_2]] {{.*}} {
 // CHECK:             %[[VAL_11:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_10]]] : memref<?xindex>
 // CHECK:             %[[VAL_12:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_10]]] : memref<?xcomplex<f64>>
 // CHECK:             %[[VAL_13:.*]] = complex.div %[[VAL_12]], %[[VAL_3]] : complex<f64>
-// CHECK:             %[[Y:.*]] = sparse_tensor.insert %[[VAL_13]] into %{{.*}}[%[[VAL_11]]] : tensor<32xcomplex<f64>, #sparse_tensor.encoding<{ lvlTypes = [ "compressed" ] }>>
+// CHECK:             %[[Y:.*]] = sparse_tensor.insert %[[VAL_13]] into %{{.*}}[%[[VAL_11]]] : tensor<32xcomplex<f64>, #sparse_tensor.encoding<{{{.*}}}>>
 // CHECK:             scf.yield %[[Y]]
 // CHECK:           }
-// CHECK:           %[[VAL_14:.*]] = sparse_tensor.load %[[T]] hasInserts : tensor<32xcomplex<f64>, #sparse_tensor.encoding<{ lvlTypes = [ "compressed" ] }>>
-// CHECK:           return %[[VAL_14]] : tensor<32xcomplex<f64>, #sparse_tensor.encoding<{ lvlTypes = [ "compressed" ] }>>
+// CHECK:           %[[VAL_14:.*]] = sparse_tensor.load %[[T]] hasInserts : tensor<32xcomplex<f64>, #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK:           return %[[VAL_14]] : tensor<32xcomplex<f64>, #sparse_tensor.encoding<{{{.*}}}>>
 // CHECK:         }
 func.func @complex_divbyc(%arg0: tensor<32xcomplex<f64>, #SV>) -> tensor<32xcomplex<f64>, #SV> {
   %c = complex.constant [0.0, 1.0] : complex<f64>
-  %init = bufferization.alloc_tensor() : tensor<32xcomplex<f64>, #SV>
+  %init = tensor.empty() : tensor<32xcomplex<f64>, #SV>
   %0 = linalg.generic #traitc
      ins(%arg0: tensor<32xcomplex<f64>, #SV>)
     outs(%init: tensor<32xcomplex<f64>, #SV>) {
diff --git a/mlir/test/Dialect/SparseTensor/sparse_kernels.mlir b/mlir/test/Dialect/SparseTensor/sparse_kernels.mlir
index 7f14934a4ef206e..a21fa7b35b54365 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_kernels.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_kernels.mlir
@@ -108,7 +108,7 @@ func.func @matmul_sparse_rhs(%a: tensor<10x20xf32>,
 // CHECK-DAG:       %[[VAL_3:.*]] = arith.constant 1 : index
 // CHECK-DAG:       %[[VAL_4:.*]] = arith.constant false
 // CHECK-DAG:       %[[VAL_5:.*]] = arith.constant true
-// CHECK-DAG:       %[[VAL_6:.*]] = bufferization.alloc_tensor() : tensor<4x4xf64, #sparse_tensor.encoding<{ lvlTypes = [ "compressed", "compressed" ] }>>
+// CHECK-DAG:       %[[VAL_6:.*]] = tensor.empty() : tensor<4x4xf64, #sparse_tensor.encoding<{ lvlTypes = [ "compressed", "compressed" ] }>>
 // CHECK-DAG:       %[[VAL_7:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<4x8xf64, #sparse_tensor.encoding<{ lvlTypes = [ "compressed", "compressed" ] }>> to memref<?xindex>
 // CHECK-DAG:       %[[VAL_8:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<4x8xf64, #sparse_tensor.encoding<{ lvlTypes = [ "compressed", "compressed" ] }>> to memref<?xindex>
 // CHECK-DAG:       %[[VAL_9:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 1 : index} : tensor<4x8xf64, #sparse_tensor.encoding<{ lvlTypes = [ "compressed", "compressed" ] }>> to memref<?xindex>
@@ -188,7 +188,7 @@ func.func @matmul_sparse_rhs(%a: tensor<10x20xf32>,
 func.func @matmul2(%A: tensor<4x8xf64, #DCSR>,
               %B: tensor<8x4xf64, #DCSR>) -> tensor<4x4xf64, #DCSR> {
   %c4 = arith.constant 4 : index
-  %C = bufferization.alloc_tensor() : tensor<4x4xf64, #DCSR>
+  %C = tensor.empty() : tensor<4x4xf64, #DCSR>
   %D = linalg.matmul
     ins(%A, %B: tensor<4x8xf64, #DCSR>, tensor<8x4xf64, #DCSR>)
        outs(%C: tensor<4x4xf64, #DCSR>) -> tensor<4x4xf64, #DCSR>
diff --git a/mlir/test/Dialect/SparseTensor/sparse_out.mlir b/mlir/test/Dialect/SparseTensor/sparse_out.mlir
index 97d8da213423d9e..9aee12f0d3cef75 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_out.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_out.mlir
@@ -102,7 +102,7 @@ func.func @sparse_simply_dynamic2(%argx: tensor<32x16xf32, #DCSR>) -> tensor<32x
 // CHECK-DAG:       %[[VAL_2:.*]] = arith.constant 0 : index
 // CHECK-DAG:       %[[VAL_3:.*]] = arith.constant 1 : index
 // CHECK-DAG:       %[[VAL_4:.*]] = arith.constant 2.000000e+00 : f32
-// CHECK-DAG:       %[[VAL_5:.*]] = bufferization.alloc_tensor() : tensor<10x20xf32, #sparse_tensor.encoding<{ lvlTypes = [ "compressed", "compressed" ] }>>
+// CHECK-DAG:       %[[VAL_5:.*]] = tensor.empty() : tensor<10x20xf32, #sparse_tensor.encoding<{ lvlTypes = [ "compressed", "compressed" ] }>>
 // CHECK-DAG:       %[[VAL_6:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 1 : index} : tensor<10x20xf32, #sparse_tensor.encoding<{ lvlTypes = [ "dense", "compressed" ] }>> to memref<?xindex>
 // CHECK-DAG:       %[[VAL_7:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 1 : index} : tensor<10x20xf32, #sparse_tensor.encoding<{ lvlTypes = [ "dense", "compressed" ] }>> to memref<?xindex>
 // CHECK-DAG:       %[[VAL_8:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<10x20xf32, #sparse_tensor.encoding<{ lvlTypes = [ "dense", "compressed" ] }>> to memref<?xf32>
@@ -124,7 +124,7 @@ func.func @sparse_simply_dynamic2(%argx: tensor<32x16xf32, #DCSR>) -> tensor<32x
 // CHECK:         }
 func.func @sparse_truly_dynamic(%arga: tensor<10x20xf32, #CSR>) -> tensor<10x20xf32, #DCSR> {
   %s = arith.constant 2.0 : f32
-  %xm = bufferization.alloc_tensor() : tensor<10x20xf32, #DCSR>
+  %xm = tensor.empty() : tensor<10x20xf32, #DCSR>
   %0 = linalg.generic #trait_scale
      ins(%arga: tensor<10x20xf32, #CSR>)
       outs(%xm: tensor<10x20xf32, #DCSR>) {
@@ -155,7 +155,7 @@ func.func @sparse_truly_dynamic(%arga: tensor<10x20xf32, #CSR>) -> tensor<10x20x
 // CHECK-DAG:       %[[VAL_TRUE:.*]] = arith.constant true
 // CHECK:           %[[VAL_5:.*]] = tensor.dim %[[VAL_0]], %[[VAL_2]] : tensor<?x?x?xi32, #sparse_tensor.encoding<{ lvlTypes = [ "compressed", "compressed", "compressed" ] }>>
 // CHECK:           %[[VAL_6:.*]] = tensor.dim %[[VAL_0]], %[[VAL_3]] : tensor<?x?x?xi32, #sparse_tensor.encoding<{ lvlTypes = [ "compressed", "compressed", "compressed" ] }>>
-// CHECK:           %[[VAL_7:.*]] = bufferization.alloc_tensor(%[[VAL_5]], %[[VAL_6]]) : tensor<?x?xi32, #sparse_tensor.encoding<{ lvlTypes = [ "compressed", "compressed" ] }>>
+// CHECK:           %[[VAL_7:.*]] = tensor.empty(%[[VAL_5]], %[[VAL_6]]) : tensor<?x?xi32, #sparse_tensor.encoding<{ lvlTypes = [ "compressed", "compressed" ] }>>
 // CHECK:           %[[VAL_8:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<?x?x?xi32, #sparse_tensor.encoding<{ lvlTypes = [ "compressed", "compressed", "compressed" ] }>> to memref<?xindex>
 // CHECK:           %[[VAL_9:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<?x?x?xi32, #sparse_tensor.encoding<{ lvlTypes = [ "compressed", "compressed", "compressed" ] }>> to memref<?xindex>
 // CHECK:           %[[VAL_10:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 1 : index} : tensor<?x?x?xi32, #sparse_tensor.encoding<{ lvlTypes = [ "compressed", "compressed", "compressed" ] }>> to memref<?xindex>
@@ -286,7 +286,7 @@ func.func @sumred(%arga: tensor<?x?x?xi32, #SparseTensor>,
   %c1 = arith.constant 1 : index
   %d0 = tensor.dim %arga, %c0 : tensor<?x?x?xi32, #SparseTensor>
   %d1 = tensor.dim %arga, %c1 : tensor<?x?x?xi32, #SparseTensor>
-  %xinit = bufferization.alloc_tensor(%d0, %d1) : tensor<?x?xi32, #DCSR>
+  %xinit = tensor.empty(%d0, %d1) : tensor<?x?xi32, #DCSR>
   %0 = linalg.generic #trait_sumred
     ins(%arga, %argb: tensor<?x?x?xi32, #SparseTensor>,
                       tensor<?x?x?xi32, #SparseTensor>)
@@ -318,7 +318,7 @@ func.func @sumred(%arga: tensor<?x?x?xi32, #SparseTensor>,
 // CHECK-DAG:       %[[VAL_5:.*]] = arith.constant true
 // CHECK:           %[[VAL_6:.*]] = tensor.dim %[[VAL_0]], %[[VAL_2]] : tensor<?x?xf32, #sparse_tensor.encoding<{ lvlTypes = [ "compressed", "compressed" ] }>>
 // CHECK:           %[[VAL_7:.*]] = tensor.dim %[[VAL_1]], %[[VAL_3]] : tensor<?x?xf32, #sparse_tensor.encoding<{ lvlTypes = [ "compressed", "compressed" ] }>>
-// CHECK:           %[[VAL_8:.*]] = bufferization.alloc_tensor(%[[VAL_6]], %[[VAL_7]]) : tensor<?x?xf32, #sparse_tensor.encoding<{ lvlTypes = [ "compressed", "compressed" ] }>>
+// CHECK:           %[[VAL_8:.*]] = tensor.empty(%[[VAL_6]], %[[VAL_7]]) : tensor<?x?xf32, #sparse_tensor.encoding<{ lvlTypes = [ "compressed", "compressed" ] }>>
 // CHECK:           %[[VAL_9:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<?x?xf32, #sparse_tensor.encoding<{ lvlTypes = [ "compressed", "compressed" ] }>> to memref<?xindex>
 // CHECK:           %[[VAL_10:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<?x?xf32, #sparse_tensor.encoding<{ lvlTypes = [ "compressed", "compressed" ] }>> to memref<?xindex>
 // CHECK:           %[[VAL_11:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 1 : index} : tensor<?x?xf32, #sparse_tensor.encoding<{ lvlTypes = [ "compressed", "compressed" ] }>> to memref<?xindex>
@@ -401,7 +401,7 @@ func.func @matmat(%arga: tensor<?x?xf32, #DCSR>,
   %c1 = arith.constant 1 : index
   %d0 = tensor.dim %arga, %c0 : tensor<?x?xf32, #DCSR>
   %d1 = tensor.dim %argb, %c1 : tensor<?x?xf32, #DCSR>
-  %cinit = bufferization.alloc_tensor(%d0, %d1) : tensor<?x?xf32, #DCSR>
+  %cinit = tensor.empty(%d0, %d1) : tensor<?x?xf32, #DCSR>
   %0 = linalg.generic #trait_matmat
        ins(%arga, %argb: tensor<?x?xf32, #DCSR>,
                          tensor<?x?xf32, #DCSR>)
diff --git a/mlir/test/Dialect/SparseTensor/sparse_vector_ops.mlir b/mlir/test/Dialect/SparseTensor/sparse_vector_ops.mlir
index 8e23f901da7654e..67841beaa6933f6 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_vector_ops.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_vector_ops.mlir
@@ -48,7 +48,7 @@
 // CHECK:           }
 func.func @vops(%arga: tensor<1024xf32, #DenseVector>,
                 %argb: tensor<1024xf32, #DenseVector>) -> tensor<1024xf32> {
-  %init = bufferization.alloc_tensor() : tensor<1024xf32>
+  %init = tensor.empty() : tensor<1024xf32>
   %o = arith.constant 1.0 : f32
   %c = arith.constant 2.0 : f32
   %i = arith.constant 255 : i64



More information about the Mlir-commits mailing list