[Mlir-commits] [mlir] [MLIR][Linalg] pack, unpack to take memref inputs (PR #129036)
Hyunsung Lee
llvmlistbot at llvm.org
Sat Mar 29 16:40:23 PDT 2025
================
@@ -4951,7 +4993,7 @@ LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
rewriter.modifyOpInPlace(packOp, [&] {
packOp.getSourceMutable().assign(source);
packOp.getDestMutable().assign(dest);
- packOp.getResult().setType(cast<RankedTensorType>(dest.getType()));
+ packOp.getResult().setType(cast<ShapedType>(dest.getType()));
----------------
ita9naiwa wrote:
now tensor version, memref version works well.
I think it's good idea to put these into test cases; how do you think?
```mlir
module {
func.func @fold_pack_unpack_memref(%x: memref<2x3xf32>) -> memref<2x3xf32> {
%unpacked = linalg.unpack %x outer_dims_perm = [] inner_dims_pos = [] inner_tiles = []
into %x : memref<2x3xf32> -> memref<2x3xf32>
%packed = linalg.pack %unpacked outer_dims_perm = [] inner_dims_pos = [] inner_tiles = []
into %x : memref<2x3xf32> -> memref<2x3xf32>
return %packed : memref<2x3xf32>
}
}
```
is canonicalized into
```mlir
module {
func.func @fold_pack_unpack_memref(%arg0: memref<2x3xf32>) -> memref<2x3xf32> {
%unpack = linalg.unpack %arg0 inner_dims_pos = [] inner_tiles = [] into %arg0 : memref<2x3xf32> -> memref<2x3xf32>
return %arg0 : memref<2x3xf32>
}
}
```
```mlir
module {
func.func @fold_pack_unpack_tensor(%x: tensor<2x3xf32>) -> tensor<2x3xf32> {
%unpacked = linalg.unpack %x outer_dims_perm = [] inner_dims_pos = [] inner_tiles = []
into %x : tensor<2x3xf32> -> tensor<2x3xf32>
%packed = linalg.pack %unpacked outer_dims_perm = [] inner_dims_pos = [] inner_tiles = []
into %x : tensor<2x3xf32> -> tensor<2x3xf32>
return %packed : tensor<2x3xf32>
}
}
```
reduces into
```mlir
base ❯ mlir-opt --canonicalize --cse cano-tensor.mlir
module {
func.func @fold_pack_unpack_tensor(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
return %arg0 : tensor<2x3xf32>
}
}
```
https://github.com/llvm/llvm-project/pull/129036
More information about the Mlir-commits
mailing list