[Mlir-commits] [mlir] [mlir][SparseTensor][NFC] Pass tensor type to descriptor helper (PR #116468)
Matthias Springer
llvmlistbot at llvm.org
Mon Nov 18 03:56:02 PST 2024
https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/116468
>From 66fb6bb37e32eac9838e6913bc5f8f9ffd089f8c Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Sat, 16 Nov 2024 05:11:48 +0100
Subject: [PATCH] [mlir][SparseTensor][NFC] Pass tensor type to descriptor
helper
---
.../Transforms/SparseTensorCodegen.cpp | 58 ++++++++++++-------
.../Transforms/Utils/CodegenUtils.cpp | 5 --
.../Transforms/Utils/CodegenUtils.h | 3 -
.../Transforms/Utils/SparseTensorDescriptor.h | 12 ++--
4 files changed, 44 insertions(+), 34 deletions(-)
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index bf7b3f9bec5586..25fca49cb0154a 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -646,10 +646,11 @@ class SparseLvlOpConverter : public OpConversionPattern<LvlOp> {
matchAndRewrite(LvlOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
std::optional<int64_t> lvl = op.getConstantLvlIndex();
- if (!lvl || !getSparseTensorEncoding(adaptor.getSource().getType()))
+ RankedTensorType srcType = op.getSource().getType();
+ if (!lvl || !getSparseTensorEncoding(srcType))
return failure();
- auto desc = getDescriptorFromTensorTuple(adaptor.getSource());
+ auto desc = getDescriptorFromTensorTuple(adaptor.getSource(), srcType);
auto sz = desc.getLvlSize(rewriter, op.getLoc(), *lvl);
rewriter.replaceOp(op, sz);
@@ -675,8 +676,9 @@ struct SparseReorderCOOConverter : public OpConversionPattern<ReorderCOOOp> {
assert(dstStt.hasSameDimToLvl(srcStt));
// We don't need a mutable descriptor here as we perform sorting in-place.
- auto nnz = genValMemSize(rewriter, op.getLoc(), adaptor.getInputCoo());
- auto desc = getDescriptorFromTensorTuple(adaptor.getInputCoo());
+ auto desc = getDescriptorFromTensorTuple(adaptor.getInputCoo(),
+ op.getInputCoo().getType());
+ auto nnz = desc.getValMemSize(rewriter, op.getLoc());
auto crd = desc.getAOSMemRef();
auto val = desc.getValMemRef();
@@ -704,7 +706,8 @@ class SparseSliceGetterOpConverter : public OpConversionPattern<Op> {
matchAndRewrite(Op op, typename Op::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Simply lowers to specifer.get <field> operation.
- auto desc = getDescriptorFromTensorTuple(adaptor.getSlice());
+ auto desc = getDescriptorFromTensorTuple(adaptor.getSlice(),
+ op.getSlice().getType());
auto v = desc.getSpecifierField(rewriter, op.getLoc(), kind,
op.getDim().getZExtValue());
@@ -762,7 +765,8 @@ class SparseTensorAllocConverter
Location loc = op.getLoc();
// Deal with copy.
if (op.getCopy()) {
- auto desc = getDescriptorFromTensorTuple(adaptor.getCopy());
+ auto desc = getDescriptorFromTensorTuple(
+ adaptor.getCopy(), cast<RankedTensorType>(op.getCopy().getType()));
SmallVector<Value> fields;
fields.reserve(desc.getNumFields());
// Memcpy on memref fields.
@@ -868,7 +872,9 @@ class SparseTensorDeallocConverter
if (createDeallocs) {
// Replace the sparse tensor deallocation with field deallocations.
Location loc = op.getLoc();
- auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
+ auto desc = getDescriptorFromTensorTuple(
+ adaptor.getTensor(),
+ cast<RankedTensorType>(op.getTensor().getType()));
for (auto input : desc.getMemRefFields())
// Deallocate every buffer used to store the sparse tensor handler.
rewriter.create<memref::DeallocOp>(loc, input);
@@ -889,7 +895,8 @@ class SparseTensorLoadConverter : public OpConversionPattern<LoadOp> {
matchAndRewrite(LoadOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Prepare descriptor.
- auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
+ auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(),
+ op.getTensor().getType());
// Generate optional insertion finalization code.
if (op.getHasInserts())
genEndInsert(rewriter, op.getLoc(), desc);
@@ -909,7 +916,8 @@ class SparseExpandConverter : public OpConversionPattern<ExpandOp> {
if (!getSparseTensorEncoding(op.getTensor().getType()))
return failure();
Location loc = op->getLoc();
- auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
+ auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(),
+ op.getTensor().getType());
const auto srcType = getSparseTensorType(op.getTensor());
Type eltType = srcType.getElementType();
Type boolType = rewriter.getIntegerType(1);
@@ -959,7 +967,8 @@ class SparseCompressConverter : public OpConversionPattern<CompressOp> {
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
SmallVector<Value> fields;
- auto desc = getMutDescriptorFromTensorTuple(adaptor.getTensor(), fields);
+ auto desc = getMutDescriptorFromTensorTuple(adaptor.getTensor(), fields,
+ op.getTensor().getType());
Value values = adaptor.getValues();
Value filled = adaptor.getFilled();
Value added = adaptor.getAdded();
@@ -1032,7 +1041,8 @@ class SparseInsertConverter : public OpConversionPattern<tensor::InsertOp> {
assert(stt.isIdentity() && "Run reinterpret-map before conversion.");
Location loc = op.getLoc();
- auto desc = getDescriptorFromTensorTuple(adaptor.getDest());
+ auto desc =
+ getDescriptorFromTensorTuple(adaptor.getDest(), op.getDest().getType());
TypeRange flatSpTensorTps = desc.getFields().getTypes();
SmallVector<Value> params = llvm::to_vector(desc.getFields());
params.append(adaptor.getIndices().begin(), adaptor.getIndices().end());
@@ -1059,7 +1069,8 @@ class SparseToPositionsConverter : public OpConversionPattern<ToPositionsOp> {
// of this operation truly observe size, not capacity!
Location loc = op.getLoc();
Level lvl = op.getLevel();
- auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
+ auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(),
+ op.getTensor().getType());
auto mem = desc.getPosMemRef(lvl);
auto size = desc.getPosMemSize(rewriter, loc, lvl);
rewriter.replaceOp(op, genSliceToSize(rewriter, loc, mem, size));
@@ -1081,7 +1092,8 @@ class SparseToCoordinatesConverter
// of this operation truly observe size, not capacity!
Location loc = op.getLoc();
Level lvl = op.getLevel();
- auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
+ auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(),
+ op.getTensor().getType());
auto mem = desc.getCrdMemRefOrView(rewriter, loc, lvl);
if (lvl < getSparseTensorType(op.getTensor()).getAoSCOOStart()) {
auto size = desc.getCrdMemSize(rewriter, loc, lvl);
@@ -1106,7 +1118,8 @@ class SparseToCoordinatesBufferConverter
// of this operation truly observe size, not capacity!
Location loc = op.getLoc();
Level lvl = getSparseTensorType(op.getTensor()).getAoSCOOStart();
- auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
+ auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(),
+ op.getTensor().getType());
auto mem = desc.getAOSMemRef();
auto size = desc.getCrdMemSize(rewriter, loc, lvl);
rewriter.replaceOp(op, genSliceToSize(rewriter, loc, mem, size));
@@ -1126,7 +1139,8 @@ class SparseToValuesConverter : public OpConversionPattern<ToValuesOp> {
// The view is restricted to the actual size to ensure clients
// of this operation truly observe size, not capacity!
Location loc = op.getLoc();
- auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
+ auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(),
+ op.getTensor().getType());
auto mem = desc.getValMemRef();
auto size = desc.getValMemSize(rewriter, loc);
rewriter.replaceOp(op, genSliceToSize(rewriter, loc, mem, size));
@@ -1172,7 +1186,8 @@ class SparseConvertConverter : public OpConversionPattern<ConvertOp> {
// else:
// dst = memref.copy(src)
Location loc = op.getLoc();
- auto srcDesc = getDescriptorFromTensorTuple(adaptor.getSource());
+ auto srcDesc = getDescriptorFromTensorTuple(adaptor.getSource(),
+ op.getSource().getType());
SmallVector<Value> fields;
foreachFieldAndTypeInSparseTensor(
SparseTensorType(cast<RankedTensorType>(op.getResult().getType())),
@@ -1236,7 +1251,8 @@ class SparseExtractSliceConverter
assert(srcEnc.withoutDimSlices() == dstEnc.withoutDimSlices());
SmallVector<Value> fields;
- auto desc = getMutDescriptorFromTensorTuple(adaptor.getSource(), fields);
+ auto desc = getMutDescriptorFromTensorTuple(adaptor.getSource(), fields,
+ op.getSource().getType());
auto newSpec = rewriter.create<StorageSpecifierInitOp>(
loc, StorageSpecifierType::get(ctx, dstEnc), desc.getSpecifier());
@@ -1285,8 +1301,9 @@ class SparseNumberOfEntriesConverter
// Query memSizes for the actually stored values.
// FIXME: the nse value computed in this way might be wrong when there is
// any "loose_compressed" level.
- rewriter.replaceOp(
- op, genValMemSize(rewriter, op.getLoc(), adaptor.getTensor()));
+ auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(),
+ op.getTensor().getType());
+ rewriter.replaceOp(op, desc.getValMemSize(rewriter, op.getLoc()));
return success();
}
};
@@ -1415,7 +1432,8 @@ struct SparseDisassembleOpConverter
LogicalResult
matchAndRewrite(DisassembleOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
+ auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(),
+ op.getTensor().getType());
Location loc = op.getLoc();
SmallVector<Value> retMem;
SmallVector<Value> retLen;
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.cpp
index de553a5f9bf08c..f92382472b4780 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.cpp
@@ -554,11 +554,6 @@ sparse_tensor::genToMemref(OpBuilder &builder, Location loc, Value tensor) {
.getResult();
}
-Value sparse_tensor::genValMemSize(OpBuilder &builder, Location loc,
- Value tensor) {
- return getDescriptorFromTensorTuple(tensor).getValMemSize(builder, loc);
-}
-
Value sparse_tensor::createOrFoldSliceOffsetOp(OpBuilder &builder, Location loc,
Value tensor, Dimension dim) {
auto enc = getSparseTensorEncoding(tensor.getType());
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.h
index d0ef8a6860bb2d..dc017e6baa6dc3 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.h
@@ -270,9 +270,6 @@ void storeAll(OpBuilder &builder, Location loc, Value mem, ValueRange vs,
TypedValue<BaseMemRefType> genToMemref(OpBuilder &builder, Location loc,
Value tensor);
-/// Generates code to retrieve the values size for the sparse tensor.
-Value genValMemSize(OpBuilder &builder, Location loc, Value tensor);
-
/// Generates code to retrieve the slice offset for the sparse tensor slice,
/// return a constant if the offset is statically known.
Value createOrFoldSliceOffsetOp(OpBuilder &builder, Location loc, Value tensor,
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.h
index c2f631605bf4b2..89858546e37e1b 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.h
@@ -245,18 +245,18 @@ inline Value genTuple(OpBuilder &builder, Location loc,
return genTuple(builder, loc, desc.getRankedTensorType(), desc.getFields());
}
-inline SparseTensorDescriptor getDescriptorFromTensorTuple(Value tensor) {
+inline SparseTensorDescriptor
+getDescriptorFromTensorTuple(Value tensor, RankedTensorType type) {
auto tuple = getTuple(tensor);
- SparseTensorType stt(cast<RankedTensorType>(tuple.getResultTypes()[0]));
- return SparseTensorDescriptor(stt, tuple.getInputs());
+ return SparseTensorDescriptor(SparseTensorType(type), tuple.getInputs());
}
inline MutSparseTensorDescriptor
-getMutDescriptorFromTensorTuple(Value tensor, SmallVectorImpl<Value> &fields) {
+getMutDescriptorFromTensorTuple(Value tensor, SmallVectorImpl<Value> &fields,
+ RankedTensorType type) {
auto tuple = getTuple(tensor);
fields.assign(tuple.getInputs().begin(), tuple.getInputs().end());
- SparseTensorType stt(cast<RankedTensorType>(tuple.getResultTypes()[0]));
- return MutSparseTensorDescriptor(stt, fields);
+ return MutSparseTensorDescriptor(SparseTensorType(type), fields);
}
} // namespace sparse_tensor
More information about the Mlir-commits
mailing list