[Mlir-commits] [mlir] [mlir][tensor] Guard constant reshape folding (PR #179077)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sat Jan 31 16:26:12 PST 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Samarth Narang (snarang181)

<details>
<summary>Changes</summary>

Fixes https://github.com/llvm/llvm-project/issues/178205

---
Full diff: https://github.com/llvm/llvm-project/pull/179077.diff


1 Files Affected:

- (modified) mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h (+7-2) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
index 64c125024d906..6f72dd34aa1bd 100644
--- a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
@@ -91,8 +91,13 @@ static OpFoldResult foldReshapeOp(ReshapeOpTy reshapeOp,
     return reshapeOp.getSrc();
 
   // Reshape of a constant can be replaced with a new constant.
-  if (auto elements = dyn_cast_or_null<DenseElementsAttr>(operands.front()))
-    return elements.reshape(cast<ShapedType>(reshapeOp.getResult().getType()));
+  if (auto elements = dyn_cast_or_null<DenseElementsAttr>(operands.front())) {
+    auto dstTy = cast<ShapedType>(reshapeOp.getResult().getType());
+    auto srcTy = cast<ShapedType>(elements.getType());
+    if (srcTy.getNumElements() != dstTy.getNumElements())
+      return nullptr; // do not fold on invalid reshape
+    return elements.reshape(dstTy);
+  }
 
   // Fold if the producer reshape source has the same shape with at most 1
   // dynamic dimension.

``````````

</details>


https://github.com/llvm/llvm-project/pull/179077


More information about the Mlir-commits mailing list