[Mlir-commits] [llvm] [mlir] [mlir][mesh] Add spmdization pass (PR #80518)
Boian Petkantchin
llvmlistbot at llvm.org
Mon Feb 5 09:01:04 PST 2024
================
@@ -0,0 +1,131 @@
+// RUN: mlir-opt --mesh-spmdization --test-constant-fold %s | FileCheck %s
----------------
sogartar wrote:
There is already a test that reshards from partial axis -> full replication at
https://github.com/llvm/llvm-project/blob/e2bb91b25c8740625fecd127c1d908a2fabd0102/mlir/test/Dialect/Mesh/resharding-spmdization.mlir#L145
The one you are proposing currently produces an unoptimal result.
```mlir
mesh.mesh @mesh_1d(shape = 2)
func.func @partial_axis_to_split_axis(
%arg0: tensor<10x14xf32>
) -> tensor<10x14xf32> {
%0 = mesh.shard %arg0 to <@mesh_1d, [[]], partial = sum[0]> : tensor<10x14xf32>
%1 = mesh.shard %0 to <@mesh_1d, [[0]]> annotate_for_users : tensor<10x14xf32>
return %1 : tensor<10x14xf32>
}
```
results in
```
mesh.mesh @mesh_1d(shape = 2)
func.func @partial_axis_to_split_axis(%arg0: tensor<10x14xf32>) -> tensor<5x14xf32> {
%c0 = arith.constant 0 : index
%c10 = arith.constant 10 : index
%0 = mesh.all_reduce %arg0 on @mesh_1d mesh_axes = [0] : tensor<10x14xf32> -> tensor<10x14xf32>
%1 = mesh.process_multi_index on @mesh_1d axes = [0] : index
%2 = mesh.mesh_shape @mesh_1d axes = [0] : index
%3 = arith.remui %c10, %2 : index
%4 = arith.cmpi eq, %3, %c0 : index
cf.assert %4, "Sharding a tensor with axis size that is not exactly divisible by the mesh axis size is not supported."
%5 = arith.divui %c10, %2 : index
%6 = arith.muli %5, %1 : index
%extracted_slice = tensor.extract_slice %0[%6, 0] [5, 14] [1, 1] : tensor<10x14xf32> to tensor<5x14xf32>
return %extracted_slice : tensor<5x14xf32>
}
```
It first all-reduces to get to full replication than slices to split the axis.
There are 2 approaches
1. Match this pattern and optimize it to a reduce-scatter. This would benefit if there are other cases where we get to this result. Unfortunately, then it will be brittle with respect to future changes of the full-to-split slicing.
2. Detect the resharding pattern at the level of sharding annotations and produce the optimal reduce-scatter at the first place.
I would like to add this optimization in another PR.
https://github.com/llvm/llvm-project/pull/80518
More information about the Mlir-commits
mailing list