[Mlir-commits] [mlir] bb4fc6b - [mlir][sparse] Adding `SparseTensorType::{operator==, hasSameDimToLvlMap}`
wren romano
llvmlistbot at llvm.org
Wed Feb 15 12:05:38 PST 2023
Author: wren romano
Date: 2023-02-15T12:05:29-08:00
New Revision: bb4fc6b6d6b41da9985db0f9b294189e25da4a72
URL: https://github.com/llvm/llvm-project/commit/bb4fc6b6d6b41da9985db0f9b294189e25da4a72
DIFF: https://github.com/llvm/llvm-project/commit/bb4fc6b6d6b41da9985db0f9b294189e25da4a72.diff
LOG: [mlir][sparse] Adding `SparseTensorType::{operator==, hasSameDimToLvlMap}`
Depends On D143800
Reviewed By: aartbik, Peiming
Differential Revision: https://reviews.llvm.org/D144052
Added:
Modified:
mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
index ba31fa72cfd82..4eeaa39e84236 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
@@ -105,6 +105,16 @@ class SparseTensorType {
/// implicit conversion.
RankedTensorType getRankedTensorType() const { return rtp; }
+ bool operator==(const SparseTensorType &other) const {
+ // All other fields are derived from `rtp` and therefore don't need
+ // to be checked.
+ return rtp == other.rtp;
+ }
+
+ bool operator!=(const SparseTensorType &other) const {
+ return !(*this == other);
+ }
+
MLIRContext *getContext() const { return rtp.getContext(); }
Type getElementType() const { return rtp.getElementType(); }
@@ -130,6 +140,8 @@ class SparseTensorType {
bool isIdentity() const { return !dim2lvl; }
/// Returns the dimToLvl mapping (or the null-map for the identity).
+ /// If you intend to compare the results of this method for equality,
+ /// see `hasSameDimToLvlMap` instead.
AffineMap getDimToLvlMap() const { return dim2lvl; }
/// Returns the dimToLvl mapping, where the identity map is expanded out
@@ -142,6 +154,17 @@ class SparseTensorType {
: AffineMap::getMultiDimIdentityMap(getDimRank(), getContext());
}
+ /// Returns true iff the two types have the same mapping. This method
+ /// takes care to handle identity maps properly, so it should be preferred
+ /// over using `getDimToLvlMap` followed by `AffineMap::operator==`.
+ bool hasSameDimToLvlMap(const SparseTensorType &other) const {
+ // If the maps are the identity, then we need to check the rank
+ // to be sure they're the same size identity. (And since identity
+ // means dimRank==lvlRank, we use lvlRank as a minor optimization.)
+ return isIdentity() ? (other.isIdentity() && lvlRank == other.lvlRank)
+ : (dim2lvl == other.dim2lvl);
+ }
+
/// Returns the dimension-rank.
Dimension getDimRank() const { return rtp.getRank(); }
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index 9ec1d7c22c801..7046306ef17ef 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -48,12 +48,6 @@ static bool isSparseTensor(OpOperand *op) {
llvm::is_contained(enc.getDimLevelType(), DimLevelType::Compressed);
}
-static bool hasSameDimOrdering(RankedTensorType rtp1, RankedTensorType rtp2) {
- assert(rtp1.getRank() == rtp2.getRank());
- return SparseTensorType(rtp1).getDimToLvlMap() ==
- SparseTensorType(rtp2).getDimToLvlMap();
-}
-
// Helper method to find zero/uninitialized allocation.
static bool isAlloc(OpOperand *op, bool isZero) {
Value val = op->get();
@@ -796,8 +790,9 @@ struct ConvertRewriter : public OpRewritePattern<ConvertOp> {
// 2. the src tensor is not ordered in the same way as the target
// tensor (e.g., src tensor is not ordered or src tensor haves a
diff erent
// dimOrdering).
- if (!isUniqueCOOType(srcRTT) && !(SparseTensorType(srcRTT).isAllOrdered() &&
- hasSameDimOrdering(srcRTT, dstTp))) {
+ if (const SparseTensorType srcTp(srcRTT);
+ !isUniqueCOOType(srcRTT) &&
+ !(srcTp.isAllOrdered() && srcTp.hasSameDimToLvlMap(dstTp))) {
// Construct a COO tensor from the src tensor.
// TODO: there may be cases for which more efficiently without
// going through an intermediate COO, such as cases that only change
@@ -841,7 +836,7 @@ struct ConvertRewriter : public OpRewritePattern<ConvertOp> {
// Sort the COO tensor so that its elements are ordered via increasing
// indices for the storage ordering of the dst tensor. Use SortCoo if the
// COO tensor has the same dim ordering as the dst tensor.
- if (dimRank > 1 && hasSameDimOrdering(srcTp, dstTp)) {
+ if (dimRank > 1 && srcTp.hasSameDimToLvlMap(dstTp)) {
MemRefType indTp =
get1DMemRefType(getIndexOverheadType(rewriter, encSrc),
/*withLayout=*/false);
More information about the Mlir-commits
mailing list