[Mlir-commits] [mlir] c8e211c - [mlir][tensor] Fix crash in expand_shape fold with dynamic result type (#183785)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sat Feb 28 02:56:26 PST 2026
Author: Mehdi Amini
Date: 2026-02-28T11:56:21+01:00
New Revision: c8e211c2a8b23b7bed4fb6b76bd441caa44973de
URL: https://github.com/llvm/llvm-project/commit/c8e211c2a8b23b7bed4fb6b76bd441caa44973de
DIFF: https://github.com/llvm/llvm-project/commit/c8e211c2a8b23b7bed4fb6b76bd441caa44973de.diff
LOG: [mlir][tensor] Fix crash in expand_shape fold with dynamic result type (#183785)
`foldReshapeOp` (in `ReshapeOpsUtils.h`) and `FoldReshapeWithConstant`
(in `TensorOps.cpp`) both tried to create a new `DenseElementsAttr`
constant when folding a reshape op whose operand is a constant. Neither
checked that the result type was statically shaped before doing so, but
`DenseElementsAttr::reshape()` and
`DenseElementsAttr::getFromRawBuffer()` both assert `hasStaticShape()`.
Guard both fold paths with a `hasStaticShape()` check so they return
early when the result type contains a dynamic dimension.
Fixes #177845
Added:
Modified:
mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
mlir/test/Dialect/Tensor/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
index 525663e5cd6c5..2e8c0e269995e 100644
--- a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
@@ -90,9 +90,14 @@ static OpFoldResult foldReshapeOp(ReshapeOpTy reshapeOp,
if (reshapeOp.getSrcType() == reshapeOp.getType())
return reshapeOp.getSrc();
- // Reshape of a constant can be replaced with a new constant.
- if (auto elements = dyn_cast_or_null<DenseElementsAttr>(operands.front()))
- return elements.reshape(cast<ShapedType>(reshapeOp.getResult().getType()));
+ // Reshape of a constant can be replaced with a new constant, but only when
+ // the result type has a static shape. DenseElementsAttr::reshape requires
+ // a static shape to preserve the element count invariant.
+ if (auto elements = dyn_cast_or_null<DenseElementsAttr>(operands.front())) {
+ auto resultType = cast<ShapedType>(reshapeOp.getResult().getType());
+ if (resultType.hasStaticShape())
+ return elements.reshape(resultType);
+ }
// Fold if the producer reshape source has the same shape with at most 1
// dynamic dimension.
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 743bdabdd8542..7d77d8cb1cc00 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -2126,6 +2126,10 @@ struct FoldReshapeWithConstant : OpRewritePattern<TensorReshapeOp> {
return failure();
if (!attr || !attr.isSplat())
return failure();
+ // DenseElementsAttr requires a static shape; skip folding for dynamic
+ // result types.
+ if (!reshapeOp.getResultType().hasStaticShape())
+ return failure();
DenseElementsAttr newAttr = DenseElementsAttr::getFromRawBuffer(
reshapeOp.getResultType(), attr.getRawData());
rewriter.replaceOpWithNewOp<arith::ConstantOp>(reshapeOp, newAttr);
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 4b7e43ca84cec..b251313b6580b 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -1687,6 +1687,22 @@ func.func @reshape_splat_constant_float64() -> tensor<2x4x2xf64> {
// -----
+// Regression test for https://github.com/llvm/llvm-project/issues/177845:
+// tensor.expand_shape of a constant to a dynamic shape must not crash.
+// FoldReshapeWithConstant must not call DenseElementsAttr::getFromRawBuffer
+// when the result type is dynamic (getFromRawBuffer requires static shape).
+
+// CHECK-LABEL: @expand_shape_splat_constant_dynamic_result
+// CHECK: arith.constant
+// CHECK: tensor.expand_shape
+func.func @expand_shape_splat_constant_dynamic_result(%n: index) -> tensor<?xi32> {
+ %cst = arith.constant dense<1> : tensor<i32>
+ %result = tensor.expand_shape %cst [] output_shape [%n] : tensor<i32> into tensor<?xi32>
+ return %result : tensor<?xi32>
+}
+
+// -----
+
// CHECK-LABEL: func @fold_rank
func.func @fold_rank() -> (index) {
%const_0 = arith.constant dense<[[[1, -2, 1, 36]], [[0, 2, -1, 64]]]>
More information about the Mlir-commits
mailing list