[Mlir-commits] [mlir] [mlir][memref] Remove incorrect `memref.transpose` fold (PR #79809)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Jan 29 03:26:26 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-memref

Author: Benjamin Maxwell (MacDue)

<details>
<summary>Changes</summary>

This folded casts into `memref.transpose` without updating the result type of the transpose op, which resulted in IR that failed to verify for statically sized memrefs.

i.e.

```mlir
%cast = memref.cast %0 : memref<?x4xf32> to memref<?x?xf32>
%transpose = memref.transpose %cast : memref<?x?xf32> to memref<?x?xf32>
```

would fold to:

```mlir
// Fails verification:
%transpose = memref.transpose %cast : memref<?x4xf32> to memref<?x?xf32>
```

---
Full diff: https://github.com/llvm/llvm-project/pull/79809.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp (-2) 
- (modified) mlir/test/Dialect/MemRef/canonicalize.mlir (+13) 


``````````diff
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index b79ab8f3d671e01..8b5765b7f8dba2a 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -3227,8 +3227,6 @@ OpFoldResult TransposeOp::fold(FoldAdaptor) {
   // result types are identical already.
   if (getPermutation().isIdentity() && getType() == getIn().getType())
     return getIn();
-  if (succeeded(foldMemRefCast(*this)))
-    return getResult();
   // Fold two consecutive memref.transpose Ops into one by composing their
   // permutation maps.
   if (auto otherTransposeOp = getIn().getDefiningOp<memref::TransposeOp>()) {
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index eccfc485b2034e4..61790bbc8a96ed6 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -1023,3 +1023,16 @@ func.func @fold_identity_transpose(%arg0: memref<1x2x3x4x5xf32>) -> memref<1x2x3
   // CHECK: return %[[arg0]]
   return %1 : memref<1x2x3x4x5xf32>
 }
+
+// -----
+
+// CHECK-LABEL: func @cannot_fold_transpose_cast(
+//  CHECK-SAME:     %[[arg0:.*]]: memref<?x4xf32, strided<[?, ?], offset: ?>>
+func.func @cannot_fold_transpose_cast(%arg0: memref<?x4xf32, strided<[?, ?], offset: ?>>) -> memref<?x?xf32, strided<[?, ?], offset: ?>> {
+    // CHECK: %[[CAST:.*]] = memref.cast %[[arg0]] : memref<?x4xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>>
+    %cast = memref.cast %arg0 : memref<?x4xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>>
+    // CHECK: %[[TRANSPOSE:.*]] = memref.transpose %[[CAST]] (d0, d1) -> (d1, d0) : memref<?x?xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>>
+    %transpose = memref.transpose %cast (d0, d1) -> (d1, d0) : memref<?x?xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>>
+    // CHECK: return %[[TRANSPOSE]]
+    return %transpose : memref<?x?xf32, strided<[?, ?], offset: ?>>
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/79809


More information about the Mlir-commits mailing list