[Mlir-commits] [mlir] [MLIR][Linalg] pack, unpack to take memref inputs (PR #129036)
Han-Chung Wang
llvmlistbot at llvm.org
Mon Mar 24 10:38:37 PDT 2025
================
@@ -43,10 +43,10 @@ class Linalg_RelayoutOp<string mnemonic, list<Trait> traits = []> :
code commonExtraClassDeclaration = [{
size_t getSourceRank() { return getSourceType().getRank(); };
size_t getDestRank() { return getDestType().getRank(); };
- RankedTensorType getSourceType() {
- return ::llvm::cast<RankedTensorType>(getSource().getType()); };
- RankedTensorType getDestType() {
- return ::llvm::cast<RankedTensorType>(getDest().getType()); };
+ ShapedType getSourceType() {
----------------
hanhanW wrote:
In the above traits, we need to remove the `NoMemoryEffect` trait and implement the trait like other LinalgOps. I.e., we need `DeclareOpInterfaceMethods<MemoryEffectsOpInterface>` for pack/unpack ops. Here is the [implmentation](https://github.com/iree-org/iree/blob/ce2585aa5db62151157e9af4b0449109e9335f4c/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp#L58-L78) in IREE. We should have something similar in `LinalgOps.cpp`.
Side note: other linalg ops have similar implementation using the [getGenericEffectsImpl method](https://github.com/llvm/llvm-project/blob/8f3f93cd78cfbf1dea349be2eef98802da8ad929/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp#L1212-L1236)
If we don't have the trait, I think the memref version would always be DCE-ed. E.g., if you run `mlir-opt --cse repro.mlir`, the pack op is not preserved, which is wrong.
```mlir
func.func @pack_memref(%source: memref<128x256xf32>) -> memref<8x16x8x32xf32> {
%dest = memref.alloc() : memref<8x16x8x32xf32>
linalg.pack %source outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [8, 32]
into %dest : memref<128x256xf32> -> memref<8x16x8x32xf32>
return %dest : memref<8x16x8x32xf32>
}
```
becomes
```mlir
module {
func.func @pack_memref(%arg0: memref<128x256xf32>) -> memref<8x16x8x32xf32> {
%alloc = memref.alloc() : memref<8x16x8x32xf32>
return %alloc : memref<8x16x8x32xf32>
}
}
```
https://github.com/llvm/llvm-project/pull/129036
More information about the Mlir-commits
mailing list