[Mlir-commits] [mlir] Let memref.collapse_shape implement ReifyRankedShapedTypeOpInterface. (PR #107752)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sun Sep 8 06:49:41 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: None (ddubov100)

<details>
<summary>Changes</summary>

This MR lets memref.collapse_shape implement ReifyRankedShapedTypeOpInterface.
To be on the safe side it adds support for following case only:
 - There is dynamic dimension in reassociation groups with single element.

---
Full diff: https://github.com/llvm/llvm-project/pull/107752.diff


3 Files Affected:

- (modified) mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td (+4-2) 
- (modified) mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp (+36) 
- (modified) mlir/test/Dialect/MemRef/resolve-dim-ops.mlir (+65) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index 2ff9d612a5efa7..604611f2e8b5ea 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -1670,7 +1670,7 @@ def MemRef_ExpandShapeOp : MemRef_ReassociativeReshapeOp<"expand_shape", [
     }]>,
 
     // Builder that infers the result layout map. The result shape must be
-    // specified. Otherwise, the op may be ambiguous. The output shape for 
+    // specified. Otherwise, the op may be ambiguous. The output shape for
     // the op will be inferred using the inferOutputShape() method.
     OpBuilder<(ins "ArrayRef<int64_t>":$resultShape, "Value":$src,
                "ArrayRef<ReassociationIndices>":$reassociation)>,
