[Mlir-commits] [mlir] Let memref.collapse_shape implement ReifyRankedShapedTypeOpInterface. (PR #138452)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue May 6 00:15:40 PDT 2025
https://github.com/davidlerner96 updated https://github.com/llvm/llvm-project/pull/138452
>From ce1f7b7ec462d381f1755ddcae5284a2729ea044 Mon Sep 17 00:00:00 2001
From: David Lerner <DavidLerner96 at gmail.com>
Date: Sun, 4 May 2025 14:33:11 +0300
Subject: [PATCH] Let memref.collapse_shape implement
ReifyRankedShapedTypeOpInterface.
This MR implements ReifyRankedShapedTypeOpInterface for memref.collapse_shape and adds
support in reifyResultShapes for memref.dim to operate directly on shaped values,
eliminating reliance on collapse_shape. The new logic fully supports all collapse sizes
and reifies dynamic dimensions, improving shape inference and lowering fidelity.
---
.../mlir/Dialect/MemRef/IR/MemRefOps.td | 3 +-
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 31 +++++++
mlir/test/Dialect/MemRef/resolve-dim-ops.mlir | 91 +++++++++++++++++++
3 files changed, 124 insertions(+), 1 deletion(-)
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index d6d8161d3117b..e401e3e8a53ae 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -1761,7 +1761,8 @@ def MemRef_ExpandShapeOp : MemRef_ReassociativeReshapeOp<"expand_shape", [
}
def MemRef_CollapseShapeOp : MemRef_ReassociativeReshapeOp<"collapse_shape", [
- DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>]> {
+ DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
+ DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>]>{
let summary = "operation to produce a memref with a smaller rank.";
let description = [{
The `memref.collapse_shape` op produces a new view with a smaller rank
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 6f10a31c15626..177b4a69d256f 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -2482,6 +2482,37 @@ MemRefType CollapseShapeOp::computeCollapsedType(
srcType.getMemorySpace());
}
+// This method handles groups of dimensions where at least one dimension is dynamic.
+// For each such group, it computes the combined size by multiplying all the sizes
+// of the dimensions in that group. These computed sizes are then used to describe
+// the resulting shape after collapsing.
+LogicalResult CollapseShapeOp::reifyResultShapes(
+ OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedResultShapes) {
+ SmallVector<ReassociationIndices, 4> reassociationArray =
+ getReassociationIndices();
+ Value source = getSrc();
+ Location loc = getLoc();
+ SmallVector<Value> dynamicValues;
+ auto resultShape = cast<ShapedType>(getResultType()).getShape();
+ auto sourceShape = cast<MemRefType>(source.getType()).getShape();
+ for (auto group : reassociationArray) {
+ if (!llvm::any_of(group, [&](int64_t dim) {
+ return ShapedType::isDynamic(sourceShape[dim]);
+ }))
+ continue;
+ Value resultVal = builder.create<memref::DimOp>(loc, source, group[0]);
+ for (auto dim : llvm::drop_begin(group)) {
+ Value nextVal = builder.create<memref::DimOp>(loc, source, dim);
+ resultVal = builder.create<arith::MulIOp>(loc, resultVal, nextVal);
+ }
+
+ dynamicValues.push_back(resultVal);
+ }
+
+ reifiedResultShapes = {getMixedValues(resultShape, dynamicValues, builder)};
+ return success();
+}
+
void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src,
ArrayRef<ReassociationIndices> reassociation,
ArrayRef<NamedAttribute> attrs) {
diff --git a/mlir/test/Dialect/MemRef/resolve-dim-ops.mlir b/mlir/test/Dialect/MemRef/resolve-dim-ops.mlir
index e354eb91d7557..f40b0ad849fa0 100644
--- a/mlir/test/Dialect/MemRef/resolve-dim-ops.mlir
+++ b/mlir/test/Dialect/MemRef/resolve-dim-ops.mlir
@@ -97,3 +97,94 @@ func.func @iter_to_init_arg_loop_like(
}
return %result : tensor<?x?xf32>
}
+
+// -----
+
+// CHECK-LABEL: func.func @collapse_dynamic_with_unit_dims(
+// CHECK-SAME: %[[arg0:.*]]: memref<1x32x?x1xsi8>) -> index {
+// CHECK: %[[c2:.*]] = arith.constant 2 : index
+// CHECK: %[[dim:.*]] = memref.dim %[[arg0]], %[[c2]] : memref<1x32x?x1xsi8>
+// CHECK: return %[[dim]] : index
+// CHECK: }
+func.func @collapse_dynamic_with_unit_dims (%arg0: memref<1x32x?x1xsi8>)
+ -> index {
+ %c2 = arith.constant 2 : index
+ %collapse_shape = memref.collapse_shape %arg0 [[0], [1], [2, 3]] : memref<1x32x?x1xsi8> into memref<1x32x?xsi8>
+ %dim_3 = memref.dim %collapse_shape, %c2 : memref<1x32x?xsi8>
+ return %dim_3: index
+}
+
+// -----
+
+// CHECK-LABEL: func.func @fold_dynamic_and_const_with_dynamic_on_right(
+// CHECK-SAME: %[[arg0:.*]]: memref<1x32x8x?xsi8>) -> index {
+// CHECK: %[[c8:.*]] = arith.constant 8 : index
+// CHECK: %[[c3:.*]] = arith.constant 3 : index
+// CHECK: %[[dim:.*]] = memref.dim %[[arg0]], %[[c3]] : memref<1x32x8x?xsi8>
+// CHECK: %[[res:.*]] = arith.muli %[[dim]], %[[c8]] : index
+// CHECK: return %[[res]] : index
+// CHECK: }
+func.func @fold_dynamic_and_const_with_dynamic_on_right(%arg0: memref<1x32x8x?xsi8>)
+ -> index {
+ %c2 = arith.constant 2 : index
+ %collapse_shape = memref.collapse_shape %arg0 [[0], [1], [2, 3]] : memref<1x32x8x?xsi8> into memref<1x32x?xsi8>
+ %dim_3 = memref.dim %collapse_shape, %c2 : memref<1x32x?xsi8>
+ return %dim_3: index
+}
+
+// -----
+
+// CHECK-LABEL: func.func @fold_dynamic_and_const_with_dynamic_on_left(
+// CHECK-SAME: %[[arg0:.*]]: memref<1x32x?x8xsi8>) -> index {
+// CHECK: %[[c8:.*]] = arith.constant 8 : index
+// CHECK: %[[c2:.*]] = arith.constant 2 : index
+// CHECK: %[[dim:.*]] = memref.dim %[[arg0]], %[[c2]] : memref<1x32x?x8xsi8>
+// CHECK: %[[res:.*]] = arith.muli %[[dim]], %[[c8]] : index
+// CHECK: return %[[res]] : index
+// CHECK: }
+func.func @fold_dynamic_and_const_with_dynamic_on_left(%arg0: memref<1x32x?x8xsi8>)
+ -> index {
+ %c2 = arith.constant 2 : index
+ %collapse_shape = memref.collapse_shape %arg0 [[0], [1], [2, 3]] : memref<1x32x?x8xsi8> into memref<1x32x?xsi8>
+ %dim_3 = memref.dim %collapse_shape, %c2 : memref<1x32x?xsi8>
+ return %dim_3: index
+}
+
+// -----
+
+// CHECK-LABEL: func.func @fold_more_than_two_elements_group(
+// CHECK-SAME: %[[arg0:.*]]: memref<2x32x?x8xsi8>) -> index {
+// CHECK: %[[c8:.*]] = arith.constant 8 : index
+// CHECK: %[[c64:.*]] = arith.constant 64 : index
+// CHECK: %[[c2:.*]] = arith.constant 2 : index
+// CHECK: %[[dim:.*]] = memref.dim %[[arg0]], %[[c2]] : memref<2x32x?x8xsi8>
+// CHECK: %[[res0:.*]] = arith.muli %[[dim]], %[[c64]] : index
+// CHECK: %[[res1:.*]] = arith.muli %[[res0]], %[[c8]] : index
+// CHECK: return %[[res1]] : index
+// CHECK: }
+func.func @fold_more_than_two_elements_group(%arg0: memref<2x32x?x8xsi8>)
+ -> index {
+ %c1 = arith.constant 0 : index
+ %collapse_shape = memref.collapse_shape %arg0 [[0, 1, 2, 3]] : memref<2x32x?x8xsi8> into memref<?xsi8>
+ %dim_3 = memref.dim %collapse_shape, %c1 : memref<?xsi8>
+ return %dim_3: index
+}
+
+// -----
+
+// CHECK-LABEL: func.func @fold_group_with_two_dynamic(
+// CHECK-SAME: %[[arg0:.*]]: memref<1x32x?x?xsi8>) -> index {
+// CHECK: %[[c3:.*]] = arith.constant 3 : index
+// CHECK: %[[c2:.*]] = arith.constant 2 : index
+// CHECK: %[[dim2:.*]] = memref.dim %[[arg0]], %[[c2]] : memref<1x32x?x?xsi8>
+// CHECK: %[[dim3:.*]] = memref.dim %[[arg0]], %[[c3]] : memref<1x32x?x?xsi8>
+// CHECK: %[[res:.*]] = arith.muli %[[dim2]], %[[dim3]] : index
+// CHECK: return %[[res]] : index
+// CHECK: }
+func.func @fold_group_with_two_dynamic(%arg0: memref<1x32x?x?xsi8>)
+ -> index {
+ %c2 = arith.constant 2 : index
+ %collapse_shape = memref.collapse_shape %arg0 [[0], [1], [2, 3]] : memref<1x32x?x?xsi8> into memref<1x32x?xsi8>
+ %dim_3 = memref.dim %collapse_shape, %c2 : memref<1x32x?xsi8>
+ return %dim_3: index
+}
More information about the Mlir-commits
mailing list