[all-commits] [llvm/llvm-project] baabcb: [mlir][mesh] Shardingcontrol (#102598)
Frank Schlimbach via All-commits
all-commits at lists.llvm.org
Mon Aug 12 04:21:19 PDT 2024
Branch: refs/heads/main
Home: https://github.com/llvm/llvm-project
Commit: baabcb28983edf8f20e39b89e2b1745412073b44
https://github.com/llvm/llvm-project/commit/baabcb28983edf8f20e39b89e2b1745412073b44
Author: Frank Schlimbach <frank.schlimbach at intel.com>
Date: 2024-08-12 (Mon, 12 Aug 2024)
Changed paths:
M mlir/include/mlir/Dialect/Mesh/IR/CMakeLists.txt
M mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
M mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
M mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
A mlir/include/mlir/Dialect/Mesh/IR/TensorShardingInterfaceImpl.h
M mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h
M mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.td
M mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h
M mlir/include/mlir/InitAllDialects.h
M mlir/include/mlir/Interfaces/InferTypeOpInterface.h
M mlir/lib/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.cpp
M mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
M mlir/lib/Dialect/Mesh/Interfaces/CMakeLists.txt
M mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
A mlir/lib/Dialect/Mesh/Interfaces/TensorShardingInterfaceImpl.cpp
M mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
M mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
M mlir/test/Dialect/Linalg/mesh-sharding-propagation.mlir
M mlir/test/Dialect/Linalg/mesh-spmdization.mlir
M mlir/test/Dialect/Mesh/canonicalization.mlir
M mlir/test/Dialect/Mesh/invalid.mlir
M mlir/test/Dialect/Mesh/ops.mlir
M mlir/test/Dialect/Mesh/resharding-spmdization.mlir
M mlir/test/Dialect/Mesh/sharding-propagation.mlir
M mlir/test/Dialect/Mesh/simplifications.mlir
M mlir/test/Dialect/Mesh/spmdization.mlir
A mlir/test/Dialect/Tensor/mesh-spmdization.mlir
M mlir/test/lib/Dialect/Mesh/TestReshardingSpmdization.cpp
Log Message:
-----------
[mlir][mesh] Shardingcontrol (#102598)
This is a fixed copy of #98145 (necessary after it got reverted).
@sogartar @yaochengji
This PR adds the following to #98145:
- `UpdateHaloOp` accepts a `memref` (instead of a tensor) and not
returning a result to clarify its inplace-semantics
- `UpdateHaloOp` accepts `split_axis` to allow multiple mesh-axes per
tensor/memref-axis (similar to `mesh.sharding`)
- The implementation of `Shardinginterface` for tensor operation
(`tensor.empty` for now) moved from the tensor library to the mesh
interface library. `spmdize` uses features from `mesh` dialect.
@rengolin agreed that `tensor` should not depend on `mesh` so this
functionality cannot live in a `tensor`s lib. The unfulfilled dependency
caused the issues leading to reverting #98145. Such cases are generally
possible and might lead to re-considering the current structure (like
for tosa ops).
- rebased onto latest main
--------------------------
Replacing `#mesh.sharding` attribute with operation `mesh.sharding`
- extended semantics now allow providing optional `halo_sizes` and
`sharded_dims_sizes`
- internally a sharding is represented as a non-IR class
`mesh::MeshSharding`
What previously was
```mlir
%sharded0 = mesh.shard %arg0 <@mesh0, [[0]]> : tensor<4x8xf32>
%sharded1 = mesh.shard %arg1 <@mesh0, [[0]]> annotate_for_users : tensor<16x8xf32>
```
is now
```mlir
%sharding = mesh.sharding @mesh0, [[0]] : !mesh.sharding
%0 = mesh.shard %arg0 to %sharding : tensor<4x8xf32>
%1 = mesh.shard %arg1 to %sharding annotate_for_users : tensor<16x8xf32>
```
and allows additional annotations to control the shard sizes:
```mlir
mesh.mesh @mesh0 (shape = 4)
%sharding0 = mesh.sharding @mesh0, [[0]] halo_sizes = [1, 2] : !mesh.sharding
%0 = mesh.shard %arg0 to %sharding0 : tensor<4x8xf32>
%sharding1 = mesh.sharding @mesh0, [[0]] sharded_dims_sizes = [3, 5, 5, 3] : !mesh.sharding
%1 = mesh.shard %arg1 to %sharding1 annotate_for_users : tensor<16x8xf32>
```
- `mesh.shard` op accepts additional optional attribute `force`, useful
for halo updates
- Some initial spmdization support for the new semantics
- Support for `tensor.empty` reacting on `sharded_dims_sizes` and
`halo_sizes` in the sharding
- New collective operation `mesh.update_halo` as a spmdized target for
shardings with `halo_sizes`
---------
Co-authored-by: frank.schlimbach <fschlimb at smtp.igk.intel.com>
Co-authored-by: Jie Fu <jiefu at tencent.com>
To unsubscribe from these emails, change your notification settings at https://github.com/llvm/llvm-project/settings/notifications
More information about the All-commits
mailing list