[Mlir-commits] [mlir] 6ad9565 - [MemRef] Implement value bounds interface for CollapseShapeOp (#164955)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Oct 28 09:24:39 PDT 2025


Author: Jorn Tuyls
Date: 2025-10-28T09:24:35-07:00
New Revision: 6ad95651159bd5d2671788de19f1ef7748bcb3c9

URL: https://github.com/llvm/llvm-project/commit/6ad95651159bd5d2671788de19f1ef7748bcb3c9
DIFF: https://github.com/llvm/llvm-project/commit/6ad95651159bd5d2671788de19f1ef7748bcb3c9.diff

LOG: [MemRef] Implement value bounds interface for CollapseShapeOp (#164955)

Added: 
    

Modified: 
    mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp
    mlir/test/Dialect/MemRef/value-bounds-op-interface-impl.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp
index 6fa8ce4efff3b..3aa801b48a2e9 100644
--- a/mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp
@@ -98,6 +98,27 @@ struct RankOpInterface
   }
 };
 
+struct CollapseShapeOpInterface
+    : public ValueBoundsOpInterface::ExternalModel<CollapseShapeOpInterface,
+                                                   memref::CollapseShapeOp> {
+  void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
+                                       ValueBoundsConstraintSet &cstr) const {
+    auto collapseOp = cast<memref::CollapseShapeOp>(op);
+    assert(value == collapseOp.getResult() && "invalid value");
+
+    // Multiply the expressions for the dimensions in the reassociation group.
+    const ReassociationIndices &reassocIndices =
+        collapseOp.getReassociationIndices()[dim];
+    AffineExpr productExpr =
+        cstr.getExpr(collapseOp.getSrc(), reassocIndices[0]);
+    for (size_t i = 1; i < reassocIndices.size(); ++i) {
+      productExpr =
+          productExpr * cstr.getExpr(collapseOp.getSrc(), reassocIndices[i]);
+    }
+    cstr.bound(value)[dim] == productExpr;
+  }
+};
+
 struct SubViewOpInterface
     : public ValueBoundsOpInterface::ExternalModel<SubViewOpInterface,
                                                    SubViewOp> {
@@ -134,6 +155,8 @@ void mlir::memref::registerValueBoundsOpInterfaceExternalModels(
         memref::AllocOpInterface<memref::AllocaOp>>(*ctx);
     memref::CastOp::attachInterface<memref::CastOpInterface>(*ctx);
     memref::DimOp::attachInterface<memref::DimOpInterface>(*ctx);
+    memref::CollapseShapeOp::attachInterface<memref::CollapseShapeOpInterface>(
+        *ctx);
     memref::ExpandShapeOp::attachInterface<memref::ExpandShapeOpInterface>(
         *ctx);
     memref::GetGlobalOp::attachInterface<memref::GetGlobalOpInterface>(*ctx);

diff  --git a/mlir/test/Dialect/MemRef/value-bounds-op-interface-impl.mlir b/mlir/test/Dialect/MemRef/value-bounds-op-interface-impl.mlir
index f9b81dfc7d468..d0aec68d54988 100644
--- a/mlir/test/Dialect/MemRef/value-bounds-op-interface-impl.mlir
+++ b/mlir/test/Dialect/MemRef/value-bounds-op-interface-impl.mlir
@@ -77,6 +77,24 @@ func.func @memref_expand(%m: memref<?xf32>, %sz: index) -> (index, index) {
 
 // -----
 
+//       CHECK: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 * 2)>
+// CHECK-LABEL: func @memref_collapse(
+//  CHECK-SAME:     %[[sz0:.*]]: index
+//   CHECK-DAG:   %[[c2:.*]] = arith.constant 2 : index
+//   CHECK-DAG:   %[[c12:.*]] = arith.constant 12 : index
+//       CHECK:   %[[dim:.*]] = memref.dim %{{.*}}, %[[c2]] : memref<3x4x?x2xf32>
+//       CHECK:   %[[mul:.*]] = affine.apply #[[$MAP]]()[%[[dim]]]
+//       CHECK:   return %[[c12]], %[[mul]]
+func.func @memref_collapse(%sz0: index) -> (index, index) {
+  %0 = memref.alloc(%sz0) : memref<3x4x?x2xf32>
+  %1 = memref.collapse_shape %0 [[0, 1], [2, 3]] : memref<3x4x?x2xf32> into memref<12x?xf32>
+  %2 = "test.reify_bound"(%1) {dim = 0} : (memref<12x?xf32>) -> (index)
+  %3 = "test.reify_bound"(%1) {dim = 1} : (memref<12x?xf32>) -> (index)
+  return %2, %3 : index, index
+}
+
+// -----
+
 // CHECK-LABEL: func @memref_get_global(
 //       CHECK:   %[[c4:.*]] = arith.constant 4 : index
 //       CHECK:   return %[[c4]]


        


More information about the Mlir-commits mailing list