[Mlir-commits] [mlir] 3da83fb - [mlir][sparse] Improving error detection/messages for `get{RankedTensor, MemRef}Type` wrappers
wren romano
llvmlistbot at llvm.org
Wed May 3 18:47:52 PDT 2023
Author: wren romano
Date: 2023-05-03T18:47:44-07:00
New Revision: 3da83fbafef1689de1fc45c2c3fa3d258edda09d
URL: https://github.com/llvm/llvm-project/commit/3da83fbafef1689de1fc45c2c3fa3d258edda09d
DIFF: https://github.com/llvm/llvm-project/commit/3da83fbafef1689de1fc45c2c3fa3d258edda09d.diff
LOG: [mlir][sparse] Improving error detection/messages for `get{RankedTensor,MemRef}Type` wrappers
This helps catch some otherwise hard to track down segfaults. N.B., even though `getSparseTensorType` is not touched in this patch, it also gains the new error checking (via `getRankedTensorType`).
Depends On D149805
Reviewed By: aartbik
Differential Revision: https://reviews.llvm.org/D149806
Added:
Modified:
mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
index 6aa4f341bba96..481c2e629198c 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
@@ -85,16 +85,25 @@ using StaticSize = int64_t;
namespace mlir {
namespace sparse_tensor {
+// NOTE: `Value::getType` doesn't check for null before trying to
+// dereference things. Therefore we check, because an assertion-failure
+// is easier to debug than a segfault. Presumably other `T::getType`
+// methods are similarly susceptible.
+
/// Convenience method to abbreviate casting `getType()`.
template <typename T>
-inline RankedTensorType getRankedTensorType(T t) {
- return t.getType().template cast<RankedTensorType>();
+inline RankedTensorType getRankedTensorType(T &&t) {
+ assert(static_cast<bool>(std::forward<T>(t)) &&
+ "getRankedTensorType got null argument");
+ return std::forward<T>(t).getType().template cast<RankedTensorType>();
}
/// Convenience method to abbreviate casting `getType()`.
template <typename T>
-inline MemRefType getMemRefType(T t) {
- return t.getType().template cast<MemRefType>();
+inline MemRefType getMemRefType(T &&t) {
+ assert(static_cast<bool>(std::forward<T>(t)) &&
+ "getMemRefType got null argument");
+ return std::forward<T>(t).getType().template cast<MemRefType>();
}
/// Convenience method to get a sparse encoding attribute from a type.
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
index bdcd3632b5d38..4c4f1f25edfd5 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
@@ -51,6 +51,7 @@ class SparseTensorType {
: rtp(rtp), enc(getSparseTensorEncoding(rtp)),
lvlRank(enc ? enc.getLvlRank() : getDimRank()),
dim2lvl(enc.hasIdDimOrdering() ? AffineMap() : enc.getDimOrdering()) {
+ assert(rtp && "got null RankedTensorType");
assert((!isIdentity() || getDimRank() == lvlRank) && "Rank mismatch");
}
More information about the Mlir-commits
mailing list