[Mlir-commits] [mlir] [MLIR][Linalg] pack, unpack to take memref inputs (PR #129036)

Han-Chung Wang llvmlistbot at llvm.org
Mon Mar 24 10:38:37 PDT 2025


================
@@ -43,10 +43,10 @@ class Linalg_RelayoutOp<string mnemonic, list<Trait> traits = []> :
   code commonExtraClassDeclaration = [{
     size_t getSourceRank() { return getSourceType().getRank(); };
     size_t getDestRank() { return getDestType().getRank(); };
-    RankedTensorType getSourceType() {
-      return ::llvm::cast<RankedTensorType>(getSource().getType()); };
-    RankedTensorType getDestType() {
-      return ::llvm::cast<RankedTensorType>(getDest().getType()); };
+    ShapedType getSourceType() {
----------------
hanhanW wrote:

In the above traits, we need to remove the `NoMemoryEffect` trait and implement the trait like other LinalgOps. I.e., we need `DeclareOpInterfaceMethods<MemoryEffectsOpInterface>` for pack/unpack ops. Here is the [implmentation](https://github.com/iree-org/iree/blob/ce2585aa5db62151157e9af4b0449109e9335f4c/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp#L58-L78) in IREE. We should have something similar in `LinalgOps.cpp`.

Side note: other linalg ops have similar implementation using the [getGenericEffectsImpl method](https://github.com/llvm/llvm-project/blob/8f3f93cd78cfbf1dea349be2eef98802da8ad929/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp#L1212-L1236)

If we don't have the trait, I think the memref version would always be DCE-ed. E.g., if you run `mlir-opt --cse repro.mlir`, the pack op is not preserved, which is wrong.


```mlir
func.func @pack_memref(%source: memref<128x256xf32>) -> memref<8x16x8x32xf32> {
  %dest = memref.alloc() : memref<8x16x8x32xf32>
  linalg.pack %source outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [8, 32]
      into %dest : memref<128x256xf32> -> memref<8x16x8x32xf32>
  return %dest : memref<8x16x8x32xf32>
}
```

becomes

```mlir
module {
  func.func @pack_memref(%arg0: memref<128x256xf32>) -> memref<8x16x8x32xf32> {
    %alloc = memref.alloc() : memref<8x16x8x32xf32>
    return %alloc : memref<8x16x8x32xf32>
  }
}
```


https://github.com/llvm/llvm-project/pull/129036


More information about the Mlir-commits mailing list