[Mlir-commits] [mlir] [MLIR][Tensor] Fix incorrect operand consumption in expand_shape canonicalization (PR #180705)

Keshav Vinayak Jha llvmlistbot at llvm.org
Tue Feb 10 00:51:46 PST 2026


https://github.com/keshavvinayak01 updated https://github.com/llvm/llvm-project/pull/180705

>From 12245ade77aab2e8ef0759836993bcccf83e8536 Mon Sep 17 00:00:00 2001
From: Keshav Vinayak Jha <keshavvinayakjha at gmail.com>
Date: Tue, 10 Feb 2026 00:41:03 -0800
Subject: [PATCH] Added fix for incorrect operand consumption in expand_shape
 canonicalization

Signed-off-by: Keshav Vinayak Jha <keshavvinayakjha at gmail.com>
---
 mlir/lib/Dialect/Tensor/IR/TensorOps.cpp   | 42 ++++++++++++----------
 mlir/test/Dialect/Tensor/canonicalize.mlir | 26 ++++++++++++++
 2 files changed, 49 insertions(+), 19 deletions(-)

diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index d837947e0dc3b..2d532be7fd026 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -2207,25 +2207,29 @@ struct ConvertToStaticExpandShape : public OpRewritePattern<ExpandShapeOp> {
 
     for (const auto &[inputDim, innerReassoc] : llvm::enumerate(reassoc)) {
       for (uint64_t outDim : innerReassoc) {
-        if (ShapedType::isStatic(newOutputShape[outDim]))
-          continue;
-
-        // If the cast's src type is dynamic, don't infer any of the
-        // corresponding expanded dimensions. `tensor.expand_shape` requires at
-        // least one of the expanded dimensions to be dynamic if the input is
-        // dynamic.
-        Value val = *outputIt;
-        ++outputIt;
-        if (ShapedType::isDynamic(castSrcShape[inputDim])) {
-          dynamicOutputShape.push_back(val);
-          continue;
-        }
-
-        APInt cst;
-        if (matchPattern(val, m_ConstantInt(&cst))) {
-          newOutputShape[outDim] = cst.getSExtValue();
-        } else {
-          dynamicOutputShape.push_back(val);
+        // If the static output shape has a dynamic dim, we must consume an operand
+        // from the input list, even if the result type is static.
+        if (expandOp.getStaticOutputShape()[outDim] == ShapedType::kDynamic) {
+          Value val = *outputIt;
+          ++outputIt;
+          if (ShapedType::isStatic(newOutputShape[outDim]))
+            continue;
+
+          // If the cast's src type is dynamic, don't infer any of the
+          // corresponding expanded dimensions. `tensor.expand_shape` requires at
+          // least one of the expanded dimensions to be dynamic if the input is
+          // dynamic.
+          if (ShapedType::isDynamic(castSrcShape[inputDim])) {
+            dynamicOutputShape.push_back(val);
+            continue;
+          }
+
+          APInt cst;
+          if (matchPattern(val, m_ConstantInt(&cst))) {
+            newOutputShape[outDim] = cst.getSExtValue();
+          } else {
+            dynamicOutputShape.push_back(val);
+          }
         }
       }
     }
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 7a2d53c0c5850..5b5d1ae6c77ef 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -2554,6 +2554,32 @@ func.func @fold_expand_of_cast(%arg0 : tensor<10x10xf32>)
 
 // -----
 
+// CHECK-LABEL: func @fold_expand_of_cast_mixed_shape
+// CHECK-SAME: %[[ARG0:.*]]: tensor<4x8xf32>
+func.func @fold_expand_of_cast_mixed_shape(%arg0: tensor<4x8xf32>) -> (index, index, index) {
+  %c1 = arith.constant 1 : index
+  %c4 = arith.constant 4 : index
+  %c8 = arith.constant 8 : index
+  %0 = tensor.cast %arg0 : tensor<4x8xf32> to tensor<?x?xf32>
+  %1 = tensor.expand_shape %0 [[0, 1], [2]] output_shape [%c1, %c4, %c8] : tensor<?x?xf32> into tensor<1x?x?xf32>
+
+  %idx0 = arith.constant 0 : index
+  %idx1 = arith.constant 1 : index
+  %idx2 = arith.constant 2 : index
+
+  %dim0 = tensor.dim %1, %idx0 : tensor<1x?x?xf32>
+  %dim1 = tensor.dim %1, %idx1 : tensor<1x?x?xf32>
+  %dim2 = tensor.dim %1, %idx2 : tensor<1x?x?xf32>
+
+  // CHECK: %[[C1:.*]] = arith.constant 1 : index
+  // CHECK: %[[C4:.*]] = arith.constant 4 : index
+  // CHECK: %[[C8:.*]] = arith.constant 8 : index
+  // CHECK: return %[[C1]], %[[C4]], %[[C8]]
+  return %dim0, %dim1, %dim2 : index, index, index
+}
+
+// -----
+
 func.func @sink_expand_of_cast(%arg0 : tensor<?x10xf32>)
     -> tensor<?x?x?xf32> {
   %c1 = arith.constant 1 : index



More information about the Mlir-commits mailing list