[Mlir-commits] [mlir] 8756292 - [mlir] Fix integer overflow in ShapedType::getNumElements and `makeCanonicalStridedLayoutExpr` (#178395)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Jan 29 02:13:50 PST 2026


Author: puneeth_aditya_5656
Date: 2026-01-29T11:13:45+01:00
New Revision: 8756292abb0072800190af329733f3d9217034fb

URL: https://github.com/llvm/llvm-project/commit/8756292abb0072800190af329733f3d9217034fb
DIFF: https://github.com/llvm/llvm-project/commit/8756292abb0072800190af329733f3d9217034fb.diff

LOG: [mlir] Fix integer overflow in ShapedType::getNumElements and `makeCanonicalStridedLayoutExpr` (#178395)

Add to `ShapedTypeInterface` a new `tryGetNumElements()` API which
returns `std::optional<int64_t>` - returns `std::nullopt` on overflow
instead of UB, using `llvm::checkedMul` for proper overflow detection.
`getNumElements()` now uses this new API to assert on overflow.

Also fix `AffineExpr` canonicalization to avoid crashing on overflow
using `llvm::checkedMul`.

Fixes #178362
Fixes #177816

---------

Co-authored-by: Claude Opus 4.5 <noreply at anthropic.com>

Added: 
    mlir/test/Dialect/MemRef/high-rank-overflow.mlir

Modified: 
    mlir/include/mlir/IR/BuiltinTypeInterfaces.td
    mlir/lib/IR/BuiltinTypeInterfaces.cpp
    mlir/lib/IR/BuiltinTypes.cpp
    mlir/unittests/IR/ShapedTypeTest.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
index c879e5efd77fe..9ef08b7020b99 100644
--- a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
+++ b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
@@ -237,6 +237,10 @@ def ShapedTypeInterface : TypeInterface<"ShapedType"> {
     /// Return the number of elements present in the given shape.
     static int64_t getNumElements(ArrayRef<int64_t> shape);
 
+    /// Return the number of elements present in the given shape, or
+    /// std::nullopt if the computation would overflow.
+    static std::optional<int64_t> tryGetNumElements(ArrayRef<int64_t> shape);
+
     /// Return a clone of this type with the given new shape and element type.
     /// The returned type is ranked, even if this type is unranked.
     auto clone(::llvm::ArrayRef<int64_t> shape, Type elementType) {
@@ -276,6 +280,13 @@ def ShapedTypeInterface : TypeInterface<"ShapedType"> {
       return ::mlir::ShapedType::getNumElements($_type.getShape());
     }
 
+    /// Return the number of elements, or std::nullopt if the computation would overflow.
+    /// Precondition: `hasStaticShape()`, otherwise abort.
+    std::optional<int64_t> tryGetNumElements() const {
+      assert(hasStaticShape() && "cannot get element count of dynamic shaped type");
+      return ::mlir::ShapedType::tryGetNumElements($_type.getShape());
+    }
+
     /// Returns true if this dimension has a dynamic size (for ranked types);
     /// aborts for unranked types.
     bool isDynamicDim(unsigned idx) const {

diff  --git a/mlir/lib/IR/BuiltinTypeInterfaces.cpp b/mlir/lib/IR/BuiltinTypeInterfaces.cpp
index 031752bffeab8..2f063be3e7cd0 100644
--- a/mlir/lib/IR/BuiltinTypeInterfaces.cpp
+++ b/mlir/lib/IR/BuiltinTypeInterfaces.cpp
@@ -8,6 +8,7 @@
 
 #include "mlir/IR/BuiltinTypes.h"
 #include "llvm/ADT/APFloat.h"
+#include "llvm/Support/CheckedArithmetic.h"
 
 using namespace mlir;
 using namespace mlir::detail;
@@ -34,11 +35,26 @@ unsigned FloatType::getFPMantissaWidth() {
 // ShapedType
 //===----------------------------------------------------------------------===//
 
-int64_t ShapedType::getNumElements(ArrayRef<int64_t> shape) {
+std::optional<int64_t> ShapedType::tryGetNumElements(ArrayRef<int64_t> shape) {
   int64_t num = 1;
   for (int64_t dim : shape) {
-    num *= dim;
-    assert(num >= 0 && "integer overflow in element count computation");
+    auto result = llvm::checkedMul(num, dim);
+    if (!result)
+      return std::nullopt;
+    num = *result;
   }
   return num;
 }
+
+int64_t ShapedType::getNumElements(ArrayRef<int64_t> shape) {
+#ifndef NDEBUG
+  std::optional<int64_t> num = tryGetNumElements(shape);
+  assert(num.has_value() && "integer overflow in element count computation");
+  return *num;
+#else
+  int64_t num = 1;
+  for (int64_t dim : shape)
+    num *= dim;
+  return num;
+#endif
+}

diff  --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index ce47c60c9b932..1e198043c590a 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -19,6 +19,7 @@
 #include "llvm/ADT/APFloat.h"
 #include "llvm/ADT/Sequence.h"
 #include "llvm/ADT/TypeSwitch.h"
+#include "llvm/Support/CheckedArithmetic.h"
 
 using namespace mlir;
 using namespace mlir::detail;
@@ -875,8 +876,13 @@ AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,
                             : getAffineConstantExpr(runningSize, context);
     expr = expr ? expr + dimExpr * stride : dimExpr * stride;
     if (size > 0) {
-      runningSize *= size;
-      assert(runningSize > 0 && "integer overflow in size computation");
+      auto result = llvm::checkedMul(runningSize, size);
+      if (!result) {
+        // Overflow occurred, treat as dynamic
+        dynamicPoisonBit = true;
+      } else {
+        runningSize = *result;
+      }
     } else {
       dynamicPoisonBit = true;
     }

diff  --git a/mlir/test/Dialect/MemRef/high-rank-overflow.mlir b/mlir/test/Dialect/MemRef/high-rank-overflow.mlir
new file mode 100644
index 0000000000000..2a6ec113c7261
--- /dev/null
+++ b/mlir/test/Dialect/MemRef/high-rank-overflow.mlir
@@ -0,0 +1,21 @@
+// RUN: mlir-opt %s --convert-to-llvm --split-input-file --verify-diagnostics | FileCheck %s
+
+// Test that extremely high-rank memrefs with overflow in stride calculation
+// are handled gracefully instead of crashing (issue #177816).
+
+// CHECK-LABEL: func @high_rank_memref_overflow
+func.func @high_rank_memref_overflow() {
+  // This creates a memref with 64 dimensions of size 2, resulting in 2^64 elements
+  // which overflows int64_t. The stride calculation should handle this gracefully.
+  %0 = memref.alloc() : memref<2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2xi32>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func @high_rank_memref_max_dim
+func.func @high_rank_memref_max_dim() {
+  // Test with fewer dimensions but larger sizes that also cause overflow
+  %0 = memref.alloc() : memref<9223372036854775807x2xi32>
+  return
+}

diff  --git a/mlir/unittests/IR/ShapedTypeTest.cpp b/mlir/unittests/IR/ShapedTypeTest.cpp
index bc4066ed210e8..a013193b8dd9e 100644
--- a/mlir/unittests/IR/ShapedTypeTest.cpp
+++ b/mlir/unittests/IR/ShapedTypeTest.cpp
@@ -15,6 +15,7 @@
 #include "llvm/ADT/SmallVector.h"
 #include "gtest/gtest.h"
 #include <cstdint>
+#include <limits>
 
 using namespace mlir;
 using namespace mlir::detail;
@@ -298,4 +299,26 @@ TEST(ShapedTypeTest, RankedTensorTypeView) {
       cast<RankedTensorType>(stringEncodingRankedTensorType).dropEncoding());
 }
 
+TEST(ShapedTypeTest, GetNumElements) {
+  // Test normal case.
+  EXPECT_EQ(ShapedType::getNumElements({2, 3, 4}), 24);
+  EXPECT_EQ(ShapedType::getNumElements({1}), 1);
+  EXPECT_EQ(ShapedType::getNumElements({}), 1);
+
+  // Test tryGetNumElements returns value for normal shapes.
+  EXPECT_TRUE(ShapedType::tryGetNumElements({2, 3, 4}).has_value());
+  EXPECT_EQ(*ShapedType::tryGetNumElements({2, 3, 4}), 24);
+
+  // Test tryGetNumElements returns nullopt for overflow.
+  // INT64_MAX = 9223372036854775807, so multiplying large dimensions overflows.
+  SmallVector<int64_t> overflowShape;
+  for (int i = 0; i < 64; ++i)
+    overflowShape.push_back(2); // 2^64 would overflow int64_t
+  EXPECT_FALSE(ShapedType::tryGetNumElements(overflowShape).has_value());
+
+  // Another overflow case with fewer but larger dimensions.
+  int64_t maxVal = std::numeric_limits<int64_t>::max();
+  EXPECT_FALSE(ShapedType::tryGetNumElements({maxVal, 2}).has_value());
+}
+
 } // namespace


        


More information about the Mlir-commits mailing list