[Mlir-commits] [mlir] [mlir][mesh] Add resharding spmdization on a 1D device mesh (PR #76179)
Boian Petkantchin
llvmlistbot at llvm.org
Tue Jan 2 09:29:23 PST 2024
https://github.com/sogartar updated https://github.com/llvm/llvm-project/pull/76179
>From 28111bed8dbc8bf057070baa56edfb530728dcb1 Mon Sep 17 00:00:00 2001
From: Boian Petkantchin <boian.petkantchin at amd.com>
Date: Wed, 6 Dec 2023 07:58:53 -0800
Subject: [PATCH 1/5] [mlir][mesh] Add resharding spmdization on a 1D device
mesh
The current implementation supports only sharding of tensor axes that have size
divisible by the mesh axis size.
---
mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td | 11 +-
mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h | 3 +
mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td | 48 +-
.../Mesh/Transforms/ReshardingSpmdizationDoc | 621 +++++++++++++++++
.../Dialect/Mesh/Transforms/Spmdization.h | 35 +
mlir/include/mlir/Support/MathExtras.h | 11 +
mlir/lib/Dialect/Mesh/IR/MeshOps.cpp | 213 ++++--
.../Dialect/Mesh/Transforms/CMakeLists.txt | 3 +
.../Dialect/Mesh/Transforms/Spmdization.cpp | 639 ++++++++++++++++++
mlir/test/Dialect/Mesh/invalid.mlir | 96 +++
mlir/test/Dialect/Mesh/ops.mlir | 49 ++
.../Dialect/Mesh/resharding-spmdization.mlir | 154 +++++
mlir/test/lib/Dialect/Mesh/CMakeLists.txt | 1 +
.../Mesh/TestReshardingSpmdization.cpp | 124 ++++
mlir/tools/mlir-opt/mlir-opt.cpp | 2 +
15 files changed, 1955 insertions(+), 55 deletions(-)
create mode 100644 mlir/include/mlir/Dialect/Mesh/Transforms/ReshardingSpmdizationDoc
create mode 100644 mlir/include/mlir/Dialect/Mesh/Transforms/Spmdization.h
create mode 100644 mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
create mode 100644 mlir/test/Dialect/Mesh/resharding-spmdization.mlir
create mode 100644 mlir/test/lib/Dialect/Mesh/TestReshardingSpmdization.cpp
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
index 9d39b1b3329fb4..a9d30dfbb9a76e 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
@@ -33,6 +33,10 @@ def Mesh_Dialect : Dialect {
let useDefaultAttributePrinterParser = 1;
let hasConstantMaterializer = 1;
}
+
+def Mesh_MeshAxis : I<16>;
+def Mesh_MeshAxesAttr : DenseArrayAttrBase<"DenseI16ArrayAttr", "int16_t", "i16">;
+
//===----------------------------------------------------------------------===//
// Mesh Enums.
//===----------------------------------------------------------------------===//
@@ -125,7 +129,7 @@ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
// The tensor is sharded on the first dimension along axis 0 of @mesh0 and
// it is also a partial_sum along mesh axis 1.
- tensor<4x8xf32, #mesh.shard<@mesh0, [[0], [], [1]]>
+ tensor<4x8xf32, #mesh.shard<@mesh0, [[0], []], partial = sum[1]>
// The tensor is sharded on the first dimension along axis 0 of @mesh0 and
// it is also a partial_max along mesh axis 1.
@@ -158,6 +162,11 @@ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
}]>
];
+ let extraClassDeclaration = [{
+ bool operator==(::mlir::Attribute rhs) const;
+ bool operator==(::mlir::mesh::MeshShardingAttr rhs) const;
+ }];
+
let genVerifyDecl = 1;
}
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
index 9077d2eb0189b7..ce7d5d045122d9 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
@@ -30,6 +30,9 @@
namespace mlir {
namespace mesh {
+using MeshAxis = int16_t;
+using MeshAxesAttr = DenseI16ArrayAttr;
+
bool isReductionLoop(IteratorType iType);
bool areReductionAndPartialMatch(IteratorType iType, Partial partial);
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index 784f3eb97763ad..1ed54b6519e4d8 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -10,6 +10,7 @@
#define MLIR_DIALECT_MESH_IR_MESHOPS_TD
include "mlir/Dialect/Mesh/IR/MeshBase.td"
+include "mlir/Dialect/Shape/IR/ShapeBase.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/IR/BuiltinTypes.td"
@@ -95,6 +96,28 @@ def Mesh_ClusterOp : Mesh_Op<"cluster", [Symbol]> {
let hasVerifier = 1;
}
+def Mesh_ClusterShapeOp : Mesh_Op<"cluster_shape", [Pure, DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
+ let summary = "Get the shape of the cluster.";
+ let arguments = (ins
+ FlatSymbolRefAttr:$mesh,
+ DefaultValuedAttr<Mesh_MeshAxesAttr, "{}">:$axes
+ );
+
+ let results = (outs
+ Variadic<Index>:$result
+ );
+
+ let assemblyFormat = [{
+ $mesh (`axes` `=` $axes^)?
+ attr-dict `:` type($result)
+ }];
+
+ let builders = [
+ OpBuilder<(ins "::mlir::mesh::ClusterOp":$mesh)>,
+ OpBuilder<(ins "StringRef":$mesh, "ArrayRef<int16_t>":$axes)>
+ ];
+}
+
def Mesh_ShardOp : Mesh_Op<"shard", [Pure, SameOperandsAndResultType]> {
let summary = "Annotate on how a tensor is sharded across a mesh cluster.";
let description = [{
@@ -186,6 +209,29 @@ def Mesh_ShardOp : Mesh_Op<"shard", [Pure, SameOperandsAndResultType]> {
}];
}
+def Mesh_ProcessIndexOp : Mesh_Op<"process_index", [Pure, DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
+ let summary = "Get the index of current device along specified mesh axis.";
+ let description = [{
+ It is used in the SPMD format of IR.
+ The `axes` mush be non-negative and less than the total number of mesh axes.
+ }];
+ let arguments = (ins
+ FlatSymbolRefAttr:$mesh,
+ DefaultValuedAttr<Mesh_MeshAxesAttr, "{}">:$axes
+ );
+ let results = (outs
+ Variadic<Index>:$result
+ );
+ let assemblyFormat = [{
+ `on` $mesh (`axes` `=` $axes^)?
+ attr-dict `:` type($result)
+ }];
+ let builders = [
+ OpBuilder<(ins "::mlir::mesh::ClusterOp":$mesh)>,
+ OpBuilder<(ins "StringRef":$mesh, "ArrayRef<int16_t>":$axes)>
+ ];
+}
+
//===----------------------------------------------------------------------===//
// collective communication ops
//===----------------------------------------------------------------------===//
@@ -197,7 +243,7 @@ class Mesh_CollectiveCommunicationOpBase<
[DeclareOpInterfaceMethods<SymbolUserOpInterface>])> {
dag commonArgs = (ins
FlatSymbolRefAttr:$mesh,
- DefaultValuedAttr<DenseI16ArrayAttr, "{}">:$mesh_axes
+ DefaultValuedAttr<Mesh_MeshAxesAttr, "{}">:$mesh_axes
);
}
diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/ReshardingSpmdizationDoc b/mlir/include/mlir/Dialect/Mesh/Transforms/ReshardingSpmdizationDoc
new file mode 100644
index 00000000000000..181f07177e0af9
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Mesh/Transforms/ReshardingSpmdizationDoc
@@ -0,0 +1,621 @@
+Reshadring Spmdization Examples
+
+--------------------------------------------------------------
+
+Reshard 2x3 tensor from sharding [[0, 1]] to sharding [[0, 1]] on a 2x3 mesh.
+
+unsharded 2x3 tensor
+11 12 13
+21 22 23
+
+sharded on a 2x3 mesh
+
+sharding = [[0, 1]]
+
+mesh contents:
+
+mesh axis 1
+----------->
++----+----+----+ mesh axis 0 |
+| 11 | 12 | 13 | |
++----+----+----+ |
+| 21 | 22 | 23 | |
++----+----+----+ ↓
+
+Transform into
+sharding = [[1, 0]]
+
+mesh axis 1
+----------->
++----+----+----+ mesh axis 0 |
+| 11 | 13 | 22 | |
++----+----+----+ |
+| 12 | 21 | 23 | |
++----+----+----+ ↓
+
+Swap contents on devices that have the same linear index in the 2 shardings.
+
+--------------------------------------------------------------
+
+Reshard 2x3 tensor from sharding [[0, 1]] to sharding [[1]] on a 2x3 mesh.
+
+unsharded 2x3 tensor
+11 12 13
+21 22 23
+
+sharded on a 2x3 mesh
+
+sharding = [[0, 1]]
+
+mesh contents:
+
+mesh axis 1
+----------->
++----+----+----+ mesh axis 0 |
+| 11 | 12 | 13 | |
++----+----+----+ |
+| 21 | 22 | 23 | |
++----+----+----+ ↓
+
+Transform into
+sharding = [[1]]
+
+mesh axis 1
+----------->
++----+----+----+ mesh axis 0 |
+| 11 | 12 | 13 | |
+| 21 | 22 | 23 | |
++----+----+----+ |
+| 11 | 12 | 13 | |
+| 21 | 22 | 23 | |
++----+----+----+ ↓
+
+All-gather along mesh axis 0.
+
+--------------------------------------------------------------
+
+Reshard 4x6 tensor from sharding [[], [0, 1]] to sharding [[], [0]] on a 2x3 mesh.
+
+unsharded 4x6 tensor
+11 12 13 14 15 16
+21 22 23 24 25 26
+
+sharded on a 2x3 mesh
+
+sharding = [[], [0, 1]]
+
+mesh contents:
+
+mesh axis 1
+----------->
++----+----+----+ mesh axis 0 |
+| 11 | 12 | 13 | |
+| 21 | 22 | 23 | |
++----+----+----+ |
+| 14 | 15 | 16 | |
+| 24 | 25 | 26 | |
++----+----+----+ ↓
+
+Transform into
+sharding = [[], [0]]
+
+mesh axis 1
+----------->
++----------+----------+ mesh axis 0 |
+| 11 12 13 | 11 12 13 | |
+| 21 22 23 | 21 22 23 | |
++----------+----------+ |
+| 14 15 16 | 14 15 16 | |
+| 24 25 26 | 24 25 26 | |
++----------+----------+ ↓
+
+all-gather along mesh axis 1
+
+--------------------------------------------------------------
+
+Reshard 4x8 tensor from sharding [[0], [1, 2]] to sharding [[0], [2]] on a 2x2x2 mesh.
+
+unsharded 4x8 tensor
+11 12 13 14 15 16 17 18
+21 22 23 24 25 26 27 28
+31 32 33 34 35 36 37 38
+41 42 43 44 45 46 47 48
+
+sharded on a 2x2x2 mesh
+
+sharding = [[0], [1, 2]]
+
+mesh contents:
+
+mesh axis 2
+----------->
++-------+-------+ mesh axis 1 | mesh axis 0 |
+| 11 12 | 13 14 | | |
+| 21 22 | 23 24 | | |
++-------+-------+ | |
+| 15 16 | 17 18 | | |
+| 25 26 | 27 28 | | |
++-------+-------+ ↓ |
++-------+-------+ |
+| 31 32 | 33 34 | |
+| 41 42 | 43 44 | |
++-------+-------+ |
+| 35 36 | 37 38 | |
+| 45 46 | 47 48 | |
++-------+-------+ ↓
+
+Transform into
+sharding = [[0], [2]]
+
+mesh axis 2
+----------->
++-------------+-------------+ mesh axis 1 | mesh axis 0 |
+| 11 12 13 14 | 15 16 17 18 | | |
+| 21 22 23 24 | 25 26 27 28 | | |
++-------------+-------------+ | |
+| 11 12 13 14 | 15 16 17 18 | | |
+| 21 22 23 24 | 25 26 27 28 | | |
++-------------+-------------+ ↓ |
++-------------+-------------+ |
+| 31 32 33 34 | 35 36 37 38 | |
+| 41 42 43 44 | 45 46 47 48 | |
++-------------+-------------+ |
+| 31 32 33 34 | 35 36 37 38 | |
+| 41 42 43 44 | 45 46 47 48 | |
++-------------+-------------+ ↓
+
+Can't be done with just an all-gather along mesh axis 1.
+Can be handled by multiple resharding transformations
+[[0], [1, 2]] -> [[0], [2, 1]] -> [[0], [2]]
+
+--------------------------------------------------------------
+
+Reshard 6x6 tensor from sharding [[0], [1]] to sharding [[1], [0]] on a 2x3 mesh.
+
+unsharded 6x6 tensor
+11 12 13 14 15 16
+21 22 23 24 25 26
+31 32 33 34 35 36
+41 42 43 44 45 46
+51 52 53 54 55 56
+61 62 63 64 65 66
+
+sharded on a 2x3 mesh
+
+sharding = [[0], [1]]
+
+mesh axis 1
+----------->
++-------+-------+-------+ mesh axis 0 |
+| 11 12 | 13 14 | 15 16 | |
+| 21 22 | 23 24 | 25 26 | |
+| 31 32 | 33 34 | 35 36 | |
++-------+-------+-------+ |
+| 41 42 | 43 44 | 45 46 | |
+| 51 52 | 53 54 | 55 56 | |
+| 61 62 | 63 64 | 65 66 | |
++-------+-------+-------+ ↓
+
+transform to
+sharding = [[1], [0]]
+
+mesh axis 1
+----------->
++----------+----------+----------+ mesh axis 0 |
+| 11 12 13 | 31 32 33 | 51 52 53 | |
+| 21 22 23 | 41 42 43 | 61 62 63 | |
++----------+----------+----------+ |
+| 14 15 16 | 34 35 36 | 54 55 56 | |
+| 24 25 26 | 44 45 46 | 64 65 66 | |
++----------+----------+----------+ ↓
+
+mesh axis 0
+----------->
++----------+----------+ mesh axis 1 |
+| 11 12 13 | 14 15 16 | |
+| 21 22 23 | 24 25 26 | |
++----------+----------+ |
+| 31 32 33 | 34 35 36 | |
+| 41 42 43 | 44 45 46 | |
++----------+----------+ |
+| 51 52 53 | 54 55 56 | |
+| 61 62 63 | 64 65 66 | |
++----------+----------+ ↓
+
+TODO
+
+--------------------------------------------------------------
+
+Reshard 6x6 tensor from sharding [[0], [1]] to sharding [[1], [0]] on a 2x6 mesh.
+
+unsharded 6x6 tensor
+11 12 13 14 15 16
+21 22 23 24 25 26
+31 32 33 34 35 36
+41 42 43 44 45 46
+51 52 53 54 55 56
+61 62 63 64 65 66
+
+shard on 2x6 mesh
+
+sharding = [[0], [1]]
+
+mesh axis 1
+----------->
++----+----+----+----+----+----+ mesh axis 0 |
+| 11 | 12 | 13 ‖ 14 | 15 | 16 | |
+| 21 | 22 | 23 ‖ 24 | 23 | 26 | |
+| 31 | 32 | 33 ‖ 34 | 35 | 36 | |
++----+----+----+----+----+----+ |
+| 41 | 42 | 43 ‖ 44 | 45 | 46 | |
+| 51 | 52 | 53 ‖ 54 | 55 | 56 | |
+| 61 | 62 | 63 ‖ 64 | 65 | 66 | |
++----+----+----+----+----+----+ ↓
+
+
+transform to
+sharding = [[1], [0]]
+
+mesh axis 0
+----------->
++----------+----------+ mesh axis 1 |
+| 11 12 13 | 14 15 16 | |
++----------+----------+ |
+| 21 22 23 | 24 25 26 | |
++----------+----------+ |
+| 31 32 33 | 34 35 36 | |
++==========+==========+ |
+| 41 42 43 | 44 45 46 | |
++----------+----------+ |
+| 51 52 53 | 54 55 56 | |
++----------+----------+ |
+| 61 62 63 | 64 65 66 | |
++----------+----------+ ↓
+
+TODO
+
+--------------------------------------------------------------
+
+Reshard KxL tensor from [[0], [1]] to [[1], [0]] on MxN mesh.
+
+M x N mesh.
+K x L tensor t.
+d(m, n) the tensor on device (m, n).
+
+sharding = [[0], [1]]
+Tensor shard s on each device has size (K ceildiv M, L ceildiv N).
+d(m, n)[k, l] -> t[m * (K ceildiv M) + k, n * (L ceildiv N) + l]
+
+substitute
+i <- m * (K ceildiv M) + k
+j <- n * (L ceildiv N) + l
+
+m -> i floordiv (K ceildiv M)
+n -> j floordiv (L ceildiv N)
+k -> i % (K ceildiv M)
+l -> j % (L ceildiv N)
+
+For the inverse map we get
+t[i, j] -> d(
+ i floordiv (K ceildiv M), j floordiv (L ceildiv N)
+)[
+ i % (K ceildiv M), j % (L ceildiv N)
+]
+
+Check:
+i = 13, j = 17, M = 3, N = 4, K = 16, L = 23
+t[13, 17] = d(
+ 13 floordiv (16 ceildiv 3),
+ 17 floordiv (23 ceilvid 4)
+)[
+ 13 % (16 ceildiv 3),
+ 17 % (23 ceilvid 4)
+]
+= d(
+ 13 floordiv 6,
+ 17 floordiv 6
+)[
+ 13 % 6,
+ 17 % 6
+]
+= d(2, 2)[1, 5]
+= t[
+ 2 * (16 ceildiv 3) + 1,
+ 2 * (23 ceildiv 4) + 5
+]
+= t[
+ 2 * 6 + 1,
+ 2 * 6 + 5
+]
+= t[13, 17]
+
+
+sharding = [[1], [0]]
+Tensor shard s on each device has size (K ceildiv N, L ceildiv M).
+d(m, n)[k, l] -> t[n * (K ceildiv N) + k, m * (L ceildiv M) + l]
+
+substitute
+i <- n * (K ceildiv N) + k
+j <- m * (L ceildiv M) + l
+
+m -> j floordiv (L ceildiv M)
+n -> i floordiv (K ceildiv N)
+k -> i % (K ceildiv N)
+l -> j % (L ceildiv M)
+
+For the inverse map we get
+t[i, j] -> d(
+ j floordiv (L ceildiv M), i floordiv (K ceildiv N)
+)[
+ i % (K ceildiv N), j % (L ceildiv M)
+]
+
+Check:
+i = 9, j = 19, M = 5, N = 2, K = 27, L = 14
+t[9, 19] = d(
+ 19 floordiv (14 ceildiv 5),
+ 9 floordiv (27 ceildiv 2)
+)[
+ 9 % (27 ceildiv 2),
+ 19 % (14 ceildiv 5)
+]
+= d(
+ 19 floordiv 3,
+ 9 floordiv 14
+)[
+ 9 % 14
+ 19 % 3
+]
+= d(6, 0)[9, 1]
+= t[
+ 0 * (27 ceildiv 2) + 9,
+ 6 * (14 ceildiv 5) + 1
+]
+= t[
+ 0 * 14 + 9,
+ 6 * 3 + 1
+]
+= t[9, 19]
+
+sharding = [[0], [1]]
+d(m, n)[k, l] -> t[m * (K ceildiv M) + k, n * (L ceildiv N) + l]
+t[i, j] -> d(i floordiv (K ceildiv M), j floordiv (L ceildiv N))[i % (K ceildiv M), j % (L ceildiv N)]
+
+sharding = [[1], [0]]
+d(m, n)[k, l] -> t[n * (K ceildiv N) + k, m * (L ceildiv M) + l]
+t[i, j] -> d(j floordiv (L ceildiv M), i floordiv (K ceildiv N))[i % (K ceildiv N), j % (L ceildiv M)]
+
+sharding [[0], [1]] -> [[1], [0]]
+d1(m, n) the tensor on device (m, n) for sharding sharding [[0], [1]].
+d2(m, n) the tensor on device (m, n) for sharding sharding [[1], [0]].
+d1(m, n)[k, l] ->
+t[m * (K ceildiv M) + k, n * (L ceildiv N) + l] ->
+d2(
+ (m * (L ceildiv M) + l) floordiv (L ceildiv M),
+ (n * (K ceildiv N) + k) floordiv (K ceildiv N)
+)[
+ (n * (K ceildiv N) + k) % (K ceildiv N),
+ (m * (L ceildiv M) + l) % (L ceildiv M)
+]
+= d2(p, q)[u, v]
+
+We want to copy the the data between devices in slices/tiles.
+What are the source/target tile coordinates?
+Fro a fixed (m, n, p, q) what is the range of (k, l, u, v)?
+TODO
+
+--------------------------------------------------------------
+
+Reshard KxL tensor from sharding [[0], [1]] to sharding [[1], [0]] on a 2x3 mesh.
+
+Device placement on a 2x3 mesh
+11 12 13 <- devices
+21 22 23
+
+sharding [[0], [1]]
+tensor axis 1
+----------->
++----+----+----+ tensor axis 0 |
+| 11 | 12 | 13 | |
++----+----+----+ |
+| 21 | 22 | 23 | |
++----+----+----+ ↓
+
+transform to
+sharding [[1], [0]]
+tensor axis 1
+----------->
++----+----+ tensor axis 0 |
+| 11 | 21 | |
++----+----+ |
+| 12 | 22 | |
++----+----+ |
+| 13 | 23 | |
++----+----+ ↓
+
++-----------------+--------+--------+-----------------+
+| | | |
++ + + +
+| 11 | 12 | 13 |
++ + + +
+| | | |
++-----------------+--------+--------+-----------------+
+| | | |
++ + + +
+| 21 | 22 | 23 |
++ + + +
+| | | |
++-----------------+--------+--------+-----------------+
+
++-----------------+--------+--------+-----------------+
+| | |
++ 11 + 21 +
+| | |
++-----------------+--------+--------+-----------------+
+| | |
++ 12 + 22 +
+| | |
++-----------------+--------+--------+-----------------+
+| | |
++ 13 + 23 +
+| | |
++-----------------+--------+--------+-----------------+
+
++-----------------+--------+--------+-----------------+
+| | | | |
++ 11 11 + 12 11 + 12 21 + 13 21 +
+| | | | |
++-----------------+--------+--------+-----------------+
+| 11 12 | 12 12 | 12 22 | 13 22 |
++-----------------+--------+--------+-----------------+
+| 21 12 | 22 12 | 22 22 | 23 22 |
++-----------------+--------+--------+-----------------+
+| | | | |
++ 21 13 + 22 13 + 22 23 + 23 23 +
+| | | | |
++-----------------+--------+--------+-----------------+
+
+If S and T are the source and target shard sizes along some tensor axis.
+Then we have a period of (S*T)/gcd(S, T). Then the cut pattern repeats.
+TODO
+
+--------------------------------------------------------------
+
+Reshard 6x6 tensor from sharding [[0], []] to sharding [[], [0]] on a 3 mesh.
+
+unsharded 6x6 tensor
+11 12 13 14 15 16
+21 22 23 24 25 26
+31 32 33 34 35 36
+41 42 43 44 45 46
+51 52 53 54 55 56
+61 62 63 64 65 66
+
+sharded on a 3 mesh
+
+sharding = [[0], []]
+
++-------------------+ mesh axis 0 |
+| 11 12 13 14 15 16 | |
+| 21 22 23 24 25 26 | |
++-------------------+ |
+| 31 32 33 34 35 36 | |
+| 41 42 43 44 45 46 | |
++-------------------+ |
+| 51 52 53 54 55 56 | |
+| 61 62 63 64 65 66 | |
++-------------------+ ↓
+
+transform to
+sharding = [[], [0]]
+
+mesh axis 0
+----------->
++-------+-------+-------+
+| 11 12 | 13 14 | 15 16 |
+| 21 22 | 23 24 | 25 26 |
+| 31 32 | 33 34 | 35 36 |
+| 41 42 | 43 44 | 45 46 |
+| 51 52 | 53 54 | 55 56 |
+| 61 62 | 63 64 | 65 66 |
++-------+-------+-------+
+
+%1 = all_to_all %0 on @mesh mesh_axes = [0] split_axis = 1 concat_axis = 0 : tensor<2x6xi8> -> tensor<6x2xi8>
+
+--------------------------------------------------------------
+
+Reshard 4x4 tensor from sharding [[0], [1, 2]] to sharding [[0, 1], [2]] on a 2x2x2 mesh.
+
+unsharded 4x4 tensor
+11 12 13 14
+21 22 23 24
+31 32 33 34
+41 42 43 44
+
+sharded on a 2x2x2 mesh
+
+sharding = [[0], [1, 2]]
+
+mesh axis 2
+----------->
++----+----+ mesh axis 1 | mesh axis 0 |
+| 11 | 12 | | |
+| 21 | 22 | | |
++----+----+ | |
+| 13 | 14 | | |
+| 23 | 24 | | |
++----+----+ ↓ |
++----+----+ |
+| 31 | 32 | |
+| 41 | 42 | |
++----+----+ |
+| 33 | 34 | |
+| 43 | 44 | |
++----+----+ ↓
+
+transform to
+sharding = [[0, 1], [2]]
+
+mesh axis 2
+----------->
++-------+-------+ mesh axis 1 | mesh axis 0 |
+| 11 12 | 13 41 | | |
++-------+-------+ | |
+| 21 22 | 23 24 | | |
++-------+-------+ ↓ |
++-------+-------+ |
+| 31 32 | 33 34 | |
++-------+-------+ |
+| 41 42 | 43 44 | |
++-------+-------+ ↓
+
+%1 = all_to_all %0 on @mesh mesh_axes = [2] split_axis = 1 concat_axis = 0 : tensor<2x1xi8> -> tensor<1x2xi8>
+is not enough.
+
+Can be decomposed into
+[[0], [1, 2]] -> [[0], [2, 1]] -> [[0, 1], [2]]
+
+--------------------------------------------------------------
+
+We can decompose each resharding into a sequence of basis reshardings.
+It is not communication efficient in terms of minimizing the data communicated
+between devices.
+An efficient approach would be more complicated to implement.
+Each device has to receive at most as much data as the size of its target
+sharding tensor.
+
+Basis:
+
+From replicate to split.
+[[]] -> [[1]]
+Extract slices without communication.
+
+From split to replicate.
+[[0]] -> [[]]
+[[0, 1]] -> [[1]]
+All-gather along mesh axis 0.
+
+Swap mesh axes order when assigned to the same tensor axis.
+[[0, 1]] -> [[1, 0]]
+Swap contents on devices with the same linear index.
+
+Move mesh axis to different tensor dimension.
+[[0], []] -> [[], [0]]
+All-to-all.
+
+Example decomposition of
+[[0], [1]] -> [[1], [0]]
+
+[[0], [1]] -> all-gather along mesh axis 1 ->
+[[0], []] -> all-to-all along mesh axis 0 ->
+[[], [0]] -> extract slice along mesh axis 1 ->
+[[1], [0]]
+
+Example decomposition of
+[[3, 2], [], [0, 1]] -> [[0], [1, 2], []]
+
+[[3, 2], [], [0, 1]] -> all-to-all along mesh axis 1 ->
+[[3, 2], [1], [0]] -> all-to-all along mesh axis 2 ->
+[[3], [1, 2], [0]] -> all-gather along mesh axis 3 ->
+[[], [1, 2], [0]] -> all-to-all along mesh axis 0 ->
+[[0], [1, 2], []]
diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/Spmdization.h b/mlir/include/mlir/Dialect/Mesh/Transforms/Spmdization.h
new file mode 100644
index 00000000000000..f71bb9b262a380
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Mesh/Transforms/Spmdization.h
@@ -0,0 +1,35 @@
+//===- Simplifications.h - Mesh Simplifications -----------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_MESH_TRANSFORMS_SPMDIZATION_H
+#define MLIR_DIALECT_MESH_TRANSFORMS_SPMDIZATION_H
+
+#include "mlir/Dialect/Mesh/IR/MeshOps.h"
+#include "mlir/IR/DialectRegistry.h"
+#include "mlir/Support/LogicalResult.h"
+
+namespace mlir {
+namespace mesh {
+
+// Return the sharded shape `shape` acording ot sharding `sharding`.
+ShapedType shardShapedType(ShapedType shape, ClusterOp mesh,
+ MeshShardingAttr sharding);
+
+// Insert resharding spmdization of the value `sourceShardValue`
+// from sharding `source` to sharding `target`.
+// `sourceShardValue` is the already sharded value according to `source`.
+TypedValue<ShapedType> reshard(OpBuilder &builder, ClusterOp mesh,
+ ShardOp source, ShardOp target,
+ TypedValue<ShapedType> sourceShardValue);
+
+void reshardingRegisterDependentDialects(DialectRegistry ®istry);
+
+} // namespace mesh
+} // namespace mlir
+
+#endif // MLIR_DIALECT_MESH_TRANSFORMS_SPMDIZATION_H
diff --git a/mlir/include/mlir/Support/MathExtras.h b/mlir/include/mlir/Support/MathExtras.h
index 17a747393d26eb..22493c11a370d1 100644
--- a/mlir/include/mlir/Support/MathExtras.h
+++ b/mlir/include/mlir/Support/MathExtras.h
@@ -15,9 +15,20 @@
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/APInt.h"
+#include <type_traits>
namespace mlir {
+// ceilDiv for unsigned integral.
+template <typename T, std::enable_if_t<std::is_integral_v<T> &&
+ !std::is_unsigned_v<T>> = true>
+T ceilDiv(T lhs, T rhs) {
+ assert(rhs != static_cast<T>(0));
+ T q = lhs / rhs;
+ T r = lhs % rhs;
+ return r == static_cast<T>(0) ? q : q + static_cast<T>(1);
+}
+
/// Returns the result of MLIR's ceildiv operation on constants. The RHS is
/// expected to be non-zero.
inline int64_t ceilDiv(int64_t lhs, int64_t rhs) {
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index d27675564c6464..de4f58d54e8ca5 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -9,8 +9,10 @@
#include "mlir/Dialect/Mesh/IR/MeshOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
+#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
+#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/Location.h"
@@ -59,8 +61,6 @@ static SmallVector<T> &canonicalizeSetAsVector(SmallVector<T> &vec) {
return vec;
}
-using MeshAxis = int16_t;
-
namespace {
struct DimensionSize {
@@ -114,6 +114,56 @@ Operation *MeshDialect::materializeConstant(OpBuilder &builder, Attribute value,
// Mesh utilities
//===----------------------------------------------------------------------===//
+static FailureOr<ClusterOp> getMesh(Operation *op, FlatSymbolRefAttr meshSymbol,
+ SymbolTableCollection &symbolTable) {
+ mesh::ClusterOp mesh =
+ symbolTable.lookupNearestSymbolFrom<mesh::ClusterOp>(op, meshSymbol);
+ if (!mesh) {
+ return op->emitError() << "Undefined required mesh symbol \""
+ << meshSymbol.getValue() << "\".";
+ }
+
+ return mesh;
+}
+
+template <typename It>
+bool isUnique(It begin, It end) {
+ if (begin == end) {
+ return true;
+ }
+ It next = std::next(begin);
+ if (next == end) {
+ return true;
+ }
+ for (; next != end; ++next, ++begin) {
+ if (*begin == *next) {
+ return false;
+ }
+ }
+ return true;
+}
+
+static LogicalResult verifyMeshAxes(Location loc, ArrayRef<MeshAxis> axes,
+ ClusterOp mesh) {
+ SmallVector<MeshAxis> sorted = llvm::to_vector(axes);
+ llvm::sort(sorted);
+ if (!isUnique(sorted.begin(), sorted.end())) {
+ return emitError(loc) << "Mesh axes contains duplicate elements.";
+ }
+
+ MeshAxis rank = mesh.getRank();
+ for (auto axis : axes) {
+ if (axis >= rank || axis < 0) {
+ return emitError(loc)
+ << "0-based mesh axis index " << axis
+ << " is out of bounds. The referenced mesh \"" << mesh.getSymName()
+ << "\" is of rank " << rank << ".";
+ }
+ }
+
+ return success();
+}
+
bool mesh::isReductionLoop(IteratorType iType) {
return iType != IteratorType::Parallel && iType != IteratorType::Invalid;
}
@@ -173,7 +223,45 @@ SmallVector<int64_t> ClusterOp::canonicalDimSizes() {
}
//===----------------------------------------------------------------------===//
-// mesh.shard op
+// mesh.cluster_shape op
+//===----------------------------------------------------------------------===//
+
+LogicalResult
+ClusterShapeOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
+ auto mesh = ::getMesh(getOperation(), getMeshAttr(), symbolTable);
+ if (failed(mesh)) {
+ return failure();
+ }
+ if (failed(verifyMeshAxes(getLoc(), getAxes(), mesh.value()))) {
+ return failure();
+ }
+
+ size_t expectedResultsCount =
+ getAxes().empty() ? mesh->getRank() : getAxes().size();
+ if (getResult().size() != expectedResultsCount) {
+ return emitError() << "Unexpected number of results " << getResult().size()
+ << ". Expected " << expectedResultsCount << ".";
+ }
+
+ return success();
+}
+
+void ClusterShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
+ ClusterOp mesh) {
+ build(odsBuilder, odsState,
+ SmallVector<Type>(mesh.getRank(), odsBuilder.getIndexType()),
+ mesh.getSymName(), MeshAxesAttr());
+}
+
+void ClusterShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
+ StringRef mesh, ArrayRef<MeshAxis> axes) {
+ build(odsBuilder, odsState,
+ SmallVector<Type>(axes.size(), odsBuilder.getIndexType()), mesh,
+ MeshAxesAttr::get(odsBuilder.getContext(), axes));
+}
+
+//===----------------------------------------------------------------------===//
+// mesh.shard attr
//===----------------------------------------------------------------------===//
LogicalResult
@@ -205,6 +293,75 @@ MeshShardingAttr::verify(function_ref<InFlightDiagnostic()> emitError,
return success();
}
+bool MeshShardingAttr::operator==(Attribute rhs) const {
+ MeshShardingAttr rhsAsMeshShardingAttr = rhs.dyn_cast<MeshShardingAttr>();
+ return rhsAsMeshShardingAttr && *this == rhsAsMeshShardingAttr;
+}
+
+bool MeshShardingAttr::operator==(MeshShardingAttr rhs) const {
+ if (getCluster() != rhs.getCluster() ||
+ getPartialAxes() != rhs.getPartialAxes()) {
+ return false;
+ }
+
+ if (!getPartialAxes().empty() && getPartialType() != rhs.getPartialType()) {
+ return false;
+ }
+
+ auto minSize = std::min(getSplitAxes().size(), rhs.getSplitAxes().size());
+ if (!llvm::equal(llvm::make_range(getSplitAxes().begin(),
+ getSplitAxes().begin() + minSize),
+ llvm::make_range(rhs.getSplitAxes().begin(),
+ rhs.getSplitAxes().begin() + minSize))) {
+ return false;
+ }
+
+ return llvm::all_of(llvm::make_range(getSplitAxes().begin() + minSize,
+ getSplitAxes().end()),
+ std::mem_fn(&DenseI32ArrayAttr::empty)) &&
+ llvm::all_of(llvm::make_range(rhs.getSplitAxes().begin() + minSize,
+ rhs.getSplitAxes().end()),
+ std::mem_fn(&DenseI32ArrayAttr::empty));
+}
+
+//===----------------------------------------------------------------------===//
+// mesh.process_index op
+//===----------------------------------------------------------------------===//
+
+LogicalResult
+ProcessIndexOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
+ auto mesh = ::getMesh(getOperation(), getMeshAttr(), symbolTable);
+ if (failed(mesh)) {
+ return failure();
+ }
+ if (failed(verifyMeshAxes(getLoc(), getAxes(), mesh.value()))) {
+ return failure();
+ }
+
+ size_t expectedResultsCount =
+ getAxes().empty() ? mesh->getRank() : getAxes().size();
+ if (getResult().size() != expectedResultsCount) {
+ return emitError() << "Unexpected number of results " << getResult().size()
+ << ". Expected " << expectedResultsCount << ".";
+ }
+
+ return success();
+}
+
+void ProcessIndexOp::build(OpBuilder &odsBuilder, OperationState &odsState,
+ ClusterOp mesh) {
+ build(odsBuilder, odsState,
+ SmallVector<Type>(mesh.getRank(), odsBuilder.getIndexType()),
+ mesh.getSymName(), MeshAxesAttr());
+}
+
+void ProcessIndexOp::build(OpBuilder &odsBuilder, OperationState &odsState,
+ StringRef mesh, ArrayRef<MeshAxis> axes) {
+ build(odsBuilder, odsState,
+ SmallVector<Type>(axes.size(), odsBuilder.getIndexType()), mesh,
+ MeshAxesAttr::get(odsBuilder.getContext(), axes));
+}
+
//===----------------------------------------------------------------------===//
// collective communication ops
//===----------------------------------------------------------------------===//
@@ -258,56 +415,6 @@ static LogicalResult verifyInGroupDevice(Location loc, StringRef deviceName,
return success();
}
-static FailureOr<ClusterOp> getMesh(Operation *op, FlatSymbolRefAttr meshSymbol,
- SymbolTableCollection &symbolTable) {
- mesh::ClusterOp mesh =
- symbolTable.lookupNearestSymbolFrom<mesh::ClusterOp>(op, meshSymbol);
- if (!mesh) {
- return op->emitError() << "Undefined required mesh symbol \""
- << meshSymbol.getValue() << "\".";
- }
-
- return mesh;
-}
-
-template <typename It>
-bool isUnique(It begin, It end) {
- if (begin == end) {
- return true;
- }
- It next = std::next(begin);
- if (next == end) {
- return true;
- }
- for (; next != end; ++next, ++begin) {
- if (*begin == *next) {
- return false;
- }
- }
- return true;
-}
-
-static LogicalResult verifyMeshAxes(Location loc, ArrayRef<MeshAxis> axes,
- ClusterOp mesh) {
- SmallVector<MeshAxis> sorted = llvm::to_vector(axes);
- llvm::sort(sorted);
- if (!isUnique(sorted.begin(), sorted.end())) {
- return emitError(loc) << "Mesh axes contains duplicate elements.";
- }
-
- MeshAxis rank = mesh.getRank();
- for (auto axis : axes) {
- if (axis >= rank || axis < 0) {
- return emitError(loc)
- << "0-based mesh axis index " << axis
- << " is out of bounds. The referenced mesh \"" << mesh.getSymName()
- << "\" is of rank " << rank << ".";
- }
- }
-
- return success();
-}
-
template <typename Op>
static FailureOr<ClusterOp>
getMeshAndVerifyAxes(Op op, SymbolTableCollection &symbolTable) {
diff --git a/mlir/lib/Dialect/Mesh/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Mesh/Transforms/CMakeLists.txt
index 044b8672c8c60c..7a70c047ec9dce 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Mesh/Transforms/CMakeLists.txt
@@ -1,6 +1,7 @@
add_mlir_dialect_library(MLIRMeshTransforms
Simplifications.cpp
ShardingPropagation.cpp
+ Spmdization.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Mesh
@@ -11,11 +12,13 @@ add_mlir_dialect_library(MLIRMeshTransforms
LINK_LIBS PUBLIC
MLIRArithDialect
+ MLIRControlFlowDialect
MLIRFuncDialect
MLIRIR
MLIRMeshDialect
MLIRPass
MLIRShardingInterface
MLIRSupport
+ MLIRTensorDialect
MLIRTosaShardingInterfaceImpl
)
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
new file mode 100644
index 00000000000000..de8b3a98df998a
--- /dev/null
+++ b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
@@ -0,0 +1,639 @@
+//===- Spmdization.cpp --------------------------------------------- C++ --===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Mesh/Transforms/Spmdization.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
+#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
+#include "mlir/Dialect/Mesh/IR/MeshOps.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinTypeInterfaces.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/ImplicitLocOpBuilder.h"
+#include "mlir/IR/Location.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/Value.h"
+#include "mlir/Support/LLVM.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Support/MathExtras.h"
+#include "llvm/ADT/ADL.h"
+#include "llvm/ADT/APInt.h"
+#include "llvm/ADT/DenseSet.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallVector.h"
+#include <algorithm>
+#include <iterator>
+#include <numeric>
+#include <optional>
+#include <tuple>
+#include <type_traits>
+
+namespace mlir {
+namespace mesh {
+
+int64_t shardDimension(int64_t dim, int64_t shardCount) {
+ if (ShapedType::isDynamic(dim) || ShapedType::isDynamic(shardCount))
+ return ShapedType::kDynamic;
+
+ assert(dim % shardCount == 0);
+ return ceilDiv(dim, shardCount);
+}
+
+int64_t unshardDimension(int64_t dim, int64_t shardCount) {
+ if (ShapedType::isDynamic(dim) || ShapedType::isDynamic(shardCount))
+ return ShapedType::kDynamic;
+
+ return dim * shardCount;
+}
+
+template <typename MeshShape, typename SplitAxes>
+int64_t shardCount(const MeshShape &meshShape, const SplitAxes &splitAxes) {
+ int64_t res = 1;
+ for (auto splitAxis : splitAxes) {
+ int64_t meshDimSize = meshShape[splitAxis];
+ if (ShapedType::isDynamic(meshDimSize)) {
+ return ShapedType::kDynamic;
+ }
+ res *= meshDimSize;
+ }
+ return res;
+}
+
+// Compute the shape for the tensor on each device in the mesh.
+// Example:
+// On a 2x4x? mesh with split axes = [[0], [1], [2]] the shape ?x5x1
+// would result in a shape for each shard of ?x2x?.
+template <typename InShape, typename MeshShape, typename SplitAxes,
+ typename OutShape>
+static void shardShape(const InShape &inShape, const MeshShape &meshShape,
+ const SplitAxes &splitAxes, OutShape &outShape) {
+ std::copy(llvm::adl_begin(inShape), llvm::adl_end(inShape),
+ llvm::adl_begin(outShape));
+ for (auto [tensorAxis, innerSplitAxes] : llvm::enumerate(splitAxes)) {
+ outShape[tensorAxis] =
+ shardDimension(inShape[tensorAxis],
+ shardCount(meshShape, innerSplitAxes.asArrayRef()));
+ }
+}
+
+ShapedType shardShapedType(ShapedType shape, ClusterOp mesh,
+ MeshShardingAttr sharding) {
+ using Dim = std::decay_t<decltype(shape.getDimSize(0))>;
+ SmallVector<Dim> resShapeArr(shape.getShape().size());
+ shardShape(shape.getShape(), mesh.canonicalDimSizes(),
+ sharding.getSplitAxes(), resShapeArr);
+ return shape.clone(resShapeArr);
+}
+
+template <typename SourceAxes, typename TargetAxes>
+static bool arePartialAxesCompatible(const SourceAxes &sourceAxes,
+ const TargetAxes &targetAxes) {
+ return llvm::all_of(targetAxes, [&sourceAxes](auto &targetAxis) {
+ return sourceAxes.contains(targetAxis);
+ });
+}
+
+// Return the reduced value and its corresponding sharding.
+// Example:
+// sourceSharding = <@mesh_1d, [[0]], partial = sum[0]>
+// targetSharding = <@mesh_1d, [[]]>
+// Then will apply all-reduce on the source value
+// and return it with the sharding <@mesh_1d, [[0]]>.
+static std::tuple<TypedValue<ShapedType>, MeshShardingAttr>
+handlePartialAxesDuringResharding(OpBuilder &builder,
+ MeshShardingAttr sourceSharding,
+ MeshShardingAttr targetSharding,
+ TypedValue<ShapedType> sourceShard) {
+ if (sourceSharding.getPartialAxes().empty() &&
+ targetSharding.getPartialAxes().empty()) {
+ return {sourceShard, sourceSharding};
+ }
+ assert(targetSharding.getPartialAxes().empty() ||
+ (!sourceSharding.getPartialAxes().empty() &&
+ sourceSharding.getPartialType() == targetSharding.getPartialType()));
+ using Axis = std::decay_t<decltype(sourceSharding.getPartialAxes().front())>;
+ using AxisSet = llvm::SmallDenseSet<Axis>;
+ AxisSet sourceShardingPartialAxesSet(sourceSharding.getPartialAxes().begin(),
+ sourceSharding.getPartialAxes().end());
+ AxisSet targetShardingPartialAxesSet(targetSharding.getPartialAxes().begin(),
+ targetSharding.getPartialAxes().end());
+ assert(arePartialAxesCompatible(sourceShardingPartialAxesSet,
+ targetShardingPartialAxesSet));
+ llvm::SmallVector<MeshAxis> allReduceMeshAxes;
+ llvm::copy_if(sourceShardingPartialAxesSet,
+ std::back_inserter(allReduceMeshAxes),
+ [&targetShardingPartialAxesSet](Axis a) {
+ return !targetShardingPartialAxesSet.contains(a);
+ });
+ if (allReduceMeshAxes.empty()) {
+ return {sourceShard, sourceSharding};
+ }
+
+ builder.setInsertionPointAfterValue(sourceShard);
+ TypedValue<ShapedType> resultValue =
+ builder
+ .create<AllReduceOp>(sourceShard.getLoc(), sourceShard.getType(),
+ sourceSharding.getCluster().getLeafReference(),
+ allReduceMeshAxes, sourceShard,
+ sourceSharding.getPartialType())
+ .getResult()
+ .cast<TypedValue<ShapedType>>();
+
+ llvm::SmallVector<int32_t> remainingPartialAxes;
+ llvm::copy_if(sourceShardingPartialAxesSet,
+ std::back_inserter(allReduceMeshAxes),
+ [&targetShardingPartialAxesSet](Axis a) {
+ return targetShardingPartialAxesSet.contains(a);
+ });
+ MeshShardingAttr resultSharding =
+ MeshShardingAttr::get(builder.getContext(), sourceSharding.getCluster(),
+ sourceSharding.getSplitAxes(), remainingPartialAxes,
+ sourceSharding.getPartialType());
+ return {resultValue, resultSharding};
+}
+
+static MeshShardingAttr
+targetShardingInSplitLastAxis(MLIRContext *ctx, MeshShardingAttr sourceSharding,
+ int64_t splitTensorAxis, MeshAxis splitMeshAxis) {
+ SmallVector<DenseI32ArrayAttr> targetShardingSplitAxes =
+ llvm::to_vector(sourceSharding.getSplitAxes());
+ while (static_cast<int64_t>(targetShardingSplitAxes.size()) <=
+ splitTensorAxis) {
+ targetShardingSplitAxes.push_back(DenseI32ArrayAttr::get(ctx, {}));
+ }
+ auto targetSplitAxes =
+ llvm::to_vector(targetShardingSplitAxes[splitTensorAxis].asArrayRef());
+ targetSplitAxes.push_back(splitMeshAxis);
+ targetShardingSplitAxes[splitTensorAxis] =
+ DenseI32ArrayAttr::get(ctx, targetSplitAxes);
+ return MeshShardingAttr::get(
+ ctx, sourceSharding.getCluster(), targetShardingSplitAxes,
+ sourceSharding.getPartialAxes(), sourceSharding.getPartialType());
+}
+
+static ShapedType targetShapeInSplitLastAxis(ShapedType sourceShape,
+ int64_t splitTensorAxis,
+ int64_t splitCount) {
+ SmallVector<int64_t> targetShape = llvm::to_vector(sourceShape.getShape());
+ targetShape[splitTensorAxis] =
+ shardDimension(targetShape[splitTensorAxis], splitCount);
+ return sourceShape.cloneWith(targetShape, sourceShape.getElementType());
+}
+
+// Split a replicated tensor along a mesh axis.
+// e.g. [[0, 1]] -> [[0, 1, 2]].
+// Returns the spmdized target value with its sharding.
+//
+// The implementation is the extract the tensor slice corresponding
+// to the current device.
+static std::tuple<TypedValue<ShapedType>, MeshShardingAttr>
+splitLastAxisInResharding(ImplicitLocOpBuilder &builder,
+ MeshShardingAttr sourceSharding,
+ TypedValue<ShapedType> sourceShard, ClusterOp mesh,
+ int64_t splitTensorAxis, MeshAxis splitMeshAxis) {
+ MLIRContext *ctx = builder.getContext();
+ builder.setInsertionPointAfterValue(sourceShard);
+
+ Value zero = builder.create<arith::ConstantOp>(builder.getIndexAttr(0));
+
+ Value processIndexAlongAxis =
+ builder
+ .create<ProcessIndexOp>(mesh.getSymName(),
+ SmallVector<MeshAxis>({splitMeshAxis}))
+ .getResult()[0];
+
+ MeshShardingAttr targetSharding = targetShardingInSplitLastAxis(
+ ctx, sourceSharding, splitTensorAxis, splitMeshAxis);
+ ShapedType targetShape =
+ targetShapeInSplitLastAxis(sourceShard.getType(), splitTensorAxis,
+ mesh.canonicalDimSizes()[splitMeshAxis]);
+
+ Value meshAxisSize =
+ builder
+ .create<ClusterShapeOp>(mesh.getSymName(),
+ SmallVector<MeshAxis>({splitMeshAxis}))
+ .getResult()[0];
+
+ Value sourceAxisSize =
+ builder.create<tensor::DimOp>(sourceShard, splitTensorAxis);
+ Value sourceAxisSizeModMeshAxisSize =
+ builder.create<arith::RemUIOp>(sourceAxisSize, meshAxisSize);
+ Value isTargetShapeExactlyDivisible = builder.create<arith::CmpIOp>(
+ arith::CmpIPredicate::eq, sourceAxisSizeModMeshAxisSize, zero);
+ builder.create<cf::AssertOp>(
+ isTargetShapeExactlyDivisible,
+ "Sharding a tensor with axis size that is not exactly divisible by the "
+ "mesh axis size is not supported.");
+ Value targetAxisSize =
+ builder.create<arith::DivUIOp>(sourceAxisSize, meshAxisSize);
+ Value axisOffset =
+ builder.create<arith::MulIOp>(targetAxisSize, processIndexAlongAxis);
+ SmallVector<int64_t> staticOffsets(targetShape.getRank(), 0);
+ staticOffsets[splitTensorAxis] = ShapedType::kDynamic;
+ DenseI64ArrayAttr staticOffsetsAttr =
+ DenseI64ArrayAttr::get(ctx, staticOffsets);
+ SmallVector<Value> dynamicOffsets(1, axisOffset);
+
+ DenseI64ArrayAttr staticSizesAttr =
+ DenseI64ArrayAttr::get(ctx, targetShape.getShape());
+ SmallVector<Value> dynamicSizes;
+ for (int64_t i = 0; i < targetShape.getRank(); ++i) {
+ if (ShapedType::isDynamic(staticSizesAttr.asArrayRef()[i])) {
+ if (i == splitTensorAxis) {
+ dynamicSizes.push_back(targetAxisSize);
+ } else {
+ Value dimSize = builder.create<tensor::DimOp>(sourceShard, i);
+ dynamicSizes.push_back(dimSize);
+ }
+ }
+ }
+
+ DenseI64ArrayAttr staticStridesAttr = DenseI64ArrayAttr::get(
+ ctx, SmallVector<int64_t>(targetShape.getRank(), 1));
+ TypedValue<RankedTensorType> targetShard =
+ builder
+ .create<tensor::ExtractSliceOp>(
+ targetShape, sourceShard, dynamicOffsets, dynamicSizes,
+ SmallVector<Value>({}), staticOffsetsAttr, staticSizesAttr,
+ staticStridesAttr)
+ .getResult();
+ return {targetShard.cast<TypedValue<ShapedType>>(), targetSharding};
+}
+
+// Detect if the resharding is of type e.g.
+// [[0, 1]] -> [[0, 1, 2]].
+// If detected, returns the corresponding tensor axis mesh axis pair.
+// Does not detect insertions like
+// [[0, 1]] -> [[0, 2, 1]].
+static std::optional<std::tuple<int64_t, MeshAxis>>
+detectSplitLastAxisInResharding(MeshShardingAttr sourceSharding,
+ MeshShardingAttr targetSharding) {
+ for (size_t tensorAxis = 0; tensorAxis < targetSharding.getSplitAxes().size();
+ ++tensorAxis) {
+ if (sourceSharding.getSplitAxes().size() > tensorAxis) {
+ if (sourceSharding.getSplitAxes()[tensorAxis].size() + 1 !=
+ targetSharding.getSplitAxes()[tensorAxis].size()) {
+ continue;
+ }
+ if (!llvm::equal(
+ sourceSharding.getSplitAxes()[tensorAxis].asArrayRef(),
+ llvm::make_range(
+ targetSharding.getSplitAxes()[tensorAxis]
+ .asArrayRef()
+ .begin(),
+ targetSharding.getSplitAxes()[tensorAxis].asArrayRef().end() -
+ 1))) {
+ continue;
+ }
+ } else {
+ if (targetSharding.getSplitAxes()[tensorAxis].size() != 1) {
+ continue;
+ }
+ }
+ return std::make_tuple(
+ tensorAxis,
+ targetSharding.getSplitAxes()[tensorAxis].asArrayRef().back());
+ }
+ return std::nullopt;
+}
+
+static std::optional<std::tuple<TypedValue<ShapedType>, MeshShardingAttr>>
+trySplitLastAxisInResharding(ImplicitLocOpBuilder &builder, ClusterOp mesh,
+ MeshShardingAttr sourceSharding,
+ MeshShardingAttr targetSharding,
+ TypedValue<ShapedType> sourceShard) {
+ if (auto detectRes =
+ detectSplitLastAxisInResharding(sourceSharding, targetSharding)) {
+ auto [tensorAxis, meshAxis] = detectRes.value();
+ return splitLastAxisInResharding(builder, sourceSharding, sourceShard, mesh,
+ tensorAxis, meshAxis);
+ }
+
+ return std::nullopt;
+}
+
+// Detect if the resharding is of type e.g.
+// [[0, 1, 2]] -> [[0, 1]].
+// If detected, returns the corresponding tensor axis mesh axis pair.
+static std::optional<std::tuple<int64_t, MeshAxis>>
+detectUnsplitLastAxisInResharding(MeshShardingAttr sourceSharding,
+ MeshShardingAttr targetSharding) {
+ for (size_t tensorAxis = 0; tensorAxis < sourceSharding.getSplitAxes().size();
+ ++tensorAxis) {
+ if (targetSharding.getSplitAxes().size() > tensorAxis) {
+ if (sourceSharding.getSplitAxes()[tensorAxis].size() !=
+ targetSharding.getSplitAxes()[tensorAxis].size() + 1)
+ continue;
+ if (!llvm::equal(
+ llvm::make_range(
+ sourceSharding.getSplitAxes()[tensorAxis]
+ .asArrayRef()
+ .begin(),
+ sourceSharding.getSplitAxes()[tensorAxis].asArrayRef().end() -
+ 1),
+ targetSharding.getSplitAxes()[tensorAxis].asArrayRef()))
+ continue;
+ } else {
+ if (sourceSharding.getSplitAxes()[tensorAxis].size() != 1)
+ continue;
+ }
+ return std::make_tuple(
+ tensorAxis,
+ sourceSharding.getSplitAxes()[tensorAxis].asArrayRef().back());
+ }
+ return std::nullopt;
+}
+
+static MeshShardingAttr
+targetShardingInUnsplitLastAxis(MLIRContext *ctx,
+ MeshShardingAttr sourceSharding,
+ int64_t splitTensorAxis) {
+ SmallVector<DenseI32ArrayAttr> targetShardingSplitAxes =
+ llvm::to_vector(sourceSharding.getSplitAxes());
+ assert(static_cast<int64_t>(targetShardingSplitAxes.size()) >
+ splitTensorAxis);
+ auto targetSplitAxes =
+ llvm::to_vector(targetShardingSplitAxes[splitTensorAxis].asArrayRef());
+
+ targetSplitAxes.pop_back();
+ targetShardingSplitAxes[splitTensorAxis] =
+ DenseI32ArrayAttr::get(ctx, targetSplitAxes);
+ return MeshShardingAttr::get(
+ ctx, sourceSharding.getCluster(), targetShardingSplitAxes,
+ sourceSharding.getPartialAxes(), sourceSharding.getPartialType());
+}
+
+static ShapedType allGatherResultShapeInUnsplitLastAxis(
+ ShapedType sourceShape, int64_t splitCount, int64_t splitTensorAxis) {
+ SmallVector<int64_t> targetShape = llvm::to_vector(sourceShape.getShape());
+ targetShape[splitTensorAxis] =
+ unshardDimension(targetShape[splitTensorAxis], splitCount);
+ return sourceShape.cloneWith(targetShape, sourceShape.getElementType());
+}
+
+static std::tuple<TypedValue<ShapedType>, MeshShardingAttr>
+unsplitLastAxisInResharding(ImplicitLocOpBuilder &builder,
+ MeshShardingAttr sourceSharding,
+ ShapedType sourceUnshardedShape,
+ TypedValue<ShapedType> sourceShard, ClusterOp mesh,
+ int64_t splitTensorAxis, MeshAxis splitMeshAxis) {
+ MLIRContext *ctx = builder.getContext();
+ builder.setInsertionPointAfterValue(sourceShard);
+
+ MeshShardingAttr targetSharding =
+ targetShardingInUnsplitLastAxis(ctx, sourceSharding, splitMeshAxis);
+ ShapedType allGatherResultShape = allGatherResultShapeInUnsplitLastAxis(
+ sourceShard.getType(), mesh.canonicalDimSizes()[splitMeshAxis],
+ splitTensorAxis);
+ Value allGatherResult = builder.create<AllGatherOp>(
+ RankedTensorType::get(allGatherResultShape.getShape(),
+ allGatherResultShape.getElementType()),
+ mesh.getSymName(), SmallVector<MeshAxis>({splitMeshAxis}), sourceShard,
+ APInt(64, splitTensorAxis));
+ ShapedType targetShape =
+ shardShapedType(sourceUnshardedShape, mesh, targetSharding);
+ TypedValue<ShapedType> targetShard =
+ builder.create<tensor::CastOp>(targetShape, allGatherResult)
+ .getResult()
+ .cast<TypedValue<ShapedType>>();
+ return {targetShard, targetSharding};
+}
+
+static std::optional<std::tuple<TypedValue<ShapedType>, MeshShardingAttr>>
+tryUnsplitLastAxisInResharding(ImplicitLocOpBuilder &builder, ClusterOp mesh,
+ MeshShardingAttr sourceSharding,
+ MeshShardingAttr targetSharding,
+ ShapedType sourceUnshardedShape,
+ TypedValue<ShapedType> sourceShard) {
+ if (auto detectRes =
+ detectUnsplitLastAxisInResharding(sourceSharding, targetSharding)) {
+ auto [tensorAxis, meshAxis] = detectRes.value();
+ return unsplitLastAxisInResharding(builder, sourceSharding,
+ sourceUnshardedShape, sourceShard, mesh,
+ tensorAxis, meshAxis);
+ }
+
+ return std::nullopt;
+}
+
+// Detect if the resharding is of type e.g.
+// [[0, 1], [2]] -> [[0], [1, 2]].
+// Only moving the last axis counts.
+// If detected, returns the corresponding (source_tensor_axis,
+// target_tensor_axis, mesh_axis) tuple.
+static std::optional<std::tuple<int64_t, int64_t, MeshAxis>>
+detectMoveLastSplitAxisInResharding(MeshShardingAttr sourceSharding,
+ MeshShardingAttr targetSharding) {
+ for (size_t sourceTensorAxis = 0;
+ sourceTensorAxis < sourceSharding.getSplitAxes().size();
+ ++sourceTensorAxis) {
+ for (size_t targetTensorAxis = 0;
+ targetTensorAxis < targetSharding.getSplitAxes().size();
+ ++targetTensorAxis) {
+ if (sourceTensorAxis == targetTensorAxis)
+ continue;
+ if (sourceSharding.getSplitAxes()[sourceTensorAxis].empty() ||
+ targetSharding.getSplitAxes()[targetTensorAxis].empty() ||
+ sourceSharding.getSplitAxes()[sourceTensorAxis].asArrayRef().back() !=
+ targetSharding.getSplitAxes()[targetTensorAxis]
+ .asArrayRef()
+ .back())
+ continue;
+ if (!llvm::equal(
+ llvm::make_range(sourceSharding.getSplitAxes()[sourceTensorAxis]
+ .asArrayRef()
+ .begin(),
+ sourceSharding.getSplitAxes()[sourceTensorAxis]
+ .asArrayRef()
+ .end() -
+ 1),
+ llvm::make_range(targetSharding.getSplitAxes()[targetTensorAxis]
+ .asArrayRef()
+ .begin(),
+ targetSharding.getSplitAxes()[targetTensorAxis]
+ .asArrayRef()
+ .end() -
+ 1)))
+ continue;
+ return std::make_tuple(
+ sourceTensorAxis, targetTensorAxis,
+ sourceSharding.getSplitAxes()[sourceTensorAxis].asArrayRef().back());
+ }
+ }
+ return std::nullopt;
+}
+
+static MeshShardingAttr
+targetShardingInMoveLastAxis(MLIRContext *ctx, MeshShardingAttr sourceSharding,
+ int64_t sourceTensorAxis,
+ int64_t targetTensorAxis) {
+ SmallVector<DenseI32ArrayAttr> targetShardingSplitAxes =
+ llvm::to_vector(sourceSharding.getSplitAxes());
+ while (static_cast<int64_t>(targetShardingSplitAxes.size()) <=
+ targetTensorAxis) {
+ targetShardingSplitAxes.push_back(DenseI32ArrayAttr::get(ctx, {}));
+ }
+
+ auto sourceSplitAxes =
+ llvm::to_vector(targetShardingSplitAxes[sourceTensorAxis].asArrayRef());
+ assert(!sourceSplitAxes.empty());
+ auto meshAxis = sourceSplitAxes.back();
+ sourceSplitAxes.pop_back();
+ targetShardingSplitAxes[sourceTensorAxis] =
+ DenseI32ArrayAttr::get(ctx, sourceSplitAxes);
+
+ auto targetSplitAxes =
+ llvm::to_vector(targetShardingSplitAxes[targetTensorAxis].asArrayRef());
+ targetSplitAxes.push_back(meshAxis);
+ targetShardingSplitAxes[targetTensorAxis] =
+ DenseI32ArrayAttr::get(ctx, targetSplitAxes);
+
+ return MeshShardingAttr::get(
+ ctx, sourceSharding.getCluster(), targetShardingSplitAxes,
+ sourceSharding.getPartialAxes(), sourceSharding.getPartialType());
+}
+
+static ShapedType allToAllResultShapeInMoveLastAxis(ShapedType sourceShape,
+ int64_t splitCount,
+ int64_t sourceTensorAxis,
+ int64_t targetTensorAxis) {
+ SmallVector<int64_t> targetShape = llvm::to_vector(sourceShape.getShape());
+ targetShape[sourceTensorAxis] =
+ unshardDimension(targetShape[sourceTensorAxis], splitCount);
+ targetShape[targetTensorAxis] =
+ shardDimension(targetShape[targetTensorAxis], splitCount);
+ return sourceShape.cloneWith(targetShape, sourceShape.getElementType());
+}
+
+static std::tuple<TypedValue<ShapedType>, MeshShardingAttr>
+moveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, ClusterOp mesh,
+ MeshShardingAttr sourceSharding,
+ ShapedType sourceUnshardedShape,
+ TypedValue<ShapedType> sourceShard,
+ int64_t sourceTensorAxis,
+ int64_t targetTensorAxis, MeshAxis meshAxis) {
+ MLIRContext *ctx = builder.getContext();
+ builder.setInsertionPointAfterValue(sourceShard);
+
+ MeshShardingAttr targetSharding = targetShardingInMoveLastAxis(
+ ctx, sourceSharding, sourceTensorAxis, targetTensorAxis);
+ ShapedType allToAllResultShape = allToAllResultShapeInMoveLastAxis(
+ sourceShard.getType(), mesh.canonicalDimSizes()[meshAxis],
+ sourceTensorAxis, targetTensorAxis);
+ Value allToAllResult = builder.create<AllToAllOp>(
+ RankedTensorType::get(allToAllResultShape.getShape(),
+ allToAllResultShape.getElementType()),
+ mesh.getSymName(), SmallVector<MeshAxis>({meshAxis}), sourceShard,
+ APInt(64, targetTensorAxis), APInt(64, sourceTensorAxis));
+ ShapedType targetShape =
+ shardShapedType(sourceUnshardedShape, mesh, targetSharding);
+ TypedValue<ShapedType> targetShard =
+ builder.create<tensor::CastOp>(targetShape, allToAllResult)
+ .getResult()
+ .cast<TypedValue<ShapedType>>();
+ return {targetShard, targetSharding};
+}
+
+static std::optional<std::tuple<TypedValue<ShapedType>, MeshShardingAttr>>
+tryMoveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, ClusterOp mesh,
+ MeshShardingAttr sourceSharding,
+ MeshShardingAttr targetSharding,
+ ShapedType sourceUnshardedShape,
+ TypedValue<ShapedType> sourceShard) {
+ if (auto detectRes =
+ detectMoveLastSplitAxisInResharding(sourceSharding, targetSharding)) {
+ auto [sourceTensorAxis, targetTensorAxis, meshAxis] = detectRes.value();
+ return moveLastSplitAxisInResharding(
+ builder, mesh, sourceSharding, sourceUnshardedShape, sourceShard,
+ sourceTensorAxis, targetTensorAxis, meshAxis);
+ }
+
+ return std::nullopt;
+}
+
+// Handles only resharding on a 1D mesh.
+// Currently the sharded tensor axes must be exactly divisible by the single
+// mesh axis size.
+static TypedValue<ShapedType>
+reshardOn1DMesh(ImplicitLocOpBuilder &builder, ClusterOp mesh,
+ MeshShardingAttr sourceSharding,
+ MeshShardingAttr targetSharding,
+ TypedValue<ShapedType> sourceUnshardedValue,
+ TypedValue<ShapedType> sourceShard) {
+ assert(sourceShard.getType() ==
+ shardShapedType(sourceUnshardedValue.getType(), mesh, sourceSharding));
+ ShapedType targetShardType =
+ shardShapedType(sourceUnshardedValue.getType(), mesh, targetSharding);
+ assert(sourceShard.getType().getRank() == targetShardType.getRank());
+ assert(mesh.getRank() == 1 && "Only 1D meshes are currently supported.");
+
+ auto [reducedSourceShard, reducedSourceSharding] =
+ handlePartialAxesDuringResharding(builder, sourceSharding, targetSharding,
+ sourceShard);
+
+ if (reducedSourceSharding == targetSharding) {
+ return reducedSourceShard;
+ }
+
+ TypedValue<ShapedType> targetShard;
+ MeshShardingAttr actualTargetSharding;
+ if (auto tryRes = tryMoveLastSplitAxisInResharding(
+ builder, mesh, reducedSourceSharding, targetSharding,
+ sourceUnshardedValue.getType(), reducedSourceShard)) {
+ std::tie(targetShard, actualTargetSharding) = tryRes.value();
+ } else if (auto tryRes = trySplitLastAxisInResharding(
+ builder, mesh, reducedSourceSharding, targetSharding,
+ reducedSourceShard)) {
+ std::tie(targetShard, actualTargetSharding) = tryRes.value();
+ } else if (auto tryRes = tryUnsplitLastAxisInResharding(
+ builder, mesh, reducedSourceSharding, targetSharding,
+ sourceUnshardedValue.getType(), reducedSourceShard)) {
+ std::tie(targetShard, actualTargetSharding) = tryRes.value();
+ } else {
+ assert(false && "Did not find any pattern to apply.");
+ }
+
+ assert(actualTargetSharding == targetSharding);
+ assert(targetShard.getType() == targetShardType);
+ return targetShard;
+}
+
+TypedValue<ShapedType> reshard(ImplicitLocOpBuilder &builder, ClusterOp mesh,
+ MeshShardingAttr sourceSharding,
+ MeshShardingAttr targetSharding,
+ TypedValue<ShapedType> sourceUnshardedValue,
+ TypedValue<ShapedType> sourceShard) {
+ // Resort to handling only 1D meshes since the general case is complicated if
+ // it needs to be communication efficient in terms of minimizing the data
+ // transfered between devices.
+ return reshardOn1DMesh(builder, mesh, sourceSharding, targetSharding,
+ sourceUnshardedValue, sourceShard);
+}
+
+TypedValue<ShapedType> reshard(OpBuilder &builder, ClusterOp mesh,
+ ShardOp source, ShardOp target,
+ TypedValue<ShapedType> sourceShardValue) {
+ assert(!source.getAnnotateForUsers());
+ assert(target.getAnnotateForUsers());
+ assert(source.getResult() == target.getOperand());
+ ImplicitLocOpBuilder implicitLocOpBuilder(target->getLoc(), builder);
+ return reshard(
+ implicitLocOpBuilder, mesh, source.getShard(), target.getShard(),
+ source.getSrc().cast<TypedValue<ShapedType>>(), sourceShardValue);
+}
+
+void reshardingRegisterDependentDialects(DialectRegistry ®istry) {
+ registry.insert<arith::ArithDialect, mesh::MeshDialect, tensor::TensorDialect,
+ cf::ControlFlowDialect>();
+}
+
+} // namespace mesh
+} // namespace mlir
diff --git a/mlir/test/Dialect/Mesh/invalid.mlir b/mlir/test/Dialect/Mesh/invalid.mlir
index 03994f8f011e1f..3ee578a37235e2 100644
--- a/mlir/test/Dialect/Mesh/invalid.mlir
+++ b/mlir/test/Dialect/Mesh/invalid.mlir
@@ -70,6 +70,102 @@ func.func @mesh_axis_negtive_in_partial(
// -----
+mesh.cluster @mesh0(rank = 2, dim_sizes = 2x4)
+
+func.func @cluster_shape_mesh_axis_out_of_bounds() -> (index, index) {
+ // expected-error at +1 {{0-based mesh axis index 2 is out of bounds. The referenced mesh "mesh0" is of rank 2.}}
+ %0:2 = mesh.cluster_shape @mesh0 axes = [0, 2] : index, index
+ return %0#0, %0#1 : index, index
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 3, dim_sizes = 1x2x3)
+
+func.func @cluster_shape_duplicate_mesh_axis() -> (index, index, index) {
+ // expected-error at +1 {{Mesh axes contains duplicate elements.}}
+ %0:3 = mesh.cluster_shape @mesh0 axes = [0, 2, 0] : index, index, index
+ return %0#0, %0#1, %0#2 : index, index, index
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 2, dim_sizes = 2x4)
+
+func.func @cluster_shape_wrong_number_of_results() -> (index, index) {
+ // expected-error at +1 {{Unexpected number of results 2. Expected 1.}}
+ %0:2 = mesh.cluster_shape @mesh0 axes = [0] : index, index
+ return %0#0, %0#1 : index, index
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 3, dim_sizes = 1x2x3)
+
+func.func @cluster_shape_wrong_number_of_results_empty_mesh_axes() -> (index, index) {
+ // expected-error at +1 {{Unexpected number of results 2. Expected 3.}}
+ %0:2 = mesh.cluster_shape @mesh0 : index, index
+ return %0#0, %0#1 : index, index
+}
+
+// -----
+
+func.func @cluster_shape_invalid_mesh_name() -> (index) {
+ // expected-error at +1 {{Undefined required mesh symbol "this_mesh_symbol_does_not_exist".}}
+ %0 = mesh.cluster_shape @this_mesh_symbol_does_not_exist : index
+ return %0#0 : index
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 2, dim_sizes = 2x4)
+
+func.func @process_index_mesh_axis_out_of_bounds() -> (index, index) {
+ // expected-error at +1 {{0-based mesh axis index 2 is out of bounds. The referenced mesh "mesh0" is of rank 2.}}
+ %0:2 = mesh.process_index on @mesh0 axes = [0, 2] : index, index
+ return %0#0, %0#1 : index, index
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 3, dim_sizes = 1x2x3)
+
+func.func @process_index_duplicate_mesh_axis() -> (index, index, index) {
+ // expected-error at +1 {{Mesh axes contains duplicate elements.}}
+ %0:3 = mesh.process_index on @mesh0 axes = [0, 2, 0] : index, index, index
+ return %0#0, %0#1, %0#2 : index, index, index
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 2, dim_sizes = 2x4)
+
+func.func @process_index_wrong_number_of_results() -> (index, index) {
+ // expected-error at +1 {{Unexpected number of results 2. Expected 1.}}
+ %0:2 = mesh.process_index on @mesh0 axes = [0] : index, index
+ return %0#0, %0#1 : index, index
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 3, dim_sizes = 1x2x3)
+
+func.func @process_index_wrong_number_of_results_empty_mesh_axes() -> (index, index) {
+ // expected-error at +1 {{Unexpected number of results 2. Expected 3.}}
+ %0:2 = mesh.process_index on @mesh0 : index, index
+ return %0#0, %0#1 : index, index
+}
+
+// -----
+
+func.func @process_index_invalid_mesh_name() -> (index) {
+ // expected-error at +1 {{Undefined required mesh symbol "this_mesh_symbol_does_not_exist".}}
+ %0 = mesh.process_index on @this_mesh_symbol_does_not_exist : index
+ return %0#0 : index
+}
+
+// -----
+
func.func @all_reduce_invalid_mesh_symbol(
%arg0 : tensor<4xf32>) -> tensor<4xf64> {
// expected-error at +1 {{Undefined required mesh symbol "this_mesh_symbol_does_not_exist".}}
diff --git a/mlir/test/Dialect/Mesh/ops.mlir b/mlir/test/Dialect/Mesh/ops.mlir
index 8f8e309d18f156..a7c3b3dbab9c13 100644
--- a/mlir/test/Dialect/Mesh/ops.mlir
+++ b/mlir/test/Dialect/Mesh/ops.mlir
@@ -132,6 +132,55 @@ func.func @mesh_shard_op_two_users(%arg0 : tensor<4x8xf32>) ->
return %1, %2 : tensor<4x8xf32>, tensor<4x8xf32>
}
+// CHECK-LABEL: func @cluster_shape
+func.func @cluster_shape() -> (index, index) {
+ // CHECK: %[[RES:.*]]:2 = mesh.cluster_shape @mesh0 axes = [0, 1] : index, index
+ %0:2 = mesh.cluster_shape @mesh0 axes = [0, 1] : index, index
+ // CHECK: return %[[RES]]#0, %[[RES]]#1 : index, index
+ return %0#0, %0#1 : index, index
+}
+
+// CHECK-LABEL: func @cluster_shape_default_axes
+func.func @cluster_shape_default_axes() -> (index, index, index) {
+ // CHECK: %[[RES:.*]]:3 = mesh.cluster_shape @mesh0 : index, index, index
+ %0:3 = mesh.cluster_shape @mesh0 : index, index, index
+ // CHECK: return %[[RES]]#0, %[[RES]]#1, %[[RES]]#2 : index, index, index
+ return %0#0, %0#1, %0#2 : index, index, index
+}
+
+// CHECK-LABEL: func @cluster_shape_empty_axes
+func.func @cluster_shape_empty_axes() -> (index, index, index) {
+ // CHECK: %[[RES:.*]]:3 = mesh.cluster_shape @mesh0 : index, index, index
+ %0:3 = mesh.cluster_shape @mesh0 axes = [] : index, index, index
+ // CHECK: return %[[RES]]#0, %[[RES]]#1, %[[RES]]#2 : index, index, index
+ return %0#0, %0#1, %0#2 : index, index, index
+}
+
+// CHECK-LABEL: func @process_index
+func.func @process_index() -> (index, index) {
+ // CHECK: %[[RES:.*]]:2 = mesh.process_index on @mesh0 axes = [0, 1] : index, index
+ %0:2 = mesh.process_index on @mesh0 axes = [0, 1] : index, index
+ // CHECK: return %[[RES]]#0, %[[RES]]#1 : index, index
+ return %0#0, %0#1 : index, index
+}
+
+// CHECK-LABEL: func @process_index_default_axes
+func.func @process_index_default_axes() -> (index, index, index) {
+ // CHECK: %[[RES:.*]]:3 = mesh.process_index on @mesh0 : index, index, index
+ %0:3 = mesh.process_index on @mesh0 : index, index, index
+ // CHECK: return %[[RES]]#0, %[[RES]]#1, %[[RES]]#2 : index, index, index
+ return %0#0, %0#1, %0#2 : index, index, index
+}
+
+// CHECK-LABEL: func @process_index_empty_axes
+func.func @process_index_empty_axes() -> (index, index, index) {
+ // CHECK: %[[RES:.*]]:3 = mesh.process_index on @mesh0 : index, index, index
+ %0:3 = mesh.process_index on @mesh0 axes = [] : index, index, index
+ // CHECK: return %[[RES]]#0, %[[RES]]#1, %[[RES]]#2 : index, index, index
+ return %0#0, %0#1, %0#2 : index, index, index
+}
+
+
// CHECK-LABEL: func @all_reduce
func.func @all_reduce(
// CHECK-SAME: %[[ARG:.*]]: tensor<3x4xf32>
diff --git a/mlir/test/Dialect/Mesh/resharding-spmdization.mlir b/mlir/test/Dialect/Mesh/resharding-spmdization.mlir
new file mode 100644
index 00000000000000..0ba0d76c09a74d
--- /dev/null
+++ b/mlir/test/Dialect/Mesh/resharding-spmdization.mlir
@@ -0,0 +1,154 @@
+// RUN: mlir-opt -test-mesh-resharding-spmdization %s | FileCheck %s
+
+mesh.cluster @mesh_1d(rank = 1, dim_sizes = 2)
+mesh.cluster @mesh_1d_dynamic(rank = 1, dim_sizes = ?)
+
+// CHECK-LABEL: func @same_source_and_target_sharding
+func.func @same_source_and_target_sharding(
+ // CHECK-SAME: %[[ARG:.*]]: tensor<2xf32>
+ %arg0: tensor<2xf32>
+) -> tensor<2xf32> {
+ %0 = mesh.shard %arg0 to <@mesh_1d, [[]]> : tensor<2xf32>
+ %1 = mesh.shard %0 to <@mesh_1d, [[]]> annotate_for_users : tensor<2xf32>
+ // CHECK: return %[[ARG]]
+ return %1 : tensor<2xf32>
+}
+
+// CHECK-LABEL: func @split_replicated_tensor_axis
+func.func @split_replicated_tensor_axis(
+ // CHECK-SAME: %[[ARG:.*]]: tensor<3x14xf32>
+ %arg0: tensor<3x14xf32>
+) -> tensor<3x14xf32> {
+ // CHECK: %[[ZERO:.*]] = arith.constant 0 : index
+ // CHECK: %[[TENSOR_SPLIT_AXIS_SIZE:.*]] = arith.constant 14 : index
+ // CHECK: %[[PROCESS_INDEX:.*]] = mesh.process_index on @mesh_1d axes = [0] : index
+ // CHECK: %[[MESH_AXIS_SIZE:.*]] = mesh.cluster_shape @mesh_1d axes = [0] : index
+ // CHECK: %[[TENSOR_SPLIT_AXIS_SIZE_MOD_MESH_AXIS_SIZE:.*]] = arith.remui %[[TENSOR_SPLIT_AXIS_SIZE]], %[[MESH_AXIS_SIZE]] : index
+ // CHECK: %[[RESULT_TENSOR_AXIS_SIZE_CHECK:.*]] = arith.cmpi eq, %[[TENSOR_SPLIT_AXIS_SIZE_MOD_MESH_AXIS_SIZE]], %[[ZERO]] : index
+ // CHECK: cf.assert %[[RESULT_TENSOR_AXIS_SIZE_CHECK]]
+ // CHECK: %[[RESULT_TENSOR_AXIS_SIZE:.*]] = arith.divui %[[TENSOR_SPLIT_AXIS_SIZE]], %[[MESH_AXIS_SIZE]] : index
+ // CHECK: %[[RESULT_TENSOR_AXIS_OFFSET:.*]] = arith.muli %[[RESULT_TENSOR_AXIS_SIZE]], %[[PROCESS_INDEX]] : index
+ // CHECK: %[[RESULT_TENSOR_SLICE:.*]] = tensor.extract_slice %[[ARG]][0, %[[RESULT_TENSOR_AXIS_OFFSET]]] [3, 7] [1, 1] : tensor<3x14xf32> to tensor<3x7xf32>
+ // CHECK: %[[RESULT:.*]] = builtin.unrealized_conversion_cast %[[RESULT_TENSOR_SLICE]] : tensor<3x7xf32> to tensor<3x14xf32>
+ %0 = mesh.shard %arg0 to <@mesh_1d, [[]]> : tensor<3x14xf32>
+ %1 = mesh.shard %0 to <@mesh_1d, [[], [0]]> annotate_for_users : tensor<3x14xf32>
+ // CHECK: return %[[RESULT]] : tensor<3x14xf32>
+ return %1 : tensor<3x14xf32>
+}
+
+// CHECK-LABEL: func @split_replicated_tensor_axis_dynamic
+func.func @split_replicated_tensor_axis_dynamic(
+ // CHECK-SAME: %[[ARG:.*]]: tensor<?x3x?xf32>
+ %arg0: tensor<?x3x?xf32>
+) -> tensor<?x3x?xf32> {
+ // CHECK: %[[ZERO:.*]] = arith.constant 0 : index
+ // CHECK: %[[TWO:.*]] = arith.constant 2 : index
+ // CHECK: %[[PROCESS_INDEX:.*]] = mesh.process_index on @mesh_1d_dynamic axes = [0] : index
+ // CHECK: %[[MESH_AXIS_SIZE:.*]] = mesh.cluster_shape @mesh_1d_dynamic axes = [0] : index
+ // CHECK: %[[TENSOR_SPLIT_AXIS_SIZE:.*]] = tensor.dim %[[ARG]], %[[ZERO]] : tensor<?x3x?xf32>
+ // CHECK: %[[TENSOR_SPLIT_AXIS_SIZE_MOD_MESH_AXIS_SIZE:.*]] = arith.remui %[[TENSOR_SPLIT_AXIS_SIZE]], %[[MESH_AXIS_SIZE]] : index
+ // CHECK: %[[RESULT_TENSOR_AXIS_SIZE_CHECK:.*]] = arith.cmpi eq, %[[TENSOR_SPLIT_AXIS_SIZE_MOD_MESH_AXIS_SIZE]], %[[ZERO]] : index
+ // CHECK: cf.assert %[[RESULT_TENSOR_AXIS_SIZE_CHECK]]
+ // CHECK: %[[RESULT_TENSOR_SPLIT_AXIS_SIZE:.*]] = arith.divui %[[TENSOR_SPLIT_AXIS_SIZE]], %[[MESH_AXIS_SIZE]] : index
+ // CHECK: %[[RESULT_TENSOR_SPLIT_AXIS_OFFSET:.*]] = arith.muli %[[RESULT_TENSOR_SPLIT_AXIS_SIZE]], %[[PROCESS_INDEX]] : index
+ // CHECK: %[[TENSOR_AXIS_2_SIZE:.*]] = tensor.dim %[[ARG]], %[[TWO]] : tensor<?x3x?xf32>
+ // CHECK: %[[RESULT_TENSOR_SLICE:.*]] = tensor.extract_slice %[[ARG]][%[[RESULT_TENSOR_SPLIT_AXIS_OFFSET]], 0, 0]
+ // CHECK-SAME: [%[[RESULT_TENSOR_SPLIT_AXIS_SIZE]], 3, %[[TENSOR_AXIS_2_SIZE]]] [1, 1, 1] : tensor<?x3x?xf32> to tensor<?x3x?xf32>
+ %0 = mesh.shard %arg0 to <@mesh_1d_dynamic, [[], [], []]> : tensor<?x3x?xf32>
+ %1 = mesh.shard %0 to <@mesh_1d_dynamic, [[0]]> annotate_for_users : tensor<?x3x?xf32>
+ // CHECK: return %[[RESULT_TENSOR_SLICE]] : tensor<?x3x?xf32>
+ return %1 : tensor<?x3x?xf32>
+}
+
+// CHECK-LABEL: func @move_split_axis
+func.func @move_split_axis(
+ // CHECK-SAME: %[[ARG:.*]]: tensor<10x14xf32>
+ %arg0: tensor<10x14xf32>
+) -> tensor<10x14xf32> {
+ // CHECK: %[[SOURCE_SHARD:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : tensor<10x14xf32> to tensor<5x14xf32>
+ // CHECK: %[[TARGET_SHARD:.*]] = mesh.all_to_all %[[SOURCE_SHARD]] on @mesh_1d mesh_axes = [0] split_axis = 1 concat_axis = 0 : tensor<5x14xf32> -> tensor<10x7xf32>
+ // CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[TARGET_SHARD]] : tensor<10x7xf32> to tensor<10x14xf32>
+ %0 = mesh.shard %arg0 to <@mesh_1d, [[0]]> : tensor<10x14xf32>
+ %1 = mesh.shard %0 to <@mesh_1d, [[], [0]]> annotate_for_users : tensor<10x14xf32>
+ // CHECK: return %[[RES]] : tensor<10x14xf32>
+ return %1 : tensor<10x14xf32>
+}
+
+// CHECK-LABEL: func @move_split_axis_dynamic_mesh
+func.func @move_split_axis_dynamic_mesh(
+ // CHECK-SAME: %[[ARG:.*]]: tensor<10x14xf32>
+ %arg0: tensor<10x14xf32>
+) -> tensor<10x14xf32> {
+ // CHECK: %[[SOURCE_SHARD:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : tensor<10x14xf32> to tensor<?x14xf32>
+ // CHECK: %[[ALL_TO_ALL:.*]] = mesh.all_to_all %[[SOURCE_SHARD]] on @mesh_1d_dynamic mesh_axes = [0] split_axis = 1 concat_axis = 0 : tensor<?x14xf32> -> tensor<?x?xf32>
+ // CHECK: %[[TARGET_SHARD:.*]] = tensor.cast %[[ALL_TO_ALL]] : tensor<?x?xf32> to tensor<10x?xf32>
+ // CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[TARGET_SHARD]] : tensor<10x?xf32> to tensor<10x14xf32>
+ %0 = mesh.shard %arg0 to <@mesh_1d_dynamic, [[0]]> : tensor<10x14xf32>
+ %1 = mesh.shard %0 to <@mesh_1d_dynamic, [[], [0]]> annotate_for_users : tensor<10x14xf32>
+ // CHECK: return %[[RES]] : tensor<10x14xf32>
+ return %1 : tensor<10x14xf32>
+}
+
+// CHECK-LABEL: func @move_split_dynamic_axis
+func.func @move_split_dynamic_axis(
+ // CHECK-SAME: %[[ARG:.*]]: tensor<?x14xf32>
+ %arg0: tensor<?x14xf32>
+) -> tensor<?x14xf32> {
+ // CHECK: %[[TARGET_SHARD:.*]] = mesh.all_to_all %[[ARG]] on @mesh_1d mesh_axes = [0] split_axis = 1 concat_axis = 0 : tensor<?x14xf32> -> tensor<?x7xf32>
+ // CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[TARGET_SHARD]] : tensor<?x7xf32> to tensor<?x14xf32>
+ %0 = mesh.shard %arg0 to <@mesh_1d, [[0]]> : tensor<?x14xf32>
+ %1 = mesh.shard %0 to <@mesh_1d, [[], [0]]> annotate_for_users : tensor<?x14xf32>
+ // CHECK: return %[[RES]] : tensor<?x14xf32>
+ return %1 : tensor<?x14xf32>
+}
+
+// CHECK-LABEL: func @unshard_static_axis
+func.func @unshard_static_axis(
+ // CHECK-SAME: %[[ARG:.*]]: tensor<10x14xf32>
+ %arg0: tensor<10x14xf32>
+) -> tensor<10x14xf32> {
+ // CHECK: %[[SOURCE_SHARD:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : tensor<10x14xf32> to tensor<5x14xf32>
+ // CHECK: %[[ALL_GATHER:.*]] = mesh.all_gather %[[SOURCE_SHARD]] on @mesh_1d mesh_axes = [0] gather_axis = 0 : tensor<5x14xf32> -> tensor<10x14xf32>
+ %0 = mesh.shard %arg0 to <@mesh_1d, [[0]]> : tensor<10x14xf32>
+ %1 = mesh.shard %0 to <@mesh_1d, [[]]> annotate_for_users : tensor<10x14xf32>
+ // CHECK: return %[[ALL_GATHER]] : tensor<10x14xf32>
+ return %1 : tensor<10x14xf32>
+}
+
+// CHECK-LABEL: func @unshard_dynamic_axis
+func.func @unshard_dynamic_axis(
+ // CHECK-SAME: %[[ARG:.*]]: tensor<?x14xf32>
+ %arg0: tensor<?x14xf32>
+) -> tensor<?x14xf32> {
+ // CHECK: %[[ALL_GATHER:.*]] = mesh.all_gather %[[ARG]] on @mesh_1d mesh_axes = [0] gather_axis = 0 : tensor<?x14xf32> -> tensor<?x14xf32>
+ %0 = mesh.shard %arg0 to <@mesh_1d, [[0]]> : tensor<?x14xf32>
+ %1 = mesh.shard %0 to <@mesh_1d, [[]]> annotate_for_users : tensor<?x14xf32>
+ // CHECK: return %[[ALL_GATHER]] : tensor<?x14xf32>
+ return %1 : tensor<?x14xf32>
+}
+
+// CHECK-LABEL: func @unshard_static_axis_on_dynamic_mesh_axis
+func.func @unshard_static_axis_on_dynamic_mesh_axis(
+// CHECK-SAME: %[[ARG:.*]]: tensor<10x14xf32>
+ %arg0: tensor<10x14xf32>
+) -> tensor<10x14xf32> {
+ // CHECK: %[[SOURCE_SHARD:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : tensor<10x14xf32> to tensor<?x14xf32>
+ // CHECK: %[[ALL_GATHER:.*]] = mesh.all_gather %[[SOURCE_SHARD]] on @mesh_1d_dynamic mesh_axes = [0] gather_axis = 0 : tensor<?x14xf32> -> tensor<?x14xf32>
+ // CHECK: %[[RES:.*]] = tensor.cast %[[ALL_GATHER]] : tensor<?x14xf32> to tensor<10x14xf32>
+ %0 = mesh.shard %arg0 to <@mesh_1d_dynamic, [[0]]> : tensor<10x14xf32>
+ %1 = mesh.shard %0 to <@mesh_1d_dynamic, [[]]> annotate_for_users : tensor<10x14xf32>
+ // CHECK: return %[[RES]] : tensor<10x14xf32>
+ return %1 : tensor<10x14xf32>
+}
+
+// CHECK-LABEL: func @partial_axis
+func.func @partial_axis(
+// CHECK-SAME: %[[ARG:.*]]: tensor<10x14xf32>
+ %arg0: tensor<10x14xf32>
+) -> tensor<10x14xf32> {
+ // CHECK: %[[ALL_REDUCE:.*]] = mesh.all_reduce %[[ARG]] on @mesh_1d mesh_axes = [0] : tensor<10x14xf32> -> tensor<10x14xf32>
+ %0 = mesh.shard %arg0 to <@mesh_1d, [[]], partial = sum[0]> : tensor<10x14xf32>
+ %1 = mesh.shard %0 to <@mesh_1d, [[]]> annotate_for_users : tensor<10x14xf32>
+ // CHECK: %[[ALL_REDUCE]] : tensor<10x14xf32>
+ return %1 : tensor<10x14xf32>
+}
diff --git a/mlir/test/lib/Dialect/Mesh/CMakeLists.txt b/mlir/test/lib/Dialect/Mesh/CMakeLists.txt
index 16b50bb878a074..f14d282857a1e0 100644
--- a/mlir/test/lib/Dialect/Mesh/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/Mesh/CMakeLists.txt
@@ -1,5 +1,6 @@
# Exclude tests from libMLIR.so
add_mlir_library(MLIRMeshTestSimplifications
+ TestReshardingSpmdization.cpp
TestSimplifications.cpp
EXCLUDE_FROM_LIBMLIR
diff --git a/mlir/test/lib/Dialect/Mesh/TestReshardingSpmdization.cpp b/mlir/test/lib/Dialect/Mesh/TestReshardingSpmdization.cpp
new file mode 100644
index 00000000000000..940010fd94888d
--- /dev/null
+++ b/mlir/test/lib/Dialect/Mesh/TestReshardingSpmdization.cpp
@@ -0,0 +1,124 @@
+//===- TestSimplification.cpp - Test simplification -----------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Mesh/IR/MeshOps.h"
+#include "mlir/Dialect/Mesh/Transforms/Spmdization.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
+#include "mlir/IR/BuiltinDialect.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/BuiltinTypeInterfaces.h"
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/ImplicitLocOpBuilder.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/SymbolTable.h"
+#include "mlir/IR/Value.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+using namespace mlir::mesh;
+
+namespace {
+
+struct TestMeshReshardingRewritePattern : OpRewritePattern<ShardOp> {
+ using OpRewritePattern<ShardOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(ShardOp op,
+ PatternRewriter &rewriter) const override {
+ if (op.getAnnotateForUsers()) {
+ return failure();
+ }
+
+ mesh::ClusterOp mesh = symbolTable.lookupNearestSymbolFrom<mesh::ClusterOp>(
+ op, op.getShard().getCluster());
+
+ bool foundUser = false;
+ for (auto user : op->getUsers()) {
+ if (auto targetShardOp = llvm::dyn_cast<ShardOp>(user)) {
+ if (targetShardOp.getAnnotateForUsers() &&
+ mesh == symbolTable.lookupNearestSymbolFrom<mesh::ClusterOp>(
+ targetShardOp, targetShardOp.getShard().getCluster())) {
+ foundUser = true;
+ break;
+ }
+ }
+ }
+
+ if (!foundUser) {
+ return failure();
+ }
+
+ for (auto user : op->getUsers()) {
+ auto targetShardOp = llvm::dyn_cast<ShardOp>(user);
+ if (!targetShardOp || !targetShardOp.getAnnotateForUsers() ||
+ symbolTable.lookupNearestSymbolFrom<mesh::ClusterOp>(
+ targetShardOp, targetShardOp.getShard().getCluster()) != mesh) {
+ continue;
+ }
+
+ ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
+ ShapedType sourceShardShape =
+ shardShapedType(op.getResult().getType(), mesh, op.getShard());
+ TypedValue<ShapedType> sourceShard =
+ builder
+ .create<UnrealizedConversionCastOp>(sourceShardShape,
+ op.getOperand())
+ ->getResult(0)
+ .cast<TypedValue<ShapedType>>();
+ TypedValue<ShapedType> targetShard =
+ reshard(builder, mesh, op, targetShardOp, sourceShard);
+ Value newTargetUnsharded =
+ builder
+ .create<UnrealizedConversionCastOp>(
+ targetShardOp.getResult().getType(), targetShard)
+ ->getResult(0);
+ rewriter.replaceAllUsesWith(targetShardOp.getResult(),
+ newTargetUnsharded);
+ }
+
+ return success();
+ }
+
+private:
+ mutable SymbolTableCollection symbolTable;
+};
+
+struct TestMeshReshardingPass
+ : public PassWrapper<TestMeshReshardingPass, OperationPass<ModuleOp>> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMeshReshardingPass)
+
+ void runOnOperation() override {
+ RewritePatternSet patterns(&getContext());
+ patterns.insert<TestMeshReshardingRewritePattern>(&getContext());
+ if (failed(applyPatternsAndFoldGreedily(getOperation().getOperation(),
+ std::move(patterns)))) {
+ return signalPassFailure();
+ }
+ }
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ reshardingRegisterDependentDialects(registry);
+ registry.insert<BuiltinDialect>();
+ }
+ StringRef getArgument() const final {
+ return "test-mesh-resharding-spmdization";
+ }
+ StringRef getDescription() const final {
+ return "Test Mesh dialect resharding spmdization.";
+ }
+};
+} // namespace
+
+namespace mlir {
+namespace test {
+void registerTestMeshReshardingSpmdizationPass() {
+ PassRegistration<TestMeshReshardingPass>();
+}
+} // namespace test
+} // namespace mlir
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index eedade691c6c39..bea2449f44dcde 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -119,6 +119,7 @@ void registerTestMathPolynomialApproximationPass();
void registerTestMemRefDependenceCheck();
void registerTestMemRefStrideCalculation();
void registerTestMeshSimplificationsPass();
+void registerTestMeshReshardingSpmdizationPass();
void registerTestNextAccessPass();
void registerTestOneToNTypeConversionPass();
void registerTestOpaqueLoc();
@@ -237,6 +238,7 @@ void registerTestPasses() {
mlir::test::registerTestMemRefDependenceCheck();
mlir::test::registerTestMemRefStrideCalculation();
mlir::test::registerTestMeshSimplificationsPass();
+ mlir::test::registerTestMeshReshardingSpmdizationPass();
mlir::test::registerTestNextAccessPass();
mlir::test::registerTestOneToNTypeConversionPass();
mlir::test::registerTestOpaqueLoc();
>From 1e0f1eb0d12f60d12654401b1e6304e58c8d964a Mon Sep 17 00:00:00 2001
From: Boian Petkantchin <boian.petkantchin at amd.com>
Date: Thu, 21 Dec 2023 13:47:33 -0800
Subject: [PATCH 2/5] Remove unused ceilDiv function
---
mlir/include/mlir/Support/MathExtras.h | 11 -----------
1 file changed, 11 deletions(-)
diff --git a/mlir/include/mlir/Support/MathExtras.h b/mlir/include/mlir/Support/MathExtras.h
index 22493c11a370d1..17a747393d26eb 100644
--- a/mlir/include/mlir/Support/MathExtras.h
+++ b/mlir/include/mlir/Support/MathExtras.h
@@ -15,20 +15,9 @@
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/APInt.h"
-#include <type_traits>
namespace mlir {
-// ceilDiv for unsigned integral.
-template <typename T, std::enable_if_t<std::is_integral_v<T> &&
- !std::is_unsigned_v<T>> = true>
-T ceilDiv(T lhs, T rhs) {
- assert(rhs != static_cast<T>(0));
- T q = lhs / rhs;
- T r = lhs % rhs;
- return r == static_cast<T>(0) ? q : q + static_cast<T>(1);
-}
-
/// Returns the result of MLIR's ceildiv operation on constants. The RHS is
/// expected to be non-zero.
inline int64_t ceilDiv(int64_t lhs, int64_t rhs) {
>From 95c5e52a5271222cb462c137edcc3419840ac6a2 Mon Sep 17 00:00:00 2001
From: Boian Petkantchin <boian.petkantchin at amd.com>
Date: Thu, 21 Dec 2023 15:10:18 -0800
Subject: [PATCH 3/5] Remove mutable state from
TestMeshReshardingRewritePattern
---
mlir/test/lib/Dialect/Mesh/TestReshardingSpmdization.cpp | 4 +---
1 file changed, 1 insertion(+), 3 deletions(-)
diff --git a/mlir/test/lib/Dialect/Mesh/TestReshardingSpmdization.cpp b/mlir/test/lib/Dialect/Mesh/TestReshardingSpmdization.cpp
index 940010fd94888d..6fecbd48f15387 100644
--- a/mlir/test/lib/Dialect/Mesh/TestReshardingSpmdization.cpp
+++ b/mlir/test/lib/Dialect/Mesh/TestReshardingSpmdization.cpp
@@ -36,6 +36,7 @@ struct TestMeshReshardingRewritePattern : OpRewritePattern<ShardOp> {
return failure();
}
+ SymbolTableCollection symbolTable;
mesh::ClusterOp mesh = symbolTable.lookupNearestSymbolFrom<mesh::ClusterOp>(
op, op.getShard().getCluster());
@@ -85,9 +86,6 @@ struct TestMeshReshardingRewritePattern : OpRewritePattern<ShardOp> {
return success();
}
-
-private:
- mutable SymbolTableCollection symbolTable;
};
struct TestMeshReshardingPass
>From 9e72161754aace136cf4f99fe340a8cee45d004e Mon Sep 17 00:00:00 2001
From: Boian Petkantchin <sogartary at yahoo.com>
Date: Tue, 2 Jan 2024 09:28:23 -0800
Subject: [PATCH 4/5] Update
mlir/include/mlir/Dialect/Mesh/Transforms/ReshardingSpmdizationDoc
Co-authored-by: Chengji Yao <yaochengji at hotmail.com>
---
.../mlir/Dialect/Mesh/Transforms/ReshardingSpmdizationDoc | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/ReshardingSpmdizationDoc b/mlir/include/mlir/Dialect/Mesh/Transforms/ReshardingSpmdizationDoc
index 181f07177e0af9..ba31e7e7ab28a1 100644
--- a/mlir/include/mlir/Dialect/Mesh/Transforms/ReshardingSpmdizationDoc
+++ b/mlir/include/mlir/Dialect/Mesh/Transforms/ReshardingSpmdizationDoc
@@ -1,4 +1,4 @@
-Reshadring Spmdization Examples
+Resharding Spmdization Examples
--------------------------------------------------------------
>From 517b50123acfa1ea502acea050a3b50ad87c8f94 Mon Sep 17 00:00:00 2001
From: Boian Petkantchin <sogartary at yahoo.com>
Date: Tue, 2 Jan 2024 09:29:15 -0800
Subject: [PATCH 5/5] Update
mlir/include/mlir/Dialect/Mesh/Transforms/ReshardingSpmdizationDoc
Co-authored-by: Chengji Yao <yaochengji at hotmail.com>
---
.../mlir/Dialect/Mesh/Transforms/ReshardingSpmdizationDoc | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/ReshardingSpmdizationDoc b/mlir/include/mlir/Dialect/Mesh/Transforms/ReshardingSpmdizationDoc
index ba31e7e7ab28a1..9299bc68d75338 100644
--- a/mlir/include/mlir/Dialect/Mesh/Transforms/ReshardingSpmdizationDoc
+++ b/mlir/include/mlir/Dialect/Mesh/Transforms/ReshardingSpmdizationDoc
@@ -401,7 +401,7 @@ d2(
We want to copy the the data between devices in slices/tiles.
What are the source/target tile coordinates?
-Fro a fixed (m, n, p, q) what is the range of (k, l, u, v)?
+For a fixed (m, n, p, q) what is the range of (k, l, u, v)?
TODO
--------------------------------------------------------------
More information about the Mlir-commits
mailing list