[Mlir-commits] [mlir] [mlir][memref] Remove incorrect `memref.transpose` fold (PR #79809)
Benjamin Maxwell
llvmlistbot at llvm.org
Mon Jan 29 06:54:49 PST 2024
https://github.com/MacDue updated https://github.com/llvm/llvm-project/pull/79809
>From d893252252d89664633f10c9960249d343bc4ea7 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Mon, 29 Jan 2024 11:02:39 +0000
Subject: [PATCH 1/2] [mlir][memref] Remove incorrect `memref.transpose` fold
This folded casts into `memref.transpose` without updating the result
type of the transpose op, which resulted in IR that failed to verify
for statically sized memrefs.
i.e.
```mlir
%cast = memref.cast %0 : memref<?x4xf32> to memref<?x?xf32>
%transpose = memref.transpose %cast : memref<?x?xf32> to memref<?x?xf32>
```
would fold to:
```mlir
// Fails verification:
%transpose = memref.transpose %cast : memref<?x4xf32> to memref<?x?xf32>
```
---
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 2 --
mlir/test/Dialect/MemRef/canonicalize.mlir | 13 +++++++++++++
2 files changed, 13 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index b79ab8f3d671e0..8b5765b7f8dba2 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -3227,8 +3227,6 @@ OpFoldResult TransposeOp::fold(FoldAdaptor) {
// result types are identical already.
if (getPermutation().isIdentity() && getType() == getIn().getType())
return getIn();
- if (succeeded(foldMemRefCast(*this)))
- return getResult();
// Fold two consecutive memref.transpose Ops into one by composing their
// permutation maps.
if (auto otherTransposeOp = getIn().getDefiningOp<memref::TransposeOp>()) {
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index eccfc485b2034e..61790bbc8a96ed 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -1023,3 +1023,16 @@ func.func @fold_identity_transpose(%arg0: memref<1x2x3x4x5xf32>) -> memref<1x2x3
// CHECK: return %[[arg0]]
return %1 : memref<1x2x3x4x5xf32>
}
+
+// -----
+
+// CHECK-LABEL: func @cannot_fold_transpose_cast(
+// CHECK-SAME: %[[arg0:.*]]: memref<?x4xf32, strided<[?, ?], offset: ?>>
+func.func @cannot_fold_transpose_cast(%arg0: memref<?x4xf32, strided<[?, ?], offset: ?>>) -> memref<?x?xf32, strided<[?, ?], offset: ?>> {
+ // CHECK: %[[CAST:.*]] = memref.cast %[[arg0]] : memref<?x4xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>>
+ %cast = memref.cast %arg0 : memref<?x4xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>>
+ // CHECK: %[[TRANSPOSE:.*]] = memref.transpose %[[CAST]] (d0, d1) -> (d1, d0) : memref<?x?xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>>
+ %transpose = memref.transpose %cast (d0, d1) -> (d1, d0) : memref<?x?xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>>
+ // CHECK: return %[[TRANSPOSE]]
+ return %transpose : memref<?x?xf32, strided<[?, ?], offset: ?>>
+}
>From 3f56308012d97eeb113955377e10ab9ffc2a7964 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Mon, 29 Jan 2024 14:53:41 +0000
Subject: [PATCH 2/2] Slightly simplify test
---
mlir/test/Dialect/MemRef/canonicalize.mlir | 16 +++++++++-------
1 file changed, 9 insertions(+), 7 deletions(-)
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index 61790bbc8a96ed..993ef32edc9d44 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -1026,13 +1026,15 @@ func.func @fold_identity_transpose(%arg0: memref<1x2x3x4x5xf32>) -> memref<1x2x3
// -----
+#transpose_map = affine_map<(d0, d1)[s0] -> (d0 + d1 * s0)>
+
// CHECK-LABEL: func @cannot_fold_transpose_cast(
-// CHECK-SAME: %[[arg0:.*]]: memref<?x4xf32, strided<[?, ?], offset: ?>>
-func.func @cannot_fold_transpose_cast(%arg0: memref<?x4xf32, strided<[?, ?], offset: ?>>) -> memref<?x?xf32, strided<[?, ?], offset: ?>> {
- // CHECK: %[[CAST:.*]] = memref.cast %[[arg0]] : memref<?x4xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>>
- %cast = memref.cast %arg0 : memref<?x4xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>>
- // CHECK: %[[TRANSPOSE:.*]] = memref.transpose %[[CAST]] (d0, d1) -> (d1, d0) : memref<?x?xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>>
- %transpose = memref.transpose %cast (d0, d1) -> (d1, d0) : memref<?x?xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>>
+// CHECK-SAME: %[[arg0:.*]]: memref<?x4xf32>
+func.func @cannot_fold_transpose_cast(%arg0: memref<?x4xf32>) -> memref<?x?xf32, #transpose_map> {
+ // CHECK: %[[CAST:.*]] = memref.cast %[[arg0]] : memref<?x4xf32> to memref<?x?xf32>
+ %cast = memref.cast %arg0 : memref<?x4xf32> to memref<?x?xf32>
+ // CHECK: %[[TRANSPOSE:.*]] = memref.transpose %[[CAST]] (d0, d1) -> (d1, d0) : memref<?x?xf32> to memref<?x?xf32, #{{.*}}>
+ %transpose = memref.transpose %cast (d0, d1) -> (d1, d0) : memref<?x?xf32> to memref<?x?xf32, #transpose_map>
// CHECK: return %[[TRANSPOSE]]
- return %transpose : memref<?x?xf32, strided<[?, ?], offset: ?>>
+ return %transpose : memref<?x?xf32, #transpose_map>
}
More information about the Mlir-commits
mailing list