[Mlir-commits] [mlir] [mlir][tensor] Guard constant reshape folding (PR #179077)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sat Jan 31 16:26:12 PST 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Samarth Narang (snarang181)
<details>
<summary>Changes</summary>
Fixes https://github.com/llvm/llvm-project/issues/178205
---
Full diff: https://github.com/llvm/llvm-project/pull/179077.diff
1 Files Affected:
- (modified) mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h (+7-2)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
index 64c125024d906..6f72dd34aa1bd 100644
--- a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
@@ -91,8 +91,13 @@ static OpFoldResult foldReshapeOp(ReshapeOpTy reshapeOp,
return reshapeOp.getSrc();
// Reshape of a constant can be replaced with a new constant.
- if (auto elements = dyn_cast_or_null<DenseElementsAttr>(operands.front()))
- return elements.reshape(cast<ShapedType>(reshapeOp.getResult().getType()));
+ if (auto elements = dyn_cast_or_null<DenseElementsAttr>(operands.front())) {
+ auto dstTy = cast<ShapedType>(reshapeOp.getResult().getType());
+ auto srcTy = cast<ShapedType>(elements.getType());
+ if (srcTy.getNumElements() != dstTy.getNumElements())
+ return nullptr; // do not fold on invalid reshape
+ return elements.reshape(dstTy);
+ }
// Fold if the producer reshape source has the same shape with at most 1
// dynamic dimension.
``````````
</details>
https://github.com/llvm/llvm-project/pull/179077
More information about the Mlir-commits
mailing list