[Mlir-commits] [mlir] [mlir][sparse] move toCOOType into SparseTensorType class (PR #73708)

Aart Bik llvmlistbot at llvm.org
Tue Nov 28 14:48:04 PST 2023


https://github.com/aartbik created https://github.com/llvm/llvm-project/pull/73708

Migrates dangling convenience method into proper SparseTensorType class. Also cleans up some details (picking right dim2lvl/lvl2dim). Removes more dead code.

>From 472964073950e87c808460f1cc56954a3c6a648c Mon Sep 17 00:00:00 2001
From: Aart Bik <ajcbik at google.com>
Date: Tue, 28 Nov 2023 14:45:30 -0800
Subject: [PATCH] [mlir][sparse] move toCOOType into SparseTensorType class

Migrates dangling convenience method into proper SparseTensorType
class. Also cleans up some details (picking right dim2lvl/lvl2dim).
Removes more dead code.
---
 .../Dialect/SparseTensor/IR/SparseTensor.h    |  4 ---
 .../SparseTensor/IR/SparseTensorType.h        | 29 ++++++-------------
 .../SparseTensor/IR/SparseTensorDialect.cpp   | 29 +++++++++++++++++--
 .../IR/SparseTensorInterfaces.cpp             |  7 ++---
 4 files changed, 38 insertions(+), 31 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
index 517c286e0206997..28dfdbdcf89b5bf 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
@@ -102,10 +102,6 @@ bool isUniqueCOOType(Type tp);
 /// the level-rank.
 Level getCOOStart(SparseTensorEncodingAttr enc);
 
-/// Helper to setup a COO type.
-RankedTensorType getCOOFromTypeWithOrdering(RankedTensorType src,
-                                            AffineMap ordering, bool ordered);
-
 /// Returns true iff MLIR operand has any sparse operand.
 inline bool hasAnySparseOperand(Operation *op) {
   return llvm::any_of(op->getOperands().getTypes(), [](Type t) {
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
index 4eb666d76cd2d6f..dc520e390de293d 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
@@ -64,18 +64,14 @@ class SparseTensorType {
   SparseTensorType(const SparseTensorType &) = default;
 
   //
-  // Factory methods.
+  // Factory methods to construct a new `SparseTensorType`
+  // with the same dimension-shape and element type.
   //
 
-  /// Constructs a new `SparseTensorType` with the same dimension-shape
-  /// and element type, but with the encoding replaced by the given encoding.
   SparseTensorType withEncoding(SparseTensorEncodingAttr newEnc) const {
     return SparseTensorType(rtp, newEnc);
   }
 
-  /// Constructs a new `SparseTensorType` with the same dimension-shape
-  /// and element type, but with the encoding replaced by
-  /// `getEncoding().withDimToLvl(dimToLvl)`.
   SparseTensorType withDimToLvl(AffineMap dimToLvl) const {
     return withEncoding(enc.withDimToLvl(dimToLvl));
   }
@@ -88,23 +84,14 @@ class SparseTensorType {
     return withDimToLvl(dimToLvlSTT.getEncoding());
   }
 
-  /// Constructs a new `SparseTensorType` with the same dimension-shape
-  /// and element type, but with the encoding replaced by
-  /// `getEncoding().withoutDimToLvl()`.
   SparseTensorType withoutDimToLvl() const {
     return withEncoding(enc.withoutDimToLvl());
   }
 
-  /// Constructs a new `SparseTensorType` with the same dimension-shape
-  /// and element type, but with the encoding replaced by
-  /// `getEncoding().withBitWidths(posWidth, crdWidth)`.
   SparseTensorType withBitWidths(unsigned posWidth, unsigned crdWidth) const {
     return withEncoding(enc.withBitWidths(posWidth, crdWidth));
   }
 
-  /// Constructs a new `SparseTensorType` with the same dimension-shape
-  /// and element type, but with the encoding replaced by
-  /// `getEncoding().withoutBitWidths()`.
   SparseTensorType withoutBitWidths() const {
     return withEncoding(enc.withoutBitWidths());
   }
@@ -118,10 +105,6 @@ class SparseTensorType {
     return withEncoding(enc.withoutDimSlices());
   }
 
-  //
-  // Other methods.
-  //
-
   /// Allow implicit conversion to `RankedTensorType`, `ShapedType`,
   /// and `Type`.  These are implicit to help alleviate the impedance
   /// mismatch for code that has not been converted to use `SparseTensorType`
@@ -170,7 +153,6 @@ class SparseTensorType {
 
   Type getElementType() const { return rtp.getElementType(); }
 
-  /// Returns the encoding (or the null-attribute for dense-tensors).
   SparseTensorEncodingAttr getEncoding() const { return enc; }
 
   //
@@ -204,6 +186,10 @@ class SparseTensorType {
   /// (This is always true for dense-tensors.)
   bool isIdentity() const { return enc.isIdentity(); }
 
+  //
+  // Other methods.
+  //
+
   /// Returns the dimToLvl mapping (or the null-map for the identity).
   /// If you intend to compare the results of this method for equality,
   /// see `hasSameDimToLvl` instead.
@@ -325,6 +311,9 @@ class SparseTensorType {
     return IndexType::get(getContext());
   }
 
+  /// Returns [un]ordered COO type for this sparse tensor type.
+  RankedTensorType getCOOType(bool ordered) const;
+
 private:
   // These two must be const, to ensure coherence of the memoized fields.
   const RankedTensorType rtp;
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index ff2930008fa093f..edf7df3cfedabba 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -36,7 +36,7 @@ using namespace mlir;
 using namespace mlir::sparse_tensor;
 
 //===----------------------------------------------------------------------===//
-// Local convenience methods.
+// Local Convenience Methods.
 //===----------------------------------------------------------------------===//
 
 static constexpr bool acceptBitWidth(unsigned bitWidth) {
@@ -711,7 +711,32 @@ LogicalResult SparseTensorEncodingAttr::verifyEncoding(
 }
 
 //===----------------------------------------------------------------------===//
-// Convenience methods.
+// SparseTensorType SparseTensorType Methods.
+//===----------------------------------------------------------------------===//
+
+RankedTensorType
+mlir::sparse_tensor::SparseTensorType::getCOOType(bool ordered) const {
+  SmallVector<LevelType> lvlTypes;
+  lvlTypes.reserve(lvlRank);
+  // An unordered and non-unique compressed level at beginning.
+  // If this is also the last level, then it is unique.
+  lvlTypes.push_back(
+      *buildLevelType(LevelFormat::Compressed, ordered, lvlRank == 1));
+  if (lvlRank > 1) {
+    // Followed by unordered non-unique n-2 singleton levels.
+    std::fill_n(std::back_inserter(lvlTypes), lvlRank - 2,
+                *buildLevelType(LevelFormat::Singleton, ordered, false));
+    // Ends by a unique singleton level unless the lvlRank is 1.
+    lvlTypes.push_back(*buildLevelType(LevelFormat::Singleton, ordered, true));
+  }
+  auto enc = SparseTensorEncodingAttr::get(getContext(), lvlTypes,
+                                           getDimToLvl(), getLvlToDim(),
+                                           getPosWidth(), getCrdWidth());
+  return RankedTensorType::get(getDimShape(), getElementType(), enc);
+}
+
+//===----------------------------------------------------------------------===//
+// Convenience Methods.
 //===----------------------------------------------------------------------===//
 
 SparseTensorEncodingAttr
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorInterfaces.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorInterfaces.cpp
index d8769eacc44f39b..c8e77f7de48300e 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorInterfaces.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorInterfaces.cpp
@@ -25,9 +25,7 @@ sparse_tensor::detail::stageWithSortImpl(StageWithSortSparseOp op,
   Location loc = op.getLoc();
   Type finalTp = op->getOpResult(0).getType();
   SparseTensorType dstStt(finalTp.cast<RankedTensorType>());
-
-  Type srcCOOTp = getCOOFromTypeWithOrdering(
-      dstStt.getRankedTensorType(), dstStt.getDimToLvl(), /*ordered=*/false);
+  Type srcCOOTp = dstStt.getCOOType(/*ordered=*/false);
 
   // Clones the original operation but changing the output to an unordered COO.
   Operation *cloned = rewriter.clone(*op.getOperation());
@@ -37,8 +35,7 @@ sparse_tensor::detail::stageWithSortImpl(StageWithSortSparseOp op,
   Value srcCOO = cloned->getOpResult(0);
 
   // -> sort
-  Type dstCOOTp = getCOOFromTypeWithOrdering(
-      dstStt.getRankedTensorType(), dstStt.getDimToLvl(), /*ordered=*/true);
+  Type dstCOOTp = dstStt.getCOOType(/*ordered=*/true);
   Value dstCOO = rewriter.create<ReorderCOOOp>(
       loc, dstCOOTp, srcCOO, SparseTensorSortKind::HybridQuickSort);
 



More information about the Mlir-commits mailing list