[Mlir-commits] [mlir] bb4fc6b - [mlir][sparse] Adding `SparseTensorType::{operator==, hasSameDimToLvlMap}`

wren romano llvmlistbot at llvm.org
Wed Feb 15 12:05:38 PST 2023


Author: wren romano
Date: 2023-02-15T12:05:29-08:00
New Revision: bb4fc6b6d6b41da9985db0f9b294189e25da4a72

URL: https://github.com/llvm/llvm-project/commit/bb4fc6b6d6b41da9985db0f9b294189e25da4a72
DIFF: https://github.com/llvm/llvm-project/commit/bb4fc6b6d6b41da9985db0f9b294189e25da4a72.diff

LOG: [mlir][sparse] Adding `SparseTensorType::{operator==, hasSameDimToLvlMap}`

Depends On D143800

Reviewed By: aartbik, Peiming

Differential Revision: https://reviews.llvm.org/D144052

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
index ba31fa72cfd82..4eeaa39e84236 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
@@ -105,6 +105,16 @@ class SparseTensorType {
   /// implicit conversion.
   RankedTensorType getRankedTensorType() const { return rtp; }
 
+  bool operator==(const SparseTensorType &other) const {
+    // All other fields are derived from `rtp` and therefore don't need
+    // to be checked.
+    return rtp == other.rtp;
+  }
+
+  bool operator!=(const SparseTensorType &other) const {
+    return !(*this == other);
+  }
+
   MLIRContext *getContext() const { return rtp.getContext(); }
 
   Type getElementType() const { return rtp.getElementType(); }
@@ -130,6 +140,8 @@ class SparseTensorType {
   bool isIdentity() const { return !dim2lvl; }
 
   /// Returns the dimToLvl mapping (or the null-map for the identity).
+  /// If you intend to compare the results of this method for equality,
+  /// see `hasSameDimToLvlMap` instead.
   AffineMap getDimToLvlMap() const { return dim2lvl; }
 
   /// Returns the dimToLvl mapping, where the identity map is expanded out
@@ -142,6 +154,17 @@ class SparseTensorType {
                : AffineMap::getMultiDimIdentityMap(getDimRank(), getContext());
   }
 
+  /// Returns true iff the two types have the same mapping.  This method
+  /// takes care to handle identity maps properly, so it should be preferred
+  /// over using `getDimToLvlMap` followed by `AffineMap::operator==`.
+  bool hasSameDimToLvlMap(const SparseTensorType &other) const {
+    // If the maps are the identity, then we need to check the rank
+    // to be sure they're the same size identity.  (And since identity
+    // means dimRank==lvlRank, we use lvlRank as a minor optimization.)
+    return isIdentity() ? (other.isIdentity() && lvlRank == other.lvlRank)
+                        : (dim2lvl == other.dim2lvl);
+  }
+
   /// Returns the dimension-rank.
   Dimension getDimRank() const { return rtp.getRank(); }
 

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index 9ec1d7c22c801..7046306ef17ef 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -48,12 +48,6 @@ static bool isSparseTensor(OpOperand *op) {
          llvm::is_contained(enc.getDimLevelType(), DimLevelType::Compressed);
 }
 
-static bool hasSameDimOrdering(RankedTensorType rtp1, RankedTensorType rtp2) {
-  assert(rtp1.getRank() == rtp2.getRank());
-  return SparseTensorType(rtp1).getDimToLvlMap() ==
-         SparseTensorType(rtp2).getDimToLvlMap();
-}
-
 // Helper method to find zero/uninitialized allocation.
 static bool isAlloc(OpOperand *op, bool isZero) {
   Value val = op->get();
@@ -796,8 +790,9 @@ struct ConvertRewriter : public OpRewritePattern<ConvertOp> {
     // 2. the src tensor is not ordered in the same way as the target
     // tensor (e.g., src tensor is not ordered or src tensor haves a 
diff erent
     // dimOrdering).
-    if (!isUniqueCOOType(srcRTT) && !(SparseTensorType(srcRTT).isAllOrdered() &&
-                                      hasSameDimOrdering(srcRTT, dstTp))) {
+    if (const SparseTensorType srcTp(srcRTT);
+        !isUniqueCOOType(srcRTT) &&
+        !(srcTp.isAllOrdered() && srcTp.hasSameDimToLvlMap(dstTp))) {
       // Construct a COO tensor from the src tensor.
       // TODO: there may be cases for which more efficiently without
       // going through an intermediate COO, such as cases that only change
@@ -841,7 +836,7 @@ struct ConvertRewriter : public OpRewritePattern<ConvertOp> {
       // Sort the COO tensor so that its elements are ordered via increasing
       // indices for the storage ordering of the dst tensor. Use SortCoo if the
       // COO tensor has the same dim ordering as the dst tensor.
-      if (dimRank > 1 && hasSameDimOrdering(srcTp, dstTp)) {
+      if (dimRank > 1 && srcTp.hasSameDimToLvlMap(dstTp)) {
         MemRefType indTp =
             get1DMemRefType(getIndexOverheadType(rewriter, encSrc),
                             /*withLayout=*/false);


        


More information about the Mlir-commits mailing list