[Mlir-commits] [mlir] dc46d5c - [mlir][sparse] improve dimop rewriting during conversion

Aart Bik llvmlistbot at llvm.org
Thu Sep 8 13:04:41 PDT 2022


Author: Aart Bik
Date: 2022-09-08T13:04:28-07:00
New Revision: dc46d5c979101e73fd61c1b6de942e17a2e8e480

URL: https://github.com/llvm/llvm-project/commit/dc46d5c979101e73fd61c1b6de942e17a2e8e480
DIFF: https://github.com/llvm/llvm-project/commit/dc46d5c979101e73fd61c1b6de942e17a2e8e480.diff

LOG: [mlir][sparse] improve dimop rewriting during conversion

Reviewed By: bixia

Differential Revision: https://reviews.llvm.org/D133512

Added: 
    

Modified: 
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
    mlir/test/Dialect/SparseTensor/conversion.mlir
    mlir/test/Dialect/SparseTensor/sparse_expand.mlir
    mlir/test/Dialect/SparseTensor/sparse_fill_zero.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
index df7a7de4a6580..e758a9b7314a9 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
@@ -1167,14 +1167,13 @@ class SparseTensorExpandConverter : public OpConversionPattern<ExpandOp> {
     // All initialization should be done on entry of the loop nest.
     rewriter.setInsertionPointAfter(op.getTensor().getDefiningOp());
     // Determine the size for access expansion (always the innermost stored
-    // dimension size, but we need to translate it back to the original
-    // dimension since the dim size utility applies dimension ordering).
+    // dimension size, translated back to original dimension). Note that we
+    // recursively rewrite the new DimOp on the **original** tensor.
     auto enc = getSparseTensorEncoding(srcType);
-    Value src = adaptor.getOperands()[0];
     unsigned innerDim = srcType.getRank() - 1;
     if (AffineMap p = enc.getDimOrdering())
       innerDim = p.getDimPosition(innerDim);
-    Value sz = genDimSizeCall(rewriter, loc, enc, src, innerDim);
+    Value sz = rewriter.create<tensor::DimOp>(loc, op.getTensor(), innerDim);
     // Allocate temporary buffers for values, filled-switch, and indices.
     // We do not use stack buffers for this, since the expanded size may
     // be rather large (as it envelops a single expanded dense dimension).

diff  --git a/mlir/test/Dialect/SparseTensor/conversion.mlir b/mlir/test/Dialect/SparseTensor/conversion.mlir
index 23bec35245e69..99dab303b8f77 100644
--- a/mlir/test/Dialect/SparseTensor/conversion.mlir
+++ b/mlir/test/Dialect/SparseTensor/conversion.mlir
@@ -16,10 +16,15 @@
   indexBitWidth = 32
 }>
 
-#SparseMatrix = #sparse_tensor.encoding<{
+#CSR = #sparse_tensor.encoding<{
   dimLevelType = ["dense", "compressed"]
 }>
 
+#CSC = #sparse_tensor.encoding<{
+  dimLevelType = ["dense", "compressed"],
+  dimOrdering = affine_map<(i,j) -> (j,i)>
+}>
+
 #SparseTensor = #sparse_tensor.encoding<{
   dimLevelType = ["dense", "compressed", "compressed"],
   dimOrdering = affine_map<(i,j,k) -> (k,i,j)>
@@ -97,9 +102,9 @@ func.func @sparse_new1d(%arg0: !llvm.ptr<i8>) -> tensor<128xf64, #SparseVector>
 //   CHECK-DAG: %[[Z:.*]] = memref.cast %[[R]] : memref<2xindex> to memref<?xindex>
 //       CHECK: %[[T:.*]] = call @newSparseTensor(%[[X]], %[[Y]], %[[Z]], %{{.*}}, %{{.*}}, %{{.*}}, %[[FromFile]], %[[A]])
 //       CHECK: return %[[T]] : !llvm.ptr<i8>
-func.func @sparse_new2d(%arg0: !llvm.ptr<i8>) -> tensor<?x?xf32, #SparseMatrix> {
-  %0 = sparse_tensor.new %arg0 : !llvm.ptr<i8> to tensor<?x?xf32, #SparseMatrix>
-  return %0 : tensor<?x?xf32, #SparseMatrix>
+func.func @sparse_new2d(%arg0: !llvm.ptr<i8>) -> tensor<?x?xf32, #CSR> {
+  %0 = sparse_tensor.new %arg0 : !llvm.ptr<i8> to tensor<?x?xf32, #CSR>
+  return %0 : tensor<?x?xf32, #CSR>
 }
 
 // CHECK-LABEL: func @sparse_new3d(
@@ -135,10 +140,10 @@ func.func @sparse_new3d(%arg0: !llvm.ptr<i8>) -> tensor<?x?x?xf32, #SparseTensor
 //       CHECK: %[[NP:.*]] = llvm.mlir.null : !llvm.ptr<i8>
 //       CHECK: %[[T:.*]] = call @newSparseTensor(%[[X]], %[[Y]], %[[Z]], %{{.*}}, %{{.*}}, %{{.*}}, %[[Empty]], %[[NP]])
 //       CHECK: return %[[T]] : !llvm.ptr<i8>
-func.func @sparse_init(%arg0: index, %arg1: index) -> tensor<?x?xf64, #SparseMatrix> {
-  %0 = bufferization.alloc_tensor(%arg0, %arg1) : tensor<?x?xf64, #SparseMatrix>
-  %1 = sparse_tensor.load %0 : tensor<?x?xf64, #SparseMatrix>
-  return %1 : tensor<?x?xf64, #SparseMatrix>
+func.func @sparse_init(%arg0: index, %arg1: index) -> tensor<?x?xf64, #CSR> {
+  %0 = bufferization.alloc_tensor(%arg0, %arg1) : tensor<?x?xf64, #CSR>
+  %1 = sparse_tensor.load %0 : tensor<?x?xf64, #CSR>
+  return %1 : tensor<?x?xf64, #CSR>
 }
 
 // CHECK-LABEL: func @sparse_release(
