[Mlir-commits] [mlir] ae7942e - [mlir][sparse] adding `SparseTensorType::get{Pointer, Index}Type` methods
wren romano
llvmlistbot at llvm.org
Wed Feb 15 14:38:02 PST 2023
Author: wren romano
Date: 2023-02-15T14:37:55-08:00
New Revision: ae7942e2960e73bd4e568b8b15d1ace35303ae10
URL: https://github.com/llvm/llvm-project/commit/ae7942e2960e73bd4e568b8b15d1ace35303ae10
DIFF: https://github.com/llvm/llvm-project/commit/ae7942e2960e73bd4e568b8b15d1ace35303ae10.diff
LOG: [mlir][sparse] adding `SparseTensorType::get{Pointer,Index}Type` methods
Depends On D143800
Reviewed By: aartbik
Differential Revision: https://reviews.llvm.org/D143946
Added:
Modified:
mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
index ad5501cfa24ca..a52a0dadd42b9 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
@@ -144,6 +144,10 @@ DEPRECATED Level toStoredDim(RankedTensorType type, Dimension d);
#undef DEPRECATED
+namespace detail {
+Type getIntegerOrIndexType(MLIRContext *ctx, unsigned bitwidth);
+} // namespace detail
+
} // namespace sparse_tensor
} // namespace mlir
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
index 4eeaa39e84236..c2adc2694e3ef 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
@@ -234,6 +234,16 @@ class SparseTensorType {
return enc ? enc.getPointerBitWidth() : 0;
}
+ /// Returns the index-overhead MLIR type, defaulting to `IndexType`.
+ Type getIndexType() const {
+ return detail::getIntegerOrIndexType(getContext(), getIndexBitWidth());
+ }
+
+ /// Returns the pointer-overhead MLIR type, defaulting to `IndexType`.
+ Type getPointerType() const {
+ return detail::getIntegerOrIndexType(getContext(), getPointerBitWidth());
+ }
+
private:
// These two must be const, to ensure coherence of the memoized fields.
const RankedTensorType rtp;
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 32d998f7388c9..8ff474c4505d5 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -114,18 +114,19 @@ SparseTensorDimSliceAttr::verify(function_ref<InFlightDiagnostic()> emitError,
<< "expect positive value or ? for slice offset/size/stride";
}
-static Type getIntegerOrIndexType(MLIRContext *ctx, unsigned bitwidth) {
+Type mlir::sparse_tensor::detail::getIntegerOrIndexType(MLIRContext *ctx,
+ unsigned bitwidth) {
if (bitwidth)
return IntegerType::get(ctx, bitwidth);
return IndexType::get(ctx);
}
Type SparseTensorEncodingAttr::getPointerType() const {
- return getIntegerOrIndexType(getContext(), getPointerBitWidth());
+ return detail::getIntegerOrIndexType(getContext(), getPointerBitWidth());
}
Type SparseTensorEncodingAttr::getIndexType() const {
- return getIntegerOrIndexType(getContext(), getIndexBitWidth());
+ return detail::getIntegerOrIndexType(getContext(), getIndexBitWidth());
}
SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutOrdering() const {
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index ceee541f2a7a8..335f743e2db3d 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -160,7 +160,7 @@ static void allocSchemeForRank(OpBuilder &builder, Location loc,
// Append linear x pointers, initialized to zero. Since each compressed
// dimension initially already has a single zero entry, this maintains
// the desired "linear + 1" length property at all times.
- Type ptrType = stt.getEncoding().getPointerType();
+ Type ptrType = stt.getPointerType();
Value ptrZero = constantZero(builder, loc, ptrType);
createPushback(builder, loc, desc, SparseTensorFieldKind::PtrMemRef, l,
ptrZero, linear);
@@ -279,8 +279,7 @@ static void createAllocFields(OpBuilder &builder, Location loc,
// to all zeros, sets the dimSizes to known values and gives all pointer
// fields an initial zero entry, so that it is easier to maintain the
// "linear + 1" length property.
- Value ptrZero =
- constantZero(builder, loc, stt.getEncoding().getPointerType());
+ Value ptrZero = constantZero(builder, loc, stt.getPointerType());
for (Level lvlRank = stt.getLvlRank(), l = 0; l < lvlRank; l++) {
// Fills dim sizes array.
// FIXME: this method seems to set *level* sizes, but the name is confusing
@@ -546,7 +545,7 @@ static void genEndInsert(OpBuilder &builder, Location loc,
// times?
//
if (l > 0) {
- Type ptrType = stt.getEncoding().getPointerType();
+ Type ptrType = stt.getPointerType();
Value ptrMemRef = desc.getPtrMemRef(l);
Value hi = desc.getPtrMemSize(builder, loc, l);
Value zero = constantIndex(builder, loc, 0);
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.cpp
index 8440630c2eefa..be59ba83f0f4b 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.cpp
@@ -179,14 +179,13 @@ void sparse_tensor::foreachFieldAndTypeInSparseTensor(
llvm::function_ref<bool(Type, FieldIndex, SparseTensorFieldKind, Level,
DimLevelType)>
callback) {
- const auto enc = stt.getEncoding();
- assert(enc);
+ assert(stt.hasEncoding());
// Construct the basic types.
- Type idxType = enc.getIndexType();
- Type ptrType = enc.getPointerType();
+ Type idxType = stt.getIndexType();
+ Type ptrType = stt.getPointerType();
Type eltType = stt.getElementType();
- Type metaDataType = StorageSpecifierType::get(enc);
+ Type metaDataType = StorageSpecifierType::get(stt.getEncoding());
// memref<? x ptr> pointers
Type ptrMemType = MemRefType::get({ShapedType::kDynamic}, ptrType);
// memref<? x idx> indices
@@ -195,7 +194,7 @@ void sparse_tensor::foreachFieldAndTypeInSparseTensor(
Type valMemType = MemRefType::get({ShapedType::kDynamic}, eltType);
foreachFieldInSparseTensor(
- enc,
+ stt.getEncoding(),
[metaDataType, ptrMemType, idxMemType, valMemType,
callback](FieldIndex fieldIdx, SparseTensorFieldKind fieldKind,
Level lvl, DimLevelType dlt) -> bool {
More information about the Mlir-commits
mailing list