[Mlir-commits] [mlir] [mlir][sparse] move all COO related methods into SparseTensorType (PR #73881)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Nov 29 16:47:19 PST 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-sparse
Author: Aart Bik (aartbik)
<details>
<summary>Changes</summary>
This centralizes all COO methods, and provides a cleaner API. Note that the "enc" only constructor is a temporary workaround the need for COO methods inside the "enc" only storage specifier.
---
Full diff: https://github.com/llvm/llvm-project/pull/73881.diff
8 Files Affected:
- (modified) mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h (-13)
- (modified) mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h (+18-2)
- (modified) mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp (+35-44)
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp (+3-4)
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp (+4-5)
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorDescriptor.cpp (+1-1)
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorDescriptor.h (+1-1)
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp (+2-2)
``````````diff
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
index 28dfdbdcf89b5bf..5e523ec428aefb9 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
@@ -89,19 +89,6 @@ inline MemRefType getMemRefType(T &&t) {
/// Returns null-attribute for any type without an encoding.
SparseTensorEncodingAttr getSparseTensorEncoding(Type type);
-/// Returns true iff the given sparse tensor encoding attribute has a trailing
-/// COO region starting at the given level.
-bool isCOOType(SparseTensorEncodingAttr enc, Level startLvl, bool isUnique);
-
-/// Returns true iff the given type is a COO type where the last level
-/// is unique.
-bool isUniqueCOOType(Type tp);
-
-/// Returns the starting level for a trailing COO region that spans
-/// at least two levels. If no such COO region is found, then returns
-/// the level-rank.
-Level getCOOStart(SparseTensorEncodingAttr enc);
-
/// Returns true iff MLIR operand has any sparse operand.
inline bool hasAnySparseOperand(Operation *op) {
return llvm::any_of(op->getOperands().getTypes(), [](Type t) {
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
index dc520e390de293d..4c98129744bcd94 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
@@ -60,6 +60,12 @@ class SparseTensorType {
: SparseTensorType(
RankedTensorType::get(stp.getShape(), stp.getElementType(), enc)) {}
+ // TODO: remove?
+ SparseTensorType(SparseTensorEncodingAttr enc)
+ : SparseTensorType(RankedTensorType::get(
+ SmallVector<Size>(enc.getDimRank(), ShapedType::kDynamic),
+ Float32Type::get(enc.getContext()), enc)) {}
+
SparseTensorType &operator=(const SparseTensorType &) = delete;
SparseTensorType(const SparseTensorType &) = default;
@@ -234,9 +240,9 @@ class SparseTensorType {
CrdTransDirectionKind::dim2lvl);
}
+ /// Returns the type with an identity mapping.
RankedTensorType getDemappedType() const {
- auto lvlShape = getLvlShape();
- return RankedTensorType::get(lvlShape, rtp.getElementType(),
+ return RankedTensorType::get(getLvlShape(), getElementType(),
enc.withoutDimToLvl());
}
@@ -311,6 +317,16 @@ class SparseTensorType {
return IndexType::get(getContext());
}
+ /// Returns true iff this sparse tensor type has a trailing
+ /// COO region starting at the given level. By default, it
+ /// tests for a unique COO type at top level.
+ bool isCOOType(Level startLvl = 0, bool isUnique = true) const;
+
+ /// Returns the starting level of this sparse tensor type for a
+ /// trailing COO region that spans **at least** two levels. If
+ /// no such COO region is found, then returns the level-rank.
+ Level getCOOStart() const;
+
/// Returns [un]ordered COO type for this sparse tensor type.
RankedTensorType getCOOType(bool ordered) const;
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index d4f8afdd62f2383..7dc4fc4f8570d60 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -66,7 +66,7 @@ void StorageLayout::foreachField(
callback) const {
const auto lvlTypes = enc.getLvlTypes();
const Level lvlRank = enc.getLvlRank();
- const Level cooStart = getCOOStart(enc);
+ const Level cooStart = SparseTensorType(enc).getCOOStart();
const Level end = cooStart == lvlRank ? cooStart : cooStart + 1;
FieldIndex fieldIdx = kDataFieldStartingIdx;
// Per-level storage.
@@ -158,7 +158,7 @@ StorageLayout::getFieldIndexAndStride(SparseTensorFieldKind kind,
unsigned stride = 1;
if (kind == SparseTensorFieldKind::CrdMemRef) {
assert(lvl.has_value());
- const Level cooStart = getCOOStart(enc);
+ const Level cooStart = SparseTensorType(enc).getCOOStart();
const Level lvlRank = enc.getLvlRank();
if (lvl.value() >= cooStart && lvl.value() < lvlRank) {
lvl = cooStart;
@@ -710,6 +710,28 @@ LogicalResult SparseTensorEncodingAttr::verifyEncoding(
// SparseTensorType Methods.
//===----------------------------------------------------------------------===//
+bool mlir::sparse_tensor::SparseTensorType::isCOOType(Level startLvl, bool isUnique) const {
+ if (!hasEncoding())
+ return false;
+ if (!isCompressedLvl(startLvl) && !isLooseCompressedLvl(startLvl))
+ return false;
+ for (Level l = startLvl + 1; l < lvlRank; ++l)
+ if (!isSingletonLvl(l))
+ return false;
+ // If isUnique is true, then make sure that the last level is unique,
+ // that is, lvlRank == 1 (unique the only compressed) and lvlRank > 1
+ // (unique on the last singleton).
+ return !isUnique || isUniqueLvl(lvlRank - 1);
+}
+
+Level mlir::sparse_tensor::SparseTensorType::getCOOStart() const {
+ if (lvlRank > 1)
+ for (Level l = 0; l < lvlRank - 1; l++)
+ if (isCOOType(l, /*isUnique=*/false))
+ return l;
+ return lvlRank;
+}
+
RankedTensorType
mlir::sparse_tensor::SparseTensorType::getCOOType(bool ordered) const {
SmallVector<LevelType> lvlTypes;
@@ -859,25 +881,6 @@ bool mlir::sparse_tensor::isBlockSparsity(AffineMap dimToLvl) {
return !coeffientMap.empty();
}
-bool mlir::sparse_tensor::isCOOType(SparseTensorEncodingAttr enc,
- Level startLvl, bool isUnique) {
- if (!enc ||
- !(enc.isCompressedLvl(startLvl) || enc.isLooseCompressedLvl(startLvl)))
- return false;
- const Level lvlRank = enc.getLvlRank();
- for (Level l = startLvl + 1; l < lvlRank; ++l)
- if (!enc.isSingletonLvl(l))
- return false;
- // If isUnique is true, then make sure that the last level is unique,
- // that is, lvlRank == 1 (unique the only compressed) and lvlRank > 1
- // (unique on the last singleton).
- return !isUnique || enc.isUniqueLvl(lvlRank - 1);
-}
-
-bool mlir::sparse_tensor::isUniqueCOOType(Type tp) {
- return isCOOType(getSparseTensorEncoding(tp), 0, /*isUnique=*/true);
-}
-
bool mlir::sparse_tensor::hasAnyNonIdentityOperandsOrResults(Operation *op) {
auto hasNonIdentityMap = [](Value v) {
auto stt = tryGetSparseTensorType(v);
@@ -888,17 +891,6 @@ bool mlir::sparse_tensor::hasAnyNonIdentityOperandsOrResults(Operation *op) {
llvm::any_of(op->getResults(), hasNonIdentityMap);
}
-Level mlir::sparse_tensor::getCOOStart(SparseTensorEncodingAttr enc) {
- // We only consider COO region with at least two levels for the purpose
- // of AOS storage optimization.
- const Level lvlRank = enc.getLvlRank();
- if (lvlRank > 1)
- for (Level l = 0; l < lvlRank - 1; l++)
- if (isCOOType(enc, l, /*isUnique=*/false))
- return l;
- return lvlRank;
-}
-
Dimension mlir::sparse_tensor::toDim(SparseTensorEncodingAttr enc, Level l) {
if (enc) {
assert(enc.isPermutation() && "Non permutation map not supported");
@@ -1013,7 +1005,7 @@ static LogicalResult verifyPackUnPack(Operation *op, bool requiresStaticShape,
return op->emitError("the sparse-tensor must have the identity mapping");
// Verifies the trailing COO.
- Level cooStartLvl = getCOOStart(stt.getEncoding());
+ Level cooStartLvl = stt.getCOOStart();
if (cooStartLvl < stt.getLvlRank()) {
// We only supports trailing COO for now, must be the last input.
auto cooTp = llvm::cast<ShapedType>(lvlTps.back());
@@ -1309,34 +1301,34 @@ OpFoldResult ReinterpretMapOp::fold(FoldAdaptor adaptor) {
}
LogicalResult ToPositionsOp::verify() {
- auto e = getSparseTensorEncoding(getTensor().getType());
+ auto stt = getSparseTensorType(getTensor());
if (failed(lvlIsInBounds(getLevel(), getTensor())))
return emitError("requested level is out of bounds");
- if (failed(isMatchingWidth(getResult(), e.getPosWidth())))
+ if (failed(isMatchingWidth(getResult(), stt.getPosWidth())))
return emitError("unexpected type for positions");
return success();
}
LogicalResult ToCoordinatesOp::verify() {
- auto e = getSparseTensorEncoding(getTensor().getType());
+ auto stt = getSparseTensorType(getTensor());
if (failed(lvlIsInBounds(getLevel(), getTensor())))
return emitError("requested level is out of bounds");
- if (failed(isMatchingWidth(getResult(), e.getCrdWidth())))
+ if (failed(isMatchingWidth(getResult(), stt.getCrdWidth())))
return emitError("unexpected type for coordinates");
return success();
}
LogicalResult ToCoordinatesBufferOp::verify() {
- auto e = getSparseTensorEncoding(getTensor().getType());
- if (getCOOStart(e) >= e.getLvlRank())
+ auto stt = getSparseTensorType(getTensor());
+ if (stt.getCOOStart() >= stt.getLvlRank())
return emitError("expected sparse tensor with a COO region");
return success();
}
LogicalResult ToValuesOp::verify() {
- auto ttp = getRankedTensorType(getTensor());
+ auto stt = getSparseTensorType(getTensor());
auto mtp = getMemRefType(getResult());
- if (ttp.getElementType() != mtp.getElementType())
+ if (stt.getElementType() != mtp.getElementType())
return emitError("unexpected mismatch in element types");
return success();
}
@@ -1660,9 +1652,8 @@ LogicalResult ReorderCOOOp::verify() {
SparseTensorType srcStt = getSparseTensorType(getInputCoo());
SparseTensorType dstStt = getSparseTensorType(getResultCoo());
- if (!isCOOType(srcStt.getEncoding(), 0, /*isUnique=*/true) ||
- !isCOOType(dstStt.getEncoding(), 0, /*isUnique=*/true))
- emitError("Unexpected non-COO sparse tensors");
+ if (!srcStt.isCOOType() || !dstStt.isCOOType())
+ emitError("Expected COO sparse tensors only");
if (!srcStt.hasSameDimToLvl(dstStt))
emitError("Unmatched dim2lvl map between input and result COO");
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
index a245344755f0404..26f015ce6ec64f7 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
@@ -412,8 +412,7 @@ void LoopEmitter::initializeLoopEmit(
auto stt = getSparseTensorType(tensor);
const Level lvlRank = stt.getLvlRank();
const auto shape = rtp.getShape();
- const auto enc = getSparseTensorEncoding(rtp);
- const Level cooStart = enc ? getCOOStart(enc) : lvlRank;
+ const Level cooStart = stt.getCOOStart();
SmallVector<Value> lvlSzs;
for (Level l = 0; l < stt.getLvlRank(); l++) {
@@ -457,8 +456,8 @@ void LoopEmitter::initializeLoopEmit(
// values.
// Delegates extra output initialization to clients.
bool isOutput = isOutputTensor(t);
- Type elementType = rtp.getElementType();
- if (!enc) {
+ Type elementType = stt.getElementType();
+ if (!stt.hasEncoding()) {
// Non-annotated dense tensors.
BaseMemRefType denseTp = MemRefType::get(shape, elementType);
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index e9062b49435f5b7..18b2bb0819e2642 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -194,7 +194,7 @@ static void createAllocFields(OpBuilder &builder, Location loc,
valHeuristic =
builder.create<arith::MulIOp>(loc, valHeuristic, lvlSizesValues[lvl]);
} else if (sizeHint) {
- if (getCOOStart(stt.getEncoding()) == 0) {
+ if (stt.getCOOStart() == 0) {
posHeuristic = constantIndex(builder, loc, 2);
crdHeuristic = builder.create<arith::MulIOp>(
loc, constantIndex(builder, loc, lvlRank), sizeHint); // AOS
@@ -657,8 +657,7 @@ struct SparseReorderCOOConverter : public OpConversionPattern<ReorderCOOOp> {
// Should have been verified.
assert(dstStt.isAllOrdered() && !srcStt.isAllOrdered() &&
- isUniqueCOOType(srcStt.getRankedTensorType()) &&
- isUniqueCOOType(dstStt.getRankedTensorType()));
+ dstStt.isCOOType() && srcStt.isCOOType());
assert(dstStt.hasSameDimToLvl(srcStt));
// We don't need a mutable descriptor here as we perform sorting in-place.
@@ -1317,7 +1316,7 @@ struct SparseAssembleOpConverter : public OpConversionPattern<AssembleOp> {
Value posBack = c0; // index to the last value in the position array
Value memSize = c1; // memory size for current array
- Level trailCOOStart = getCOOStart(stt.getEncoding());
+ Level trailCOOStart = stt.getCOOStart();
Level trailCOORank = stt.getLvlRank() - trailCOOStart;
// Sets up SparseTensorSpecifier.
for (Level lvl = 0, lvlRank = stt.getLvlRank(); lvl < lvlRank; lvl++) {
@@ -1454,7 +1453,7 @@ struct SparseNewConverter : public OpConversionPattern<NewOp> {
const auto dstTp = getSparseTensorType(op.getResult());
// Creating COO with NewOp is handled by direct IR codegen. All other cases
// are handled by rewriting.
- if (!dstTp.hasEncoding() || getCOOStart(dstTp.getEncoding()) != 0)
+ if (!dstTp.hasEncoding() || dstTp.getCOOStart() != 0)
return failure();
// Implement as follows:
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorDescriptor.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorDescriptor.cpp
index 1c6d7bebe37e46c..3ab4157475cd4c2 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorDescriptor.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorDescriptor.cpp
@@ -103,7 +103,7 @@ void SparseTensorSpecifier::setSpecifierField(OpBuilder &builder, Location loc,
Value sparse_tensor::SparseTensorDescriptor::getCrdMemRefOrView(
OpBuilder &builder, Location loc, Level lvl) const {
- const Level cooStart = getCOOStart(rType.getEncoding());
+ const Level cooStart = rType.getCOOStart();
if (lvl < cooStart)
return getMemRefField(SparseTensorFieldKind::CrdMemRef, lvl);
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorDescriptor.h b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorDescriptor.h
index 4bd700eef522e04..5c7d8aa4c9d9678 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorDescriptor.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorDescriptor.h
@@ -137,7 +137,7 @@ class SparseTensorDescriptorImpl {
}
Value getAOSMemRef() const {
- const Level cooStart = getCOOStart(rType.getEncoding());
+ const Level cooStart = rType.getCOOStart();
assert(cooStart < rType.getLvlRank());
return getMemRefField(SparseTensorFieldKind::CrdMemRef, cooStart);
}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index 2bd129b85ea5416..4fc692f2fe9ddc2 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -1180,8 +1180,7 @@ struct NewRewriter : public OpRewritePattern<NewOp> {
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
auto stt = getSparseTensorType(op.getResult());
- auto enc = stt.getEncoding();
- if (!stt.hasEncoding() || getCOOStart(enc) == 0)
+ if (!stt.hasEncoding() || stt.getCOOStart() == 0)
return failure();
// Implement the NewOp as follows:
@@ -1192,6 +1191,7 @@ struct NewRewriter : public OpRewritePattern<NewOp> {
RankedTensorType cooTp = stt.getCOOType(/*ordered=*/true);
Value cooTensor = rewriter.create<NewOp>(loc, cooTp, op.getSource());
Value convert = cooTensor;
+ auto enc = stt.getEncoding();
if (!stt.isPermutation()) { // demap coo, demap dstTp
auto coo = getSparseTensorType(cooTensor).getEncoding().withoutDimToLvl();
convert = rewriter.create<ReinterpretMapOp>(loc, coo, convert);
``````````
</details>
https://github.com/llvm/llvm-project/pull/73881
More information about the Mlir-commits
mailing list