@@ -277,9 +282,9 @@ func.func @sparse_convert_1d_ss(%arg0: tensor<?xf32, #SparseVector64>) -> tensor
 //       CHECK: %[[T:.*]] = call @newSparseTensor(%[[X]], %[[Y]], %[[Z]], %{{.*}}, %{{.*}}, %{{.*}}, %[[FromCOO]], %[[C]])
 //       CHECK: call @delSparseTensorCOOF64(%[[C]])
 //       CHECK: return %[[T]] : !llvm.ptr<i8>
-func.func @sparse_convert_2d(%arg0: tensor<2x4xf64>) -> tensor<2x4xf64, #SparseMatrix> {
-  %0 = sparse_tensor.convert %arg0 : tensor<2x4xf64> to tensor<2x4xf64, #SparseMatrix>
-  return %0 : tensor<2x4xf64, #SparseMatrix>
+func.func @sparse_convert_2d(%arg0: tensor<2x4xf64>) -> tensor<2x4xf64, #CSR> {
+  %0 = sparse_tensor.convert %arg0 : tensor<2x4xf64> to tensor<2x4xf64, #CSR>
+  return %0 : tensor<2x4xf64, #CSR>
 }
 
 // CHECK-LABEL: func @sparse_constant() -> !llvm.ptr<i8> {
@@ -309,12 +314,12 @@ func.func @sparse_convert_2d(%arg0: tensor<2x4xf64>) -> tensor<2x4xf64, #SparseM
 //       CHECK: %[[T:.*]] = call @newSparseTensor(%[[X]], %[[Y]], %[[Z]], %{{.*}}, %{{.*}}, %{{.*}}, %[[FromCOO]], %[[C]])
 //       CHECK: call @delSparseTensorCOOF32(%[[C]])
 //       CHECK: return %[[T]] : !llvm.ptr<i8>
-func.func @sparse_constant() -> tensor<8x7xf32, #SparseMatrix>{
+func.func @sparse_constant() -> tensor<8x7xf32, #CSR>{
   // Initialize a tensor.
   %0 = arith.constant sparse<[[0, 0], [1, 6]], [1.0, 5.0]> : tensor<8x7xf32>
   // Convert the tensor to a sparse tensor.
-  %1 = sparse_tensor.convert %0 : tensor<8x7xf32> to tensor<8x7xf32, #SparseMatrix>
-  return %1 : tensor<8x7xf32, #SparseMatrix>
+  %1 = sparse_tensor.convert %0 : tensor<8x7xf32> to tensor<8x7xf32, #CSR>
+  return %1 : tensor<8x7xf32, #CSR>
 }
 
 // CHECK-LABEL: func @sparse_convert_3d(
@@ -493,20 +498,52 @@ func.func @sparse_insert(%arg0: tensor<128xf32, #SparseVector>,
   return
 }
 
-// CHECK-LABEL: func @sparse_expansion()
-//  CHECK-DAG:  %[[C:.*]] = arith.constant 1 : index
+// CHECK-LABEL: func @sparse_expansion1()
+//       CHECK: %[[N:.*]] = call @newSparseTensor
+//       CHECK: %[[A:.*]] = memref.alloc() : memref<8xf64>
+//       CHECK: %[[B:.*]] = memref.alloc() : memref<8xi1>
+//       CHECK: %[[C:.*]] = memref.alloc() : memref<8xindex>
+//       CHECK: %[[D:.*]] = memref.cast %[[C]] : memref<8xindex> to memref<?xindex>
+//   CHECK-DAG: linalg.fill ins(%{{.*}} : f64) outs(%[[A]] : memref<8xf64>)
+//   CHECK-DAG: linalg.fill ins(%{{.*}} : i1) outs(%[[B]] : memref<8xi1>)
+//       CHECK: return %[[D]] : memref<?xindex>
+func.func @sparse_expansion1() -> memref<?xindex> {
+  %0 = bufferization.alloc_tensor() : tensor<4x8xf64, #CSR>
+  %values, %filled, %added, %count = sparse_tensor.expand %0
+    : tensor<4x8xf64, #CSR> to memref<?xf64>, memref<?xi1>, memref<?xindex>, index
+  return %added : memref<?xindex>
+}
+
+// CHECK-LABEL: func @sparse_expansion2()
+//       CHECK: %[[N:.*]] = call @newSparseTensor
+//       CHECK: %[[A:.*]] = memref.alloc() : memref<4xf64>
+//       CHECK: %[[B:.*]] = memref.alloc() : memref<4xi1>
+//       CHECK: %[[C:.*]] = memref.alloc() : memref<4xindex>
+//       CHECK: %[[D:.*]] = memref.cast %[[C]] : memref<4xindex> to memref<?xindex>
+//   CHECK-DAG: linalg.fill ins(%{{.*}} : f64) outs(%[[A]] : memref<4xf64>)
+//   CHECK-DAG: linalg.fill ins(%{{.*}} : i1) outs(%[[B]] : memref<4xi1>)
+//       CHECK: return %[[D]] : memref<?xindex>
+func.func @sparse_expansion2() -> memref<?xindex> {
+  %0 = bufferization.alloc_tensor() : tensor<4x8xf64, #CSC>
+  %values, %filled, %added, %count = sparse_tensor.expand %0
+    : tensor<4x8xf64, #CSC> to memref<?xf64>, memref<?xi1>, memref<?xindex>, index
+  return %added : memref<?xindex>
+}
+
+// CHECK-LABEL: func @sparse_expansion3(
+//       CHECK: %[[C1:.*]] = arith.constant 1 : index
 //       CHECK: %[[N:.*]] = call @newSparseTensor
-//       CHECK: %[[S:.*]] = call @sparseDimSize(%[[N]], %[[C]]) : (!llvm.ptr<i8>, index) -> index
+//       CHECK: %[[S:.*]] = call @sparseDimSize(%[[N]], %c1) : (!llvm.ptr<i8>, index) -> index
 //       CHECK: %[[A:.*]] = memref.alloc(%[[S]]) : memref<?xf64>
 //       CHECK: %[[B:.*]] = memref.alloc(%[[S]]) : memref<?xi1>
 //       CHECK: %[[C:.*]] = memref.alloc(%[[S]]) : memref<?xindex>
 //   CHECK-DAG: linalg.fill ins(%{{.*}} : f64) outs(%[[A]] : memref<?xf64>)
 //   CHECK-DAG: linalg.fill ins(%{{.*}} : i1) outs(%[[B]] : memref<?xi1>)
 //       CHECK: return %[[C]] : memref<?xindex>
