[Mlir-commits] [mlir] [mlir][sparse] move toCOOType into SparseTensorType class (PR #73708)

Aart Bik llvmlistbot at llvm.org
Tue Nov 28 14:53:26 PST 2023


https://github.com/aartbik updated https://github.com/llvm/llvm-project/pull/73708

>From 472964073950e87c808460f1cc56954a3c6a648c Mon Sep 17 00:00:00 2001
From: Aart Bik <ajcbik at google.com>
Date: Tue, 28 Nov 2023 14:45:30 -0800
Subject: [PATCH 1/2] [mlir][sparse] move toCOOType into SparseTensorType class

Migrates dangling convenience method into proper SparseTensorType
class. Also cleans up some details (picking right dim2lvl/lvl2dim).
Removes more dead code.
---
 .../Dialect/SparseTensor/IR/SparseTensor.h    |  4 ---
 .../SparseTensor/IR/SparseTensorType.h        | 29 ++++++-------------
 .../SparseTensor/IR/SparseTensorDialect.cpp   | 29 +++++++++++++++++--
 .../IR/SparseTensorInterfaces.cpp             |  7 ++---
 4 files changed, 38 insertions(+), 31 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
index 517c286e0206997..28dfdbdcf89b5bf 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
@@ -102,10 +102,6 @@ bool isUniqueCOOType(Type tp);
 /// the level-rank.
 Level getCOOStart(SparseTensorEncodingAttr enc);
 
-/// Helper to setup a COO type.
-RankedTensorType getCOOFromTypeWithOrdering(RankedTensorType src,
-                                            AffineMap ordering, bool ordered);
-
 /// Returns true iff MLIR operand has any sparse operand.
 inline bool hasAnySparseOperand(Operation *op) {
   return llvm::any_of(op->getOperands().getTypes(), [](Type t) {
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
index 4eb666d76cd2d6f..dc520e390de293d 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
@@ -64,18 +64,14 @@ class SparseTensorType {
   SparseTensorType(const SparseTensorType &) = default;
 
   //
-  // Factory methods.
+  // Factory methods to construct a new `SparseTensorType`
+  // with the same dimension-shape and element type.
   //
 
-  /// Constructs a new `SparseTensorType` with the same dimension-shape
-  /// and element type, but with the encoding replaced by the given encoding.
   SparseTensorType withEncoding(SparseTensorEncodingAttr newEnc) const {
     return SparseTensorType(rtp, newEnc);
   }
 
-  /// Constructs a new `SparseTensorType` with the same dimension-shape
-  /// and element type, but with the encoding replaced by
-  /// `getEncoding().withDimToLvl(dimToLvl)`.
   SparseTensorType withDimToLvl(AffineMap dimToLvl) const {
     return withEncoding(enc.withDimToLvl(dimToLvl));
   }
@@ -88,23 +84,14 @@ class SparseTensorType {
     return withDimToLvl(dimToLvlSTT.getEncoding());
   }
 
-  /// Constructs a new `SparseTensorType` with the same dimension-shape
-  /// and element type, but with the encoding replaced by
-  /// `getEncoding().withoutDimToLvl()`.
   SparseTensorType withoutDimToLvl() const {
     return withEncoding(enc.withoutDimToLvl());
   }
 
-  /// Constructs a new `SparseTensorType` with the same dimension-shape
-  /// and element type, but with the encoding replaced by
-  /// `getEncoding().withBitWidths(posWidth, crdWidth)`.
   SparseTensorType withBitWidths(unsigned posWidth, unsigned crdWidth) const {
     return withEncoding(enc.withBitWidths(posWidth, crdWidth));
   }
 
-  /// Constructs a new `SparseTensorType` with the same dimension-shape
-  /// and element type, but with the encoding replaced by
-  /// `getEncoding().withoutBitWidths()`.
   SparseTensorType withoutBitWidths() const {
     return withEncoding(enc.withoutBitWidths());
   }
@@ -118,10 +105,6 @@ class SparseTensorType {
     return withEncoding(enc.withoutDimSlices());
   }
 
-  //
-  // Other methods.
-  //
-
   /// Allow implicit conversion to `RankedTensorType`, `ShapedType`,
   /// and `Type`.  These are implicit to help alleviate the impedance
   /// mismatch for code that has not been converted to use `SparseTensorType`
@@ -170,7 +153,6 @@ class SparseTensorType {
 
   Type getElementType() const { return rtp.getElementType(); }
 
-  /// Returns the encoding (or the null-attribute for dense-tensors).
   SparseTensorEncodingAttr getEncoding() const { return enc; }
 
   //
@@ -204,6 +186,10 @@ class SparseTensorType {
   /// (This is always true for dense-tensors.)
   bool isIdentity() const { return enc.isIdentity(); }
 
+  //
+  // Other methods.
+  //
+
   /// Returns the dimToLvl mapping (or the null-map for the identity).
   /// If you intend to compare the results of this method for equality,
   /// see `hasSameDimToLvl` instead.
@@ -325,6 +311,9 @@ class SparseTensorType {
     return IndexType::get(getContext());
   }
 
+  /// Returns [un]ordered COO type for this sparse tensor type.
+  RankedTensorType getCOOType(bool ordered) const;
+
 private:
   // These two must be const, to ensure coherence of the memoized fields.
   const RankedTensorType rtp;
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index ff2930008fa093f..edf7df3cfedabba 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -36,7 +36,7 @@ using namespace mlir;
 using namespace mlir::sparse_tensor;
 
 //===----------------------------------------------------------------------===//
-// Local convenience methods.
+// Local Convenience Methods.
 //===----------------------------------------------------------------------===//
 
 static constexpr bool acceptBitWidth(unsigned bitWidth) {
@@ -711,7 +711,32 @@ LogicalResult SparseTensorEncodingAttr::verifyEncoding(
 }
 
 //===----------------------------------------------------------------------===//
-// Convenience methods.
+// SparseTensorType SparseTensorType Methods.
+//===----------------------------------------------------------------------===//
+
+RankedTensorType
+mlir::sparse_tensor::SparseTensorType::getCOOType(bool ordered) const {
+  SmallVector<LevelType> lvlTypes;
+  lvlTypes.reserve(lvlRank);
+  // An unordered and non-unique compressed level at beginning.
+  // If this is also the last level, then it is unique.
+  lvlTypes.push_back(
+      *buildLevelType(LevelFormat::Compressed, ordered, lvlRank == 1));
+  if (lvlRank > 1) {
+    // Followed by unordered non-unique n-2 singleton levels.
+    std::fill_n(std::back_inserter(lvlTypes), lvlRank - 2,
+                *buildLevelType(LevelFormat::Singleton, ordered, false));
+    // Ends by a unique singleton level unless the lvlRank is 1.
+    lvlTypes.push_back(*buildLevelType(LevelFormat::Singleton, ordered, true));
+  }
+  auto enc = SparseTensorEncodingAttr::get(getContext(), lvlTypes,
+                                           getDimToLvl(), getLvlToDim(),
+                                           getPosWidth(), getCrdWidth());
+  return RankedTensorType::get(getDimShape(), getElementType(), enc);
+}
+
+//===----------------------------------------------------------------------===//
+// Convenience Methods.
 //===----------------------------------------------------------------------===//
 
 SparseTensorEncodingAttr
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorInterfaces.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorInterfaces.cpp
index d8769eacc44f39b..c8e77f7de48300e 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorInterfaces.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorInterfaces.cpp
@@ -25,9 +25,7 @@ sparse_tensor::detail::stageWithSortImpl(StageWithSortSparseOp op,
   Location loc = op.getLoc();
   Type finalTp = op->getOpResult(0).getType();
   SparseTensorType dstStt(finalTp.cast<RankedTensorType>());
-
-  Type srcCOOTp = getCOOFromTypeWithOrdering(
-      dstStt.getRankedTensorType(), dstStt.getDimToLvl(), /*ordered=*/false);
+  Type srcCOOTp = dstStt.getCOOType(/*ordered=*/false);
 
   // Clones the original operation but changing the output to an unordered COO.
   Operation *cloned = rewriter.clone(*op.getOperation());
@@ -37,8 +35,7 @@ sparse_tensor::detail::stageWithSortImpl(StageWithSortSparseOp op,
   Value srcCOO = cloned->getOpResult(0);
 
   // -> sort
-  Type dstCOOTp = getCOOFromTypeWithOrdering(
-      dstStt.getRankedTensorType(), dstStt.getDimToLvl(), /*ordered=*/true);
+  Type dstCOOTp = dstStt.getCOOType(/*ordered=*/true);
   Value dstCOO = rewriter.create<ReorderCOOOp>(
       loc, dstCOOTp, srcCOO, SparseTensorSortKind::HybridQuickSort);
 

>From d126b4dee618bae40b4df0b9bf78ec8c26433473 Mon Sep 17 00:00:00 2001
From: Aart Bik <ajcbik at google.com>
Date: Tue, 28 Nov 2023 14:52:52 -0800
Subject: [PATCH 2/2] rewriting file too

---
 .../SparseTensor/Transforms/SparseTensorRewriting.cpp  | 10 ++--------
 1 file changed, 2 insertions(+), 8 deletions(-)

diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index 702666e9d40c31f..2bd129b85ea5416 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -132,15 +132,9 @@ static void sizesForTensor(OpBuilder &builder, SmallVectorImpl<Value> &sizes,
   }
 }
 
-// TODO: The dim level property of the COO type relies on input tensors, the
-// shape relies on the output tensor
-static RankedTensorType getCOOType(const SparseTensorType &stt, bool ordered) {
-  return getCOOFromTypeWithOrdering(stt, stt.getDimToLvl(), ordered);
-}
-
 static RankedTensorType getBufferType(const SparseTensorType &stt,
                                       bool needTmpCOO) {
-  return needTmpCOO ? getCOOType(stt, /*ordered=*/false)
+  return needTmpCOO ? stt.getCOOType(/*ordered=*/false)
                     : stt.getRankedTensorType();
 }
 
@@ -1195,7 +1189,7 @@ struct NewRewriter : public OpRewritePattern<NewOp> {
     //   %t = sparse_tensor.convert %orderedCoo
     // with enveloping reinterpreted_map ops for non-permutations.
     RankedTensorType dstTp = stt.getRankedTensorType();
-    RankedTensorType cooTp = getCOOType(dstTp, /*ordered=*/true);
+    RankedTensorType cooTp = stt.getCOOType(/*ordered=*/true);
     Value cooTensor = rewriter.create<NewOp>(loc, cooTp, op.getSource());
     Value convert = cooTensor;
     if (!stt.isPermutation()) { // demap coo, demap dstTp



More information about the Mlir-commits mailing list