[Mlir-commits] [mlir] [mlir][sparse] avoid tensor to memref conversion in sparse tensor rewri… (PR #69362)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Oct 17 11:01:31 PDT 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Peiming Liu (PeimingLiu)

<details>
<summary>Changes</summary>

…ting rules.

---

Patch is 27.11 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/69362.diff


3 Files Affected:

- (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp (+41-66) 
- (modified) mlir/test/Dialect/SparseTensor/convert_sparse2dense.mlir (+14-21) 
- (modified) mlir/test/Dialect/SparseTensor/sparse_concat.mlir (+77-71) 


``````````diff
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index 1bfee3aa1d7ee8e..e50b14975e83d63 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -829,47 +829,40 @@ struct ReshapeRewriter : public OpRewritePattern<ReshapeOp> {
   }
 };
 
+// A trivial wrapper to help generate different operations for dense/sparse
+// tensors.
 struct TensorLike {
   TensorLike(OpBuilder &builder, Location loc, RankedTensorType rtt,
-             ValueRange sizes)
-      : isSparse(rtt.getEncoding() != nullptr) {
+             ValueRange sizes) {
     SmallVector<Value> dynSzs;
     getDynamicSizes(rtt, sizes, dynSzs);
 
-    if (isSparse)
-      val = builder.create<AllocTensorOp>(loc, rtt, dynSzs);
-    else
-      val = allocDenseTensor(builder, loc, rtt, sizes);
-  };
-
-  void insertOrStore(OpBuilder &builder, Location loc, Value v,
-                     ValueRange crds) {
-    if (isSparse)
-      val = builder.create<InsertOp>(loc, v, val, crds);
-    else
-      builder.create<memref::StoreOp>(loc, v, val, crds);
+    val = builder.create<AllocTensorOp>(loc, rtt, dynSzs);
+    if (!isSparse()) {
+      Value c0 = constantZero(builder, loc, rtt.getElementType());
+      val = builder.create<linalg::FillOp>(loc, c0, val).getResult(0);
+    }
   }
 
-  Value getSSA() const {
-    // We don't need to maintain the SSA chain for a memref value.
-    return isSparse ? val : nullptr;
+  void insert(OpBuilder &builder, Location loc, Value v, ValueRange crds) {
+    // TODO: Unify these two.
+    if (isSparse())
+      val = builder.create<sparse_tensor::InsertOp>(loc, v, val, crds);
+    else
+      val = builder.create<tensor::InsertOp>(loc, v, val, crds);
   }
 
   Value finalize(OpBuilder &builder, Location loc, RankedTensorType rtp) const {
-    if (isSparse)
+    if (isSparse())
       return builder.create<LoadOp>(loc, val, true);
-    return builder.create<bufferization::ToTensorOp>(loc, rtp, val);
+    return val;
   }
 
-  void updateSSA(Value v) {
-    // Dense memref is a non-SSA value.
-    assert(isSparse);
-    val = v;
+  bool isSparse() const {
+    return getSparseTensorEncoding(val.getType()) != nullptr;
   }
 
-private:
-  bool isSparse;
-  Value val; // either a memref (for dense tensor) or a sparse tensor.
+  Value val;
 };
 
 struct ConcatenateRewriter : public OpRewritePattern<ConcatenateOp> {
@@ -901,14 +894,14 @@ struct ConcatenateRewriter : public OpRewritePattern<ConcatenateOp> {
 
     TensorLike dstBuf(rewriter, loc, dstTp.getRankedTensorType(), sizes);
     Value offset = constantIndex(rewriter, loc, 0);
-    Value iterArg = dstBuf.getSSA();
+    Value iterArg = dstBuf.val;
 
     ForeachOp foreachOp;
     for (Value input : op.getInputs()) {
       // Builds a for op for each input tensor to append new values into the
       // output tensor.
       foreachOp = rewriter.create<ForeachOp>(
-          loc, input, iterArg ? ValueRange{iterArg} : ValueRange{},
+          loc, input, iterArg,
           [&](OpBuilder &builder, Location loc, ValueRange dcvs, Value v,
               ValueRange reduc) {
             SmallVector<Value> dstLcvs(dstTp.getLvlRank());
@@ -920,32 +913,26 @@ struct ConcatenateRewriter : public OpRewritePattern<ConcatenateOp> {
               // FIXME: `toStoredDim` is deprecated
               dstLcvs[toStoredDim(dstTp.getEncoding(), d)] = crd;
             }
-
-            if (!reduc.empty())
-              dstBuf.updateSSA(reduc.front());
-
+            // Enters foreach, updates the SSA chain.
+            dstBuf.val = reduc.front();
             if (!dstTp.isAllDense()) {
               Value cond = genIsNonzero(builder, loc, v);
               auto ifOp = builder.create<scf::IfOp>(loc, reduc.getTypes(), cond,
                                                     /*else*/ true);
               builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
-              builder.create<scf::YieldOp>(loc, dstBuf.getSSA());
+              builder.create<scf::YieldOp>(loc, dstBuf.val);
 
               builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
-              dstBuf.insertOrStore(builder, loc, v, dstLcvs);
-              builder.create<scf::YieldOp>(loc, dstBuf.getSSA());
+              dstBuf.insert(builder, loc, v, dstLcvs);
+              builder.create<scf::YieldOp>(loc, dstBuf.val);
 
               // Exits the ifOp, update the sparse tensor SSA value.
               builder.setInsertionPointAfter(ifOp);
-              assert(!reduc.empty());
-              dstBuf.updateSSA(ifOp.getResult(0));
+              dstBuf.val = ifOp.getResult(0);
             } else {
-              dstBuf.insertOrStore(builder, loc, v, dstLcvs);
+              dstBuf.insert(builder, loc, v, dstLcvs);
             }
-            if (reduc.empty())
-              builder.create<sparse_tensor::YieldOp>(loc);
-            else
-              builder.create<sparse_tensor::YieldOp>(loc, dstBuf.getSSA());
+            builder.create<sparse_tensor::YieldOp>(loc, dstBuf.val);
           });
       // Accumulates the offset. Note that only static-shaped inputs are allowed
       // by concatenate op verifier, which saves us from computing the offset
@@ -955,15 +942,11 @@ struct ConcatenateRewriter : public OpRewritePattern<ConcatenateOp> {
       offset = rewriter.create<arith::AddIOp>(
           loc, offset, constantIndex(rewriter, loc, *sh));
 
-      if (!foreachOp.getResults().empty()) {
-        iterArg = foreachOp.getResult(0);
-        dstBuf.updateSSA(iterArg);
-      }
+      iterArg = foreachOp.getResult(0);
+      dstBuf.val = iterArg;
     }
 
-    if (!foreachOp.getResults().empty())
-      dstBuf.updateSSA(iterArg);
-
+    dstBuf.val = iterArg;
     Value ret = dstBuf.finalize(rewriter, loc, dstTp.getRankedTensorType());
     rewriter.replaceOp(op, ret);
     return success();
@@ -1010,15 +993,12 @@ struct DirectConvertRewriter : public OpRewritePattern<ConvertOp> {
     ValueRange vs;
     TensorLike dstBuf(rewriter, loc, dstStt.getRankedTensorType(), sizes);
 
-    Value iterArg = dstBuf.getSSA();
     auto foreachOp = rewriter.create<ForeachOp>(
-        loc, src, iterArg ? ValueRange{iterArg} : ValueRange{}, foreachOrder,
+        loc, src, dstBuf.val, foreachOrder,
         [&](OpBuilder &builder, Location loc, ValueRange dcvs, Value v,
             ValueRange reduc) {
           // Enters the loop, update the SSA value for insertion chain.
-          if (!reduc.empty())
-            dstBuf.updateSSA(reduc.front());
-
+          dstBuf.val = reduc.front();
           const Dimension dimRank = dstStt.getDimRank();
           const Level lvlRank = dstStt.getLvlRank();
           SmallVector<Value> lcvs(lvlRank);
@@ -1028,34 +1008,29 @@ struct DirectConvertRewriter : public OpRewritePattern<ConvertOp> {
           }
 
           if (!skipZeroCheck) {
-            assert(!reduc.empty());
             Value cond = genIsNonzero(builder, loc, v);
             auto ifOp = builder.create<scf::IfOp>(loc, reduc.getTypes(), cond,
                                                   /*else*/ true);
             builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
-            builder.create<scf::YieldOp>(loc, dstBuf.getSSA());
+            builder.create<scf::YieldOp>(loc, dstBuf.val);
 
             builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
-            dstBuf.insertOrStore(builder, loc, v, lcvs);
-            builder.create<scf::YieldOp>(loc, dstBuf.getSSA());
+            dstBuf.insert(builder, loc, v, lcvs);
+            builder.create<scf::YieldOp>(loc, dstBuf.val);
 
             // Exits the ifOp, update the sparse tensor SSA value.
             builder.setInsertionPointAfter(ifOp);
-            dstBuf.updateSSA(ifOp.getResult(0));
+            dstBuf.val = ifOp.getResult(0);
           } else {
-            dstBuf.insertOrStore(builder, loc, v, lcvs);
+            dstBuf.insert(builder, loc, v, lcvs);
           }
-          if (reduc.empty())
-            builder.create<sparse_tensor::YieldOp>(loc);
-          else
-            builder.create<sparse_tensor::YieldOp>(loc, dstBuf.getSSA());
+          builder.create<sparse_tensor::YieldOp>(loc, dstBuf.val);
         });
 
     rewriter.setInsertionPointAfter(foreachOp);
 
     // Exits the for loop, links the SSA chain.
-    if (!foreachOp.getResults().empty())
-      dstBuf.updateSSA(foreachOp.getResult(0));
+    dstBuf.val = foreachOp.getResult(0);
 
     Value ret = dstBuf.finalize(rewriter, loc, dstStt.getRankedTensorType());
     rewriter.replaceOp(op, ret);
diff --git a/mlir/test/Dialect/SparseTensor/convert_sparse2dense.mlir b/mlir/test/Dialect/SparseTensor/convert_sparse2dense.mlir
index c22f051a0d5854d..e2dcb068e11851e 100644
--- a/mlir/test/Dialect/SparseTensor/convert_sparse2dense.mlir
+++ b/mlir/test/Dialect/SparseTensor/convert_sparse2dense.mlir
@@ -14,11 +14,10 @@
 
 // CHECK-LABEL:  func.func @sparse_convert_1d
 // CHECK-NOT:      sparse_tensor.reorder_coo
-// CHECK:          memref.alloc
+// CHECK:          bufferization.alloc_tensor
 // CHECK:          linalg.fill
 // CHECK:          sparse_tensor.foreach
-// CHECK:            memref.store
-// CHECK:          bufferization.to_tensor
+// CHECK:            tensor.insert
 func.func @sparse_convert_1d(%arg0: tensor<13xi32, #SparseVector>) -> tensor<13xi32> {
   %0 = sparse_tensor.convert %arg0 : tensor<13xi32, #SparseVector> to tensor<13xi32>
   return %0 : tensor<13xi32>
@@ -26,11 +25,10 @@ func.func @sparse_convert_1d(%arg0: tensor<13xi32, #SparseVector>) -> tensor<13x
 
 // CHECK-LABEL:  func.func @sparse_convert_1d_dyn
 // CHECK-NOT:      sparse_tensor.reorder_coo
-// CHECK:          memref.alloc
+// CHECK:          bufferization.alloc_tensor
 // CHECK:          linalg.fill
 // CHECK:          sparse_tensor.foreach
-// CHECK:            memref.store
-// CHECK:          bufferization.to_tensor
+// CHECK:            tensor.insert
 func.func @sparse_convert_1d_dyn(%arg0: tensor<?xi32, #SparseVector>) -> tensor<?xi32> {
   %0 = sparse_tensor.convert %arg0 : tensor<?xi32, #SparseVector> to tensor<?xi32>
   return %0 : tensor<?xi32>
@@ -38,11 +36,10 @@ func.func @sparse_convert_1d_dyn(%arg0: tensor<?xi32, #SparseVector>) -> tensor<
 
 // CHECK-LABEL:  func.func @sparse_convert_2d
 // CHECK-NOT:      sparse_tensor.reorder_coo
-// CHECK:          memref.alloc
+// CHECK:          bufferization.alloc_tensor
 // CHECK:          linalg.fill
 // CHECK:          sparse_tensor.foreach
-// CHECK:            memref.store
-// CHECK:          bufferization.to_tensor
+// CHECK:            tensor.insert
 func.func @sparse_convert_2d(%arg0: tensor<2x4xf64, #SparseMatrix>) -> tensor<2x4xf64> {
   %0 = sparse_tensor.convert %arg0 : tensor<2x4xf64, #SparseMatrix> to tensor<2x4xf64>
   return %0 : tensor<2x4xf64>
@@ -50,11 +47,10 @@ func.func @sparse_convert_2d(%arg0: tensor<2x4xf64, #SparseMatrix>) -> tensor<2x
 
 // CHECK-LABEL:  func.func @sparse_convert_2d_dyn
 // CHECK-NOT:      sparse_tensor.reorder_coo
-// CHECK:          memref.alloc
+// CHECK:          bufferization.alloc_tensor
 // CHECK:          linalg.fill
 // CHECK:          sparse_tensor.foreach
-// CHECK:            memref.store
-// CHECK:          bufferization.to_tensor
+// CHECK:            tensor.insert
 func.func @sparse_convert_2d_dyn0(%arg0: tensor<?x4xf64, #SparseMatrix>) -> tensor<?x4xf64> {
   %0 = sparse_tensor.convert %arg0 : tensor<?x4xf64, #SparseMatrix> to tensor<?x4xf64>
   return %0 : tensor<?x4xf64>
@@ -62,11 +58,10 @@ func.func @sparse_convert_2d_dyn0(%arg0: tensor<?x4xf64, #SparseMatrix>) -> tens
 
 // CHECK-LABEL:  func.func @sparse_convert_2d_dyn1
 // CHECK-NOT:      sparse_tensor.reorder_coo
-// CHECK:          memref.alloc
+// CHECK:          bufferization.alloc_tensor
 // CHECK:          linalg.fill
 // CHECK:          sparse_tensor.foreach
-// CHECK:            memref.store
-// CHECK:          bufferization.to_tensor
+// CHECK:            tensor.insert
 func.func @sparse_convert_2d_dyn1(%arg0: tensor<2x?xf64, #SparseMatrix>) -> tensor<2x?xf64> {
   %0 = sparse_tensor.convert %arg0 : tensor<2x?xf64, #SparseMatrix> to tensor<2x?xf64>
   return %0 : tensor<2x?xf64>
@@ -74,11 +69,10 @@ func.func @sparse_convert_2d_dyn1(%arg0: tensor<2x?xf64, #SparseMatrix>) -> tens
 
 // CHECK-LABEL:  func.func @sparse_convert_2d_dyn2
 // CHECK-NOT:      sparse_tensor.reorder_coo
-// CHECK:          memref.alloc
+// CHECK:          bufferization.alloc_tensor
 // CHECK:          linalg.fill
 // CHECK:          sparse_tensor.foreach
-// CHECK:            memref.store
-// CHECK:          bufferization.to_tensor
+// CHECK:            tensor.insert
 func.func @sparse_convert_2d_dyn2(%arg0: tensor<?x?xf64, #SparseMatrix>) -> tensor<?x?xf64> {
   %0 = sparse_tensor.convert %arg0 : tensor<?x?xf64, #SparseMatrix> to tensor<?x?xf64>
   return %0 : tensor<?x?xf64>
@@ -86,11 +80,10 @@ func.func @sparse_convert_2d_dyn2(%arg0: tensor<?x?xf64, #SparseMatrix>) -> tens
 
 // CHECK-LABEL:  func.func @sparse_convert_3d
 // CHECK-NOT:      sparse_tensor.reorder_coo
-// CHECK:          memref.alloc
+// CHECK:          bufferization.alloc_tensor
 // CHECK:          linalg.fill
 // CHECK:          sparse_tensor.foreach
-// CHECK:            memref.store
-// CHECK:          bufferization.to_tensor
+// CHECK:            tensor.insert
 func.func @sparse_convert_3d(%arg0: tensor<2x3x4xf64, #SparseTensor>) -> tensor<2x3x4xf64> {
   %0 = sparse_tensor.convert %arg0 : tensor<2x3x4xf64, #SparseTensor> to tensor<2x3x4xf64>
   return %0 : tensor<2x3x4xf64>
diff --git a/mlir/test/Dialect/SparseTensor/sparse_concat.mlir b/mlir/test/Dialect/SparseTensor/sparse_concat.mlir
index bdfab54dc6daeb5..f3d3dd28563e891 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_concat.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_concat.mlir
@@ -176,77 +176,83 @@ func.func @concat_sparse_sparse_dynamic(%arg0: tensor<2x4xf64, #DCSR>,
     return %0 : tensor<?x?xf64, #DCSR>
 }
 
-// CHECK-LABEL: @concat_sparse_sparse_dense(
-//  CHECK-SAME:  %[[TMP_arg0:.*]]: tensor<2x4xf64, #sparse_tensor
-//  CHECK-SAME:  %[[TMP_arg1:.*]]: tensor<3x4xf64, #sparse_tensor
-//  CHECK-SAME:  %[[TMP_arg2:.*]]: tensor<4x4xf64, #sparse_tensor
-//   CHECK-DAG:  %[[TMP_c0:.*]] = arith.constant 0 : index
-//   CHECK-DAG:  %[[TMP_c1:.*]] = arith.constant 1 : index
-//   CHECK-DAG:  %[[TMP_c5:.*]] = arith.constant 5 : index
-//   CHECK-DAG:  %[[TMP_c2:.*]] = arith.constant 2 : index
-//   CHECK-DAG:  %[[TMP_c9:.*]] = arith.constant 9 : index
-//   CHECK-DAG:  %[[TMP_c4:.*]] = arith.constant 4 : index
-//   CHECK-DAG:  %[[TMP_d0:.*]] = arith.constant 0.000000e+00 : f64
-//       CHECK:  %[[A:.*]] = memref.alloc(%[[TMP_c9]], %[[TMP_c4]]) : memref<?x?xf64>
-//       CHECK:  linalg.fill ins(%[[TMP_d0]] : f64) outs(%[[A]] : memref<?x?xf64>)
-//       CHECK:  %[[TMP_1:.*]] = sparse_tensor.positions %[[TMP_arg0]] {level = 0 : index} : tensor<2x4xf64, #sparse_tensor
-//       CHECK:  %[[TMP_2:.*]] = sparse_tensor.coordinates %[[TMP_arg0]] {level = 0 : index} : tensor<2x4xf64, #sparse_tensor
-//       CHECK:  %[[TMP_3:.*]] = sparse_tensor.positions %[[TMP_arg0]] {level = 1 : index} : tensor<2x4xf64, #sparse_tensor
-//       CHECK:  %[[TMP_4:.*]] = sparse_tensor.coordinates %[[TMP_arg0]] {level = 1 : index} : tensor<2x4xf64, #sparse_tensor
-//       CHECK:  %[[TMP_5:.*]] = sparse_tensor.values %[[TMP_arg0]] : tensor<2x4xf64, #sparse_tensor
-//       CHECK:  %[[TMP_6:.*]] = memref.load %[[TMP_1]][%[[TMP_c0]]] : memref<?xindex>
-//       CHECK:  %[[TMP_7:.*]] = memref.load %[[TMP_1]][%[[TMP_c1]]] : memref<?xindex>
-//       CHECK:  scf.for %[[TMP_arg3:.*]] = %[[TMP_6]] to %[[TMP_7]] step %[[TMP_c1]]
-//       CHECK:    %[[TMP_23:.*]] = memref.load %[[TMP_2]][%[[TMP_arg3]]] : memref<?xindex>
-//   CHECK-DAG:    %[[TMP_25:.*]] = memref.load %[[TMP_3]][%[[TMP_arg3]]] : memref<?xindex>
-//   CHECK-DAG:    %[[TMP_24:.*]] = arith.addi %[[TMP_arg3]], %[[TMP_c1]] : index
-//       CHECK:    %[[TMP_26:.*]] = memref.load %[[TMP_3]][%[[TMP_24]]] : memref<?xindex>
-//       CHECK:    scf.for %[[TMP_arg4:.*]] = %[[TMP_25]] to %[[TMP_26]] step %[[TMP_c1]]
-//       CHECK:      %[[TMP_27:.*]] = memref.load %[[TMP_4]][%[[TMP_arg4]]] : memref<?xindex>
-//       CHECK:      %[[TMP_28:.*]] = memref.load %[[TMP_5]][%[[TMP_arg4]]] : memref<?xf64>
-//       CHECK:      memref.store %[[TMP_28]], %[[A]]{{\[}}%[[TMP_23]], %[[TMP_27]]] : memref<?x?xf64>
-//       CHECK:    }
-//       CHECK:  }
-//       CHECK:  %[[TMP_8:.*]] = sparse_tensor.positions %[[TMP_arg1]] {level = 0 : index} : tensor<3x4xf64, #sparse_tensor
-//       CHECK:  %[[TMP_9:.*]] = sparse_tensor.coordinates %[[TMP_arg1]] {level = 0 : index} : tensor<3x4xf64, #sparse_tensor
-//       CHECK:  %[[TMP_10:.*]] = sparse_tensor.positions %[[TMP_arg1]] {level = 1 : index} : tensor<3x4xf64, #sparse_tensor
-//       CHECK:  %[[TMP_11:.*]] = sparse_tensor.coordinates %[[TMP_arg1]] {level = 1 : index} : tensor<3x4xf64, #sparse_tensor
-//       CHECK:  %[[TMP_12:.*]] = sparse_tensor.values %[[TMP_arg1]] : tensor<3x4xf64, #sparse_tensor
-//       CHECK:  %[[TMP_13:.*]] = memref.load %[[TMP_8]][%[[TMP_c0]]] : memref<?xindex>
-//       CHECK:  %[[TMP_14:.*]] = memref.load %[[TMP_8]][%[[TMP_c1]]] : memref<?xindex>
-//       CHECK:  scf.for %[[TMP_arg3:.*]] = %[[TMP_13]] to %[[TMP_14]] step %[[TMP_c1]]
-//       CHECK:    %[[TMP_23:.*]] = memref.load %[[TMP_9]][%[[TMP_arg3]]] : memref<?xindex>
-//   CHECK-DAG:    %[[TMP_25:.*]] = memref.load %[[TMP_10]][%[[TMP_arg3]]] : memref<?xindex>
-//   CHECK-DAG:    %[[TMP_24:.*]] = arith.addi %[[TMP_arg3]], %[[TMP_c1]] : index
-//       CHECK:    %[[TMP_26:.*]] = memref.load %[[TMP_10]][%[[TMP_24]]] : memref<?xindex>
-//       CHECK:    scf.for %[[TMP_arg4:.*]] = %[[TMP_25]] to %[[TMP_26]] step %[[TMP_c1]]
-//       CHECK:      %[[TMP_27:.*]] = memref.load %[[TMP_11]][%[[TMP_arg4]]] : memref<?xindex>
-//       CHECK:      %[[TMP_28:.*]] = memref.load %[[TMP_12]][%[[TMP_arg4]]] : memref<?xf64>
-//       CHECK:      %[[TMP_29:.*]] = arith.addi %[[TMP_23]], %[[TMP_c2]] : index
-//       CHECK:      memref.store %[[TMP_28]], %[[A]]{{\[}}%[[TMP_29]], %[[TMP_27]]] : memref<?x?xf64>
-//       CHECK:    }
-//       CHECK:  }
-//       CHECK:  %[[TMP_15:.*]] = sparse_tensor.positions %[[TMP_arg2]] {level = 0 : index} : tensor<4x4xf64, #sparse_tensor
-//       CHECK:  %[[TMP_16:.*]] = sparse_tensor.coordinates %[[TMP_arg2]] {level = 0 : index} : tensor<4x4xf64, #sparse_tensor
-//       CHECK:  %[[TMP_17:.*]] = sparse_tensor.positions %[[TMP_arg2]] {level = 1 : index} : tensor<4x4xf64, #sparse_tensor
-//       CHECK:  %[[TMP_18:.*]] = sparse_tensor.coordinates %[[TMP_arg2]] {level = 1 : index} : tensor<4x4xf64, #sparse_tensor
-//       CHECK:  %[[TMP_19:.*]] = sparse_tensor.values %[[TMP_arg2]] : tensor<4x4xf64, #sparse_tensor
-//       CHECK:  %[[TMP_20:.*]] = memref.load %[[TMP_15]][%[[TMP_c0]]] : memref<?xindex>
-//       CHECK:  %[[TMP_21:.*]] = memref.load %[[TMP_15]][%[[TMP_c1]]] : memref<?xindex>
-//       CHECK:  scf.for %[[TMP_arg3:.*]] = %[[TMP_20]] to %[[TMP_21]] step %[[TMP_c1]]
-//       CHECK:    %[[TMP_23:.*]] = memref.load %[[TMP_16]][%[[TMP_arg3]]] : memref<?xindex>
-//       CHECK:    %[[TMP_25:.*]] = memref.load %[[TMP_17]][%[[TMP_arg3]]] : memref<?xindex>
-//       CHECK:    %[[TMP_24:.*]] = arith.addi %[[TMP_arg3]], %[[TMP_c1]] : index
-//       CHECK:    %[[TMP_26:.*]] = memref.load %[[TMP_17]][%[[TMP_24]]] : memref<?xindex>
-//       CHECK:    scf.for %[[TMP_arg4:.*]] = %[[TMP_25]] to %[[TMP_26]] step %[[TMP_c1]]
-//       CHECK:      %[[TMP_27:.*]] = memref.load %[[TMP_18]][%[[TMP_arg4]]] : memref<?xindex>
-//       CHECK:      %[[TMP_28:.*]] = memref.load %[[TMP_19]][%[[TMP_arg4]]] : memref<?xf64>
-//       CHECK:      %[[TMP_29:.*]] = arith.addi %[[TMP_23]], %[[TMP_c5]] : index
-//       CHECK:      memref.store %[[TMP_28]], %[[A]]{{\[}}%[[TMP_29]], %[[TMP_27]]] : memref<?x?xf64>
-//       CHECK:    }
-//       CHECK:  }
-//       CHECK:  %[[R:.*]] = bufferization.to_tensor %[[A]] : memref<?x?xf64>
-//       CHECK:  return %[[R]] : tensor<?x?xf64>
+// CHECK-LABEL:   func.func @concat_sparse_sparse_dense(
+// CHECK-SAME:      %[[VAL_0:.*]]: tensor<2x4xf64, #sparse_tensor
+// CHECK-SAME:      %[[VAL_1:.*]]: tensor<3x4xf64, #sparse_tensor
+// CHECK-SAME:      %[[VAL_2...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/69362


More information about the Mlir-commits mailing list