[Mlir-commits] [mlir] [mlir][mesh] Add resharding spmdization on a 1D device mesh (PR #76179)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Dec 21 12:31:18 PST 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-core

Author: Boian Petkantchin (sogartar)

<details>
<summary>Changes</summary>

The current implementation supports only sharding of tensor axes that have size divisible by the mesh axis size.

---

Patch is 83.62 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/76179.diff


15 Files Affected:

- (modified) mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td (+10-1) 
- (modified) mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h (+3) 
- (modified) mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td (+47-1) 
- (added) mlir/include/mlir/Dialect/Mesh/Transforms/ReshardingSpmdizationDoc (+621) 
- (added) mlir/include/mlir/Dialect/Mesh/Transforms/Spmdization.h (+35) 
- (modified) mlir/include/mlir/Support/MathExtras.h (+11) 
- (modified) mlir/lib/Dialect/Mesh/IR/MeshOps.cpp (+160-53) 
- (modified) mlir/lib/Dialect/Mesh/Transforms/CMakeLists.txt (+3) 
- (added) mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp (+639) 
- (modified) mlir/test/Dialect/Mesh/invalid.mlir (+96) 
- (modified) mlir/test/Dialect/Mesh/ops.mlir (+49) 
- (added) mlir/test/Dialect/Mesh/resharding-spmdization.mlir (+154) 
- (modified) mlir/test/lib/Dialect/Mesh/CMakeLists.txt (+1) 
- (added) mlir/test/lib/Dialect/Mesh/TestReshardingSpmdization.cpp (+124) 
- (modified) mlir/tools/mlir-opt/mlir-opt.cpp (+2) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
index 9d39b1b3329fb4..a9d30dfbb9a76e 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
@@ -33,6 +33,10 @@ def Mesh_Dialect : Dialect {
   let useDefaultAttributePrinterParser = 1;
   let hasConstantMaterializer = 1;
 }
+
+def Mesh_MeshAxis : I<16>;
+def Mesh_MeshAxesAttr : DenseArrayAttrBase<"DenseI16ArrayAttr", "int16_t", "i16">;
+
 //===----------------------------------------------------------------------===//
 // Mesh Enums.
 //===----------------------------------------------------------------------===//
@@ -125,7 +129,7 @@ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
 
     // The tensor is sharded on the first dimension along axis 0 of @mesh0 and
     // it is also a partial_sum along mesh axis 1.
-    tensor<4x8xf32, #mesh.shard<@mesh0, [[0], [], [1]]>
+    tensor<4x8xf32, #mesh.shard<@mesh0, [[0], []], partial = sum[1]>
 
     // The tensor is sharded on the first dimension along axis 0 of @mesh0 and
     // it is also a partial_max along mesh axis 1.
@@ -158,6 +162,11 @@ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
     }]>
   ];
 
