[Mlir-commits] [mlir] Let `memref.{expand, collapse}_shape` implement `ReifyRankedShapedTypeOpInterface` (PR #89111)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Apr 17 10:54:35 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Benoit Jacob (bjacob)

<details>
<summary>Changes</summary>

In https://github.com/llvm/llvm-project/pull/88423 I came across a need for folding `memref.dim` into `memref.expand_shape` and it was (IMO rightly) suggested that the proper way to fix that was to implement `ReifyRankedShapedTypeOpInterface`. This PR does that for both `memref.{expand,collapse}_shape` to be consistent, as it would be surprising if these two ops weren't closely mirroring one another.

It would be good to carry on the `ReifyRankedShapedTypeOpInterface` migration, in particular completing it for `memref` and `tensor` ops, particularly to `tensor.{expand,collapse}_shape` to be consistent with this (and then one could drop some existing custom Fold patterns). However, I have heard of an ongoing project to generalize `expand_shape` to relax the requirement that at most one dimension in each reassociation group be dynamic. It would be wise to allow for that project to complete first, as the `ReifyRankedShapedTypeOpInterface` implementation will otherwise entrench the current semantics.

FYI @<!-- -->qedawkins @<!-- -->MaheshRavishankar @<!-- -->Shukla-Gaurav 

---
Full diff: https://github.com/llvm/llvm-project/pull/89111.diff


3 Files Affected:

- (modified) mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td (+5-2) 
- (modified) mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp (+90) 
- (modified) mlir/test/Dialect/MemRef/resolve-dim-ops.mlir (+32) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index 39e66cd9e6e5ab..8f6bff5809ca2b 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -1546,8 +1546,11 @@ def MemRef_ReshapeOp: MemRef_Op<"reshape", [
 //===----------------------------------------------------------------------===//
 
 class MemRef_ReassociativeReshapeOp<string mnemonic, list<Trait> traits = []> :
-    MemRef_Op<mnemonic, !listconcat(traits,
-      [Pure, ViewLikeOpInterface])>,
+    MemRef_Op<mnemonic, !listconcat(traits, [
+      Pure,
+      ViewLikeOpInterface,
+      DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>
+    ])>,
     Arguments<(ins AnyStridedMemRef:$src, IndexListArrayAttr:$reassociation)>,
     Results<(outs AnyStridedMemRef:$result)>{
 
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 836dcb8f329e70..e2ce7d93d227cc 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -9,6 +9,7 @@
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Arith/Utils/Utils.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
 #include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/IR/AffineMap.h"
 #include "mlir/IR/Builders.h"
@@ -2079,6 +2080,95 @@ void ExpandShapeOp::getAsmResultNames(
   setNameFn(getResult(), "expand_shape");
 }
 
+LogicalResult ExpandShapeOp::reifyResultShapes(
+    OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
+  SmallVector<OpFoldResult> resultDims;
+  ArrayRef<int64_t> expandedShape = this->getResultType().getShape();
+  for (size_t expanded_dim = 0; expanded_dim < expandedShape.size();
+       ++expanded_dim) {
+    if (ShapedType::isDynamic(expandedShape[expanded_dim])) {
+      // Dynamic dimension case. Map expanded_dim to the corresponded
+      // collapsed dim. All other expanded dimensions corresponding to
+      // that collapsed dim must be static-size. Compute their product
+      // to divide the result size by.
+      auto reassoc = this->getReassociationIndices();
+      for (size_t collapsed_dim = 0; collapsed_dim < reassoc.size();
+           ++collapsed_dim) {
+        ReassociationIndices associated_dims = reassoc[collapsed_dim];
+        bool found_expanded_dim = false;
+        int64_t other_associated_dims_product_size = 1;
+        for (size_t associated_dim : associated_dims) {
+          if (associated_dim == expanded_dim) {
+            found_expanded_dim = true;
+          } else {
+            assert(!ShapedType::isDynamic(expandedShape[associated_dim]) &&
+                   "At most one dimension of a reassociation group may be "
+                   "dynamic in the result type.");
+            other_associated_dims_product_size *= expandedShape[associated_dim];
+          }
+        }
+        if (!found_expanded_dim) {
+          continue;
+        }
+        Value srcDimSize =
+            builder.create<memref::DimOp>(getLoc(), getSrc(), collapsed_dim);
+        Value resultDimSize = builder.create<arith::DivSIOp>(
+            getLoc(), srcDimSize,
+            builder.create<arith::ConstantIndexOp>(
+                getLoc(), other_associated_dims_product_size));
+        resultDims.push_back(resultDimSize);
+      }
+    } else {
+      resultDims.push_back(getAsIndexOpFoldResult(builder.getContext(),
+                                                  expandedShape[expanded_dim]));
+    }
+  }
+  reifiedReturnShapes = {resultDims};
+  return success();
+}
+
+LogicalResult CollapseShapeOp::reifyResultShapes(
+    OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
+  SmallVector<OpFoldResult> resultDims;
+  ArrayRef<int64_t> collapsedShape = this->getResultType().getShape();
+  ArrayRef<int64_t> expandedShape = this->getSrcType().getShape();
+  for (size_t collapsed_dim = 0; collapsed_dim < collapsedShape.size();
+       ++collapsed_dim) {
+    if (ShapedType::isDynamic(collapsedShape[collapsed_dim])) {
+      // Dynamic dimension case. All other expanded dimensions corresponding
+      // to that collapsed_dim must be static-size. Compute their product
+      // to multiply the result size by.
+      auto reassoc = this->getReassociationIndices();
+      ReassociationIndices associated_dims = reassoc[collapsed_dim];
+      std::optional<size_t> expanded_dim;
+      int64_t other_associated_dims_product_size = 1;
+      for (size_t associated_dim : associated_dims) {
+        if (ShapedType::isDynamic(expandedShape[associated_dim])) {
+          assert(!expanded_dim && "At most one dimension of a reassociation "
+                                  "group may be dynamic in the result type.");
+          expanded_dim = associated_dim;
+        } else {
+          other_associated_dims_product_size *= expandedShape[associated_dim];
+        }
+      }
+      assert(expanded_dim && "No dynamic dimension in the reassociation group "
+                             "to match the dynamic collapsed dimension.");
+      Value srcDimSize =
+          builder.create<memref::DimOp>(getLoc(), getSrc(), *expanded_dim);
+      Value resultDimSize = builder.create<arith::MulIOp>(
+          getLoc(), srcDimSize,
+          builder.create<arith::ConstantIndexOp>(
+              getLoc(), other_associated_dims_product_size));
+      resultDims.push_back(resultDimSize);
+    } else {
+      resultDims.push_back(getAsIndexOpFoldResult(
+          builder.getContext(), collapsedShape[collapsed_dim]));
+    }
+  }
+  reifiedReturnShapes = {resultDims};
+  return success();
+}
+
 /// Helper function for verifying the shape of ExpandShapeOp and ResultShapeOp
 /// result and operand. Layout maps are verified separately.
 ///
diff --git a/mlir/test/Dialect/MemRef/resolve-dim-ops.mlir b/mlir/test/Dialect/MemRef/resolve-dim-ops.mlir
index 18e9a9d02e1081..fb0f9106e61bbf 100644
--- a/mlir/test/Dialect/MemRef/resolve-dim-ops.mlir
+++ b/mlir/test/Dialect/MemRef/resolve-dim-ops.mlir
@@ -25,3 +25,35 @@ func.func @dim_out_of_bounds_2(%idx1 : index, %idx2 : index) -> index {
   %0 = tensor.dim %alloc, %idx : tensor<?x?xf32>
   return %0 : index
 }
+
+// -----
+
+// Test case: Folding of memref.dim(memref.expand_shape)
+// CHECK-LABEL: func @dim_of_memref_expand_shape(
+//  CHECK-SAME:     %[[MEM:[0-9a-z]+]]: memref<?x8xi32>
+//  CHECK-NEXT:   %[[IDX:.*]] = arith.constant 0
+//  CHECK-NEXT:   %[[DIM:.*]] = memref.dim %[[MEM]], %[[IDX]] : memref<?x8xi32>
+//       CHECK:   return %[[DIM]] : index
+func.func @dim_of_memref_expand_shape(%arg0: memref<?x8xi32>)
+    -> index {
+  %c1 = arith.constant 1 : index
+  %0 = memref.expand_shape %arg0 [[0, 1], [2, 3]]: memref<?x8xi32> into memref<1x?x2x4xi32>
+  %1 = memref.dim %0, %c1 : memref<1x?x2x4xi32>
+  return %1 : index
+}
+
+// -----
+
+// Test case: Folding of memref.dim(memref.collapse_shape)
+// CHECK-LABEL: func @dim_of_memref_collapse_shape(
+//  CHECK-SAME:     %[[MEM:[0-9a-z]+]]: memref<1x?x2x4xi32>
+//  CHECK-NEXT:   %[[IDX:.*]] = arith.constant 1
+//  CHECK-NEXT:   %[[DIM:.*]] = memref.dim %[[MEM]], %[[IDX]] : memref<1x?x2x4xi32>
+//       CHECK:   return %[[DIM]] : index
+func.func @dim_of_memref_collapse_shape(%arg0: memref<1x?x2x4xi32>)
+    -> index {
+  %c0 = arith.constant 0 : index
+  %0 = memref.collapse_shape %arg0 [[0, 1], [2, 3]]: memref<1x?x2x4xi32> into memref<?x8xi32>
+  %1 = memref.dim %0, %c0 : memref<?x8xi32>
+  return %1 : index
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/89111


More information about the Mlir-commits mailing list