[Mlir-commits] [mlir] [mlir][mesh] Add resharding spmdization on a 1D device mesh (PR #76179)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Dec 21 12:31:18 PST 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-core
Author: Boian Petkantchin (sogartar)
<details>
<summary>Changes</summary>
The current implementation supports only sharding of tensor axes that have size divisible by the mesh axis size.
---
Patch is 83.62 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/76179.diff
15 Files Affected:
- (modified) mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td (+10-1)
- (modified) mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h (+3)
- (modified) mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td (+47-1)
- (added) mlir/include/mlir/Dialect/Mesh/Transforms/ReshardingSpmdizationDoc (+621)
- (added) mlir/include/mlir/Dialect/Mesh/Transforms/Spmdization.h (+35)
- (modified) mlir/include/mlir/Support/MathExtras.h (+11)
- (modified) mlir/lib/Dialect/Mesh/IR/MeshOps.cpp (+160-53)
- (modified) mlir/lib/Dialect/Mesh/Transforms/CMakeLists.txt (+3)
- (added) mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp (+639)
- (modified) mlir/test/Dialect/Mesh/invalid.mlir (+96)
- (modified) mlir/test/Dialect/Mesh/ops.mlir (+49)
- (added) mlir/test/Dialect/Mesh/resharding-spmdization.mlir (+154)
- (modified) mlir/test/lib/Dialect/Mesh/CMakeLists.txt (+1)
- (added) mlir/test/lib/Dialect/Mesh/TestReshardingSpmdization.cpp (+124)
- (modified) mlir/tools/mlir-opt/mlir-opt.cpp (+2)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
index 9d39b1b3329fb4..a9d30dfbb9a76e 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
@@ -33,6 +33,10 @@ def Mesh_Dialect : Dialect {
let useDefaultAttributePrinterParser = 1;
let hasConstantMaterializer = 1;
}
+
+def Mesh_MeshAxis : I<16>;
+def Mesh_MeshAxesAttr : DenseArrayAttrBase<"DenseI16ArrayAttr", "int16_t", "i16">;
+
//===----------------------------------------------------------------------===//
// Mesh Enums.
//===----------------------------------------------------------------------===//
@@ -125,7 +129,7 @@ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
// The tensor is sharded on the first dimension along axis 0 of @mesh0 and
// it is also a partial_sum along mesh axis 1.
- tensor<4x8xf32, #mesh.shard<@mesh0, [[0], [], [1]]>
+ tensor<4x8xf32, #mesh.shard<@mesh0, [[0], []], partial = sum[1]>
// The tensor is sharded on the first dimension along axis 0 of @mesh0 and
// it is also a partial_max along mesh axis 1.
@@ -158,6 +162,11 @@ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
}]>
];
+ let extraClassDeclaration = [{
+ bool operator==(::mlir::Attribute rhs) const;
+ bool operator==(::mlir::mesh::MeshShardingAttr rhs) const;
+ }];
+
let genVerifyDecl = 1;
}
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
index 9077d2eb0189b7..ce7d5d045122d9 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
@@ -30,6 +30,9 @@
namespace mlir {
namespace mesh {
+using MeshAxis = int16_t;
+using MeshAxesAttr = DenseI16ArrayAttr;
+
bool isReductionLoop(IteratorType iType);
bool areReductionAndPartialMatch(IteratorType iType, Partial partial);
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index 784f3eb97763ad..1ed54b6519e4d8 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -10,6 +10,7 @@
#define MLIR_DIALECT_MESH_IR_MESHOPS_TD
include "mlir/Dialect/Mesh/IR/MeshBase.td"
+include "mlir/Dialect/Shape/IR/ShapeBase.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/IR/BuiltinTypes.td"
@@ -95,6 +96,28 @@ def Mesh_ClusterOp : Mesh_Op<"cluster", [Symbol]> {
let hasVerifier = 1;
}
+def Mesh_ClusterShapeOp : Mesh_Op<"cluster_shape", [Pure, DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
+ let summary = "Get the shape of the cluster.";
+ let arguments = (ins
+ FlatSymbolRefAttr:$mesh,
+ DefaultValuedAttr<Mesh_MeshAxesAttr, "{}">:$axes
+ );
+
+ let results = (outs
+ Variadic<Index>:$result
+ );
+
+ let assemblyFormat = [{
+ $mesh (`axes` `=` $axes^)?
+ attr-dict `:` type($result)
+ }];
+
+ let builders = [
+ OpBuilder<(ins "::mlir::mesh::ClusterOp":$mesh)>,
+ OpBuilder<(ins "StringRef":$mesh, "ArrayRef<int16_t>":$axes)>
+ ];
+}
+
def Mesh_ShardOp : Mesh_Op<"shard", [Pure, SameOperandsAndResultType]> {
let summary = "Annotate on how a tensor is sharded across a mesh cluster.";
let description = [{
@@ -186,6 +209,29 @@ def Mesh_ShardOp : Mesh_Op<"shard", [Pure, SameOperandsAndResultType]> {
}];
}
+def Mesh_ProcessIndexOp : Mesh_Op<"process_index", [Pure, DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
+ let summary = "Get the index of current device along specified mesh axis.";
+ let description = [{
+ It is used in the SPMD format of IR.
+ The `axes` mush be non-negative and less than the total number of mesh axes.
+ }];
+ let arguments = (ins
+ FlatSymbolRefAttr:$mesh,
+ DefaultValuedAttr<Mesh_MeshAxesAttr, "{}">:$axes
+ );
+ let results = (outs
+ Variadic<Index>:$result
+ );
+ let assemblyFormat = [{
+ `on` $mesh (`axes` `=` $axes^)?
+ attr-dict `:` type($result)
+ }];
+ let builders = [
+ OpBuilder<(ins "::mlir::mesh::ClusterOp":$mesh)>,
+ OpBuilder<(ins "StringRef":$mesh, "ArrayRef<int16_t>":$axes)>
+ ];
+}
+
//===----------------------------------------------------------------------===//
// collective communication ops
//===----------------------------------------------------------------------===//
@@ -197,7 +243,7 @@ class Mesh_CollectiveCommunicationOpBase<
[DeclareOpInterfaceMethods<SymbolUserOpInterface>])> {
dag commonArgs = (ins
FlatSymbolRefAttr:$mesh,
- DefaultValuedAttr<DenseI16ArrayAttr, "{}">:$mesh_axes
+ DefaultValuedAttr<Mesh_MeshAxesAttr, "{}">:$mesh_axes
);
}
diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/ReshardingSpmdizationDoc b/mlir/include/mlir/Dialect/Mesh/Transforms/ReshardingSpmdizationDoc
new file mode 100644
index 00000000000000..181f07177e0af9
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Mesh/Transforms/ReshardingSpmdizationDoc
@@ -0,0 +1,621 @@
+Reshadring Spmdization Examples
+
+--------------------------------------------------------------
+
+Reshard 2x3 tensor from sharding [[0, 1]] to sharding [[0, 1]] on a 2x3 mesh.
+
+unsharded 2x3 tensor
+11 12 13
+21 22 23
+
+sharded on a 2x3 mesh
+
+sharding = [[0, 1]]
+
+mesh contents:
+
+mesh axis 1
+----------->
++----+----+----+ mesh axis 0 |
+| 11 | 12 | 13 | |
++----+----+----+ |
+| 21 | 22 | 23 | |
++----+----+----+ ↓
+
+Transform into
+sharding = [[1, 0]]
+
+mesh axis 1
+----------->
++----+----+----+ mesh axis 0 |
+| 11 | 13 | 22 | |
++----+----+----+ |
+| 12 | 21 | 23 | |
++----+----+----+ ↓
+
+Swap contents on devices that have the same linear index in the 2 shardings.
+
+--------------------------------------------------------------
+
+Reshard 2x3 tensor from sharding [[0, 1]] to sharding [[1]] on a 2x3 mesh.
+
+unsharded 2x3 tensor
+11 12 13
+21 22 23
+
+sharded on a 2x3 mesh
+
+sharding = [[0, 1]]
+
+mesh contents:
+
+mesh axis 1
+----------->
++----+----+----+ mesh axis 0 |
+| 11 | 12 | 13 | |
++----+----+----+ |
+| 21 | 22 | 23 | |
++----+----+----+ ↓
+
+Transform into
+sharding = [[1]]
+
+mesh axis 1
+----------->
++----+----+----+ mesh axis 0 |
+| 11 | 12 | 13 | |
+| 21 | 22 | 23 | |
++----+----+----+ |
+| 11 | 12 | 13 | |
+| 21 | 22 | 23 | |
++----+----+----+ ↓
+
+All-gather along mesh axis 0.
+
+--------------------------------------------------------------
+
+Reshard 4x6 tensor from sharding [[], [0, 1]] to sharding [[], [0]] on a 2x3 mesh.
+
+unsharded 4x6 tensor
+11 12 13 14 15 16
+21 22 23 24 25 26
+
+sharded on a 2x3 mesh
+
+sharding = [[], [0, 1]]
+
+mesh contents:
+
+mesh axis 1
+----------->
++----+----+----+ mesh axis 0 |
+| 11 | 12 | 13 | |
+| 21 | 22 | 23 | |
++----+----+----+ |
+| 14 | 15 | 16 | |
+| 24 | 25 | 26 | |
++----+----+----+ ↓
+
+Transform into
+sharding = [[], [0]]
+
+mesh axis 1
+----------->
++----------+----------+ mesh axis 0 |
+| 11 12 13 | 11 12 13 | |
+| 21 22 23 | 21 22 23 | |
++----------+----------+ |
+| 14 15 16 | 14 15 16 | |
+| 24 25 26 | 24 25 26 | |
++----------+----------+ ↓
+
+all-gather along mesh axis 1
+
+--------------------------------------------------------------
+
+Reshard 4x8 tensor from sharding [[0], [1, 2]] to sharding [[0], [2]] on a 2x2x2 mesh.
+
+unsharded 4x8 tensor
+11 12 13 14 15 16 17 18
+21 22 23 24 25 26 27 28
+31 32 33 34 35 36 37 38
+41 42 43 44 45 46 47 48
+
+sharded on a 2x2x2 mesh
+
+sharding = [[0], [1, 2]]
+
+mesh contents:
+
+mesh axis 2
+----------->
++-------+-------+ mesh axis 1 | mesh axis 0 |
+| 11 12 | 13 14 | | |
+| 21 22 | 23 24 | | |
++-------+-------+ | |
+| 15 16 | 17 18 | | |
+| 25 26 | 27 28 | | |
++-------+-------+ ↓ |
++-------+-------+ |
+| 31 32 | 33 34 | |
+| 41 42 | 43 44 | |
++-------+-------+ |
+| 35 36 | 37 38 | |
+| 45 46 | 47 48 | |
++-------+-------+ ↓
+
+Transform into
+sharding = [[0], [2]]
+
+mesh axis 2
+----------->
++-------------+-------------+ mesh axis 1 | mesh axis 0 |
+| 11 12 13 14 | 15 16 17 18 | | |
+| 21 22 23 24 | 25 26 27 28 | | |
++-------------+-------------+ | |
+| 11 12 13 14 | 15 16 17 18 | | |
+| 21 22 23 24 | 25 26 27 28 | | |
++-------------+-------------+ ↓ |
++-------------+-------------+ |
+| 31 32 33 34 | 35 36 37 38 | |
+| 41 42 43 44 | 45 46 47 48 | |
++-------------+-------------+ |
+| 31 32 33 34 | 35 36 37 38 | |
+| 41 42 43 44 | 45 46 47 48 | |
++-------------+-------------+ ↓
+
+Can't be done with just an all-gather along mesh axis 1.
+Can be handled by multiple resharding transformations
+[[0], [1, 2]] -> [[0], [2, 1]] -> [[0], [2]]
+
+--------------------------------------------------------------
+
+Reshard 6x6 tensor from sharding [[0], [1]] to sharding [[1], [0]] on a 2x3 mesh.
+
+unsharded 6x6 tensor
+11 12 13 14 15 16
+21 22 23 24 25 26
+31 32 33 34 35 36
+41 42 43 44 45 46
+51 52 53 54 55 56
+61 62 63 64 65 66
+
+sharded on a 2x3 mesh
+
+sharding = [[0], [1]]
+
+mesh axis 1
+----------->
++-------+-------+-------+ mesh axis 0 |
+| 11 12 | 13 14 | 15 16 | |
+| 21 22 | 23 24 | 25 26 | |
+| 31 32 | 33 34 | 35 36 | |
++-------+-------+-------+ |
+| 41 42 | 43 44 | 45 46 | |
+| 51 52 | 53 54 | 55 56 | |
+| 61 62 | 63 64 | 65 66 | |
++-------+-------+-------+ ↓
+
+transform to
+sharding = [[1], [0]]
+
+mesh axis 1
+----------->
++----------+----------+----------+ mesh axis 0 |
+| 11 12 13 | 31 32 33 | 51 52 53 | |
+| 21 22 23 | 41 42 43 | 61 62 63 | |
++----------+----------+----------+ |
+| 14 15 16 | 34 35 36 | 54 55 56 | |
+| 24 25 26 | 44 45 46 | 64 65 66 | |
++----------+----------+----------+ ↓
+
+mesh axis 0
+----------->
++----------+----------+ mesh axis 1 |
+| 11 12 13 | 14 15 16 | |
+| 21 22 23 | 24 25 26 | |
++----------+----------+ |
+| 31 32 33 | 34 35 36 | |
+| 41 42 43 | 44 45 46 | |
++----------+----------+ |
+| 51 52 53 | 54 55 56 | |
+| 61 62 63 | 64 65 66 | |
++----------+----------+ ↓
+
+TODO
+
+--------------------------------------------------------------
+
+Reshard 6x6 tensor from sharding [[0], [1]] to sharding [[1], [0]] on a 2x6 mesh.
+
+unsharded 6x6 tensor
+11 12 13 14 15 16
+21 22 23 24 25 26
+31 32 33 34 35 36
+41 42 43 44 45 46
+51 52 53 54 55 56
+61 62 63 64 65 66
+
+shard on 2x6 mesh
+
+sharding = [[0], [1]]
+
+mesh axis 1
+----------->
++----+----+----+----+----+----+ mesh axis 0 |
+| 11 | 12 | 13 ‖ 14 | 15 | 16 | |
+| 21 | 22 | 23 ‖ 24 | 23 | 26 | |
+| 31 | 32 | 33 ‖ 34 | 35 | 36 | |
++----+----+----+----+----+----+ |
+| 41 | 42 | 43 ‖ 44 | 45 | 46 | |
+| 51 | 52 | 53 ‖ 54 | 55 | 56 | |
+| 61 | 62 | 63 ‖ 64 | 65 | 66 | |
++----+----+----+----+----+----+ ↓
+
+
+transform to
+sharding = [[1], [0]]
+
+mesh axis 0
+----------->
++----------+----------+ mesh axis 1 |
+| 11 12 13 | 14 15 16 | |
++----------+----------+ |
+| 21 22 23 | 24 25 26 | |
++----------+----------+ |
+| 31 32 33 | 34 35 36 | |
++==========+==========+ |
+| 41 42 43 | 44 45 46 | |
++----------+----------+ |
+| 51 52 53 | 54 55 56 | |
++----------+----------+ |
+| 61 62 63 | 64 65 66 | |
++----------+----------+ ↓
+
+TODO
+
+--------------------------------------------------------------
+
+Reshard KxL tensor from [[0], [1]] to [[1], [0]] on MxN mesh.
+
+M x N mesh.
+K x L tensor t.
+d(m, n) the tensor on device (m, n).
+
+sharding = [[0], [1]]
+Tensor shard s on each device has size (K ceildiv M, L ceildiv N).
+d(m, n)[k, l] -> t[m * (K ceildiv M) + k, n * (L ceildiv N) + l]
+
+substitute
+i <- m * (K ceildiv M) + k
+j <- n * (L ceildiv N) + l
+
+m -> i floordiv (K ceildiv M)
+n -> j floordiv (L ceildiv N)
+k -> i % (K ceildiv M)
+l -> j % (L ceildiv N)
+
+For the inverse map we get
+t[i, j] -> d(
+ i floordiv (K ceildiv M), j floordiv (L ceildiv N)
+)[
+ i % (K ceildiv M), j % (L ceildiv N)
+]
+
+Check:
+i = 13, j = 17, M = 3, N = 4, K = 16, L = 23
+t[13, 17] = d(
+ 13 floordiv (16 ceildiv 3),
+ 17 floordiv (23 ceilvid 4)
+)[
+ 13 % (16 ceildiv 3),
+ 17 % (23 ceilvid 4)
+]
+= d(
+ 13 floordiv 6,
+ 17 floordiv 6
+)[
+ 13 % 6,
+ 17 % 6
+]
+= d(2, 2)[1, 5]
+= t[
+ 2 * (16 ceildiv 3) + 1,
+ 2 * (23 ceildiv 4) + 5
+]
+= t[
+ 2 * 6 + 1,
+ 2 * 6 + 5
+]
+= t[13, 17]
+
+
+sharding = [[1], [0]]
+Tensor shard s on each device has size (K ceildiv N, L ceildiv M).
+d(m, n)[k, l] -> t[n * (K ceildiv N) + k, m * (L ceildiv M) + l]
+
+substitute
+i <- n * (K ceildiv N) + k
+j <- m * (L ceildiv M) + l
+
+m -> j floordiv (L ceildiv M)
+n -> i floordiv (K ceildiv N)
+k -> i % (K ceildiv N)
+l -> j % (L ceildiv M)
+
+For the inverse map we get
+t[i, j] -> d(
+ j floordiv (L ceildiv M), i floordiv (K ceildiv N)
+)[
+ i % (K ceildiv N), j % (L ceildiv M)
+]
+
+Check:
+i = 9, j = 19, M = 5, N = 2, K = 27, L = 14
+t[9, 19] = d(
+ 19 floordiv (14 ceildiv 5),
+ 9 floordiv (27 ceildiv 2)
+)[
+ 9 % (27 ceildiv 2),
+ 19 % (14 ceildiv 5)
+]
+= d(
+ 19 floordiv 3,
+ 9 floordiv 14
+)[
+ 9 % 14
+ 19 % 3
+]
+= d(6, 0)[9, 1]
+= t[
+ 0 * (27 ceildiv 2) + 9,
+ 6 * (14 ceildiv 5) + 1
+]
+= t[
+ 0 * 14 + 9,
+ 6 * 3 + 1
+]
+= t[9, 19]
+
+sharding = [[0], [1]]
+d(m, n)[k, l] -> t[m * (K ceildiv M) + k, n * (L ceildiv N) + l]
+t[i, j] -> d(i floordiv (K ceildiv M), j floordiv (L ceildiv N))[i % (K ceildiv M), j % (L ceildiv N)]
+
+sharding = [[1], [0]]
+d(m, n)[k, l] -> t[n * (K ceildiv N) + k, m * (L ceildiv M) + l]
+t[i, j] -> d(j floordiv (L ceildiv M), i floordiv (K ceildiv N))[i % (K ceildiv N), j % (L ceildiv M)]
+
+sharding [[0], [1]] -> [[1], [0]]
+d1(m, n) the tensor on device (m, n) for sharding sharding [[0], [1]].
+d2(m, n) the tensor on device (m, n) for sharding sharding [[1], [0]].
+d1(m, n)[k, l] ->
+t[m * (K ceildiv M) + k, n * (L ceildiv N) + l] ->
+d2(
+ (m * (L ceildiv M) + l) floordiv (L ceildiv M),
+ (n * (K ceildiv N) + k) floordiv (K ceildiv N)
+)[
+ (n * (K ceildiv N) + k) % (K ceildiv N),
+ (m * (L ceildiv M) + l) % (L ceildiv M)
+]
+= d2(p, q)[u, v]
+
+We want to copy the the data between devices in slices/tiles.
+What are the source/target tile coordinates?
+Fro a fixed (m, n, p, q) what is the range of (k, l, u, v)?
+TODO
+
+--------------------------------------------------------------
+
+Reshard KxL tensor from sharding [[0], [1]] to sharding [[1], [0]] on a 2x3 mesh.
+
+Device placement on a 2x3 mesh
+11 12 13 <- devices
+21 22 23
+
+sharding [[0], [1]]
+tensor axis 1
+----------->
++----+----+----+ tensor axis 0 |
+| 11 | 12 | 13 | |
++----+----+----+ |
+| 21 | 22 | 23 | |
++----+----+----+ ↓
+
+transform to
+sharding [[1], [0]]
+tensor axis 1
+----------->
++----+----+ tensor axis 0 |
+| 11 | 21 | |
++----+----+ |
+| 12 | 22 | |
++----+----+ |
+| 13 | 23 | |
++----+----+ ↓
+
++-----------------+--------+--------+-----------------+
+| | | |
++ + + +
+| 11 | 12 | 13 |
++ + + +
+| | | |
++-----------------+--------+--------+-----------------+
+| | | |
++ + + +
+| 21 | 22 | 23 |
++ + + +
+| | | |
++-----------------+--------+--------+-----------------+
+
++-----------------+--------+--------+-----------------+
+| | |
++ 11 + 21 +
+| | |
++-----------------+--------+--------+-----------------+
+| | |
++ 12 + 22 +
+| | |
++-----------------+--------+--------+-----------------+
+| | |
++ 13 + 23 +
+| | |
++-----------------+--------+--------+-----------------+
+
++-----------------+--------+--------+-----------------+
+| | | | |
++ 11 11 + 12 11 + 12 21 + 13 21 +
+| | | | |
++-----------------+--------+--------+-----------------+
+| 11 12 | 12 12 | 12 22 | 13 22 |
++-----------------+--------+--------+-----------------+
+| 21 12 | 22 12 | 22 22 | 23 22 |
++-----------------+--------+--------+-----------------+
+| | | | |
++ 21 13 + 22 13 + 22 23 + 23 23 +
+| | | | |
++-----------------+--------+--------+-----------------+
+
+If S and T are the source and target shard sizes along some tensor axis.
+Then we have a period of (S*T)/gcd(S, T). Then the cut pattern repeats.
+TODO
+
+--------------------------------------------------------------
+
+Reshard 6x6 tensor from sharding [[0], []] to sharding [[], [0]] on a 3 mesh.
+
+unsharded 6x6 tensor
+11 12 13 14 15 16
+21 22 23 24 25 26
+31 32 33 34 35 36
+41 42 43 44 45 46
+51 52 53 54 55 56
+61 62 63 64 65 66
+
+sharded on a 3 mesh
+
+sharding = [[0], []]
+
++-------------------+ mesh axis 0 |
+| 11 12 13 14 15 16 | |
+| 21 22 23 24 25 26 | |
++-------------------+ |
+| 31 32 33 34 35 36 | |
+| 41 42 43 44 45 46 | |
++-------------------+ |
+| 51 52 53 54 55 56 | |
+| 61 62 63 64 65 66 | |
++-------------------+ ↓
+
+transform to
+sharding = [[], [0]]
+
+mesh axis 0
+----------->
++-------+-------+-------+
+| 11 12 | 13 14 | 15 16 |
+| 21 22 | 23 24 | 25 26 |
+| 31 32 | 33 34 | 35 36 |
+| 41 42 | 43 44 | 45 46 |
+| 51 52 | 53 54 | 55 56 |
+| 61 62 | 63 64 | 65 66 |
++-------+-------+-------+
+
+%1 = all_to_all %0 on @mesh mesh_axes = [0] split_axis = 1 concat_axis = 0 : tensor<2x6xi8> -> tensor<6x2xi8>
+
+--------------------------------------------------------------
+
+Reshard 4x4 tensor from sharding [[0], [1, 2]] to sharding [[0, 1], [2]] on a 2x2x2 mesh.
+
+unsharded 4x4 tensor
+11 12 13 14
+21 22 23 24
+31 32 33 34
+41 42 43 44
+
+sharded on a 2x2x2 mesh
+
+sharding = [[0], [1, 2]]
+
+mesh axis 2
+----------->
++----+----+ mesh axis 1 | mesh axis 0 |
+| 11 | 12 | | |
+| 21 | 22 | | |
++----+----+ | |
+| 13 | 14 | | |
+| 23 | 24 | | |
++----+----+ ↓ |
++----+----+ |
+| 31 | 32 | |
+| 41...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/76179
More information about the Mlir-commits
mailing list