[Mlir-commits] [mlir] [mlir][tensor] Guard constant reshape folding (PR #179077)
Samarth Narang
llvmlistbot at llvm.org
Sat Jan 31 16:25:42 PST 2026
https://github.com/snarang181 created https://github.com/llvm/llvm-project/pull/179077
Fixes https://github.com/llvm/llvm-project/issues/178205
>From 922eb3a69b1d819e34f45440173284fc2fb08969 Mon Sep 17 00:00:00 2001
From: Samarth Narang <snarang at utexas.edu>
Date: Sat, 31 Jan 2026 19:25:15 -0500
Subject: [PATCH] [mlir][tensor] Guard constant reshape folding
---
mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h | 9 +++++++--
1 file changed, 7 insertions(+), 2 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
index 64c125024d906..6f72dd34aa1bd 100644
--- a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
@@ -91,8 +91,13 @@ static OpFoldResult foldReshapeOp(ReshapeOpTy reshapeOp,
return reshapeOp.getSrc();
// Reshape of a constant can be replaced with a new constant.
- if (auto elements = dyn_cast_or_null<DenseElementsAttr>(operands.front()))
- return elements.reshape(cast<ShapedType>(reshapeOp.getResult().getType()));
+ if (auto elements = dyn_cast_or_null<DenseElementsAttr>(operands.front())) {
+ auto dstTy = cast<ShapedType>(reshapeOp.getResult().getType());
+ auto srcTy = cast<ShapedType>(elements.getType());
+ if (srcTy.getNumElements() != dstTy.getNumElements())
+ return nullptr; // do not fold on invalid reshape
+ return elements.reshape(dstTy);
+ }
// Fold if the producer reshape source has the same shape with at most 1
// dynamic dimension.
More information about the Mlir-commits
mailing list