[Mlir-commits] [mlir] 8756292 - [mlir] Fix integer overflow in ShapedType::getNumElements and `makeCanonicalStridedLayoutExpr` (#178395)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Jan 29 02:13:50 PST 2026
Author: puneeth_aditya_5656
Date: 2026-01-29T11:13:45+01:00
New Revision: 8756292abb0072800190af329733f3d9217034fb
URL: https://github.com/llvm/llvm-project/commit/8756292abb0072800190af329733f3d9217034fb
DIFF: https://github.com/llvm/llvm-project/commit/8756292abb0072800190af329733f3d9217034fb.diff
LOG: [mlir] Fix integer overflow in ShapedType::getNumElements and `makeCanonicalStridedLayoutExpr` (#178395)
Add to `ShapedTypeInterface` a new `tryGetNumElements()` API which
returns `std::optional<int64_t>` - returns `std::nullopt` on overflow
instead of UB, using `llvm::checkedMul` for proper overflow detection.
`getNumElements()` now uses this new API to assert on overflow.
Also fix `AffineExpr` canonicalization to avoid crashing on overflow
using `llvm::checkedMul`.
Fixes #178362
Fixes #177816
---------
Co-authored-by: Claude Opus 4.5 <noreply at anthropic.com>
Added:
mlir/test/Dialect/MemRef/high-rank-overflow.mlir
Modified:
mlir/include/mlir/IR/BuiltinTypeInterfaces.td
mlir/lib/IR/BuiltinTypeInterfaces.cpp
mlir/lib/IR/BuiltinTypes.cpp
mlir/unittests/IR/ShapedTypeTest.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
index c879e5efd77fe..9ef08b7020b99 100644
--- a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
+++ b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
@@ -237,6 +237,10 @@ def ShapedTypeInterface : TypeInterface<"ShapedType"> {
/// Return the number of elements present in the given shape.
static int64_t getNumElements(ArrayRef<int64_t> shape);
+ /// Return the number of elements present in the given shape, or
+ /// std::nullopt if the computation would overflow.
+ static std::optional<int64_t> tryGetNumElements(ArrayRef<int64_t> shape);
+
/// Return a clone of this type with the given new shape and element type.
/// The returned type is ranked, even if this type is unranked.
auto clone(::llvm::ArrayRef<int64_t> shape, Type elementType) {
@@ -276,6 +280,13 @@ def ShapedTypeInterface : TypeInterface<"ShapedType"> {
return ::mlir::ShapedType::getNumElements($_type.getShape());
}
+ /// Return the number of elements, or std::nullopt if the computation would overflow.
+ /// Precondition: `hasStaticShape()`, otherwise abort.
+ std::optional<int64_t> tryGetNumElements() const {
+ assert(hasStaticShape() && "cannot get element count of dynamic shaped type");
+ return ::mlir::ShapedType::tryGetNumElements($_type.getShape());
+ }
+
/// Returns true if this dimension has a dynamic size (for ranked types);
/// aborts for unranked types.
bool isDynamicDim(unsigned idx) const {
diff --git a/mlir/lib/IR/BuiltinTypeInterfaces.cpp b/mlir/lib/IR/BuiltinTypeInterfaces.cpp
index 031752bffeab8..2f063be3e7cd0 100644
--- a/mlir/lib/IR/BuiltinTypeInterfaces.cpp
+++ b/mlir/lib/IR/BuiltinTypeInterfaces.cpp
@@ -8,6 +8,7 @@
#include "mlir/IR/BuiltinTypes.h"
#include "llvm/ADT/APFloat.h"
+#include "llvm/Support/CheckedArithmetic.h"
using namespace mlir;
using namespace mlir::detail;
@@ -34,11 +35,26 @@ unsigned FloatType::getFPMantissaWidth() {
// ShapedType
//===----------------------------------------------------------------------===//
-int64_t ShapedType::getNumElements(ArrayRef<int64_t> shape) {
+std::optional<int64_t> ShapedType::tryGetNumElements(ArrayRef<int64_t> shape) {
int64_t num = 1;
for (int64_t dim : shape) {
- num *= dim;
- assert(num >= 0 && "integer overflow in element count computation");
+ auto result = llvm::checkedMul(num, dim);
+ if (!result)
+ return std::nullopt;
+ num = *result;
}
return num;
}
+
+int64_t ShapedType::getNumElements(ArrayRef<int64_t> shape) {
+#ifndef NDEBUG
+ std::optional<int64_t> num = tryGetNumElements(shape);
+ assert(num.has_value() && "integer overflow in element count computation");
+ return *num;
+#else
+ int64_t num = 1;
+ for (int64_t dim : shape)
+ num *= dim;
+ return num;
+#endif
+}
diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index ce47c60c9b932..1e198043c590a 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -19,6 +19,7 @@
#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/Sequence.h"
#include "llvm/ADT/TypeSwitch.h"
+#include "llvm/Support/CheckedArithmetic.h"
using namespace mlir;
using namespace mlir::detail;
@@ -875,8 +876,13 @@ AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,
: getAffineConstantExpr(runningSize, context);
expr = expr ? expr + dimExpr * stride : dimExpr * stride;
if (size > 0) {
- runningSize *= size;
- assert(runningSize > 0 && "integer overflow in size computation");
+ auto result = llvm::checkedMul(runningSize, size);
+ if (!result) {
+ // Overflow occurred, treat as dynamic
+ dynamicPoisonBit = true;
+ } else {
+ runningSize = *result;
+ }
} else {
dynamicPoisonBit = true;
}
diff --git a/mlir/test/Dialect/MemRef/high-rank-overflow.mlir b/mlir/test/Dialect/MemRef/high-rank-overflow.mlir
new file mode 100644
index 0000000000000..2a6ec113c7261
--- /dev/null
+++ b/mlir/test/Dialect/MemRef/high-rank-overflow.mlir
@@ -0,0 +1,21 @@
+// RUN: mlir-opt %s --convert-to-llvm --split-input-file --verify-diagnostics | FileCheck %s
+
+// Test that extremely high-rank memrefs with overflow in stride calculation
+// are handled gracefully instead of crashing (issue #177816).
+
+// CHECK-LABEL: func @high_rank_memref_overflow
+func.func @high_rank_memref_overflow() {
+ // This creates a memref with 64 dimensions of size 2, resulting in 2^64 elements
+ // which overflows int64_t. The stride calculation should handle this gracefully.
+ %0 = memref.alloc() : memref<2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2xi32>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func @high_rank_memref_max_dim
+func.func @high_rank_memref_max_dim() {
+ // Test with fewer dimensions but larger sizes that also cause overflow
+ %0 = memref.alloc() : memref<9223372036854775807x2xi32>
+ return
+}
diff --git a/mlir/unittests/IR/ShapedTypeTest.cpp b/mlir/unittests/IR/ShapedTypeTest.cpp
index bc4066ed210e8..a013193b8dd9e 100644
--- a/mlir/unittests/IR/ShapedTypeTest.cpp
+++ b/mlir/unittests/IR/ShapedTypeTest.cpp
@@ -15,6 +15,7 @@
#include "llvm/ADT/SmallVector.h"
#include "gtest/gtest.h"
#include <cstdint>
+#include <limits>
using namespace mlir;
using namespace mlir::detail;
@@ -298,4 +299,26 @@ TEST(ShapedTypeTest, RankedTensorTypeView) {
cast<RankedTensorType>(stringEncodingRankedTensorType).dropEncoding());
}
+TEST(ShapedTypeTest, GetNumElements) {
+ // Test normal case.
+ EXPECT_EQ(ShapedType::getNumElements({2, 3, 4}), 24);
+ EXPECT_EQ(ShapedType::getNumElements({1}), 1);
+ EXPECT_EQ(ShapedType::getNumElements({}), 1);
+
+ // Test tryGetNumElements returns value for normal shapes.
+ EXPECT_TRUE(ShapedType::tryGetNumElements({2, 3, 4}).has_value());
+ EXPECT_EQ(*ShapedType::tryGetNumElements({2, 3, 4}), 24);
+
+ // Test tryGetNumElements returns nullopt for overflow.
+ // INT64_MAX = 9223372036854775807, so multiplying large dimensions overflows.
+ SmallVector<int64_t> overflowShape;
+ for (int i = 0; i < 64; ++i)
+ overflowShape.push_back(2); // 2^64 would overflow int64_t
+ EXPECT_FALSE(ShapedType::tryGetNumElements(overflowShape).has_value());
+
+ // Another overflow case with fewer but larger dimensions.
+ int64_t maxVal = std::numeric_limits<int64_t>::max();
+ EXPECT_FALSE(ShapedType::tryGetNumElements({maxVal, 2}).has_value());
+}
+
} // namespace
More information about the Mlir-commits
mailing list