[Mlir-commits] [mlir] [MemRef] Implement value bounds interface for ExpandShapeOp (PR #164438)

Jorn Tuyls llvmlistbot at llvm.org
Tue Oct 21 08:09:04 PDT 2025


https://github.com/jtuyls created https://github.com/llvm/llvm-project/pull/164438

None

>From 1c1dc18d2f435e53b56cf675125c8953c7055d3a Mon Sep 17 00:00:00 2001
From: Jorn Tuyls <jorn.tuyls at gmail.com>
Date: Tue, 21 Oct 2025 14:55:56 +0000
Subject: [PATCH] [MemRef] Implement value bounds interface for ExpandShapeOp

---
 .../MemRef/IR/ValueBoundsOpInterfaceImpl.cpp       | 13 +++++++++++++
 .../MemRef/value-bounds-op-interface-impl.mlir     | 14 ++++++++++++++
 2 files changed, 27 insertions(+)

diff --git a/mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp
index 11400de35e430..a15bf891dd596 100644
--- a/mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp
@@ -59,6 +59,17 @@ struct DimOpInterface
   }
 };
 
+struct ExpandShapeOpInterface
+    : public ValueBoundsOpInterface::ExternalModel<ExpandShapeOpInterface,
+                                                   memref::ExpandShapeOp> {
+  void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
+                                       ValueBoundsConstraintSet &cstr) const {
+    auto expandOp = cast<memref::ExpandShapeOp>(op);
+    assert(value == expandOp.getResult() && "invalid value");
+    cstr.bound(value)[dim] == expandOp.getOutputShape()[dim];
+  }
+};
+
 struct GetGlobalOpInterface
     : public ValueBoundsOpInterface::ExternalModel<GetGlobalOpInterface,
                                                    GetGlobalOp> {
@@ -123,6 +134,8 @@ void mlir::memref::registerValueBoundsOpInterfaceExternalModels(
         memref::AllocOpInterface<memref::AllocaOp>>(*ctx);
     memref::CastOp::attachInterface<memref::CastOpInterface>(*ctx);
     memref::DimOp::attachInterface<memref::DimOpInterface>(*ctx);
+    memref::ExpandShapeOp::attachInterface<memref::ExpandShapeOpInterface>(
+        *ctx);
     memref::GetGlobalOp::attachInterface<memref::GetGlobalOpInterface>(*ctx);
     memref::RankOp::attachInterface<memref::RankOpInterface>(*ctx);
     memref::SubViewOp::attachInterface<memref::SubViewOpInterface>(*ctx);
diff --git a/mlir/test/Dialect/MemRef/value-bounds-op-interface-impl.mlir b/mlir/test/Dialect/MemRef/value-bounds-op-interface-impl.mlir
index 8bd7ae8df9049..ac1f22b68b1e1 100644
--- a/mlir/test/Dialect/MemRef/value-bounds-op-interface-impl.mlir
+++ b/mlir/test/Dialect/MemRef/value-bounds-op-interface-impl.mlir
@@ -63,6 +63,20 @@ func.func @memref_dim_all_positive(%m: memref<?xf32>, %x: index) {
 
 // -----
 
+// CHECK-LABEL: func @memref_expand(
+//  CHECK-SAME:     %[[m:[a-zA-Z0-9]+]]: memref<?xf32>
+//  CHECK-SAME:     %[[sz:[a-zA-Z0-9]+]]: index
+//       CHECK:   %[[c4:.*]] = arith.constant 4 : index
+//       CHECK:   return %[[sz]], %[[c4]]
+func.func @memref_expand(%m: memref<?xf32>, %sz: index) -> (index, index) {
+  %0 = memref.expand_shape %m [[0, 1]] output_shape [%sz, 4]: memref<?xf32> into memref<?x4xf32>
+  %1 = "test.reify_bound"(%0) {dim = 0} : (memref<?x4xf32>) -> (index)
+  %2 = "test.reify_bound"(%0) {dim = 1} : (memref<?x4xf32>) -> (index)
+  return %1, %2 : index, index
+}
+
+// -----
+
 // CHECK-LABEL: func @memref_get_global(
 //       CHECK:   %[[c4:.*]] = arith.constant 4 : index
 //       CHECK:   return %[[c4]]



More information about the Mlir-commits mailing list