[Mlir-commits] [mlir] Let `memref.expand_shape` implement `ReifyRankedShapedTypeOpInterface` (PR #90975)

Benoit Jacob llvmlistbot at llvm.org
Fri May 3 08:31:11 PDT 2024


https://github.com/bjacob created https://github.com/llvm/llvm-project/pull/90975

This is a new take on #89111. Now that #90040 is merged, this has become trivial to implement. The added test shows the kind of benefit that we get from this: now dim-of-expand-shape naturally folds without us needing to implement an ad-hoc folding rewrite.

>From d39f8d51c5da63417542d09f2345acf269e7b4b6 Mon Sep 17 00:00:00 2001
From: Benoit Jacob <jacob.benoit.1 at gmail.com>
Date: Thu, 11 Apr 2024 13:41:41 -0400
Subject: [PATCH] reify

---
 .../mlir/Dialect/MemRef/IR/MemRefOps.td        |  3 ++-
 mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp       |  7 +++++++
 mlir/test/Dialect/MemRef/resolve-dim-ops.mlir  | 18 ++++++++++++++++++
 3 files changed, 27 insertions(+), 1 deletion(-)

diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index 14b8d95ea15b41..5738b6ca51c12c 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -1578,7 +1578,8 @@ class MemRef_ReassociativeReshapeOp<string mnemonic, list<Trait> traits = []> :
 }
 
 def MemRef_ExpandShapeOp : MemRef_ReassociativeReshapeOp<"expand_shape", [
-    DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>]> {
+    DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
+    DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>]> {
   let summary = "operation to produce a memref with a higher rank.";
   let description = [{
     The `memref.expand_shape` op produces a new view with a higher rank whose
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index b969d41d934d41..393f73dc65cd8d 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -2079,6 +2079,13 @@ void ExpandShapeOp::getAsmResultNames(
   setNameFn(getResult(), "expand_shape");
 }
 
+LogicalResult ExpandShapeOp::reifyResultShapes(
+    OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedResultShapes) {
+  reifiedResultShapes = {
+      getMixedValues(getStaticOutputShape(), getOutputShape(), builder)};
+  return success();
+}
+
 /// Helper function for verifying the shape of ExpandShapeOp and ResultShapeOp
 /// result and operand. Layout maps are verified separately.
 ///
diff --git a/mlir/test/Dialect/MemRef/resolve-dim-ops.mlir b/mlir/test/Dialect/MemRef/resolve-dim-ops.mlir
index 40f88de01b8bd7..85a4853972457c 100644
--- a/mlir/test/Dialect/MemRef/resolve-dim-ops.mlir
+++ b/mlir/test/Dialect/MemRef/resolve-dim-ops.mlir
@@ -53,3 +53,21 @@ func.func @static_dim_of_transpose_op(%arg0: tensor<1x100x?x8xi8>) -> index {
   %dim = tensor.dim %1, %c2 : tensor<1x8x100x?xi8>
   return %dim : index
 }
+
+// -----
+
+// Test case: Folding of memref.dim(memref.expand_shape)
+// CHECK-LABEL: func @dim_of_memref_expand_shape(
+//  CHECK-SAME:     %[[MEM:[0-9a-z]+]]: memref<?x8xi32>
+//  CHECK-NEXT:   %[[IDX:.*]] = arith.constant 0
+//  CHECK-NEXT:   %[[DIM:.*]] = memref.dim %[[MEM]], %[[IDX]] : memref<?x8xi32>
+//       CHECK:   return %[[DIM]] : index
+func.func @dim_of_memref_expand_shape(%arg0: memref<?x8xi32>)
+    -> index {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %s = memref.dim %arg0, %c0 : memref<?x8xi32>
+  %0 = memref.expand_shape %arg0 [[0, 1], [2, 3]] output_shape [1, %s, 2, 4]: memref<?x8xi32> into memref<1x?x2x4xi32>
+  %1 = memref.dim %0, %c1 : memref<1x?x2x4xi32>
+  return %1 : index
+}



More information about the Mlir-commits mailing list