+  let extraClassDeclaration = [{
+    bool operator==(::mlir::Attribute rhs) const;
+    bool operator==(::mlir::mesh::MeshShardingAttr rhs) const;
+  }];
+
   let genVerifyDecl = 1;
 }
 
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
index 9077d2eb0189b7..ce7d5d045122d9 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
@@ -30,6 +30,9 @@
 namespace mlir {
 namespace mesh {
 
+using MeshAxis = int16_t;
+using MeshAxesAttr = DenseI16ArrayAttr;
+
 bool isReductionLoop(IteratorType iType);
 
 bool areReductionAndPartialMatch(IteratorType iType, Partial partial);
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index 784f3eb97763ad..1ed54b6519e4d8 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -10,6 +10,7 @@
 #define MLIR_DIALECT_MESH_IR_MESHOPS_TD
 
 include "mlir/Dialect/Mesh/IR/MeshBase.td"
+include "mlir/Dialect/Shape/IR/ShapeBase.td"
 include "mlir/Interfaces/InferTypeOpInterface.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
 include "mlir/IR/BuiltinTypes.td"
@@ -95,6 +96,28 @@ def Mesh_ClusterOp : Mesh_Op<"cluster", [Symbol]> {
   let hasVerifier = 1;
 }
 
+def Mesh_ClusterShapeOp : Mesh_Op<"cluster_shape", [Pure, DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
+  let summary = "Get the shape of the cluster.";
+  let arguments = (ins
+    FlatSymbolRefAttr:$mesh,
+    DefaultValuedAttr<Mesh_MeshAxesAttr, "{}">:$axes
+  );
+
+  let results = (outs
+    Variadic<Index>:$result
+  );
+
+  let assemblyFormat = [{
+    $mesh (`axes` `=` $axes^)?
+    attr-dict `:` type($result)
+  }];
+
+  let builders = [
+    OpBuilder<(ins "::mlir::mesh::ClusterOp":$mesh)>,
+    OpBuilder<(ins "StringRef":$mesh, "ArrayRef<int16_t>":$axes)>
+  ];
+}
+
 def Mesh_ShardOp : Mesh_Op<"shard", [Pure, SameOperandsAndResultType]> {
   let summary = "Annotate on how a tensor is sharded across a mesh cluster.";
   let description = [{
@@ -186,6 +209,29 @@ def Mesh_ShardOp : Mesh_Op<"shard", [Pure, SameOperandsAndResultType]> {
   }];
 }
 
+def Mesh_ProcessIndexOp : Mesh_Op<"process_index", [Pure, DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
+  let summary = "Get the index of current device along specified mesh axis.";
+  let description = [{
+    It is used in the SPMD format of IR.
+    The `axes` mush be non-negative and less than the total number of mesh axes.
+  }];
+  let arguments = (ins
+    FlatSymbolRefAttr:$mesh,
+    DefaultValuedAttr<Mesh_MeshAxesAttr, "{}">:$axes
+  );
+  let results = (outs
+    Variadic<Index>:$result
+  );
+  let assemblyFormat = [{
+    `on` $mesh (`axes` `=` $axes^)?
+    attr-dict `:` type($result)
+  }];
+  let builders = [
+    OpBuilder<(ins "::mlir::mesh::ClusterOp":$mesh)>,
+    OpBuilder<(ins "StringRef":$mesh, "ArrayRef<int16_t>":$axes)>
+  ];
+}
+
 //===----------------------------------------------------------------------===//
 // collective communication ops
 //===----------------------------------------------------------------------===//
@@ -197,7 +243,7 @@ class Mesh_CollectiveCommunicationOpBase<
       [DeclareOpInterfaceMethods<SymbolUserOpInterface>])> {
   dag commonArgs = (ins
     FlatSymbolRefAttr:$mesh,
-    DefaultValuedAttr<DenseI16ArrayAttr, "{}">:$mesh_axes
+    DefaultValuedAttr<Mesh_MeshAxesAttr, "{}">:$mesh_axes
   );
 }
 
diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/ReshardingSpmdizationDoc b/mlir/include/mlir/Dialect/Mesh/Transforms/ReshardingSpmdizationDoc
new file mode 100644
index 00000000000000..181f07177e0af9
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Mesh/Transforms/ReshardingSpmdizationDoc
@@ -0,0 +1,621 @@
+Reshadring Spmdization Examples
+
+--------------------------------------------------------------
+
+Reshard 2x3 tensor from sharding [[0, 1]] to sharding [[0, 1]] on a 2x3 mesh.
+
+unsharded 2x3 tensor
+11 12 13
+21 22 23
+
+sharded on a 2x3 mesh
+
+sharding = [[0, 1]]
+
+mesh contents:
+
+mesh axis 1
+----------->
++----+----+----+ mesh axis 0 |
+| 11 | 12 | 13 |             |
++----+----+----+             |
+| 21 | 22 | 23 |             |
++----+----+----+             ↓
+
+Transform into
+sharding = [[1, 0]]
+
+mesh axis 1
+----------->
++----+----+----+ mesh axis 0 |
+| 11 | 13 | 22 |             |
++----+----+----+             |
+| 12 | 21 | 23 |             |
++----+----+----+             ↓
+
+Swap contents on devices that have the same linear index in the 2 shardings.
+
+--------------------------------------------------------------
+
+Reshard 2x3 tensor from sharding [[0, 1]] to sharding [[1]] on a 2x3 mesh.
+
+unsharded 2x3 tensor
+11 12 13
+21 22 23
+
+sharded on a 2x3 mesh
+
+sharding = [[0, 1]]
+
+mesh contents:
+
+mesh axis 1
+----------->
++----+----+----+ mesh axis 0 |
+| 11 | 12 | 13 |             |
++----+----+----+             |
+| 21 | 22 | 23 |             |
++----+----+----+             ↓
+
+Transform into
+sharding = [[1]]
+
+mesh axis 1
+----------->
++----+----+----+ mesh axis 0 |
+| 11 | 12 | 13 |             |
+| 21 | 22 | 23 |             |
++----+----+----+             |
+| 11 | 12 | 13 |             |
+| 21 | 22 | 23 |             |
++----+----+----+             ↓
+
+All-gather along mesh axis 0.
+
+--------------------------------------------------------------
+
+Reshard 4x6 tensor from sharding [[], [0, 1]] to sharding [[], [0]] on a 2x3 mesh.
+
+unsharded 4x6 tensor
+11 12 13 14 15 16
+21 22 23 24 25 26
+
+sharded on a 2x3 mesh
+
+sharding = [[], [0, 1]]
+
+mesh contents:
+
+mesh axis 1
+----------->
++----+----+----+ mesh axis 0 |
+| 11 | 12 | 13 |             |
+| 21 | 22 | 23 |             |
++----+----+----+             |
+| 14 | 15 | 16 |             |
+| 24 | 25 | 26 |             |
++----+----+----+             ↓
+
+Transform into
+sharding = [[], [0]]
+
+mesh axis 1
+----------->
++----------+----------+ mesh axis 0 |
+| 11 12 13 | 11 12 13 |             |
+| 21 22 23 | 21 22 23 |             |
++----------+----------+             |
+| 14 15 16 | 14 15 16 |             |
+| 24 25 26 | 24 25 26 |             |
++----------+----------+             ↓
+
+all-gather along mesh axis 1
+
+--------------------------------------------------------------
+
+Reshard 4x8 tensor from sharding [[0], [1, 2]] to sharding [[0], [2]] on a 2x2x2 mesh.
+
+unsharded 4x8 tensor
+11 12 13 14 15 16 17 18
+21 22 23 24 25 26 27 28
+31 32 33 34 35 36 37 38
+41 42 43 44 45 46 47 48
+
+sharded on a 2x2x2 mesh
+
+sharding = [[0], [1, 2]]
+
+mesh contents:
+
+mesh axis 2
+----------->
++-------+-------+ mesh axis 1 | mesh axis 0 |
+| 11 12 | 13 14 |             |             |
+| 21 22 | 23 24 |             |             |
++-------+-------+             |             |
+| 15 16 | 17 18 |             |             |
+| 25 26 | 27 28 |             |             |
++-------+-------+             ↓             |
++-------+-------+                           |
+| 31 32 | 33 34 |                           |
+| 41 42 | 43 44 |                           |
++-------+-------+                           |
+| 35 36 | 37 38 |                           |
+| 45 46 | 47 48 |                           |
++-------+-------+                           ↓
+
+Transform into
+sharding = [[0], [2]]
+
+mesh axis 2
+----------->
++-------------+-------------+ mesh axis 1 | mesh axis 0 |
+| 11 12 13 14 | 15 16 17 18 |             |             |
+| 21 22 23 24 | 25 26 27 28 |             |             |
++-------------+-------------+             |             |
+| 11 12 13 14 | 15 16 17 18 |             |             |
+| 21 22 23 24 | 25 26 27 28 |             |             |
++-------------+-------------+             ↓             |
++-------------+-------------+                           |
+| 31 32 33 34 | 35 36 37 38 |                           |
+| 41 42 43 44 | 45 46 47 48 |                           |
++-------------+-------------+                           |
+| 31 32 33 34 | 35 36 37 38 |                           |
+| 41 42 43 44 | 45 46 47 48 |                           |
++-------------+-------------+                           ↓
+
+Can't be done with just an all-gather along mesh axis 1.
+Can be handled by multiple resharding transformations
+[[0], [1, 2]] -> [[0], [2, 1]] -> [[0], [2]]
+
+--------------------------------------------------------------
+
+Reshard 6x6 tensor from sharding [[0], [1]] to sharding [[1], [0]] on a 2x3 mesh.
+
+unsharded 6x6 tensor
+11 12 13 14 15 16
+21 22 23 24 25 26
+31 32 33 34 35 36
+41 42 43 44 45 46
+51 52 53 54 55 56
+61 62 63 64 65 66
+
+sharded on a 2x3 mesh
+
+sharding = [[0], [1]]
+
+mesh axis 1
+----------->
++-------+-------+-------+ mesh axis 0 |
+| 11 12 | 13 14 | 15 16 |             |
+| 21 22 | 23 24 | 25 26 |             |
+| 31 32 | 33 34 | 35 36 |             |
++-------+-------+-------+             |
+| 41 42 | 43 44 | 45 46 |             |
+| 51 52 | 53 54 | 55 56 |             |
+| 61 62 | 63 64 | 65 66 |             |
++-------+-------+-------+             ↓
+
+transform to
+sharding = [[1], [0]]
+
+mesh axis 1
+----------->
++----------+----------+----------+ mesh axis 0 |
+| 11 12 13 | 31 32 33 | 51 52 53 |             |
+| 21 22 23 | 41 42 43 | 61 62 63 |             |
++----------+----------+----------+             |
+| 14 15 16 | 34 35 36 | 54 55 56 |             |
+| 24 25 26 | 44 45 46 | 64 65 66 |             |
++----------+----------+----------+             ↓
+
+mesh axis 0
+----------->
++----------+----------+ mesh axis 1 |
+| 11 12 13 | 14 15 16 |             |
+| 21 22 23 | 24 25 26 |             |
++----------+----------+             |
+| 31 32 33 | 34 35 36 |             |
+| 41 42 43 | 44 45 46 |             |
++----------+----------+             |
+| 51 52 53 | 54 55 56 |             |
+| 61 62 63 | 64 65 66 |             |
++----------+----------+             ↓
+
+TODO
+
+--------------------------------------------------------------
+
+Reshard 6x6 tensor from sharding [[0], [1]] to sharding [[1], [0]] on a 2x6 mesh.
+
+unsharded 6x6 tensor
+11 12 13 14 15 16
+21 22 23 24 25 26
+31 32 33 34 35 36
+41 42 43 44 45 46
+51 52 53 54 55 56
+61 62 63 64 65 66
+
+shard on 2x6 mesh
+
+sharding = [[0], [1]]
+
+mesh axis 1
+----------->
++----+----+----+----+----+----+ mesh axis 0 |
+| 11 | 12 | 13 ‖ 14 | 15 | 16 |             |
+| 21 | 22 | 23 ‖ 24 | 23 | 26 |             |
+| 31 | 32 | 33 ‖ 34 | 35 | 36 |             |
++----+----+----+----+----+----+             |
+| 41 | 42 | 43 ‖ 44 | 45 | 46 |             |
+| 51 | 52 | 53 ‖ 54 | 55 | 56 |             |
+| 61 | 62 | 63 ‖ 64 | 65 | 66 |             |
++----+----+----+----+----+----+             ↓
+
+
+transform to
+sharding = [[1], [0]]
+
+mesh axis 0
+----------->
++----------+----------+ mesh axis 1 |
+| 11 12 13 | 14 15 16 |             |
++----------+----------+             |
+| 21 22 23 | 24 25 26 |             |
++----------+----------+             |
+| 31 32 33 | 34 35 36 |             |
++==========+==========+             |
+| 41 42 43 | 44 45 46 |             |
++----------+----------+             |
+| 51 52 53 | 54 55 56 |             |
++----------+----------+             |
+| 61 62 63 | 64 65 66 |             |
++----------+----------+             ↓
+
+TODO
+
+--------------------------------------------------------------
+
+Reshard KxL tensor from [[0], [1]] to [[1], [0]] on MxN mesh.
+
+M x N mesh.
+K x L tensor t.
+d(m, n) the tensor on device (m, n).
+
+sharding = [[0], [1]]
+Tensor shard s on each device has size (K ceildiv M, L ceildiv N).
+d(m, n)[k, l] -> t[m * (K ceildiv M) + k, n * (L ceildiv N) + l]
+
+substitute
+i <- m * (K ceildiv M) + k
+j <- n * (L ceildiv N) + l
+
+m -> i floordiv (K ceildiv M)
+n -> j floordiv (L ceildiv N)
+k -> i % (K ceildiv M)
+l -> j % (L ceildiv N)
+
+For the inverse map we get
+t[i, j] -> d(
+    i floordiv (K ceildiv M), j floordiv (L ceildiv N)
+)[
+    i % (K ceildiv M), j % (L ceildiv N)
+]
+
+Check:
+i = 13, j = 17, M = 3, N = 4, K = 16, L = 23
+t[13, 17] = d(
+    13 floordiv (16 ceildiv 3),
+    17 floordiv (23 ceilvid 4)
+)[
+    13 % (16 ceildiv 3),
+    17 % (23 ceilvid 4)
+]
+= d(
+    13 floordiv 6,
+    17 floordiv 6
+)[
+    13 % 6,
+    17 % 6
+]
+= d(2, 2)[1, 5]
+= t[
+    2 * (16 ceildiv 3) + 1,
+    2 * (23 ceildiv 4) + 5
+]
+= t[
+    2 * 6 + 1,
+    2 * 6 + 5
+]
+= t[13, 17]
+
+
+sharding = [[1], [0]]
+Tensor shard s on each device has size (K ceildiv N, L ceildiv M).
+d(m, n)[k, l] -> t[n * (K ceildiv N) + k, m * (L ceildiv M) + l]
+
+substitute
+i <- n * (K ceildiv N) + k
+j <- m * (L ceildiv M) + l
+
+m -> j floordiv (L ceildiv M)
+n -> i floordiv (K ceildiv N)
+k -> i % (K ceildiv N)
+l -> j % (L ceildiv M)
+
+For the inverse map we get
+t[i, j] -> d(
+    j floordiv (L ceildiv M), i floordiv (K ceildiv N)
+)[
+    i % (K ceildiv N), j % (L ceildiv M)
+]
+
+Check:
+i = 9, j = 19, M = 5, N = 2, K = 27, L = 14
+t[9, 19] = d(
+    19 floordiv (14 ceildiv 5),
+    9 floordiv (27 ceildiv 2)
+)[
+    9 % (27 ceildiv 2),
+    19 % (14 ceildiv 5)
+]
+= d(
+    19 floordiv 3,
+    9 floordiv 14
+)[
+    9 % 14
+    19 % 3
+]
+= d(6, 0)[9, 1]
+= t[
+    0 * (27 ceildiv 2) + 9,
+    6 * (14 ceildiv 5) + 1
+]
+= t[
+    0 * 14 + 9,
+    6 * 3 + 1
+]
+= t[9, 19]
+
+sharding = [[0], [1]]
+d(m, n)[k, l] -> t[m * (K ceildiv M) + k, n * (L ceildiv N) + l]
+t[i, j] -> d(i floordiv (K ceildiv M), j floordiv (L ceildiv N))[i % (K ceildiv M), j % (L ceildiv N)]
+
+sharding = [[1], [0]]
+d(m, n)[k, l] -> t[n * (K ceildiv N) + k, m * (L ceildiv M) + l]
+t[i, j] -> d(j floordiv (L ceildiv M), i floordiv (K ceildiv N))[i % (K ceildiv N), j % (L ceildiv M)]
+
+sharding [[0], [1]] -> [[1], [0]]
+d1(m, n) the tensor on device (m, n) for sharding sharding [[0], [1]].
+d2(m, n) the tensor on device (m, n) for sharding sharding [[1], [0]].
+d1(m, n)[k, l] ->
+t[m * (K ceildiv M) + k, n * (L ceildiv N) + l] ->
+d2(
+    (m * (L ceildiv M) + l) floordiv (L ceildiv M),
+    (n * (K ceildiv N) + k) floordiv (K ceildiv N)
+)[
+    (n * (K ceildiv N) + k) % (K ceildiv N),
+    (m * (L ceildiv M) + l) % (L ceildiv M)
+]
+= d2(p, q)[u, v]
+
+We want to copy the the data between devices in slices/tiles.
+What are the source/target tile coordinates?
+Fro a fixed (m, n, p, q) what is the range of (k, l, u, v)?
+TODO
+
+--------------------------------------------------------------
+
+Reshard KxL tensor from sharding [[0], [1]] to sharding [[1], [0]] on a 2x3 mesh.
+
+Device placement on a 2x3 mesh
+11 12 13  <- devices
+21 22 23
+
+sharding [[0], [1]]
+tensor axis 1
+----------->
++----+----+----+ tensor axis 0 |
+| 11 | 12 | 13 |               |
++----+----+----+               |
+| 21 | 22 | 23 |               |
++----+----+----+               ↓
+
+transform to
+sharding [[1], [0]]
+tensor axis 1
+----------->
++----+----+ tensor axis 0 |
+| 11 | 21 |               |
++----+----+               |
+| 12 | 22 |               |
++----+----+               |
+| 13 | 23 |               |
++----+----+               ↓
+
++-----------------+--------+--------+-----------------+
+|                 |                 |                 |
++                 +                 +                 +
+|       11        |        12       |        13       |
++                 +                 +                 +
+|                 |                 |                 |
++-----------------+--------+--------+-----------------+
+|                 |                 |                 |
++                 +                 +                 +
+|       21        |        22       |        23       |
++                 +                 +                 +
+|                 |                 |                 |
++-----------------+--------+--------+-----------------+
+
++-----------------+--------+--------+-----------------+
+|                          |                          |
++          11              +               21         +
+|                          |                          |
++-----------------+--------+--------+-----------------+
+|                          |                          |
++          12              +               22         +
+|                          |                          |
++-----------------+--------+--------+-----------------+
+|                          |                          |
++          13              +               23         +
+|                          |                          |
++-----------------+--------+--------+-----------------+
+
++-----------------+--------+--------+-----------------+
+|                 |        |        |                 |
++     11  11      + 12  11 + 12  21 +     13  21      +
+|                 |        |        |                 |
++-----------------+--------+--------+-----------------+
+|     11  12      | 12  12 | 12  22 |     13  22      |
++-----------------+--------+--------+-----------------+
+|     21  12      | 22  12 | 22  22 |     23  22      |
++-----------------+--------+--------+-----------------+
+|                 |        |        |                 |
++     21  13      + 22  13 + 22  23 +     23  23      +
+|                 |        |        |                 |
++-----------------+--------+--------+-----------------+
+
+If S and T are the source and target shard sizes along some tensor axis.
+Then we have a period of (S*T)/gcd(S, T). Then the cut pattern repeats.
+TODO
+
+--------------------------------------------------------------
+
+Reshard 6x6 tensor from sharding [[0], []] to sharding [[], [0]] on a 3 mesh.
+
+unsharded 6x6 tensor
+11 12 13 14 15 16
+21 22 23 24 25 26
+31 32 33 34 35 36
+41 42 43 44 45 46
+51 52 53 54 55 56
+61 62 63 64 65 66
+
+sharded on a 3 mesh
+
+sharding = [[0], []]
+
++-------------------+ mesh axis 0 |
+| 11 12 13 14 15 16 |             |
+| 21 22 23 24 25 26 |             |
++-------------------+             |
+| 31 32 33 34 35 36 |             |
+| 41 42 43 44 45 46 |             |
++-------------------+             |
+| 51 52 53 54 55 56 |             |
+| 61 62 63 64 65 66 |             |
++-------------------+             ↓
+
+transform to
+sharding = [[], [0]]
+
+mesh axis 0
+----------->
++-------+-------+-------+
+| 11 12 | 13 14 | 15 16 |
+| 21 22 | 23 24 | 25 26 |
+| 31 32 | 33 34 | 35 36 |
+| 41 42 | 43 44 | 45 46 |
+| 51 52 | 53 54 | 55 56 |
+| 61 62 | 63 64 | 65 66 |
++-------+-------+-------+
+
+%1 = all_to_all %0 on @mesh mesh_axes = [0] split_axis = 1 concat_axis = 0 : tensor<2x6xi8> -> tensor<6x2xi8>
+
+--------------------------------------------------------------
+
+Reshard 4x4 tensor from sharding [[0], [1, 2]] to sharding [[0, 1], [2]] on a 2x2x2 mesh.
+
+unsharded 4x4 tensor
+11 12 13 14
+21 22 23 24
+31 32 33 34
+41 42 43 44
+
+sharded on a 2x2x2 mesh
+
+sharding = [[0], [1, 2]]
+
+mesh axis 2
+----------->
++----+----+ mesh axis 1 | mesh axis 0 |
+| 11 | 12 |             |             |
+| 21 | 22 |             |             |
++----+----+             |             |
+| 13 | 14 |             |             |
+| 23 | 24 |             |             |
++----+----+             ↓             |
++----+----+                           |
+| 31 | 32 |                           |
+| 41...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/76179


More information about the Mlir-commits mailing list