[Mlir-commits] [mlir] [MLIR][ODS] Fix AllElementCountsMatch crash on dynamic shaped types (PR #183948)

Mehdi Amini llvmlistbot at llvm.org
Sun Mar 1 02:51:33 PST 2026


https://github.com/joker-eph updated https://github.com/llvm/llvm-project/pull/183948

>From 36e2098c62381648a8724a6942edefacbf0cd3be Mon Sep 17 00:00:00 2001
From: Mehdi Amini <joker.eph at gmail.com>
Date: Sat, 28 Feb 2026 10:10:09 -0800
Subject: [PATCH] [MLIR][ODS] Fix AllElementCountsMatch crash on dynamic or
 non-shaped types

The AllElementCountsMatch trait crashed in two cases:
- Dynamic shaped types: ShapedType::getNumElements() asserts hasStaticShape()
- Non-ShapedType inputs: llvm::cast<ShapedType> asserts on types like i32

Fix the predicate to check isa<ShapedType> before casting, so both cases
produce a verification diagnostic instead of crashing.

Add a test op with AnyType constraints and a test case for the non-shaped
path alongside the existing dynamic-shape regression test.

Fixes #159740
---
 mlir/include/mlir/IR/OpBase.td        | 14 ++++++++++++--
 mlir/test/lib/Dialect/Test/TestOps.td |  7 +++++++
 mlir/test/mlir-tblgen/types.mlir      | 21 +++++++++++++++++++++
 3 files changed, 40 insertions(+), 2 deletions(-)

diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index 7a667d701ab71..7f36e6c74c7f7 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -543,9 +543,19 @@ class AnyMatchOperatorTrait<list<string> names, string operator,
   list<string> values = names;
 }
 
+// Verifies that all named operands/results of a shaped type have the same
+// element count. Fails if any operand or result is not a ShapedType or has
+// dynamic dimensions, since the element count cannot be determined statically.
 class AllElementCountsMatch<list<string> names> :
-    AllMatchSameOperatorTrait<names, ElementCount<"_self">.result,
-                              "element count">;
+    PredOpTrait<"all of {" # !interleave(names, ", ") # "} have same element count",
+        And<[
+            // Fail if any type is not a ShapedType or has dynamic dimensions.
+            Neg<AnyMatchOperatorPred<names,
+                "!::mlir::isa<::mlir::ShapedType>($_self.getType()) || "
+                "!::llvm::cast<::mlir::ShapedType>($_self.getType()).hasStaticShape()">>,
+            // All types are statically shaped; verify element counts match.
+            AllMatchSameOperatorPred<names, ElementCount<"_self">.result>
+        ]>>;
 
 class AllElementTypesMatch<list<string> names> :
     AllMatchSameOperatorTrait<names, ElementType<"_self">.result,
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 62a374e08ec1c..fe02536a1df5b 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -831,6 +831,13 @@ def OperandZeroAndResultHaveSameElementCount :
   let results = (outs AnyShaped:$res);
 }
 
+def OperandAndResultHaveSameElementCountAnyType :
+    TEST_Op<"operand_and_result_have_same_element_count_any_type",
+            [AllElementCountsMatch<["x", "res"]>]> {
+  let arguments = (ins AnyType:$x);
+  let results = (outs AnyType:$res);
+}
+
 def FourEqualsFive :
     TEST_Op<"four_equals_five", [AllMatch<["5", "4"], "4 equals 5">]>;
 
diff --git a/mlir/test/mlir-tblgen/types.mlir b/mlir/test/mlir-tblgen/types.mlir
index 7652a87037c92..c2acce0903bf4 100644
--- a/mlir/test/mlir-tblgen/types.mlir
+++ b/mlir/test/mlir-tblgen/types.mlir
@@ -432,6 +432,27 @@ func.func @same_element_count_success(%arg0: tensor<36xi32>, %arg1: tensor<1x2xf
 
 // -----
 
+// Regression test for https://github.com/llvm/llvm-project/issues/159740
+// AllElementCountsMatch should fail (not crash) when operands/results have dynamic shapes.
+func.func @same_element_count_dynamic(%arg0: tensor<2xi32>) {
+  // expected-error at +1 {{all of {x, res} have same element count}}
+  %0 = "test.operand0_and_result_have_same_element_count"(%arg0, %arg0) :
+    (tensor<2xi32>, tensor<2xi32>) -> tensor<?x?xf32>
+  return
+}
+
+// -----
+
+// AllElementCountsMatch should fail (not crash) when operands/results are not
+// ShapedType at all (e.g. a plain integer).
+func.func @same_element_count_non_shaped(%arg0: i32) -> i32 {
+  // expected-error at +1 {{all of {x, res} have same element count}}
+  %0 = "test.operand_and_result_have_same_element_count_any_type"(%arg0) : (i32) -> i32
+  return %0 : i32
+}
+
+// -----
+
 func.func @same_element_count_failure(%arg0: tensor<1xi32>, %arg1: tensor<1x2xf32>) {
   // expected-error at +1 {{all of {x, res} have same element count}}
   "test.operand0_and_result_have_same_element_count"(%arg0, %arg1) : (tensor<1xi32>, tensor<1x2xf32>) -> (tensor<2xi32>)



More information about the Mlir-commits mailing list