[Mlir-commits] [mlir] c8e211c - [mlir][tensor] Fix crash in expand_shape fold with dynamic result type (#183785)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sat Feb 28 02:56:26 PST 2026


Author: Mehdi Amini
Date: 2026-02-28T11:56:21+01:00
New Revision: c8e211c2a8b23b7bed4fb6b76bd441caa44973de

URL: https://github.com/llvm/llvm-project/commit/c8e211c2a8b23b7bed4fb6b76bd441caa44973de
DIFF: https://github.com/llvm/llvm-project/commit/c8e211c2a8b23b7bed4fb6b76bd441caa44973de.diff

LOG: [mlir][tensor] Fix crash in expand_shape fold with dynamic result type (#183785)

`foldReshapeOp` (in `ReshapeOpsUtils.h`) and `FoldReshapeWithConstant`
(in `TensorOps.cpp`) both tried to create a new `DenseElementsAttr`
constant when folding a reshape op whose operand is a constant. Neither
checked that the result type was statically shaped before doing so, but
`DenseElementsAttr::reshape()` and
`DenseElementsAttr::getFromRawBuffer()` both assert `hasStaticShape()`.

Guard both fold paths with a `hasStaticShape()` check so they return
early when the result type contains a dynamic dimension.

Fixes #177845

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
    mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
    mlir/test/Dialect/Tensor/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
index 525663e5cd6c5..2e8c0e269995e 100644
--- a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
@@ -90,9 +90,14 @@ static OpFoldResult foldReshapeOp(ReshapeOpTy reshapeOp,
   if (reshapeOp.getSrcType() == reshapeOp.getType())
     return reshapeOp.getSrc();
 
-  // Reshape of a constant can be replaced with a new constant.
-  if (auto elements = dyn_cast_or_null<DenseElementsAttr>(operands.front()))
-    return elements.reshape(cast<ShapedType>(reshapeOp.getResult().getType()));
+  // Reshape of a constant can be replaced with a new constant, but only when
+  // the result type has a static shape. DenseElementsAttr::reshape requires
+  // a static shape to preserve the element count invariant.
+  if (auto elements = dyn_cast_or_null<DenseElementsAttr>(operands.front())) {
+    auto resultType = cast<ShapedType>(reshapeOp.getResult().getType());
+    if (resultType.hasStaticShape())
+      return elements.reshape(resultType);
+  }
 
   // Fold if the producer reshape source has the same shape with at most 1
   // dynamic dimension.

diff  --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 743bdabdd8542..7d77d8cb1cc00 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -2126,6 +2126,10 @@ struct FoldReshapeWithConstant : OpRewritePattern<TensorReshapeOp> {
       return failure();
     if (!attr || !attr.isSplat())
       return failure();
+    // DenseElementsAttr requires a static shape; skip folding for dynamic
+    // result types.
+    if (!reshapeOp.getResultType().hasStaticShape())
+      return failure();
     DenseElementsAttr newAttr = DenseElementsAttr::getFromRawBuffer(
         reshapeOp.getResultType(), attr.getRawData());
     rewriter.replaceOpWithNewOp<arith::ConstantOp>(reshapeOp, newAttr);

diff  --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 4b7e43ca84cec..b251313b6580b 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -1687,6 +1687,22 @@ func.func @reshape_splat_constant_float64() -> tensor<2x4x2xf64> {
 
 // -----
 
+// Regression test for https://github.com/llvm/llvm-project/issues/177845:
+// tensor.expand_shape of a constant to a dynamic shape must not crash.
+// FoldReshapeWithConstant must not call DenseElementsAttr::getFromRawBuffer
+// when the result type is dynamic (getFromRawBuffer requires static shape).
+
+// CHECK-LABEL: @expand_shape_splat_constant_dynamic_result
+//       CHECK:   arith.constant
+//       CHECK:   tensor.expand_shape
+func.func @expand_shape_splat_constant_dynamic_result(%n: index) -> tensor<?xi32> {
+  %cst = arith.constant dense<1> : tensor<i32>
+  %result = tensor.expand_shape %cst [] output_shape [%n] : tensor<i32> into tensor<?xi32>
+  return %result : tensor<?xi32>
+}
+
+// -----
+
 // CHECK-LABEL: func @fold_rank
 func.func @fold_rank() -> (index) {
   %const_0 = arith.constant dense<[[[1, -2, 1, 36]], [[0, 2, -1, 64]]]>


        


More information about the Mlir-commits mailing list