[Mlir-commits] [mlir] [MemRef] Implement value bounds interface for CollapseShapeOp (PR #164955)
Jorn Tuyls
llvmlistbot at llvm.org
Fri Oct 24 04:03:55 PDT 2025
https://github.com/jtuyls created https://github.com/llvm/llvm-project/pull/164955
None
>From 19a950111fefe8ba36eacb9bdc28adfa86e24230 Mon Sep 17 00:00:00 2001
From: Jorn Tuyls <jorn.tuyls at gmail.com>
Date: Fri, 24 Oct 2025 10:57:39 +0000
Subject: [PATCH] [MemRef] Implement value bounds interface for CollapseShapeOp
---
.../MemRef/IR/ValueBoundsOpInterfaceImpl.cpp | 23 +++++++++++++++++++
.../value-bounds-op-interface-impl.mlir | 18 +++++++++++++++
2 files changed, 41 insertions(+)
diff --git a/mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp
index a15bf891dd596..ca3c366ccec5e 100644
--- a/mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp
@@ -98,6 +98,27 @@ struct RankOpInterface
}
};
+struct CollapseShapeOpInterface
+ : public ValueBoundsOpInterface::ExternalModel<CollapseShapeOpInterface,
+ memref::CollapseShapeOp> {
+ void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
+ ValueBoundsConstraintSet &cstr) const {
+ auto collapseOp = cast<memref::CollapseShapeOp>(op);
+ assert(value == collapseOp.getResult() && "invalid value");
+
+ // Multiply the expressions for the dimensions in the reassociation group.
+ const ReassociationIndices &reassocIndices =
+ collapseOp.getReassociationIndices()[dim];
+ AffineExpr productExpr =
+ cstr.getExpr(collapseOp.getSrc(), reassocIndices[0]);
+ for (size_t i = 1; i < reassocIndices.size(); ++i) {
+ productExpr =
+ productExpr * cstr.getExpr(collapseOp.getSrc(), reassocIndices[i]);
+ }
+ cstr.bound(value)[dim] == productExpr;
+ }
+};
+
struct SubViewOpInterface
: public ValueBoundsOpInterface::ExternalModel<SubViewOpInterface,
SubViewOp> {
@@ -134,6 +155,8 @@ void mlir::memref::registerValueBoundsOpInterfaceExternalModels(
memref::AllocOpInterface<memref::AllocaOp>>(*ctx);
memref::CastOp::attachInterface<memref::CastOpInterface>(*ctx);
memref::DimOp::attachInterface<memref::DimOpInterface>(*ctx);
+ memref::CollapseShapeOp::attachInterface<memref::CollapseShapeOpInterface>(
+ *ctx);
memref::ExpandShapeOp::attachInterface<memref::ExpandShapeOpInterface>(
*ctx);
memref::GetGlobalOp::attachInterface<memref::GetGlobalOpInterface>(*ctx);
diff --git a/mlir/test/Dialect/MemRef/value-bounds-op-interface-impl.mlir b/mlir/test/Dialect/MemRef/value-bounds-op-interface-impl.mlir
index ac1f22b68b1e1..700535a3c21ff 100644
--- a/mlir/test/Dialect/MemRef/value-bounds-op-interface-impl.mlir
+++ b/mlir/test/Dialect/MemRef/value-bounds-op-interface-impl.mlir
@@ -77,6 +77,24 @@ func.func @memref_expand(%m: memref<?xf32>, %sz: index) -> (index, index) {
// -----
+// CHECK: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 * 2)>
+// CHECK-LABEL: func @memref_collapse(
+// CHECK-SAME: %[[sz0:.*]]: index
+// CHECK-DAG: %[[c2:.*]] = arith.constant 2 : index
+// CHECK-DAG: %[[c12:.*]] = arith.constant 12 : index
+// CHECK: %[[dim:.*]] = memref.dim %{{.*}}, %[[c2]] : memref<3x4x?x2xf32>
+// CHECK: %[[mul:.*]] = affine.apply #[[$MAP]]()[%[[dim]]]
+// CHECK: return %[[c12]], %[[mul]]
+func.func @memref_collapse(%sz0: index) -> (index, index) {
+ %0 = memref.alloc(%sz0) : memref<3x4x?x2xf32>
+ %1 = memref.collapse_shape %0 [[0, 1], [2, 3]] : memref<3x4x?x2xf32> into memref<12x?xf32>
+ %2 = "test.reify_bound"(%1) {dim = 0} : (memref<12x?xf32>) -> (index)
+ %3 = "test.reify_bound"(%1) {dim = 1} : (memref<12x?xf32>) -> (index)
+ return %2, %3 : index, index
+}
+
+// -----
+
// CHECK-LABEL: func @memref_get_global(
// CHECK: %[[c4:.*]] = arith.constant 4 : index
// CHECK: return %[[c4]]
More information about the Mlir-commits
mailing list