[Mlir-commits] [mlir] b05a12e - Let `memref.expand_shape` implement `ReifyRankedShapedTypeOpInterface` (#90975)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri May 3 15:33:05 PDT 2024


Author: Benoit Jacob
Date: 2024-05-03T18:33:01-04:00
New Revision: b05a12e9d0b668effa50f566508daa279ec85c93

URL: https://github.com/llvm/llvm-project/commit/b05a12e9d0b668effa50f566508daa279ec85c93
DIFF: https://github.com/llvm/llvm-project/commit/b05a12e9d0b668effa50f566508daa279ec85c93.diff

LOG: Let `memref.expand_shape` implement `ReifyRankedShapedTypeOpInterface` (#90975)

This is a new take on #89111. Now that #90040 is merged, this has become
trivial to implement. The added test shows the kind of benefit that we
get from this: now dim-of-expand-shape naturally folds without us
needing to implement an ad-hoc folding rewrite.

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
    mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
    mlir/test/Dialect/MemRef/resolve-dim-ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index 14b8d95ea15b41..5738b6ca51c12c 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -1578,7 +1578,8 @@ class MemRef_ReassociativeReshapeOp<string mnemonic, list<Trait> traits = []> :
 }
 
 def MemRef_ExpandShapeOp : MemRef_ReassociativeReshapeOp<"expand_shape", [
-    DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>]> {
+    DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
+    DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>]> {
   let summary = "operation to produce a memref with a higher rank.";
   let description = [{
     The `memref.expand_shape` op produces a new view with a higher rank whose

diff  --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index b969d41d934d41..393f73dc65cd8d 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -2079,6 +2079,13 @@ void ExpandShapeOp::getAsmResultNames(
   setNameFn(getResult(), "expand_shape");
 }
 
+LogicalResult ExpandShapeOp::reifyResultShapes(
+    OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedResultShapes) {
+  reifiedResultShapes = {
+      getMixedValues(getStaticOutputShape(), getOutputShape(), builder)};
+  return success();
+}
+
 /// Helper function for verifying the shape of ExpandShapeOp and ResultShapeOp
 /// result and operand. Layout maps are verified separately.
 ///

diff  --git a/mlir/test/Dialect/MemRef/resolve-dim-ops.mlir b/mlir/test/Dialect/MemRef/resolve-dim-ops.mlir
index 40f88de01b8bd7..85a4853972457c 100644
--- a/mlir/test/Dialect/MemRef/resolve-dim-ops.mlir
+++ b/mlir/test/Dialect/MemRef/resolve-dim-ops.mlir
@@ -53,3 +53,21 @@ func.func @static_dim_of_transpose_op(%arg0: tensor<1x100x?x8xi8>) -> index {
   %dim = tensor.dim %1, %c2 : tensor<1x8x100x?xi8>
   return %dim : index
 }
+
+// -----
+
+// Test case: Folding of memref.dim(memref.expand_shape)
+// CHECK-LABEL: func @dim_of_memref_expand_shape(
+//  CHECK-SAME:     %[[MEM:[0-9a-z]+]]: memref<?x8xi32>
+//  CHECK-NEXT:   %[[IDX:.*]] = arith.constant 0
+//  CHECK-NEXT:   %[[DIM:.*]] = memref.dim %[[MEM]], %[[IDX]] : memref<?x8xi32>
+//       CHECK:   return %[[DIM]] : index
+func.func @dim_of_memref_expand_shape(%arg0: memref<?x8xi32>)
+    -> index {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %s = memref.dim %arg0, %c0 : memref<?x8xi32>
+  %0 = memref.expand_shape %arg0 [[0, 1], [2, 3]] output_shape [1, %s, 2, 4]: memref<?x8xi32> into memref<1x?x2x4xi32>
+  %1 = memref.dim %0, %c1 : memref<1x?x2x4xi32>
+  return %1 : index
+}


        


More information about the Mlir-commits mailing list