[Mlir-commits] [mlir] [llvm] [mlir][mesh] Add spmdization pass (PR #80518)
Boian Petkantchin
llvmlistbot at llvm.org
Mon Feb 5 14:05:08 PST 2024
================
@@ -29,4 +29,79 @@ def ShardingPropagation : Pass<"sharding-propagation", "mlir::func::FuncOp"> {
];
}
+def Spmdization : Pass<"mesh-spmdization", "mlir::func::FuncOp"> {
+ let summary = "Partition a function into SPMD form.";
+ let description = [{
+ This pass fits in right after a pass that annotates the function with
+ shardings like the `ShardingPropagation` pass.
+ It operates on a fully annotated IR.
+
+ A fully annotated IR required that all ranked tensor operands, results and
+ block arguments are annotated with the `mesh.shard` operation.
+
+ All direct descendant operations in the function must implement the
+ `ShardingInterface` interface or all their ranked tensor operands and
+ results must have full replication sharding.
+
+ The input IR must have sharding annotations such that each operation
+ that implements `ShardingInterface` can handle during spmdization with
+ its `spmdize` method.
+ This can be achieved with the `ShardingPropagation` pass.
+
+ If the function has multiple terminating blocks,
+ it is the responsibility of the the one who annotates the function with
+ shardings to make sure that all returns would be consisted that is,
+ have the same sharding.
+
+ Example:
+ ```mlir
+ mesh.mesh @mesh_1d(shape = 2)
+
+ func.func @f(
+ %arg0: tensor<2xi8>
+ // CHECK-SAME: -> tensor<2xi8> {
+ ) -> tensor<2xi8> {
+ %0 = mesh.shard %arg0 to <@mesh_1d, [[]]> : tensor<2xi8>
+ %1 = mesh.shard %0 to <@mesh_1d, [[0]]> annotate_for_users: tensor<2xi8>
+ %2 = tosa.abs %1 : (tensor<2xi8>) -> tensor<2xi8>
+ %3 = mesh.shard %2 to <@mesh_1d, [[0]]> : tensor<2xi8>
+ %4 = mesh.shard %3 to <@mesh_1d, [[]]> annotate_for_users: tensor<2xi8>
+ return %4 : tensor<2xi8>
+ }
+ ```
+ Spmdizing the above would result in
+ * resharding the fully replicated input into splitting it along the only
+ tensor axis.
+ * Performing the element-wise `abs` operation on each device.
+ * Resharding back to full replication with an all-gather.
+
+ ```mlir
+ mesh.mesh @mesh_1d(shape = 2)
+ func.func @f(%arg0: tensor<2xi8>) -> tensor<2xi8> {
----------------
sogartar wrote:
I see what you mean. I changed the example.
https://github.com/llvm/llvm-project/pull/80518
More information about the Mlir-commits
mailing list