[Mlir-commits] [mlir] [mlir][sparse] move toCOOType into SparseTensorType class (PR #73708)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Nov 28 14:48:35 PST 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Aart Bik (aartbik)
<details>
<summary>Changes</summary>
Migrates dangling convenience method into proper SparseTensorType class. Also cleans up some details (picking right dim2lvl/lvl2dim). Removes more dead code.
---
Full diff: https://github.com/llvm/llvm-project/pull/73708.diff
4 Files Affected:
- (modified) mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h (-4)
- (modified) mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h (+9-20)
- (modified) mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp (+27-2)
- (modified) mlir/lib/Dialect/SparseTensor/IR/SparseTensorInterfaces.cpp (+2-5)
``````````diff
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
index 517c286e0206997..28dfdbdcf89b5bf 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
@@ -102,10 +102,6 @@ bool isUniqueCOOType(Type tp);
/// the level-rank.
Level getCOOStart(SparseTensorEncodingAttr enc);
-/// Helper to setup a COO type.
-RankedTensorType getCOOFromTypeWithOrdering(RankedTensorType src,
- AffineMap ordering, bool ordered);
-
/// Returns true iff MLIR operand has any sparse operand.
inline bool hasAnySparseOperand(Operation *op) {
return llvm::any_of(op->getOperands().getTypes(), [](Type t) {
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
index 4eb666d76cd2d6f..dc520e390de293d 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
@@ -64,18 +64,14 @@ class SparseTensorType {
SparseTensorType(const SparseTensorType &) = default;
//
- // Factory methods.
+ // Factory methods to construct a new `SparseTensorType`
+ // with the same dimension-shape and element type.
//
- /// Constructs a new `SparseTensorType` with the same dimension-shape
- /// and element type, but with the encoding replaced by the given encoding.
SparseTensorType withEncoding(SparseTensorEncodingAttr newEnc) const {
return SparseTensorType(rtp, newEnc);
}
- /// Constructs a new `SparseTensorType` with the same dimension-shape
- /// and element type, but with the encoding replaced by
- /// `getEncoding().withDimToLvl(dimToLvl)`.
SparseTensorType withDimToLvl(AffineMap dimToLvl) const {
return withEncoding(enc.withDimToLvl(dimToLvl));
}
@@ -88,23 +84,14 @@ class SparseTensorType {
return withDimToLvl(dimToLvlSTT.getEncoding());
}
- /// Constructs a new `SparseTensorType` with the same dimension-shape
- /// and element type, but with the encoding replaced by
- /// `getEncoding().withoutDimToLvl()`.
SparseTensorType withoutDimToLvl() const {
return withEncoding(enc.withoutDimToLvl());
}
- /// Constructs a new `SparseTensorType` with the same dimension-shape
- /// and element type, but with the encoding replaced by
- /// `getEncoding().withBitWidths(posWidth, crdWidth)`.
SparseTensorType withBitWidths(unsigned posWidth, unsigned crdWidth) const {
return withEncoding(enc.withBitWidths(posWidth, crdWidth));
}
- /// Constructs a new `SparseTensorType` with the same dimension-shape
- /// and element type, but with the encoding replaced by
- /// `getEncoding().withoutBitWidths()`.
SparseTensorType withoutBitWidths() const {
return withEncoding(enc.withoutBitWidths());
}
@@ -118,10 +105,6 @@ class SparseTensorType {
return withEncoding(enc.withoutDimSlices());
}
- //
- // Other methods.
- //
-
/// Allow implicit conversion to `RankedTensorType`, `ShapedType`,
/// and `Type`. These are implicit to help alleviate the impedance
/// mismatch for code that has not been converted to use `SparseTensorType`
@@ -170,7 +153,6 @@ class SparseTensorType {
Type getElementType() const { return rtp.getElementType(); }
- /// Returns the encoding (or the null-attribute for dense-tensors).
SparseTensorEncodingAttr getEncoding() const { return enc; }
//
@@ -204,6 +186,10 @@ class SparseTensorType {
/// (This is always true for dense-tensors.)
bool isIdentity() const { return enc.isIdentity(); }
+ //
+ // Other methods.
+ //
+
/// Returns the dimToLvl mapping (or the null-map for the identity).
/// If you intend to compare the results of this method for equality,
/// see `hasSameDimToLvl` instead.
@@ -325,6 +311,9 @@ class SparseTensorType {
return IndexType::get(getContext());
}
+ /// Returns [un]ordered COO type for this sparse tensor type.
+ RankedTensorType getCOOType(bool ordered) const;
+
private:
// These two must be const, to ensure coherence of the memoized fields.
const RankedTensorType rtp;
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index ff2930008fa093f..edf7df3cfedabba 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -36,7 +36,7 @@ using namespace mlir;
using namespace mlir::sparse_tensor;
//===----------------------------------------------------------------------===//
-// Local convenience methods.
+// Local Convenience Methods.
//===----------------------------------------------------------------------===//
static constexpr bool acceptBitWidth(unsigned bitWidth) {
@@ -711,7 +711,32 @@ LogicalResult SparseTensorEncodingAttr::verifyEncoding(
}
//===----------------------------------------------------------------------===//
-// Convenience methods.
+// SparseTensorType SparseTensorType Methods.
+//===----------------------------------------------------------------------===//
+
+RankedTensorType
+mlir::sparse_tensor::SparseTensorType::getCOOType(bool ordered) const {
+ SmallVector<LevelType> lvlTypes;
+ lvlTypes.reserve(lvlRank);
+ // An unordered and non-unique compressed level at beginning.
+ // If this is also the last level, then it is unique.
+ lvlTypes.push_back(
+ *buildLevelType(LevelFormat::Compressed, ordered, lvlRank == 1));
+ if (lvlRank > 1) {
+ // Followed by unordered non-unique n-2 singleton levels.
+ std::fill_n(std::back_inserter(lvlTypes), lvlRank - 2,
+ *buildLevelType(LevelFormat::Singleton, ordered, false));
+ // Ends by a unique singleton level unless the lvlRank is 1.
+ lvlTypes.push_back(*buildLevelType(LevelFormat::Singleton, ordered, true));
+ }
+ auto enc = SparseTensorEncodingAttr::get(getContext(), lvlTypes,
+ getDimToLvl(), getLvlToDim(),
+ getPosWidth(), getCrdWidth());
+ return RankedTensorType::get(getDimShape(), getElementType(), enc);
+}
+
+//===----------------------------------------------------------------------===//
+// Convenience Methods.
//===----------------------------------------------------------------------===//
SparseTensorEncodingAttr
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorInterfaces.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorInterfaces.cpp
index d8769eacc44f39b..c8e77f7de48300e 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorInterfaces.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorInterfaces.cpp
@@ -25,9 +25,7 @@ sparse_tensor::detail::stageWithSortImpl(StageWithSortSparseOp op,
Location loc = op.getLoc();
Type finalTp = op->getOpResult(0).getType();
SparseTensorType dstStt(finalTp.cast<RankedTensorType>());
-
- Type srcCOOTp = getCOOFromTypeWithOrdering(
- dstStt.getRankedTensorType(), dstStt.getDimToLvl(), /*ordered=*/false);
+ Type srcCOOTp = dstStt.getCOOType(/*ordered=*/false);
// Clones the original operation but changing the output to an unordered COO.
Operation *cloned = rewriter.clone(*op.getOperation());
@@ -37,8 +35,7 @@ sparse_tensor::detail::stageWithSortImpl(StageWithSortSparseOp op,
Value srcCOO = cloned->getOpResult(0);
// -> sort
- Type dstCOOTp = getCOOFromTypeWithOrdering(
- dstStt.getRankedTensorType(), dstStt.getDimToLvl(), /*ordered=*/true);
+ Type dstCOOTp = dstStt.getCOOType(/*ordered=*/true);
Value dstCOO = rewriter.create<ReorderCOOOp>(
loc, dstCOOTp, srcCOO, SparseTensorSortKind::HybridQuickSort);
``````````
</details>
https://github.com/llvm/llvm-project/pull/73708
More information about the Mlir-commits
mailing list