[Mlir-commits] [mlir] 9f596a7 - [mlir][sparse] implement simple codegen for insertion (and related ops)
Aart Bik
llvmlistbot at llvm.org
Mon Oct 17 18:02:16 PDT 2022
Author: Aart Bik
Date: 2022-10-17T18:02:08-07:00
New Revision: 9f596a7c67cc10fbb624f8b55bf9637e1dddd1c2
URL: https://github.com/llvm/llvm-project/commit/9f596a7c67cc10fbb624f8b55bf9637e1dddd1c2
DIFF: https://github.com/llvm/llvm-project/commit/9f596a7c67cc10fbb624f8b55bf9637e1dddd1c2.diff
LOG: [mlir][sparse] implement simple codegen for insertion (and related ops)
This is a proof of concept insertion implementation that sets up
the basic framework and implements it with push backs for just
sparse vectors. It adds insertion/compression through SSA values,
so that we properly update the memref after after pushback operation.
Note that properly using SSA values in sparsification is still TBD
but I will wait until Peiming's loop emitter is in to avoid conflicts.
Reviewed By: wrengr
Differential Revision: https://reviews.llvm.org/D136008
Added:
Modified:
mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
mlir/test/Dialect/SparseTensor/codegen.mlir
mlir/test/Dialect/SparseTensor/conversion.mlir
mlir/test/Dialect/SparseTensor/roundtrip.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index fef73ada1743f..14e5af0403843 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -190,20 +190,21 @@ def SparseTensor_ConcatenateOp : SparseTensor_Op<"concatenate", [Pure]>,
//===----------------------------------------------------------------------===//
// Sparse Tensor Management Operations. These operations are "impure" in the
-// sense that they do not properly operate on SSA values. Instead, the behavior
-// is solely defined by side-effects. These operations provide a bridge between
-// "sparsification" on one hand and a support library or actual code generation
-// on the other hand. The semantics of these operations may be refined over time
-// as our sparse abstractions evolve.
+// sense that some behavior is defined by side-effects. These operations provide
+// a bridge between "sparsification" on one hand and a support library or actual
+// code generation on the other hand. The semantics of these operations may be
+// refined over time as our sparse abstractions evolve.
//===----------------------------------------------------------------------===//
def SparseTensor_InsertOp : SparseTensor_Op<"insert",
[TypesMatchWith<"value type matches element type of tensor",
"tensor", "value",
- "$_self.cast<ShapedType>().getElementType()">]>,
+ "$_self.cast<ShapedType>().getElementType()">,
+ AllTypesMatch<["tensor", "result"]>]>,
Arguments<(ins AnyType:$value,
AnySparseTensor:$tensor,
- Variadic<Index>:$indices)> {
+ Variadic<Index>:$indices)>,
+ Results<(outs AnySparseTensor:$result)> {
string summary = "Inserts a value into given sparse tensor";
string description = [{
Inserts the given value at given indices into the underlying
@@ -221,19 +222,19 @@ def SparseTensor_InsertOp : SparseTensor_Op<"insert",
diff erent insertion regimens. Inserting in a way contrary to
these properties results in undefined behavior.
- Note that this operation is "impure" in the sense that its behavior
- is solely defined by side-effects and not SSA values. The semantics
- may be refined over time as our sparse abstractions evolve. In
- particular, this operation is scheduled to be unified with the
- dense counterpart `tensor.insert` that has proper SSA semantics.
+ Note that this operation is "impure" in the sense that even though
+ the result is modeled through an SSA value, the insertion is eventually
+ done "in place", and referencing the old SSA value is undefined behavior.
+ This operation is scheduled to be unified with the dense counterpart
+ `tensor.insert` that has pure SSA semantics.
Example:
```mlir
- sparse_tensor.insert %val into %tensor[%i,%j] : tensor<1024x1024xf64, #CSR>
+ %result = sparse_tensor.insert %val into %tensor[%i,%j] : tensor<1024x1024xf64, #CSR>
```
}];
- let assemblyFormat = "$value `into` $tensor `[` $indices `]` attr-dict`:` type($tensor)";
+ let assemblyFormat = "$value `into` $tensor `[` $indices `]` attr-dict `:` type($tensor)";
let hasVerifier = 1;
}
@@ -255,7 +256,8 @@ def SparseTensor_PushBackOp : SparseTensor_Op<"push_back", []>,
the code for capacity check and reallocation. The typical usage will be for
"dynamic" sparse tensors for which a capacity can be set beforehand.
- The operation returns an SSA value for the memref. Referencing the memref
+ Note that this operation is "impure" in the sense that even though
+ the result is modeled through an SSA value, referencing the memref
through the old SSA value after this operation is undefined behavior.
Example:
@@ -302,9 +304,9 @@ def SparseTensor_ExpandOp : SparseTensor_Op<"expand", []>,
through an indirection using the added array, so that the operations are
kept proportional to the number of nonzeros.
- Note that this operation is "impure" in the sense that its behavior
- is solely defined by side-effects and not SSA values. The semantics
- may be refined over time as our sparse abstractions evolve.
+ Note that this operation is "impure" in the sense that even though the
+ results are modeled through SSA values, the operation relies on a proper
+ side-effecting context that sets and resets the expanded arrays.
Example:
@@ -317,13 +319,15 @@ def SparseTensor_ExpandOp : SparseTensor_Op<"expand", []>,
" `,` type($filled) `,` type($added)";
}
-def SparseTensor_CompressOp : SparseTensor_Op<"compress", []>,
+def SparseTensor_CompressOp : SparseTensor_Op<"compress",
+ [AllTypesMatch<["tensor", "result"]>]>,
Arguments<(ins AnyStridedMemRefOfRank<1>:$values,
StridedMemRefRankOf<[I1],[1]>:$filled,
StridedMemRefRankOf<[Index],[1]>:$added,
Index:$count,
AnySparseTensor:$tensor,
- Variadic<Index>:$indices)> {
+ Variadic<Index>:$indices)>,
+ Results<(outs AnySparseTensor:$result)> {
string summary = "Compressed an access pattern for insertion";
string description = [{
Finishes a single access pattern expansion by moving inserted elements
@@ -335,14 +339,14 @@ def SparseTensor_CompressOp : SparseTensor_Op<"compress", []>,
array, so that the operations are kept proportional to the number of
nonzeros. See the `sparse_tensor.expand` operation for more details.
- Note that this operation is "impure" in the sense that its behavior
- is solely defined by side-effects and not SSA values. The semantics
- may be refined over time as our sparse abstractions evolve.
+ Note that this operation is "impure" in the sense that even though
+ the result is modeled through an SSA value, the insertion is eventually
+ done "in place", and referencing the old SSA value is undefined behavior.
Example:
```mlir
- sparse_tensor.compress %values, %filled, %added, %count into %tensor[%i]
+ %result = sparse_tensor.compress %values, %filled, %added, %count into %tensor[%i]
: memref<?xf64>, memref<?xi1>, memref<?xindex>, tensor<4x4xf64, #CSR>
```
}];
@@ -372,14 +376,16 @@ def SparseTensor_LoadOp : SparseTensor_Op<"load", [SameOperandsAndResultType]>,
sparse storage format needs to be finalized. Otherwise, the operation
simply folds away.
- Note that this operation is "impure" in the sense that its behavior
- is solely defined by side-effects and not SSA values. The semantics
- may be refined over time as our sparse abstractions evolve.
+ Note that this operation is "impure" in the sense that even though
+ the result is modeled through an SSA value, the operation relies on
+ a proper context of materializing and inserting the tensor value.
- Example:
+ Examples:
```mlir
- %1 = sparse_tensor.load %0 : tensor<8xf64, #SV>
+ %result = sparse_tensor.load %tensor : tensor<8xf64, #SV>
+
+ %1 = sparse_tensor.load %0 hasInserts : tensor<16x32xf32, #CSR>
```
}];
let assemblyFormat = "$tensor (`hasInserts` $hasInserts^)? attr-dict `:` type($tensor)";
@@ -397,8 +403,7 @@ def SparseTensor_OutOp : SparseTensor_Op<"out", []>,
a buffer defined by a pointer.
Note that this operation is "impure" in the sense that its behavior
- is solely defined by side-effects and not SSA values. The semantics
- may be refined over time as our sparse abstractions evolve.
+ is solely defined by side-effects and not SSA values.
Example:
@@ -442,8 +447,7 @@ def SparseTensor_SortOp : SparseTensor_Op<"sort", [AttrSizedOperandSegments]>,
be used to implement the operator.
Note that this operation is "impure" in the sense that its behavior is
- solely defined by side-effects and not SSA values. The semantics may be
- refined over time as our sparse abstractions evolve.
+ solely defined by side-effects and not SSA values.
Example:
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index e18e1ceadb67b..5e5815b60061b 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -35,6 +35,18 @@ namespace {
// Helper methods.
//===----------------------------------------------------------------------===//
+/// Returns the "tuple" value of the adapted tensor.
+static UnrealizedConversionCastOp getTuple(Value tensor) {
+ return llvm::cast<UnrealizedConversionCastOp>(tensor.getDefiningOp());
+}
+
+/// Packs the given values as a "tuple" value.
+static Value genTuple(OpBuilder &rewriter, Location loc, Type tp,
+ ValueRange values) {
+ return rewriter.create<UnrealizedConversionCastOp>(loc, TypeRange(tp), values)
+ .getResult(0);
+}
+
/// Flatten a list of operands that may contain sparse tensors.
static void flattenOperands(ValueRange operands,
SmallVectorImpl<Value> &flattened) {
@@ -43,14 +55,13 @@ static void flattenOperands(ValueRange operands,
// ==>
// memref ..., c, memref ...
for (auto operand : operands) {
- if (auto cast =
- dyn_cast<UnrealizedConversionCastOp>(operand.getDefiningOp());
- cast && getSparseTensorEncoding(cast->getResultTypes()[0]))
+ if (auto tuple = getTuple(operand);
+ tuple && getSparseTensorEncoding(tuple->getResultTypes()[0]))
// An unrealized_conversion_cast will be inserted by type converter to
// inter-mix the gap between 1:N conversion between sparse tensors and
// fields. In this case, take the operands in the cast and replace the
// sparse tensor output with the flattened type array.
- flattened.append(cast.getOperands().begin(), cast.getOperands().end());
+ flattened.append(tuple.getOperands().begin(), tuple.getOperands().end());
else
flattened.push_back(operand);
}
@@ -73,8 +84,7 @@ static Optional<Value> sizeFromTensorAtDim(OpBuilder &rewriter, Location loc,
// Any other query can consult the dimSizes array at field 0 using,
// accounting for the reordering applied to the sparse storage.
- auto tuple =
- llvm::cast<UnrealizedConversionCastOp>(adaptedValue.getDefiningOp());
+ auto tuple = getTuple(adaptedValue);
Value idx = constantIndex(rewriter, loc, toStoredDim(tensorTp, dim));
return rewriter.create<memref::LoadOp>(loc, tuple.getInputs().front(), idx)
.getResult();
@@ -264,6 +274,54 @@ static scf::ForOp createFor(OpBuilder &builder, Location loc, Value count) {
return forOp;
}
+/// Creates a pushback op for given field and updates the fields array
+/// accordingly.
+static void createPushback(OpBuilder &builder, Location loc,
+ SmallVectorImpl<Value> &fields, unsigned field,
+ Value value) {
+ assert(field < fields.size());
+ fields[field] =
+ builder.create<PushBackOp>(loc, fields[field].getType(), fields[1],
+ fields[field], value, APInt(64, field));
+}
+
+/// Generates insertion code.
+//
+// TODO: generalize this for any rank and format currently it is just sparse
+// vectors as a proof of concept that we have everything in place!
+//
+static void genInsert(OpBuilder &builder, Location loc, RankedTensorType rtp,
+ SmallVectorImpl<Value> &fields,
+ SmallVectorImpl<Value> &indices, Value value) {
+ unsigned rank = indices.size();
+ assert(rtp.getShape().size() == rank);
+ if (rank != 1 || !isCompressedDim(rtp, 0) || !isUniqueDim(rtp, 0) ||
+ !isOrderedDim(rtp, 0))
+ return; // TODO: add codegen
+ // push_back memSizes pointers-0 0
+ // push_back memSizes indices-0 index
+ // push_back memSizes values value
+ Value zero = constantIndex(builder, loc, 0);
+ createPushback(builder, loc, fields, 2, zero);
+ createPushback(builder, loc, fields, 3, indices[0]);
+ createPushback(builder, loc, fields, 4, value);
+}
+
+/// Generations insertion finalization code.
+//
+// TODO: this too only works for the very simple case
+//
+static void genEndInsert(OpBuilder &builder, Location loc, RankedTensorType rtp,
+ SmallVectorImpl<Value> &fields) {
+ if (rtp.getShape().size() != 1 || !isCompressedDim(rtp, 0) ||
+ !isUniqueDim(rtp, 0) || !isOrderedDim(rtp, 0))
+ return; // TODO: add codegen
+ // push_back memSizes pointers-0 memSizes[2]
+ Value two = constantIndex(builder, loc, 2);
+ Value size = builder.create<memref::LoadOp>(loc, fields[1], two);
+ createPushback(builder, loc, fields, 2, size);
+}
+
//===----------------------------------------------------------------------===//
// Codegen rules.
//===----------------------------------------------------------------------===//
@@ -325,12 +383,10 @@ class SparseCallConverter : public OpConversionPattern<func::CallOp> {
assert(!sparseFlat.empty());
if (sparseFlat.size() > 1) {
auto flatSize = sparseFlat.size();
- ValueRange sparseElem(iterator_range<ResultRange::iterator>(
+ ValueRange fields(iterator_range<ResultRange::iterator>(
newCall.result_begin() + retOffset,
newCall.result_begin() + retOffset + flatSize));
- auto castOp = rewriter.create<UnrealizedConversionCastOp>(
- loc, TypeRange({retType}), sparseElem);
- castedRet.push_back(castOp.getResult(0));
+ castedRet.push_back(genTuple(rewriter, loc, retType, fields));
retOffset += flatSize;
} else {
// If this is an 1:1 conversion, no need for casting.
@@ -404,8 +460,7 @@ class SparseTensorAllocConverter
Location loc = op.getLoc();
SmallVector<Value, 8> fields;
createAllocFields(rewriter, loc, resType, adaptor.getOperands(), fields);
- rewriter.replaceOpWithNewOp<UnrealizedConversionCastOp>(
- op, TypeRange{resType}, fields);
+ rewriter.replaceOp(op, genTuple(rewriter, loc, resType, fields));
return success();
}
};
@@ -424,8 +479,7 @@ class SparseTensorDeallocConverter
// Replace the sparse tensor deallocation with field deallocations.
Location loc = op.getLoc();
- auto tuple = llvm::cast<UnrealizedConversionCastOp>(
- adaptor.getTensor().getDefiningOp());
+ auto tuple = getTuple(adaptor.getTensor());
for (auto input : tuple.getInputs())
// Deallocate every buffer used to store the sparse tensor handler.
rewriter.create<memref::DeallocOp>(loc, input);
@@ -442,11 +496,15 @@ class SparseTensorLoadConverter : public OpConversionPattern<LoadOp> {
LogicalResult
matchAndRewrite(LoadOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- if (op.getHasInserts()) {
- // Finalize any pending insertions.
- // TODO: implement
- }
- rewriter.replaceOp(op, adaptor.getOperands());
+ RankedTensorType srcType =
+ op.getTensor().getType().cast<RankedTensorType>();
+ auto tuple = getTuple(adaptor.getTensor());
+ // Prepare fields.
+ SmallVector<Value, 8> fields(tuple.getInputs());
+ // Generate optional insertion finalization code.
+ if (op.getHasInserts())
+ genEndInsert(rewriter, op.getLoc(), srcType, fields);
+ rewriter.replaceOp(op, genTuple(rewriter, op.getLoc(), srcType, fields));
return success();
}
};
@@ -514,10 +572,14 @@ class SparseCompressConverter : public OpConversionPattern<CompressOp> {
RankedTensorType dstType =
op.getTensor().getType().cast<RankedTensorType>();
Type eltType = dstType.getElementType();
+ auto tuple = getTuple(adaptor.getTensor());
Value values = adaptor.getValues();
Value filled = adaptor.getFilled();
Value added = adaptor.getAdded();
Value count = adaptor.getCount();
+ // Prepare fields and indices.
+ SmallVector<Value, 8> fields(tuple.getInputs());
+ SmallVector<Value, 8> indices(adaptor.getIndices());
// If the innermost dimension is ordered, we need to sort the indices
// in the "added" array prior to applying the compression.
unsigned rank = dstType.getShape().size();
@@ -532,21 +594,20 @@ class SparseCompressConverter : public OpConversionPattern<CompressOp> {
// for (i = 0; i < count; i++) {
// index = added[i];
// value = values[index];
- //
- // TODO: insert prev_indices, index, value
- //
+ // insert({prev_indices, index}, value);
// values[index] = 0;
// filled[index] = false;
// }
Value i = createFor(rewriter, loc, count).getInductionVar();
Value index = rewriter.create<memref::LoadOp>(loc, added, i);
- rewriter.create<memref::LoadOp>(loc, values, index);
- // TODO: insert
+ Value value = rewriter.create<memref::LoadOp>(loc, values, index);
+ indices.push_back(index);
+ // TODO: generate yield cycle
+ genInsert(rewriter, loc, dstType, fields, indices, value);
rewriter.create<memref::StoreOp>(loc, constantZero(rewriter, loc, eltType),
values, index);
rewriter.create<memref::StoreOp>(loc, constantI1(rewriter, loc, false),
filled, index);
-
// Deallocate the buffers on exit of the full loop nest.
Operation *parent = op;
for (; isa<scf::ForOp>(parent->getParentOp()) ||
@@ -559,7 +620,28 @@ class SparseCompressConverter : public OpConversionPattern<CompressOp> {
rewriter.create<memref::DeallocOp>(loc, values);
rewriter.create<memref::DeallocOp>(loc, filled);
rewriter.create<memref::DeallocOp>(loc, added);
- rewriter.eraseOp(op);
+ rewriter.replaceOp(op, genTuple(rewriter, loc, dstType, fields));
+ return success();
+ }
+};
+
+/// Sparse codegen rule for the insert operator.
+class SparseInsertConverter : public OpConversionPattern<InsertOp> {
+public:
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(InsertOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ RankedTensorType dstType =
+ op.getTensor().getType().cast<RankedTensorType>();
+ auto tuple = getTuple(adaptor.getTensor());
+ // Prepare fields and indices.
+ SmallVector<Value, 8> fields(tuple.getInputs());
+ SmallVector<Value, 8> indices(adaptor.getIndices());
+ // Generate insertion.
+ Value value = adaptor.getValue();
+ genInsert(rewriter, op->getLoc(), dstType, fields, indices, value);
+ rewriter.replaceOp(op, genTuple(rewriter, op.getLoc(), dstType, fields));
return success();
}
};
@@ -576,8 +658,7 @@ class SparseGetterOpConverter : public OpConversionPattern<SourceOp> {
// Replace the requested pointer access with corresponding field.
// The cast_op is inserted by type converter to intermix 1:N type
// conversion.
- auto tuple = llvm::cast<UnrealizedConversionCastOp>(
- adaptor.getTensor().getDefiningOp());
+ auto tuple = getTuple(adaptor.getTensor());
unsigned idx = Base::getIndexForOp(tuple, op);
auto fields = tuple.getInputs();
assert(idx < fields.size());
@@ -648,6 +729,7 @@ void mlir::populateSparseTensorCodegenPatterns(TypeConverter &typeConverter,
SparseCastConverter, SparseTensorAllocConverter,
SparseTensorDeallocConverter, SparseTensorLoadConverter,
SparseExpandConverter, SparseCompressConverter,
- SparseToPointersConverter, SparseToIndicesConverter,
- SparseToValuesConverter>(typeConverter, patterns.getContext());
+ SparseInsertConverter, SparseToPointersConverter,
+ SparseToIndicesConverter, SparseToValuesConverter>(
+ typeConverter, patterns.getContext());
}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
index c7a6048a8503f..e54e58626cd50 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
@@ -1071,9 +1071,9 @@ class SparseTensorInsertConverter : public OpConversionPattern<InsertOp> {
constantIndex(rewriter, loc, i));
rewriter.create<memref::StoreOp>(loc, adaptor.getValue(), vref);
SmallString<12> name{"lexInsert", primaryTypeFunctionSuffix(elemTp)};
- replaceOpWithFuncCall(rewriter, op, name, {},
- {adaptor.getTensor(), mref, vref},
- EmitCInterface::On);
+ createFuncCall(rewriter, loc, name, {}, {adaptor.getTensor(), mref, vref},
+ EmitCInterface::On);
+ rewriter.replaceOp(op, adaptor.getTensor());
return success();
}
};
@@ -1149,9 +1149,10 @@ class SparseTensorCompressConverter : public OpConversionPattern<CompressOp> {
rewriter.create<memref::StoreOp>(loc, adaptor.getIndices()[i], mref,
constantIndex(rewriter, loc, i));
SmallString<12> name{"expInsert", primaryTypeFunctionSuffix(elemTp)};
- replaceOpWithFuncCall(rewriter, op, name, {},
- {tensor, mref, values, filled, added, count},
- EmitCInterface::On);
+ createFuncCall(rewriter, loc, name, {},
+ {tensor, mref, values, filled, added, count},
+ EmitCInterface::On);
+ rewriter.replaceOp(op, adaptor.getTensor());
// Deallocate the buffers on exit of the loop nest.
Operation *parent = op;
for (; isa<scf::ForOp>(parent->getParentOp()) ||
diff --git a/mlir/test/Dialect/SparseTensor/codegen.mlir b/mlir/test/Dialect/SparseTensor/codegen.mlir
index b227cfc50f6ca..ecdf0e45beb55 100644
--- a/mlir/test/Dialect/SparseTensor/codegen.mlir
+++ b/mlir/test/Dialect/SparseTensor/codegen.mlir
@@ -1,5 +1,7 @@
// RUN: mlir-opt %s --sparse-tensor-codegen --canonicalize --cse | FileCheck %s
+#SV = #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ] }>
+
#SparseVector = #sparse_tensor.encoding<{
dimLevelType = [ "compressed" ],
indexBitWidth = 64,
@@ -383,10 +385,10 @@ func.func @sparse_compression(%tensor: tensor<8x8xf64, #CSR>,
%filled: memref<?xi1>,
%added: memref<?xindex>,
%count: index,
- %i: index) {
- sparse_tensor.compress %values, %filled, %added, %count into %tensor[%i]
+ %i: index) -> tensor<8x8xf64, #CSR> {
+ %0 = sparse_tensor.compress %values, %filled, %added, %count into %tensor[%i]
: memref<?xf64>, memref<?xi1>, memref<?xindex>, tensor<8x8xf64, #CSR>
- return
+ return %0 : tensor<8x8xf64, #CSR>
}
// CHECK-LABEL: func @sparse_compression_unordered(
@@ -420,8 +422,30 @@ func.func @sparse_compression_unordered(%tensor: tensor<8x8xf64, #UCSR>,
%filled: memref<?xi1>,
%added: memref<?xindex>,
%count: index,
- %i: index) {
- sparse_tensor.compress %values, %filled, %added, %count into %tensor[%i]
+ %i: index) -> tensor<8x8xf64, #UCSR> {
+ %0 = sparse_tensor.compress %values, %filled, %added, %count into %tensor[%i]
: memref<?xf64>, memref<?xi1>, memref<?xindex>, tensor<8x8xf64, #UCSR>
- return
+ return %0 : tensor<8x8xf64, #UCSR>
+}
+
+// CHECK-LABEL: func @sparse_insert(
+// CHECK-SAME: %[[A0:.*0]]: memref<1xindex>,
+// CHECK-SAME: %[[A1:.*1]]: memref<3xindex>,
+// CHECK-SAME: %[[A2:.*2]]: memref<?xindex>,
+// CHECK-SAME: %[[A3:.*3]]: memref<?xindex>,
+// CHECK-SAME: %[[A4:.*4]]: memref<?xf64>,
+// CHECK-SAME: %[[A5:.*5]]: index,
+// CHECK-SAME: %[[A6:.*6]]: f64)
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
+// CHECK: %[[T0:.*]] = sparse_tensor.push_back %[[A1]], %[[A2]], %[[C0]]
+// CHECK: %[[T1:.*]] = sparse_tensor.push_back %[[A1]], %[[A3]], %[[A5]]
+// CHECK: %[[T2:.*]] = sparse_tensor.push_back %[[A1]], %[[A4]], %[[A6]]
+// CHECK: %[[T3:.*]] = memref.load %[[A1]][%[[C2]]] : memref<3xindex>
+// CHECK: %[[T4:.*]] = sparse_tensor.push_back %[[A1]], %[[T0]], %[[T3]]
+// CHECK: return %[[A0]], %[[A1]], %[[T4]], %[[T1]], %[[T2]] : memref<1xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf64>
+func.func @sparse_insert(%arg0: tensor<128xf64, #SV>, %arg1: index, %arg2: f64) -> tensor<128xf64, #SV> {
+ %0 = sparse_tensor.insert %arg2 into %arg0[%arg1] : tensor<128xf64, #SV>
+ %1 = sparse_tensor.load %0 hasInserts : tensor<128xf64, #SV>
+ return %1 : tensor<128xf64, #SV>
}
diff --git a/mlir/test/Dialect/SparseTensor/conversion.mlir b/mlir/test/Dialect/SparseTensor/conversion.mlir
index 2694564b022cd..44fcd4219ec08 100644
--- a/mlir/test/Dialect/SparseTensor/conversion.mlir
+++ b/mlir/test/Dialect/SparseTensor/conversion.mlir
@@ -288,7 +288,7 @@ func.func @sparse_reconstruct_ins(%arg0: tensor<128xf32, #SparseVector>) -> tens
// CHECK-LABEL: func @sparse_insert(
// CHECK-SAME: %[[A:.*]]: !llvm.ptr<i8>,
// CHECK-SAME: %[[B:.*]]: index,
-// CHECK-SAME: %[[C:.*]]: f32) {
+// CHECK-SAME: %[[C:.*]]: f32) -> !llvm.ptr<i8> {
// CHECK-DAG: %[[M:.*]] = memref.alloca() : memref<1xindex>
// CHECK-DAG: %[[V:.*]] = memref.alloca() : memref<f32>
// CHECK-DAG: %[[MC:.*]] = memref.cast %[[M]] : memref<1xindex> to memref<?xindex>
@@ -296,12 +296,12 @@ func.func @sparse_reconstruct_ins(%arg0: tensor<128xf32, #SparseVector>) -> tens
// CHECK-DAG: memref.store %[[B]], %[[M]][%[[C0]]] : memref<1xindex>
// CHECK-DAG: memref.store %[[C]], %[[V]][] : memref<f32>
// CHECK: call @lexInsertF32(%[[A]], %[[MC]], %[[V]]) : (!llvm.ptr<i8>, memref<?xindex>, memref<f32>) -> ()
-// CHECK: return
+// CHECK: return %[[A]] : !llvm.ptr<i8>
func.func @sparse_insert(%arg0: tensor<128xf32, #SparseVector>,
%arg1: index,
- %arg2: f32) {
- sparse_tensor.insert %arg2 into %arg0[%arg1] : tensor<128xf32, #SparseVector>
- return
+ %arg2: f32) -> tensor<128xf32, #SparseVector> {
+ %0 = sparse_tensor.insert %arg2 into %arg0[%arg1] : tensor<128xf32, #SparseVector>
+ return %0 : tensor<128xf32, #SparseVector>
}
// CHECK-LABEL: func @sparse_expansion1()
@@ -359,7 +359,7 @@ func.func @sparse_expansion3(%arg0: index, %arg1: index) -> memref<?xindex> {
// CHECK-SAME: %[[C:.*2]]: memref<?xi1>,
// CHECK-SAME: %[[D:.*3]]: memref<?xindex>,
// CHECK-SAME: %[[E:.*4]]: index,
-// CHECK-SAME: %[[F:.*5]]: index)
+// CHECK-SAME: %[[F:.*5]]: index) -> !llvm.ptr<i8> {
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[X:.*]] = memref.alloca() : memref<2xindex>
// CHECK-DAG: %[[Y:.*]] = memref.cast %[[X]] : memref<2xindex> to memref<?xindex>
@@ -368,16 +368,16 @@ func.func @sparse_expansion3(%arg0: index, %arg1: index) -> memref<?xindex> {
// CHECK-DAG: memref.dealloc %[[B]] : memref<?xf64>
// CHECK-DAG: memref.dealloc %[[C]] : memref<?xi1>
// CHECK-DAG: memref.dealloc %[[D]] : memref<?xindex>
-// CHECK: return
+// CHECK: return %[[A]] : !llvm.ptr<i8>
func.func @sparse_compression(%tensor: tensor<8x8xf64, #CSR>,
%values: memref<?xf64>,
%filled: memref<?xi1>,
%added: memref<?xindex>,
%count: index,
- %i: index) {
- sparse_tensor.compress %values, %filled, %added, %count into %tensor[%i]
+ %i: index) -> tensor<8x8xf64, #CSR> {
+ %0 = sparse_tensor.compress %values, %filled, %added, %count into %tensor[%i]
: memref<?xf64>, memref<?xi1>, memref<?xindex>, tensor<8x8xf64, #CSR>
- return
+ return %0 : tensor<8x8xf64, #CSR>
}
// CHECK-LABEL: func @sparse_out1(
diff --git a/mlir/test/Dialect/SparseTensor/roundtrip.mlir b/mlir/test/Dialect/SparseTensor/roundtrip.mlir
index fd850aacacae7..ca3f8843825e6 100644
--- a/mlir/test/Dialect/SparseTensor/roundtrip.mlir
+++ b/mlir/test/Dialect/SparseTensor/roundtrip.mlir
@@ -122,12 +122,12 @@ func.func @sparse_load_ins(%arg0: tensor<16x32xf64, #DenseMatrix>) -> tensor<16x
// CHECK-LABEL: func @sparse_insert(
// CHECK-SAME: %[[A:.*]]: tensor<128xf64, #sparse_tensor.encoding<{{.*}}>>,
// CHECK-SAME: %[[B:.*]]: index,
-// CHECK-SAME: %[[C:.*]]: f64) {
-// CHECK: sparse_tensor.insert %[[C]] into %[[A]][%[[B]]] : tensor<128xf64, #{{.*}}>
-// CHECK: return
-func.func @sparse_insert(%arg0: tensor<128xf64, #SparseVector>, %arg1: index, %arg2: f64) {
- sparse_tensor.insert %arg2 into %arg0[%arg1] : tensor<128xf64, #SparseVector>
- return
+// CHECK-SAME: %[[C:.*]]: f64)
+// CHECK: %[[T:.*]] = sparse_tensor.insert %[[C]] into %[[A]][%[[B]]] : tensor<128xf64, #{{.*}}>
+// CHECK: return %[[T]] : tensor<128xf64, #{{.*}}>
+func.func @sparse_insert(%arg0: tensor<128xf64, #SparseVector>, %arg1: index, %arg2: f64) -> tensor<128xf64, #SparseVector> {
+ %0 = sparse_tensor.insert %arg2 into %arg0[%arg1] : tensor<128xf64, #SparseVector>
+ return %0 : tensor<128xf64, #SparseVector>
}
// -----
@@ -181,17 +181,17 @@ func.func @sparse_expansion(%tensor: tensor<8x8xf64, #SparseMatrix>) -> index {
// CHECK-SAME: %[[A3:.*3]]: index
// CHECK-SAME: %[[A4:.*4]]: tensor<8x8xf64, #sparse_tensor.encoding<{{.*}}>>,
// CHECK-SAME: %[[A5:.*5]]: index)
-// CHECK: sparse_tensor.compress %[[A0]], %[[A1]], %[[A2]], %[[A3]] into %[[A4]][%[[A5]]
-// CHECK: return
+// CHECK: %[[T:.*]] = sparse_tensor.compress %[[A0]], %[[A1]], %[[A2]], %[[A3]] into %[[A4]][%[[A5]]
+// CHECK: return %[[T]] : tensor<8x8xf64, #sparse_tensor.encoding<{{.*}}>>
func.func @sparse_compression(%values: memref<?xf64>,
%filled: memref<?xi1>,
%added: memref<?xindex>,
%count: index,
%tensor: tensor<8x8xf64, #SparseMatrix>,
- %index: index) {
- sparse_tensor.compress %values, %filled, %added, %count into %tensor[%index]
+ %index: index) -> tensor<8x8xf64, #SparseMatrix> {
+ %0 = sparse_tensor.compress %values, %filled, %added, %count into %tensor[%index]
: memref<?xf64>, memref<?xi1>, memref<?xindex>, tensor<8x8xf64, #SparseMatrix>
- return
+ return %0 : tensor<8x8xf64, #SparseMatrix>
}
// -----
More information about the Mlir-commits
mailing list