[llvm] [mlir] [mlir][mesh] Add spmdization pass (PR #80518)

Boian Petkantchin via llvm-commits llvm-commits at lists.llvm.org
Mon Feb 5 09:01:04 PST 2024


================
@@ -0,0 +1,131 @@
+// RUN: mlir-opt --mesh-spmdization --test-constant-fold %s | FileCheck %s
----------------
sogartar wrote:

There is already a test that reshards from partial axis -> full replication at 
https://github.com/llvm/llvm-project/blob/e2bb91b25c8740625fecd127c1d908a2fabd0102/mlir/test/Dialect/Mesh/resharding-spmdization.mlir#L145

The one you are proposing currently produces an unoptimal result.
```mlir
mesh.mesh @mesh_1d(shape = 2)

func.func @partial_axis_to_split_axis(
  %arg0: tensor<10x14xf32>
) -> tensor<10x14xf32> {
  %0 = mesh.shard %arg0 to <@mesh_1d, [[]], partial = sum[0]> : tensor<10x14xf32>
  %1 = mesh.shard %0 to <@mesh_1d, [[0]]> annotate_for_users : tensor<10x14xf32>
  return %1 : tensor<10x14xf32>
}
```
results in
```
mesh.mesh @mesh_1d(shape = 2)
func.func @partial_axis_to_split_axis(%arg0: tensor<10x14xf32>) -> tensor<5x14xf32> {
  %c0 = arith.constant 0 : index
  %c10 = arith.constant 10 : index
  %0 = mesh.all_reduce %arg0 on @mesh_1d mesh_axes = [0] : tensor<10x14xf32> -> tensor<10x14xf32>
  %1 = mesh.process_multi_index on @mesh_1d axes = [0] : index
  %2 = mesh.mesh_shape @mesh_1d axes = [0] : index
  %3 = arith.remui %c10, %2 : index
  %4 = arith.cmpi eq, %3, %c0 : index
  cf.assert %4, "Sharding a tensor with axis size that is not exactly divisible by the mesh axis size is not supported."
  %5 = arith.divui %c10, %2 : index
  %6 = arith.muli %5, %1 : index
  %extracted_slice = tensor.extract_slice %0[%6, 0] [5, 14] [1, 1] : tensor<10x14xf32> to tensor<5x14xf32>
  return %extracted_slice : tensor<5x14xf32>
}
```

It first all-reduces to get to full replication than slices to split the axis.
There are 2 approaches
1. Match this pattern and optimize it to a reduce-scatter. This would benefit if there are other cases where we get to this result. Unfortunately, then it will be brittle with respect to future changes of the full-to-split slicing.
2. Detect the resharding pattern at the level of sharding annotations and produce the optimal reduce-scatter at the first place.

I would like to add this optimization in another PR.

https://github.com/llvm/llvm-project/pull/80518


More information about the llvm-commits mailing list