[Mlir-commits] [mlir] 64df1c0 - [sparse] allow unpack op to return any integer type. (#66161)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Sep 12 17:27:55 PDT 2023
Author: Peiming Liu
Date: 2023-09-12T17:27:51-07:00
New Revision: 64df1c08d04039369b657bc3a8ed606dfc3ba479
URL: https://github.com/llvm/llvm-project/commit/64df1c08d04039369b657bc3a8ed606dfc3ba479
DIFF: https://github.com/llvm/llvm-project/commit/64df1c08d04039369b657bc3a8ed606dfc3ba479.diff
LOG: [sparse] allow unpack op to return any integer type. (#66161)
Added:
Modified:
mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index 5395a0e088c90db..7d9f1d3b26c0678 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -108,8 +108,8 @@ def SparseTensor_UnpackOp : SparseTensor_Op<"unpack", [Pure, SameVariadicResultS
Variadic<TensorOf<[AnySignlessIntegerOrIndex]>>:$out_levels)>,
Results<(outs TensorOf<[AnyType]>:$ret_values,
Variadic<TensorOf<[AnySignlessIntegerOrIndex]>>:$ret_levels,
- Index:$val_len,
- Variadic<Index>:$lvl_lens)> {
+ AnySignlessIntegerOrIndex:$val_len,
+ Variadic<AnySignlessIntegerOrIndex>:$lvl_lens)> {
let summary = "Returns the (values, coordinates) pair unpacked from the input tensor";
let description = [{
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index 20df5f43e897ecb..0c8a304841c10d5 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -1323,7 +1323,8 @@ struct SparseUnpackOpConverter : public OpConversionPattern<UnpackOp> {
// TODO: maybe change unpack/pack operation instead to be
// consistent.
retMem.insert(retMem.begin(), dst);
- retLen.insert(retLen.begin(), sz);
+ Type valLenTp = op.getValLen().getType();
+ retLen.insert(retLen.begin(), genCast(rewriter, loc, sz, valLenTp));
} else {
assert(fKind == SparseTensorFieldKind::PosMemRef ||
fKind == SparseTensorFieldKind::CrdMemRef);
@@ -1334,7 +1335,9 @@ struct SparseUnpackOpConverter : public OpConversionPattern<UnpackOp> {
src = desc.getMemRefField(fid);
dst = genToMemref(rewriter, loc, op.getOutLevels()[fid]);
retMem.push_back(dst);
- retLen.push_back(sz);
+ // Retrieves the corresponding level length type.
+ Type lvlLenTp = op.getLvlLens().getTypes()[retLen.size()];
+ retLen.push_back(genCast(rewriter, loc, sz, lvlLenTp));
}
Value flatOut = dst;
if (dst.getType().getRank() != 1) {
diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack.mlir
index 66e1632694cb4bc..cc8d538e6adfb83 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack.mlir
@@ -178,7 +178,7 @@ module {
%i_csr = tensor.empty() : tensor<3xi32>
%rd_csr, %rp_csr, %ri_csr, %ld_csr, %lp_csr, %li_csr = sparse_tensor.unpack %csr : tensor<2x2xf64, #CSR>
outs(%d_csr, %p_csr, %i_csr : tensor<4xf64>, tensor<3xi32>, tensor<3xi32>)
- -> tensor<4xf64>, (tensor<3xi32>, tensor<3xi32>), index, (index, index)
+ -> tensor<4xf64>, (tensor<3xi32>, tensor<3xi32>), index, (i32, i64)
// CHECK-NEXT: ( 1, 2, 3, {{.*}} )
%vd_csr = vector.transfer_read %rd_csr[%c0], %f0 : tensor<4xf64>, vector<4xf64>
@@ -203,7 +203,7 @@ module {
%oi = tensor.empty() : tensor<3x2xi32>
%d, %p, %i, %dl, %pl, %il = sparse_tensor.unpack %s5 : tensor<10x10xf64, #SortedCOOI32>
outs(%od, %op, %oi : tensor<3xf64>, tensor<2xi32>, tensor<3x2xi32>)
- -> tensor<3xf64>, (tensor<2xi32>, tensor<3x2xi32>), index, (index, index)
+ -> tensor<3xf64>, (tensor<2xi32>, tensor<3x2xi32>), index, (i32, i64)
// CHECK-NEXT: ( 1, 2, 3 )
%vd = vector.transfer_read %d[%c0], %f0 : tensor<3xf64>, vector<3xf64>
@@ -219,7 +219,7 @@ module {
%boi = tensor.empty() : tensor<6x2xindex>
%bd, %bp, %bi, %ld, %lp, %li = sparse_tensor.unpack %bs : tensor<2x10x10xf64, #BCOO>
outs(%bod, %bop, %boi : tensor<6xf64>, tensor<4xindex>, tensor<6x2xindex>)
- -> tensor<6xf64>, (tensor<4xindex>, tensor<6x2xindex>), index, (index, index)
+ -> tensor<6xf64>, (tensor<4xindex>, tensor<6x2xindex>), index, (i32, i64)
// CHECK-NEXT: ( 1, 2, 3, 4, 5, {{.*}} )
%vbd = vector.transfer_read %bd[%c0], %f0 : tensor<6xf64>, vector<6xf64>
@@ -231,7 +231,7 @@ module {
%vbi = vector.transfer_read %bi[%c0, %c0], %c0 : tensor<6x2xindex>, vector<6x2xindex>
vector.print %vbi : vector<6x2xindex>
// CHECK-NEXT: 10
- vector.print %li : index
+ vector.print %li : i64
return
}
More information about the Mlir-commits
mailing list