[Mlir-commits] [mlir] a63d6a0 - [mlir][sparse] make UnpackOp return the actual filled length of unpacked memory
Peiming Liu
llvmlistbot at llvm.org
Fri Jun 30 14:35:21 PDT 2023
Author: Peiming Liu
Date: 2023-06-30T21:35:15Z
New Revision: a63d6a00140d25289e0cba294c9228dbb74b06fa
URL: https://github.com/llvm/llvm-project/commit/a63d6a00140d25289e0cba294c9228dbb74b06fa
DIFF: https://github.com/llvm/llvm-project/commit/a63d6a00140d25289e0cba294c9228dbb74b06fa.diff
LOG: [mlir][sparse] make UnpackOp return the actual filled length of unpacked memory
This might simplify frontend implementation by avoiding recomputation for the same value.
Reviewed By: aartbik
Differential Revision: https://reviews.llvm.org/D154244
Added:
Modified:
mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
mlir/lib/Dialect/SparseTensor/Transforms/BufferizableOpInterfaceImpl.cpp
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
mlir/test/Dialect/SparseTensor/invalid.mlir
mlir/test/Dialect/SparseTensor/roundtrip.mlir
mlir/test/Dialect/SparseTensor/sparse_pack.mlir
mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index 9fd81b98a2dcd6..01320a2e4c5acb 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -102,18 +102,21 @@ def SparseTensor_PackOp : SparseTensor_Op<"pack", [Pure]>,
let hasVerifier = 1;
}
-def SparseTensor_UnpackOp : SparseTensor_Op<"unpack", [Pure]>,
+def SparseTensor_UnpackOp : SparseTensor_Op<"unpack", [Pure, SameVariadicResultSize]>,
Arguments<(ins AnySparseTensor:$tensor,
TensorOf<[AnyType]>:$out_values,
Variadic<TensorOf<[AnySignlessIntegerOrIndex]>>:$out_levels)>,
Results<(outs TensorOf<[AnyType]>:$ret_values,
- Variadic<TensorOf<[AnySignlessIntegerOrIndex]>>:$ret_levels)> {
+ Variadic<TensorOf<[AnySignlessIntegerOrIndex]>>:$ret_levels,
+ Index:$val_len,
+ Variadic<Index>:$lvl_lens)> {
let summary = "Returns the (values, coordinates) pair unpacked from the input tensor";
let description = [{
The unpack operation is the inverse of `sparse_tensor::pack`. It returns
the values and per-level position and coordinate array to the user
- from the sparse tensor. This operation can be used for returning an
+ from the sparse tensor along with the actual length of the memory used in
+ each returned buffer. This operation can be used for returning an
unpacked MLIR sparse tensor to frontend; e.g., returning two numpy arrays to Python.
Disclaimer: This is the user's responsibility to allocate large enough buffers
@@ -128,18 +131,22 @@ def SparseTensor_UnpackOp : SparseTensor_Op<"unpack", [Pure]>,
// input COO format |1.1, 0.0, 0.0, 0.0|
// of 3x4 matrix |0.0, 0.0, 2.2, 3.3|
// |0.0, 0.0, 0.0, 0.0|
- %values, %pos, %coords = sparse_tensor.unpack %sp : tensor<3x4xf64, #SparseVector>
- outs(%od, %op, %oi : tensor<3xf64>, tensor<2xindex>, tensor<3x2xindex>)
- -> tensor<3xf64>, tensor<2xindex>, tensor<3x2xindex>
- // %values = arith.constant dense<[ 1.1, 2.2, 3.3 ]> : tensor<3xf64>
- // %pos = arith.constant dense<[ 0, 3 ]> : tensor<2xindex>
- // %coordinates = arith.constant dense<[[0,0], [1,2], [1,3]]> : tensor<3x2xindex>
+ %v, %p, %c, %v_len, %p_len, %c_len = sparse_tensor.unpack %sp : tensor<3x4xf64, #SparseVector>
+ outs(%od, %op, %oi : tensor<3xf64>, tensor<2xindex>, tensor<3x2xindex>)
+ -> tensor<3xf64>, (tensor<2xindex>, tensor<3x2xindex>), index, (index, index)
+ // %v = arith.constant dense<[ 1.1, 2.2, 3.3 ]> : tensor<3xf64>
+ // %p = arith.constant dense<[ 0, 3 ]> : tensor<2xindex>
+ // %c = arith.constant dense<[[0,0], [1,2], [1,3]]> : tensor<3x2xindex>
+ // %v_len = 3
+ // %p_len = 2
+ // %c_len = 6 (3x2)
```
}];
let assemblyFormat =
- "$tensor `:` type($tensor) `outs` `(` $out_values `,` $out_levels `:` type($out_values) `,` type($out_levels) `)`"
- "attr-dict `->` type($ret_values) `,` type($ret_levels)";
+ "$tensor `:` type($tensor) "
+ "`outs` `(` $out_values `,` $out_levels `:` type($out_values) `,` type($out_levels) `)` attr-dict"
+ "`->` type($ret_values) `,` `(` type($ret_levels) `)` `,` type($val_len) `,` `(` type($lvl_lens) `)`";
let hasVerifier = 1;
}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/BufferizableOpInterfaceImpl.cpp
index e712c9396466bb..cd49205627ea7c 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -166,13 +166,13 @@ struct UnpackOpInterface
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
// We write into the output operand.
- assert(op->getNumOperands() == op->getNumResults() + 1);
+ assert(2 * (op->getNumOperands() - 1) == op->getNumResults());
return opOperand.getOperandNumber() > 0;
}
AliasingOpResultList getAliasingOpResults(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
- assert(op->getNumOperands() == op->getNumResults() + 1);
+ assert(2 * (op->getNumOperands() - 1) == op->getNumResults());
if (opOperand.getOperandNumber() == 0)
return {};
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index a7f37e8189ea03..73840ba1327572 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -1311,7 +1311,8 @@ struct SparseUnpackOpConverter : public OpConversionPattern<UnpackOp> {
auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
Location loc = op.getLoc();
SmallVector<Value> retMem;
- desc.getLayout().foreachField([desc, loc, &rewriter, &op, &retMem](
+ SmallVector<Value> retLen;
+ desc.getLayout().foreachField([desc, loc, &rewriter, &op, &retMem, &retLen](
FieldIndex fid,
SparseTensorFieldKind fKind, Level lvl,
DimLevelType dlt) -> bool {
@@ -1329,6 +1330,7 @@ struct SparseUnpackOpConverter : public OpConversionPattern<UnpackOp> {
// TODO: maybe change unpack/pack operation instead to be
// consistent.
retMem.insert(retMem.begin(), dst);
+ retLen.insert(retLen.begin(), sz);
} else {
assert(fKind == SparseTensorFieldKind::PosMemRef ||
fKind == SparseTensorFieldKind::CrdMemRef);
@@ -1339,6 +1341,7 @@ struct SparseUnpackOpConverter : public OpConversionPattern<UnpackOp> {
src = desc.getMemRefField(fid);
dst = genToMemref(rewriter, loc, op.getOutLevels()[fid]);
retMem.push_back(dst);
+ retLen.push_back(sz);
}
Value flatOut = dst;
if (dst.getType().getRank() != 1) {
@@ -1352,12 +1355,13 @@ struct SparseUnpackOpConverter : public OpConversionPattern<UnpackOp> {
});
// Converts MemRefs back to Tensors.
- SmallVector<Value> retTensor = llvm::to_vector(
+ SmallVector<Value> retValues = llvm::to_vector(
llvm::map_range(retMem, [&rewriter, loc](Value v) -> Value {
return rewriter.create<bufferization::ToTensorOp>(loc, v);
}));
-
- rewriter.replaceOp(op, retTensor);
+ // Appends the actual memory length used in each buffer returned.
+ retValues.append(retLen.begin(), retLen.end());
+ rewriter.replaceOp(op, retValues);
return success();
}
};
diff --git a/mlir/test/Dialect/SparseTensor/invalid.mlir b/mlir/test/Dialect/SparseTensor/invalid.mlir
index 7a6c4824aabed6..1510c3d65b57a5 100644
--- a/mlir/test/Dialect/SparseTensor/invalid.mlir
+++ b/mlir/test/Dialect/SparseTensor/invalid.mlir
@@ -60,9 +60,9 @@ func.func @invalid_pack_mis_position(%values: tensor<6xf64>, %coordinates: tenso
func.func @invalid_unpack_type(%sp: tensor<100xf32, #SparseVector>, %values: tensor<6xf64>, %pos: tensor<2xi32>, %coordinates: tensor<6x1xi32>) {
// expected-error at +1 {{input/output element-types don't match}}
- %rv, %rp, %rc = sparse_tensor.unpack %sp : tensor<100xf32, #SparseVector>
+ %rv, %rp, %rc, %vl, %pl, %cl = sparse_tensor.unpack %sp : tensor<100xf32, #SparseVector>
outs(%values, %pos, %coordinates : tensor<6xf64>, tensor<2xi32>, tensor<6x1xi32>)
- -> tensor<6xf64>, tensor<2xi32>, tensor<6x1xi32>
+ -> tensor<6xf64>, (tensor<2xi32>, tensor<6x1xi32>), index, (index, index)
return
}
@@ -72,9 +72,9 @@ func.func @invalid_unpack_type(%sp: tensor<100xf32, #SparseVector>, %values: ten
func.func @invalid_unpack_type(%sp: tensor<100x2xf64, #SparseVector>, %values: tensor<6xf64>, %pos: tensor<2xi32>, %coordinates: tensor<6x3xi32>) {
// expected-error at +1 {{input/output trailing COO level-ranks don't match}}
- %rv, %rp, %rc = sparse_tensor.unpack %sp : tensor<100x2xf64, #SparseVector>
+ %rv, %rp, %rc, %vl, %pl, %cl = sparse_tensor.unpack %sp : tensor<100x2xf64, #SparseVector>
outs(%values, %pos, %coordinates : tensor<6xf64>, tensor<2xi32>, tensor<6x3xi32>)
- -> tensor<6xf64>, tensor<2xi32>, tensor<6x3xi32>
+ -> tensor<6xf64>, (tensor<2xi32>, tensor<6x3xi32>), index, (index, index)
return
}
@@ -84,9 +84,9 @@ func.func @invalid_unpack_type(%sp: tensor<100x2xf64, #SparseVector>, %values: t
func.func @invalid_unpack_mis_position(%sp: tensor<2x100xf64, #CSR>, %values: tensor<6xf64>, %coordinates: tensor<6xi32>) {
// expected-error at +1 {{inconsistent number of fields between input/output}}
- %rv, %rc = sparse_tensor.unpack %sp : tensor<2x100xf64, #CSR>
+ %rv, %rc, %vl, %pl = sparse_tensor.unpack %sp : tensor<2x100xf64, #CSR>
outs(%values, %coordinates : tensor<6xf64>, tensor<6xi32>)
- -> tensor<6xf64>, tensor<6xi32>
+ -> tensor<6xf64>, (tensor<6xi32>), index, (index)
return
}
diff --git a/mlir/test/Dialect/SparseTensor/roundtrip.mlir b/mlir/test/Dialect/SparseTensor/roundtrip.mlir
index 43429f454e1225..ace180369da7db 100644
--- a/mlir/test/Dialect/SparseTensor/roundtrip.mlir
+++ b/mlir/test/Dialect/SparseTensor/roundtrip.mlir
@@ -36,16 +36,16 @@ func.func @sparse_pack(%data: tensor<6xf64>, %pos: tensor<2xi32>, %index: tensor
// CHECK-SAME: %[[OD:.*]]: tensor<6xf64>
// CHECK-SAME: %[[OP:.*]]: tensor<2xindex>
// CHECK-SAME: %[[OI:.*]]: tensor<6x1xi32>
-// CHECK: %[[D:.*]], %[[P:.*]]:2 = sparse_tensor.unpack %[[T]]
+// CHECK: %[[D:.*]], %[[P:.*]]:2, %[[DL:.*]], %[[PL:.*]]:2 = sparse_tensor.unpack %[[T]]
// CHECK: return %[[D]], %[[P]]#0, %[[P]]#1
func.func @sparse_unpack(%sp : tensor<100xf64, #SparseVector>,
%od : tensor<6xf64>,
%op : tensor<2xindex>,
%oi : tensor<6x1xi32>)
-> (tensor<6xf64>, tensor<2xindex>, tensor<6x1xi32>) {
- %rd, %rp, %ri = sparse_tensor.unpack %sp : tensor<100xf64, #SparseVector>
+ %rd, %rp, %ri, %vl, %pl, %cl = sparse_tensor.unpack %sp : tensor<100xf64, #SparseVector>
outs(%od, %op, %oi : tensor<6xf64>, tensor<2xindex>, tensor<6x1xi32>)
- -> tensor<6xf64>, tensor<2xindex>, tensor<6x1xi32>
+ -> tensor<6xf64>, (tensor<2xindex>, tensor<6x1xi32>), index, (index, index)
return %rd, %rp, %ri : tensor<6xf64>, tensor<2xindex>, tensor<6x1xi32>
}
diff --git a/mlir/test/Dialect/SparseTensor/sparse_pack.mlir b/mlir/test/Dialect/SparseTensor/sparse_pack.mlir
index 5d7305dac54cc6..0377ecc3090a17 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_pack.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_pack.mlir
@@ -70,8 +70,8 @@ func.func @sparse_unpack(%sp : tensor<100x100xf64, #COO>,
%op : tensor<2xindex>,
%oi : tensor<6x2xi32>)
-> (tensor<6xf64>, tensor<2xindex>, tensor<6x2xi32>) {
- %rd, %rp, %ri = sparse_tensor.unpack %sp : tensor<100x100xf64, #COO>
- outs(%od, %op, %oi : tensor<6xf64>, tensor<2xindex>, tensor<6x2xi32>)
- -> tensor<6xf64>, tensor<2xindex>, tensor<6x2xi32>
+ %rd, %rp, %ri, %dl, %pl, %il = sparse_tensor.unpack %sp : tensor<100x100xf64, #COO>
+ outs(%od, %op, %oi : tensor<6xf64>, tensor<2xindex>, tensor<6x2xi32>)
+ -> tensor<6xf64>, (tensor<2xindex>, tensor<6x2xi32>), index, (index, index)
return %rd, %rp, %ri : tensor<6xf64>, tensor<2xindex>, tensor<6x2xi32>
}
diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack.mlir
index 3014407a95a1a6..acb5dca06863d9 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack.mlir
@@ -171,9 +171,9 @@ module {
%d_csr = tensor.empty() : tensor<4xf64>
%p_csr = tensor.empty() : tensor<3xi32>
%i_csr = tensor.empty() : tensor<3xi32>
- %rd_csr, %rp_csr, %ri_csr = sparse_tensor.unpack %csr : tensor<2x2xf64, #CSR>
+ %rd_csr, %rp_csr, %ri_csr, %ld_csr, %lp_csr, %li_csr = sparse_tensor.unpack %csr : tensor<2x2xf64, #CSR>
outs(%d_csr, %p_csr, %i_csr : tensor<4xf64>, tensor<3xi32>, tensor<3xi32>)
- -> tensor<4xf64>, tensor<3xi32>, tensor<3xi32>
+ -> tensor<4xf64>, (tensor<3xi32>, tensor<3xi32>), index, (index, index)
// CHECK-NEXT: ( 1, 2, 3, {{.*}} )
%vd_csr = vector.transfer_read %rd_csr[%c0], %f0 : tensor<4xf64>, vector<4xf64>
@@ -196,9 +196,9 @@ module {
%od = tensor.empty() : tensor<3xf64>
%op = tensor.empty() : tensor<2xi32>
%oi = tensor.empty() : tensor<3x2xi32>
- %d, %p, %i = sparse_tensor.unpack %s5 : tensor<10x10xf64, #SortedCOOI32>
+ %d, %p, %i, %dl, %pl, %il = sparse_tensor.unpack %s5 : tensor<10x10xf64, #SortedCOOI32>
outs(%od, %op, %oi : tensor<3xf64>, tensor<2xi32>, tensor<3x2xi32>)
- -> tensor<3xf64>, tensor<2xi32>, tensor<3x2xi32>
+ -> tensor<3xf64>, (tensor<2xi32>, tensor<3x2xi32>), index, (index, index)
// CHECK-NEXT: ( 1, 2, 3 )
%vd = vector.transfer_read %d[%c0], %f0 : tensor<3xf64>, vector<3xf64>
@@ -212,17 +212,21 @@ module {
%bod = tensor.empty() : tensor<6xf64>
%bop = tensor.empty() : tensor<4xindex>
%boi = tensor.empty() : tensor<6x2xindex>
- %bd, %bp, %bi = sparse_tensor.unpack %bs : tensor<2x10x10xf64, #BCOO>
+ %bd, %bp, %bi, %ld, %lp, %li = sparse_tensor.unpack %bs : tensor<2x10x10xf64, #BCOO>
outs(%bod, %bop, %boi : tensor<6xf64>, tensor<4xindex>, tensor<6x2xindex>)
- -> tensor<6xf64>, tensor<4xindex>, tensor<6x2xindex>
+ -> tensor<6xf64>, (tensor<4xindex>, tensor<6x2xindex>), index, (index, index)
// CHECK-NEXT: ( 1, 2, 3, 4, 5, {{.*}} )
%vbd = vector.transfer_read %bd[%c0], %f0 : tensor<6xf64>, vector<6xf64>
vector.print %vbd : vector<6xf64>
+ // CHECK-NEXT: 5
+ vector.print %ld : index
// CHECK-NEXT: ( ( 1, 2 ), ( 5, 6 ), ( 7, 8 ), ( 2, 3 ), ( 4, 2 ), ( {{.*}}, {{.*}} ) )
%vbi = vector.transfer_read %bi[%c0, %c0], %c0 : tensor<6x2xindex>, vector<6x2xindex>
vector.print %vbi : vector<6x2xindex>
+ // CHECK-NEXT: 10
+ vector.print %li : index
return
}
More information about the Mlir-commits
mailing list