[Mlir-commits] [mlir] [sparse] allow unpack op to return 0-ranked tensor type. (PR #66269)

Peiming Liu llvmlistbot at llvm.org
Wed Sep 13 11:17:47 PDT 2023


https://github.com/PeimingLiu created https://github.com/llvm/llvm-project/pull/66269:

Many frontends canonicalize scalar into 0-ranked tensor, it change will hopefully make the operation easier to use for those cases.

>From c08bc442a8fbfb21262b6a166c3ced9a21c280fd Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Wed, 13 Sep 2023 18:09:40 +0000
Subject: [PATCH] [sparse] allow unpack op to return 0-ranked tensor type.

---
 .../SparseTensor/IR/SparseTensorAttrDefs.td     |  4 ++++
 .../Dialect/SparseTensor/IR/SparseTensorOps.td  |  4 ++--
 .../Transforms/SparseTensorCodegen.cpp          | 17 +++++++++++++++--
 .../Transforms/SparseTensorPasses.cpp           |  2 ++
 .../Dialect/SparseTensor/CPU/sparse_pack.mlir   |  5 +++--
 5 files changed, 26 insertions(+), 6 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
index e2f3df005b70d69..bf077db43ec10e9 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
@@ -438,6 +438,10 @@ class RankedSparseTensorOf<list<Type> allowedTypes>
 
 def AnyRankedSparseTensor : RankedSparseTensorOf<[AnyType]>;
 
+class ScalarLikeOf<list<Type> allowedTypes>
+  : AnyTypeOf<[0DTensorOf<allowedTypes>, AnyTypeOf<allowedTypes>]>;
+
+
 //===----------------------------------------------------------------------===//
 // Sparse Tensor Sorting Algorithm Attribute.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index 7d9f1d3b26c0678..7430a3c6118cef4 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -108,8 +108,8 @@ def SparseTensor_UnpackOp : SparseTensor_Op<"unpack", [Pure, SameVariadicResultS
                    Variadic<TensorOf<[AnySignlessIntegerOrIndex]>>:$out_levels)>,
     Results<(outs TensorOf<[AnyType]>:$ret_values,
                   Variadic<TensorOf<[AnySignlessIntegerOrIndex]>>:$ret_levels,
-                  AnySignlessIntegerOrIndex:$val_len,
-                  Variadic<AnySignlessIntegerOrIndex>:$lvl_lens)> {
+                  ScalarLikeOf<[AnySignlessIntegerOrIndex]>:$val_len,
+                  Variadic<ScalarLikeOf<[AnySignlessIntegerOrIndex]>>:$lvl_lens)> {
   let summary = "Returns the (values, coordinates) pair unpacked from the input tensor";
 
   let description = [{
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index 0c8a304841c10d5..557c5c471c4a77c 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -559,6 +559,18 @@ static ReassociationIndices getReassociationForFlattening(ShapedType srcTp) {
   return reassociation;
 }
 
+static Value genScalarToTensor(OpBuilder &builder, Location loc, Value elem,
+                               Type dstTp) {
+  if (auto rtp = dstTp.dyn_cast<RankedTensorType>()) {
+    // Scalars can only be converted to 0-ranked tensors.
+    if (rtp.getRank() != 0)
+      return nullptr;
+    elem = genCast(builder, loc, elem, rtp.getElementType());
+    return builder.create<tensor::FromElementsOp>(loc, rtp, elem);
+  }
+  return genCast(builder, loc, elem, dstTp);
+}
+
 //===----------------------------------------------------------------------===//
 // Codegen rules.
 //===----------------------------------------------------------------------===//
@@ -1324,7 +1336,8 @@ struct SparseUnpackOpConverter : public OpConversionPattern<UnpackOp> {
         // consistent.
         retMem.insert(retMem.begin(), dst);
         Type valLenTp = op.getValLen().getType();
-        retLen.insert(retLen.begin(), genCast(rewriter, loc, sz, valLenTp));
+        retLen.insert(retLen.begin(),
+                      genScalarToTensor(rewriter, loc, sz, valLenTp));
       } else {
         assert(fKind == SparseTensorFieldKind::PosMemRef ||
                fKind == SparseTensorFieldKind::CrdMemRef);
@@ -1337,7 +1350,7 @@ struct SparseUnpackOpConverter : public OpConversionPattern<UnpackOp> {
         retMem.push_back(dst);
         // Retrieves the corresponding level length type.
         Type lvlLenTp = op.getLvlLens().getTypes()[retLen.size()];
-        retLen.push_back(genCast(rewriter, loc, sz, lvlLenTp));
+        retLen.push_back(genScalarToTensor(rewriter, loc, sz, lvlLenTp));
       }
       Value flatOut = dst;
       if (dst.getType().getRank() != 1) {
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
index cce26bc603eeb3c..2956cf57ade0290 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
@@ -214,6 +214,8 @@ struct SparseTensorCodegenPass
     target.addLegalOp<GetStorageSpecifierOp>();
     target.addLegalOp<SetStorageSpecifierOp>();
     target.addLegalOp<StorageSpecifierInitOp>();
+    // tensor::FromElementsOp might be yield after lowering unpack.
+    target.addLegalOp<tensor::FromElementsOp>();
     // All dynamic rules below accept new function, call, return, and
     // various tensor and bufferization operations as legal output of the
     // rewriting provided that all sparse tensor types have been fully
diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack.mlir
index cc8d538e6adfb83..d95efb507765403 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack.mlir
@@ -219,7 +219,7 @@ module {
     %boi = tensor.empty() : tensor<6x2xindex>
     %bd, %bp, %bi, %ld, %lp, %li = sparse_tensor.unpack %bs : tensor<2x10x10xf64, #BCOO>
                     outs(%bod, %bop, %boi : tensor<6xf64>, tensor<4xindex>, tensor<6x2xindex>)
-                    -> tensor<6xf64>, (tensor<4xindex>, tensor<6x2xindex>), index, (i32, i64)
+                    -> tensor<6xf64>, (tensor<4xindex>, tensor<6x2xindex>), index, (i32, tensor<i64>)
 
     // CHECK-NEXT: ( 1, 2, 3, 4, 5, {{.*}} )
     %vbd = vector.transfer_read %bd[%c0], %f0 : tensor<6xf64>, vector<6xf64>
@@ -231,7 +231,8 @@ module {
     %vbi = vector.transfer_read %bi[%c0, %c0], %c0 : tensor<6x2xindex>, vector<6x2xindex>
     vector.print %vbi : vector<6x2xindex>
     // CHECK-NEXT: 10
-    vector.print %li : i64
+    %si = tensor.extract %li[] : tensor<i64>
+    vector.print %si : i64
 
     return
   }



More information about the Mlir-commits mailing list