[Mlir-commits] [mlir] [mlir][sparse] implements tensor.insert on sparse tensors. (PR #70737)

Peiming Liu llvmlistbot at llvm.org
Mon Oct 30 15:23:23 PDT 2023


https://github.com/PeimingLiu created https://github.com/llvm/llvm-project/pull/70737

None

>From 9c20453b7198b728ad5f02c382874724a0336c1a Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Mon, 30 Oct 2023 22:17:57 +0000
Subject: [PATCH] [mlir][sparse] implements tensor.insert on sparse tensors.

---
 .../SparseTensor/IR/SparseTensorType.h        |  9 +++
 .../Transforms/SparseReinterpretMap.cpp       | 61 ++++++++++++++++++-
 .../Transforms/SparseTensorRewriting.cpp      | 56 +++--------------
 .../SparsificationAndBufferizationPass.cpp    |  4 +-
 .../SparseTensor/convert_dense2sparse.mlir    | 14 ++---
 .../SparseTensor/convert_sparse2sparse.mlir   |  6 +-
 .../Dialect/SparseTensor/sparse_concat.mlir   | 12 ++--
 7 files changed, 98 insertions(+), 64 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
index 7a1f1e2144e049d..34f56c1947cc27c 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
@@ -251,6 +251,15 @@ class SparseTensorType {
                                        CrdTransDirectionKind::dim2lvl);
   }
 
+  RankedTensorType getDemappedType() const {
+    auto lvlShape = getLvlShape();
+    return RankedTensorType::get(
+        lvlShape, rtp.getElementType(),
+        SparseTensorEncodingAttr::get(rtp.getContext(), getLvlTypes(),
+                                      AffineMap(), AffineMap(), getPosWidth(),
+                                      getCrdWidth(), enc.getDimSlices()));
+  }
+
   /// Safely looks up the requested dimension-DynSize.  If you intend
   /// to check the result with `ShapedType::isDynamic`, then see the
   /// `getStaticDimSize` method instead.
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
index 10722ccb6eea743..66fd2e4d94a28bd 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
@@ -6,9 +6,15 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
 #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/IR/AffineMap.h"
