[Mlir-commits] [mlir] 867e196 - [mlir][sparse] Adding SparseTensorEncodingAttr::getDimSlice

wren romano llvmlistbot at llvm.org
Thu May 18 17:06:10 PDT 2023


Author: wren romano
Date: 2023-05-18T17:06:01-07:00
New Revision: 867e19648df234091d0241038d13684d912fb184

URL: https://github.com/llvm/llvm-project/commit/867e19648df234091d0241038d13684d912fb184
DIFF: https://github.com/llvm/llvm-project/commit/867e19648df234091d0241038d13684d912fb184.diff

LOG: [mlir][sparse] Adding SparseTensorEncodingAttr::getDimSlice

This helps catch segfaults and OOB.

Reviewed By: aartbik, Peiming

Differential Revision: https://reviews.llvm.org/D150917

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
    mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
index 57edaf9aa24b..505231045aa8 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
@@ -342,9 +342,9 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
     bool isOrderedLvl(::mlir::sparse_tensor::Level l) const { return isOrderedDLT(getLvlType(l)); }
     bool isUniqueLvl(::mlir::sparse_tensor::Level l) const { return isUniqueDLT(getLvlType(l)); }
 
-    bool isSlice() const {
-      return !getDimSlices().empty();
-    }
+    bool isSlice() const;
+
+    ::mlir::sparse_tensor::SparseTensorDimSliceAttr getDimSlice(::mlir::sparse_tensor::Dimension dim) const;
 
     std::optional<uint64_t> getStaticDimSliceOffset(::mlir::sparse_tensor::Dimension dim) const;
     std::optional<uint64_t> getStaticDimSliceSize(::mlir::sparse_tensor::Dimension dim) const;

diff  --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 93845f17a77d..16f270fc464e 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -298,19 +298,32 @@ DimLevelType SparseTensorEncodingAttr::getLvlType(Level l) const {
   return getLvlTypes()[l];
 }
 
+bool SparseTensorEncodingAttr::isSlice() const {
+  assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
+  return !getDimSlices().empty();
+}
+
+SparseTensorDimSliceAttr
+SparseTensorEncodingAttr::getDimSlice(Dimension dim) const {
+  assert(isSlice() && "Is not a slice");
+  const auto dimSlices = getDimSlices();
+  assert(dim < dimSlices.size() && "Dimension is out of bounds");
+  return dimSlices[dim];
+}
+
 std::optional<uint64_t>
 SparseTensorEncodingAttr::getStaticDimSliceOffset(Dimension dim) const {
-  return getDimSlices()[dim].getStaticOffset();
+  return getDimSlice(dim).getStaticOffset();
 }
 
 std::optional<uint64_t>
 SparseTensorEncodingAttr::getStaticDimSliceSize(Dimension dim) const {
-  return getDimSlices()[dim].getStaticSize();
+  return getDimSlice(dim).getStaticSize();
 }
 
 std::optional<uint64_t>
 SparseTensorEncodingAttr::getStaticDimSliceStride(Dimension dim) const {
-  return getDimSlices()[dim].getStaticStride();
+  return getDimSlice(dim).getStaticStride();
 }
 
 std::optional<uint64_t>


        


More information about the Mlir-commits mailing list