[Mlir-commits] [mlir] b05a12e - Let `memref.expand_shape` implement `ReifyRankedShapedTypeOpInterface` (#90975)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri May 3 15:33:05 PDT 2024
Author: Benoit Jacob
Date: 2024-05-03T18:33:01-04:00
New Revision: b05a12e9d0b668effa50f566508daa279ec85c93
URL: https://github.com/llvm/llvm-project/commit/b05a12e9d0b668effa50f566508daa279ec85c93
DIFF: https://github.com/llvm/llvm-project/commit/b05a12e9d0b668effa50f566508daa279ec85c93.diff
LOG: Let `memref.expand_shape` implement `ReifyRankedShapedTypeOpInterface` (#90975)
This is a new take on #89111. Now that #90040 is merged, this has become
trivial to implement. The added test shows the kind of benefit that we
get from this: now dim-of-expand-shape naturally folds without us
needing to implement an ad-hoc folding rewrite.
Added:
Modified:
mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
mlir/test/Dialect/MemRef/resolve-dim-ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index 14b8d95ea15b41..5738b6ca51c12c 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -1578,7 +1578,8 @@ class MemRef_ReassociativeReshapeOp<string mnemonic, list<Trait> traits = []> :
}
def MemRef_ExpandShapeOp : MemRef_ReassociativeReshapeOp<"expand_shape", [
- DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>]> {
+ DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
+ DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>]> {
let summary = "operation to produce a memref with a higher rank.";
let description = [{
The `memref.expand_shape` op produces a new view with a higher rank whose
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index b969d41d934d41..393f73dc65cd8d 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -2079,6 +2079,13 @@ void ExpandShapeOp::getAsmResultNames(
setNameFn(getResult(), "expand_shape");
}
+LogicalResult ExpandShapeOp::reifyResultShapes(
+ OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedResultShapes) {
+ reifiedResultShapes = {
+ getMixedValues(getStaticOutputShape(), getOutputShape(), builder)};
+ return success();
+}
+
/// Helper function for verifying the shape of ExpandShapeOp and ResultShapeOp
/// result and operand. Layout maps are verified separately.
///
diff --git a/mlir/test/Dialect/MemRef/resolve-dim-ops.mlir b/mlir/test/Dialect/MemRef/resolve-dim-ops.mlir
index 40f88de01b8bd7..85a4853972457c 100644
--- a/mlir/test/Dialect/MemRef/resolve-dim-ops.mlir
+++ b/mlir/test/Dialect/MemRef/resolve-dim-ops.mlir
@@ -53,3 +53,21 @@ func.func @static_dim_of_transpose_op(%arg0: tensor<1x100x?x8xi8>) -> index {
%dim = tensor.dim %1, %c2 : tensor<1x8x100x?xi8>
return %dim : index
}
+
+// -----
+
+// Test case: Folding of memref.dim(memref.expand_shape)
+// CHECK-LABEL: func @dim_of_memref_expand_shape(
+// CHECK-SAME: %[[MEM:[0-9a-z]+]]: memref<?x8xi32>
+// CHECK-NEXT: %[[IDX:.*]] = arith.constant 0
+// CHECK-NEXT: %[[DIM:.*]] = memref.dim %[[MEM]], %[[IDX]] : memref<?x8xi32>
+// CHECK: return %[[DIM]] : index
+func.func @dim_of_memref_expand_shape(%arg0: memref<?x8xi32>)
+ -> index {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %s = memref.dim %arg0, %c0 : memref<?x8xi32>
+ %0 = memref.expand_shape %arg0 [[0, 1], [2, 3]] output_shape [1, %s, 2, 4]: memref<?x8xi32> into memref<1x?x2x4xi32>
+ %1 = memref.dim %0, %c1 : memref<1x?x2x4xi32>
+ return %1 : index
+}
More information about the Mlir-commits
mailing list