-func.func @sparse_expansion() -> memref<?xindex> {
-  %0 = bufferization.alloc_tensor() : tensor<4x8xf64, #SparseMatrix>
+func.func @sparse_expansion3(%arg0: index, %arg1: index) -> memref<?xindex> {
+  %0 = bufferization.alloc_tensor(%arg0, %arg1) : tensor<?x?xf64, #CSC>
   %values, %filled, %added, %count = sparse_tensor.expand %0
-    : tensor<4x8xf64, #SparseMatrix> to memref<?xf64>, memref<?xi1>, memref<?xindex>, index
+    : tensor<?x?xf64, #CSC> to memref<?xf64>, memref<?xi1>, memref<?xindex>, index
   return %added : memref<?xindex>
 }
 
@@ -521,11 +558,11 @@ func.func @sparse_expansion() -> memref<?xindex> {
 //   CHECK-DAG: memref.dealloc %[[D]] : memref<?xi1>
 //   CHECK-DAG: memref.dealloc %[[E]] : memref<?xindex>
 //       CHECK: return
-func.func @sparse_compression(%arg0: tensor<8x8xf64, #SparseMatrix>,
+func.func @sparse_compression(%arg0: tensor<8x8xf64, #CSR>,
                          %arg1: memref<?xindex>, %arg2: memref<?xf64>, %arg3: memref<?xi1>,
                          %arg4: memref<?xindex>, %arg5: index) {
   sparse_tensor.compress %arg0, %arg1, %arg2, %arg3, %arg4, %arg5
-    : tensor<8x8xf64, #SparseMatrix>, memref<?xindex>, memref<?xf64>, memref<?xi1>, memref<?xindex>, index
+    : tensor<8x8xf64, #CSR>, memref<?xindex>, memref<?xf64>, memref<?xi1>, memref<?xindex>, index
   return
 }
 
@@ -538,8 +575,8 @@ func.func @sparse_compression(%arg0: tensor<8x8xf64, #SparseMatrix>,
 //       CHECK: call @outSparseTensorF64(%[[COO]], %[[B]], %[[Sort]]) : (!llvm.ptr<i8>, !llvm.ptr<i8>, i1) -> ()
 //       CHECK: call @delSparseTensorCOOF64(%[[COO]])
 //       CHECK: return
-func.func @sparse_out1(%arg0: tensor<?x?xf64, #SparseMatrix>, %arg1: !llvm.ptr<i8>) {
-  sparse_tensor.out %arg0, %arg1 : tensor<?x?xf64, #SparseMatrix>, !llvm.ptr<i8>
+func.func @sparse_out1(%arg0: tensor<?x?xf64, #CSR>, %arg1: !llvm.ptr<i8>) {
+  sparse_tensor.out %arg0, %arg1 : tensor<?x?xf64, #CSR>, !llvm.ptr<i8>
   return
 }
 
@@ -562,9 +599,9 @@ func.func @sparse_out2(%arg0: tensor<?x?x?xf32, #SparseTensor>, %arg1: !llvm.ptr
 //       CHECK: %[[D:.*]] = bufferization.alloc_tensor
 //       CHECK: return %[[S]], %[[D]] : !llvm.ptr<i8>, tensor<?x?xf64>
 func.func @sparse_and_dense_init(%arg0: index, %arg1: index)
-           -> (tensor<?x?xf64, #SparseMatrix>, tensor<?x?xf64>) {
-  %0 = bufferization.alloc_tensor(%arg0, %arg1) : tensor<?x?xf64, #SparseMatrix>
-  %1 = sparse_tensor.load %0 : tensor<?x?xf64, #SparseMatrix>
+           -> (tensor<?x?xf64, #CSR>, tensor<?x?xf64>) {
+  %0 = bufferization.alloc_tensor(%arg0, %arg1) : tensor<?x?xf64, #CSR>
+  %1 = sparse_tensor.load %0 : tensor<?x?xf64, #CSR>
   %2 = bufferization.alloc_tensor(%arg0, %arg1) : tensor<?x?xf64>
-  return %1, %2 : tensor<?x?xf64, #SparseMatrix>, tensor<?x?xf64>
+  return %1, %2 : tensor<?x?xf64, #CSR>, tensor<?x?xf64>
 }

diff  --git a/mlir/test/Dialect/SparseTensor/sparse_expand.mlir b/mlir/test/Dialect/SparseTensor/sparse_expand.mlir
index d266d4b99a72a..fd4e297985f80 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_expand.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_expand.mlir
@@ -98,12 +98,12 @@ func.func @kernel(%arga: tensor<?x?xf64, #DCSC>) -> tensor<?xf64, #SV> {
 // CHECK-CONVERT-LABEL: func @matmul1(
 // CHECK-CONVERT-DAG: %[[C0:.*]] = arith.constant 0 : index
 // CHECK-CONVERT-DAG: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-CONVERT-DAG: %[[C4:.*]] = arith.constant 4 : index
 // CHECK-CONVERT-DAG: %[[C8:.*]] = arith.constant 8 : index
 // CHECK-CONVERT: %[[N:.*]] = call @newSparseTensor
-// CHECK-CONVERT: %[[S:.*]] = call @sparseDimSize(%[[N]], %[[C1]])
-// CHECK-CONVERT: %[[A:.*]] = memref.alloc(%[[S]]) : memref<?xf64>
-// CHECK-CONVERT: %[[B:.*]] = memref.alloc(%[[S]]) : memref<?xi1>
-// CHECK-CONVERT: %[[C:.*]] = memref.alloc(%[[S]]) : memref<?xindex>
+// CHECK-CONVERT: %[[A:.*]] = memref.alloc(%[[C4]]) : memref<?xf64>
+// CHECK-CONVERT: %[[B:.*]] = memref.alloc(%[[C4]]) : memref<?xi1>
+// CHECK-CONVERT: %[[C:.*]] = memref.alloc(%[[C4]]) : memref<?xindex>
 // CHECK-CONVERT: linalg.fill ins(%{{.*}} : f64) outs(%[[A]] : memref<?xf64>)
 // CHECK-CONVERT: linalg.fill ins(%{{.*}} : i1) outs(%[[B]] : memref<?xi1>)
 // CHECK-CONVERT: scf.for %{{.*}} = %[[C0]] to %[[C8]] step %[[C1]] {
@@ -147,11 +147,11 @@ func.func @matmul1(%A: tensor<8x2xf64, #CSR>,
 // CHECK-CONVERT-DAG: %[[C0:.*]] = arith.constant 0 : index
 // CHECK-CONVERT-DAG: %[[C1:.*]] = arith.constant 1 : index
 // CHECK-CONVERT-DAG: %[[C4:.*]] = arith.constant 4 : index
+// CHECK-CONVERT-DAG: %[[C8:.*]] = arith.constant 8 : index
 // CHECK-CONVERT: %[[N:.*]] = call @newSparseTensor
-// CHECK-CONVERT: %[[S:.*]] = call @sparseDimSize(%[[N]], %[[C1]])
-// CHECK-CONVERT: %[[A:.*]] = memref.alloc(%[[S]]) : memref<?xf64>
-// CHECK-CONVERT: %[[B:.*]] = memref.alloc(%[[S]]) : memref<?xi1>
-// CHECK-CONVERT: %[[C:.*]] = memref.alloc(%[[S]]) : memref<?xindex>
+// CHECK-CONVERT: %[[A:.*]] = memref.alloc(%[[C8]]) : memref<?xf64>
+// CHECK-CONVERT: %[[B:.*]] = memref.alloc(%[[C8]]) : memref<?xi1>
+// CHECK-CONVERT: %[[C:.*]] = memref.alloc(%[[C8]]) : memref<?xindex>
 // CHECK-CONVERT: linalg.fill ins(%{{.*}} : f64) outs(%[[A]] : memref<?xf64>)
 // CHECK-CONVERT: linalg.fill ins(%{{.*}} : i1) outs(%[[B]] : memref<?xi1>)
 // CHECK-CONVERT: scf.for %{{.*}} = %[[C0]] to %[[C4]] step %[[C1]] {

diff  --git a/mlir/test/Dialect/SparseTensor/sparse_fill_zero.mlir b/mlir/test/Dialect/SparseTensor/sparse_fill_zero.mlir
index f2812cd7fb673..7b87411a08f2d 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_fill_zero.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_fill_zero.mlir
@@ -1,122 +1,127 @@
 // RUN: mlir-opt %s --linalg-generalize-named-ops --sparsification --sparse-tensor-conversion --canonicalize --cse | FileCheck %s
 
 #DCSR = #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>
-// CHECK-LABEL:  func.func @fill_zero_after_alloc
-// CHECK-SAME:     %[[TMP_arg0:.*]]: !llvm.ptr<i8>,
-// CHECK-SAME:     %[[TMP_arg1:.*]]: !llvm.ptr<i8>
-// CHECK:    %[[TMP_cst:.*]] = arith.constant 0.000000e+00 : f64
-// CHECK:    %[[TMP_c1_i32:.*]] = arith.constant 1 : i32
-// CHECK:    %[[TMP_c0_i32:.*]] = arith.constant 0 : i32
-// CHECK:    %[[TMP_c0:.*]] = arith.constant 0 : index
-// CHECK:    %[[TMP_c1:.*]] = arith.constant 1 : index
-// CHECK:    %[[TMP_false:.*]] = arith.constant false
-// CHECK:    %[[TMP_true:.*]] = arith.constant true
-// CHECK:    %[[TMP_c100:.*]] = arith.constant 100 : index
-// CHECK:    %[[TMP_c1_i8:.*]] = arith.constant 1 : i8
-// CHECK:    %[[TMP_0:.*]] = memref.alloca() : memref<2xi8>
-// CHECK:    %[[TMP_1:.*]] = memref.cast %[[TMP_0]] : memref<2xi8> to memref<?xi8>
-// CHECK:    memref.store %[[TMP_c1_i8]], %[[TMP_0]][%[[TMP_c0]]] : memref<2xi8>
-// CHECK:    memref.store %[[TMP_c1_i8]], %[[TMP_0]][%[[TMP_c1]]] : memref<2xi8>
-// CHECK:    %[[TMP_2:.*]] = memref.alloca() : memref<2xindex>
-// CHECK:    %[[TMP_3:.*]] = memref.cast %[[TMP_2]] : memref<2xindex> to memref<?xindex>
-// CHECK:    memref.store %[[TMP_c100]], %[[TMP_2]][%[[TMP_c0]]] : memref<2xindex>
-// CHECK:    memref.store %[[TMP_c100]], %[[TMP_2]][%[[TMP_c1]]] : memref<2xindex>
-// CHECK:    %[[TMP_4:.*]] = memref.alloca() : memref<2xindex>
-// CHECK:    %[[TMP_5:.*]] = memref.cast %[[TMP_4]] : memref<2xindex> to memref<?xindex>
-// CHECK:    memref.store %[[TMP_c0]], %[[TMP_4]][%[[TMP_c0]]] : memref<2xindex>
-// CHECK:    memref.store %[[TMP_c1]], %[[TMP_4]][%[[TMP_c1]]] : memref<2xindex>
-// CHECK:    %[[TMP_6:.*]] = llvm.mlir.null : !llvm.ptr<i8>
-// CHECK:    %[[TMP_7:.*]] = call @newSparseTensor(%[[TMP_1]], %[[TMP_3]], %[[TMP_5]], %[[TMP_c0_i32]], %[[TMP_c0_i32]], %[[TMP_c1_i32]], %[[TMP_c0_i32]], %[[TMP_6]])
-// CHECK:    %[[TMP_8:.*]] = call @sparseDimSize(%[[TMP_7]], %[[TMP_c1]])
-// CHECK:    %[[TMP_9:.*]] = memref.alloc(%[[TMP_8]]) : memref<?xf64>
-// CHECK:    %[[TMP_10:.*]] = memref.alloc(%[[TMP_8]]) : memref<?xi1>
-// CHECK:    %[[TMP_11:.*]] = memref.alloc(%[[TMP_8]]) : memref<?xindex>
-// CHECK:    linalg.fill ins(%[[TMP_cst]] : f64) outs(%[[TMP_9]] : memref<?xf64>)
-// CHECK:    linalg.fill ins(%[[TMP_false]] : i1) outs(%[[TMP_10]] : memref<?xi1>)
-// CHECK:    %[[TMP_12:.*]] = call @sparsePointers0(%[[TMP_arg0]], %[[TMP_c0]])
-// CHECK:    %[[TMP_13:.*]] = call @sparseIndices0(%[[TMP_arg0]], %[[TMP_c0]])
-// CHECK:    %[[TMP_14:.*]] = call @sparsePointers0(%[[TMP_arg0]], %[[TMP_c1]])
-// CHECK:    %[[TMP_15:.*]] = call @sparseIndices0(%[[TMP_arg0]], %[[TMP_c1]])
-// CHECK:    %[[TMP_16:.*]] = call @sparseValuesF64(%[[TMP_arg0]])
-// CHECK:    %[[TMP_17:.*]] = call @sparsePointers0(%[[TMP_arg1]], %[[TMP_c0]])
-// CHECK:    %[[TMP_18:.*]] = call @sparseIndices0(%[[TMP_arg1]], %[[TMP_c0]])
-// CHECK:    %[[TMP_19:.*]] = call @sparsePointers0(%[[TMP_arg1]], %[[TMP_c1]])
-// CHECK:    %[[TMP_20:.*]] = call @sparseIndices0(%[[TMP_arg1]], %[[TMP_c1]])
-// CHECK:    %[[TMP_21:.*]] = call @sparseValuesF64(%[[TMP_arg1]])
-// CHECK:    %[[TMP_22:.*]] = memref.alloca() : memref<2xindex>
-// CHECK:    %[[TMP_23:.*]] = memref.cast %[[TMP_22]] : memref<2xindex> to memref<?xindex>
-// CHECK:    %[[TMP_24:.*]] = memref.load %[[TMP_12]][%[[TMP_c0]]] : memref<?xindex>
-// CHECK:    %[[TMP_25:.*]] = memref.load %[[TMP_12]][%[[TMP_c1]]] : memref<?xindex>
-// CHECK:    scf.for %[[TMP_arg2:.*]] = %[[TMP_24]] to %[[TMP_25]] step %[[TMP_c1]] {
-// CHECK:      %[[TMP_26:.*]] = memref.load %[[TMP_13]][%[[TMP_arg2]]] : memref<?xindex>
-// CHECK:      memref.store %[[TMP_26]], %[[TMP_22]][%[[TMP_c0]]] : memref<2xindex>
-// CHECK:      %[[TMP_27:.*]] = memref.load %[[TMP_14]][%[[TMP_arg2]]] : memref<?xindex>
-// CHECK:      %[[TMP_28:.*]] = arith.addi %[[TMP_arg2]], %[[TMP_c1]] : index
-// CHECK:      %[[TMP_29:.*]] = memref.load %[[TMP_14]][%[[TMP_28]]] : memref<?xindex>
-// CHECK:      %[[TMP_30:.*]] = memref.load %[[TMP_17]][%[[TMP_c0]]] : memref<?xindex>
-// CHECK:      %[[TMP_31:.*]] = memref.load %[[TMP_17]][%[[TMP_c1]]] : memref<?xindex>
-// CHECK:      %[[TMP_32:.*]]:3 = scf.while (%[[TMP_arg3:.*]] = %[[TMP_27]], %[[TMP_arg4:.*]] = %[[TMP_30]], %[[TMP_arg5:.*]] = %[[TMP_c0]]) : (index, index, index) -> (index, index, index) {
-// CHECK:        %[[TMP_33:.*]] = arith.cmpi ult, %[[TMP_arg3]], %[[TMP_29]] : index
-// CHECK:        %[[TMP_34:.*]] = arith.cmpi ult, %[[TMP_arg4]], %[[TMP_31]] : index
-// CHECK:        %[[TMP_35:.*]] = arith.andi %[[TMP_33]], %[[TMP_34]] : i1
-// CHECK:        scf.condition(%[[TMP_35]]) %[[TMP_arg3]], %[[TMP_arg4]], %[[TMP_arg5]] : index, index, index
-// CHECK:      } do {
-// CHECK:      ^bb0(%[[TMP_arg3:.*]]: index, %[[TMP_arg4:.*]]: index, %[[TMP_arg5:.*]]: index):
-// CHECK:        %[[TMP_33:.*]] = memref.load %[[TMP_15]][%[[TMP_arg3]]] : memref<?xindex>
-// CHECK:        %[[TMP_34:.*]] = memref.load %[[TMP_18]][%[[TMP_arg4]]] : memref<?xindex>
-// CHECK:        %[[TMP_35:.*]] = arith.cmpi ult, %[[TMP_34]], %[[TMP_33]] : index
-// CHECK:        %[[TMP_36:.*]] = arith.select %[[TMP_35]], %[[TMP_34]], %[[TMP_33]] : index
-// CHECK:        %[[TMP_37:.*]] = arith.cmpi eq, %[[TMP_33]], %[[TMP_36]] : index
-// CHECK:        %[[TMP_38:.*]] = arith.cmpi eq, %[[TMP_34]], %[[TMP_36]] : index
-// CHECK:        %[[TMP_39:.*]] = arith.andi %[[TMP_37]], %[[TMP_38]] : i1
-// CHECK:        %[[TMP_40:.*]] = scf.if %[[TMP_39]] -> (index) {
-// CHECK:          %[[TMP_45:.*]] = memref.load %[[TMP_16]][%[[TMP_arg3]]] : memref<?xf64>
-// CHECK:          %[[TMP_46:.*]] = memref.load %[[TMP_19]][%[[TMP_arg4]]] : memref<?xindex>
-// CHECK:          %[[TMP_47:.*]] = arith.addi %[[TMP_arg4]], %[[TMP_c1]] : index
-// CHECK:          %[[TMP_48:.*]] = memref.load %[[TMP_19]][%[[TMP_47]]] : memref<?xindex>
-// CHECK:          %[[TMP_49:.*]] = scf.for %[[TMP_arg6:.*]] = %[[TMP_46]] to %[[TMP_48]] step %[[TMP_c1]] iter_args(%[[TMP_arg7:.*]] = %[[TMP_arg5]]) -> (index) {
-// CHECK:            %[[TMP_50:.*]] = memref.load %[[TMP_20]][%[[TMP_arg6]]] : memref<?xindex>
-// CHECK:            %[[TMP_51:.*]] = memref.load %[[TMP_9]][%[[TMP_50]]] : memref<?xf64>
-// CHECK:            %[[TMP_52:.*]] = memref.load %[[TMP_21]][%[[TMP_arg6]]] : memref<?xf64>
-// CHECK:            %[[TMP_53:.*]] = arith.mulf %[[TMP_45]], %[[TMP_52]] : f64
-// CHECK:            %[[TMP_54:.*]] = arith.addf %[[TMP_51]], %[[TMP_53]] : f64
-// CHECK:            %[[TMP_55:.*]] = memref.load %[[TMP_10]][%[[TMP_50]]] : memref<?xi1>
-// CHECK:            %[[TMP_56:.*]] = arith.cmpi eq, %[[TMP_55]], %[[TMP_false]] : i1
-// CHECK:            %[[TMP_57:.*]] = scf.if %[[TMP_56]] -> (index) {
-// CHECK:              memref.store %[[TMP_true]], %[[TMP_10]][%[[TMP_50]]] : memref<?xi1>
-// CHECK:              memref.store %[[TMP_50]], %[[TMP_11]][%[[TMP_arg7]]] : memref<?xindex>
-// CHECK:              %[[TMP_58:.*]] = arith.addi %[[TMP_arg7]], %[[TMP_c1]] : index
-// CHECK:              scf.yield %[[TMP_58]] : index
-// CHECK:            } else {
-// CHECK:              scf.yield %[[TMP_arg7]] : index
-// CHECK:            }
-// CHECK:            memref.store %[[TMP_54]], %[[TMP_9]][%[[TMP_50]]] : memref<?xf64>
-// CHECK:            scf.yield %[[TMP_57]] : index
-// CHECK:          }
-// CHECK:          scf.yield %[[TMP_49]] : index
-// CHECK:        } else {
-// CHECK:          scf.yield %[[TMP_arg5]] : index
-// CHECK:        }
-// CHECK:        %[[TMP_41:.*]] = arith.addi %[[TMP_arg3]], %[[TMP_c1]] : index
-// CHECK:        %[[TMP_42:.*]] = arith.select %[[TMP_37]], %[[TMP_41]], %[[TMP_arg3]] : index
-// CHECK:        %[[TMP_43:.*]] = arith.addi %[[TMP_arg4]], %[[TMP_c1]] : index
-// CHECK:        %[[TMP_44:.*]] = arith.select %[[TMP_38]], %[[TMP_43]], %[[TMP_arg4]] : index
-// CHECK:        scf.yield %[[TMP_42]], %[[TMP_44]], %[[TMP_40]] : index, index, index
-// CHECK:      }
-// CHECK:      func.call @expInsertF64(%[[TMP_7]], %[[TMP_23]], %[[TMP_9]], %[[TMP_10]], %[[TMP_11]], %[[TMP_32]]#2)
-// CHECK:    }
-// CHECK:    memref.dealloc %[[TMP_9]] : memref<?xf64>
-// CHECK:    memref.dealloc %[[TMP_10]] : memref<?xi1>
-// CHECK:    memref.dealloc %[[TMP_11]] : memref<?xindex>
-// CHECK:    call @endInsert(%[[TMP_7]]) : (!llvm.ptr<i8>) -> ()
-// CHECK:    return %[[TMP_7]] : !llvm.ptr<i8>
-func.func @fill_zero_after_alloc(%arg0: tensor<100x100xf64, #DCSR>,
-                                 %arg1: tensor<100x100xf64, #DCSR>) -> tensor<100x100xf64, #DCSR> {
-  %0 = bufferization.alloc_tensor() : tensor<100x100xf64, #DCSR>
+
+// CHECK-LABEL:   func.func @fill_zero_after_alloc(
+// CHECK-SAME:      %[[VAL_0:.*]]: !llvm.ptr<i8>,
+// CHECK-SAME:      %[[VAL_1:.*]]: !llvm.ptr<i8>) -> !llvm.ptr<i8> {
+// CHECK:           %[[VAL_2:.*]] = arith.constant 0.000000e+00 : f64
+// CHECK:           %[[VAL_3:.*]] = arith.constant 1 : i32
+// CHECK:           %[[VAL_4:.*]] = arith.constant 0 : i32
+// CHECK:           %[[VAL_5:.*]] = arith.constant 0 : index
+// CHECK:           %[[VAL_6:.*]] = arith.constant 1 : index
+// CHECK:           %[[VAL_7:.*]] = arith.constant false
+// CHECK:           %[[VAL_8:.*]] = arith.constant true
+// CHECK:           %[[VAL_9:.*]] = arith.constant 100 : index
+// CHECK:           %[[VAL_10:.*]] = arith.constant 300 : index
+// CHECK:           %[[VAL_11:.*]] = arith.constant 1 : i8
+// CHECK:           %[[VAL_12:.*]] = memref.alloca() : memref<2xi8>
+// CHECK:           %[[VAL_13:.*]] = memref.cast %[[VAL_12]] : memref<2xi8> to memref<?xi8>
+// CHECK:           memref.store %[[VAL_11]], %[[VAL_12]]{{\[}}%[[VAL_5]]] : memref<2xi8>
+// CHECK:           memref.store %[[VAL_11]], %[[VAL_12]]{{\[}}%[[VAL_6]]] : memref<2xi8>
+// CHECK:           %[[VAL_14:.*]] = memref.alloca() : memref<2xindex>
+// CHECK:           %[[VAL_15:.*]] = memref.cast %[[VAL_14]] : memref<2xindex> to memref<?xindex>
+// CHECK:           memref.store %[[VAL_9]], %[[VAL_14]]{{\[}}%[[VAL_5]]] : memref<2xindex>
+// CHECK:           memref.store %[[VAL_10]], %[[VAL_14]]{{\[}}%[[VAL_6]]] : memref<2xindex>
+// CHECK:           %[[VAL_16:.*]] = memref.alloca() : memref<2xindex>
+// CHECK:           %[[VAL_17:.*]] = memref.cast %[[VAL_16]] : memref<2xindex> to memref<?xindex>
+// CHECK:           memref.store %[[VAL_5]], %[[VAL_16]]{{\[}}%[[VAL_5]]] : memref<2xindex>
+// CHECK:           memref.store %[[VAL_6]], %[[VAL_16]]{{\[}}%[[VAL_6]]] : memref<2xindex>
+// CHECK:           %[[VAL_18:.*]] = llvm.mlir.null : !llvm.ptr<i8>
+// CHECK:           %[[VAL_19:.*]] = call @newSparseTensor(%[[VAL_13]], %[[VAL_15]], %[[VAL_17]], %[[VAL_4]], %[[VAL_4]], %[[VAL_3]], %[[VAL_4]], %[[VAL_18]]) : (memref<?xi8>, memref<?xindex>, memref<?xindex>, i32, i32, i32, i32, !llvm.ptr<i8>) -> !llvm.ptr<i8>
+// CHECK:           %[[VAL_20:.*]] = memref.alloc() : memref<300xf64>
+// CHECK:           %[[VAL_21:.*]] = memref.cast %[[VAL_20]] : memref<300xf64> to memref<?xf64>
+// CHECK:           %[[VAL_22:.*]] = memref.alloc() : memref<300xi1>
+// CHECK:           %[[VAL_23:.*]] = memref.cast %[[VAL_22]] : memref<300xi1> to memref<?xi1>
+// CHECK:           %[[VAL_24:.*]] = memref.alloc() : memref<300xindex>
+// CHECK:           %[[VAL_25:.*]] = memref.cast %[[VAL_24]] : memref<300xindex> to memref<?xindex>
+// CHECK:           linalg.fill ins(%[[VAL_2]] : f64) outs(%[[VAL_20]] : memref<300xf64>)
+// CHECK:           linalg.fill ins(%[[VAL_7]] : i1) outs(%[[VAL_22]] : memref<300xi1>)
+// CHECK:           %[[VAL_26:.*]] = call @sparsePointers0(%[[VAL_0]], %[[VAL_5]]) : (!llvm.ptr<i8>, index) -> memref<?xindex>
+// CHECK:           %[[VAL_27:.*]] = call @sparseIndices0(%[[VAL_0]], %[[VAL_5]]) : (!llvm.ptr<i8>, index) -> memref<?xindex>
+// CHECK:           %[[VAL_28:.*]] = call @sparsePointers0(%[[VAL_0]], %[[VAL_6]]) : (!llvm.ptr<i8>, index) -> memref<?xindex>
+// CHECK:           %[[VAL_29:.*]] = call @sparseIndices0(%[[VAL_0]], %[[VAL_6]]) : (!llvm.ptr<i8>, index) -> memref<?xindex>
+// CHECK:           %[[VAL_30:.*]] = call @sparseValuesF64(%[[VAL_0]]) : (!llvm.ptr<i8>) -> memref<?xf64>
+// CHECK:           %[[VAL_31:.*]] = call @sparsePointers0(%[[VAL_1]], %[[VAL_5]]) : (!llvm.ptr<i8>, index) -> memref<?xindex>
+// CHECK:           %[[VAL_32:.*]] = call @sparseIndices0(%[[VAL_1]], %[[VAL_5]]) : (!llvm.ptr<i8>, index) -> memref<?xindex>
+// CHECK:           %[[VAL_33:.*]] = call @sparsePointers0(%[[VAL_1]], %[[VAL_6]]) : (!llvm.ptr<i8>, index) -> memref<?xindex>
+// CHECK:           %[[VAL_34:.*]] = call @sparseIndices0(%[[VAL_1]], %[[VAL_6]]) : (!llvm.ptr<i8>, index) -> memref<?xindex>
+// CHECK:           %[[VAL_35:.*]] = call @sparseValuesF64(%[[VAL_1]]) : (!llvm.ptr<i8>) -> memref<?xf64>
+// CHECK:           %[[VAL_36:.*]] = memref.alloca() : memref<2xindex>
+// CHECK:           %[[VAL_37:.*]] = memref.cast %[[VAL_36]] : memref<2xindex> to memref<?xindex>
+// CHECK:           %[[VAL_38:.*]] = memref.load %[[VAL_26]]{{\[}}%[[VAL_5]]] : memref<?xindex>
+// CHECK:           %[[VAL_39:.*]] = memref.load %[[VAL_26]]{{\[}}%[[VAL_6]]] : memref<?xindex>
+// CHECK:           scf.for %[[VAL_40:.*]] = %[[VAL_38]] to %[[VAL_39]] step %[[VAL_6]] {
+// CHECK:             %[[VAL_41:.*]] = memref.load %[[VAL_27]]{{\[}}%[[VAL_40]]] : memref<?xindex>
+// CHECK:             memref.store %[[VAL_41]], %[[VAL_36]]{{\[}}%[[VAL_5]]] : memref<2xindex>
+// CHECK:             %[[VAL_42:.*]] = memref.load %[[VAL_28]]{{\[}}%[[VAL_40]]] : memref<?xindex>
+// CHECK:             %[[VAL_43:.*]] = arith.addi %[[VAL_40]], %[[VAL_6]] : index
+// CHECK:             %[[VAL_44:.*]] = memref.load %[[VAL_28]]{{\[}}%[[VAL_43]]] : memref<?xindex>
+// CHECK:             %[[VAL_45:.*]] = memref.load %[[VAL_31]]{{\[}}%[[VAL_5]]] : memref<?xindex>
+// CHECK:             %[[VAL_46:.*]] = memref.load %[[VAL_31]]{{\[}}%[[VAL_6]]] : memref<?xindex>
+// CHECK:             %[[VAL_47:.*]]:3 = scf.while (%[[VAL_48:.*]] = %[[VAL_42]], %[[VAL_49:.*]] = %[[VAL_45]], %[[VAL_50:.*]] = %[[VAL_5]]) : (index, index, index) -> (index, index, index) {
+// CHECK:               %[[VAL_51:.*]] = arith.cmpi ult, %[[VAL_48]], %[[VAL_44]] : index
+// CHECK:               %[[VAL_52:.*]] = arith.cmpi ult, %[[VAL_49]], %[[VAL_46]] : index
+// CHECK:               %[[VAL_53:.*]] = arith.andi %[[VAL_51]], %[[VAL_52]] : i1
+// CHECK:               scf.condition(%[[VAL_53]]) %[[VAL_48]], %[[VAL_49]], %[[VAL_50]] : index, index, index
+// CHECK:             } do {
+// CHECK:             ^bb0(%[[VAL_54:.*]]: index, %[[VAL_55:.*]]: index, %[[VAL_56:.*]]: index):
+// CHECK:               %[[VAL_57:.*]] = memref.load %[[VAL_29]]{{\[}}%[[VAL_54]]] : memref<?xindex>
+// CHECK:               %[[VAL_58:.*]] = memref.load %[[VAL_32]]{{\[}}%[[VAL_55]]] : memref<?xindex>
+// CHECK:               %[[VAL_59:.*]] = arith.cmpi ult, %[[VAL_58]], %[[VAL_57]] : index
+// CHECK:               %[[VAL_60:.*]] = arith.select %[[VAL_59]], %[[VAL_58]], %[[VAL_57]] : index
+// CHECK:               %[[VAL_61:.*]] = arith.cmpi eq, %[[VAL_57]], %[[VAL_60]] : index
+// CHECK:               %[[VAL_62:.*]] = arith.cmpi eq, %[[VAL_58]], %[[VAL_60]] : index
+// CHECK:               %[[VAL_63:.*]] = arith.andi %[[VAL_61]], %[[VAL_62]] : i1
+// CHECK:               %[[VAL_64:.*]] = scf.if %[[VAL_63]] -> (index) {
+// CHECK:                 %[[VAL_65:.*]] = memref.load %[[VAL_30]]{{\[}}%[[VAL_54]]] : memref<?xf64>
+// CHECK:                 %[[VAL_66:.*]] = memref.load %[[VAL_33]]{{\[}}%[[VAL_55]]] : memref<?xindex>
+// CHECK:                 %[[VAL_67:.*]] = arith.addi %[[VAL_55]], %[[VAL_6]] : index
+// CHECK:                 %[[VAL_68:.*]] = memref.load %[[VAL_33]]{{\[}}%[[VAL_67]]] : memref<?xindex>
+// CHECK:                 %[[VAL_69:.*]] = scf.for %[[VAL_70:.*]] = %[[VAL_66]] to %[[VAL_68]] step %[[VAL_6]] iter_args(%[[VAL_71:.*]] = %[[VAL_56]]) -> (index) {
+// CHECK:                   %[[VAL_72:.*]] = memref.load %[[VAL_34]]{{\[}}%[[VAL_70]]] : memref<?xindex>
+// CHECK:                   %[[VAL_73:.*]] = memref.load %[[VAL_20]]{{\[}}%[[VAL_72]]] : memref<300xf64>
+// CHECK:                   %[[VAL_74:.*]] = memref.load %[[VAL_35]]{{\[}}%[[VAL_70]]] : memref<?xf64>
+// CHECK:                   %[[VAL_75:.*]] = arith.mulf %[[VAL_65]], %[[VAL_74]] : f64
+// CHECK:                   %[[VAL_76:.*]] = arith.addf %[[VAL_73]], %[[VAL_75]] : f64
+// CHECK:                   %[[VAL_77:.*]] = memref.load %[[VAL_22]]{{\[}}%[[VAL_72]]] : memref<300xi1>
+// CHECK:                   %[[VAL_78:.*]] = arith.cmpi eq, %[[VAL_77]], %[[VAL_7]] : i1
+// CHECK:                   %[[VAL_79:.*]] = scf.if %[[VAL_78]] -> (index) {
+// CHECK:                     memref.store %[[VAL_8]], %[[VAL_22]]{{\[}}%[[VAL_72]]] : memref<300xi1>
+// CHECK:                     memref.store %[[VAL_72]], %[[VAL_24]]{{\[}}%[[VAL_71]]] : memref<300xindex>
+// CHECK:                     %[[VAL_80:.*]] = arith.addi %[[VAL_71]], %[[VAL_6]] : index
+// CHECK:                     scf.yield %[[VAL_80]] : index
+// CHECK:                   } else {
+// CHECK:                     scf.yield %[[VAL_71]] : index
+// CHECK:                   }
+// CHECK:                   memref.store %[[VAL_76]], %[[VAL_20]]{{\[}}%[[VAL_72]]] : memref<300xf64>
+// CHECK:                   scf.yield %[[VAL_81:.*]] : index
+// CHECK:                 }
+// CHECK:                 scf.yield %[[VAL_82:.*]] : index
+// CHECK:               } else {
+// CHECK:                 scf.yield %[[VAL_56]] : index
+// CHECK:               }
+// CHECK:               %[[VAL_83:.*]] = arith.addi %[[VAL_54]], %[[VAL_6]] : index
+// CHECK:               %[[VAL_84:.*]] = arith.select %[[VAL_61]], %[[VAL_83]], %[[VAL_54]] : index
+// CHECK:               %[[VAL_85:.*]] = arith.addi %[[VAL_55]], %[[VAL_6]] : index
+// CHECK:               %[[VAL_86:.*]] = arith.select %[[VAL_62]], %[[VAL_85]], %[[VAL_55]] : index
+// CHECK:               scf.yield %[[VAL_84]], %[[VAL_86]], %[[VAL_87:.*]] : index, index, index
+// CHECK:             }
+// CHECK:             func.call @expInsertF64(%[[VAL_19]], %[[VAL_37]], %[[VAL_21]], %[[VAL_23]], %[[VAL_25]], %[[VAL_88:.*]]#2) : (!llvm.ptr<i8>, memref<?xindex>, memref<?xf64>, memref<?xi1>, memref<?xindex>, index) -> ()
+// CHECK:           }
+// CHECK:           memref.dealloc %[[VAL_20]] : memref<300xf64>
+// CHECK:           memref.dealloc %[[VAL_22]] : memref<300xi1>
+// CHECK:           memref.dealloc %[[VAL_24]] : memref<300xindex>
+// CHECK:           call @endInsert(%[[VAL_19]]) : (!llvm.ptr<i8>) -> ()
+// CHECK:           return %[[VAL_19]] : !llvm.ptr<i8>
+// CHECK:         }
+func.func @fill_zero_after_alloc(%arg0: tensor<100x200xf64, #DCSR>,
+                                 %arg1: tensor<200x300xf64, #DCSR>) -> tensor<100x300xf64, #DCSR> {
+  %0 = bufferization.alloc_tensor() : tensor<100x300xf64, #DCSR>
   %cst = arith.constant 0.000000e+00 : f64
   %1 = linalg.fill ins(%cst : f64)
-                   outs(%0 : tensor<100x100xf64, #DCSR>) -> tensor<100x100xf64, #DCSR>
-  %2 = linalg.matmul ins(%arg0, %arg1 : tensor<100x100xf64, #DCSR>, tensor<100x100xf64, #DCSR>)
-                     outs(%1 : tensor<100x100xf64, #DCSR>) -> tensor<100x100xf64, #DCSR>
-  return %2 : tensor<100x100xf64, #DCSR>
+                   outs(%0 : tensor<100x300xf64, #DCSR>) -> tensor<100x300xf64, #DCSR>
+  %2 = linalg.matmul ins(%arg0, %arg1 : tensor<100x200xf64, #DCSR>, tensor<200x300xf64, #DCSR>)
+                     outs(%1 : tensor<100x300xf64, #DCSR>) -> tensor<100x300xf64, #DCSR>
+  return %2 : tensor<100x300xf64, #DCSR>
 }


        


More information about the Mlir-commits mailing list