[Mlir-commits] [mlir] 3da83fb - [mlir][sparse] Improving error detection/messages for `get{RankedTensor, MemRef}Type` wrappers

wren romano llvmlistbot at llvm.org
Wed May 3 18:47:52 PDT 2023


Author: wren romano
Date: 2023-05-03T18:47:44-07:00
New Revision: 3da83fbafef1689de1fc45c2c3fa3d258edda09d

URL: https://github.com/llvm/llvm-project/commit/3da83fbafef1689de1fc45c2c3fa3d258edda09d
DIFF: https://github.com/llvm/llvm-project/commit/3da83fbafef1689de1fc45c2c3fa3d258edda09d.diff

LOG: [mlir][sparse] Improving error detection/messages for `get{RankedTensor,MemRef}Type` wrappers

This helps catch some otherwise hard to track down segfaults. N.B., even though `getSparseTensorType` is not touched in this patch, it also gains the new error checking (via `getRankedTensorType`).

Depends On D149805

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D149806

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
    mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
index 6aa4f341bba96..481c2e629198c 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
@@ -85,16 +85,25 @@ using StaticSize = int64_t;
 namespace mlir {
 namespace sparse_tensor {
 
+// NOTE: `Value::getType` doesn't check for null before trying to
+// dereference things.  Therefore we check, because an assertion-failure
+// is easier to debug than a segfault.  Presumably other `T::getType`
+// methods are similarly susceptible.
+
 /// Convenience method to abbreviate casting `getType()`.
 template <typename T>
-inline RankedTensorType getRankedTensorType(T t) {
-  return t.getType().template cast<RankedTensorType>();
+inline RankedTensorType getRankedTensorType(T &&t) {
+  assert(static_cast<bool>(std::forward<T>(t)) &&
+         "getRankedTensorType got null argument");
+  return std::forward<T>(t).getType().template cast<RankedTensorType>();
 }
 
 /// Convenience method to abbreviate casting `getType()`.
 template <typename T>
-inline MemRefType getMemRefType(T t) {
-  return t.getType().template cast<MemRefType>();
+inline MemRefType getMemRefType(T &&t) {
+  assert(static_cast<bool>(std::forward<T>(t)) &&
+         "getMemRefType got null argument");
+  return std::forward<T>(t).getType().template cast<MemRefType>();
 }
 
 /// Convenience method to get a sparse encoding attribute from a type.

diff  --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
index bdcd3632b5d38..4c4f1f25edfd5 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
@@ -51,6 +51,7 @@ class SparseTensorType {
       : rtp(rtp), enc(getSparseTensorEncoding(rtp)),
         lvlRank(enc ? enc.getLvlRank() : getDimRank()),
         dim2lvl(enc.hasIdDimOrdering() ? AffineMap() : enc.getDimOrdering()) {
+    assert(rtp && "got null RankedTensorType");
     assert((!isIdentity() || getDimRank() == lvlRank) && "Rank mismatch");
   }
 


        


More information about the Mlir-commits mailing list