+
+using namespace mlir;
+using namespace mlir::sparse_tensor;
 
 namespace {
 
@@ -17,7 +23,60 @@ namespace {
 //   (2) rewrite linalg.generic ops traits on level crds
 //   (3) compute topsort, and resolve cyles with sparse_tensor.convert ops
 
+//===----------------------------------------------------------------------===//
+// Reiterpret Map Rewriters for operations other than linalg.generics
+//===----------------------------------------------------------------------===//
+
+struct CrdTranslateRewriter : public OpRewritePattern<CrdTranslateOp> {
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(CrdTranslateOp op,
+                                PatternRewriter &rewriter) const override {
+    AffineMap map = op.getDirection() == CrdTransDirectionKind::dim2lvl
+                        ? op.getEncoder().getDimToLvl()
+                        : op.getEncoder().getLvlToDim();
+    SmallVector<Value> outCrds;
+    for (AffineExpr result : map.getResults()) {
+      // TODO: we should probably expand the affine map to IR using our own
+      // rules, since affine.apply assume signed value, while the cooridinates
+      // we provided must always be signless.
+      Value trans = rewriter.create<affine::AffineApplyOp>(
+          op.getLoc(), AffineMap::get(map.getNumDims(), 0, result),
+          op.getInCrds());
+      outCrds.push_back(trans);
+    }
+    rewriter.replaceOp(op, outCrds);
+    return success();
+  }
+};
+
+struct TensorInsertRewriter : public OpRewritePattern<tensor::InsertOp> {
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(tensor::InsertOp op,
+                                PatternRewriter &rewriter) const override {
+
+    if (!op.getResult().getType().getEncoding())
+      return failure();
+    Location loc = op.getLoc();
+    auto stt = getSparseTensorType(op.getResult());
+    ValueRange lvlCrd = stt.translateCrds(rewriter, loc, op.getIndices(),
+                                          CrdTransDirectionKind::dim2lvl);
+
+    Value t = rewriter.create<ReinterpretMapOp>(
+        loc, stt.getEncoding().withoutDimToLvl(), op.getDest());
+    t = rewriter.create<sparse_tensor::InsertOp>(loc, op.getScalar(), t,
+                                                 lvlCrd);
+    rewriter.replaceOpWithNewOp<ReinterpretMapOp>(op, op.getType(), t);
+    return success();
+  }
+};
+
 } // namespace
 
 void mlir::populateSparseReinterpretMap(RewritePatternSet &patterns,
-                                        ReinterpretMapScope scope) {}
+                                        ReinterpretMapScope scope) {
+  if (scope == ReinterpretMapScope::kAll ||
+      scope == ReinterpretMapScope::kExceptGeneric) {
+    patterns.add<CrdTranslateRewriter, TensorInsertRewriter>(
+        patterns.getContext());
+  }
+}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index 528e70bd3b1ef5f..2d45087aa5801cd 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -846,11 +846,7 @@ struct TensorLike {
   }
 
   void insert(OpBuilder &builder, Location loc, Value v, ValueRange crds) {
-    // TODO: Unify these two.
-    if (isSparse())
-      val = builder.create<sparse_tensor::InsertOp>(loc, v, val, crds);
-    else
-      val = builder.create<tensor::InsertOp>(loc, v, val, crds);
+    val = builder.create<tensor::InsertOp>(loc, v, val, crds);
   }
 
   Value finalize(OpBuilder &builder, Location loc, RankedTensorType rtp) const {
@@ -866,28 +862,6 @@ struct TensorLike {
   Value val;
 };
 
-struct CrdTranslateRewriter : public OpRewritePattern<CrdTranslateOp> {
-  using OpRewritePattern::OpRewritePattern;
-  LogicalResult matchAndRewrite(CrdTranslateOp op,
-                                PatternRewriter &rewriter) const override {
-    AffineMap map = op.getDirection() == CrdTransDirectionKind::dim2lvl
-                        ? op.getEncoder().getDimToLvl()
-                        : op.getEncoder().getLvlToDim();
-    SmallVector<Value> outCrds;
-    for (AffineExpr result : map.getResults()) {
-      // TODO: we should probably expand the affine map to IR using our own
-      // rules, since affine.apply assume signed value, while the cooridinates
-      // we provided must always be signless.
-      Value trans = rewriter.create<affine::AffineApplyOp>(
-          op.getLoc(), AffineMap::get(map.getNumDims(), 0, result),
-          op.getInCrds());
-      outCrds.push_back(trans);
-    }
-    rewriter.replaceOp(op, outCrds);
-    return success();
-  }
-};
-
 struct SparseTensorDimOpRewriter : public OpRewritePattern<tensor::DimOp> {
   using OpRewritePattern::OpRewritePattern;
   LogicalResult matchAndRewrite(tensor::DimOp op,
@@ -969,15 +943,10 @@ struct ConcatenateRewriter : public OpRewritePattern<ConcatenateOp> {
           loc, input, iterArg,
           [&](OpBuilder &builder, Location loc, ValueRange dcvs, Value v,
               ValueRange reduc) {
-            SmallVector<Value> dstLcvs(dstTp.getLvlRank());
-            for (Dimension d = 0; d < dimRank; d++) {
-              Value crd = dcvs[d];
-              // Transforms coordinates for the concatenating dim.
-              if (d == conDim)
-                crd = builder.create<arith::AddIOp>(loc, crd, offset);
-              // FIXME: `toStoredDim` is deprecated
-              dstLcvs[toStoredDim(dstTp.getEncoding(), d)] = crd;
-            }
+            SmallVector<Value> offDimCrd(dcvs);
+            offDimCrd[conDim] =
+                builder.create<arith::AddIOp>(loc, offDimCrd[conDim], offset);
+
             // Enters foreach, updates the SSA chain.
             dstBuf.val = reduc.front();
             if (!dstTp.isAllDense()) {
@@ -988,14 +957,14 @@ struct ConcatenateRewriter : public OpRewritePattern<ConcatenateOp> {
               builder.create<scf::YieldOp>(loc, dstBuf.val);
 
               builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
-              dstBuf.insert(builder, loc, v, dstLcvs);
+              dstBuf.insert(builder, loc, v, offDimCrd);
               builder.create<scf::YieldOp>(loc, dstBuf.val);
 
               // Exits the ifOp, update the sparse tensor SSA value.
               builder.setInsertionPointAfter(ifOp);
               dstBuf.val = ifOp.getResult(0);
             } else {
-              dstBuf.insert(builder, loc, v, dstLcvs);
+              dstBuf.insert(builder, loc, v, offDimCrd);
             }
             builder.create<sparse_tensor::YieldOp>(loc, dstBuf.val);
           });
@@ -1064,10 +1033,6 @@ struct DirectConvertRewriter : public OpRewritePattern<ConvertOp> {
             ValueRange reduc) {
           // Enters the loop, update the SSA value for insertion chain.
           dstBuf.val = reduc.front();
-
-          ValueRange lcvs = dstStt.translateCrds(
-              builder, loc, dcvs, CrdTransDirectionKind::dim2lvl);
-
           if (!skipZeroCheck) {
             Value cond = genIsNonzero(builder, loc, v);
             auto ifOp = builder.create<scf::IfOp>(loc, reduc.getTypes(), cond,
@@ -1076,14 +1041,14 @@ struct DirectConvertRewriter : public OpRewritePattern<ConvertOp> {
             builder.create<scf::YieldOp>(loc, dstBuf.val);
 
             builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
-            dstBuf.insert(builder, loc, v, lcvs);
+            dstBuf.insert(builder, loc, v, dcvs);
             builder.create<scf::YieldOp>(loc, dstBuf.val);
 
             // Exits the ifOp, update the sparse tensor SSA value.
             builder.setInsertionPointAfter(ifOp);
             dstBuf.val = ifOp.getResult(0);
           } else {
-            dstBuf.insert(builder, loc, v, lcvs);
+            dstBuf.insert(builder, loc, v, dcvs);
           }
           builder.create<sparse_tensor::YieldOp>(loc, dstBuf.val);
         });
@@ -1306,8 +1271,7 @@ void mlir::populatePreSparsificationRewriting(RewritePatternSet &patterns) {
 void mlir::populateLowerSparseOpsToForeachPatterns(RewritePatternSet &patterns,
                                                    bool enableRT,
                                                    bool enableConvert) {
-  patterns.add<ConcatenateRewriter, CrdTranslateRewriter,
-               ReshapeRewriter<tensor::ExpandShapeOp>,
+  patterns.add<ConcatenateRewriter, ReshapeRewriter<tensor::ExpandShapeOp>,
                ReshapeRewriter<tensor::CollapseShapeOp>,
                Sparse2SparseReshapeRewriter<tensor::ExpandShapeOp>,
                Sparse2SparseReshapeRewriter<tensor::CollapseShapeOp>,
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp
index f3f3828e0c5bdff..41940f731e76c17 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp
@@ -143,7 +143,9 @@ class SparsificationAndBufferizationPass
       pm.addNestedPass<func::FuncOp>(createStageSparseOperationsPass());
       pm.addPass(createLowerSparseOpsToForeachPass(enableRuntimeLibrary,
                                                    /*enableConvert=*/true));
-      // TODO: DemapPass here!
+      // Handle dim-to-lvl maps on operations other than linalg.generic.
+      pm.addPass(
+          createSparseReinterpretMapPass(ReinterpretMapScope::kExceptGeneric));
       pm.addNestedPass<func::FuncOp>(createLowerForeachToSCFPass());
       if (vectorLength > 0) {
         pm.addPass(mlir::createLoopInvariantCodeMotionPass());
diff --git a/mlir/test/Dialect/SparseTensor/convert_dense2sparse.mlir b/mlir/test/Dialect/SparseTensor/convert_dense2sparse.mlir
index 4f37ae9207be9cc..96a1140372bd6cd 100644
--- a/mlir/test/Dialect/SparseTensor/convert_dense2sparse.mlir
+++ b/mlir/test/Dialect/SparseTensor/convert_dense2sparse.mlir
@@ -19,7 +19,7 @@
 // CHECK-LABEL:   func.func @sparse_convert_1d
 // CHECK:           sparse_tensor.foreach
 // CHECK:            scf.if
-// CHECK:              sparse_tensor.insert
+// CHECK:              tensor.insert
 // CHECK-NOT:       sparse_tensor.reorder_coo
 // CHECK:           sparse_tensor.load
 func.func @sparse_convert_1d(%arg0: tensor<?xi32>) -> tensor<?xi32, #SparseVector> {
@@ -30,7 +30,7 @@ func.func @sparse_convert_1d(%arg0: tensor<?xi32>) -> tensor<?xi32, #SparseVecto
 // CHECK-LABEL:   func.func @sparse_convert_complex
 // CHECK:           sparse_tensor.foreach
 // CHECK:            scf.if
-// CHECK:              sparse_tensor.insert
+// CHECK:              tensor.insert
 // CHECK-NOT:       sparse_tensor.reorder_coo
 // CHECK:           sparse_tensor.load
 func.func @sparse_convert_complex(%arg0: tensor<100xcomplex<f64>>) -> tensor<100xcomplex<f64>, #SparseVector> {
@@ -41,7 +41,7 @@ func.func @sparse_convert_complex(%arg0: tensor<100xcomplex<f64>>) -> tensor<100
 // CHECK-LABEL:   func.func @sparse_convert_2d
 // CHECK:           sparse_tensor.foreach
 // CHECK:            scf.if
-// CHECK:              sparse_tensor.insert
+// CHECK:              tensor.insert
 // CHECK-NOT:       sparse_tensor.reorder_coo
 // CHECK:           sparse_tensor.load
 func.func @sparse_convert_2d(%arg0: tensor<2x4xf64>) -> tensor<2x4xf64, #CSR> {
@@ -52,7 +52,7 @@ func.func @sparse_convert_2d(%arg0: tensor<2x4xf64>) -> tensor<2x4xf64, #CSR> {
 // CHECK-LABEL:   func.func @sparse_constant
 // CHECK:           sparse_tensor.foreach
 // CHECK-NOT:         scf.if
-// CHECK:               sparse_tensor.insert
+// CHECK:               tensor.insert
 // CHECK-NOT:       sparse_tensor.reorder_coo
 // CHECK:           sparse_tensor.load
 func.func @sparse_constant() -> tensor<8x7xf32, #CSR>{
@@ -66,7 +66,7 @@ func.func @sparse_constant() -> tensor<8x7xf32, #CSR>{
 // CHECK-LABEL:   func.func @sparse_constant_csc
 // CHECK:           sparse_tensor.foreach
 // CHECK-NOT:         scf.if
-// CHECK:               sparse_tensor.insert
+// CHECK:               tensor.insert
 // CHECK-NOT:       sparse_tensor.reorder_coo
 // CHECK:           sparse_tensor.load
 func.func @sparse_constant_csc() -> tensor<8x7xf32, #CSC>{
@@ -80,11 +80,11 @@ func.func @sparse_constant_csc() -> tensor<8x7xf32, #CSC>{
 // CHECK-LABEL:   func.func @sparse_convert_3d
 // CHECK:           sparse_tensor.foreach
 // CHECK:             scf.if
-// CHECK:               sparse_tensor.insert
+// CHECK:               tensor.insert
 // CHECK:           sparse_tensor.load
 // CHECK:           sparse_tensor.reorder_coo
 // CHECK:           sparse_tensor.foreach
-// CHECK:             sparse_tensor.insert
+// CHECK:             tensor.insert
 // CHECK:           sparse_tensor.load
 func.func @sparse_convert_3d(%arg0: tensor<?x?x?xf64>) -> tensor<?x?x?xf64, #SparseTensor> {
   %0 = sparse_tensor.convert %arg0 : tensor<?x?x?xf64> to tensor<?x?x?xf64, #SparseTensor>
diff --git a/mlir/test/Dialect/SparseTensor/convert_sparse2sparse.mlir b/mlir/test/Dialect/SparseTensor/convert_sparse2sparse.mlir
index 896bc02212971f0..0673f915a1cf626 100644
--- a/mlir/test/Dialect/SparseTensor/convert_sparse2sparse.mlir
+++ b/mlir/test/Dialect/SparseTensor/convert_sparse2sparse.mlir
@@ -66,11 +66,11 @@ func.func @sparse_convert(%arg0: tensor<?xf32, #SparseVector64>) -> tensor<?xf32
 
 // CHECK-LABEL:   func.func @sparse_convert_permuted
 // CHECK:           sparse_tensor.foreach
-// CHECK:             sparse_tensor.insert
+// CHECK:             tensor.insert
 // CHECK:           sparse_tensor.load
 // CHECK:           sparse_tensor.reorder_coo
 // CHECK:           sparse_tensor.foreach
-// CHECK:             sparse_tensor.insert
+// CHECK:             tensor.insert
 // CHECK:           sparse_tensor.load
 // CHECK:           return
 func.func @sparse_convert_permuted(%arg0: tensor<?x?x?xf32, #SortedCOO3D>) -> tensor<?x?x?xf32, #TsssPermuted> {
@@ -80,7 +80,7 @@ func.func @sparse_convert_permuted(%arg0: tensor<?x?x?xf32, #SortedCOO3D>) -> te
 
 // CHECK-LABEL:   func.func @sparse_convert_slice
 // CHECK:           sparse_tensor.foreach
-// CHECK:             sparse_tensor.insert
+// CHECK:             tensor.insert
 // CHECK:           sparse_tensor.load
 // CHECK-NOT:       sparse_tensor.reorder_coo
 // CHECK:           return
diff --git a/mlir/test/Dialect/SparseTensor/sparse_concat.mlir b/mlir/test/Dialect/SparseTensor/sparse_concat.mlir
index e4e2748112d78c4..86dc9a117507135 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_concat.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_concat.mlir
@@ -30,7 +30,7 @@
 //       CHECK:    %[[RET_4:.*]] = scf.for %[[TMP_arg4:.*]] = %[[TMP_25]] to %[[TMP_26]] step %[[TMP_c1]] iter_args(%[[A1:.*]] = %[[A0]])
 //       CHECK:      %[[TMP_27:.*]] = memref.load %[[TMP_4]][%[[TMP_arg4]]] : memref<?xindex>
 //       CHECK:      %[[TMP_28:.*]] = memref.load %[[TMP_5]][%[[TMP_arg4]]] : memref<?xf64>
-//       CHECK:      %[[NEW_1:.*]] = sparse_tensor.insert %[[TMP_28]] into %[[A1]][%[[TMP_23]], %[[TMP_27]]] : tensor<9x4xf64, #sparse_tensor
+//       CHECK:      %[[NEW_1:.*]] = tensor.insert %[[TMP_28]] into %[[A1]][%[[TMP_23]], %[[TMP_27]]] : tensor<9x4xf64, #sparse_tensor
 //       CHECK:      scf.yield %[[NEW_1]]
 //       CHECK:    }
 //       CHECK:    scf.yield %[[RET_4]]
@@ -51,7 +51,7 @@
 //       CHECK:      %[[TMP_27:.*]] = memref.load %[[TMP_11]][%[[TMP_arg4]]] : memref<?xindex>
 //       CHECK:      %[[TMP_28:.*]] = memref.load %[[TMP_12]][%[[TMP_arg4]]] : memref<?xf64>
 //       CHECK:      %[[TMP_29:.*]] = arith.addi %[[TMP_23]], %[[TMP_c2]] : index
-//       CHECK:      %[[NEW_2:.*]] = sparse_tensor.insert %[[TMP_28]] into %[[A3]][%[[TMP_29]], %[[TMP_27]]] : tensor<9x4xf64, #sparse_tensor
+//       CHECK:      %[[NEW_2:.*]] = tensor.insert %[[TMP_28]] into %[[A3]][%[[TMP_29]], %[[TMP_27]]] : tensor<9x4xf64, #sparse_tensor
 //       CHECK:      scf.yield %[[NEW_2]]
 //       CHECK:    }
 //       CHECK:    scf.yield %[[RET_5]]
@@ -72,7 +72,7 @@
 //       CHECK:      %[[TMP_27:.*]] = memref.load %[[TMP_18]][%[[TMP_arg4]]] : memref<?xindex>
 //       CHECK:      %[[TMP_28:.*]] = memref.load %[[TMP_19]][%[[TMP_arg4]]] : memref<?xf64>
 //       CHECK:      %[[TMP_29:.*]] = arith.addi %[[TMP_23]], %[[TMP_c5]] : index
-//       CHECK:      %[[NEW_3:.*]] = sparse_tensor.insert %[[TMP_28]] into %[[A5]][%[[TMP_29]], %[[TMP_27]]] : tensor<9x4xf64, #sparse_tensor
+//       CHECK:      %[[NEW_3:.*]] = tensor.insert %[[TMP_28]] into %[[A5]][%[[TMP_29]], %[[TMP_27]]] : tensor<9x4xf64, #sparse_tensor
 //       CHECK:      scf.yield %[[NEW_3]]
 //       CHECK:    }
 //       CHECK:    scf.yield %[[RET_6]]
@@ -116,7 +116,7 @@ func.func @concat_sparse_sparse(%arg0: tensor<2x4xf64, #DCSR>,
 //       CHECK:    %[[RET_4:.*]] = scf.for %[[TMP_arg4:.*]] = %[[TMP_25]] to %[[TMP_26]] step %[[TMP_c1]] iter_args(%[[A1:.*]] = %[[A0]])
 //       CHECK:      %[[TMP_27:.*]] = memref.load %[[TMP_4]][%[[TMP_arg4]]] : memref<?xindex>
 //       CHECK:      %[[TMP_28:.*]] = memref.load %[[TMP_5]][%[[TMP_arg4]]] : memref<?xf64>
-//       CHECK:      %[[NEW_1:.*]] = sparse_tensor.insert %[[TMP_28]] into %[[A1]][%[[TMP_23]], %[[TMP_27]]] : tensor<?x?xf64, #sparse_tensor
+//       CHECK:      %[[NEW_1:.*]] = tensor.insert %[[TMP_28]] into %[[A1]][%[[TMP_23]], %[[TMP_27]]] : tensor<?x?xf64, #sparse_tensor
 //       CHECK:      scf.yield %[[NEW_1]]
 //       CHECK:    }
 //       CHECK:    scf.yield %[[RET_4]]
@@ -137,7 +137,7 @@ func.func @concat_sparse_sparse(%arg0: tensor<2x4xf64, #DCSR>,
 //       CHECK:      %[[TMP_27:.*]] = memref.load %[[TMP_11]][%[[TMP_arg4]]] : memref<?xindex>
 //       CHECK:      %[[TMP_28:.*]] = memref.load %[[TMP_12]][%[[TMP_arg4]]] : memref<?xf64>
 //       CHECK:      %[[TMP_29:.*]] = arith.addi %[[TMP_23]], %[[TMP_c2]] : index
-//       CHECK:      %[[NEW_2:.*]] = sparse_tensor.insert %[[TMP_28]] into %[[A3]][%[[TMP_29]], %[[TMP_27]]] : tensor<?x?xf64, #sparse_tensor
+//       CHECK:      %[[NEW_2:.*]] = tensor.insert %[[TMP_28]] into %[[A3]][%[[TMP_29]], %[[TMP_27]]] : tensor<?x?xf64, #sparse_tensor
 //       CHECK:      scf.yield %[[NEW_2]]
 //       CHECK:    }
 //       CHECK:    scf.yield %[[RET_5]]
@@ -158,7 +158,7 @@ func.func @concat_sparse_sparse(%arg0: tensor<2x4xf64, #DCSR>,
 //       CHECK:      %[[TMP_27:.*]] = memref.load %[[TMP_18]][%[[TMP_arg4]]] : memref<?xindex>
 //       CHECK:      %[[TMP_28:.*]] = memref.load %[[TMP_19]][%[[TMP_arg4]]] : memref<?xf64>
 //       CHECK:      %[[TMP_29:.*]] = arith.addi %[[TMP_23]], %[[TMP_c5]] : index
-//       CHECK:      %[[NEW_3:.*]] = sparse_tensor.insert %[[TMP_28]] into %[[A5]][%[[TMP_29]], %[[TMP_27]]] : tensor<?x?xf64, #sparse_tensor
+//       CHECK:      %[[NEW_3:.*]] = tensor.insert %[[TMP_28]] into %[[A5]][%[[TMP_29]], %[[TMP_27]]] : tensor<?x?xf64, #sparse_tensor
 //       CHECK:      scf.yield %[[NEW_3]]
 //       CHECK:    }
 //       CHECK:    scf.yield %[[RET_6]]



More information about the Mlir-commits mailing list