[Mlir-commits] [mlir] [sparse] allow unpack op to return any integer type. (PR #66161)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Sep 12 17:14:37 PDT 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-sparse
            
<details>
<summary>Changes</summary>
None
--
Full diff: https://github.com/llvm/llvm-project/pull/66161.diff

3 Files Affected:

- (modified) mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td (+2-2) 
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp (+5-2) 
- (modified) mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack.mlir (+4-4) 


<pre>
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index 5395a0e088c90db..7d9f1d3b26c0678 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -108,8 +108,8 @@ def SparseTensor_UnpackOp : SparseTensor_Op<"unpack", [Pure, SameVariadicResultS
                    Variadic<TensorOf<[AnySignlessIntegerOrIndex]>>:$out_levels)>,
     Results<(outs TensorOf<[AnyType]>:$ret_values,
                   Variadic<TensorOf<[AnySignlessIntegerOrIndex]>>:$ret_levels,
-                  Index:$val_len,
-                  Variadic<Index>:$lvl_lens)> {
+                  AnySignlessIntegerOrIndex:$val_len,
+                  Variadic<AnySignlessIntegerOrIndex>:$lvl_lens)> {
   let summary = "Returns the (values, coordinates) pair unpacked from the input tensor";
 
   let description = [{
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index 20df5f43e897ecb..05e48e57cabd351 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -1323,7 +1323,8 @@ struct SparseUnpackOpConverter : public OpConversionPattern<UnpackOp> {
         // TODO: maybe change unpack/pack operation instead to be
         // consistent.
         retMem.insert(retMem.begin(), dst);
-        retLen.insert(retLen.begin(), sz);
+        retLen.insert(retLen.begin(),
+                      genCast(rewriter, loc, sz, op.getValLen().getType()));
       } else {
         assert(fKind == SparseTensorFieldKind::PosMemRef ||
                fKind == SparseTensorFieldKind::CrdMemRef);
@@ -1334,7 +1335,9 @@ struct SparseUnpackOpConverter : public OpConversionPattern<UnpackOp> {
         src = desc.getMemRefField(fid);
         dst = genToMemref(rewriter, loc, op.getOutLevels()[fid]);
         retMem.push_back(dst);
-        retLen.push_back(sz);
+        // Retrieves the corresponding level length type.
+        Type lvlLenTp = op.getLvlLens().getTypes()[retLen.size()];
+        retLen.push_back(genCast(rewriter, loc, sz, lvlLenTp));
       }
       Value flatOut = dst;
       if (dst.getType().getRank() != 1) {
diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack.mlir
index 66e1632694cb4bc..cc8d538e6adfb83 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack.mlir
@@ -178,7 +178,7 @@ module {
     %i_csr = tensor.empty() : tensor<3xi32>
     %rd_csr, %rp_csr, %ri_csr, %ld_csr, %lp_csr, %li_csr = sparse_tensor.unpack %csr : tensor<2x2xf64, #CSR>
                  outs(%d_csr, %p_csr, %i_csr : tensor<4xf64>, tensor<3xi32>, tensor<3xi32>)
-                 -> tensor<4xf64>, (tensor<3xi32>, tensor<3xi32>), index, (index, index)
+                 -> tensor<4xf64>, (tensor<3xi32>, tensor<3xi32>), index, (i32, i64)
 
     // CHECK-NEXT: ( 1, 2, 3, {{.*}} )
     %vd_csr = vector.transfer_read %rd_csr[%c0], %f0 : tensor<4xf64>, vector<4xf64>
@@ -203,7 +203,7 @@ module {
     %oi = tensor.empty() : tensor<3x2xi32>
     %d, %p, %i, %dl, %pl, %il = sparse_tensor.unpack %s5 : tensor<10x10xf64, #SortedCOOI32>
                  outs(%od, %op, %oi : tensor<3xf64>, tensor<2xi32>, tensor<3x2xi32>)
-                 -> tensor<3xf64>, (tensor<2xi32>, tensor<3x2xi32>), index, (index, index)
+                 -> tensor<3xf64>, (tensor<2xi32>, tensor<3x2xi32>), index, (i32, i64)
 
     // CHECK-NEXT: ( 1, 2, 3 )
     %vd = vector.transfer_read %d[%c0], %f0 : tensor<3xf64>, vector<3xf64>
@@ -219,7 +219,7 @@ module {
     %boi = tensor.empty() : tensor<6x2xindex>
     %bd, %bp, %bi, %ld, %lp, %li = sparse_tensor.unpack %bs : tensor<2x10x10xf64, #BCOO>
                     outs(%bod, %bop, %boi : tensor<6xf64>, tensor<4xindex>, tensor<6x2xindex>)
-                    -> tensor<6xf64>, (tensor<4xindex>, tensor<6x2xindex>), index, (index, index)
+                    -> tensor<6xf64>, (tensor<4xindex>, tensor<6x2xindex>), index, (i32, i64)
 
     // CHECK-NEXT: ( 1, 2, 3, 4, 5, {{.*}} )
     %vbd = vector.transfer_read %bd[%c0], %f0 : tensor<6xf64>, vector<6xf64>
@@ -231,7 +231,7 @@ module {
     %vbi = vector.transfer_read %bi[%c0, %c0], %c0 : tensor<6x2xindex>, vector<6x2xindex>
     vector.print %vbi : vector<6x2xindex>
     // CHECK-NEXT: 10
-    vector.print %li : index
+    vector.print %li : i64
 
     return
   }
</pre>
</details>


https://github.com/llvm/llvm-project/pull/66161


More information about the Mlir-commits mailing list