[Mlir-commits] [mlir] Let memref.collapse_shape implement ReifyRankedShapedTypeOpInterface. (PR #107752)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sun Sep 8 06:48:50 PDT 2024
https://github.com/ddubov100 created https://github.com/llvm/llvm-project/pull/107752
This MR lets memref.collapse_shape implement ReifyRankedShapedTypeOpInterface.
To be on the safe side it adds support for following case only:
- There is dynamic dimension in reassociation groups with single element.
>From 45a4160d88abaa71f85ebc93d50bee5353b9dd48 Mon Sep 17 00:00:00 2001
From: dubov diana <ddubov at mobileye.com>
Date: Sun, 8 Sep 2024 15:19:54 +0300
Subject: [PATCH] Let implement .
---
.../mlir/Dialect/MemRef/IR/MemRefOps.td | 6 +-
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 36 ++++++++++
mlir/test/Dialect/MemRef/resolve-dim-ops.mlir | 65 +++++++++++++++++++
3 files changed, 105 insertions(+), 2 deletions(-)
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index 2ff9d612a5efa7..604611f2e8b5ea 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -1670,7 +1670,7 @@ def MemRef_ExpandShapeOp : MemRef_ReassociativeReshapeOp<"expand_shape", [
}]>,
// Builder that infers the result layout map. The result shape must be
- // specified. Otherwise, the op may be ambiguous. The output shape for
+ // specified. Otherwise, the op may be ambiguous. The output shape for
// the op will be inferred using the inferOutputShape() method.
OpBuilder<(ins "ArrayRef<int64_t>":$resultShape, "Value":$src,
"ArrayRef<ReassociationIndices>":$reassociation)>,
@@ -1699,7 +1699,9 @@ def MemRef_ExpandShapeOp : MemRef_ReassociativeReshapeOp<"expand_shape", [
}
def MemRef_CollapseShapeOp : MemRef_ReassociativeReshapeOp<"collapse_shape", [
- DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>]> {
+ DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
+ DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>]>
+{
let summary = "operation to produce a memref with a smaller rank.";
let description = [{
The `memref.collapse_shape` op produces a new view with a smaller rank
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 779ffbfc23f4d7..48df018ee73d05 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -2497,6 +2497,42 @@ MemRefType CollapseShapeOp::computeCollapsedType(
srcType.getMemorySpace());
}
+static bool isDynamicInGroup(ReassociationIndices group,
+ ArrayRef<int64_t> sourceShape) {
+ return llvm::any_of(group, [sourceShape](int64_t dim) {
+ return ShapedType::isDynamic(sourceShape[dim]);
+ });
+}
+
+// This method supports following cases only:
+// - There is dynamic dimension in reassociation groups with single element.
+LogicalResult CollapseShapeOp::reifyResultShapes(
+ OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedResultShapes) {
+ SmallVector<ReassociationIndices, 4> reassociationArray =
+ getReassociationIndices();
+ Value source = getSrc();
+ auto sourceShape = cast<MemRefType>(source.getType()).getShape();
+ if (!ShapedType::isDynamicShape(sourceShape))
+ return failure();
+ for (auto group : enumerate(reassociationArray)) {
+ bool isDynInGroup = isDynamicInGroup(group.value(), sourceShape);
+ if (isDynInGroup && group.value().size() > 1)
+ return failure();
+ }
+ auto resultShape = cast<ShapedType>(getResultType()).getShape();
+
+ SmallVector<Value> dynamicValues;
+ for (int64_t i = 0; i < resultShape.size(); ++i) {
+ if (ShapedType::isDynamic(resultShape[i]))
+ dynamicValues.push_back(builder.create<DimOp>(
+ source.getLoc(), source,
+ builder.create<arith::ConstantIndexOp>(source.getLoc(), i)));
+ }
+ reifiedResultShapes = {getMixedValues(resultShape, dynamicValues, builder)};
+
+ return success();
+}
+
void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src,
ArrayRef<ReassociationIndices> reassociation,
ArrayRef<NamedAttribute> attrs) {
diff --git a/mlir/test/Dialect/MemRef/resolve-dim-ops.mlir b/mlir/test/Dialect/MemRef/resolve-dim-ops.mlir
index 85a4853972457c..912261515144d0 100644
--- a/mlir/test/Dialect/MemRef/resolve-dim-ops.mlir
+++ b/mlir/test/Dialect/MemRef/resolve-dim-ops.mlir
@@ -12,6 +12,71 @@ func.func @dim_out_of_bounds(%m : memref<7x8xf32>) -> index {
// -----
+// CHECK-LABEL: func.func @dyn_dim_of_memref_collapse_shape(
+// CHECK-SAME: %[[VAL_0:.*]]: memref<?x4x8x32xsi8>) -> index {
+// CHECK: %[[VAL_1:.*]] = arith.constant 0 : index
+// CHECK: %[[VAL_2:.*]] = memref.dim %[[VAL_0]], %[[VAL_1]] : memref<?x4x8x32xsi8>
+// CHECK: return %[[VAL_2]] : index
+// CHECK: }
+
+func.func @dyn_dim_of_memref_collapse_shape(%arg0: memref<?x4x8x32xsi8>)
+ -> index
+{
+ %c0 = arith.constant 0 : index
+ %dim_16 = memref.dim %arg0, %c0 : memref<?x4x8x32xsi8>
+ %alloc_17 = memref.alloc(%dim_16) {alignment = 32 : i64} : memref<?x32x4x8xsi8>
+ %collapse_shape = memref.collapse_shape %alloc_17 [[0], [1], [2, 3]] : memref<?x32x4x8xsi8> into memref<?x32x32xsi8>
+ %dim_18 = memref.dim %collapse_shape, %c0 : memref<?x32x32xsi8>
+ return %dim_18: index
+}
+
+// -----
+
+// CHECK-LABEL: func.func @resolve_when_collapse_after_collapse(
+// CHECK-SAME: %[[VAL_0:.*]]: memref<?x4x8x32xsi8>) -> index {
+// CHECK: %[[VAL_1:.*]] = arith.constant 0 : index
+// CHECK: %[[VAL_2:.*]] = memref.dim %[[VAL_0]], %[[VAL_1]] : memref<?x4x8x32xsi8>
+// CHECK: return %[[VAL_2]] : index
+// CHECK: }
+
+func.func @resolve_when_collapse_after_collapse(%arg0: memref<?x4x8x32xsi8>)
+ -> index
+{
+ %c0 = arith.constant 0 : index
+ %dim_16 = memref.dim %arg0, %c0 : memref<?x4x8x32xsi8>
+ %alloc_17 = memref.alloc(%dim_16) {alignment = 32 : i64} : memref<?x32x4x8xsi8>
+ %collapse_shape = memref.collapse_shape %alloc_17 [[0], [1], [2, 3]] : memref<?x32x4x8xsi8> into memref<?x32x32xsi8>
+ %collapse_shape_1 = memref.collapse_shape %collapse_shape [[0], [1, 2]] : memref<?x32x32xsi8> into memref<?x1024xsi8>
+ %dim_18 = memref.dim %collapse_shape_1, %c0 : memref<?x1024xsi8>
+ return %dim_18: index
+}
+
+// -----
+
+// CHECK-LABEL: func.func @unfoldable_memref_collapse_shape(
+// CHECK-SAME: %[[VAL_0:.*]]: memref<1x?x8x32xsi8>) -> index {
+// CHECK: %[[VAL_1:.*]] = arith.constant 1 : index
+// CHECK: %[[VAL_2:.*]] = arith.constant 2 : index
+// CHECK: %[[VAL_3:.*]] = memref.dim %[[VAL_0]], %[[VAL_1]] : memref<1x?x8x32xsi8>
+// CHECK: %[[VAL_4:.*]] = memref.alloc(%[[VAL_3]]) {alignment = 32 : i64} : memref<1x32x?x8xsi8>
+// CHECK: %[[VAL_5:.*]] = memref.collapse_shape %[[VAL_4]] {{\[\[}}0], [1], [2, 3]] : memref<1x32x?x8xsi8> into memref<1x32x?xsi8>
+// CHECK: %[[VAL_6:.*]] = memref.dim %[[VAL_5]], %[[VAL_2]] : memref<1x32x?xsi8>
+// CHECK: return %[[VAL_6]] : index
+// CHECK: }
+func.func @unfoldable_memref_collapse_shape(%arg0: memref<1x?x8x32xsi8>)
+ -> index
+{
+ %c1 = arith.constant 1 : index
+ %c2 = arith.constant 2 : index
+ %dim_1 = memref.dim %arg0, %c1 : memref<1x?x8x32xsi8>
+ %alloc_0 = memref.alloc(%dim_1) {alignment = 32 : i64} : memref<1x32x?x8xsi8>
+ %collapse_shape = memref.collapse_shape %alloc_0 [[0], [1], [2, 3]] : memref<1x32x?x8xsi8> into memref<1x32x?xsi8>
+ %dim_3 = memref.dim %collapse_shape, %c2 : memref<1x32x?xsi8>
+ return %dim_3: index
+}
+
+// -----
+
// CHECK-LABEL: func @dim_out_of_bounds_2(
// CHECK-NEXT: arith.constant
// CHECK-NEXT: arith.constant
More information about the Mlir-commits
mailing list