[Mlir-commits] [mlir] Let `memref.{expand, collapse}_shape` implement `ReifyRankedShapedTypeOpInterface` (PR #89111)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Apr 17 10:54:35 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Benoit Jacob (bjacob)
<details>
<summary>Changes</summary>
In https://github.com/llvm/llvm-project/pull/88423 I came across a need for folding `memref.dim` into `memref.expand_shape` and it was (IMO rightly) suggested that the proper way to fix that was to implement `ReifyRankedShapedTypeOpInterface`. This PR does that for both `memref.{expand,collapse}_shape` to be consistent, as it would be surprising if these two ops weren't closely mirroring one another.
It would be good to carry on the `ReifyRankedShapedTypeOpInterface` migration, in particular completing it for `memref` and `tensor` ops, particularly to `tensor.{expand,collapse}_shape` to be consistent with this (and then one could drop some existing custom Fold patterns). However, I have heard of an ongoing project to generalize `expand_shape` to relax the requirement that at most one dimension in each reassociation group be dynamic. It would be wise to allow for that project to complete first, as the `ReifyRankedShapedTypeOpInterface` implementation will otherwise entrench the current semantics.
FYI @<!-- -->qedawkins @<!-- -->MaheshRavishankar @<!-- -->Shukla-Gaurav
---
Full diff: https://github.com/llvm/llvm-project/pull/89111.diff
3 Files Affected:
- (modified) mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td (+5-2)
- (modified) mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp (+90)
- (modified) mlir/test/Dialect/MemRef/resolve-dim-ops.mlir (+32)
``````````diff
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index 39e66cd9e6e5ab..8f6bff5809ca2b 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -1546,8 +1546,11 @@ def MemRef_ReshapeOp: MemRef_Op<"reshape", [
//===----------------------------------------------------------------------===//
class MemRef_ReassociativeReshapeOp<string mnemonic, list<Trait> traits = []> :
- MemRef_Op<mnemonic, !listconcat(traits,
- [Pure, ViewLikeOpInterface])>,
+ MemRef_Op<mnemonic, !listconcat(traits, [
+ Pure,
+ ViewLikeOpInterface,
+ DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>
+ ])>,
Arguments<(ins AnyStridedMemRef:$src, IndexListArrayAttr:$reassociation)>,
Results<(outs AnyStridedMemRef:$result)>{
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 836dcb8f329e70..e2ce7d93d227cc 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -9,6 +9,7 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Builders.h"
@@ -2079,6 +2080,95 @@ void ExpandShapeOp::getAsmResultNames(
setNameFn(getResult(), "expand_shape");
}
+LogicalResult ExpandShapeOp::reifyResultShapes(
+ OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
+ SmallVector<OpFoldResult> resultDims;
+ ArrayRef<int64_t> expandedShape = this->getResultType().getShape();
+ for (size_t expanded_dim = 0; expanded_dim < expandedShape.size();
+ ++expanded_dim) {
+ if (ShapedType::isDynamic(expandedShape[expanded_dim])) {
+ // Dynamic dimension case. Map expanded_dim to the corresponded
+ // collapsed dim. All other expanded dimensions corresponding to
+ // that collapsed dim must be static-size. Compute their product
+ // to divide the result size by.
+ auto reassoc = this->getReassociationIndices();
+ for (size_t collapsed_dim = 0; collapsed_dim < reassoc.size();
+ ++collapsed_dim) {
+ ReassociationIndices associated_dims = reassoc[collapsed_dim];
+ bool found_expanded_dim = false;
+ int64_t other_associated_dims_product_size = 1;
+ for (size_t associated_dim : associated_dims) {
+ if (associated_dim == expanded_dim) {
+ found_expanded_dim = true;
+ } else {
+ assert(!ShapedType::isDynamic(expandedShape[associated_dim]) &&
+ "At most one dimension of a reassociation group may be "
+ "dynamic in the result type.");
+ other_associated_dims_product_size *= expandedShape[associated_dim];
+ }
+ }
+ if (!found_expanded_dim) {
+ continue;
+ }
+ Value srcDimSize =
+ builder.create<memref::DimOp>(getLoc(), getSrc(), collapsed_dim);
+ Value resultDimSize = builder.create<arith::DivSIOp>(
+ getLoc(), srcDimSize,
+ builder.create<arith::ConstantIndexOp>(
+ getLoc(), other_associated_dims_product_size));
+ resultDims.push_back(resultDimSize);
+ }
+ } else {
+ resultDims.push_back(getAsIndexOpFoldResult(builder.getContext(),
+ expandedShape[expanded_dim]));
+ }
+ }
+ reifiedReturnShapes = {resultDims};
+ return success();
+}
+
+LogicalResult CollapseShapeOp::reifyResultShapes(
+ OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
+ SmallVector<OpFoldResult> resultDims;
+ ArrayRef<int64_t> collapsedShape = this->getResultType().getShape();
+ ArrayRef<int64_t> expandedShape = this->getSrcType().getShape();
+ for (size_t collapsed_dim = 0; collapsed_dim < collapsedShape.size();
+ ++collapsed_dim) {
+ if (ShapedType::isDynamic(collapsedShape[collapsed_dim])) {
+ // Dynamic dimension case. All other expanded dimensions corresponding
+ // to that collapsed_dim must be static-size. Compute their product
+ // to multiply the result size by.
+ auto reassoc = this->getReassociationIndices();
+ ReassociationIndices associated_dims = reassoc[collapsed_dim];
+ std::optional<size_t> expanded_dim;
+ int64_t other_associated_dims_product_size = 1;
+ for (size_t associated_dim : associated_dims) {
+ if (ShapedType::isDynamic(expandedShape[associated_dim])) {
+ assert(!expanded_dim && "At most one dimension of a reassociation "
+ "group may be dynamic in the result type.");
+ expanded_dim = associated_dim;
+ } else {
+ other_associated_dims_product_size *= expandedShape[associated_dim];
+ }
+ }
+ assert(expanded_dim && "No dynamic dimension in the reassociation group "
+ "to match the dynamic collapsed dimension.");
+ Value srcDimSize =
+ builder.create<memref::DimOp>(getLoc(), getSrc(), *expanded_dim);
+ Value resultDimSize = builder.create<arith::MulIOp>(
+ getLoc(), srcDimSize,
+ builder.create<arith::ConstantIndexOp>(
+ getLoc(), other_associated_dims_product_size));
+ resultDims.push_back(resultDimSize);
+ } else {
+ resultDims.push_back(getAsIndexOpFoldResult(
+ builder.getContext(), collapsedShape[collapsed_dim]));
+ }
+ }
+ reifiedReturnShapes = {resultDims};
+ return success();
+}
+
/// Helper function for verifying the shape of ExpandShapeOp and ResultShapeOp
/// result and operand. Layout maps are verified separately.
///
diff --git a/mlir/test/Dialect/MemRef/resolve-dim-ops.mlir b/mlir/test/Dialect/MemRef/resolve-dim-ops.mlir
index 18e9a9d02e1081..fb0f9106e61bbf 100644
--- a/mlir/test/Dialect/MemRef/resolve-dim-ops.mlir
+++ b/mlir/test/Dialect/MemRef/resolve-dim-ops.mlir
@@ -25,3 +25,35 @@ func.func @dim_out_of_bounds_2(%idx1 : index, %idx2 : index) -> index {
%0 = tensor.dim %alloc, %idx : tensor<?x?xf32>
return %0 : index
}
+
+// -----
+
+// Test case: Folding of memref.dim(memref.expand_shape)
+// CHECK-LABEL: func @dim_of_memref_expand_shape(
+// CHECK-SAME: %[[MEM:[0-9a-z]+]]: memref<?x8xi32>
+// CHECK-NEXT: %[[IDX:.*]] = arith.constant 0
+// CHECK-NEXT: %[[DIM:.*]] = memref.dim %[[MEM]], %[[IDX]] : memref<?x8xi32>
+// CHECK: return %[[DIM]] : index
+func.func @dim_of_memref_expand_shape(%arg0: memref<?x8xi32>)
+ -> index {
+ %c1 = arith.constant 1 : index
+ %0 = memref.expand_shape %arg0 [[0, 1], [2, 3]]: memref<?x8xi32> into memref<1x?x2x4xi32>
+ %1 = memref.dim %0, %c1 : memref<1x?x2x4xi32>
+ return %1 : index
+}
+
+// -----
+
+// Test case: Folding of memref.dim(memref.collapse_shape)
+// CHECK-LABEL: func @dim_of_memref_collapse_shape(
+// CHECK-SAME: %[[MEM:[0-9a-z]+]]: memref<1x?x2x4xi32>
+// CHECK-NEXT: %[[IDX:.*]] = arith.constant 1
+// CHECK-NEXT: %[[DIM:.*]] = memref.dim %[[MEM]], %[[IDX]] : memref<1x?x2x4xi32>
+// CHECK: return %[[DIM]] : index
+func.func @dim_of_memref_collapse_shape(%arg0: memref<1x?x2x4xi32>)
+ -> index {
+ %c0 = arith.constant 0 : index
+ %0 = memref.collapse_shape %arg0 [[0, 1], [2, 3]]: memref<1x?x2x4xi32> into memref<?x8xi32>
+ %1 = memref.dim %0, %c0 : memref<?x8xi32>
+ return %1 : index
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/89111
More information about the Mlir-commits
mailing list