[Mlir-commits] [mlir] [mlir][Arith] Fix crash when folding operations with dynamic-shaped tensors (PR #178428)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Jan 28 06:12:33 PST 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-arith

Author: puneeth_aditya_5656 (mugiwaraluffy56)

<details>
<summary>Changes</summary>

## Summary
- Add static shape check in `getBoolAttribute` to prevent crash when folding comparison operations with dynamic-shaped tensor types
- Add static shape check in `SelectOp::fold` before creating `DenseElementsAttr` for the result
- Add test case to verify the fix

Fixes #<!-- -->178415.

## Test plan
- [x] Added test case in `mlir/test/Dialect/Arith/canonicalize.mlir` that verifies `arith.cmpi` with dynamic-shaped tensors does not crash
- [ ] CI tests pass


---
Full diff: https://github.com/llvm/llvm-project/pull/178428.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/Arith/IR/ArithOps.cpp (+7) 
- (modified) mlir/test/Dialect/Arith/canonicalize.mlir (+11) 


``````````diff
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 565a537616971..d6a7df72750a0 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -151,6 +151,9 @@ static Attribute getBoolAttribute(Type type, bool value) {
   ShapedType shapedType = dyn_cast_or_null<ShapedType>(type);
   if (!shapedType)
     return boolAttr;
+  // DenseElementsAttr requires a static shape.
+  if (!shapedType.hasStaticShape())
+    return {};
   return DenseElementsAttr::get(shapedType, boolAttr);
 }
 
@@ -2531,6 +2534,10 @@ OpFoldResult arith::SelectOp::fold(FoldAdaptor adaptor) {
             dyn_cast_if_present<DenseElementsAttr>(adaptor.getTrueValue())) {
       if (auto rhs =
               dyn_cast_if_present<DenseElementsAttr>(adaptor.getFalseValue())) {
+        // DenseElementsAttr::get requires a static shape.
+        if (!lhs.getType().hasStaticShape())
+          return nullptr;
+
         SmallVector<Attribute> results;
         results.reserve(static_cast<size_t>(cond.getNumElements()));
         auto condVals = llvm::make_range(cond.value_begin<BoolAttr>(),
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index 3ad1530248809..18e0d2d2ea3c4 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -3403,3 +3403,14 @@ func.func @unreachable() {
   cf.br ^unreachable
 }
 
+// -----
+
+// Verify that cmpi with dynamic-shaped tensors does not crash during folding.
+// The fold cannot create a DenseElementsAttr for dynamic shapes.
+// CHECK-LABEL: @cmpi_dynamic_shape_no_fold
+//       CHECK:   arith.cmpi eq
+func.func @cmpi_dynamic_shape_no_fold(%arg0: tensor<?xi32>) -> tensor<?xi1> {
+  %0 = arith.cmpi eq, %arg0, %arg0 : tensor<?xi32>
+  return %0 : tensor<?xi1>
+}
+

``````````

</details>


https://github.com/llvm/llvm-project/pull/178428


More information about the Mlir-commits mailing list