[Mlir-commits] [mlir] [mlir][sparse] move all COO related methods into SparseTensorType (PR #73881)

Aart Bik llvmlistbot at llvm.org
Wed Nov 29 16:53:07 PST 2023


https://github.com/aartbik updated https://github.com/llvm/llvm-project/pull/73881

>From 7d872d07cee1c34e72bb6f53e8a249c680dc20ae Mon Sep 17 00:00:00 2001
From: Aart Bik <ajcbik at google.com>
Date: Wed, 29 Nov 2023 16:39:05 -0800
Subject: [PATCH 1/2] [mlir][sparse] move all COO related methods into
 SparseTensorType

This centralizes all COO methods, and provides a cleaner API.
Note that the "enc" only constructor is a temporary workaround
the need for COO methods inside the "enc" only storage specifier.
---
 .../Dialect/SparseTensor/IR/SparseTensor.h    | 13 ---
 .../SparseTensor/IR/SparseTensorType.h        | 20 ++++-
 .../SparseTensor/IR/SparseTensorDialect.cpp   | 79 ++++++++-----------
 .../SparseTensor/Transforms/LoopEmitter.cpp   |  7 +-
 .../Transforms/SparseTensorCodegen.cpp        |  9 +--
 .../Transforms/SparseTensorDescriptor.cpp     |  2 +-
 .../Transforms/SparseTensorDescriptor.h       |  2 +-
 .../Transforms/SparseTensorRewriting.cpp      |  4 +-
 8 files changed, 64 insertions(+), 72 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
index 28dfdbdcf89b5bf..5e523ec428aefb9 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
@@ -89,19 +89,6 @@ inline MemRefType getMemRefType(T &&t) {
 /// Returns null-attribute for any type without an encoding.
 SparseTensorEncodingAttr getSparseTensorEncoding(Type type);
 
-/// Returns true iff the given sparse tensor encoding attribute has a trailing
-/// COO region starting at the given level.
-bool isCOOType(SparseTensorEncodingAttr enc, Level startLvl, bool isUnique);
-
-/// Returns true iff the given type is a COO type where the last level
-/// is unique.
-bool isUniqueCOOType(Type tp);
-
-/// Returns the starting level for a trailing COO region that spans
-/// at least two levels.  If no such COO region is found, then returns
-/// the level-rank.
-Level getCOOStart(SparseTensorEncodingAttr enc);
-
 /// Returns true iff MLIR operand has any sparse operand.
 inline bool hasAnySparseOperand(Operation *op) {
   return llvm::any_of(op->getOperands().getTypes(), [](Type t) {
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
index dc520e390de293d..4c98129744bcd94 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
@@ -60,6 +60,12 @@ class SparseTensorType {
       : SparseTensorType(
             RankedTensorType::get(stp.getShape(), stp.getElementType(), enc)) {}
 
+  // TODO: remove?
+  SparseTensorType(SparseTensorEncodingAttr enc)
+      : SparseTensorType(RankedTensorType::get(
+            SmallVector<Size>(enc.getDimRank(), ShapedType::kDynamic),
+            Float32Type::get(enc.getContext()), enc)) {}
+
   SparseTensorType &operator=(const SparseTensorType &) = delete;
   SparseTensorType(const SparseTensorType &) = default;
 
@@ -234,9 +240,9 @@ class SparseTensorType {
                                        CrdTransDirectionKind::dim2lvl);
   }
 
+  /// Returns the type with an identity mapping.
   RankedTensorType getDemappedType() const {
-    auto lvlShape = getLvlShape();
-    return RankedTensorType::get(lvlShape, rtp.getElementType(),
+    return RankedTensorType::get(getLvlShape(), getElementType(),
                                  enc.withoutDimToLvl());
   }
 
@@ -311,6 +317,16 @@ class SparseTensorType {
     return IndexType::get(getContext());
   }
 
+  /// Returns true iff this sparse tensor type has a trailing
+  /// COO region starting at the given level. By default, it
+  /// tests for a unique COO type at top level.
+  bool isCOOType(Level startLvl = 0, bool isUnique = true) const;
+
+  /// Returns the starting level of this sparse tensor type for a
+  /// trailing COO region that spans **at least** two levels. If
+  /// no such COO region is found, then returns the level-rank.
+  Level getCOOStart() const;
+
   /// Returns [un]ordered COO type for this sparse tensor type.
   RankedTensorType getCOOType(bool ordered) const;
 
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index d4f8afdd62f2383..7dc4fc4f8570d60 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -66,7 +66,7 @@ void StorageLayout::foreachField(
         callback) const {
   const auto lvlTypes = enc.getLvlTypes();
   const Level lvlRank = enc.getLvlRank();
-  const Level cooStart = getCOOStart(enc);
+  const Level cooStart = SparseTensorType(enc).getCOOStart();
   const Level end = cooStart == lvlRank ? cooStart : cooStart + 1;
   FieldIndex fieldIdx = kDataFieldStartingIdx;
   // Per-level storage.
@@ -158,7 +158,7 @@ StorageLayout::getFieldIndexAndStride(SparseTensorFieldKind kind,
   unsigned stride = 1;
   if (kind == SparseTensorFieldKind::CrdMemRef) {
     assert(lvl.has_value());
-    const Level cooStart = getCOOStart(enc);
+    const Level cooStart = SparseTensorType(enc).getCOOStart();
     const Level lvlRank = enc.getLvlRank();
     if (lvl.value() >= cooStart && lvl.value() < lvlRank) {
       lvl = cooStart;
@@ -710,6 +710,28 @@ LogicalResult SparseTensorEncodingAttr::verifyEncoding(
 // SparseTensorType Methods.
 //===----------------------------------------------------------------------===//
 
+bool mlir::sparse_tensor::SparseTensorType::isCOOType(Level startLvl, bool isUnique) const {
+  if (!hasEncoding())
+    return false;
+  if (!isCompressedLvl(startLvl) && !isLooseCompressedLvl(startLvl))
+    return false;
+  for (Level l = startLvl + 1; l < lvlRank; ++l)
+    if (!isSingletonLvl(l))
+      return false;
+  // If isUnique is true, then make sure that the last level is unique,
+  // that is, lvlRank == 1 (unique the only compressed) and lvlRank > 1
+  // (unique on the last singleton).
+  return !isUnique || isUniqueLvl(lvlRank - 1);
+}
+
+Level mlir::sparse_tensor::SparseTensorType::getCOOStart() const {
+  if (lvlRank > 1)
+    for (Level l = 0; l < lvlRank - 1; l++)
+      if (isCOOType(l, /*isUnique=*/false))
+        return l;
+  return lvlRank;
+}
+
 RankedTensorType
 mlir::sparse_tensor::SparseTensorType::getCOOType(bool ordered) const {
   SmallVector<LevelType> lvlTypes;
@@ -859,25 +881,6 @@ bool mlir::sparse_tensor::isBlockSparsity(AffineMap dimToLvl) {
   return !coeffientMap.empty();
 }
 
-bool mlir::sparse_tensor::isCOOType(SparseTensorEncodingAttr enc,
-                                    Level startLvl, bool isUnique) {
-  if (!enc ||
-      !(enc.isCompressedLvl(startLvl) || enc.isLooseCompressedLvl(startLvl)))
-    return false;
-  const Level lvlRank = enc.getLvlRank();
-  for (Level l = startLvl + 1; l < lvlRank; ++l)
-    if (!enc.isSingletonLvl(l))
-      return false;
-  // If isUnique is true, then make sure that the last level is unique,
-  // that is, lvlRank == 1 (unique the only compressed) and lvlRank > 1
-  // (unique on the last singleton).
-  return !isUnique || enc.isUniqueLvl(lvlRank - 1);
-}
-
-bool mlir::sparse_tensor::isUniqueCOOType(Type tp) {
-  return isCOOType(getSparseTensorEncoding(tp), 0, /*isUnique=*/true);
-}
-
 bool mlir::sparse_tensor::hasAnyNonIdentityOperandsOrResults(Operation *op) {
   auto hasNonIdentityMap = [](Value v) {
     auto stt = tryGetSparseTensorType(v);
@@ -888,17 +891,6 @@ bool mlir::sparse_tensor::hasAnyNonIdentityOperandsOrResults(Operation *op) {
          llvm::any_of(op->getResults(), hasNonIdentityMap);
 }
 
-Level mlir::sparse_tensor::getCOOStart(SparseTensorEncodingAttr enc) {
-  // We only consider COO region with at least two levels for the purpose
-  // of AOS storage optimization.
-  const Level lvlRank = enc.getLvlRank();
-  if (lvlRank > 1)
-    for (Level l = 0; l < lvlRank - 1; l++)
-      if (isCOOType(enc, l, /*isUnique=*/false))
-        return l;
-  return lvlRank;
-}
-
 Dimension mlir::sparse_tensor::toDim(SparseTensorEncodingAttr enc, Level l) {
   if (enc) {
     assert(enc.isPermutation() && "Non permutation map not supported");
@@ -1013,7 +1005,7 @@ static LogicalResult verifyPackUnPack(Operation *op, bool requiresStaticShape,
     return op->emitError("the sparse-tensor must have the identity mapping");
 
   // Verifies the trailing COO.
-  Level cooStartLvl = getCOOStart(stt.getEncoding());
+  Level cooStartLvl = stt.getCOOStart();
   if (cooStartLvl < stt.getLvlRank()) {
     // We only supports trailing COO for now, must be the last input.
     auto cooTp = llvm::cast<ShapedType>(lvlTps.back());
@@ -1309,34 +1301,34 @@ OpFoldResult ReinterpretMapOp::fold(FoldAdaptor adaptor) {
 }
 
 LogicalResult ToPositionsOp::verify() {
-  auto e = getSparseTensorEncoding(getTensor().getType());
+  auto stt = getSparseTensorType(getTensor());
   if (failed(lvlIsInBounds(getLevel(), getTensor())))
     return emitError("requested level is out of bounds");
-  if (failed(isMatchingWidth(getResult(), e.getPosWidth())))
+  if (failed(isMatchingWidth(getResult(), stt.getPosWidth())))
     return emitError("unexpected type for positions");
   return success();
 }
 
 LogicalResult ToCoordinatesOp::verify() {
-  auto e = getSparseTensorEncoding(getTensor().getType());
+  auto stt = getSparseTensorType(getTensor());
   if (failed(lvlIsInBounds(getLevel(), getTensor())))
     return emitError("requested level is out of bounds");
-  if (failed(isMatchingWidth(getResult(), e.getCrdWidth())))
+  if (failed(isMatchingWidth(getResult(), stt.getCrdWidth())))
     return emitError("unexpected type for coordinates");
   return success();
 }
 
 LogicalResult ToCoordinatesBufferOp::verify() {
-  auto e = getSparseTensorEncoding(getTensor().getType());
-  if (getCOOStart(e) >= e.getLvlRank())
+  auto stt = getSparseTensorType(getTensor());
+  if (stt.getCOOStart() >= stt.getLvlRank())
     return emitError("expected sparse tensor with a COO region");
   return success();
 }
 
 LogicalResult ToValuesOp::verify() {
-  auto ttp = getRankedTensorType(getTensor());
+  auto stt = getSparseTensorType(getTensor());
   auto mtp = getMemRefType(getResult());
-  if (ttp.getElementType() != mtp.getElementType())
+  if (stt.getElementType() != mtp.getElementType())
     return emitError("unexpected mismatch in element types");
   return success();
 }
@@ -1660,9 +1652,8 @@ LogicalResult ReorderCOOOp::verify() {
   SparseTensorType srcStt = getSparseTensorType(getInputCoo());
   SparseTensorType dstStt = getSparseTensorType(getResultCoo());
 
-  if (!isCOOType(srcStt.getEncoding(), 0, /*isUnique=*/true) ||
-      !isCOOType(dstStt.getEncoding(), 0, /*isUnique=*/true))
-    emitError("Unexpected non-COO sparse tensors");
+  if (!srcStt.isCOOType() || !dstStt.isCOOType())
+    emitError("Expected COO sparse tensors only");
 
   if (!srcStt.hasSameDimToLvl(dstStt))
     emitError("Unmatched dim2lvl map between input and result COO");
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
index a245344755f0404..26f015ce6ec64f7 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
@@ -412,8 +412,7 @@ void LoopEmitter::initializeLoopEmit(
     auto stt = getSparseTensorType(tensor);
     const Level lvlRank = stt.getLvlRank();
     const auto shape = rtp.getShape();
-    const auto enc = getSparseTensorEncoding(rtp);
-    const Level cooStart = enc ? getCOOStart(enc) : lvlRank;
+    const Level cooStart = stt.getCOOStart();
 
     SmallVector<Value> lvlSzs;
     for (Level l = 0; l < stt.getLvlRank(); l++) {
@@ -457,8 +456,8 @@ void LoopEmitter::initializeLoopEmit(
     // values.
     // Delegates extra output initialization to clients.
     bool isOutput = isOutputTensor(t);
-    Type elementType = rtp.getElementType();
-    if (!enc) {
+    Type elementType = stt.getElementType();
+    if (!stt.hasEncoding()) {
       // Non-annotated dense tensors.
       BaseMemRefType denseTp = MemRefType::get(shape, elementType);
 
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index e9062b49435f5b7..18b2bb0819e2642 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -194,7 +194,7 @@ static void createAllocFields(OpBuilder &builder, Location loc,
       valHeuristic =
           builder.create<arith::MulIOp>(loc, valHeuristic, lvlSizesValues[lvl]);
   } else if (sizeHint) {
-    if (getCOOStart(stt.getEncoding()) == 0) {
+    if (stt.getCOOStart() == 0) {
       posHeuristic = constantIndex(builder, loc, 2);
       crdHeuristic = builder.create<arith::MulIOp>(
           loc, constantIndex(builder, loc, lvlRank), sizeHint); // AOS
@@ -657,8 +657,7 @@ struct SparseReorderCOOConverter : public OpConversionPattern<ReorderCOOOp> {
 
     // Should have been verified.
     assert(dstStt.isAllOrdered() && !srcStt.isAllOrdered() &&
-           isUniqueCOOType(srcStt.getRankedTensorType()) &&
-           isUniqueCOOType(dstStt.getRankedTensorType()));
+           dstStt.isCOOType() && srcStt.isCOOType());
     assert(dstStt.hasSameDimToLvl(srcStt));
 
     // We don't need a mutable descriptor here as we perform sorting in-place.
@@ -1317,7 +1316,7 @@ struct SparseAssembleOpConverter : public OpConversionPattern<AssembleOp> {
     Value posBack = c0; // index to the last value in the position array
     Value memSize = c1; // memory size for current array
 
-    Level trailCOOStart = getCOOStart(stt.getEncoding());
+    Level trailCOOStart = stt.getCOOStart();
     Level trailCOORank = stt.getLvlRank() - trailCOOStart;
     // Sets up SparseTensorSpecifier.
     for (Level lvl = 0, lvlRank = stt.getLvlRank(); lvl < lvlRank; lvl++) {
@@ -1454,7 +1453,7 @@ struct SparseNewConverter : public OpConversionPattern<NewOp> {
     const auto dstTp = getSparseTensorType(op.getResult());
     // Creating COO with NewOp is handled by direct IR codegen. All other cases
     // are handled by rewriting.
-    if (!dstTp.hasEncoding() || getCOOStart(dstTp.getEncoding()) != 0)
+    if (!dstTp.hasEncoding() || dstTp.getCOOStart() != 0)
       return failure();
 
     // Implement as follows:
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorDescriptor.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorDescriptor.cpp
index 1c6d7bebe37e46c..3ab4157475cd4c2 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorDescriptor.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorDescriptor.cpp
@@ -103,7 +103,7 @@ void SparseTensorSpecifier::setSpecifierField(OpBuilder &builder, Location loc,
 
 Value sparse_tensor::SparseTensorDescriptor::getCrdMemRefOrView(
     OpBuilder &builder, Location loc, Level lvl) const {
-  const Level cooStart = getCOOStart(rType.getEncoding());
+  const Level cooStart = rType.getCOOStart();
   if (lvl < cooStart)
     return getMemRefField(SparseTensorFieldKind::CrdMemRef, lvl);
 
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorDescriptor.h b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorDescriptor.h
index 4bd700eef522e04..5c7d8aa4c9d9678 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorDescriptor.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorDescriptor.h
@@ -137,7 +137,7 @@ class SparseTensorDescriptorImpl {
   }
 
   Value getAOSMemRef() const {
-    const Level cooStart = getCOOStart(rType.getEncoding());
+    const Level cooStart = rType.getCOOStart();
     assert(cooStart < rType.getLvlRank());
     return getMemRefField(SparseTensorFieldKind::CrdMemRef, cooStart);
   }
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index 2bd129b85ea5416..4fc692f2fe9ddc2 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -1180,8 +1180,7 @@ struct NewRewriter : public OpRewritePattern<NewOp> {
                                 PatternRewriter &rewriter) const override {
     Location loc = op.getLoc();
     auto stt = getSparseTensorType(op.getResult());
-    auto enc = stt.getEncoding();
-    if (!stt.hasEncoding() || getCOOStart(enc) == 0)
+    if (!stt.hasEncoding() || stt.getCOOStart() == 0)
       return failure();
 
     // Implement the NewOp as follows:
@@ -1192,6 +1191,7 @@ struct NewRewriter : public OpRewritePattern<NewOp> {
     RankedTensorType cooTp = stt.getCOOType(/*ordered=*/true);
     Value cooTensor = rewriter.create<NewOp>(loc, cooTp, op.getSource());
     Value convert = cooTensor;
+    auto enc = stt.getEncoding();
     if (!stt.isPermutation()) { // demap coo, demap dstTp
       auto coo = getSparseTensorType(cooTensor).getEncoding().withoutDimToLvl();
       convert = rewriter.create<ReinterpretMapOp>(loc, coo, convert);

>From c16256b1408cec8195d1f6dbe5674170a8f6a630 Mon Sep 17 00:00:00 2001
From: Aart Bik <ajcbik at google.com>
Date: Wed, 29 Nov 2023 16:52:41 -0800
Subject: [PATCH 2/2] clang-format

---
 mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp | 3 ++-
 1 file changed, 2 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 7dc4fc4f8570d60..90ac9e58e60d96c 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -710,7 +710,8 @@ LogicalResult SparseTensorEncodingAttr::verifyEncoding(
 // SparseTensorType Methods.
 //===----------------------------------------------------------------------===//
 
-bool mlir::sparse_tensor::SparseTensorType::isCOOType(Level startLvl, bool isUnique) const {
+bool mlir::sparse_tensor::SparseTensorType::isCOOType(Level startLvl,
+                                                      bool isUnique) const {
   if (!hasEncoding())
     return false;
   if (!isCompressedLvl(startLvl) && !isLooseCompressedLvl(startLvl))



More information about the Mlir-commits mailing list