[Mlir-commits] [mlir] [mlir][sparse] move toCOOType into SparseTensorType class (PR #73708)
Aart Bik
llvmlistbot at llvm.org
Tue Nov 28 15:06:32 PST 2023
https://github.com/aartbik updated https://github.com/llvm/llvm-project/pull/73708
>From 472964073950e87c808460f1cc56954a3c6a648c Mon Sep 17 00:00:00 2001
From: Aart Bik <ajcbik at google.com>
Date: Tue, 28 Nov 2023 14:45:30 -0800
Subject: [PATCH 1/4] [mlir][sparse] move toCOOType into SparseTensorType class
Migrates dangling convenience method into proper SparseTensorType
class. Also cleans up some details (picking right dim2lvl/lvl2dim).
Removes more dead code.
---
.../Dialect/SparseTensor/IR/SparseTensor.h | 4 ---
.../SparseTensor/IR/SparseTensorType.h | 29 ++++++-------------
.../SparseTensor/IR/SparseTensorDialect.cpp | 29 +++++++++++++++++--
.../IR/SparseTensorInterfaces.cpp | 7 ++---
4 files changed, 38 insertions(+), 31 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
index 517c286e0206997..28dfdbdcf89b5bf 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
@@ -102,10 +102,6 @@ bool isUniqueCOOType(Type tp);
/// the level-rank.
Level getCOOStart(SparseTensorEncodingAttr enc);
-/// Helper to setup a COO type.
-RankedTensorType getCOOFromTypeWithOrdering(RankedTensorType src,
- AffineMap ordering, bool ordered);
-
/// Returns true iff MLIR operand has any sparse operand.
inline bool hasAnySparseOperand(Operation *op) {
return llvm::any_of(op->getOperands().getTypes(), [](Type t) {
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
index 4eb666d76cd2d6f..dc520e390de293d 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
@@ -64,18 +64,14 @@ class SparseTensorType {
SparseTensorType(const SparseTensorType &) = default;
//
- // Factory methods.
+ // Factory methods to construct a new `SparseTensorType`
+ // with the same dimension-shape and element type.
//
- /// Constructs a new `SparseTensorType` with the same dimension-shape
- /// and element type, but with the encoding replaced by the given encoding.
SparseTensorType withEncoding(SparseTensorEncodingAttr newEnc) const {
return SparseTensorType(rtp, newEnc);
}
- /// Constructs a new `SparseTensorType` with the same dimension-shape
- /// and element type, but with the encoding replaced by
- /// `getEncoding().withDimToLvl(dimToLvl)`.
SparseTensorType withDimToLvl(AffineMap dimToLvl) const {
return withEncoding(enc.withDimToLvl(dimToLvl));
}
@@ -88,23 +84,14 @@ class SparseTensorType {
return withDimToLvl(dimToLvlSTT.getEncoding());
}
- /// Constructs a new `SparseTensorType` with the same dimension-shape
- /// and element type, but with the encoding replaced by
- /// `getEncoding().withoutDimToLvl()`.
SparseTensorType withoutDimToLvl() const {
return withEncoding(enc.withoutDimToLvl());
}
- /// Constructs a new `SparseTensorType` with the same dimension-shape
- /// and element type, but with the encoding replaced by
- /// `getEncoding().withBitWidths(posWidth, crdWidth)`.
SparseTensorType withBitWidths(unsigned posWidth, unsigned crdWidth) const {
return withEncoding(enc.withBitWidths(posWidth, crdWidth));
}
- /// Constructs a new `SparseTensorType` with the same dimension-shape
- /// and element type, but with the encoding replaced by
- /// `getEncoding().withoutBitWidths()`.
SparseTensorType withoutBitWidths() const {
return withEncoding(enc.withoutBitWidths());
}
@@ -118,10 +105,6 @@ class SparseTensorType {
return withEncoding(enc.withoutDimSlices());
}
- //
- // Other methods.
- //
-
/// Allow implicit conversion to `RankedTensorType`, `ShapedType`,
/// and `Type`. These are implicit to help alleviate the impedance
/// mismatch for code that has not been converted to use `SparseTensorType`
@@ -170,7 +153,6 @@ class SparseTensorType {
Type getElementType() const { return rtp.getElementType(); }
- /// Returns the encoding (or the null-attribute for dense-tensors).
SparseTensorEncodingAttr getEncoding() const { return enc; }
//
@@ -204,6 +186,10 @@ class SparseTensorType {
/// (This is always true for dense-tensors.)
bool isIdentity() const { return enc.isIdentity(); }
+ //
+ // Other methods.
+ //
+
/// Returns the dimToLvl mapping (or the null-map for the identity).
/// If you intend to compare the results of this method for equality,
/// see `hasSameDimToLvl` instead.
@@ -325,6 +311,9 @@ class SparseTensorType {
return IndexType::get(getContext());
}
+ /// Returns [un]ordered COO type for this sparse tensor type.
+ RankedTensorType getCOOType(bool ordered) const;
+
private:
// These two must be const, to ensure coherence of the memoized fields.
const RankedTensorType rtp;
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index ff2930008fa093f..edf7df3cfedabba 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -36,7 +36,7 @@ using namespace mlir;
using namespace mlir::sparse_tensor;
//===----------------------------------------------------------------------===//
-// Local convenience methods.
+// Local Convenience Methods.
//===----------------------------------------------------------------------===//
static constexpr bool acceptBitWidth(unsigned bitWidth) {
@@ -711,7 +711,32 @@ LogicalResult SparseTensorEncodingAttr::verifyEncoding(
}
//===----------------------------------------------------------------------===//
-// Convenience methods.
+// SparseTensorType SparseTensorType Methods.
+//===----------------------------------------------------------------------===//
+
+RankedTensorType
+mlir::sparse_tensor::SparseTensorType::getCOOType(bool ordered) const {
+ SmallVector<LevelType> lvlTypes;
+ lvlTypes.reserve(lvlRank);
+ // An unordered and non-unique compressed level at beginning.
+ // If this is also the last level, then it is unique.
+ lvlTypes.push_back(
+ *buildLevelType(LevelFormat::Compressed, ordered, lvlRank == 1));
+ if (lvlRank > 1) {
+ // Followed by unordered non-unique n-2 singleton levels.
+ std::fill_n(std::back_inserter(lvlTypes), lvlRank - 2,
+ *buildLevelType(LevelFormat::Singleton, ordered, false));
+ // Ends by a unique singleton level unless the lvlRank is 1.
+ lvlTypes.push_back(*buildLevelType(LevelFormat::Singleton, ordered, true));
+ }
+ auto enc = SparseTensorEncodingAttr::get(getContext(), lvlTypes,
+ getDimToLvl(), getLvlToDim(),
+ getPosWidth(), getCrdWidth());
+ return RankedTensorType::get(getDimShape(), getElementType(), enc);
+}
+
+//===----------------------------------------------------------------------===//
+// Convenience Methods.
//===----------------------------------------------------------------------===//
SparseTensorEncodingAttr
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorInterfaces.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorInterfaces.cpp
index d8769eacc44f39b..c8e77f7de48300e 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorInterfaces.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorInterfaces.cpp
@@ -25,9 +25,7 @@ sparse_tensor::detail::stageWithSortImpl(StageWithSortSparseOp op,
Location loc = op.getLoc();
Type finalTp = op->getOpResult(0).getType();
SparseTensorType dstStt(finalTp.cast<RankedTensorType>());
-
- Type srcCOOTp = getCOOFromTypeWithOrdering(
- dstStt.getRankedTensorType(), dstStt.getDimToLvl(), /*ordered=*/false);
+ Type srcCOOTp = dstStt.getCOOType(/*ordered=*/false);
// Clones the original operation but changing the output to an unordered COO.
Operation *cloned = rewriter.clone(*op.getOperation());
@@ -37,8 +35,7 @@ sparse_tensor::detail::stageWithSortImpl(StageWithSortSparseOp op,
Value srcCOO = cloned->getOpResult(0);
// -> sort
- Type dstCOOTp = getCOOFromTypeWithOrdering(
- dstStt.getRankedTensorType(), dstStt.getDimToLvl(), /*ordered=*/true);
+ Type dstCOOTp = dstStt.getCOOType(/*ordered=*/true);
Value dstCOO = rewriter.create<ReorderCOOOp>(
loc, dstCOOTp, srcCOO, SparseTensorSortKind::HybridQuickSort);
>From d126b4dee618bae40b4df0b9bf78ec8c26433473 Mon Sep 17 00:00:00 2001
From: Aart Bik <ajcbik at google.com>
Date: Tue, 28 Nov 2023 14:52:52 -0800
Subject: [PATCH 2/4] rewriting file too
---
.../SparseTensor/Transforms/SparseTensorRewriting.cpp | 10 ++--------
1 file changed, 2 insertions(+), 8 deletions(-)
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index 702666e9d40c31f..2bd129b85ea5416 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -132,15 +132,9 @@ static void sizesForTensor(OpBuilder &builder, SmallVectorImpl<Value> &sizes,
}
}
-// TODO: The dim level property of the COO type relies on input tensors, the
-// shape relies on the output tensor
-static RankedTensorType getCOOType(const SparseTensorType &stt, bool ordered) {
- return getCOOFromTypeWithOrdering(stt, stt.getDimToLvl(), ordered);
-}
-
static RankedTensorType getBufferType(const SparseTensorType &stt,
bool needTmpCOO) {
- return needTmpCOO ? getCOOType(stt, /*ordered=*/false)
+ return needTmpCOO ? stt.getCOOType(/*ordered=*/false)
: stt.getRankedTensorType();
}
@@ -1195,7 +1189,7 @@ struct NewRewriter : public OpRewritePattern<NewOp> {
// %t = sparse_tensor.convert %orderedCoo
// with enveloping reinterpreted_map ops for non-permutations.
RankedTensorType dstTp = stt.getRankedTensorType();
- RankedTensorType cooTp = getCOOType(dstTp, /*ordered=*/true);
+ RankedTensorType cooTp = stt.getCOOType(/*ordered=*/true);
Value cooTensor = rewriter.create<NewOp>(loc, cooTp, op.getSource());
Value convert = cooTensor;
if (!stt.isPermutation()) { // demap coo, demap dstTp
>From 086bf81a004dff468bb88f50e3eba9ddeb655615 Mon Sep 17 00:00:00 2001
From: Aart Bik <ajcbik at google.com>
Date: Tue, 28 Nov 2023 15:03:22 -0800
Subject: [PATCH 3/4] DCE
---
.../SparseTensor/IR/SparseTensorDialect.cpp | 33 -------------------
1 file changed, 33 deletions(-)
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index edf7df3cfedabba..50c62bbba9d2437 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -903,39 +903,6 @@ Level mlir::sparse_tensor::getCOOStart(SparseTensorEncodingAttr enc) {
return lvlRank;
}
-// Helper to setup a COO type.
-RankedTensorType sparse_tensor::getCOOFromTypeWithOrdering(RankedTensorType rtt,
- AffineMap lvlPerm,
- bool ordered) {
- const SparseTensorType src(rtt);
- const Level lvlRank = src.getLvlRank();
- SmallVector<LevelType> lvlTypes;
- lvlTypes.reserve(lvlRank);
-
- // An unordered and non-unique compressed level at beginning.
- // If this is also the last level, then it is unique.
- lvlTypes.push_back(
- *buildLevelType(LevelFormat::Compressed, ordered, lvlRank == 1));
- if (lvlRank > 1) {
- // TODO: it is actually ordered at the level for ordered input.
- // Followed by unordered non-unique n-2 singleton levels.
- std::fill_n(std::back_inserter(lvlTypes), lvlRank - 2,
- *buildLevelType(LevelFormat::Singleton, ordered, false));
- // Ends by a unique singleton level unless the lvlRank is 1.
- lvlTypes.push_back(*buildLevelType(LevelFormat::Singleton, ordered, true));
- }
-
- // TODO: Maybe pick the bitwidth based on input/output tensors (probably the
- // largest one among them) in the original operation instead of using the
- // default value.
- unsigned posWidth = src.getPosWidth();
- unsigned crdWidth = src.getCrdWidth();
- AffineMap invPerm = src.getLvlToDim();
- auto enc = SparseTensorEncodingAttr::get(src.getContext(), lvlTypes, lvlPerm,
- invPerm, posWidth, crdWidth);
- return RankedTensorType::get(src.getDimShape(), src.getElementType(), enc);
-}
-
Dimension mlir::sparse_tensor::toDim(SparseTensorEncodingAttr enc, Level l) {
if (enc) {
assert(enc.isPermutation() && "Non permutation map not supported");
>From 6b2bb9b550c4b63a5b25c2e770ad0deae3bf32a9 Mon Sep 17 00:00:00 2001
From: Aart Bik <ajcbik at google.com>
Date: Tue, 28 Nov 2023 15:06:02 -0800
Subject: [PATCH 4/4] feedback
---
mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 50c62bbba9d2437..20a091a81a26c25 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -711,7 +711,7 @@ LogicalResult SparseTensorEncodingAttr::verifyEncoding(
}
//===----------------------------------------------------------------------===//
-// SparseTensorType SparseTensorType Methods.
+// SparseTensorType Methods.
//===----------------------------------------------------------------------===//
RankedTensorType
More information about the Mlir-commits
mailing list