[Mlir-commits] [mlir] [sparse] allow unpack op to return 0-ranked tensor type. (PR #66269)
Peiming Liu
llvmlistbot at llvm.org
Wed Sep 13 11:17:47 PDT 2023
https://github.com/PeimingLiu created https://github.com/llvm/llvm-project/pull/66269:
Many frontends canonicalize scalar into 0-ranked tensor, it change will hopefully make the operation easier to use for those cases.
>From c08bc442a8fbfb21262b6a166c3ced9a21c280fd Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Wed, 13 Sep 2023 18:09:40 +0000
Subject: [PATCH] [sparse] allow unpack op to return 0-ranked tensor type.
---
.../SparseTensor/IR/SparseTensorAttrDefs.td | 4 ++++
.../Dialect/SparseTensor/IR/SparseTensorOps.td | 4 ++--
.../Transforms/SparseTensorCodegen.cpp | 17 +++++++++++++++--
.../Transforms/SparseTensorPasses.cpp | 2 ++
.../Dialect/SparseTensor/CPU/sparse_pack.mlir | 5 +++--
5 files changed, 26 insertions(+), 6 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
index e2f3df005b70d69..bf077db43ec10e9 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
@@ -438,6 +438,10 @@ class RankedSparseTensorOf<list<Type> allowedTypes>
def AnyRankedSparseTensor : RankedSparseTensorOf<[AnyType]>;
+class ScalarLikeOf<list<Type> allowedTypes>
+ : AnyTypeOf<[0DTensorOf<allowedTypes>, AnyTypeOf<allowedTypes>]>;
+
+
//===----------------------------------------------------------------------===//
// Sparse Tensor Sorting Algorithm Attribute.
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index 7d9f1d3b26c0678..7430a3c6118cef4 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -108,8 +108,8 @@ def SparseTensor_UnpackOp : SparseTensor_Op<"unpack", [Pure, SameVariadicResultS
Variadic<TensorOf<[AnySignlessIntegerOrIndex]>>:$out_levels)>,
Results<(outs TensorOf<[AnyType]>:$ret_values,
Variadic<TensorOf<[AnySignlessIntegerOrIndex]>>:$ret_levels,
- AnySignlessIntegerOrIndex:$val_len,
- Variadic<AnySignlessIntegerOrIndex>:$lvl_lens)> {
+ ScalarLikeOf<[AnySignlessIntegerOrIndex]>:$val_len,
+ Variadic<ScalarLikeOf<[AnySignlessIntegerOrIndex]>>:$lvl_lens)> {
let summary = "Returns the (values, coordinates) pair unpacked from the input tensor";
let description = [{
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index 0c8a304841c10d5..557c5c471c4a77c 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -559,6 +559,18 @@ static ReassociationIndices getReassociationForFlattening(ShapedType srcTp) {
return reassociation;
}
+static Value genScalarToTensor(OpBuilder &builder, Location loc, Value elem,
+ Type dstTp) {
+ if (auto rtp = dstTp.dyn_cast<RankedTensorType>()) {
+ // Scalars can only be converted to 0-ranked tensors.
+ if (rtp.getRank() != 0)
+ return nullptr;
+ elem = genCast(builder, loc, elem, rtp.getElementType());
+ return builder.create<tensor::FromElementsOp>(loc, rtp, elem);
+ }
+ return genCast(builder, loc, elem, dstTp);
+}
+
//===----------------------------------------------------------------------===//
// Codegen rules.
//===----------------------------------------------------------------------===//
@@ -1324,7 +1336,8 @@ struct SparseUnpackOpConverter : public OpConversionPattern<UnpackOp> {
// consistent.
retMem.insert(retMem.begin(), dst);
Type valLenTp = op.getValLen().getType();
- retLen.insert(retLen.begin(), genCast(rewriter, loc, sz, valLenTp));
+ retLen.insert(retLen.begin(),
+ genScalarToTensor(rewriter, loc, sz, valLenTp));
} else {
assert(fKind == SparseTensorFieldKind::PosMemRef ||
fKind == SparseTensorFieldKind::CrdMemRef);
@@ -1337,7 +1350,7 @@ struct SparseUnpackOpConverter : public OpConversionPattern<UnpackOp> {
retMem.push_back(dst);
// Retrieves the corresponding level length type.
Type lvlLenTp = op.getLvlLens().getTypes()[retLen.size()];
- retLen.push_back(genCast(rewriter, loc, sz, lvlLenTp));
+ retLen.push_back(genScalarToTensor(rewriter, loc, sz, lvlLenTp));
}
Value flatOut = dst;
if (dst.getType().getRank() != 1) {
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
index cce26bc603eeb3c..2956cf57ade0290 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
@@ -214,6 +214,8 @@ struct SparseTensorCodegenPass
target.addLegalOp<GetStorageSpecifierOp>();
target.addLegalOp<SetStorageSpecifierOp>();
target.addLegalOp<StorageSpecifierInitOp>();
+ // tensor::FromElementsOp might be yield after lowering unpack.
+ target.addLegalOp<tensor::FromElementsOp>();
// All dynamic rules below accept new function, call, return, and
// various tensor and bufferization operations as legal output of the
// rewriting provided that all sparse tensor types have been fully
diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack.mlir
index cc8d538e6adfb83..d95efb507765403 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack.mlir
@@ -219,7 +219,7 @@ module {
%boi = tensor.empty() : tensor<6x2xindex>
%bd, %bp, %bi, %ld, %lp, %li = sparse_tensor.unpack %bs : tensor<2x10x10xf64, #BCOO>
outs(%bod, %bop, %boi : tensor<6xf64>, tensor<4xindex>, tensor<6x2xindex>)
- -> tensor<6xf64>, (tensor<4xindex>, tensor<6x2xindex>), index, (i32, i64)
+ -> tensor<6xf64>, (tensor<4xindex>, tensor<6x2xindex>), index, (i32, tensor<i64>)
// CHECK-NEXT: ( 1, 2, 3, 4, 5, {{.*}} )
%vbd = vector.transfer_read %bd[%c0], %f0 : tensor<6xf64>, vector<6xf64>
@@ -231,7 +231,8 @@ module {
%vbi = vector.transfer_read %bi[%c0, %c0], %c0 : tensor<6x2xindex>, vector<6x2xindex>
vector.print %vbi : vector<6x2xindex>
// CHECK-NEXT: 10
- vector.print %li : i64
+ %si = tensor.extract %li[] : tensor<i64>
+ vector.print %si : i64
return
}
More information about the Mlir-commits
mailing list