[Mlir-commits] [mlir] [mlir][sparse] reuse tensor.insert operation to insert elements into … (PR #84987)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Mar 12 14:47:09 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-sparse

@llvm/pr-subscribers-mlir

Author: Peiming Liu (PeimingLiu)

<details>
<summary>Changes</summary>

…a sparse tensor.

---

Patch is 48.77 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/84987.diff


26 Files Affected:

- (modified) mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td (-42) 
- (modified) mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp (-9) 
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/BufferizableOpInterfaceImpl.cpp (-22) 
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp (+2-2) 
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp (+11-7) 
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp (+14-7) 
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp (+4-2) 
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp (+3-2) 
- (modified) mlir/test/Dialect/SparseTensor/codegen.mlir (+3-3) 
- (modified) mlir/test/Dialect/SparseTensor/constant_index_map.mlir (+1-1) 
- (modified) mlir/test/Dialect/SparseTensor/conversion.mlir (+1-1) 
- (modified) mlir/test/Dialect/SparseTensor/invalid.mlir (-18) 
- (modified) mlir/test/Dialect/SparseTensor/roundtrip.mlir (+2-2) 
- (modified) mlir/test/Dialect/SparseTensor/sparse_2d.mlir (+5-5) 
- (modified) mlir/test/Dialect/SparseTensor/sparse_broadcast.mlir (+1-1) 
- (modified) mlir/test/Dialect/SparseTensor/sparse_conv_2d_slice_based.mlir (+1-1) 
- (modified) mlir/test/Dialect/SparseTensor/sparse_fp_ops.mlir (+2-2) 
- (modified) mlir/test/Dialect/SparseTensor/sparse_index.mlir (+1-1) 
- (modified) mlir/test/Dialect/SparseTensor/sparse_out.mlir (+2-2) 
- (modified) mlir/test/Dialect/SparseTensor/sparse_reinterpret_map.mlir (+1-1) 
- (modified) mlir/test/Dialect/SparseTensor/sparse_reshape.mlir (+4-4) 
- (modified) mlir/test/Dialect/SparseTensor/sparse_tensor_reshape.mlir (+1-1) 
- (modified) mlir/test/Dialect/SparseTensor/sparse_transpose.mlir (+1-1) 
- (modified) mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_insert_1d.mlir (+5-5) 
- (modified) mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_insert_2d.mlir (+20-20) 
- (modified) mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_insert_3d.mlir (+20-20) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index feed15d6af0544..0498576fcffc51 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -668,48 +668,6 @@ def SparseTensor_CrdTranslateOp : SparseTensor_Op<"crd_translate", [Pure]>,
 // refined over time as our sparse abstractions evolve.
 //===----------------------------------------------------------------------===//
 
-def SparseTensor_InsertOp : SparseTensor_Op<"insert",
-    [TypesMatchWith<"value type matches element type of tensor",
-                    "tensor", "value",
-                    "::llvm::cast<ShapedType>($_self).getElementType()">,
-     AllTypesMatch<["tensor", "result"]>]>,
-    Arguments<(ins AnyType:$value,
-               AnySparseTensor:$tensor,
-               Variadic<Index>:$lvlCoords)>,
-    Results<(outs AnySparseTensor:$result)> {
-  string summary = "Inserts a value into the sparse tensor";
-  string description = [{
-    Inserts the value into the underlying storage of the tensor at the
-    given level-coordinates. The arity of `lvlCoords` must match the
-    level-rank of the tensor. This operation can only be applied when
-    the tensor materializes unintialized from a `tensor.empty` operation
-    and the final tensor is constructed with a `load` operation which
-    has the `hasInserts` attribute set.
-
-    The level-properties of the sparse tensor type fully describe what
-    kind of insertion order is allowed.  When all levels have "unique"
-    and "ordered" properties, for example, insertions should occur in
-    strict lexicographical level-coordinate order.  Other properties
-    define different insertion regimens.  Inserting in a way contrary
-    to these properties results in undefined behavior.
-
-    Note that this operation is "impure" in the sense that even though
-    the result is modeled through an SSA value, the insertion is eventually
-    done "in place", and referencing the old SSA value is undefined behavior.
-    This operation is scheduled to be unified with the dense counterpart
-    `tensor.insert` that has pure SSA semantics.
-
-    Example:
-
-    ```mlir
-    %result = sparse_tensor.insert %val into %tensor[%i,%j] : tensor<1024x1024xf64, #CSR>
-    ```
-  }];
-  let assemblyFormat = "$value `into` $tensor `[` $lvlCoords `]` attr-dict"
-                       "`:` type($tensor)";
-  let hasVerifier = 1;
-}
-
 def SparseTensor_PushBackOp : SparseTensor_Op<"push_back",
     [TypesMatchWith<"value type matches element type of inBuffer",
                     "inBuffer", "value",
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index c19907a945d3bb..7750efdd9add0f 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -1741,15 +1741,6 @@ LogicalResult ConcatenateOp::verify() {
   return success();
 }
 
-LogicalResult InsertOp::verify() {
-  const auto stt = getSparseTensorType(getTensor());
-  if (stt.getEncoding().getBatchLvlRank() > 0)
-    return emitOpError("batched sparse tensor insertion not implemented");
-  if (stt.getLvlRank() != static_cast<Level>(getLvlCoords().size()))
-    return emitOpError("incorrect number of coordinates");
-  return success();
-}
-
 void PushBackOp::build(OpBuilder &builder, OperationState &result,
                        Value curSize, Value inBuffer, Value value) {
   build(builder, result, curSize, inBuffer, value, Value());
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/BufferizableOpInterfaceImpl.cpp
index 3f4ae1f67de150..a942a721e218fe 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -187,27 +187,6 @@ struct DisassembleOpInterface
   }
 };
 
-struct InsertOpInterface : public SparseBufferizableOpInterfaceExternalModel<
-                               InsertOpInterface, sparse_tensor::InsertOp> {
-  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
-                              const AnalysisState &state) const {
-    return true;
-  }
-
-  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
-                               const AnalysisState &state) const {
-    // InsertOp writes to memory.
-    return true;
-  }
-
-  AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
-                                      const AnalysisState &state) const {
-    // InsertOp returns an alias of its operand.
-    assert(op->getNumResults() == 1);
-    return {{op->getOpResult(0), BufferRelation::Equivalent}};
-  }
-};
-
 struct NumberOfEntriesOpInterface
     : public SparseBufferizableOpInterfaceExternalModel<
           NumberOfEntriesOpInterface, sparse_tensor::NumberOfEntriesOp> {
@@ -324,7 +303,6 @@ void mlir::sparse_tensor::registerBufferizableOpInterfaceExternalModels(
     sparse_tensor::ConvertOp::attachInterface<ConvertOpInterface>(*ctx);
     sparse_tensor::LoadOp::attachInterface<LoadOpInterface>(*ctx);
     sparse_tensor::NewOp::attachInterface<NewOpInterface>(*ctx);
-    sparse_tensor::InsertOp::attachInterface<InsertOpInterface>(*ctx);
     sparse_tensor::NumberOfEntriesOp::attachInterface<
         NumberOfEntriesOpInterface>(*ctx);
     sparse_tensor::AssembleOp::attachInterface<AssembleOpInterface>(*ctx);
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
index fbe2fc31ab8b15..f93b59de29e57b 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
@@ -640,14 +640,14 @@ struct TensorInsertDemapper
   using DemapInsRewriter::DemapInsRewriter;
   LogicalResult rewriteOp(tensor::InsertOp op, OpAdaptor adaptor,
                           PatternRewriter &rewriter) const {
-    if (!hasAnySparseResult(op))
+    if (!hasAnySparseResult(op) || !hasAnyNonIdentityOperandsOrResults(op))
       return failure();
 
     Location loc = op.getLoc();
     auto stt = getSparseTensorType(op.getResult());
     ValueRange lvlCrd = stt.translateCrds(rewriter, loc, op.getIndices(),
                                           CrdTransDirectionKind::dim2lvl);
-    auto insertOp = rewriter.create<sparse_tensor::InsertOp>(
+    auto insertOp = rewriter.create<tensor::InsertOp>(
         loc, op.getScalar(), adaptor.getDest(), lvlCrd);
 
     Value out = genRemap(rewriter, stt.getEncoding(), insertOp.getResult());
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index 44c5d4dbe485bf..56c51c1ab1a84f 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -1014,24 +1014,28 @@ class SparseCompressConverter : public OpConversionPattern<CompressOp> {
 };
 
 /// Sparse codegen rule for the insert operator.
-class SparseInsertConverter : public OpConversionPattern<InsertOp> {
+class SparseInsertConverter : public OpConversionPattern<tensor::InsertOp> {
 public:
   using OpConversionPattern::OpConversionPattern;
   LogicalResult
-  matchAndRewrite(InsertOp op, OpAdaptor adaptor,
+  matchAndRewrite(tensor::InsertOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
+    if (auto stt = getSparseTensorType(adaptor.getDest()); !stt.hasEncoding()) {
+      assert(stt.isIdentity() && "Run reinterpret-map before conversion.");
+      return failure();
+    }
     Location loc = op.getLoc();
-    auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
+    auto desc = getDescriptorFromTensorTuple(adaptor.getDest());
     TypeRange flatSpTensorTps = desc.getFields().getTypes();
     SmallVector<Value> params = llvm::to_vector(desc.getFields());
-    params.append(adaptor.getLvlCoords().begin(), adaptor.getLvlCoords().end());
-    params.push_back(adaptor.getValue());
-    SparseInsertGenerator insertGen(op.getTensor().getType(), flatSpTensorTps,
+    params.append(adaptor.getIndices().begin(), adaptor.getIndices().end());
+    params.push_back(adaptor.getScalar());
+    SparseInsertGenerator insertGen(op.getDest().getType(), flatSpTensorTps,
                                     params, /*genCall=*/true);
     SmallVector<Value> ret = insertGen.genCallOrInline(rewriter, loc);
     // Replace operation with resulting memrefs.
     rewriter.replaceOp(op,
-                       genTuple(rewriter, loc, op.getTensor().getType(), ret));
+                       genTuple(rewriter, loc, op.getDest().getType(), ret));
     return success();
   }
 };
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
index 010c3aa58b72c7..0937c10f257283 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
@@ -580,17 +580,24 @@ class SparseTensorLoadConverter : public OpConversionPattern<LoadOp> {
 };
 
 /// Sparse conversion rule for the insertion operator.
-class SparseTensorInsertConverter : public OpConversionPattern<InsertOp> {
+class SparseTensorInsertConverter
+    : public OpConversionPattern<tensor::InsertOp> {
 public:
   using OpConversionPattern::OpConversionPattern;
   LogicalResult
-  matchAndRewrite(InsertOp op, OpAdaptor adaptor,
+  matchAndRewrite(tensor::InsertOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     // Note that the current regime only allows for strict lexicographic
     // coordinate order. All values are passed by reference through stack
     // allocated memrefs.
     Location loc = op->getLoc();
-    const auto stt = getSparseTensorType(op.getTensor());
+    const auto stt = getSparseTensorType(op.getDest());
+
+    // Dense tensor insertion.
+    if (!stt.hasEncoding())
+      return failure();
+
+    assert(stt.isIdentity() && "Run reinterpret-map before conversion.");
     const auto elemTp = stt.getElementType();
     const Level lvlRank = stt.getLvlRank();
     Value lvlCoords, vref;
@@ -608,12 +615,12 @@ class SparseTensorInsertConverter : public OpConversionPattern<InsertOp> {
       lvlCoords = genAlloca(rewriter, loc, lvlRank, rewriter.getIndexType());
       vref = genAllocaScalar(rewriter, loc, elemTp);
     }
-    storeAll(rewriter, loc, lvlCoords, adaptor.getLvlCoords());
-    rewriter.create<memref::StoreOp>(loc, adaptor.getValue(), vref);
+    storeAll(rewriter, loc, lvlCoords, adaptor.getIndices());
+    rewriter.create<memref::StoreOp>(loc, adaptor.getScalar(), vref);
     SmallString<12> name{"lexInsert", primaryTypeFunctionSuffix(elemTp)};
     createFuncCall(rewriter, loc, name, {},
-                   {adaptor.getTensor(), lvlCoords, vref}, EmitCInterface::On);
-    rewriter.replaceOp(op, adaptor.getTensor());
+                   {adaptor.getDest(), lvlCoords, vref}, EmitCInterface::On);
+    rewriter.replaceOp(op, adaptor.getDest());
     return success();
   }
 };
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index a65bce78d095cf..17f70d0796ccfc 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -817,7 +817,8 @@ struct TensorReshapeRewriter : public OpRewritePattern<tensor::ReshapeOp> {
           reshapeCvs(builder, loc, expandReass, collapsedSizes, collapsedDcvs,
                      dstSizes, dstDcvs);
 
-          auto t = builder.create<InsertOp>(loc, v, reduc.front(), dstDcvs);
+          auto t =
+              builder.create<tensor::InsertOp>(loc, v, reduc.front(), dstDcvs);
           builder.create<sparse_tensor::YieldOp>(loc, t);
         });
 
@@ -901,7 +902,8 @@ struct Sparse2SparseReshapeRewriter : public OpRewritePattern<ReshapeOp> {
           SmallVector<Value> dstDcvs;
           reshapeCvs(builder, loc, op.getReassociationIndices(), srcSizes,
                      srcDcvs, dstSizes, dstDcvs);
-          auto t = builder.create<InsertOp>(loc, v, reduc.front(), dstDcvs);
+          auto t =
+              builder.create<tensor::InsertOp>(loc, v, reduc.front(), dstDcvs);
           builder.create<sparse_tensor::YieldOp>(loc, t);
         });
 
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index 1fb70ed5035c03..cd046b670d9a8e 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -428,7 +428,7 @@ static void genInsertionStore(CodegenEnv &env, OpBuilder &builder, OpOperand *t,
           /*else=*/true);
       // True branch.
       builder.setInsertionPointToStart(ifValidLexInsert.thenBlock());
-      Value res = builder.create<InsertOp>(loc, rhs, chain, ivs);
+      Value res = builder.create<tensor::InsertOp>(loc, rhs, chain, ivs);
       builder.create<scf::YieldOp>(loc, res);
       // False branch.
       builder.setInsertionPointToStart(ifValidLexInsert.elseBlock());
@@ -438,7 +438,8 @@ static void genInsertionStore(CodegenEnv &env, OpBuilder &builder, OpOperand *t,
       env.updateInsertionChain(ifValidLexInsert.getResult(0));
     } else {
       // Generates regular insertion chain.
-      env.updateInsertionChain(builder.create<InsertOp>(loc, rhs, chain, ivs));
+      env.updateInsertionChain(
+          builder.create<tensor::InsertOp>(loc, rhs, chain, ivs));
     }
     return;
   }
diff --git a/mlir/test/Dialect/SparseTensor/codegen.mlir b/mlir/test/Dialect/SparseTensor/codegen.mlir
index b63762485c961f..40bfa1e4e2a501 100644
--- a/mlir/test/Dialect/SparseTensor/codegen.mlir
+++ b/mlir/test/Dialect/SparseTensor/codegen.mlir
@@ -643,7 +643,7 @@ func.func @sparse_compression_unordered(%tensor: tensor<8x8xf64, #UCSR>,
 //       CHECK: %[[R:.*]]:4 = call @_insert_compressed_128_f64_0_0(%[[A1]], %[[A2]], %[[A3]], %[[A4]], %[[A5]], %[[A6]])
 //       CHECK: return %[[R]]#0, %[[R]]#1, %[[R]]#2, %[[R]]#3
 func.func @sparse_insert(%arg0: tensor<128xf64, #SV>, %arg1: index, %arg2: f64) -> tensor<128xf64, #SV> {
-  %0 = sparse_tensor.insert %arg2 into %arg0[%arg1] : tensor<128xf64, #SV>
+  %0 = tensor.insert %arg2 into %arg0[%arg1] : tensor<128xf64, #SV>
   %1 = sparse_tensor.load %0 hasInserts : tensor<128xf64, #SV>
   return %1 : tensor<128xf64, #SV>
 }
@@ -666,7 +666,7 @@ func.func @sparse_insert(%arg0: tensor<128xf64, #SV>, %arg1: index, %arg2: f64)
 //       CHECK: %[[R:.*]]:4 = call @_insert_compressed_128_f64_64_32(%[[A1]], %[[A2]], %[[A3]], %[[A4]], %[[A5]], %[[A6]])
 //       CHECK: return %[[R]]#0, %[[R]]#1, %[[R]]#2, %[[R]]#3
 func.func @sparse_insert_typed(%arg0: tensor<128xf64, #SparseVector>, %arg1: index, %arg2: f64) -> tensor<128xf64, #SparseVector> {
-  %0 = sparse_tensor.insert %arg2 into %arg0[%arg1] : tensor<128xf64, #SparseVector>
+  %0 = tensor.insert %arg2 into %arg0[%arg1] : tensor<128xf64, #SparseVector>
   %1 = sparse_tensor.load %0 hasInserts : tensor<128xf64, #SparseVector>
   return %1 : tensor<128xf64, #SparseVector>
 }
@@ -690,7 +690,7 @@ func.func @sparse_insert_typed(%arg0: tensor<128xf64, #SparseVector>, %arg1: ind
 //       CHECK: %[[R:.*]]:4 = call @_insert_compressed_nonunique_singleton_5_6_f64_0_0(%[[A0]], %[[A1]], %[[A2]], %[[A3]], %[[A4]], %[[A4]], %[[A5]])
 //       CHECK: return %[[R]]#0, %[[R]]#1, %[[R]]#2, %[[R]]#3
 func.func @sparse_insert_coo(%arg0: tensor<5x6xf64, #Coo>, %arg1: index, %arg2: f64) -> tensor<5x6xf64, #Coo> {
-  %0 = sparse_tensor.insert %arg2 into %arg0[%arg1, %arg1] : tensor<5x6xf64, #Coo>
+  %0 = tensor.insert %arg2 into %arg0[%arg1, %arg1] : tensor<5x6xf64, #Coo>
   %1 = sparse_tensor.load %0 hasInserts : tensor<5x6xf64, #Coo>
   return %1 : tensor<5x6xf64, #Coo>
 }
diff --git a/mlir/test/Dialect/SparseTensor/constant_index_map.mlir b/mlir/test/Dialect/SparseTensor/constant_index_map.mlir
index eaef6a31585292..f9559ce648c783 100644
--- a/mlir/test/Dialect/SparseTensor/constant_index_map.mlir
+++ b/mlir/test/Dialect/SparseTensor/constant_index_map.mlir
@@ -20,7 +20,7 @@
 // CHECK:             %[[VAL_11:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_3]], %[[VAL_9]]] : memref<1x77xi1>
 // CHECK:             %[[VAL_12:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_3]], %[[VAL_9]]] : memref<1x77xi1>
 // CHECK:             %[[VAL_13:.*]] = arith.addi %[[VAL_11]], %[[VAL_12]] : i1
-// CHECK:             %[[VAL_14:.*]] = sparse_tensor.insert %[[VAL_13]] into %[[VAL_10]]{{\[}}%[[VAL_9]]] : tensor<77xi1, #{{.*}}>
+// CHECK:             %[[VAL_14:.*]] = tensor.insert %[[VAL_13]] into %[[VAL_10]]{{\[}}%[[VAL_9]]] : tensor<77xi1, #{{.*}}>
 // CHECK:             scf.yield %[[VAL_14]] : tensor<77xi1, #{{.*}}>
 // CHECK:           }
 // CHECK:           %[[VAL_15:.*]] = sparse_tensor.load %[[VAL_16:.*]] hasInserts : tensor<77xi1, #{{.*}}>
diff --git a/mlir/test/Dialect/SparseTensor/conversion.mlir b/mlir/test/Dialect/SparseTensor/conversion.mlir
index 465f2108626606..f23f6ac4f181e2 100644
--- a/mlir/test/Dialect/SparseTensor/conversion.mlir
+++ b/mlir/test/Dialect/SparseTensor/conversion.mlir
@@ -318,7 +318,7 @@ func.func @sparse_reconstruct_ins(%arg0: tensor<128xf32, #SparseVector>) -> tens
 func.func @sparse_insert(%arg0: tensor<128xf32, #SparseVector>,
                          %arg1: index,
                          %arg2: f32) -> tensor<128xf32, #SparseVector> {
-  %0 = sparse_tensor.insert %arg2 into %arg0[%arg1] : tensor<128xf32, #SparseVector>
+  %0 = tensor.insert %arg2 into %arg0[%arg1] : tensor<128xf32, #SparseVector>
   return %0 : tensor<128xf32, #SparseVector>
 }
 
diff --git a/mlir/test/Dialect/SparseTensor/invalid.mlir b/mlir/test/Dialect/SparseTensor/invalid.mlir
index eac97f702f58bd..48f28ef390ed53 100644
--- a/mlir/test/Dialect/SparseTensor/invalid.mlir
+++ b/mlir/test/Dialect/SparseTensor/invalid.mlir
@@ -290,24 +290,6 @@ func.func @sparse_unannotated_load(%arg0: tensor<16x32xf64>) -> tensor<16x32xf64
 
 // -----
 
-func.func @sparse_unannotated_insert(%arg0: tensor<128xf64>, %arg1: index, %arg2: f64) {
-  // expected-error at +1 {{'sparse_tensor.insert' 'tensor' must be sparse tensor of any type values, but got 'tensor<128xf64>'}}
-  sparse_tensor.insert %arg2 into %arg0[%arg1] : tensor<128xf64>
-  return
-}
-
-// -----
-
-#CSR = #sparse_tensor.encoding<{map = (d0, d1) -> (d0 : dense, d1 : compressed)}>
-
-func.func @sparse_wrong_arity_insert(%arg0: tensor<128x64xf64, #CSR>, %arg1: index, %arg2: f64) {
-  // expected-error at +1 {{'sparse_tensor.insert' op incorrect number of coordinates}}
-  sparse_tensor.insert %arg2 into %arg0[%arg1] : tensor<128x64xf64, #CSR>
-  return
-}
-
-// -----
-
 func.func @sparse_push_back(%arg0: index, %arg1: memref<?xf64>, %arg2: f32) -> (memref<?xf64>, index) {
   // expected-error at +1 {{'sparse_tensor.push_back' op failed to verify that value type matches element type of inBuffer}}
   %0:2 = sparse_tensor.push_back %arg0, %arg1, %arg2 : index, memref<?xf64>, f32
diff --git a/mlir/test/Dialect/SparseTensor/roundtrip.mlir b/mlir/test/Dialect/SparseTensor/roundtrip.mlir
index 41094fbad9218f..e9e458e805ba47 100644
--- a/mlir/test/Dialect/SparseTensor/roundtrip.mlir
+++ b/mlir/test/Dialect/SparseTensor/roundtrip.mlir
@@ -311,10 +311,10 @@ func.func @sparse_load_ins(%arg0: tensor<16x32xf64, #DenseMatrix>) -> tensor<16x
 //  CHECK-SAME: %[[A:.*]]: tensor<128xf64, #sparse{{[0-9]*}}>,
 //  CHECK-SAME: %[[B:.*]]: index,
 //  CHECK-SAME: %[[C:.*]]: f64)
-//       CHECK: %[[T:.*]] = sparse_tensor.insert %[[C]] into %[[A]][%[[B]]] : tensor<128xf64, #{{.*}}>
+//       CHECK: %[[T:.*]] = tensor.insert %[[C]] into %[[A]][%[[B]]] : tensor<128xf64, #{{.*}}>
 //       CHECK: return %[[T]] : tensor<128xf64, #{{.*}}>
 func.func @sparse_insert(%arg0: tensor<128xf64, #SparseVector>, %arg1: index, %arg2: f64) -> tensor<128xf64, #SparseVector> {
-  %0 = sparse_tensor.insert %arg2 into %arg0[%arg1] : tensor<128xf64, #SparseVector>
+  ...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/84987


More information about the Mlir-commits mailing list