[Mlir-commits] [mlir] [mlir][sparse] move toCOOType into SparseTensorType class (PR #73708)

Aart Bik llvmlistbot at llvm.org
Tue Nov 28 15:18:08 PST 2023


https://github.com/aartbik updated https://github.com/llvm/llvm-project/pull/73708

>From 472964073950e87c808460f1cc56954a3c6a648c Mon Sep 17 00:00:00 2001
From: Aart Bik <ajcbik at google.com>
Date: Tue, 28 Nov 2023 14:45:30 -0800
Subject: [PATCH 1/5] [mlir][sparse] move toCOOType into SparseTensorType class

Migrates dangling convenience method into proper SparseTensorType
class. Also cleans up some details (picking right dim2lvl/lvl2dim).
Removes more dead code.
---
 .../Dialect/SparseTensor/IR/SparseTensor.h    |  4 ---
 .../SparseTensor/IR/SparseTensorType.h        | 29 ++++++-------------
 .../SparseTensor/IR/SparseTensorDialect.cpp   | 29 +++++++++++++++++--
 .../IR/SparseTensorInterfaces.cpp             |  7 ++---
 4 files changed, 38 insertions(+), 31 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
index 517c286e0206997..28dfdbdcf89b5bf 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
@@ -102,10 +102,6 @@ bool isUniqueCOOType(Type tp);
 /// the level-rank.
 Level getCOOStart(SparseTensorEncodingAttr enc);
 
-/// Helper to setup a COO type.
-RankedTensorType getCOOFromTypeWithOrdering(RankedTensorType src,
-                                            AffineMap ordering, bool ordered);
-
 /// Returns true iff MLIR operand has any sparse operand.
 inline bool hasAnySparseOperand(Operation *op) {
   return llvm::any_of(op->getOperands().getTypes(), [](Type t) {
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
index 4eb666d76cd2d6f..dc520e390de293d 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
@@ -64,18 +64,14 @@ class SparseTensorType {
   SparseTensorType(const SparseTensorType &) = default;
 
   //
-  // Factory methods.
+  // Factory methods to construct a new `SparseTensorType`
+  // with the same dimension-shape and element type.
   //
 
-  /// Constructs a new `SparseTensorType` with the same dimension-shape
-  /// and element type, but with the encoding replaced by the given encoding.
   SparseTensorType withEncoding(SparseTensorEncodingAttr newEnc) const {
     return SparseTensorType(rtp, newEnc);
   }
 
-  /// Constructs a new `SparseTensorType` with the same dimension-shape
-  /// and element type, but with the encoding replaced by
-  /// `getEncoding().withDimToLvl(dimToLvl)`.
   SparseTensorType withDimToLvl(AffineMap dimToLvl) const {
     return withEncoding(enc.withDimToLvl(dimToLvl));
   }
@@ -88,23 +84,14 @@ class SparseTensorType {
     return withDimToLvl(dimToLvlSTT.getEncoding());
   }
 
-  /// Constructs a new `SparseTensorType` with the same dimension-shape
-  /// and element type, but with the encoding replaced by
-  /// `getEncoding().withoutDimToLvl()`.
   SparseTensorType withoutDimToLvl() const {
     return withEncoding(enc.withoutDimToLvl());
   }
 
-  /// Constructs a new `SparseTensorType` with the same dimension-shape
-  /// and element type, but with the encoding replaced by
-  /// `getEncoding().withBitWidths(posWidth, crdWidth)`.
   SparseTensorType withBitWidths(unsigned posWidth, unsigned crdWidth) const {
     return withEncoding(enc.withBitWidths(posWidth, crdWidth));
   }
 
-  /// Constructs a new `SparseTensorType` with the same dimension-shape
-  /// and element type, but with the encoding replaced by
-  /// `getEncoding().withoutBitWidths()`.
   SparseTensorType withoutBitWidths() const {
     return withEncoding(enc.withoutBitWidths());
   }
@@ -118,10 +105,6 @@ class SparseTensorType {
     return withEncoding(enc.withoutDimSlices());
   }
 
-  //
-  // Other methods.
-  //
-
   /// Allow implicit conversion to `RankedTensorType`, `ShapedType`,
   /// and `Type`.  These are implicit to help alleviate the impedance
   /// mismatch for code that has not been converted to use `SparseTensorType`
@@ -170,7 +153,6 @@ class SparseTensorType {
 
   Type getElementType() const { return rtp.getElementType(); }
 
-  /// Returns the encoding (or the null-attribute for dense-tensors).
   SparseTensorEncodingAttr getEncoding() const { return enc; }
 
   //
@@ -204,6 +186,10 @@ class SparseTensorType {
   /// (This is always true for dense-tensors.)
   bool isIdentity() const { return enc.isIdentity(); }
 
+  //
+  // Other methods.
+  //
+
   /// Returns the dimToLvl mapping (or the null-map for the identity).
   /// If you intend to compare the results of this method for equality,
   /// see `hasSameDimToLvl` instead.
@@ -325,6 +311,9 @@ class SparseTensorType {
     return IndexType::get(getContext());
   }
 
+  /// Returns [un]ordered COO type for this sparse tensor type.
+  RankedTensorType getCOOType(bool ordered) const;
+
 private:
   // These two must be const, to ensure coherence of the memoized fields.
   const RankedTensorType rtp;
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index ff2930008fa093f..edf7df3cfedabba 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -36,7 +36,7 @@ using namespace mlir;
 using namespace mlir::sparse_tensor;
 
 //===----------------------------------------------------------------------===//
-// Local convenience methods.
+// Local Convenience Methods.
 //===----------------------------------------------------------------------===//
 
 static constexpr bool acceptBitWidth(unsigned bitWidth) {
@@ -711,7 +711,32 @@ LogicalResult SparseTensorEncodingAttr::verifyEncoding(
 }
 
 //===----------------------------------------------------------------------===//
-// Convenience methods.
+// SparseTensorType SparseTensorType Methods.
+//===----------------------------------------------------------------------===//
+
+RankedTensorType
+mlir::sparse_tensor::SparseTensorType::getCOOType(bool ordered) const {
+  SmallVector<LevelType> lvlTypes;
+  lvlTypes.reserve(lvlRank);
+  // An unordered and non-unique compressed level at beginning.
+  // If this is also the last level, then it is unique.
+  lvlTypes.push_back(
+      *buildLevelType(LevelFormat::Compressed, ordered, lvlRank == 1));
+  if (lvlRank > 1) {
+    // Followed by unordered non-unique n-2 singleton levels.
+    std::fill_n(std::back_inserter(lvlTypes), lvlRank - 2,
+                *buildLevelType(LevelFormat::Singleton, ordered, false));
+    // Ends by a unique singleton level unless the lvlRank is 1.
+    lvlTypes.push_back(*buildLevelType(LevelFormat::Singleton, ordered, true));
+  }
+  auto enc = SparseTensorEncodingAttr::get(getContext(), lvlTypes,
+                                           getDimToLvl(), getLvlToDim(),
+                                           getPosWidth(), getCrdWidth());
+  return RankedTensorType::get(getDimShape(), getElementType(), enc);
+}
+
+//===----------------------------------------------------------------------===//
+// Convenience Methods.
 //===----------------------------------------------------------------------===//
 
 SparseTensorEncodingAttr
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorInterfaces.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorInterfaces.cpp
index d8769eacc44f39b..c8e77f7de48300e 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorInterfaces.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorInterfaces.cpp
@@ -25,9 +25,7 @@ sparse_tensor::detail::stageWithSortImpl(StageWithSortSparseOp op,
   Location loc = op.getLoc();
   Type finalTp = op->getOpResult(0).getType();
   SparseTensorType dstStt(finalTp.cast<RankedTensorType>());
-
-  Type srcCOOTp = getCOOFromTypeWithOrdering(
-      dstStt.getRankedTensorType(), dstStt.getDimToLvl(), /*ordered=*/false);
+  Type srcCOOTp = dstStt.getCOOType(/*ordered=*/false);
 
   // Clones the original operation but changing the output to an unordered COO.
   Operation *cloned = rewriter.clone(*op.getOperation());
@@ -37,8 +35,7 @@ sparse_tensor::detail::stageWithSortImpl(StageWithSortSparseOp op,
   Value srcCOO = cloned->getOpResult(0);
 
   // -> sort
-  Type dstCOOTp = getCOOFromTypeWithOrdering(
-      dstStt.getRankedTensorType(), dstStt.getDimToLvl(), /*ordered=*/true);
+  Type dstCOOTp = dstStt.getCOOType(/*ordered=*/true);
   Value dstCOO = rewriter.create<ReorderCOOOp>(
       loc, dstCOOTp, srcCOO, SparseTensorSortKind::HybridQuickSort);
 

>From d126b4dee618bae40b4df0b9bf78ec8c26433473 Mon Sep 17 00:00:00 2001
From: Aart Bik <ajcbik at google.com>
Date: Tue, 28 Nov 2023 14:52:52 -0800
Subject: [PATCH 2/5] rewriting file too

---
 .../SparseTensor/Transforms/SparseTensorRewriting.cpp  | 10 ++--------
 1 file changed, 2 insertions(+), 8 deletions(-)

diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index 702666e9d40c31f..2bd129b85ea5416 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -132,15 +132,9 @@ static void sizesForTensor(OpBuilder &builder, SmallVectorImpl<Value> &sizes,
   }
 }
 
-// TODO: The dim level property of the COO type relies on input tensors, the
-// shape relies on the output tensor
-static RankedTensorType getCOOType(const SparseTensorType &stt, bool ordered) {
-  return getCOOFromTypeWithOrdering(stt, stt.getDimToLvl(), ordered);
-}
-
 static RankedTensorType getBufferType(const SparseTensorType &stt,
                                       bool needTmpCOO) {
-  return needTmpCOO ? getCOOType(stt, /*ordered=*/false)
+  return needTmpCOO ? stt.getCOOType(/*ordered=*/false)
                     : stt.getRankedTensorType();
 }
 
@@ -1195,7 +1189,7 @@ struct NewRewriter : public OpRewritePattern<NewOp> {
     //   %t = sparse_tensor.convert %orderedCoo
     // with enveloping reinterpreted_map ops for non-permutations.
     RankedTensorType dstTp = stt.getRankedTensorType();
-    RankedTensorType cooTp = getCOOType(dstTp, /*ordered=*/true);
+    RankedTensorType cooTp = stt.getCOOType(/*ordered=*/true);
     Value cooTensor = rewriter.create<NewOp>(loc, cooTp, op.getSource());
     Value convert = cooTensor;
     if (!stt.isPermutation()) { // demap coo, demap dstTp

>From 086bf81a004dff468bb88f50e3eba9ddeb655615 Mon Sep 17 00:00:00 2001
From: Aart Bik <ajcbik at google.com>
Date: Tue, 28 Nov 2023 15:03:22 -0800
Subject: [PATCH 3/5] DCE

---
 .../SparseTensor/IR/SparseTensorDialect.cpp   | 33 -------------------
 1 file changed, 33 deletions(-)

diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index edf7df3cfedabba..50c62bbba9d2437 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -903,39 +903,6 @@ Level mlir::sparse_tensor::getCOOStart(SparseTensorEncodingAttr enc) {
   return lvlRank;
 }
 
-// Helper to setup a COO type.
-RankedTensorType sparse_tensor::getCOOFromTypeWithOrdering(RankedTensorType rtt,
-                                                           AffineMap lvlPerm,
-                                                           bool ordered) {
-  const SparseTensorType src(rtt);
-  const Level lvlRank = src.getLvlRank();
-  SmallVector<LevelType> lvlTypes;
-  lvlTypes.reserve(lvlRank);
-
-  // An unordered and non-unique compressed level at beginning.
-  // If this is also the last level, then it is unique.
-  lvlTypes.push_back(
-      *buildLevelType(LevelFormat::Compressed, ordered, lvlRank == 1));
-  if (lvlRank > 1) {
-    // TODO: it is actually ordered at the level for ordered input.
-    // Followed by unordered non-unique n-2 singleton levels.
-    std::fill_n(std::back_inserter(lvlTypes), lvlRank - 2,
-                *buildLevelType(LevelFormat::Singleton, ordered, false));
-    // Ends by a unique singleton level unless the lvlRank is 1.
-    lvlTypes.push_back(*buildLevelType(LevelFormat::Singleton, ordered, true));
-  }
-
-  // TODO: Maybe pick the bitwidth based on input/output tensors (probably the
-  // largest one among them) in the original operation instead of using the
-  // default value.
-  unsigned posWidth = src.getPosWidth();
-  unsigned crdWidth = src.getCrdWidth();
-  AffineMap invPerm = src.getLvlToDim();
-  auto enc = SparseTensorEncodingAttr::get(src.getContext(), lvlTypes, lvlPerm,
-                                           invPerm, posWidth, crdWidth);
-  return RankedTensorType::get(src.getDimShape(), src.getElementType(), enc);
-}
-
 Dimension mlir::sparse_tensor::toDim(SparseTensorEncodingAttr enc, Level l) {
   if (enc) {
     assert(enc.isPermutation() && "Non permutation map not supported");

>From 6b2bb9b550c4b63a5b25c2e770ad0deae3bf32a9 Mon Sep 17 00:00:00 2001
From: Aart Bik <ajcbik at google.com>
Date: Tue, 28 Nov 2023 15:06:02 -0800
Subject: [PATCH 4/5] feedback

---
 mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 50c62bbba9d2437..20a091a81a26c25 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -711,7 +711,7 @@ LogicalResult SparseTensorEncodingAttr::verifyEncoding(
 }
 
 //===----------------------------------------------------------------------===//
-// SparseTensorType SparseTensorType Methods.
+// SparseTensorType Methods.
 //===----------------------------------------------------------------------===//
 
 RankedTensorType

>From 56f339ad734128afe9e648e3017b04e3477760ff Mon Sep 17 00:00:00 2001
From: Aart Bik <ajcbik at google.com>
Date: Tue, 28 Nov 2023 15:17:39 -0800
Subject: [PATCH 5/5] comments

---
 mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp | 8 ++++----
 1 file changed, 4 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 20a091a81a26c25..74d2fd5fd9f829c 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -718,15 +718,15 @@ RankedTensorType
 mlir::sparse_tensor::SparseTensorType::getCOOType(bool ordered) const {
   SmallVector<LevelType> lvlTypes;
   lvlTypes.reserve(lvlRank);
-  // An unordered and non-unique compressed level at beginning.
-  // If this is also the last level, then it is unique.
+  // A non-unique compressed level at beginning (unless this is
+  // also the last level, then it is unique).
   lvlTypes.push_back(
       *buildLevelType(LevelFormat::Compressed, ordered, lvlRank == 1));
   if (lvlRank > 1) {
-    // Followed by unordered non-unique n-2 singleton levels.
+    // Followed by n-2 non-unique singleton levels.
     std::fill_n(std::back_inserter(lvlTypes), lvlRank - 2,
                 *buildLevelType(LevelFormat::Singleton, ordered, false));
-    // Ends by a unique singleton level unless the lvlRank is 1.
+    // Ends by a unique singleton level.
     lvlTypes.push_back(*buildLevelType(LevelFormat::Singleton, ordered, true));
   }
   auto enc = SparseTensorEncodingAttr::get(getContext(), lvlTypes,



More information about the Mlir-commits mailing list