[Mlir-commits] [mlir] [mlir][sparse] avoid tensor to memref conversion in sparse tensor rewri… (PR #69362)

Peiming Liu llvmlistbot at llvm.org
Tue Oct 17 10:59:52 PDT 2023


https://github.com/PeimingLiu created https://github.com/llvm/llvm-project/pull/69362

…ting rules.

>From f461a5c264367a8e9eeef5b5dd3107696a5aa24d Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Tue, 17 Oct 2023 17:37:27 +0000
Subject: [PATCH] [mlir][sparse] avoid tensor->memref conversion in sparse
 tensor rewriting rules.

---
 .../Transforms/SparseTensorRewriting.cpp      | 107 +++++--------
 .../SparseTensor/convert_sparse2dense.mlir    |  35 ++---
 .../Dialect/SparseTensor/sparse_concat.mlir   | 148 +++++++++---------
 3 files changed, 132 insertions(+), 158 deletions(-)

diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index 1bfee3aa1d7ee8e..e50b14975e83d63 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -829,47 +829,40 @@ struct ReshapeRewriter : public OpRewritePattern<ReshapeOp> {
   }
 };
 
+// A trivial wrapper to help generate different operations for dense/sparse
+// tensors.
 struct TensorLike {
   TensorLike(OpBuilder &builder, Location loc, RankedTensorType rtt,
-             ValueRange sizes)
-      : isSparse(rtt.getEncoding() != nullptr) {
+             ValueRange sizes) {
     SmallVector<Value> dynSzs;
     getDynamicSizes(rtt, sizes, dynSzs);
 
-    if (isSparse)
-      val = builder.create<AllocTensorOp>(loc, rtt, dynSzs);
-    else
-      val = allocDenseTensor(builder, loc, rtt, sizes);
-  };
-
-  void insertOrStore(OpBuilder &builder, Location loc, Value v,
-                     ValueRange crds) {
-    if (isSparse)
-      val = builder.create<InsertOp>(loc, v, val, crds);
-    else
-      builder.create<memref::StoreOp>(loc, v, val, crds);
+    val = builder.create<AllocTensorOp>(loc, rtt, dynSzs);
+    if (!isSparse()) {
+      Value c0 = constantZero(builder, loc, rtt.getElementType());
+      val = builder.create<linalg::FillOp>(loc, c0, val).getResult(0);
+    }
   }
 
-  Value getSSA() const {
-    // We don't need to maintain the SSA chain for a memref value.
-    return isSparse ? val : nullptr;
+  void insert(OpBuilder &builder, Location loc, Value v, ValueRange crds) {
+    // TODO: Unify these two.
+    if (isSparse())
+      val = builder.create<sparse_tensor::InsertOp>(loc, v, val, crds);
+    else
+      val = builder.create<tensor::InsertOp>(loc, v, val, crds);
   }
 
   Value finalize(OpBuilder &builder, Location loc, RankedTensorType rtp) const {
-    if (isSparse)
+    if (isSparse())
       return builder.create<LoadOp>(loc, val, true);
-    return builder.create<bufferization::ToTensorOp>(loc, rtp, val);
+    return val;
   }
 
-  void updateSSA(Value v) {
-    // Dense memref is a non-SSA value.
-    assert(isSparse);
-    val = v;
+  bool isSparse() const {
+    return getSparseTensorEncoding(val.getType()) != nullptr;
   }
 
-private:
-  bool isSparse;
-  Value val; // either a memref (for dense tensor) or a sparse tensor.
+  Value val;
 };
 
 struct ConcatenateRewriter : public OpRewritePattern<ConcatenateOp> {
@@ -901,14 +894,14 @@ struct ConcatenateRewriter : public OpRewritePattern<ConcatenateOp> {
 
     TensorLike dstBuf(rewriter, loc, dstTp.getRankedTensorType(), sizes);
     Value offset = constantIndex(rewriter, loc, 0);
-    Value iterArg = dstBuf.getSSA();
+    Value iterArg = dstBuf.val;
 
     ForeachOp foreachOp;
     for (Value input : op.getInputs()) {
       // Builds a for op for each input tensor to append new values into the
       // output tensor.
       foreachOp = rewriter.create<ForeachOp>(
-          loc, input, iterArg ? ValueRange{iterArg} : ValueRange{},
+          loc, input, iterArg,
           [&](OpBuilder &builder, Location loc, ValueRange dcvs, Value v,
               ValueRange reduc) {
             SmallVector<Value> dstLcvs(dstTp.getLvlRank());
@@ -920,32 +913,26 @@ struct ConcatenateRewriter : public OpRewritePattern<ConcatenateOp> {
               // FIXME: `toStoredDim` is deprecated
               dstLcvs[toStoredDim(dstTp.getEncoding(), d)] = crd;
             }
-
-            if (!reduc.empty())
-              dstBuf.updateSSA(reduc.front());
-
+            // Enters foreach, updates the SSA chain.
+            dstBuf.val = reduc.front();
             if (!dstTp.isAllDense()) {
               Value cond = genIsNonzero(builder, loc, v);
               auto ifOp = builder.create<scf::IfOp>(loc, reduc.getTypes(), cond,
                                                     /*else*/ true);
               builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
-              builder.create<scf::YieldOp>(loc, dstBuf.getSSA());
+              builder.create<scf::YieldOp>(loc, dstBuf.val);
 
               builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
-              dstBuf.insertOrStore(builder, loc, v, dstLcvs);
-              builder.create<scf::YieldOp>(loc, dstBuf.getSSA());
+              dstBuf.insert(builder, loc, v, dstLcvs);
+              builder.create<scf::YieldOp>(loc, dstBuf.val);
 
               // Exits the ifOp, update the sparse tensor SSA value.
               builder.setInsertionPointAfter(ifOp);
-              assert(!reduc.empty());
-              dstBuf.updateSSA(ifOp.getResult(0));
+              dstBuf.val = ifOp.getResult(0);
             } else {
-              dstBuf.insertOrStore(builder, loc, v, dstLcvs);
+              dstBuf.insert(builder, loc, v, dstLcvs);
             }
-            if (reduc.empty())
-              builder.create<sparse_tensor::YieldOp>(loc);
-            else
-              builder.create<sparse_tensor::YieldOp>(loc, dstBuf.getSSA());
+            builder.create<sparse_tensor::YieldOp>(loc, dstBuf.val);
           });
       // Accumulates the offset. Note that only static-shaped inputs are allowed
       // by concatenate op verifier, which saves us from computing the offset
@@ -955,15 +942,11 @@ struct ConcatenateRewriter : public OpRewritePattern<ConcatenateOp> {
       offset = rewriter.create<arith::AddIOp>(
           loc, offset, constantIndex(rewriter, loc, *sh));
 
-      if (!foreachOp.getResults().empty()) {
-        iterArg = foreachOp.getResult(0);
-        dstBuf.updateSSA(iterArg);
-      }
+      iterArg = foreachOp.getResult(0);
+      dstBuf.val = iterArg;
     }
 
-    if (!foreachOp.getResults().empty())
-      dstBuf.updateSSA(iterArg);
-
+    dstBuf.val = iterArg;
     Value ret = dstBuf.finalize(rewriter, loc, dstTp.getRankedTensorType());
     rewriter.replaceOp(op, ret);
     return success();
@@ -1010,15 +993,12 @@ struct DirectConvertRewriter : public OpRewritePattern<ConvertOp> {
     ValueRange vs;
     TensorLike dstBuf(rewriter, loc, dstStt.getRankedTensorType(), sizes);
 
-    Value iterArg = dstBuf.getSSA();
     auto foreachOp = rewriter.create<ForeachOp>(
-        loc, src, iterArg ? ValueRange{iterArg} : ValueRange{}, foreachOrder,
+        loc, src, dstBuf.val, foreachOrder,
         [&](OpBuilder &builder, Location loc, ValueRange dcvs, Value v,
             ValueRange reduc) {
           // Enters the loop, update the SSA value for insertion chain.
-          if (!reduc.empty())
-            dstBuf.updateSSA(reduc.front());
-
+          dstBuf.val = reduc.front();
           const Dimension dimRank = dstStt.getDimRank();
           const Level lvlRank = dstStt.getLvlRank();
           SmallVector<Value> lcvs(lvlRank);
@@ -1028,34 +1008,29 @@ struct DirectConvertRewriter : public OpRewritePattern<ConvertOp> {
           }
 
           if (!skipZeroCheck) {
-            assert(!reduc.empty());
             Value cond = genIsNonzero(builder, loc, v);
             auto ifOp = builder.create<scf::IfOp>(loc, reduc.getTypes(), cond,
                                                   /*else*/ true);
             builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
-            builder.create<scf::YieldOp>(loc, dstBuf.getSSA());
+            builder.create<scf::YieldOp>(loc, dstBuf.val);
 
             builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
-            dstBuf.insertOrStore(builder, loc, v, lcvs);
-            builder.create<scf::YieldOp>(loc, dstBuf.getSSA());
+            dstBuf.insert(builder, loc, v, lcvs);
+            builder.create<scf::YieldOp>(loc, dstBuf.val);
 
             // Exits the ifOp, update the sparse tensor SSA value.
             builder.setInsertionPointAfter(ifOp);
-            dstBuf.updateSSA(ifOp.getResult(0));
+            dstBuf.val = ifOp.getResult(0);
           } else {
-            dstBuf.insertOrStore(builder, loc, v, lcvs);
+            dstBuf.insert(builder, loc, v, lcvs);
           }
-          if (reduc.empty())
-            builder.create<sparse_tensor::YieldOp>(loc);
-          else
-            builder.create<sparse_tensor::YieldOp>(loc, dstBuf.getSSA());
+          builder.create<sparse_tensor::YieldOp>(loc, dstBuf.val);
         });
 
     rewriter.setInsertionPointAfter(foreachOp);
 
     // Exits the for loop, links the SSA chain.
-    if (!foreachOp.getResults().empty())
-      dstBuf.updateSSA(foreachOp.getResult(0));
+    dstBuf.val = foreachOp.getResult(0);
 
     Value ret = dstBuf.finalize(rewriter, loc, dstStt.getRankedTensorType());
     rewriter.replaceOp(op, ret);
diff --git a/mlir/test/Dialect/SparseTensor/convert_sparse2dense.mlir b/mlir/test/Dialect/SparseTensor/convert_sparse2dense.mlir
index c22f051a0d5854d..e2dcb068e11851e 100644
--- a/mlir/test/Dialect/SparseTensor/convert_sparse2dense.mlir
+++ b/mlir/test/Dialect/SparseTensor/convert_sparse2dense.mlir
@@ -14,11 +14,10 @@
 
 // CHECK-LABEL:  func.func @sparse_convert_1d
 // CHECK-NOT:      sparse_tensor.reorder_coo
-// CHECK:          memref.alloc
+// CHECK:          bufferization.alloc_tensor
 // CHECK:          linalg.fill
 // CHECK:          sparse_tensor.foreach
-// CHECK:            memref.store
-// CHECK:          bufferization.to_tensor
+// CHECK:            tensor.insert
 func.func @sparse_convert_1d(%arg0: tensor<13xi32, #SparseVector>) -> tensor<13xi32> {
   %0 = sparse_tensor.convert %arg0 : tensor<13xi32, #SparseVector> to tensor<13xi32>
   return %0 : tensor<13xi32>
@@ -26,11 +25,10 @@ func.func @sparse_convert_1d(%arg0: tensor<13xi32, #SparseVector>) -> tensor<13x
 
 // CHECK-LABEL:  func.func @sparse_convert_1d_dyn
 // CHECK-NOT:      sparse_tensor.reorder_coo
-// CHECK:          memref.alloc
+// CHECK:          bufferization.alloc_tensor
 // CHECK:          linalg.fill
 // CHECK:          sparse_tensor.foreach
-// CHECK:            memref.store
-// CHECK:          bufferization.to_tensor
+// CHECK:            tensor.insert
 func.func @sparse_convert_1d_dyn(%arg0: tensor<?xi32, #SparseVector>) -> tensor<?xi32> {
   %0 = sparse_tensor.convert %arg0 : tensor<?xi32, #SparseVector> to tensor<?xi32>
   return %0 : tensor<?xi32>
@@ -38,11 +36,10 @@ func.func @sparse_convert_1d_dyn(%arg0: tensor<?xi32, #SparseVector>) -> tensor<
 
 // CHECK-LABEL:  func.func @sparse_convert_2d
 // CHECK-NOT:      sparse_tensor.reorder_coo
-// CHECK:          memref.alloc
+// CHECK:          bufferization.alloc_tensor
 // CHECK:          linalg.fill
 // CHECK:          sparse_tensor.foreach
-// CHECK:            memref.store
-// CHECK:          bufferization.to_tensor
+// CHECK:            tensor.insert
 func.func @sparse_convert_2d(%arg0: tensor<2x4xf64, #SparseMatrix>) -> tensor<2x4xf64> {
   %0 = sparse_tensor.convert %arg0 : tensor<2x4xf64, #SparseMatrix> to tensor<2x4xf64>
   return %0 : tensor<2x4xf64>
@@ -50,11 +47,10 @@ func.func @sparse_convert_2d(%arg0: tensor<2x4xf64, #SparseMatrix>) -> tensor<2x
 
 // CHECK-LABEL:  func.func @sparse_convert_2d_dyn
 // CHECK-NOT:      sparse_tensor.reorder_coo
-// CHECK:          memref.alloc
+// CHECK:          bufferization.alloc_tensor
 // CHECK:          linalg.fill
 // CHECK:          sparse_tensor.foreach
-// CHECK:            memref.store
-// CHECK:          bufferization.to_tensor
+// CHECK:            tensor.insert
 func.func @sparse_convert_2d_dyn0(%arg0: tensor<?x4xf64, #SparseMatrix>) -> tensor<?x4xf64> {
   %0 = sparse_tensor.convert %arg0 : tensor<?x4xf64, #SparseMatrix> to tensor<?x4xf64>
   return %0 : tensor<?x4xf64>
@@ -62,11 +58,10 @@ func.func @sparse_convert_2d_dyn0(%arg0: tensor<?x4xf64, #SparseMatrix>) -> tens
 
 // CHECK-LABEL:  func.func @sparse_convert_2d_dyn1
 // CHECK-NOT:      sparse_tensor.reorder_coo
-// CHECK:          memref.alloc
+// CHECK:          bufferization.alloc_tensor
 // CHECK:          linalg.fill
 // CHECK:          sparse_tensor.foreach
-// CHECK:            memref.store
-// CHECK:          bufferization.to_tensor
+// CHECK:            tensor.insert
 func.func @sparse_convert_2d_dyn1(%arg0: tensor<2x?xf64, #SparseMatrix>) -> tensor<2x?xf64> {
   %0 = sparse_tensor.convert %arg0 : tensor<2x?xf64, #SparseMatrix> to tensor<2x?xf64>
   return %0 : tensor<2x?xf64>
@@ -74,11 +69,10 @@ func.func @sparse_convert_2d_dyn1(%arg0: tensor<2x?xf64, #SparseMatrix>) -> tens
 
 // CHECK-LABEL:  func.func @sparse_convert_2d_dyn2
 // CHECK-NOT:      sparse_tensor.reorder_coo
-// CHECK:          memref.alloc
+// CHECK:          bufferization.alloc_tensor
 // CHECK:          linalg.fill
 // CHECK:          sparse_tensor.foreach
-// CHECK:            memref.store
-// CHECK:          bufferization.to_tensor
+// CHECK:            tensor.insert
 func.func @sparse_convert_2d_dyn2(%arg0: tensor<?x?xf64, #SparseMatrix>) -> tensor<?x?xf64> {
   %0 = sparse_tensor.convert %arg0 : tensor<?x?xf64, #SparseMatrix> to tensor<?x?xf64>
   return %0 : tensor<?x?xf64>
@@ -86,11 +80,10 @@ func.func @sparse_convert_2d_dyn2(%arg0: tensor<?x?xf64, #SparseMatrix>) -> tens
 
 // CHECK-LABEL:  func.func @sparse_convert_3d
 // CHECK-NOT:      sparse_tensor.reorder_coo
-// CHECK:          memref.alloc
+// CHECK:          bufferization.alloc_tensor
 // CHECK:          linalg.fill
 // CHECK:          sparse_tensor.foreach
-// CHECK:            memref.store
-// CHECK:          bufferization.to_tensor
+// CHECK:            tensor.insert
 func.func @sparse_convert_3d(%arg0: tensor<2x3x4xf64, #SparseTensor>) -> tensor<2x3x4xf64> {
   %0 = sparse_tensor.convert %arg0 : tensor<2x3x4xf64, #SparseTensor> to tensor<2x3x4xf64>
   return %0 : tensor<2x3x4xf64>
diff --git a/mlir/test/Dialect/SparseTensor/sparse_concat.mlir b/mlir/test/Dialect/SparseTensor/sparse_concat.mlir
index bdfab54dc6daeb5..f3d3dd28563e891 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_concat.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_concat.mlir
@@ -176,77 +176,83 @@ func.func @concat_sparse_sparse_dynamic(%arg0: tensor<2x4xf64, #DCSR>,
     return %0 : tensor<?x?xf64, #DCSR>
 }
 
-// CHECK-LABEL: @concat_sparse_sparse_dense(
-//  CHECK-SAME:  %[[TMP_arg0:.*]]: tensor<2x4xf64, #sparse_tensor
-//  CHECK-SAME:  %[[TMP_arg1:.*]]: tensor<3x4xf64, #sparse_tensor
-//  CHECK-SAME:  %[[TMP_arg2:.*]]: tensor<4x4xf64, #sparse_tensor
-//   CHECK-DAG:  %[[TMP_c0:.*]] = arith.constant 0 : index
-//   CHECK-DAG:  %[[TMP_c1:.*]] = arith.constant 1 : index
-//   CHECK-DAG:  %[[TMP_c5:.*]] = arith.constant 5 : index
-//   CHECK-DAG:  %[[TMP_c2:.*]] = arith.constant 2 : index
-//   CHECK-DAG:  %[[TMP_c9:.*]] = arith.constant 9 : index
-//   CHECK-DAG:  %[[TMP_c4:.*]] = arith.constant 4 : index
-//   CHECK-DAG:  %[[TMP_d0:.*]] = arith.constant 0.000000e+00 : f64
-//       CHECK:  %[[A:.*]] = memref.alloc(%[[TMP_c9]], %[[TMP_c4]]) : memref<?x?xf64>
-//       CHECK:  linalg.fill ins(%[[TMP_d0]] : f64) outs(%[[A]] : memref<?x?xf64>)
-//       CHECK:  %[[TMP_1:.*]] = sparse_tensor.positions %[[TMP_arg0]] {level = 0 : index} : tensor<2x4xf64, #sparse_tensor
-//       CHECK:  %[[TMP_2:.*]] = sparse_tensor.coordinates %[[TMP_arg0]] {level = 0 : index} : tensor<2x4xf64, #sparse_tensor
-//       CHECK:  %[[TMP_3:.*]] = sparse_tensor.positions %[[TMP_arg0]] {level = 1 : index} : tensor<2x4xf64, #sparse_tensor
-//       CHECK:  %[[TMP_4:.*]] = sparse_tensor.coordinates %[[TMP_arg0]] {level = 1 : index} : tensor<2x4xf64, #sparse_tensor
-//       CHECK:  %[[TMP_5:.*]] = sparse_tensor.values %[[TMP_arg0]] : tensor<2x4xf64, #sparse_tensor
-//       CHECK:  %[[TMP_6:.*]] = memref.load %[[TMP_1]][%[[TMP_c0]]] : memref<?xindex>
-//       CHECK:  %[[TMP_7:.*]] = memref.load %[[TMP_1]][%[[TMP_c1]]] : memref<?xindex>
-//       CHECK:  scf.for %[[TMP_arg3:.*]] = %[[TMP_6]] to %[[TMP_7]] step %[[TMP_c1]]
-//       CHECK:    %[[TMP_23:.*]] = memref.load %[[TMP_2]][%[[TMP_arg3]]] : memref<?xindex>
-//   CHECK-DAG:    %[[TMP_25:.*]] = memref.load %[[TMP_3]][%[[TMP_arg3]]] : memref<?xindex>
-//   CHECK-DAG:    %[[TMP_24:.*]] = arith.addi %[[TMP_arg3]], %[[TMP_c1]] : index
-//       CHECK:    %[[TMP_26:.*]] = memref.load %[[TMP_3]][%[[TMP_24]]] : memref<?xindex>
-//       CHECK:    scf.for %[[TMP_arg4:.*]] = %[[TMP_25]] to %[[TMP_26]] step %[[TMP_c1]]
-//       CHECK:      %[[TMP_27:.*]] = memref.load %[[TMP_4]][%[[TMP_arg4]]] : memref<?xindex>
-//       CHECK:      %[[TMP_28:.*]] = memref.load %[[TMP_5]][%[[TMP_arg4]]] : memref<?xf64>
-//       CHECK:      memref.store %[[TMP_28]], %[[A]]{{\[}}%[[TMP_23]], %[[TMP_27]]] : memref<?x?xf64>
-//       CHECK:    }
-//       CHECK:  }
-//       CHECK:  %[[TMP_8:.*]] = sparse_tensor.positions %[[TMP_arg1]] {level = 0 : index} : tensor<3x4xf64, #sparse_tensor
-//       CHECK:  %[[TMP_9:.*]] = sparse_tensor.coordinates %[[TMP_arg1]] {level = 0 : index} : tensor<3x4xf64, #sparse_tensor
-//       CHECK:  %[[TMP_10:.*]] = sparse_tensor.positions %[[TMP_arg1]] {level = 1 : index} : tensor<3x4xf64, #sparse_tensor
-//       CHECK:  %[[TMP_11:.*]] = sparse_tensor.coordinates %[[TMP_arg1]] {level = 1 : index} : tensor<3x4xf64, #sparse_tensor
-//       CHECK:  %[[TMP_12:.*]] = sparse_tensor.values %[[TMP_arg1]] : tensor<3x4xf64, #sparse_tensor
-//       CHECK:  %[[TMP_13:.*]] = memref.load %[[TMP_8]][%[[TMP_c0]]] : memref<?xindex>
-//       CHECK:  %[[TMP_14:.*]] = memref.load %[[TMP_8]][%[[TMP_c1]]] : memref<?xindex>
-//       CHECK:  scf.for %[[TMP_arg3:.*]] = %[[TMP_13]] to %[[TMP_14]] step %[[TMP_c1]]
-//       CHECK:    %[[TMP_23:.*]] = memref.load %[[TMP_9]][%[[TMP_arg3]]] : memref<?xindex>
-//   CHECK-DAG:    %[[TMP_25:.*]] = memref.load %[[TMP_10]][%[[TMP_arg3]]] : memref<?xindex>
-//   CHECK-DAG:    %[[TMP_24:.*]] = arith.addi %[[TMP_arg3]], %[[TMP_c1]] : index
-//       CHECK:    %[[TMP_26:.*]] = memref.load %[[TMP_10]][%[[TMP_24]]] : memref<?xindex>
-//       CHECK:    scf.for %[[TMP_arg4:.*]] = %[[TMP_25]] to %[[TMP_26]] step %[[TMP_c1]]
-//       CHECK:      %[[TMP_27:.*]] = memref.load %[[TMP_11]][%[[TMP_arg4]]] : memref<?xindex>
-//       CHECK:      %[[TMP_28:.*]] = memref.load %[[TMP_12]][%[[TMP_arg4]]] : memref<?xf64>
-//       CHECK:      %[[TMP_29:.*]] = arith.addi %[[TMP_23]], %[[TMP_c2]] : index
-//       CHECK:      memref.store %[[TMP_28]], %[[A]]{{\[}}%[[TMP_29]], %[[TMP_27]]] : memref<?x?xf64>
-//       CHECK:    }
-//       CHECK:  }
-//       CHECK:  %[[TMP_15:.*]] = sparse_tensor.positions %[[TMP_arg2]] {level = 0 : index} : tensor<4x4xf64, #sparse_tensor
-//       CHECK:  %[[TMP_16:.*]] = sparse_tensor.coordinates %[[TMP_arg2]] {level = 0 : index} : tensor<4x4xf64, #sparse_tensor
-//       CHECK:  %[[TMP_17:.*]] = sparse_tensor.positions %[[TMP_arg2]] {level = 1 : index} : tensor<4x4xf64, #sparse_tensor
-//       CHECK:  %[[TMP_18:.*]] = sparse_tensor.coordinates %[[TMP_arg2]] {level = 1 : index} : tensor<4x4xf64, #sparse_tensor
-//       CHECK:  %[[TMP_19:.*]] = sparse_tensor.values %[[TMP_arg2]] : tensor<4x4xf64, #sparse_tensor
-//       CHECK:  %[[TMP_20:.*]] = memref.load %[[TMP_15]][%[[TMP_c0]]] : memref<?xindex>
-//       CHECK:  %[[TMP_21:.*]] = memref.load %[[TMP_15]][%[[TMP_c1]]] : memref<?xindex>
-//       CHECK:  scf.for %[[TMP_arg3:.*]] = %[[TMP_20]] to %[[TMP_21]] step %[[TMP_c1]]
-//       CHECK:    %[[TMP_23:.*]] = memref.load %[[TMP_16]][%[[TMP_arg3]]] : memref<?xindex>
-//       CHECK:    %[[TMP_25:.*]] = memref.load %[[TMP_17]][%[[TMP_arg3]]] : memref<?xindex>
-//       CHECK:    %[[TMP_24:.*]] = arith.addi %[[TMP_arg3]], %[[TMP_c1]] : index
-//       CHECK:    %[[TMP_26:.*]] = memref.load %[[TMP_17]][%[[TMP_24]]] : memref<?xindex>
-//       CHECK:    scf.for %[[TMP_arg4:.*]] = %[[TMP_25]] to %[[TMP_26]] step %[[TMP_c1]]
-//       CHECK:      %[[TMP_27:.*]] = memref.load %[[TMP_18]][%[[TMP_arg4]]] : memref<?xindex>
-//       CHECK:      %[[TMP_28:.*]] = memref.load %[[TMP_19]][%[[TMP_arg4]]] : memref<?xf64>
-//       CHECK:      %[[TMP_29:.*]] = arith.addi %[[TMP_23]], %[[TMP_c5]] : index
-//       CHECK:      memref.store %[[TMP_28]], %[[A]]{{\[}}%[[TMP_29]], %[[TMP_27]]] : memref<?x?xf64>
-//       CHECK:    }
-//       CHECK:  }
-//       CHECK:  %[[R:.*]] = bufferization.to_tensor %[[A]] : memref<?x?xf64>
-//       CHECK:  return %[[R]] : tensor<?x?xf64>
+// CHECK-LABEL:   func.func @concat_sparse_sparse_dense(
+// CHECK-SAME:      %[[VAL_0:.*]]: tensor<2x4xf64, #sparse_tensor
+// CHECK-SAME:      %[[VAL_1:.*]]: tensor<3x4xf64, #sparse_tensor
+// CHECK-SAME:      %[[VAL_2:.*]]: tensor<4x4xf64, #sparse_tensor
+// CHECK-DAG:       %[[VAL_3:.*]] = arith.constant 4 : index
+// CHECK-DAG:       %[[VAL_4:.*]] = arith.constant 9 : index
+// CHECK-DAG:       %[[VAL_5:.*]] = arith.constant 5 : index
+// CHECK-DAG:       %[[VAL_6:.*]] = arith.constant 0.000000e+00 : f64
+// CHECK-DAG:       %[[VAL_7:.*]] = arith.constant 0 : index
+// CHECK-DAG:       %[[VAL_8:.*]] = arith.constant 1 : index
+// CHECK-DAG:       %[[VAL_9:.*]] = arith.constant 2 : index
+// CHECK:           %[[VAL_10:.*]] = bufferization.alloc_tensor(%[[VAL_4]], %[[VAL_3]]) : tensor<?x?xf64>
+// CHECK:           %[[VAL_11:.*]] = linalg.fill ins(%[[VAL_6]] : f64) outs(%[[VAL_10]] : tensor<?x?xf64>) -> tensor<?x?xf64>
+// CHECK:           %[[VAL_12:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<2x4xf64, #sparse_tensor
+// CHECK:           %[[VAL_13:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<2x4xf64, #sparse_tensor
+// CHECK:           %[[VAL_14:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 1 : index} : tensor<2x4xf64, #sparse_tensor
+// CHECK:           %[[VAL_15:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 1 : index} : tensor<2x4xf64, #sparse_tensor
+// CHECK:           %[[VAL_16:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<2x4xf64, #sparse_tensor
+// CHECK:           %[[VAL_17:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_7]]] : memref<?xindex>
+// CHECK:           %[[VAL_18:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_8]]] : memref<?xindex>
+// CHECK:           %[[VAL_19:.*]] = scf.for %[[VAL_20:.*]] = %[[VAL_17]] to %[[VAL_18]] step %[[VAL_8]] iter_args(%[[VAL_21:.*]] = %[[VAL_11]]) -> (tensor<?x?xf64>) {
+// CHECK:             %[[VAL_22:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_20]]] : memref<?xindex>
+// CHECK:             %[[VAL_23:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_20]]] : memref<?xindex>
+// CHECK:             %[[VAL_24:.*]] = arith.addi %[[VAL_20]], %[[VAL_8]] : index
+// CHECK:             %[[VAL_25:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_24]]] : memref<?xindex>
+// CHECK:             %[[VAL_26:.*]] = scf.for %[[VAL_27:.*]] = %[[VAL_23]] to %[[VAL_25]] step %[[VAL_8]] iter_args(%[[VAL_28:.*]] = %[[VAL_21]]) -> (tensor<?x?xf64>) {
+// CHECK:               %[[VAL_29:.*]] = memref.load %[[VAL_15]]{{\[}}%[[VAL_27]]] : memref<?xindex>
+// CHECK:               %[[VAL_30:.*]] = memref.load %[[VAL_16]]{{\[}}%[[VAL_27]]] : memref<?xf64>
+// CHECK:               %[[VAL_31:.*]] = tensor.insert %[[VAL_30]] into %[[VAL_28]]{{\[}}%[[VAL_22]], %[[VAL_29]]] : tensor<?x?xf64>
+// CHECK:               scf.yield %[[VAL_31]] : tensor<?x?xf64>
+// CHECK:             }
+// CHECK:             scf.yield %[[VAL_26]] : tensor<?x?xf64>
+// CHECK:           }
+// CHECK:           %[[VAL_32:.*]] = sparse_tensor.positions %[[VAL_1]] {level = 0 : index} : tensor<3x4xf64, #sparse_tensor
+// CHECK:           %[[VAL_33:.*]] = sparse_tensor.coordinates %[[VAL_1]] {level = 0 : index} : tensor<3x4xf64, #sparse_tensor
+// CHECK:           %[[VAL_34:.*]] = sparse_tensor.positions %[[VAL_1]] {level = 1 : index} : tensor<3x4xf64, #sparse_tensor
+// CHECK:           %[[VAL_35:.*]] = sparse_tensor.coordinates %[[VAL_1]] {level = 1 : index} : tensor<3x4xf64, #sparse_tensor
+// CHECK:           %[[VAL_36:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<3x4xf64, #sparse_tensor
+// CHECK:           %[[VAL_37:.*]] = memref.load %[[VAL_32]]{{\[}}%[[VAL_7]]] : memref<?xindex>
+// CHECK:           %[[VAL_38:.*]] = memref.load %[[VAL_32]]{{\[}}%[[VAL_8]]] : memref<?xindex>
+// CHECK:           %[[VAL_39:.*]] = scf.for %[[VAL_40:.*]] = %[[VAL_37]] to %[[VAL_38]] step %[[VAL_8]] iter_args(%[[VAL_41:.*]] = %[[VAL_19]]) -> (tensor<?x?xf64>) {
+// CHECK:             %[[VAL_42:.*]] = memref.load %[[VAL_33]]{{\[}}%[[VAL_40]]] : memref<?xindex>
+// CHECK:             %[[VAL_43:.*]] = memref.load %[[VAL_34]]{{\[}}%[[VAL_40]]] : memref<?xindex>
+// CHECK:             %[[VAL_44:.*]] = arith.addi %[[VAL_40]], %[[VAL_8]] : index
+// CHECK:             %[[VAL_45:.*]] = memref.load %[[VAL_34]]{{\[}}%[[VAL_44]]] : memref<?xindex>
+// CHECK:             %[[VAL_46:.*]] = scf.for %[[VAL_47:.*]] = %[[VAL_43]] to %[[VAL_45]] step %[[VAL_8]] iter_args(%[[VAL_48:.*]] = %[[VAL_41]]) -> (tensor<?x?xf64>) {
+// CHECK:               %[[VAL_49:.*]] = memref.load %[[VAL_35]]{{\[}}%[[VAL_47]]] : memref<?xindex>
+// CHECK:               %[[VAL_50:.*]] = memref.load %[[VAL_36]]{{\[}}%[[VAL_47]]] : memref<?xf64>
+// CHECK:               %[[VAL_51:.*]] = arith.addi %[[VAL_42]], %[[VAL_9]] : index
+// CHECK:               %[[VAL_52:.*]] = tensor.insert %[[VAL_50]] into %[[VAL_48]]{{\[}}%[[VAL_51]], %[[VAL_49]]] : tensor<?x?xf64>
+// CHECK:               scf.yield %[[VAL_52]] : tensor<?x?xf64>
+// CHECK:             }
+// CHECK:             scf.yield %[[VAL_46]] : tensor<?x?xf64>
+// CHECK:           }
+// CHECK:           %[[VAL_53:.*]] = sparse_tensor.positions %[[VAL_2]] {level = 0 : index} : tensor<4x4xf64, #sparse_tensor
+// CHECK:           %[[VAL_54:.*]] = sparse_tensor.coordinates %[[VAL_2]] {level = 0 : index} : tensor<4x4xf64, #sparse_tensor
+// CHECK:           %[[VAL_55:.*]] = sparse_tensor.positions %[[VAL_2]] {level = 1 : index} : tensor<4x4xf64, #sparse_tensor
+// CHECK:           %[[VAL_56:.*]] = sparse_tensor.coordinates %[[VAL_2]] {level = 1 : index} : tensor<4x4xf64, #sparse_tensor
+// CHECK:           %[[VAL_57:.*]] = sparse_tensor.values %[[VAL_2]] : tensor<4x4xf64, #sparse_tensor
+// CHECK:           %[[VAL_58:.*]] = memref.load %[[VAL_53]]{{\[}}%[[VAL_7]]] : memref<?xindex>
+// CHECK:           %[[VAL_59:.*]] = memref.load %[[VAL_53]]{{\[}}%[[VAL_8]]] : memref<?xindex>
+// CHECK:           %[[VAL_60:.*]] = scf.for %[[VAL_61:.*]] = %[[VAL_58]] to %[[VAL_59]] step %[[VAL_8]] iter_args(%[[VAL_62:.*]] = %[[VAL_39]]) -> (tensor<?x?xf64>) {
+// CHECK:             %[[VAL_63:.*]] = memref.load %[[VAL_54]]{{\[}}%[[VAL_61]]] : memref<?xindex>
+// CHECK:             %[[VAL_64:.*]] = memref.load %[[VAL_55]]{{\[}}%[[VAL_61]]] : memref<?xindex>
+// CHECK:             %[[VAL_65:.*]] = arith.addi %[[VAL_61]], %[[VAL_8]] : index
+// CHECK:             %[[VAL_66:.*]] = memref.load %[[VAL_55]]{{\[}}%[[VAL_65]]] : memref<?xindex>
+// CHECK:             %[[VAL_67:.*]] = scf.for %[[VAL_68:.*]] = %[[VAL_64]] to %[[VAL_66]] step %[[VAL_8]] iter_args(%[[VAL_69:.*]] = %[[VAL_62]]) -> (tensor<?x?xf64>) {
+// CHECK:               %[[VAL_70:.*]] = memref.load %[[VAL_56]]{{\[}}%[[VAL_68]]] : memref<?xindex>
+// CHECK:               %[[VAL_71:.*]] = memref.load %[[VAL_57]]{{\[}}%[[VAL_68]]] : memref<?xf64>
+// CHECK:               %[[VAL_72:.*]] = arith.addi %[[VAL_63]], %[[VAL_5]] : index
+// CHECK:               %[[VAL_73:.*]] = tensor.insert %[[VAL_71]] into %[[VAL_69]]{{\[}}%[[VAL_72]], %[[VAL_70]]] : tensor<?x?xf64>
+// CHECK:               scf.yield %[[VAL_73]] : tensor<?x?xf64>
+// CHECK:             }
+// CHECK:             scf.yield %[[VAL_67]] : tensor<?x?xf64>
+// CHECK:           }
+// CHECK:           return %[[VAL_60]] : tensor<?x?xf64>
+// CHECK:         }
 func.func @concat_sparse_sparse_dense(%arg0: tensor<2x4xf64, #DCSR>,
                                 %arg1: tensor<3x4xf64, #DCSR>,
                                 %arg2: tensor<4x4xf64, #DCSR>)



More information about the Mlir-commits mailing list