[Mlir-commits] [mlir] [MLIR][Linalg] pack, unpack to take memref inputs (PR #129036)

Hyunsung Lee llvmlistbot at llvm.org
Sat Mar 29 16:40:23 PDT 2025


================
@@ -4951,7 +4993,7 @@ LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
     rewriter.modifyOpInPlace(packOp, [&] {
       packOp.getSourceMutable().assign(source);
       packOp.getDestMutable().assign(dest);
-      packOp.getResult().setType(cast<RankedTensorType>(dest.getType()));
+      packOp.getResult().setType(cast<ShapedType>(dest.getType()));
----------------
ita9naiwa wrote:

now tensor version, memref version works well.

I think it's good idea to put these into test cases; how do you think?

```mlir
module {
func.func @fold_pack_unpack_memref(%x: memref<2x3xf32>) -> memref<2x3xf32> {
  %unpacked = linalg.unpack %x outer_dims_perm = [] inner_dims_pos = [] inner_tiles = []
               into %x : memref<2x3xf32> -> memref<2x3xf32>
  %packed = linalg.pack %unpacked outer_dims_perm = [] inner_dims_pos = [] inner_tiles = []
             into %x : memref<2x3xf32> -> memref<2x3xf32>
  return %packed : memref<2x3xf32>
}
}
```
is canonicalized into
```mlir
module {
  func.func @fold_pack_unpack_memref(%arg0: memref<2x3xf32>) -> memref<2x3xf32> {
    %unpack = linalg.unpack %arg0 inner_dims_pos = [] inner_tiles = [] into %arg0 : memref<2x3xf32> -> memref<2x3xf32>
    return %arg0 : memref<2x3xf32>
  }
}
```


```mlir
module {
func.func @fold_pack_unpack_tensor(%x: tensor<2x3xf32>) -> tensor<2x3xf32> {
  %unpacked = linalg.unpack %x outer_dims_perm = [] inner_dims_pos = [] inner_tiles = []
             into %x : tensor<2x3xf32> -> tensor<2x3xf32>
  %packed = linalg.pack %unpacked outer_dims_perm = [] inner_dims_pos = [] inner_tiles = []
             into %x : tensor<2x3xf32> -> tensor<2x3xf32>
  return %packed : tensor<2x3xf32>
}
}
```
reduces into
```mlir
base ❯ mlir-opt --canonicalize --cse cano-tensor.mlir
module {
  func.func @fold_pack_unpack_tensor(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
    return %arg0 : tensor<2x3xf32>
  }
}
```

https://github.com/llvm/llvm-project/pull/129036


More information about the Mlir-commits mailing list