[Mlir-commits] [mlir] [mlir][SparseTensor][NFC] Pass tensor type to descriptor helper (PR #116468)

Matthias Springer llvmlistbot at llvm.org
Mon Nov 18 03:56:02 PST 2024


https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/116468

>From 66fb6bb37e32eac9838e6913bc5f8f9ffd089f8c Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Sat, 16 Nov 2024 05:11:48 +0100
Subject: [PATCH] [mlir][SparseTensor][NFC] Pass tensor type to descriptor
 helper

---
 .../Transforms/SparseTensorCodegen.cpp        | 58 ++++++++++++-------
 .../Transforms/Utils/CodegenUtils.cpp         |  5 --
 .../Transforms/Utils/CodegenUtils.h           |  3 -
 .../Transforms/Utils/SparseTensorDescriptor.h | 12 ++--
 4 files changed, 44 insertions(+), 34 deletions(-)

diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index bf7b3f9bec5586..25fca49cb0154a 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -646,10 +646,11 @@ class SparseLvlOpConverter : public OpConversionPattern<LvlOp> {
   matchAndRewrite(LvlOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     std::optional<int64_t> lvl = op.getConstantLvlIndex();
-    if (!lvl || !getSparseTensorEncoding(adaptor.getSource().getType()))
+    RankedTensorType srcType = op.getSource().getType();
+    if (!lvl || !getSparseTensorEncoding(srcType))
       return failure();
 
-    auto desc = getDescriptorFromTensorTuple(adaptor.getSource());
+    auto desc = getDescriptorFromTensorTuple(adaptor.getSource(), srcType);
     auto sz = desc.getLvlSize(rewriter, op.getLoc(), *lvl);
 
     rewriter.replaceOp(op, sz);
@@ -675,8 +676,9 @@ struct SparseReorderCOOConverter : public OpConversionPattern<ReorderCOOOp> {
     assert(dstStt.hasSameDimToLvl(srcStt));
 
     // We don't need a mutable descriptor here as we perform sorting in-place.
-    auto nnz = genValMemSize(rewriter, op.getLoc(), adaptor.getInputCoo());
-    auto desc = getDescriptorFromTensorTuple(adaptor.getInputCoo());
+    auto desc = getDescriptorFromTensorTuple(adaptor.getInputCoo(),
+                                             op.getInputCoo().getType());
+    auto nnz = desc.getValMemSize(rewriter, op.getLoc());
     auto crd = desc.getAOSMemRef();
     auto val = desc.getValMemRef();
 
@@ -704,7 +706,8 @@ class SparseSliceGetterOpConverter : public OpConversionPattern<Op> {
   matchAndRewrite(Op op, typename Op::Adaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     // Simply lowers to specifer.get <field> operation.
-    auto desc = getDescriptorFromTensorTuple(adaptor.getSlice());
+    auto desc = getDescriptorFromTensorTuple(adaptor.getSlice(),
+                                             op.getSlice().getType());
     auto v = desc.getSpecifierField(rewriter, op.getLoc(), kind,
                                     op.getDim().getZExtValue());
 
@@ -762,7 +765,8 @@ class SparseTensorAllocConverter
     Location loc = op.getLoc();
     // Deal with copy.
     if (op.getCopy()) {
-      auto desc = getDescriptorFromTensorTuple(adaptor.getCopy());
+      auto desc = getDescriptorFromTensorTuple(
+          adaptor.getCopy(), cast<RankedTensorType>(op.getCopy().getType()));
       SmallVector<Value> fields;
       fields.reserve(desc.getNumFields());
       // Memcpy on memref fields.
@@ -868,7 +872,9 @@ class SparseTensorDeallocConverter
     if (createDeallocs) {
       // Replace the sparse tensor deallocation with field deallocations.
       Location loc = op.getLoc();
-      auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
+      auto desc = getDescriptorFromTensorTuple(
+          adaptor.getTensor(),
+          cast<RankedTensorType>(op.getTensor().getType()));
       for (auto input : desc.getMemRefFields())
         // Deallocate every buffer used to store the sparse tensor handler.
         rewriter.create<memref::DeallocOp>(loc, input);
@@ -889,7 +895,8 @@ class SparseTensorLoadConverter : public OpConversionPattern<LoadOp> {
   matchAndRewrite(LoadOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     // Prepare descriptor.
-    auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
+    auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(),
+                                             op.getTensor().getType());
     // Generate optional insertion finalization code.
     if (op.getHasInserts())
       genEndInsert(rewriter, op.getLoc(), desc);
@@ -909,7 +916,8 @@ class SparseExpandConverter : public OpConversionPattern<ExpandOp> {
     if (!getSparseTensorEncoding(op.getTensor().getType()))
       return failure();
     Location loc = op->getLoc();
-    auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
+    auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(),
+                                             op.getTensor().getType());
     const auto srcType = getSparseTensorType(op.getTensor());
     Type eltType = srcType.getElementType();
     Type boolType = rewriter.getIntegerType(1);
@@ -959,7 +967,8 @@ class SparseCompressConverter : public OpConversionPattern<CompressOp> {
                   ConversionPatternRewriter &rewriter) const override {
     Location loc = op->getLoc();
     SmallVector<Value> fields;
-    auto desc = getMutDescriptorFromTensorTuple(adaptor.getTensor(), fields);
+    auto desc = getMutDescriptorFromTensorTuple(adaptor.getTensor(), fields,
+                                                op.getTensor().getType());
     Value values = adaptor.getValues();
     Value filled = adaptor.getFilled();
     Value added = adaptor.getAdded();
@@ -1032,7 +1041,8 @@ class SparseInsertConverter : public OpConversionPattern<tensor::InsertOp> {
     assert(stt.isIdentity() && "Run reinterpret-map before conversion.");
 
     Location loc = op.getLoc();
-    auto desc = getDescriptorFromTensorTuple(adaptor.getDest());
+    auto desc =
+        getDescriptorFromTensorTuple(adaptor.getDest(), op.getDest().getType());
     TypeRange flatSpTensorTps = desc.getFields().getTypes();
     SmallVector<Value> params = llvm::to_vector(desc.getFields());
     params.append(adaptor.getIndices().begin(), adaptor.getIndices().end());
@@ -1059,7 +1069,8 @@ class SparseToPositionsConverter : public OpConversionPattern<ToPositionsOp> {
     // of this operation truly observe size, not capacity!
     Location loc = op.getLoc();
     Level lvl = op.getLevel();
-    auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
+    auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(),
+                                             op.getTensor().getType());
     auto mem = desc.getPosMemRef(lvl);
     auto size = desc.getPosMemSize(rewriter, loc, lvl);
     rewriter.replaceOp(op, genSliceToSize(rewriter, loc, mem, size));
@@ -1081,7 +1092,8 @@ class SparseToCoordinatesConverter
     // of this operation truly observe size, not capacity!
     Location loc = op.getLoc();
     Level lvl = op.getLevel();
-    auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
+    auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(),
+                                             op.getTensor().getType());
     auto mem = desc.getCrdMemRefOrView(rewriter, loc, lvl);
     if (lvl < getSparseTensorType(op.getTensor()).getAoSCOOStart()) {
       auto size = desc.getCrdMemSize(rewriter, loc, lvl);
@@ -1106,7 +1118,8 @@ class SparseToCoordinatesBufferConverter
     // of this operation truly observe size, not capacity!
     Location loc = op.getLoc();
     Level lvl = getSparseTensorType(op.getTensor()).getAoSCOOStart();
-    auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
+    auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(),
+                                             op.getTensor().getType());
     auto mem = desc.getAOSMemRef();
     auto size = desc.getCrdMemSize(rewriter, loc, lvl);
     rewriter.replaceOp(op, genSliceToSize(rewriter, loc, mem, size));
@@ -1126,7 +1139,8 @@ class SparseToValuesConverter : public OpConversionPattern<ToValuesOp> {
     // The view is restricted to the actual size to ensure clients
     // of this operation truly observe size, not capacity!
     Location loc = op.getLoc();
-    auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
+    auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(),
+                                             op.getTensor().getType());
     auto mem = desc.getValMemRef();
     auto size = desc.getValMemSize(rewriter, loc);
     rewriter.replaceOp(op, genSliceToSize(rewriter, loc, mem, size));
@@ -1172,7 +1186,8 @@ class SparseConvertConverter : public OpConversionPattern<ConvertOp> {
     //   else:
     //     dst = memref.copy(src)
     Location loc = op.getLoc();
-    auto srcDesc = getDescriptorFromTensorTuple(adaptor.getSource());
+    auto srcDesc = getDescriptorFromTensorTuple(adaptor.getSource(),
+                                                op.getSource().getType());
     SmallVector<Value> fields;
     foreachFieldAndTypeInSparseTensor(
         SparseTensorType(cast<RankedTensorType>(op.getResult().getType())),
@@ -1236,7 +1251,8 @@ class SparseExtractSliceConverter
     assert(srcEnc.withoutDimSlices() == dstEnc.withoutDimSlices());
 
     SmallVector<Value> fields;
-    auto desc = getMutDescriptorFromTensorTuple(adaptor.getSource(), fields);
+    auto desc = getMutDescriptorFromTensorTuple(adaptor.getSource(), fields,
+                                                op.getSource().getType());
 
     auto newSpec = rewriter.create<StorageSpecifierInitOp>(
         loc, StorageSpecifierType::get(ctx, dstEnc), desc.getSpecifier());
@@ -1285,8 +1301,9 @@ class SparseNumberOfEntriesConverter
     // Query memSizes for the actually stored values.
     // FIXME: the nse value computed in this way might be wrong when there is
     // any "loose_compressed" level.
-    rewriter.replaceOp(
-        op, genValMemSize(rewriter, op.getLoc(), adaptor.getTensor()));
+    auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(),
+                                             op.getTensor().getType());
+    rewriter.replaceOp(op, desc.getValMemSize(rewriter, op.getLoc()));
     return success();
   }
 };
@@ -1415,7 +1432,8 @@ struct SparseDisassembleOpConverter
   LogicalResult
   matchAndRewrite(DisassembleOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
+    auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(),
+                                             op.getTensor().getType());
     Location loc = op.getLoc();
     SmallVector<Value> retMem;
     SmallVector<Value> retLen;
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.cpp
index de553a5f9bf08c..f92382472b4780 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.cpp
@@ -554,11 +554,6 @@ sparse_tensor::genToMemref(OpBuilder &builder, Location loc, Value tensor) {
       .getResult();
 }
 
-Value sparse_tensor::genValMemSize(OpBuilder &builder, Location loc,
-                                   Value tensor) {
-  return getDescriptorFromTensorTuple(tensor).getValMemSize(builder, loc);
-}
-
 Value sparse_tensor::createOrFoldSliceOffsetOp(OpBuilder &builder, Location loc,
                                                Value tensor, Dimension dim) {
   auto enc = getSparseTensorEncoding(tensor.getType());
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.h
index d0ef8a6860bb2d..dc017e6baa6dc3 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.h
@@ -270,9 +270,6 @@ void storeAll(OpBuilder &builder, Location loc, Value mem, ValueRange vs,
 TypedValue<BaseMemRefType> genToMemref(OpBuilder &builder, Location loc,
                                        Value tensor);
 
-/// Generates code to retrieve the values size for the sparse tensor.
-Value genValMemSize(OpBuilder &builder, Location loc, Value tensor);
-
 /// Generates code to retrieve the slice offset for the sparse tensor slice,
 /// return a constant if the offset is statically known.
 Value createOrFoldSliceOffsetOp(OpBuilder &builder, Location loc, Value tensor,
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.h
index c2f631605bf4b2..89858546e37e1b 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.h
@@ -245,18 +245,18 @@ inline Value genTuple(OpBuilder &builder, Location loc,
   return genTuple(builder, loc, desc.getRankedTensorType(), desc.getFields());
 }
 
-inline SparseTensorDescriptor getDescriptorFromTensorTuple(Value tensor) {
+inline SparseTensorDescriptor
+getDescriptorFromTensorTuple(Value tensor, RankedTensorType type) {
   auto tuple = getTuple(tensor);
-  SparseTensorType stt(cast<RankedTensorType>(tuple.getResultTypes()[0]));
-  return SparseTensorDescriptor(stt, tuple.getInputs());
+  return SparseTensorDescriptor(SparseTensorType(type), tuple.getInputs());
 }
 
 inline MutSparseTensorDescriptor
-getMutDescriptorFromTensorTuple(Value tensor, SmallVectorImpl<Value> &fields) {
+getMutDescriptorFromTensorTuple(Value tensor, SmallVectorImpl<Value> &fields,
+                                RankedTensorType type) {
   auto tuple = getTuple(tensor);
   fields.assign(tuple.getInputs().begin(), tuple.getInputs().end());
-  SparseTensorType stt(cast<RankedTensorType>(tuple.getResultTypes()[0]));
-  return MutSparseTensorDescriptor(stt, fields);
+  return MutSparseTensorDescriptor(SparseTensorType(type), fields);
 }
 
 } // namespace sparse_tensor



More information about the Mlir-commits mailing list