[Mlir-commits] [mlir] [sparse] allow unpack op to return any integer type. (PR #66161)

Peiming Liu llvmlistbot at llvm.org
Tue Sep 12 17:13:55 PDT 2023


https://github.com/PeimingLiu created https://github.com/llvm/llvm-project/pull/66161:

None

>From 37d86769ceeca2e4f8ce88bc6d45535cfde046d2 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Wed, 13 Sep 2023 00:09:51 +0000
Subject: [PATCH] [sparse] allow unpack op to return any integer type.

---
 .../mlir/Dialect/SparseTensor/IR/SparseTensorOps.td       | 4 ++--
 .../SparseTensor/Transforms/SparseTensorCodegen.cpp       | 7 +++++--
 .../Integration/Dialect/SparseTensor/CPU/sparse_pack.mlir | 8 ++++----
 3 files changed, 11 insertions(+), 8 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index 5395a0e088c90db..7d9f1d3b26c0678 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -108,8 +108,8 @@ def SparseTensor_UnpackOp : SparseTensor_Op<"unpack", [Pure, SameVariadicResultS
                    Variadic<TensorOf<[AnySignlessIntegerOrIndex]>>:$out_levels)>,
     Results<(outs TensorOf<[AnyType]>:$ret_values,
                   Variadic<TensorOf<[AnySignlessIntegerOrIndex]>>:$ret_levels,
-                  Index:$val_len,
-                  Variadic<Index>:$lvl_lens)> {
+                  AnySignlessIntegerOrIndex:$val_len,
+                  Variadic<AnySignlessIntegerOrIndex>:$lvl_lens)> {
   let summary = "Returns the (values, coordinates) pair unpacked from the input tensor";
 
   let description = [{
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index 20df5f43e897ecb..05e48e57cabd351 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -1323,7 +1323,8 @@ struct SparseUnpackOpConverter : public OpConversionPattern<UnpackOp> {
         // TODO: maybe change unpack/pack operation instead to be
         // consistent.
         retMem.insert(retMem.begin(), dst);
-        retLen.insert(retLen.begin(), sz);
+        retLen.insert(retLen.begin(),
+                      genCast(rewriter, loc, sz, op.getValLen().getType()));
       } else {
         assert(fKind == SparseTensorFieldKind::PosMemRef ||
                fKind == SparseTensorFieldKind::CrdMemRef);
@@ -1334,7 +1335,9 @@ struct SparseUnpackOpConverter : public OpConversionPattern<UnpackOp> {
         src = desc.getMemRefField(fid);
         dst = genToMemref(rewriter, loc, op.getOutLevels()[fid]);
         retMem.push_back(dst);
-        retLen.push_back(sz);
+        // Retrieves the corresponding level length type.
+        Type lvlLenTp = op.getLvlLens().getTypes()[retLen.size()];
+        retLen.push_back(genCast(rewriter, loc, sz, lvlLenTp));
       }
       Value flatOut = dst;
       if (dst.getType().getRank() != 1) {
diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack.mlir
index 66e1632694cb4bc..cc8d538e6adfb83 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack.mlir
@@ -178,7 +178,7 @@ module {
     %i_csr = tensor.empty() : tensor<3xi32>
     %rd_csr, %rp_csr, %ri_csr, %ld_csr, %lp_csr, %li_csr = sparse_tensor.unpack %csr : tensor<2x2xf64, #CSR>
                  outs(%d_csr, %p_csr, %i_csr : tensor<4xf64>, tensor<3xi32>, tensor<3xi32>)
-                 -> tensor<4xf64>, (tensor<3xi32>, tensor<3xi32>), index, (index, index)
+                 -> tensor<4xf64>, (tensor<3xi32>, tensor<3xi32>), index, (i32, i64)
 
     // CHECK-NEXT: ( 1, 2, 3, {{.*}} )
     %vd_csr = vector.transfer_read %rd_csr[%c0], %f0 : tensor<4xf64>, vector<4xf64>
@@ -203,7 +203,7 @@ module {
     %oi = tensor.empty() : tensor<3x2xi32>
     %d, %p, %i, %dl, %pl, %il = sparse_tensor.unpack %s5 : tensor<10x10xf64, #SortedCOOI32>
                  outs(%od, %op, %oi : tensor<3xf64>, tensor<2xi32>, tensor<3x2xi32>)
-                 -> tensor<3xf64>, (tensor<2xi32>, tensor<3x2xi32>), index, (index, index)
+                 -> tensor<3xf64>, (tensor<2xi32>, tensor<3x2xi32>), index, (i32, i64)
 
     // CHECK-NEXT: ( 1, 2, 3 )
     %vd = vector.transfer_read %d[%c0], %f0 : tensor<3xf64>, vector<3xf64>
@@ -219,7 +219,7 @@ module {
     %boi = tensor.empty() : tensor<6x2xindex>
     %bd, %bp, %bi, %ld, %lp, %li = sparse_tensor.unpack %bs : tensor<2x10x10xf64, #BCOO>
                     outs(%bod, %bop, %boi : tensor<6xf64>, tensor<4xindex>, tensor<6x2xindex>)
-                    -> tensor<6xf64>, (tensor<4xindex>, tensor<6x2xindex>), index, (index, index)
+                    -> tensor<6xf64>, (tensor<4xindex>, tensor<6x2xindex>), index, (i32, i64)
 
     // CHECK-NEXT: ( 1, 2, 3, 4, 5, {{.*}} )
     %vbd = vector.transfer_read %bd[%c0], %f0 : tensor<6xf64>, vector<6xf64>
@@ -231,7 +231,7 @@ module {
     %vbi = vector.transfer_read %bi[%c0, %c0], %c0 : tensor<6x2xindex>, vector<6x2xindex>
     vector.print %vbi : vector<6x2xindex>
     // CHECK-NEXT: 10
-    vector.print %li : index
+    vector.print %li : i64
 
     return
   }



More information about the Mlir-commits mailing list