[Mlir-commits] [mlir] [mlir][sparse] implements tensor.insert on sparse tensors. (PR #70737)
Peiming Liu
llvmlistbot at llvm.org
Mon Oct 30 15:23:23 PDT 2023
https://github.com/PeimingLiu created https://github.com/llvm/llvm-project/pull/70737
None
>From 9c20453b7198b728ad5f02c382874724a0336c1a Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Mon, 30 Oct 2023 22:17:57 +0000
Subject: [PATCH] [mlir][sparse] implements tensor.insert on sparse tensors.
---
.../SparseTensor/IR/SparseTensorType.h | 9 +++
.../Transforms/SparseReinterpretMap.cpp | 61 ++++++++++++++++++-
.../Transforms/SparseTensorRewriting.cpp | 56 +++--------------
.../SparsificationAndBufferizationPass.cpp | 4 +-
.../SparseTensor/convert_dense2sparse.mlir | 14 ++---
.../SparseTensor/convert_sparse2sparse.mlir | 6 +-
.../Dialect/SparseTensor/sparse_concat.mlir | 12 ++--
7 files changed, 98 insertions(+), 64 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
index 7a1f1e2144e049d..34f56c1947cc27c 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
@@ -251,6 +251,15 @@ class SparseTensorType {
CrdTransDirectionKind::dim2lvl);
}
+ RankedTensorType getDemappedType() const {
+ auto lvlShape = getLvlShape();
+ return RankedTensorType::get(
+ lvlShape, rtp.getElementType(),
+ SparseTensorEncodingAttr::get(rtp.getContext(), getLvlTypes(),
+ AffineMap(), AffineMap(), getPosWidth(),
+ getCrdWidth(), enc.getDimSlices()));
+ }
+
/// Safely looks up the requested dimension-DynSize. If you intend
/// to check the result with `ShapedType::isDynamic`, then see the
/// `getStaticDimSize` method instead.
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
index 10722ccb6eea743..66fd2e4d94a28bd 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
@@ -6,9 +6,15 @@
//
//===----------------------------------------------------------------------===//
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/IR/AffineMap.h"
+
+using namespace mlir;
+using namespace mlir::sparse_tensor;
namespace {
@@ -17,7 +23,60 @@ namespace {
// (2) rewrite linalg.generic ops traits on level crds
// (3) compute topsort, and resolve cyles with sparse_tensor.convert ops
+//===----------------------------------------------------------------------===//
+// Reiterpret Map Rewriters for operations other than linalg.generics
+//===----------------------------------------------------------------------===//
+
+struct CrdTranslateRewriter : public OpRewritePattern<CrdTranslateOp> {
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(CrdTranslateOp op,
+ PatternRewriter &rewriter) const override {
+ AffineMap map = op.getDirection() == CrdTransDirectionKind::dim2lvl
+ ? op.getEncoder().getDimToLvl()
+ : op.getEncoder().getLvlToDim();
+ SmallVector<Value> outCrds;
+ for (AffineExpr result : map.getResults()) {
+ // TODO: we should probably expand the affine map to IR using our own
+ // rules, since affine.apply assume signed value, while the cooridinates
+ // we provided must always be signless.
+ Value trans = rewriter.create<affine::AffineApplyOp>(
+ op.getLoc(), AffineMap::get(map.getNumDims(), 0, result),
+ op.getInCrds());
+ outCrds.push_back(trans);
+ }
+ rewriter.replaceOp(op, outCrds);
+ return success();
+ }
+};
+
+struct TensorInsertRewriter : public OpRewritePattern<tensor::InsertOp> {
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(tensor::InsertOp op,
+ PatternRewriter &rewriter) const override {
+
+ if (!op.getResult().getType().getEncoding())
+ return failure();
+ Location loc = op.getLoc();
+ auto stt = getSparseTensorType(op.getResult());
+ ValueRange lvlCrd = stt.translateCrds(rewriter, loc, op.getIndices(),
+ CrdTransDirectionKind::dim2lvl);
+
+ Value t = rewriter.create<ReinterpretMapOp>(
+ loc, stt.getEncoding().withoutDimToLvl(), op.getDest());
+ t = rewriter.create<sparse_tensor::InsertOp>(loc, op.getScalar(), t,
+ lvlCrd);
+ rewriter.replaceOpWithNewOp<ReinterpretMapOp>(op, op.getType(), t);
+ return success();
+ }
+};
+
} // namespace
void mlir::populateSparseReinterpretMap(RewritePatternSet &patterns,
- ReinterpretMapScope scope) {}
+ ReinterpretMapScope scope) {
+ if (scope == ReinterpretMapScope::kAll ||
+ scope == ReinterpretMapScope::kExceptGeneric) {
+ patterns.add<CrdTranslateRewriter, TensorInsertRewriter>(
+ patterns.getContext());
+ }
+}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index 528e70bd3b1ef5f..2d45087aa5801cd 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -846,11 +846,7 @@ struct TensorLike {
}
void insert(OpBuilder &builder, Location loc, Value v, ValueRange crds) {
- // TODO: Unify these two.
- if (isSparse())
- val = builder.create<sparse_tensor::InsertOp>(loc, v, val, crds);
- else
- val = builder.create<tensor::InsertOp>(loc, v, val, crds);
+ val = builder.create<tensor::InsertOp>(loc, v, val, crds);
}
Value finalize(OpBuilder &builder, Location loc, RankedTensorType rtp) const {
@@ -866,28 +862,6 @@ struct TensorLike {
Value val;
};
-struct CrdTranslateRewriter : public OpRewritePattern<CrdTranslateOp> {
- using OpRewritePattern::OpRewritePattern;
- LogicalResult matchAndRewrite(CrdTranslateOp op,
- PatternRewriter &rewriter) const override {
- AffineMap map = op.getDirection() == CrdTransDirectionKind::dim2lvl
- ? op.getEncoder().getDimToLvl()
- : op.getEncoder().getLvlToDim();
- SmallVector<Value> outCrds;
- for (AffineExpr result : map.getResults()) {
- // TODO: we should probably expand the affine map to IR using our own
- // rules, since affine.apply assume signed value, while the cooridinates
- // we provided must always be signless.
- Value trans = rewriter.create<affine::AffineApplyOp>(
- op.getLoc(), AffineMap::get(map.getNumDims(), 0, result),
- op.getInCrds());
- outCrds.push_back(trans);
- }
- rewriter.replaceOp(op, outCrds);
- return success();
- }
-};
-
struct SparseTensorDimOpRewriter : public OpRewritePattern<tensor::DimOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(tensor::DimOp op,
@@ -969,15 +943,10 @@ struct ConcatenateRewriter : public OpRewritePattern<ConcatenateOp> {
loc, input, iterArg,
[&](OpBuilder &builder, Location loc, ValueRange dcvs, Value v,
ValueRange reduc) {
- SmallVector<Value> dstLcvs(dstTp.getLvlRank());
- for (Dimension d = 0; d < dimRank; d++) {
- Value crd = dcvs[d];
- // Transforms coordinates for the concatenating dim.
- if (d == conDim)
- crd = builder.create<arith::AddIOp>(loc, crd, offset);
- // FIXME: `toStoredDim` is deprecated
- dstLcvs[toStoredDim(dstTp.getEncoding(), d)] = crd;
- }
+ SmallVector<Value> offDimCrd(dcvs);
+ offDimCrd[conDim] =
+ builder.create<arith::AddIOp>(loc, offDimCrd[conDim], offset);
+
// Enters foreach, updates the SSA chain.
dstBuf.val = reduc.front();
if (!dstTp.isAllDense()) {
@@ -988,14 +957,14 @@ struct ConcatenateRewriter : public OpRewritePattern<ConcatenateOp> {
builder.create<scf::YieldOp>(loc, dstBuf.val);
builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
- dstBuf.insert(builder, loc, v, dstLcvs);
+ dstBuf.insert(builder, loc, v, offDimCrd);
builder.create<scf::YieldOp>(loc, dstBuf.val);
// Exits the ifOp, update the sparse tensor SSA value.
builder.setInsertionPointAfter(ifOp);
dstBuf.val = ifOp.getResult(0);
} else {
- dstBuf.insert(builder, loc, v, dstLcvs);
+ dstBuf.insert(builder, loc, v, offDimCrd);
}
builder.create<sparse_tensor::YieldOp>(loc, dstBuf.val);
});
@@ -1064,10 +1033,6 @@ struct DirectConvertRewriter : public OpRewritePattern<ConvertOp> {
ValueRange reduc) {
// Enters the loop, update the SSA value for insertion chain.
dstBuf.val = reduc.front();
-
- ValueRange lcvs = dstStt.translateCrds(
- builder, loc, dcvs, CrdTransDirectionKind::dim2lvl);
-
if (!skipZeroCheck) {
Value cond = genIsNonzero(builder, loc, v);
auto ifOp = builder.create<scf::IfOp>(loc, reduc.getTypes(), cond,
@@ -1076,14 +1041,14 @@ struct DirectConvertRewriter : public OpRewritePattern<ConvertOp> {
builder.create<scf::YieldOp>(loc, dstBuf.val);
builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
- dstBuf.insert(builder, loc, v, lcvs);
+ dstBuf.insert(builder, loc, v, dcvs);
builder.create<scf::YieldOp>(loc, dstBuf.val);
// Exits the ifOp, update the sparse tensor SSA value.
builder.setInsertionPointAfter(ifOp);
dstBuf.val = ifOp.getResult(0);
} else {
- dstBuf.insert(builder, loc, v, lcvs);
+ dstBuf.insert(builder, loc, v, dcvs);
}
builder.create<sparse_tensor::YieldOp>(loc, dstBuf.val);
});
@@ -1306,8 +1271,7 @@ void mlir::populatePreSparsificationRewriting(RewritePatternSet &patterns) {
void mlir::populateLowerSparseOpsToForeachPatterns(RewritePatternSet &patterns,
bool enableRT,
bool enableConvert) {
- patterns.add<ConcatenateRewriter, CrdTranslateRewriter,
- ReshapeRewriter<tensor::ExpandShapeOp>,
+ patterns.add<ConcatenateRewriter, ReshapeRewriter<tensor::ExpandShapeOp>,
ReshapeRewriter<tensor::CollapseShapeOp>,
Sparse2SparseReshapeRewriter<tensor::ExpandShapeOp>,
Sparse2SparseReshapeRewriter<tensor::CollapseShapeOp>,
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp
index f3f3828e0c5bdff..41940f731e76c17 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp
@@ -143,7 +143,9 @@ class SparsificationAndBufferizationPass
pm.addNestedPass<func::FuncOp>(createStageSparseOperationsPass());
pm.addPass(createLowerSparseOpsToForeachPass(enableRuntimeLibrary,
/*enableConvert=*/true));
- // TODO: DemapPass here!
+ // Handle dim-to-lvl maps on operations other than linalg.generic.
+ pm.addPass(
+ createSparseReinterpretMapPass(ReinterpretMapScope::kExceptGeneric));
pm.addNestedPass<func::FuncOp>(createLowerForeachToSCFPass());
if (vectorLength > 0) {
pm.addPass(mlir::createLoopInvariantCodeMotionPass());
diff --git a/mlir/test/Dialect/SparseTensor/convert_dense2sparse.mlir b/mlir/test/Dialect/SparseTensor/convert_dense2sparse.mlir
index 4f37ae9207be9cc..96a1140372bd6cd 100644
--- a/mlir/test/Dialect/SparseTensor/convert_dense2sparse.mlir
+++ b/mlir/test/Dialect/SparseTensor/convert_dense2sparse.mlir
@@ -19,7 +19,7 @@
// CHECK-LABEL: func.func @sparse_convert_1d
// CHECK: sparse_tensor.foreach
// CHECK: scf.if
-// CHECK: sparse_tensor.insert
+// CHECK: tensor.insert
// CHECK-NOT: sparse_tensor.reorder_coo
// CHECK: sparse_tensor.load
func.func @sparse_convert_1d(%arg0: tensor<?xi32>) -> tensor<?xi32, #SparseVector> {
@@ -30,7 +30,7 @@ func.func @sparse_convert_1d(%arg0: tensor<?xi32>) -> tensor<?xi32, #SparseVecto
// CHECK-LABEL: func.func @sparse_convert_complex
// CHECK: sparse_tensor.foreach
// CHECK: scf.if
-// CHECK: sparse_tensor.insert
+// CHECK: tensor.insert
// CHECK-NOT: sparse_tensor.reorder_coo
// CHECK: sparse_tensor.load
func.func @sparse_convert_complex(%arg0: tensor<100xcomplex<f64>>) -> tensor<100xcomplex<f64>, #SparseVector> {
@@ -41,7 +41,7 @@ func.func @sparse_convert_complex(%arg0: tensor<100xcomplex<f64>>) -> tensor<100
// CHECK-LABEL: func.func @sparse_convert_2d
// CHECK: sparse_tensor.foreach
// CHECK: scf.if
-// CHECK: sparse_tensor.insert
+// CHECK: tensor.insert
// CHECK-NOT: sparse_tensor.reorder_coo
// CHECK: sparse_tensor.load
func.func @sparse_convert_2d(%arg0: tensor<2x4xf64>) -> tensor<2x4xf64, #CSR> {
@@ -52,7 +52,7 @@ func.func @sparse_convert_2d(%arg0: tensor<2x4xf64>) -> tensor<2x4xf64, #CSR> {
// CHECK-LABEL: func.func @sparse_constant
// CHECK: sparse_tensor.foreach
// CHECK-NOT: scf.if
-// CHECK: sparse_tensor.insert
+// CHECK: tensor.insert
// CHECK-NOT: sparse_tensor.reorder_coo
// CHECK: sparse_tensor.load
func.func @sparse_constant() -> tensor<8x7xf32, #CSR>{
@@ -66,7 +66,7 @@ func.func @sparse_constant() -> tensor<8x7xf32, #CSR>{
// CHECK-LABEL: func.func @sparse_constant_csc
// CHECK: sparse_tensor.foreach
// CHECK-NOT: scf.if
-// CHECK: sparse_tensor.insert
+// CHECK: tensor.insert
// CHECK-NOT: sparse_tensor.reorder_coo
// CHECK: sparse_tensor.load
func.func @sparse_constant_csc() -> tensor<8x7xf32, #CSC>{
@@ -80,11 +80,11 @@ func.func @sparse_constant_csc() -> tensor<8x7xf32, #CSC>{
// CHECK-LABEL: func.func @sparse_convert_3d
// CHECK: sparse_tensor.foreach
// CHECK: scf.if
-// CHECK: sparse_tensor.insert
+// CHECK: tensor.insert
// CHECK: sparse_tensor.load
// CHECK: sparse_tensor.reorder_coo
// CHECK: sparse_tensor.foreach
-// CHECK: sparse_tensor.insert
+// CHECK: tensor.insert
// CHECK: sparse_tensor.load
func.func @sparse_convert_3d(%arg0: tensor<?x?x?xf64>) -> tensor<?x?x?xf64, #SparseTensor> {
%0 = sparse_tensor.convert %arg0 : tensor<?x?x?xf64> to tensor<?x?x?xf64, #SparseTensor>
diff --git a/mlir/test/Dialect/SparseTensor/convert_sparse2sparse.mlir b/mlir/test/Dialect/SparseTensor/convert_sparse2sparse.mlir
index 896bc02212971f0..0673f915a1cf626 100644
--- a/mlir/test/Dialect/SparseTensor/convert_sparse2sparse.mlir
+++ b/mlir/test/Dialect/SparseTensor/convert_sparse2sparse.mlir
@@ -66,11 +66,11 @@ func.func @sparse_convert(%arg0: tensor<?xf32, #SparseVector64>) -> tensor<?xf32
// CHECK-LABEL: func.func @sparse_convert_permuted
// CHECK: sparse_tensor.foreach
-// CHECK: sparse_tensor.insert
+// CHECK: tensor.insert
// CHECK: sparse_tensor.load
// CHECK: sparse_tensor.reorder_coo
// CHECK: sparse_tensor.foreach
-// CHECK: sparse_tensor.insert
+// CHECK: tensor.insert
// CHECK: sparse_tensor.load
// CHECK: return
func.func @sparse_convert_permuted(%arg0: tensor<?x?x?xf32, #SortedCOO3D>) -> tensor<?x?x?xf32, #TsssPermuted> {
@@ -80,7 +80,7 @@ func.func @sparse_convert_permuted(%arg0: tensor<?x?x?xf32, #SortedCOO3D>) -> te
// CHECK-LABEL: func.func @sparse_convert_slice
// CHECK: sparse_tensor.foreach
-// CHECK: sparse_tensor.insert
+// CHECK: tensor.insert
// CHECK: sparse_tensor.load
// CHECK-NOT: sparse_tensor.reorder_coo
// CHECK: return
diff --git a/mlir/test/Dialect/SparseTensor/sparse_concat.mlir b/mlir/test/Dialect/SparseTensor/sparse_concat.mlir
index e4e2748112d78c4..86dc9a117507135 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_concat.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_concat.mlir
@@ -30,7 +30,7 @@
// CHECK: %[[RET_4:.*]] = scf.for %[[TMP_arg4:.*]] = %[[TMP_25]] to %[[TMP_26]] step %[[TMP_c1]] iter_args(%[[A1:.*]] = %[[A0]])
// CHECK: %[[TMP_27:.*]] = memref.load %[[TMP_4]][%[[TMP_arg4]]] : memref<?xindex>
// CHECK: %[[TMP_28:.*]] = memref.load %[[TMP_5]][%[[TMP_arg4]]] : memref<?xf64>
-// CHECK: %[[NEW_1:.*]] = sparse_tensor.insert %[[TMP_28]] into %[[A1]][%[[TMP_23]], %[[TMP_27]]] : tensor<9x4xf64, #sparse_tensor
+// CHECK: %[[NEW_1:.*]] = tensor.insert %[[TMP_28]] into %[[A1]][%[[TMP_23]], %[[TMP_27]]] : tensor<9x4xf64, #sparse_tensor
// CHECK: scf.yield %[[NEW_1]]
// CHECK: }
// CHECK: scf.yield %[[RET_4]]
@@ -51,7 +51,7 @@
// CHECK: %[[TMP_27:.*]] = memref.load %[[TMP_11]][%[[TMP_arg4]]] : memref<?xindex>
// CHECK: %[[TMP_28:.*]] = memref.load %[[TMP_12]][%[[TMP_arg4]]] : memref<?xf64>
// CHECK: %[[TMP_29:.*]] = arith.addi %[[TMP_23]], %[[TMP_c2]] : index
-// CHECK: %[[NEW_2:.*]] = sparse_tensor.insert %[[TMP_28]] into %[[A3]][%[[TMP_29]], %[[TMP_27]]] : tensor<9x4xf64, #sparse_tensor
+// CHECK: %[[NEW_2:.*]] = tensor.insert %[[TMP_28]] into %[[A3]][%[[TMP_29]], %[[TMP_27]]] : tensor<9x4xf64, #sparse_tensor
// CHECK: scf.yield %[[NEW_2]]
// CHECK: }
// CHECK: scf.yield %[[RET_5]]
@@ -72,7 +72,7 @@
// CHECK: %[[TMP_27:.*]] = memref.load %[[TMP_18]][%[[TMP_arg4]]] : memref<?xindex>
// CHECK: %[[TMP_28:.*]] = memref.load %[[TMP_19]][%[[TMP_arg4]]] : memref<?xf64>
// CHECK: %[[TMP_29:.*]] = arith.addi %[[TMP_23]], %[[TMP_c5]] : index
-// CHECK: %[[NEW_3:.*]] = sparse_tensor.insert %[[TMP_28]] into %[[A5]][%[[TMP_29]], %[[TMP_27]]] : tensor<9x4xf64, #sparse_tensor
+// CHECK: %[[NEW_3:.*]] = tensor.insert %[[TMP_28]] into %[[A5]][%[[TMP_29]], %[[TMP_27]]] : tensor<9x4xf64, #sparse_tensor
// CHECK: scf.yield %[[NEW_3]]
// CHECK: }
// CHECK: scf.yield %[[RET_6]]
@@ -116,7 +116,7 @@ func.func @concat_sparse_sparse(%arg0: tensor<2x4xf64, #DCSR>,
// CHECK: %[[RET_4:.*]] = scf.for %[[TMP_arg4:.*]] = %[[TMP_25]] to %[[TMP_26]] step %[[TMP_c1]] iter_args(%[[A1:.*]] = %[[A0]])
// CHECK: %[[TMP_27:.*]] = memref.load %[[TMP_4]][%[[TMP_arg4]]] : memref<?xindex>
// CHECK: %[[TMP_28:.*]] = memref.load %[[TMP_5]][%[[TMP_arg4]]] : memref<?xf64>
-// CHECK: %[[NEW_1:.*]] = sparse_tensor.insert %[[TMP_28]] into %[[A1]][%[[TMP_23]], %[[TMP_27]]] : tensor<?x?xf64, #sparse_tensor
+// CHECK: %[[NEW_1:.*]] = tensor.insert %[[TMP_28]] into %[[A1]][%[[TMP_23]], %[[TMP_27]]] : tensor<?x?xf64, #sparse_tensor
// CHECK: scf.yield %[[NEW_1]]
// CHECK: }
// CHECK: scf.yield %[[RET_4]]
@@ -137,7 +137,7 @@ func.func @concat_sparse_sparse(%arg0: tensor<2x4xf64, #DCSR>,
// CHECK: %[[TMP_27:.*]] = memref.load %[[TMP_11]][%[[TMP_arg4]]] : memref<?xindex>
// CHECK: %[[TMP_28:.*]] = memref.load %[[TMP_12]][%[[TMP_arg4]]] : memref<?xf64>
// CHECK: %[[TMP_29:.*]] = arith.addi %[[TMP_23]], %[[TMP_c2]] : index
-// CHECK: %[[NEW_2:.*]] = sparse_tensor.insert %[[TMP_28]] into %[[A3]][%[[TMP_29]], %[[TMP_27]]] : tensor<?x?xf64, #sparse_tensor
+// CHECK: %[[NEW_2:.*]] = tensor.insert %[[TMP_28]] into %[[A3]][%[[TMP_29]], %[[TMP_27]]] : tensor<?x?xf64, #sparse_tensor
// CHECK: scf.yield %[[NEW_2]]
// CHECK: }
// CHECK: scf.yield %[[RET_5]]
@@ -158,7 +158,7 @@ func.func @concat_sparse_sparse(%arg0: tensor<2x4xf64, #DCSR>,
// CHECK: %[[TMP_27:.*]] = memref.load %[[TMP_18]][%[[TMP_arg4]]] : memref<?xindex>
// CHECK: %[[TMP_28:.*]] = memref.load %[[TMP_19]][%[[TMP_arg4]]] : memref<?xf64>
// CHECK: %[[TMP_29:.*]] = arith.addi %[[TMP_23]], %[[TMP_c5]] : index
-// CHECK: %[[NEW_3:.*]] = sparse_tensor.insert %[[TMP_28]] into %[[A5]][%[[TMP_29]], %[[TMP_27]]] : tensor<?x?xf64, #sparse_tensor
+// CHECK: %[[NEW_3:.*]] = tensor.insert %[[TMP_28]] into %[[A5]][%[[TMP_29]], %[[TMP_27]]] : tensor<?x?xf64, #sparse_tensor
// CHECK: scf.yield %[[NEW_3]]
// CHECK: }
// CHECK: scf.yield %[[RET_6]]
More information about the Mlir-commits
mailing list