[Mlir-commits] [mlir] [MemRef] Fix value bounds interface for ExpandShapeOp (PR #165333)

Jorn Tuyls llvmlistbot at llvm.org
Mon Oct 27 15:33:39 PDT 2025


https://github.com/jtuyls created https://github.com/llvm/llvm-project/pull/165333

We shouldn't just consider the dynamic dimensions, but all output dimensions for the value bounds constraints. The previous test just worked because the dynamic dimension was on the first position.

>From 1057fb686ff20221f673748cee65c8f4c2f50ace Mon Sep 17 00:00:00 2001
From: Jorn Tuyls <jorn.tuyls at gmail.com>
Date: Mon, 27 Oct 2025 22:29:02 +0000
Subject: [PATCH] [MemRef] Fix value bounds interface for ExpandShapeOp

---
 mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp | 2 +-
 .../Dialect/MemRef/value-bounds-op-interface-impl.mlir    | 8 ++++----
 2 files changed, 5 insertions(+), 5 deletions(-)

diff --git a/mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp
index a15bf891dd596..6fa8ce4efff3b 100644
--- a/mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp
@@ -66,7 +66,7 @@ struct ExpandShapeOpInterface
                                        ValueBoundsConstraintSet &cstr) const {
     auto expandOp = cast<memref::ExpandShapeOp>(op);
     assert(value == expandOp.getResult() && "invalid value");
-    cstr.bound(value)[dim] == expandOp.getOutputShape()[dim];
+    cstr.bound(value)[dim] == expandOp.getMixedOutputShape()[dim];
   }
 };
 
diff --git a/mlir/test/Dialect/MemRef/value-bounds-op-interface-impl.mlir b/mlir/test/Dialect/MemRef/value-bounds-op-interface-impl.mlir
index ac1f22b68b1e1..f9b81dfc7d468 100644
--- a/mlir/test/Dialect/MemRef/value-bounds-op-interface-impl.mlir
+++ b/mlir/test/Dialect/MemRef/value-bounds-op-interface-impl.mlir
@@ -67,11 +67,11 @@ func.func @memref_dim_all_positive(%m: memref<?xf32>, %x: index) {
 //  CHECK-SAME:     %[[m:[a-zA-Z0-9]+]]: memref<?xf32>
 //  CHECK-SAME:     %[[sz:[a-zA-Z0-9]+]]: index
 //       CHECK:   %[[c4:.*]] = arith.constant 4 : index
-//       CHECK:   return %[[sz]], %[[c4]]
+//       CHECK:   return %[[c4]], %[[sz]]
 func.func @memref_expand(%m: memref<?xf32>, %sz: index) -> (index, index) {
-  %0 = memref.expand_shape %m [[0, 1]] output_shape [%sz, 4]: memref<?xf32> into memref<?x4xf32>
-  %1 = "test.reify_bound"(%0) {dim = 0} : (memref<?x4xf32>) -> (index)
-  %2 = "test.reify_bound"(%0) {dim = 1} : (memref<?x4xf32>) -> (index)
+  %0 = memref.expand_shape %m [[0, 1]] output_shape [4, %sz]: memref<?xf32> into memref<4x?xf32>
+  %1 = "test.reify_bound"(%0) {dim = 0} : (memref<4x?xf32>) -> (index)
+  %2 = "test.reify_bound"(%0) {dim = 1} : (memref<4x?xf32>) -> (index)
   return %1, %2 : index, index
 }
 



More information about the Mlir-commits mailing list