[Mlir-commits] [mlir] [mlir][tensor] Guard constant reshape folding (PR #179077)

Samarth Narang llvmlistbot at llvm.org
Sat Jan 31 16:25:42 PST 2026


https://github.com/snarang181 created https://github.com/llvm/llvm-project/pull/179077

Fixes https://github.com/llvm/llvm-project/issues/178205

>From 922eb3a69b1d819e34f45440173284fc2fb08969 Mon Sep 17 00:00:00 2001
From: Samarth Narang <snarang at utexas.edu>
Date: Sat, 31 Jan 2026 19:25:15 -0500
Subject: [PATCH] [mlir][tensor] Guard constant reshape folding

---
 mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h | 9 +++++++--
 1 file changed, 7 insertions(+), 2 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
index 64c125024d906..6f72dd34aa1bd 100644
--- a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
@@ -91,8 +91,13 @@ static OpFoldResult foldReshapeOp(ReshapeOpTy reshapeOp,
     return reshapeOp.getSrc();
 
   // Reshape of a constant can be replaced with a new constant.
-  if (auto elements = dyn_cast_or_null<DenseElementsAttr>(operands.front()))
-    return elements.reshape(cast<ShapedType>(reshapeOp.getResult().getType()));
+  if (auto elements = dyn_cast_or_null<DenseElementsAttr>(operands.front())) {
+    auto dstTy = cast<ShapedType>(reshapeOp.getResult().getType());
+    auto srcTy = cast<ShapedType>(elements.getType());
+    if (srcTy.getNumElements() != dstTy.getNumElements())
+      return nullptr; // do not fold on invalid reshape
+    return elements.reshape(dstTy);
+  }
 
   // Fold if the producer reshape source has the same shape with at most 1
   // dynamic dimension.



More information about the Mlir-commits mailing list