[Mlir-commits] [mlir] [MLIR][Tensor] Fix incorrect operand consumption in expand_shape canonicalization (PR #180705)
Keshav Vinayak Jha
llvmlistbot at llvm.org
Tue Feb 10 00:51:46 PST 2026
https://github.com/keshavvinayak01 updated https://github.com/llvm/llvm-project/pull/180705
>From 12245ade77aab2e8ef0759836993bcccf83e8536 Mon Sep 17 00:00:00 2001
From: Keshav Vinayak Jha <keshavvinayakjha at gmail.com>
Date: Tue, 10 Feb 2026 00:41:03 -0800
Subject: [PATCH] Added fix for incorrect operand consumption in expand_shape
canonicalization
Signed-off-by: Keshav Vinayak Jha <keshavvinayakjha at gmail.com>
---
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 42 ++++++++++++----------
mlir/test/Dialect/Tensor/canonicalize.mlir | 26 ++++++++++++++
2 files changed, 49 insertions(+), 19 deletions(-)
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index d837947e0dc3b..2d532be7fd026 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -2207,25 +2207,29 @@ struct ConvertToStaticExpandShape : public OpRewritePattern<ExpandShapeOp> {
for (const auto &[inputDim, innerReassoc] : llvm::enumerate(reassoc)) {
for (uint64_t outDim : innerReassoc) {
- if (ShapedType::isStatic(newOutputShape[outDim]))
- continue;
-
- // If the cast's src type is dynamic, don't infer any of the
- // corresponding expanded dimensions. `tensor.expand_shape` requires at
- // least one of the expanded dimensions to be dynamic if the input is
- // dynamic.
- Value val = *outputIt;
- ++outputIt;
- if (ShapedType::isDynamic(castSrcShape[inputDim])) {
- dynamicOutputShape.push_back(val);
- continue;
- }
-
- APInt cst;
- if (matchPattern(val, m_ConstantInt(&cst))) {
- newOutputShape[outDim] = cst.getSExtValue();
- } else {
- dynamicOutputShape.push_back(val);
+ // If the static output shape has a dynamic dim, we must consume an operand
+ // from the input list, even if the result type is static.
+ if (expandOp.getStaticOutputShape()[outDim] == ShapedType::kDynamic) {
+ Value val = *outputIt;
+ ++outputIt;
+ if (ShapedType::isStatic(newOutputShape[outDim]))
+ continue;
+
+ // If the cast's src type is dynamic, don't infer any of the
+ // corresponding expanded dimensions. `tensor.expand_shape` requires at
+ // least one of the expanded dimensions to be dynamic if the input is
+ // dynamic.
+ if (ShapedType::isDynamic(castSrcShape[inputDim])) {
+ dynamicOutputShape.push_back(val);
+ continue;
+ }
+
+ APInt cst;
+ if (matchPattern(val, m_ConstantInt(&cst))) {
+ newOutputShape[outDim] = cst.getSExtValue();
+ } else {
+ dynamicOutputShape.push_back(val);
+ }
}
}
}
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 7a2d53c0c5850..5b5d1ae6c77ef 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -2554,6 +2554,32 @@ func.func @fold_expand_of_cast(%arg0 : tensor<10x10xf32>)
// -----
+// CHECK-LABEL: func @fold_expand_of_cast_mixed_shape
+// CHECK-SAME: %[[ARG0:.*]]: tensor<4x8xf32>
+func.func @fold_expand_of_cast_mixed_shape(%arg0: tensor<4x8xf32>) -> (index, index, index) {
+ %c1 = arith.constant 1 : index
+ %c4 = arith.constant 4 : index
+ %c8 = arith.constant 8 : index
+ %0 = tensor.cast %arg0 : tensor<4x8xf32> to tensor<?x?xf32>
+ %1 = tensor.expand_shape %0 [[0, 1], [2]] output_shape [%c1, %c4, %c8] : tensor<?x?xf32> into tensor<1x?x?xf32>
+
+ %idx0 = arith.constant 0 : index
+ %idx1 = arith.constant 1 : index
+ %idx2 = arith.constant 2 : index
+
+ %dim0 = tensor.dim %1, %idx0 : tensor<1x?x?xf32>
+ %dim1 = tensor.dim %1, %idx1 : tensor<1x?x?xf32>
+ %dim2 = tensor.dim %1, %idx2 : tensor<1x?x?xf32>
+
+ // CHECK: %[[C1:.*]] = arith.constant 1 : index
+ // CHECK: %[[C4:.*]] = arith.constant 4 : index
+ // CHECK: %[[C8:.*]] = arith.constant 8 : index
+ // CHECK: return %[[C1]], %[[C4]], %[[C8]]
+ return %dim0, %dim1, %dim2 : index, index, index
+}
+
+// -----
+
func.func @sink_expand_of_cast(%arg0 : tensor<?x10xf32>)
-> tensor<?x?x?xf32> {
%c1 = arith.constant 1 : index
More information about the Mlir-commits
mailing list