@@ -1699,7 +1699,9 @@ def MemRef_ExpandShapeOp : MemRef_ReassociativeReshapeOp<"expand_shape", [
 }
 
 def MemRef_CollapseShapeOp : MemRef_ReassociativeReshapeOp<"collapse_shape", [
-    DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>]> {
+    DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
+    DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>]>
+{
   let summary = "operation to produce a memref with a smaller rank.";
   let description = [{
     The `memref.collapse_shape` op produces a new view with a smaller rank
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 779ffbfc23f4d7..48df018ee73d05 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -2497,6 +2497,42 @@ MemRefType CollapseShapeOp::computeCollapsedType(
                          srcType.getMemorySpace());
 }
 
+static bool isDynamicInGroup(ReassociationIndices group,
+                             ArrayRef<int64_t> sourceShape) {
+  return llvm::any_of(group, [sourceShape](int64_t dim) {
+    return ShapedType::isDynamic(sourceShape[dim]);
+  });
+}
+
+// This method supports following cases only:
+// - There is dynamic dimension in reassociation groups with single element.
+LogicalResult CollapseShapeOp::reifyResultShapes(
+    OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedResultShapes) {
+  SmallVector<ReassociationIndices, 4> reassociationArray =
+      getReassociationIndices();
+  Value source = getSrc();
+  auto sourceShape = cast<MemRefType>(source.getType()).getShape();
+  if (!ShapedType::isDynamicShape(sourceShape))
+    return failure();
+  for (auto group : enumerate(reassociationArray)) {
+    bool isDynInGroup = isDynamicInGroup(group.value(), sourceShape);
+    if (isDynInGroup && group.value().size() > 1)
+      return failure();
+  }
+  auto resultShape = cast<ShapedType>(getResultType()).getShape();
+
+  SmallVector<Value> dynamicValues;
+  for (int64_t i = 0; i < resultShape.size(); ++i) {
+    if (ShapedType::isDynamic(resultShape[i]))
+      dynamicValues.push_back(builder.create<DimOp>(
+          source.getLoc(), source,
+          builder.create<arith::ConstantIndexOp>(source.getLoc(), i)));
+  }
+  reifiedResultShapes = {getMixedValues(resultShape, dynamicValues, builder)};
+
+  return success();
+}
+
 void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src,
                             ArrayRef<ReassociationIndices> reassociation,
                             ArrayRef<NamedAttribute> attrs) {
diff --git a/mlir/test/Dialect/MemRef/resolve-dim-ops.mlir b/mlir/test/Dialect/MemRef/resolve-dim-ops.mlir
index 85a4853972457c..912261515144d0 100644
--- a/mlir/test/Dialect/MemRef/resolve-dim-ops.mlir
+++ b/mlir/test/Dialect/MemRef/resolve-dim-ops.mlir
@@ -12,6 +12,71 @@ func.func @dim_out_of_bounds(%m : memref<7x8xf32>) -> index {
 
 // -----
 
+// CHECK-LABEL:   func.func @dyn_dim_of_memref_collapse_shape(
+// CHECK-SAME:                                            %[[VAL_0:.*]]: memref<?x4x8x32xsi8>) -> index {
+// CHECK:           %[[VAL_1:.*]] = arith.constant 0 : index
+// CHECK:           %[[VAL_2:.*]] = memref.dim %[[VAL_0]], %[[VAL_1]] : memref<?x4x8x32xsi8>
+// CHECK:           return %[[VAL_2]] : index
+// CHECK:         }
+
+func.func @dyn_dim_of_memref_collapse_shape(%arg0: memref<?x4x8x32xsi8>)
+    -> index
+{
+    %c0 = arith.constant 0 : index
+    %dim_16 = memref.dim %arg0, %c0 : memref<?x4x8x32xsi8>
+    %alloc_17 = memref.alloc(%dim_16) {alignment = 32 : i64} : memref<?x32x4x8xsi8>
+    %collapse_shape = memref.collapse_shape %alloc_17 [[0], [1], [2, 3]] : memref<?x32x4x8xsi8> into memref<?x32x32xsi8>
+    %dim_18 = memref.dim %collapse_shape, %c0 : memref<?x32x32xsi8>
+    return %dim_18: index
+}
+
+// -----
+
+// CHECK-LABEL:   func.func @resolve_when_collapse_after_collapse(
+// CHECK-SAME:                                                %[[VAL_0:.*]]: memref<?x4x8x32xsi8>) -> index {
+// CHECK:           %[[VAL_1:.*]] = arith.constant 0 : index
+// CHECK:           %[[VAL_2:.*]] = memref.dim %[[VAL_0]], %[[VAL_1]] : memref<?x4x8x32xsi8>
+// CHECK:           return %[[VAL_2]] : index
+// CHECK:         }
+
+func.func @resolve_when_collapse_after_collapse(%arg0: memref<?x4x8x32xsi8>)
+    -> index
+{
+    %c0 = arith.constant 0 : index
+    %dim_16 = memref.dim %arg0, %c0 : memref<?x4x8x32xsi8>
+    %alloc_17 = memref.alloc(%dim_16) {alignment = 32 : i64} : memref<?x32x4x8xsi8>
+    %collapse_shape = memref.collapse_shape %alloc_17 [[0], [1], [2, 3]] : memref<?x32x4x8xsi8> into memref<?x32x32xsi8>
+    %collapse_shape_1 = memref.collapse_shape %collapse_shape [[0], [1, 2]] : memref<?x32x32xsi8> into memref<?x1024xsi8>
+    %dim_18 = memref.dim %collapse_shape_1, %c0 : memref<?x1024xsi8>
+    return %dim_18: index
+}
+
+// -----
+
+// CHECK-LABEL:   func.func @unfoldable_memref_collapse_shape(
+// CHECK-SAME:                                                %[[VAL_0:.*]]: memref<1x?x8x32xsi8>) -> index {
+// CHECK:           %[[VAL_1:.*]] = arith.constant 1 : index
+// CHECK:           %[[VAL_2:.*]] = arith.constant 2 : index
+// CHECK:           %[[VAL_3:.*]] = memref.dim %[[VAL_0]], %[[VAL_1]] : memref<1x?x8x32xsi8>
+// CHECK:           %[[VAL_4:.*]] = memref.alloc(%[[VAL_3]]) {alignment = 32 : i64} : memref<1x32x?x8xsi8>
+// CHECK:           %[[VAL_5:.*]] = memref.collapse_shape %[[VAL_4]] {{\[\[}}0], [1], [2, 3]] : memref<1x32x?x8xsi8> into memref<1x32x?xsi8>
+// CHECK:           %[[VAL_6:.*]] = memref.dim %[[VAL_5]], %[[VAL_2]] : memref<1x32x?xsi8>
+// CHECK:           return %[[VAL_6]] : index
+// CHECK:         }
+func.func @unfoldable_memref_collapse_shape(%arg0: memref<1x?x8x32xsi8>)
+    -> index
+{
+    %c1 = arith.constant 1 : index
+    %c2 = arith.constant 2 : index
+    %dim_1 = memref.dim %arg0, %c1 : memref<1x?x8x32xsi8>
+    %alloc_0 = memref.alloc(%dim_1) {alignment = 32 : i64} : memref<1x32x?x8xsi8>
+    %collapse_shape = memref.collapse_shape %alloc_0 [[0], [1], [2, 3]] : memref<1x32x?x8xsi8> into memref<1x32x?xsi8>
+    %dim_3 = memref.dim %collapse_shape, %c2 : memref<1x32x?xsi8>
+    return %dim_3: index
+}
+
+// -----
+
 // CHECK-LABEL: func @dim_out_of_bounds_2(
 //  CHECK-NEXT:   arith.constant
 //  CHECK-NEXT:   arith.constant

``````````

</details>


https://github.com/llvm/llvm-project/pull/107752


More information about the Mlir-commits mailing list