[mlir] [llvm] [mlir][mesh] Add spmdization pass (PR #80518)

Chengji Yao via llvm-commits llvm-commits at lists.llvm.org
Mon Feb 5 11:09:41 PST 2024


================
@@ -29,4 +29,79 @@ def ShardingPropagation : Pass<"sharding-propagation", "mlir::func::FuncOp"> {
   ];
 }
 
+def Spmdization : Pass<"mesh-spmdization", "mlir::func::FuncOp"> {
+  let summary = "Partition a function into SPMD form.";
+  let description = [{
+    This pass fits in right after a pass that annotates the function with
+    shardings like the `ShardingPropagation` pass.
+    It operates on a fully annotated IR.
+
+    A fully annotated IR required that all ranked tensor operands, results and
+    block arguments are annotated with the `mesh.shard` operation.
+  
+    All direct descendant operations in the function must implement the
+    `ShardingInterface` interface or all their ranked tensor operands and
+    results must have full replication sharding.
+
+    The input IR must have sharding annotations such that each operation
+    that implements `ShardingInterface` can handle during spmdization with
+    its `spmdize` method.
+    This can be achieved with the `ShardingPropagation` pass.
+
+    If the function has multiple terminating blocks,
+    it is the responsibility of the the one who annotates the function with
+    shardings to make sure that all returns would be consisted that is,
+    have the same sharding.
+
+    Example:
+    ```mlir
+    mesh.mesh @mesh_1d(shape = 2)
+
+    func.func @f(
+      %arg0: tensor<2xi8>
+    // CHECK-SAME: -> tensor<2xi8> {
+    ) -> tensor<2xi8> {
+      %0 = mesh.shard %arg0 to <@mesh_1d, [[]]> : tensor<2xi8>
+      %1 = mesh.shard %0 to <@mesh_1d, [[0]]> annotate_for_users: tensor<2xi8>
+      %2 = tosa.abs %1 : (tensor<2xi8>) -> tensor<2xi8>
+      %3 = mesh.shard %2 to <@mesh_1d, [[0]]> : tensor<2xi8>
+      %4 = mesh.shard %3 to <@mesh_1d, [[]]> annotate_for_users: tensor<2xi8>
+      return %4 : tensor<2xi8>
+    }
+    ```
+    Spmdizing the above would result in 
+    * resharding the fully replicated input into splitting it along the only
+    tensor axis.
+    * Performing the element-wise `abs` operation on each device.
+    * Resharding back to full replication with an all-gather.
+
+    ```mlir
+    mesh.mesh @mesh_1d(shape = 2)
+    func.func @f(%arg0: tensor<2xi8>) -> tensor<2xi8> {
----------------
yaochengji wrote:

Oh, I meant the function type could change.

E.g. the shard of the argument `tensor<2xi8>` of a func is set to `<@mesh_1d, [[0]]>`, then after spmdization this argument could be of type `tensor<1xi8>`

https://github.com/llvm/llvm-project/pull/80518


More information about the llvm-commits mailing list