[Mlir-commits] [mlir] Let memref.collapse_shape implement ReifyRankedShapedTypeOpInterface. (PR #138452)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue May 6 00:15:40 PDT 2025


https://github.com/davidlerner96 updated https://github.com/llvm/llvm-project/pull/138452

>From ce1f7b7ec462d381f1755ddcae5284a2729ea044 Mon Sep 17 00:00:00 2001
From: David Lerner <DavidLerner96 at gmail.com>
Date: Sun, 4 May 2025 14:33:11 +0300
Subject: [PATCH] Let memref.collapse_shape implement
 ReifyRankedShapedTypeOpInterface.

This MR implements ReifyRankedShapedTypeOpInterface for memref.collapse_shape and adds
support in reifyResultShapes for memref.dim to operate directly on shaped values,
eliminating reliance on collapse_shape. The new logic fully supports all collapse sizes
and reifies dynamic dimensions, improving shape inference and lowering fidelity.
---
 .../mlir/Dialect/MemRef/IR/MemRefOps.td       |  3 +-
 mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp      | 31 +++++++
 mlir/test/Dialect/MemRef/resolve-dim-ops.mlir | 91 +++++++++++++++++++
 3 files changed, 124 insertions(+), 1 deletion(-)

diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index d6d8161d3117b..e401e3e8a53ae 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -1761,7 +1761,8 @@ def MemRef_ExpandShapeOp : MemRef_ReassociativeReshapeOp<"expand_shape", [
 }
 
 def MemRef_CollapseShapeOp : MemRef_ReassociativeReshapeOp<"collapse_shape", [
-    DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>]> {
+    DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
+    DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>]>{
   let summary = "operation to produce a memref with a smaller rank.";
   let description = [{
     The `memref.collapse_shape` op produces a new view with a smaller rank
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 6f10a31c15626..177b4a69d256f 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -2482,6 +2482,37 @@ MemRefType CollapseShapeOp::computeCollapsedType(
                          srcType.getMemorySpace());
 }
 
+// This method handles groups of dimensions where at least one dimension is dynamic.
+// For each such group, it computes the combined size by multiplying all the sizes
+// of the dimensions in that group. These computed sizes are then used to describe
+// the resulting shape after collapsing.
+LogicalResult CollapseShapeOp::reifyResultShapes(
+    OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedResultShapes) {
+  SmallVector<ReassociationIndices, 4> reassociationArray =
+      getReassociationIndices();
+  Value source = getSrc();
+  Location loc = getLoc();
+  SmallVector<Value> dynamicValues;
+  auto resultShape = cast<ShapedType>(getResultType()).getShape();
+  auto sourceShape = cast<MemRefType>(source.getType()).getShape();
+  for (auto group : reassociationArray) {
+    if (!llvm::any_of(group, [&](int64_t dim) {
+          return ShapedType::isDynamic(sourceShape[dim]);
+        }))
+      continue;
+    Value resultVal = builder.create<memref::DimOp>(loc, source, group[0]);
+    for (auto dim : llvm::drop_begin(group)) {
+      Value nextVal = builder.create<memref::DimOp>(loc, source, dim);
+      resultVal = builder.create<arith::MulIOp>(loc, resultVal, nextVal);
+    }
+
+    dynamicValues.push_back(resultVal);
+  }
+
+  reifiedResultShapes = {getMixedValues(resultShape, dynamicValues, builder)};
+  return success();
+}
+
 void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src,
                             ArrayRef<ReassociationIndices> reassociation,
                             ArrayRef<NamedAttribute> attrs) {
diff --git a/mlir/test/Dialect/MemRef/resolve-dim-ops.mlir b/mlir/test/Dialect/MemRef/resolve-dim-ops.mlir
index e354eb91d7557..f40b0ad849fa0 100644
--- a/mlir/test/Dialect/MemRef/resolve-dim-ops.mlir
+++ b/mlir/test/Dialect/MemRef/resolve-dim-ops.mlir
@@ -97,3 +97,94 @@ func.func @iter_to_init_arg_loop_like(
   }
   return %result : tensor<?x?xf32>
 }
+
+// -----
+
+// CHECK-LABEL:   func.func @collapse_dynamic_with_unit_dims(
+// CHECK-SAME:                                               %[[arg0:.*]]: memref<1x32x?x1xsi8>) -> index {
+// CHECK:           %[[c2:.*]] = arith.constant 2 : index
+// CHECK:           %[[dim:.*]] = memref.dim %[[arg0]], %[[c2]] : memref<1x32x?x1xsi8>
+// CHECK:           return %[[dim]] : index
+// CHECK:         }
+func.func @collapse_dynamic_with_unit_dims (%arg0: memref<1x32x?x1xsi8>)
+    -> index {
+  %c2 = arith.constant 2 : index
+  %collapse_shape = memref.collapse_shape %arg0 [[0], [1], [2, 3]] : memref<1x32x?x1xsi8> into memref<1x32x?xsi8>
+  %dim_3 = memref.dim %collapse_shape, %c2 : memref<1x32x?xsi8>
+  return %dim_3: index
+}
+
+// -----
+
+// CHECK-LABEL:   func.func @fold_dynamic_and_const_with_dynamic_on_right(
+// CHECK-SAME:                                                            %[[arg0:.*]]: memref<1x32x8x?xsi8>) -> index {
+// CHECK:           %[[c8:.*]] = arith.constant 8 : index
+// CHECK:           %[[c3:.*]] = arith.constant 3 : index
+// CHECK:           %[[dim:.*]] = memref.dim %[[arg0]], %[[c3]] : memref<1x32x8x?xsi8>
+// CHECK:           %[[res:.*]] = arith.muli %[[dim]], %[[c8]] : index
+// CHECK:           return %[[res]] : index
+// CHECK:         }
+func.func @fold_dynamic_and_const_with_dynamic_on_right(%arg0: memref<1x32x8x?xsi8>)
+    -> index {
+  %c2 = arith.constant 2 : index
+  %collapse_shape = memref.collapse_shape %arg0 [[0], [1], [2, 3]] : memref<1x32x8x?xsi8> into memref<1x32x?xsi8>
+  %dim_3 = memref.dim %collapse_shape, %c2 : memref<1x32x?xsi8>
+  return %dim_3: index
+}
+
+// -----
+
+// CHECK-LABEL:   func.func @fold_dynamic_and_const_with_dynamic_on_left(
+// CHECK-SAME:                                                           %[[arg0:.*]]: memref<1x32x?x8xsi8>) -> index {
+// CHECK:           %[[c8:.*]] = arith.constant 8 : index
+// CHECK:           %[[c2:.*]] = arith.constant 2 : index
+// CHECK:           %[[dim:.*]] = memref.dim %[[arg0]], %[[c2]] : memref<1x32x?x8xsi8>
+// CHECK:           %[[res:.*]] = arith.muli %[[dim]], %[[c8]] : index
+// CHECK:           return %[[res]] : index
+// CHECK:         }
+func.func @fold_dynamic_and_const_with_dynamic_on_left(%arg0: memref<1x32x?x8xsi8>)
+    -> index {
+  %c2 = arith.constant 2 : index
+  %collapse_shape = memref.collapse_shape %arg0 [[0], [1], [2, 3]] : memref<1x32x?x8xsi8> into memref<1x32x?xsi8>
+  %dim_3 = memref.dim %collapse_shape, %c2 : memref<1x32x?xsi8>
+  return %dim_3: index
+}
+
+// -----
+
+// CHECK-LABEL:   func.func @fold_more_than_two_elements_group(
+// CHECK-SAME:                              %[[arg0:.*]]: memref<2x32x?x8xsi8>) -> index {
+// CHECK:           %[[c8:.*]] = arith.constant 8 : index
+// CHECK:           %[[c64:.*]] = arith.constant 64 : index
+// CHECK:           %[[c2:.*]] = arith.constant 2 : index
+// CHECK:           %[[dim:.*]] = memref.dim %[[arg0]], %[[c2]] : memref<2x32x?x8xsi8>
+// CHECK:           %[[res0:.*]] = arith.muli %[[dim]], %[[c64]] : index
+// CHECK:           %[[res1:.*]] = arith.muli %[[res0]], %[[c8]] : index
+// CHECK:           return %[[res1]] : index
+// CHECK:         }
+func.func @fold_more_than_two_elements_group(%arg0: memref<2x32x?x8xsi8>)
+    -> index {
+  %c1 = arith.constant 0 : index
+  %collapse_shape = memref.collapse_shape %arg0 [[0, 1, 2, 3]] : memref<2x32x?x8xsi8> into memref<?xsi8>
+  %dim_3 = memref.dim %collapse_shape, %c1 : memref<?xsi8>
+  return %dim_3: index
+}
+
+// -----
+
+// CHECK-LABEL:   func.func @fold_group_with_two_dynamic(
+// CHECK-SAME:                                           %[[arg0:.*]]: memref<1x32x?x?xsi8>) -> index {
+// CHECK:           %[[c3:.*]] = arith.constant 3 : index
+// CHECK:           %[[c2:.*]] = arith.constant 2 : index
+// CHECK:           %[[dim2:.*]] = memref.dim %[[arg0]], %[[c2]] : memref<1x32x?x?xsi8>
+// CHECK:           %[[dim3:.*]] = memref.dim %[[arg0]], %[[c3]] : memref<1x32x?x?xsi8>
+// CHECK:           %[[res:.*]] = arith.muli %[[dim2]], %[[dim3]] : index
+// CHECK:           return %[[res]] : index
+// CHECK:         }
+func.func @fold_group_with_two_dynamic(%arg0: memref<1x32x?x?xsi8>)
+    -> index {
+  %c2 = arith.constant 2 : index
+  %collapse_shape = memref.collapse_shape %arg0 [[0], [1], [2, 3]] : memref<1x32x?x?xsi8> into memref<1x32x?xsi8>
+  %dim_3 = memref.dim %collapse_shape, %c2 : memref<1x32x?xsi8>
+  return %dim_3: index
+}



More information about the Mlir-commits mailing list