[Mlir-commits] [mlir] 867e196 - [mlir][sparse] Adding SparseTensorEncodingAttr::getDimSlice
wren romano
llvmlistbot at llvm.org
Thu May 18 17:06:10 PDT 2023
Author: wren romano
Date: 2023-05-18T17:06:01-07:00
New Revision: 867e19648df234091d0241038d13684d912fb184
URL: https://github.com/llvm/llvm-project/commit/867e19648df234091d0241038d13684d912fb184
DIFF: https://github.com/llvm/llvm-project/commit/867e19648df234091d0241038d13684d912fb184.diff
LOG: [mlir][sparse] Adding SparseTensorEncodingAttr::getDimSlice
This helps catch segfaults and OOB.
Reviewed By: aartbik, Peiming
Differential Revision: https://reviews.llvm.org/D150917
Added:
Modified:
mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
index 57edaf9aa24b..505231045aa8 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
@@ -342,9 +342,9 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
bool isOrderedLvl(::mlir::sparse_tensor::Level l) const { return isOrderedDLT(getLvlType(l)); }
bool isUniqueLvl(::mlir::sparse_tensor::Level l) const { return isUniqueDLT(getLvlType(l)); }
- bool isSlice() const {
- return !getDimSlices().empty();
- }
+ bool isSlice() const;
+
+ ::mlir::sparse_tensor::SparseTensorDimSliceAttr getDimSlice(::mlir::sparse_tensor::Dimension dim) const;
std::optional<uint64_t> getStaticDimSliceOffset(::mlir::sparse_tensor::Dimension dim) const;
std::optional<uint64_t> getStaticDimSliceSize(::mlir::sparse_tensor::Dimension dim) const;
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 93845f17a77d..16f270fc464e 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -298,19 +298,32 @@ DimLevelType SparseTensorEncodingAttr::getLvlType(Level l) const {
return getLvlTypes()[l];
}
+bool SparseTensorEncodingAttr::isSlice() const {
+ assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
+ return !getDimSlices().empty();
+}
+
+SparseTensorDimSliceAttr
+SparseTensorEncodingAttr::getDimSlice(Dimension dim) const {
+ assert(isSlice() && "Is not a slice");
+ const auto dimSlices = getDimSlices();
+ assert(dim < dimSlices.size() && "Dimension is out of bounds");
+ return dimSlices[dim];
+}
+
std::optional<uint64_t>
SparseTensorEncodingAttr::getStaticDimSliceOffset(Dimension dim) const {
- return getDimSlices()[dim].getStaticOffset();
+ return getDimSlice(dim).getStaticOffset();
}
std::optional<uint64_t>
SparseTensorEncodingAttr::getStaticDimSliceSize(Dimension dim) const {
- return getDimSlices()[dim].getStaticSize();
+ return getDimSlice(dim).getStaticSize();
}
std::optional<uint64_t>
SparseTensorEncodingAttr::getStaticDimSliceStride(Dimension dim) const {
- return getDimSlices()[dim].getStaticStride();
+ return getDimSlice(dim).getStaticStride();
}
std::optional<uint64_t>
More information about the Mlir-commits
mailing list